Discussion: July 6th
In this multi-part assignment, we will try to implement simple versions of autoregressive models like PixelCNN as well as VQ-VAEs, and combine the two.
Note: Some “starter code” for this is available in the course repository! Also, feel free to refer back to the paper at any time.
Let’s start with a relatively simple VQ-VAE. Build an autoencoder on a dataset of your choice; this should be easy for you at this point (just take old code if you have it!). Now, turn it into a VQ-VAE as follows:
K x d
matrix, where d
is the latent
dimension of the autoencoder, and K
is the number of discrete codes you want
– another hyperparameter to tune!b x K
matrix (b
is batch size). This is
efficiently computable using broadcasting.tf.argmin
. This gives you, for each batch element, the codebook entry
that is closest to the encoder result.tf.gather
.Obviously this limits the model somewhat, since it can only produce K
distinct
outputs. You can view this as a kind of clustering of the dataset into K
classes.
Does the above work? No? Something about no gradients for the encoder?
…The problem is of course that the discrete argmin
operation does not allow
for gradient flow. We can fix this via the”straight-through estimator”.
This can be implemented as x + stop_gradient(y - x)
.
y
is the result of the codebook lookup and x
are the encoder outputs.y
.x
– the encoder output.Finally, we need to train the VQ-VAE. The loss has several components:
square(stop_gradient(encodings) - codes)
where encodings
are the encoder outputs and codes
the corresponding codebook entries. This
draws the codebook entries closer to the encoder outputs without modifying the
encoder.square(encodings - stop_gradient(codes))
, which is
supposed to encourage the encoder to “stick with” a codebook entry and not jump
around too much.As usual, you will need to balance these losses reasonably well. The paper proposes scaling the commitment loss lower than the codebook loss. If your net doesn’t train properly, you might need to scale up the reconstruction loss.
If the model trains reasonably well, generation is simple: Draw a random vector from the codebook and put it through the decoder. This approach has issues, though. Encoding a large amount of images into a single set of quantized codes doesn’t work well – you can only generate as many different outputs as you have codebook entries! With a reasonable number (say, 256), this severely limits model capacity and thus output quality. If you tried it like this, you likely got very blurry results.
Instead, VQ-VAEs are rather used to encode data into a grid (or sequence) of codes – e.g. use a CNN encoder, but do not flatten the data at any point. That way, you get a grid of codes. Say, 32x32 inputs, with 8x8 overall stride and the final convolutional layer has 64 channels – this would mean you end up with a 4x4 “image” where each pixel is a 64-d code. Now, the idea is to quantize each code separately – allowing for far more variety in the outputs as now we have 16 pixels with 256 possible values each (instead of just 256 possible values overall – still sticking with the example of having 256 entries in the codebook). You should be able to reach similar reconstruction quality to a normal autoencoder!
…how do we generate new images now? In principle, you would need to sample one code for each of the 4x4 pixels and send this “image” through the decoder. The problem here is that, by sampling each pixel independently, the resulting images will lack global coherence. We need to somehow sample codes in a dependent manner. Which brings us to…
It is possible to view an image as a sequence of pixel values, and set up a generative process accordingly. The inductive bias of this model is questionable, but the practical results are rather strong. In the course repository, you can find a notebook that applies this idea in a very naive fashion the MNIST dataset. This model ignores the fact that images have rows and columns and thus tends to generate outputs in the wrong place. This also means that, for example, the pixel directly above another one is treated as farther away than the one directly to the left. Overall, it’s just not a good model! Instead, let’s use ideas from the PixelRNN paper, in particular PixelCNN.
PixelCNN is very simple to implement using the kernel_constraint
functionality
in Keras layers. Here, we can supply a mask that is multiplied with the kernel
after each training step, zeroing out the affected components. We need to do
this because the network cannot look into the future, i.e. it cannot make use
of pixels that have not been generated yet (according to whatever generation
order is being used). See figure 4 in the paper, or figure 1 in
the follow-up.
You can find sample code in the course repository. Use
this (or your own implementation) to build a stack of convolutional layers. The
final layer should have 256 output units and acts as the softmax predictor just
like in the RNN code. Training is straightforward – the target is equal to the
input (but note that inputs should be scaled to [0, 1] where as targets should
be category indices from 0 to 255)! Because the layers are masked such that the
central pixel is not looked at, this works out just right.
With the model trained (which can take a while), you can use it for generation! This proceeds sequentially, one pixel after the other, and is very slow. There is sample code in the repo, but this is not optimized very well, so you may be able to improve on it. Hopefully, your results are better than with the RNN!
Note: If you get bad results, consider using residual connections (see figure 5 in the paper), which should improve performance considerably. You can also try other changes from the follow-up (linked above) or… the follow-follow-up.
Also: Transferring this to color datasets is a bit more complicated because of how the masks interact with the color channels. In the PixelRNN paper, every color value is generated one-by-one, which complicates the masks. However, it should also be possible to sample all color values of one pixel independently in parallel, although technically this weakens the model.
Hopefully, your autoregressive model can already produce decent samples. But recall that we also have the VQ-VAE to worry about! if the components function reasonably well, you can put them together as follows: