Text Generation With Pytorch

Reading Time: 3 minutes

Hello guys! Here we are again to have some fun with deep learning. As of the previous post, we trained a model to generate text with Tensorflow. Today, I am gonna show you how we can do it with Pytorch. Let’s go!


In this blog post, what we are going to do is pretty much the same as what we did in the last post. We will create a model which can learn to generate some meaningful context like below:

“I am sure we have studied Hogwarts and saw that the magical appearance of Dumbledore was starting to fear that the innocent” she said. Harry had just given the illusion how stars had lunged in at the same moment they were staring into a corner, the faint wall had obliged in the ground, he tried, but the detritus of magical creature lay in the air in front and spitefully the bond about a celebrated of green and brown, that weapons began weight magicked the wounds quickly; Dolohov.


To get the most out of today’s post, I suggest that you have:

  • Python installed (Python3 is definitely recommended)
  • Pytorch installed (at least version 1.0)
  • Some experience with Python and know how RNNs, word embeddings work
  • Read my previous post (link here)

About the last thing though, the logic behind how things work remains the same regardless of whether your code is written in Tensorflow or Pytorch, so this post will be focused on Pytorch implementation only.


The data processing code from the last post is not Tensorflow-dependent, which means that we can use as-is without any modifications.

Firstly, let’s import the packages we need for today:

Obviously we can’t use tf.app.flags, but we always have argparse at our back to do the job.

Next, we need a function to process the raw data. You can check the implementation detail in the Dataset session of the last post. Here I only show you the complete code:

And finally, we must define a function to generate batches for training:

That is all we need for this step. Phew! Not always that easy though, but
just make things simple where things can be simple, right?


Creating a network in Pytorch is very straight-forward. All we have to do is create a subclass of torch.nn.Module, define the necessary layers in __init__ method and implement the forward pass within forward method.

Let’s recall a little bit. We need an embedding layer, an LSTM layer, and a dense layer, so here is the __init__ method:

The next method, forward, will take an input sequence and the previous states and produce the output together with states of the current timestep:

Because we need to reset states at the beginning of every epoch, we need to define one more method to help us set all states to zero:

That may look strange to some of you. Since LSTM’s states consist of two separate states called hidden states and memory states (denoted as state_h and state_c respectively). Remember this difference when using LSTM units.


We have done with the network. Now we need a loss function and a training op. Defining the two is surprisingly simple in Pytorch:

“We’re not doing gradient clipping this time?”, you may ask. So glad that you pointed it out. Of course we will, but not here. You will see in a second.


We are ready to train the network. Here we will come across one thing that some may like while others may not favor at all: manually manage the data transfer between devices.

If your machine doesn’t have a GPU, you are somehow lucky. For those who have, just don’t forget to keep track of where your tensors are. Here are some tips of mine:

  • If the training is slow, you might have forgotten to move data to GPU
  • You can move everything to GPU first, then fix along the errors until things work.

Okay, let’s code. First, we will get the device information, get the training data, create the network, loss function and the training op. And don’t forget to transfer the network to GPU:

Next, for each epoch, we will loop through the batches to compute loss values and update network’s parameters. A typical set of steps for training in Pytorch is:

  • Call the train() method on the network’s instance (it will inform inner mechanism that we are about to train, not execute the training)
  • Reset all gradients
  • Compute output, loss value, accuracy, etc
  • Perform back-propagation
  • Update the network’s parameters

Here is how it looks like in code:

You may notice the detach() thing. Whenever we want to use something that belongs to the computational graph for other operations, we must remove them from the graph by calling detach() method. The reason is, Pytorch keeps track of the tensors’ flow to perform back-propagation through a mechanism called autograd. We mess it up and Pytorch will fail to deliver the loss.

Is there anything I have missed? Oh, the gradient clipping! While it may not be as intuitive, it only requires one line of code. We just need to put it after calling loss.backward() and before optimizer.step() like this:

Finally, we will add code to print the loss value to console and have the model generate some text for us during training:

That is the training loop. The only thing left is to define the predict method.


We finally reached the last and most fun part: implement the predict method. What we are going to do can be illustrated in this figure below:

Fig. 1: the inference process

Assuming that we have some initial words (“Lord” and “Voldemort” in this case), we will use them as input to compute the final output, which is the word “is”. The code is as follow, don’t forget to tell the network that we are about to evaluate by calling eval() method and of course, remember to move your stuff to GPU:

Next, we will use that final output as input for the next time step and continue doing so until we have a sequence of length we wanted. Finally, we simply print out the result sequence to the console:

We can now hit the run button and of course, don’t forget to get yourself a cup of coffee. Enjoy your machine’s creativity!

Final word

So in today’s post, we have created a model which can learn from any raw text source and generate some interesting content for us.

We have done it with ease by using Pytorch, a deep learning library which has gained a bunch of attention for the recent years. All the code and training data can be found at my repo (Pytorch scripts have _pt postfix).

That’s it for today, guys! Thank you so much for reading. And I am definitely seeing you soon.


  1. Text generation with Tensorflow: link
  2. Colah’s excellent blog post about LSTM: link
  3. Intro to RNN’s tutorial from Mat, Udacity: link
  4. Donald Trump’s full speech: link
  5. Oliver Twist: link

Trung Tran is a software developer + AI engineer. He also works on networking & cybersecurity on the side. He loves blogging about new technologies and all posts are from his own experiences and opinions.

8 comments On Text Generation With Pytorch

  • I received a lot of emails when I published my old blog post asking for Harry Potter’s text files. I’m sorry for disappointing you guys but I can’t share them (you know the reason why).

    Still, there’s a lot of free stuff out there for you to experiment. So, enjoy your network 😀

  • Hi, Can you show how can we calculate a score(like perplexity) for a sentence, to show how good the sentence is based on this trained language model?

  • Thanks for the nice tutorial! I have got a problem with the UTF-8 encoding. I get some weird string for certain characters, even if there are in the UTF-8 encoding. Here an example: b’I am too beautiful snowy owl, scar. \xe2\x80\x98You\xe2\x80\x99ve already broken his legs was no good garden was standing there into his hands out there and a huge chessboard, and the door slid open up \xe2\x80\x93 Professor flying by a small package was still standing getting all the stranger. Think he said, \xe2\x80\x98Don\xe2\x80\x99 mind you\xe2\x80\x99re nervous. Go on!\xe2\x80\x99 from under Crabbe they\xe2\x80\x99d stepped over a dirty behind him in her hand. He laid them started to the Gryffindor team,\xe2\x80\x99 Filch was. And it, because the Stone\xe2\x80\x99s the even seen in loud If we were the Muggles started lookin\xe2\x80\x99 had to send Norbert and threw’
    I completely cloned your repo and still got the error.

  • Hi, Sir thanks for the your helping but I wrote it by colab, but I took error, can you help me? Thanks

  • It might seem obvious, but it’s worth noting that this will break if any of the initial words aren’t in the initial data set

    • Also, the naming of the checkpoint folder is inconsistent in the post, you have it as checkpoint, and checkpoint_pt

  • Checkpoint path is different in your arguments up top vs in your main() function, which causes the code to break. Also, if the words in predict() function are not present in the dataset, the code will break.

  • Michail Strijov

    Thank you so much for adding this, I was looking for the reason it was throwing errors at me (was training on a dataset in cyrillic with initial words in English…not so obvious for me :D)

Leave a reply:

Your email address will not be published.