Discussion: June 24th
Deadline: June 23rd, 23:59
In this multi-part assignment, we will try to implement VQ-VAEs and combine them with autoregressive models.
Note: Some “starter code” for this is available in the course repository! Also, feel free to refer back to the paper at any time.
Let’s start with a relatively simple VQ-VAE. Build an autoencoder on a dataset of your choice; this should be easy for you at this point (just take old code if you have it!). Now, turn it into a VQ-VAE as follows:
K x d matrix, where d is the latent
dimension of the autoencoder, and K is the number of discrete codes you want
– another hyperparameter to tune!b x K matrix (b is batch size). This is
efficiently computable using broadcasting.tf.argmin. This gives you, for each batch element, the codebook entry
that is closest to the encoder result.tf.gather.Obviously this limits the model somewhat, since it can only produce K distinct
outputs. You can view this as a kind of clustering of the dataset into K
classes.
Does the above work? No? Something about no gradients for the encoder?
…The problem is of course that the discrete argmin operation does not allow
for gradient flow. We can fix this via the “straight-through estimator”.
This can be implemented as x + stop_gradient(y - x).
y is the result of the codebook lookup and x are the encoder outputs.y.x – the encoder output.Finally, we need to train the VQ-VAE. The loss has several components:
square(stop_gradient(encodings) - codes) where encodings
are the encoder outputs and codes the corresponding codebook entries. This
draws the codebook entries closer to the encoder outputs without modifying the
encoder.square(encodings - stop_gradient(codes)), which is
supposed to encourage the encoder to “stick with” a codebook entry and not jump
around too much.As usual, you will need to balance these losses reasonably well. The paper proposes scaling the commitment loss lower than the codebook loss. If your net doesn’t train properly, you might need to scale up the reconstruction loss.
If the model trains reasonably well, generation is simple: Draw a random vector from the codebook and put it through the decoder. This approach has issues, though. Encoding a large amount of images into a single set of quantized codes doesn’t work well – you can only generate as many different outputs as you have codebook entries! With a reasonable number (say, 256), this severely limits model capacity and thus output quality. If you tried it like this, you likely got very blurry results.
Instead, VQ-VAEs are rather used to encode data into a grid (or sequence) of codes – e.g. use a CNN encoder, but do not flatten the data at any point. That way, you get a grid of codes. Say, 32x32 inputs, with 8x8 overall stride and the final convolutional layer has 64 channels – this would mean you end up with a 4x4 “image” where each pixel is a 64-d code. Now, the idea is to quantize each code separately – allowing for far more variety in the outputs as now we have 16 pixels with 256 possible values each (instead of just 256 possible values overall – still sticking with the example of having 256 entries in the codebook). You should be able to reach similar reconstruction quality to a normal autoencoder!
…how do we generate new images now? In principle, you would need to sample one code for each of the 4x4 pixels and send this “image” through the decoder. The problem here is that, by sampling each pixel independently, the resulting images will lack global coherence. We need to somehow sample codes in a dependent manner. Which brings us to…
Recall that we trained autoregressive models directly on (image) data. Turns out, since the VQ-VAE output above is essentially an image with discrete elements (codebook indices), we can also train autoregressive models on those compressed images! This means that the autoregressive model learns to generate codes that “fit together” well, and then the VQ-VAE decompresses them back to image space.
Embedding.
This would allow the autoregressive model to choose its own “interpretaition”
of what the code indices stand for, instead of being forced to work on the
original codes.