Hello guys! Here we are again to have some fun with deep learning. As of the previous post, we trained a model to generate text with Tensorflow. Today, I am gonna show you how we can do it with Pytorch. Let’s go!
Overview
In this blog post, what we are going to do is pretty much the same as what we did in the last post. We will create a model which can learn to generate some meaningful context like below:
“I am sure we have studied Hogwarts and saw that the magical appearance of Dumbledore was starting to fear that the innocent” she said. Harry had just given the illusion how stars had lunged in at the same moment they were staring into a corner, the faint wall had obliged in the ground, he tried, but the detritus of magical creature lay in the air in front and spitefully the bond about a celebrated of green and brown, that weapons began weight magicked the wounds quickly; Dolohov.
Prerequisites
To get the most out of today’s post, I suggest that you have:
- Python installed (Python3 is definitely recommended)
Pytorch installed (at least version 1.0)- Some experience with Python and know how RNNs, word embeddings work
- Read my previous post (link here)
About the last thing though, the logic behind how things work remains the same regardless of whether your code is written in Tensorflow or Pytorch, so this post will be focused on Pytorch implementation only.
Dataset
The data processing code from the last post is not Tensorflow-dependent, which means that we can use as-is without any modifications.
Firstly, let’s import the packages we need for today:
Obviously we can’t use tf.app.flags, but we always have argparse at our back to do the job.
Next, we need a function to process the raw data. You can check the implementation detail in the Dataset session of the last post. Here I only show you the complete code:
And finally, we must define a function to generate batches for training:
That is all we need for this step. Phew! Not always that easy though, but
just make things simple where things can be simple, right?
Model
Creating a network in Pytorch is very straight-forward. All we have to do is create a subclass of torch.nn.Module, define the necessary layers in __init__ method and implement the forward pass within
Let’s recall a little bit. We need an embedding layer, an LSTM layer, and a dense layer, so here is the __init__ method:
The next method, forward, will take an input sequence and the previous states and produce the output together with states of the current timestep:
Because we need to reset states at the beginning of every epoch, we need to define one more method to help us set all states to zero:
That may look strange to some of you. Since LSTM’s states consist of two separate states called hidden states and memory states (denoted as state_h and state_c respectively). Remember this difference when using LSTM units.
Loss
We have done with the network. Now we need a loss function and a training op. Defining the two is surprisingly simple in Pytorch:
“We’re not doing gradient clipping this time?”, you may ask. So glad that you pointed it out. Of course we will, but not here. You will see in a second.
Training
We are ready to train the network. Here we will come across one thing that some may like while others may not favor at all: manually manage the data transfer between devices.
If your machine doesn’t have a GPU, you are somehow lucky. For those who have, just don’t forget to keep track of where your tensors are. Here are some tips of mine:
- If the training is slow, you might have forgotten to move data to GPU
- You can move everything to GPU first, then fix along the errors until things work.
Okay, let’s code. First, we will get the device information, get the training data, create the network, loss function and the training op. And don’t forget to transfer the network to GPU:
Next, for each epoch, we will loop through the batches to compute loss values and update network’s parameters. A typical set of steps for training in Pytorch is:
- Call the train() method on the network’s instance (it will inform inner mechanism that we are about to train, not execute the training)
- Reset all gradients
- Compute output, loss value, accuracy, etc
- Perform back-propagation
- Update the network’s parameters
Here is how it looks like in code:
You may notice the detach() thing. Whenever we want to use something that belongs to the computational graph for other operations, we must remove them from the graph by calling detach() method. The reason is, Pytorch keeps track of the tensors’ flow to perform back-propagation through a mechanism called
Is there anything I have missed? Oh, the gradient clipping! While it may not be as intuitive, it only requires one line of code. We just need to put it after calling loss.backward() and before optimizer.step() like this:
Finally, we will add code to print the loss value to console and have the model generate some text for us during training:
That is the training loop. The only thing left is to define the predict method.
Inference
We finally reached the last and most fun part: implement the predict method. What we are going to do can be illustrated in this figure below:

Assuming that we have some initial words (“Lord” and “Voldemort” in this case), we will use them as input to compute the final output, which is the word “is”. The code is as follow, don’t forget to tell the network that we are about to evaluate by calling eval() method and of course, remember to move your stuff to GPU:
Next, we will use that final output as input for the next time step and continue doing so until we have a sequence of length we wanted. Finally, we simply print out the result sequence to the console:
We can now hit the run button and of course, don’t forget to get yourself a cup of coffee. Enjoy your machine’s creativity!
Final word
So in today’s post, we have created a model which can learn from any raw text source and generate some interesting content for us.
We have done it with ease by using Pytorch, a deep learning library which has gained a bunch of attention for the recent years. All the code and training data can be found at my repo (Pytorch scripts have _pt postfix).
That’s it for today, guys! Thank you so much for reading. And I am definitely seeing you soon.
7 comments On Text Generation With Pytorch
I received a lot of emails when I published my old blog post asking for Harry Potter’s text files. I’m sorry for disappointing you guys but I can’t share them (you know the reason why).
Still, there’s a lot of free stuff out there for you to experiment. So, enjoy your network 😀
Hi, Can you show how can we calculate a score(like perplexity) for a sentence, to show how good the sentence is based on this trained language model?
Thanks for the nice tutorial! I have got a problem with the UTF-8 encoding. I get some weird string for certain characters, even if there are in the UTF-8 encoding. Here an example: b’I am too beautiful snowy owl, scar. \xe2\x80\x98You\xe2\x80\x99ve already broken his legs was no good garden was standing there into his hands out there and a huge chessboard, and the door slid open up \xe2\x80\x93 Professor flying by a small package was still standing getting all the stranger. Think he said, \xe2\x80\x98Don\xe2\x80\x99 mind you\xe2\x80\x99re nervous. Go on!\xe2\x80\x99 from under Crabbe they\xe2\x80\x99d stepped over a dirty behind him in her hand. He laid them started to the Gryffindor team,\xe2\x80\x99 Filch was. And it, because the Stone\xe2\x80\x99s the even seen in loud If we were the Muggles started lookin\xe2\x80\x99 had to send Norbert and threw’
I completely cloned your repo and still got the error.
Hi, Sir thanks for the your helping but I wrote it by colab, but I took error, can you help me? Thanks
It might seem obvious, but it’s worth noting that this will break if any of the initial words aren’t in the initial data set
Also, the naming of the checkpoint folder is inconsistent in the post, you have it as checkpoint, and checkpoint_pt
Checkpoint path is different in your arguments up top vs in your main() function, which causes the code to break. Also, if the words in predict() function are not present in the dataset, the code will break.
Sliding Sidebar
Recent Posts
Trung Tran
Deep Learning Fellow whose work focuses on 2D/3D object detection and neural machine translation models. Love to create things and blog about the process.
Follow Me On
Categories
Tags
Tweets
Archives
Recent Comments