VQ-VAEs were introduced in 2017 for discrete representation learning.
In a standard autoencoder, data points are mapped to arbitrary latent representations.
In variational autoencoders, each data point is mapped to a probability distribution
instead.
Usually, continuous distributions are chosen, with the most common choice being
Gaussians.
VQ-VAEs, on the other hand, encode data to discrete representations.
To achieve this, first a regular autoencoder is applied.
Next, each latent vector is mapped to the closest vector in a codebook, which
is also part of the model, and learned alongside the other parameters.
This means that the latent representations are limited to combinations of the
codebook vectors.
For more details, please read the paper linked above.
The discrete latent space has some unique advantages over regular VAEs.
For example, we can compress high-dimensional data to a smaller discrete
representation, and then train an autoregressive model to generate such representations,
which can then be decoded back into the original space.
Such techniques are used in models like MusicGen to generate music with
high sampling rates relatively efficiently.
VQ-VAEs have a major advantage over regular autoencoders in terms of the compression they can achieve in the latent space. As an example, let’s say we are working with color images of size 256x256. Perhaps we train a convolutional autoencoder to compress these to a size of 32x32. On the surface, this seems like a reduction by a factor of 64, since that’s the reduction in the number of pixels. However, we generally have a larger number of channels in our latent representation! Color image have three channels, but our encoding may have a lot more. Let’s say we have 256 channels in the latent space. This is more than 80 times more than the input, completely negating the reduction in pixels! Additionally, we usually encode to floats, which are often stored at 32bit precision, whereas images are often only stored in 8bit precision. This means our autoencoder blows up the representation by another factor of 4! Overall, our “compressed” latent representation takes more space than the original images. Of course, we could play around with the parameters to alleviate this, but we will have a hard time achieving strong compression while retaining good reconstruction quality.
This is where VQ-VAEs shine: Since we have a limited number of codebook vectors, and we know that each “pixel” in the encoding is one of those vectors, we do not actually need to store the full encodings. Rather, we only need to store, for each pixel, the index of the codebook vector at this point. The codebook, meanwhile, needs to be stored only a single time, no matter how many images we encode. The number of bits per codebook index depends on the size of the codebook. For example, a codebook with 1024 entries would require indices from 0 to 1023, which require 10 bits to store. Thus, in this example, each pixel would only require 10 bits to store the index, down from 32*256 in the example above (32 bits per float times 256 channels). Finally, we can actually achieve proper compression!
When training autoencoders, there are many moving parts:
Dataset, architecture, latent space structure, reconstruction loss…
I conducted some experiments where I tried to simplify things as much as possible.
I trained a convolutional autoencoder to minimize mean squared reconstruction error
on CIFAR10.
The architecture is fixed, except for the number of channels d
in the final
encoder layer, which are varied systematically.
As such, images are encoded to 4x4xd
, down from the 32x32x3
input size.
The results can be seen below.
Unsurprisingly, more channels, i.e. a larger latent space, results in better
reconstructions, as the model can simply store more information in a larger
space.
I’m not sure why performance slightly degraded when using d=128
; there may
have been some instability in the training as the model starts to overfit.
Subjectively, a loss of around 0.002-0.003 is where the reconstructions start
looking “acceptable”, so in this case, d=16
is a kind of lower bound for
acceptable performance.
Of course, a different (e.g. larger) architecture may be able to compress more efficiently,
and thus manage with a smaller d
.
On the other hand, performance plateaus after d=64
.
My goal here was not to find the best possible architecture; rather, this just
serves to provide a baseline performance for the following VQ-VAEs.
When using VQ-VAEs, we introduce another main variable: Codebook size k
.
That is, in addition to d
, how many vectors do we allow in the codebook?
Obviously, more vectors can cover the latent space more densely, which should
be beneficial for reconstruction performance, as the quantization becomes more
precise.
Compare the two figures below; each orange cross is one codebook vector.
The first only uses 8 codebook vectors, the second one uses 512.
However, more vectors also means less compression.
As such, we only want to use as many vectors as actually necessary.
Here are the results; this time, each d
is represented by a different line,
whereas the new parameter k
is on the x-axis.
Things look quite a bit different this time:
While performance still improves with larger d
as well as k
, there is an
interaction between the two.
To be precise, codebook size k
seems to put a hard limit on the benefits of
larger d
.
In fact, even with a rather large 4096 codebook entries, performance maxes out
at d=8
, with no benefits for more latent channels.
Also, already for d=8
, performance lags far behind the basic no-VQ autoencoder.
It seems an absurd number of codebook entries would be required to catch up.
These observations make sense: For larger d
, the overall size of the space
increases exponentially.
As such, we require many more codebook entries to properly “fill out” the space.
Thus, we quickly arrive at a situation where the model cannot make use of the
larger d
, as the codebook vectors simply cannot make use of the available space.
This obviously presents us with a bit of a problem. If VQ-VAEs result in unacceptable reconstruction performance, their compression advantage is not useful. Recall that I mentioned a loss of around 0.002-0.003 is a good minimal target for this task. None of the VQ models even come close to this value, in fact barely reaching around 0.007 or so. Increasing the codebook size becomes infeasible at some point.
We can understand this mathematically.
Recall that, for k=1024
, each “pixel” in the encoded image requires 10 bits
to store the codebook index.
This is the “bit depth” of each sample.
Also, in our architecture, we encode each image to a 4x4
array of indices.
That is, there are 16 vectors per image.
This puts us at 16*10=160
bits per image.
Compare this to the original: 32x32 pixels, 3 color channels, 8 bits per value.
This puts us at 32*32*3*8 = 24576
bits per image.
We are compressing by a factor of 153.6.
What if we are okay with more bits, i.e. less compression?
We can either increase the number of vectors, that is, compressed image size (say, to 8x8),
or the bit depth.
Increasing the image size can have negative consequences for downstream applications.
For example, if we wanted to train a generative model on the latent space, we only
need to generate 16 vectors for a 4x4 image, whereas we would need 64 vectors for
an 8x8 image, meaning four times more effort.
The other option is to increase bit depth, which is drectly related to codebook size k
.
Maybe we are okay with 2x less compression, i.e. increasing bit depth from 10 to 20.
But 20 bits correspond to 2**20
codebook entries – over a million!
This is infeasible, and hints to why increasing k
is not effective:
Each doubling of codebook size adds only a single bit of information per vector.
Clearly, we need some other method to increase the capacity of our codebooks.
The basic idea of residual VQ (RVQ) is to use a series of codebooks applied in sequence. After applying one codebook, there is generally some degree of quantization error, i.e. the difference between the quantized vector and the pre-quantized encoding. RVQ then applies a second codebook to quantize the quantization error, which is just another vector. This will incur yet another quantization error, which can be quantized via a third codebook, and so on. The final quantization is the sum of all per-codebook quantizations.
This is an efficient way to achieve higher bit-depth:
If one codebook with 1024 entries takes 10 bits, then two such codebooks use 20
bits, with only 2048 vectors overall.
Recall that a single codebook would require over a million vectors for 20 bits.
There is also an intuitive way to understand this:
One codebook with 1024 entries obviously only gives 1024 options for quantized vectors.
But with two codebooks, each entry in the first can be paired with each entry in the
second, giving 1024*1024
entries, i.e. 2**20
or over a million, the same number
of options as a single codebook with a million entries.
The effect becomes more dramatic with more codebooks, exponentially increasing
the number of possible quantizations.
All in all, this provides an efficient way to significantly increase bit depth
for our VQ-VAEs.
Here are some results for our CIFAR10 task:
Here, d=64
was fixed, as this was sufficient for optimal performance in the no-VQ
condition.
Codebook size and number of codebooks was varied.
As we can see, even with only two vectors per codebook, by using enough codebooks,
we can achieve decent performance – actually better than with a single codebook
with k=4096
in the previous experiment, even though using 32 codebooks we only have
64 vectors overall.
Using larger codebooks, we can finally approach no-VQ performance.
There are different ways to interpret these results. Having “number of codebooks” on the x-axis is not really fair, since the models using larger codebooks of course have many more vectors overall. We can re-order the curves to have “total number of codebook vectors” on the x-axis:
This is simply k * number_of_codebooks
.
This now seems to imply that using many smaller codebooks is actually more efficient
in terms of performance.
So is the answer to just use a huge number of size-2 codebooks?
Not really.
Recall that, at the end of the day, our main concern may be the degree of compression
of the data, i.e. bit depth.
32 codebooks of size 2 may have 64 vectors overall, but the number of bits here
is also 32 – one bit per codebook.
On the other hand, a single codebook of size 64 only requires 6 bits.
Thus, it may be a better idea to sort the x-axis by number of bits required:
This reveals yet a different picture – it seems to barely matter what combination
of codebook size and number we use to achieve a given number of bits!
If anything, this implies that larger codebooks perform slightly better.
Another striking feature is the very clean functional form – looks like a power law
could be a good fit, for example.
Using this, it may be possible to predict in advance how many bits would be required
to achieve a certain performance.
Finally, we can also use “possible number of quantizations” for the x-axis.
For example, as mentioned earlier, two codebooks of size 1024 allow for around
one million different quantizations.
This looks similar to the previous plot, and that is no surprise – it turns out, the number of bits is just the (base-2) logarithm of the number of quantizations! As such, this is really just a re-scaling of the x-axis.
To finish up, here is one more experiment: Recall that there are two ways to increase the overall number of bits: Increasing bit depth, or increasing the size of the encoded image. I wanted to see how the two relate, so I trained another set of models. These have basically the same architecture, but I cut off one set of layers to stop already at a resolution of 8x8. Results are shown below:
Here, I only tested the regular VQ-VAE, i.e. a single codebook, but varying
number of entries k
.
This implies that, with the same codebook size, the 8x8 models perform better.
But, of course, this is once again not a fair comparison:
At the same bit depth (related directly to k
), the 8x8 models have four times
as many vectors in their latent space, and thus use four times more bits.
We can once again equalize the x-axis by number of bits, this time for the whole
encoded image:
Looks like the 8x8 models actually perform worse! This once again shows how important it is to use the correct information on the axes. Of course, this could just be a quirk of the architecture design, since the 8x8 models simply have fewer layers, which could lead to weaker performance. This is not supposed to be an exhaustive test – I just wanted to showcase all the different factors we can vary.
We have seen that the number of bits in the latent space is key for good performance with VQ-VAEs. With a single codebook, it can be difficult to achieve higher bit rates, as these might require too many codebook vectors to be feasible. Residual vector quantization provides an interesting workaround; they seem to be a good option to achieve close-to-non-VQ performance. It remains to be seen whether such results generalize to more complex datasets and architectures. Here, other factors may start playing a confounding role, or more complex loss functions than MSE may not show the same predictable behavior. Still, it can be reassuring to see such clear and consistent behavior in the context of deep learning, where we often feel like we arestumbling through the dark when looking for improvements to our models.
]]>One widely applicable technique to improve neural network optimization is batch normalization. When it was initially proposed, BN seemed like a simple way to massively speed up training and improve performance, with very few downsides. Over the years, however, cracks have begun to show: Why BN even works at all is debated (also see here or here or here…) and research has shown problematic behavior (e.g. here).
BN and GANs are combined in some architectures, such as the popular reference DCGAN. Although outdated, this still serves as an important go-to architecture when starting to learn about GANs, with sample code being available on both the Tensorflow and Pytorch websites. Notably, the DCGAN paper mentions that they use BN, but not in the final generator layer nor in the first discriminator layer, as this would cause unstable behavior. However, no explanations for this behavior are given, nor why removing BN from those layers specifically should fix it.
After running into strange issues and seemingly impossible behavior in my own research involving GANs, I decided to further investigate how these networks interact with BN. My findings up to this point are summarized in this blog.
Let’s train a small GAN on a simple toy dataset. See the dataset below; the goal
is to learn a generator G
that essentially transforms noise (drawn from a 2D
standard normal distribution in this case, although this could be any distribution
at all) into data samples. We will use Tensorflow/Keras, although none of the
issues discussed in this post are framework-specific.
First, we set up simple networks for G
and D
:
generator = tf.keras.Sequential(
[tfkl.Dense(64),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(64),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(2)], name="generator")
discriminator = tf.keras.Sequential(
[tfkl.Dense(64),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(64),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(1)], name="discriminator")
We can train the GAN like this:
def train_step(real_batch):
n_batch = tf.shape(real_batch)[0]
noise = tf.random.normal([n_batch, 2])
real_labels = tf.ones([n_batch, 1])
fake_labels = tf.zeros([n_batch, 1])
# train g
with tf.GradientTape() as g_tape:
fake_batch = generator(noise, training=True)
d_out_deception = discriminator(fake_batch, training=True)
deception_loss = -1 * loss_fn(fake_labels, d_out_deception)
g_grads = g_tape.gradient(deception_loss, generator.trainable_variables)
g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
# train d
with tf.GradientTape() as d_tape:
d_out_fake = discriminator(fake_batch, training=True)
d_out_real = discriminator(real_batch, training=True)
d_loss = 0.5 * (loss_fn(real_labels, d_out_real) + loss_fn(fake_labels, d_out_fake))
d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
return -deception_loss, d_loss
D
is trained to classify real samples as 1 and fake samples as 0, guided by a
standard classification loss (binary cross-entropy). G
is trained to maximize
this loss. We alternate one step of training G
and one step of training D
.
The goal is for the networks to reach an equilibrium point where the distribution
of samples generated by G
matches the data distribution, and D
outputs 0.5
for real and fake samples alike (that is, it “classifies” with maximum uncertainty).
We run this training for 2000 steps using Adam.
Note that we always put in
the full population (2048 samples) as the batch in each training step. Here are
the resulting samples:
As we can see, the generated samples match the data quite well. The loss for D
in this case is around ln(2) ~= 0.69
, which indicates outputs of 0.5 for all samples,
as desired.
Now let’s add BN to our models. Of course, we don’t actually need to do this for this simple experiment, but let us assume we want to scale up our experiments to more complex data/networks, where BN could be helpful^{1}.
discriminator = tf.keras.Sequential(
[tfkl.Dense(64),
tfkl.BatchNormalization(),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(64),
tfkl.BatchNormalization(),
tfkl.LeakyReLU(alpha=0.01),
tfkl.Dense(1)], name="discriminator")
We train the network in the exact same way as before and get this result:
It doesn’t look quite right, does it? Almost seems like the entire distribution is slightly shifted. Knowing that GANs have stability issues, we might just try again with a new initialization:
Now it’s definitely shifted in a different direction! And yet, in both cases,
the observed loss is again around ln(2)
, indicating that D
cannot tell the
distributions apart at all.
Let’s make the issue more obvious. We slightly change the loss used to train G
like this:
deception_loss -= 0.005*tf.reduce_mean(fake_batch[:, 0])
This additional loss term encourages the generator to shift its generated samples positively on the x-axis (i.e. to the right).
The scaling is
hand-tuned: It must be large enough to actually produce relevant gradients for
the network, but if it is too large, G
will essentially just ignore the
adversarial game and infinitely decrease the loss by moving its samples further
and further to the right. Empirically, with the scaling chosen as it is, G
will not
shift the samples if this causes it to lose out in the adversarial game (i.e. D
easily tells the distributions apart since one is shifted). G
will only shift
samples if it can somehow do this without D
noticing.
Using this loss, we get the results below:
This time, the shift is very obvious. Still, D
still incurs a loss of ln(2)
,
meaning it is completely fooled by the generated samples!
Interestingly, if we remove BN but keep the additional loss term, we get this
result:
No shift has occurred! It seems that without BN, D
picks up on the shift and
the resulting worse loss for G
causes it to keep the samples where they should be.
It seems like our D
is somehow blind to shifts in the data. The overall shape
looks good, just not the location! The issue here is actually quite obvious when
we take another look at the training code for D
:
# train d
with tf.GradientTape() as d_tape:
d_out_fake = discriminator(fake_batch, training=True)
d_out_real = discriminator(real_batch, training=True)
d_loss = 0.5 * (loss_fn(real_labels, d_out_real) + loss_fn(fake_labels, d_out_fake))
d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
Note that we put the real and fake samples into D
separately. Recall that BN
normalizes features using batch statistics. This means that the real samples
will be normalized to mean 0 and variance of 1, and the fake samples will also be
normalized to mean 0 and variance of 1^{2}.
This means that, if the fake samples are transformed by an affine-linear function
(shifted and/or scaled by a constant), this will be normalized away by D
, making
it completely insensitive to such differences between distributions! In case this
is not quite understandable yet, here is a simple example:
Say, you have two “real” and two “fake” samples with only a single feature. E.g.
a_real = [0.5, 1.5]
and a_fake = [6., 8.]
. It would be very easy to tell these
two apart. However, when each batch is normalized separately with their respective
mean and standard deviation, they both result in [-1, 1]
! Thus, any model
working with the normalized data can never tell the two batches apart.
This is quite disastrous, as it means that G
does not actually have to match
the real data properly. Here, we have only seen affine-linear shifts being undetected.
However, since BN is usually applied in all layers, it may be that higher-order
differences between the distributions could also be normalized away in deeper layers
(but so far, I have not been able to show this experimentally).
In the simple example above, what if we combine the two batches into one? Using
a = [0.5, 1.5, 6., 8.]
and normalizing this, we get
[-1.12815215, -0.80582296, 0.64465837, 1.28931674]
. As we can see, it is still
easy to tell apart real and fake samples, e.g. with a threshold at 0.
This implies a straightforward solution: Use joint batches to train
D
instead^{3}!
This is what the training step for D
would look like:
# train d
with tf.GradientTape() as d_tape:
combined_batch = tf.concat([fake_batch, real_batch], axis=0)
combined_labels = tf.concat([fake_labels, real_labels], axis=0)
d_out_fake = discriminator(combined_batch, training=True)
d_loss = loss_fn(combined_labels, d_out_fake)
d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
We train a model (with BN and without the additional loss term) with this and get the results below.
Oh no… This looks bad. Even stranger, the losses are 6.62 for G
and 0.17 for D
.
Remember the equilibrium point 0.69? The loss for G
is higher indicating that
it’s “winning” the game against D
. At the same time, the loss for D
is lower,
indicating that it’s also winning! How can this be possible?
Clearly, the above issue must be related to joint batches somehow, as that was the
only change we made. Let us forget about those for a moment and just go back to
split batches. Instead, I want to show the issue that originally made me launch
this investigation. With the training step for D
being as before (split batches),
let’s look at the code for G
instead, specifically the role of D
:
# train g
with tf.GradientTape() as g_tape:
fake_batch = generator(noise, training=True)
d_out_deception = discriminator(fake_batch, training=True)
deception_loss = -1 * loss_fn(fake_labels, d_out_deception)
g_grads = g_tape.gradient(deception_loss, generator.trainable_variables)
g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
We are calling D
with training=True
. But if you think about it, we aren’t really
training D
here, are we? What if we just set training=False
? It seems more
appropriate. Well, we get results like below:
This is even worse than above! Losses are 4.23 (G
) and 0.007 (D
), once again
indicating that
both networks seem to be winning the game at the same time. I want to stress that
this is really not possible since we use the standard “zero-sum” formulation of
the GAN game^{4}.
The only explanation is that somehow, G
and D
must be playing different games.
Let’s finally try to fix this.
Both failures shown above have to be related to Batchnorm:
D
that introduces dependencies between batch
elements, possibly causing different behavior between using joint or split batches.D
that has different behavior between training
being True
or False
.Specifically, during training, Batchnorm uses batch statistics to normalize
features. During inference, however, this is undesirable as you
don’t want predictions for one element to depend on those for other elements –
you might not even have batches to run on, just single examples! For this reason,
Batchnorm also keeps a moving average of batch statistics. As such, batch statistics
are accumulated over the course of training, and when using training=False
these
accumulated statistics are used for normalization instead. This gives us a hint why
approach 2 above didn’t work: There must be some difference between batch statistics
and accumulated statistics.
When using split batches for training D
, both real and generated samples are
normalized using their respective batch statistics. At the same time, these batch
statistics are used to update the moving average population statistics. However,
we only have one set of such statistics, while the real and generated statistics
will likely be quite different (especially early in training). This results in
the moving averages to be somewhere between the real and generated statistics, as
exemplified below using the means:
When training G
with training=False
in D
, these accumulated statistics are
used to normalize the generated batch input in D
. But since these statistics do not match
the generated batch statistics, the inputs are not properly normalized! This issue gets worse
in every layer that uses BN, and leads to D
with training=False
to essentially
be a different function than with training=True
. And this, finally, solves the
riddle of how both networks can “win” at the same time: G
is basically playing
a different game! Clearly, we cannot train our GAN like this.
Remember that using joint batches for D
did not work, either. The issue is actually
similar: When training D
, features are normalized using the statistics of the joint
batch. But when training G
, we usually only feed a generated batch, which, of course,
has different statistics. Again, this causes a mismatch between the D
that is
being trained, and the D
that G
trains on.
Clearly, we have to resolve this mismatch somehow. Funnily enough, one possible fix
is to actually combine the two approaches that did not work, i.e. D
with
training=False
when training G
, and joint batches for training D
. The code
then looks like this:
# train g
with tf.GradientTape() as g_tape:
fake_batch = generator(noise, training=True)
d_out_deception = discriminator(fake_batch, training=False)
deception_loss = -1 * loss_fn(fake_labels, d_out_deception)
g_grads = g_tape.gradient(deception_loss, generator.trainable_variables)
g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
# train d
with tf.GradientTape() as d_tape:
combined_batch = tf.concat([fake_batch, real_batch], axis=0)
combined_labels = tf.concat([fake_labels, real_labels], axis=0)
d_out_fake = discriminator(combined_batch, training=True)
d_loss = loss_fn(combined_labels, d_out_fake)
d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
Now, when training D
, the batches are normalized using the joint statistics. Also,
the statistics of both real and fake samples are accumulated into the moving average.
When training G
, we only pass a generated batch to D
, but because we set training=False
,
the accumulated statistics are used, not the ones of the batch. Finally, the mismatch
is resolved! As we can see below, the model produces good samples, and it is
even “immune” to affine-linear shifts in the data (like the original model without BN)
due to the joint batch normalization.
Unfortunately, while the above works well for this simple toy example, I have still observed issues with both networks “winning” at the same time in larger-scale experiments. This implies that there must still be a mismatch in the normalization. Most likely, the accumulated population statistics (which are updated with a slow-moving average) lag behind the current batch statistics, especially early in training when the model parameters change rapidly. This could perhaps be fixed by making the average move faster, but this would make the population statistics more volatile and may reduce inference performance. To finish our investigation, I would like to discuss a few more possible solutions:
D
when training G
, and using training=True
again. This guarantees consistent behavior in D
for both training steps.
The only annoyance is that this requires additional computation,
as we have to put an entire real batch through D
that we otherwise wouldn’t
need, and that doesn’t influence the training of G
directly (except through
the changed statistics).D
to ignore affine-linear shifts (at the population level)
in the data. If we look at Instance Normalization, this should be even worse,
as it would normalize affine-linear shifts even at the instance level.
However, since IN only makes sense for data with spatial dimensions (images etc),
we cannot test it in the simple framework used for this post. This may be
content for follow-up work. Layernorm, on the other hand, seems to be able to
detect these shifts (why/how exactly, I still have to think about), and thus
Groupnorm should, as well. Finally, Weightnorm doesn’t modify the features at
all, so it should have no issues.D
. This prevents the “normalizing away”
of affine-linear shifts, and the first-layer features can pick up on them and
“preserve” them throughout the network. The DCGAN paper, for example, simply
states that they did this (also removing BN from the final layer of G
) because
the training would otherwise be unstable. However, they do not offer any explanations
for why BN may be harmful. The question for me is whether BN in later layers
could still be problematic, since it might normalize away “higher-order” differences
in features that have the same mean and variance, but differ in higher moments.Recall that our “broken” training attempts were caused by a mismatch between the distributions of the real and fake samples. Batchnorm normalizes the mean and variance. What if we force those to be the same between the distributions? Specifically, we transform our generated samples like this:
train_real_means, train_real_vars = tf.nn.moments(x_samples, axes=[0])
train_fake_means, train_fake_vars = tf.nn.moments(fake_batch, axes=[0])
fake_batch = (fake_batch - train_fake_means) / tf.math.sqrt(train_fake_vars)
fake_batch = fake_batch * tf.math.sqrt(train_real_vars) + train_real_means
Essentially, this works like a batchnorm layer with beta
and gamma
fixed to
the statistics of the real data. Since we now force the statistics of the generated
samples to equal the real data, there should be no more distribution differences, right?
Sort of. This actually works, if D
has BN only in the first layer. It makes
sense when you think about it: We equalized the means and variances of the real
and fake data, so they will be “appropriately” normalized by the first-layer BN.
However, the distributions are of course not identical; there will still be
differences in higher moments. BN in later layers will still normalize these
away, but because we did not equalize them between the distributions, we once
again run into the old issues of G
and D
playing different games (samples
being normalized with different statistics). To me, this indicates that BN in
later layers can still be problematic, and just removing it from the first layer
(as recommended in the DCGAN paper) may not be sufficient.
So is BN terrible and useless? Probably not. The issues discussed here only arise due to us training with two distinct populations (real and generated data). BN will likely still work well in many standard tasks in discriminative and generative modeling.
However, I think it’s important to be aware of the implications of normalizing this or that, completely removing certain information from your model in the process. In my personal experience, most normalization techniques are somewhat situational, and it is rarely clear in advance which one will work best for a certain task or model. Getting an idea of their respective quirks, however, can limit the search space and save lots of time.
And stay away from GANs, kids. ;)
Whether or not G
includes BN is not relevant here. ↩
To be precise, the features after the dense layers are normalized. However, these are just linear transformations, so the (also linear) normalization happening after the first dense layer, but before the non-linearity, results in the same phenomenon. ↩
Note that reference implementations such as the Tensorflow/Pytorch DCGAN code use split batches! ↩
To be precise, for G
we are only using half the loss (only on generated
samples). However, since the D
loss is the average between that and the loss
on real samples, if the G
loss is, say, 4, the D
loss would have to be at
least 2. ↩
I probably could have kept this post for myself, since the main reason I’m writing this is to force myself to better formalize and structure these ideas. But I also wouldn’t mind discussing these topics with other people. ;)
This will be a look at generative models for music from the perspective of a deep learning researcher. In particular, I will be taking the standpoint that such models should aim to possess some sort of creativity: Producing truly novel and “interesting” artifacts. To put it another way: They should model the process, not the data. I realize that many of the aspects I bring up are already being considered and actively researched in other communities (e.g. computational creativity). However, my main goal is to bring up specifically why some of the current research directions in deep generative modeling are (IMHO) misguided and why resources might be better spent on other problems.^{1}
Generative modeling may be summarized as: Given a set (or more generally, a domain) of data x, build a model of the probability distribution p(x). In fact, most “modern” deep generative modeling frameworks (such as GANs or VAEs)^{2} do not actually represent this distribution, focusing instead on producing samples from p(x). The implications of this (e.g. preventing many potential applications of a generative model, such as inpainting or density estimation), while certainly important, are not the topic of this discussion. Instead, I want to focus on the case where generating outputs is all we care about.
While such models are mostly built and tested in the image domain (especially faces), attempts at creating music “from nothing” are becoming more ambitious (e.g. Music Transformer, Musenet, or most recently Jukebox). These models are certainly great accomplishments of engineering, since musical data (especially on the raw audio level) is very high-dimensional with dependencies across multiple (and very long) time spans. But what is their value, really? No matter how many layers the model has, how many units per layer, or which fancy connection schemes and training tricks are used – at the end of the day, the model will capture statistical relations between data points, because that’s what we ask it to do. Better models and/or smart conditioning on factors such as musical style, composers, lyrics etc. may mask this somewhat, but it doesn’t change the fundamental nature of these models – produce things that are like the things they have seen (or rather, been shown) before. ^{3} The inadequacy of this approach will be discussed further in the following sections, where I will sometimes refer to such models as copy models (please take this as somewhat tongue-in-cheek).
Please note: I am aware that there are contexts/applications where training a generative model to “copy” a distribution is actually the goal, and there is nothing wrong with that. However, models like the ones mentioned above are usually presented “as is”, with their ability to generate an endless stream of music as the main selling point. Interestingly, they are often explicitly advertised as generating music “in a certain style”, which IMHO masks their limitations somewhat by pretending that generating hours of Mozart-like music (for example) is the whole point. Of course, there is some value here – namely, exploring/showcasing model architectures that are capable of the kind of long-term structure needed to produce music. I’m certainly not proposing to get rid of deep models altogether – their incredible expressiveness should be leveraged. But I believe that at some point, the relentless scaling-up should be stopped, or at least halted for a bit, and the insights be applied to more creative approaches to making music.
As an approach radically different from copy models, it might be possible to start generating artifacts (e.g. pieces of music) without any reference data whatsoever. This will likely result in pieces that seem very alien to us, since they are not at all grounded in our own experience. Still, I believe it could be interesting to see where, for example, purely information-theoretical approaches could take us. That is, art should be predictable (so it is “understandable”), but not too predictable, since that would be boring/not stimulating. The process could additionally be equipped with simple sensory priors (e.g. related to harmony) to make the results more familiar. Such models could be used to investigate many interesting questions, for example: Under which models/assumptions is human-like music possible to evolve? When adding more assumptions, does it eventually become inevitable? Speaking of evolution…
Given a fixed data distribution for training, a generative model will be “done” eventually. That is, it will have converged to the “best” achievable state given the architecture, data, learning goals, training procedure etc. If we then start sampling thousands upon thousands of outputs (compositions) from the model, these will all come from the exact same (albeit possibly extremely complex) distribution. Diversity can be achieved by using conditional distributions instead, but these will still be stationary.
It should be clear that this is not a reasonable model of any creative process, nor will it ever create something truly novel. On the contrary, such a process should be non-stationary, that is, always evolving. New genres develop, old ones fall out of favor. Ideas are iterated upon. New technology becomes available, fundamentally disrupting the “generative distribution”. Such things should (in my opinion) be much more interesting to model and explore than a literal jukebox.
Concretely, I believe that concepts from research on open-endednessshould be very interesting to explore here. One example might be co-evolving generators along with “musical tastes” that can change over time. Speciation could lead to different genres existing in parallel. A minimal criterion approach could guarantee that only music that actually has listeners can thrive, while at the same time making sure that listeners like something instead of just rejecting everything for an easy “win”. Importantly, this allows for modeling the apparent subjectivity/taste that plays a big role in human appreciation of art, without relying on black-box conditioning procedures using opaque latent codes or similar approaches.
To expand upon the speculation on “ex nihilo” generative models at the end of the last section, it could be interesting to train a copy model to initialize an open-ended search process which is then perhaps guided by more general principles/priors. This would allow for exploring possible evolutions of existent musical genres.
Generative modeling of music is usually done at one of two levels.
The first one
is the symbolic level, where music is represented by MIDI, ABC notation, piano
rolls or some other format. What is common to such representations is that they
use (often discrete) “entities” encoding notes or note-like events in terms of
important properties such as the pitch and length of a tone. Importantly, there
is no direct relation to audio in these symbols – the sequences need to be
interpreted first, e.g. by playing them on some instrument. This implies that
the same MIDI sequence can sound vastly different when it is interpreted via two
different instruments. This is arguably already a problem in itself, since
widely used symbolic representations lack means of encoding many factors that
are important in contemporary music (electronic music in particular).
In that
regard, it is quite telling that many symbolic models are trained on classical
(piano) music. Here, the instrument is known and fixed, and so it can be assumed
that a “sensible” sequence of symbols will sound good.^{4}
However, there is a second problem related to the interpretation of musical symbols, which is perhaps easier to miss. Namely, the symbols have absolutely no meaning by themselves. Previously I said that they usually encode factors such as the pitch or length of a tone – but the exact relationships are imposed by human interpretation. Take the typical western twelve-tone equal temperament (which MIDI note pitches are also commonly mapped to) as an example: Here, every twelve tones are one octave apart (i.e. they double in fundamental frequency). Seven tones (a fifth) are in a relation of 3:2, etc. Generally, every tone increases in frequency by about 6% compared to the next-lower one. Such intervals undoubtedly play an important role in human perception of music. But these relations are completely absent from a symbolic note representation. For a model training on such data, there might as well be five tones to an octave, or thirteen, or… The concept of intervals does not arise from symbolic data, and thus a model trained on such data cannot learn about it.
Then why do such models manage to produce data that sounds “good”, with harmonic intervals we find pleasing? This is simply because the models copy the data they receive during training. If the data tends to use certain intervals and avoid others, the model will do so, as well. The difference is that the training data was generated (i.e. the songs where composed) with a certain symbol-to-sound relationship in mind (e.g. twelve-tone equal temperament). However, this relationship is lost on the model, which merely copies what is has been taught without understanding the ultimate “meaning” (in terms of actual sound). In fact, this seems incredibly close to John Searle’s famous Chinese Room argument. Thus, unless one wants to view music generation as an exercise in pure symbol manipulation, the symbolic level seems unfit for any kind of music generation that is not content with merely copying existing data, although possible remedies could be to use more expressive symbols that relate more closely to the audio level (e.g. using frequencies instead of note numbers), or equip the model with strong priors informed by this relationship.
Aside from the symbolic level, it is also possible to directly generate data on the audio (waveform) level.^{5} However, this approach has so far been lagging behind symbolic models in terms of structure, coherence and audio quality. Some models were developed on the single-note level (e.g. GANSynth), with a focus on quality. While such models are interesting for creative applications in human-in-the-loop scenarios (e.g. sound design), they are obviously not capable of producing interesting musical sequences. Still, there have been examples of modeling sequences in the waveform domain with some success (e.g.DAA).
Recently, OpenAI released their Jukebox model, which scaled up waveform generative models to levels far beyond what has been seen before. The fact that a single model can produce samples of such variety, conditioned on styles and even on lyrics, is astounding. However, there are still some issues with generating at the audio level:
It may be possible to combine symbolic and waveform approaches to achieve the best of both worlds. Essentially, this means using a symbolic-level model to produce sequences of symbols, and then a waveform model that translates those symbols into sound. This preserves many advantages of symbolic models (e.g. explicit, interpretable representations and specific ways of making sound) while also allowing the model to “connect” with the domain we eventually care about (audio).
While this sounds good in theory, there are of course problems with this approach, too. The main one is probably how to formulate a joint model for symbols and sound. A major obstacle here is that it is not possible to backpropagate through discrete symbols. Since most symbolic models output soft probability distributions over symbols, this is not a problem in the pure setting. But a joint model would probably not be able to work with such soft outputs, since it would be like “pressing every piano key a little bit, but one more than the others”. Still, there are workarounds for this issue, such as vector quantization with straight through estimators – or dropping gradient-based methods entirely and using alternatives like reinforcement learning or evolutionary computing instead.
Besides this problem at the symbolic level, there is also one in the symbol-to-audio pipeline: This generator needs to be differentiable, too. This means we cannot simply train a symbolic model with a “real” piano (samples), since the instrument cannot be backpropagated through. Alternatively, using standard neural network architectures (e.g. Wavenet) can lead to artifacts and/or slow generation as discussed before. Personally, I am really interested in approaches like DDSP that preserve differentiability while incorporating a sensible inductive bias for audio, leading to much better quality with simpler models.
A possible hybrid approach could go like this:
Colton
Heath & Ventura argue that this third point is a key component that is lacking in many generative systems. We can find analogues, however, in some modeling frameworks: (Variational) Autoencoders have an inference network (encoder) that can process data, which would include its own outputs. However, it is not clear how one would connect this to “appreciation”, since the main point of the encoder usually is to simply map the data to a lower-dimensional representation. Still, perhaps it could be easier to work in this space than in the data space directly.
On the other hand, autoregressive models as well as flow-based models (which generalize the former) can explicitly compute probabilities for a given data point, which might be taken as a proxy for “quality”. A model could use this to reject bad samples (e.g. that resulted from an unfortunate random draw) on its own. This is troublesome, however, since it is not clear a priori what a “high” probability is, and accordingly what kind of score one should strive for. This is particularly true in the (common) case where the data is treated as continuous, and the probabilities computed are actually densities. Also, this approach seems inappropriate for judging novelty – truly novel work would likely receive a low probability and thus be difficult to differentiate from work that is simply low-quality, which would also receive a low score.
Additionally, none of these models use their “self-judging” abilities to actually iterate and improve on their own outputs. This is fairly common in a creative process: Create something (perhaps only partially), judge which parts are good/bad and improve on the ones that are lacking. Here, I find self-attention approaches such as the transformer interesting: The model can essentially take multiple turns in creating something, looking at specific parts of its own output and use this information to iterate further. However, current transformer models usually do not produce actual outputs (in data space) at each layer; instead they compute on high-dimensional hidden representations and only produce an output at the very end.
Given our evolutionary history, I believe it’s safe to say that perception came first, and the ability and desire for creativity arose out of these capacities. At the same time, generative and perceptual processes could also be tightly interlinked inside a model, e.g. using a predictive coding framework. At this time, I don’t know enough about PC to really make a judgement (or go into more detail), however.
Likely the biggest challenge in modeling (human) creativity is that art is usually “about” something, meaning that it relates in some way to the artist’s own experience in the world. As such, properly approaching this subject seems to require solving strong AI. However, there may be ways to at least make steps towards a solution via simpler methods. One example could be multi-modal representations. As humans, we are able to connect and interrelate perceptions from different senses, e.g. vision, touch and hearing. We can also relate these perceptions to memories, abstract concepts, language etc. It seems obvious that such connections inform many creative artifacts. For example, program music provides a “soundtrack” that fits a story or series of events. Such music is neither creatable nor understandable without understanding language/stories (which in turn requires general world knowledge). On a more personal level, an artist may create a piece of music that somehow mirrors a specific experience, say, “lying at night at the shore of a calm mountain lake”.
Models that simply learn to approximate a given data distribution (limited to the modality of interest) clearly cannot make such connections.^{6} However, this could be different for a model that learns about audio and vision (for example) concurrently. As long as there is some connection between modalities, e.g. via a shared conceptual space (embeddings are an extremely popular method and could be a simple way of achieving this to a first approximation) it should be possible for the model to connect the visual concepts it is learning about with the audio dimension. This, like the other proposals in this text, is obviously an extremely rough sketch with many, many details to be considered – but this requires research, not blog posts.
To summarize:
Each of these points offers several directions for future research to explore. It is possible that none of these proposed methods/directions will result in anything comparable to copy models, in terms of surface-level quality, for a very long time. However, I believe it is important to break the mould of trying to make progress by throwing humongous amounts of compute at highly complex data distributions. Instead, generative music (at least for creative purposes) should start from first principles and accept that the results might be “lame” for a while. In the long term, this has the potential to teach us about music, about creativity in general, and about ourselves. Can Jukebox do that?
Besides, including detailed reviews of CC literature would make this post excessively long. ↩
Please note, I will not be providing citations for general deep learning concepts that I would believe practitioners to be familiar with, nor a few other things – I’m a bit lazy and this is not a publication. ↩
To take a more extreme view: The only reason these models produce anything “new” is due to limited capacity and inherent stochasticity of the data. If they could literally copy everything perfectly, they would. ↩
Another reason for the preference for such data sets is likely that they are widely available without copyright issues, which is a big problem with musical data. The fact that they use a single instrument also makes them much more straightforward to model. ↩
We can also generate audio at the spectrogram level. This tends to be easier, but then the problem is how to invert the spectrograms to audio without loss of quality. ↩
This would also include the creation of emotional music; this cannot be created without “knowledge” of emotions. Except, of course, if the model learns to copy a database of existing emotional music – the common approach. ↩
Here are some facts I learned about the city:
The first day of the conference mainly consisted of two tutorial sessions. There were three topics for each slot, so you had to pick one. Since I was there with Sebastian, we decided to split up and each visit different topics. Instead of discussing the tutorials themselves, however, I would like to discuss my personal takeaways as to how you (IMHO) should structure such a session (and what you should avoid).
These aren’t necessarily the “best papers”, but colored by my own preferences or things I just found “cool”. Note that the order in this post is simply the order in which the papers were presented.
The paper by Bittner et al. is an interesting case of “meta research”, showing the phenomenon of different people working with differing version of the “same” dataset, as well as the significant impact this can have on the results. While they only sample a small number of exemplary datasets, it makes you wonder how many research projects are/were affected by problems like this.
The authors also propose a Python library for unified distribution and verification of MIR datasets. This is great, if only for the fact that datasets that do not come from the same source often come in different formats requiring separate preprocessing procedures etc. A common repository can reduce the load of having to write new processing code for each new dataset.
This paper by Parmer et al. investigates the “complexity” in terms of information theory of various western music styles and its development over time. While the methodology can be questioned – e.g. the information measure is limited and not verified with regard to human perception of complexity – there are some interesting findings about how different genres seem to “prioritize” different areas of complexity, how the different areas developed over time, and how Billboard Top 100 songs differ (or don’t) from the general population of songs.
Choi et al. tackle the problem of unsupervised transcription in this paper, based on how a human might do it: Listen to a piece of music, try to play it, and adjust your play to fix any errors on your part (i.e. differences between what you heard and what you played).
The core idea is to use an encoder-decoder approach with a fixed decoder (your
“instrument”). By limiting the decoder to essentially produce impulse responses
for certain instruments, the encoder is forced to produce the corresponding
transcriptions. This of course has many limitations, e.g. a limited set of “sounds”
the decoder can produce, as well as the decoder needing to be differentiable.
Making this approach work for non-drum sounds would probably need a lot of extra
work, but in my opinion it’s still a cool idea.
Another interesting thing is their use of the sparsemax activation,
which I didn’t know about before – a kind of sparse (but differentiable)
alternative to softmax.
A special shoutout goes to their training dataset, which apparently “was crawled from various websites”. Way to foster reproducibility!
Invertible neural networks are, essentially, able to map back from their output to the input. Kelz et al. introduce this concept to MIR tasks in their paper. The networks are very close to flow-based generative models, but there is the problem that in classification tasks, we usually don’t want the output to carry all information in the input (or even be the same size as the input), which makes inverting the network rather difficult. To fix this issue, there is an “auxiliary” output that carries the extra information needed for invertibility, but not for the actual task of interest.
Invertible neural networks primarily promise better interpretability of trained models, which is definitely needed in deep learning. For example, in the case of transcription, a given symbolic (transcribed) piece of music can be inverted to give an example sound that would be transcribed this way. This way, one can check whether the concepts that the model learned make sense intuitively. As an added bonus, we get a generative model “for free” by training a discriminative one.
Grachten et al. propose a kind of “neural equalizer” that automatically attenuates resonances in music. They show that a network working directly on raw audio performs on par with hand-crafted feature pipelines. I suppose this isn’t super impressive to most people, but I like the idea of incorporating AI/ML into music production. Possible future work includes processing a piece of music such that it adheres to some desired spectral profile (i.e., which frequencies should be present how strongly?). Check the paper.
Not my topic at all, but this dataset represents a huge effort: Thousands of videos, 40 dancers, several genres, up to nine cameras… Plus, it’s all free and open. Cheers to Tsuchida et al.!
The folks at Google Magenta presented a paper on efficient neural audio synthesis. What I find more relevant is the follow-up, currently under review for ICLR, which is basically what I wanted to do to kick off my PhD. Oh well. Check it out though, the examples are quite impressive. And yet, lots of work still to be done…
Müller et al. probably didn’t “need” to write this paper since the notebooks kind of stand on their own, but I suppose it is rather like a “companion paper” since, unfortunately, papers are still the most important thing in research. Read the paper, check out the notebooks, read the FMP book… it’s all good stuff.
At a glance, the paper by Lattner et al. is a bit too complex for me (hahaha), but this is definitely something I’ll be playing around with in the near future in order to gain some understanding. The idea is to learn invariances to several “simple” kinds of transformations such as transposition or time-shifting, although the authors also demonstrate uses in the image domain (e.g. rotation). Unfortunately, the paper leaves out most ofthe visualizations of the learned filters that were on the poster, which looked really intriguing.
This paper by MacKinlay et al. sounded super cool (and he was really nice to talk to at the poster!), but I have to admit that this goes way over my head mathematically. Another one I’m going to have to do some tinkering with. :) Unfortunately they don’t include sound examples in the paper, but it provides a method doing “musical style transfer” (i.e. transfering the timbre of one signal onto the melody of another). While this isn’t “new” per se, the approach sounds quite sensible, plus it’s IMHO more attractive than just “letting a neural network do it”, which most style transfer solutions nowadays seem to go for.
The LBD session was for “experimental” stuff that wasn’t ready in time for the regular deadline; it was also “just” posters (and optional demos), no papers. Unfortunately, there was very little time given the sheer amount of stuff, so here are just some quick highlights:
There was an “unconference” session with some interesting topics. Unfortunately, it overlapped with the LBD session, so I ended up not going… Shame, the whole conference was set up in a “single-track” way, but it seems like they just wanted to do a bit too much on Friday, so this kinda fell under the wagon. Maybe next year!
A few more general takeaways:
Overall, ISMIR 2019 was a great time! Met plenty of nice people, ate some great food, got lots of new input… And motivation to submit something to next one so we can go to Montreal next year. :) See you there!
]]>This tutorial is based on one that was previously found on the Tensorflow website. You can check that one for additional conceptual guidance, however the code snippets found there are intended for old Tensorflow versions (1.x). Since the official TF website is now lacking a comparable tutorial (the simple MNIST tutorials use Keras instead of low-level concepts), the following is supposed to offer an updated version working in Tensorflow 2.0. It is intended as a supplementary tutorial for Assignment 1 of our Deep Learning class and assumes that you already went through the other posts linked there.
Download this simple dataset class and put it in the same folder as your script/notebook. It’s just a wrapper for simple production of random minibatches of data.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from datasets import MNISTDataset
We make use of the “built-in” MNIST data in Tensorflow. We plot the first
training image just so we know what we’re dealing with – it should be a 5. Feel
free to plot more images (and print the corresponding labels) to get to know the
data! Next, we create a dataset via our simple wrapper, using a batch size of 128.
Be aware that the data is originally represented as uint8
in the range
[0, 255]
but MNISTDataset
converts it to float32
in [0,1]
by default.
Also, labels are converted from uint8
to int32
.
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
plt.imshow(train_images[0], cmap="Greys_r")
data = MNISTDataset(train_images.reshape([-1, 784]), train_labels,
test_images.reshape([-1, 784]), test_labels,
batch_size=128)
We decide on the number of training steps and the learning rate, and set up our weights to be trained with random initial values (and zero biases).
train_steps = 1000
learning_rate = 0.1
W = tf.Variable(np.zeros([784, 10]).astype(np.float32))
b = tf.Variable(np.zeros(10, dtype=np.float32))
The main training loop, using cross-entropy as a loss function. We regularly print the current loss and accuracy to check progress.
Note that we compute the “logits”, which is the common name for pre-softmax values. They can be interpreted as log unnormalized probabilities and represent a “score” for each class.
In computing the accuracy, notice that we have to fiddle around with dtypes quite a bit – this is unfortunately common in Tensorflow.
for step in range(train_steps):
image_batch, label_batch = data.next_batch()
with tf.GradientTape() as tape:
logits = tf.matmul(image_batch, W) + b
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label_batch))
gradients = tape.gradient(loss, [W, b])
W.assign_sub(learning_rate * gradients[0])
b.assign_sub(learning_rate * gradients[1])
if not step % 100:
predictions = tf.argmax(logits, axis=1, output_type=tf.int32)
acc = tf.reduce_mean(tf.cast(tf.equal(predictions, label_batch),
tf.float32))
print("Loss: {} Accuracy: {}".format(loss, acc))
We can use the trained model to predict labels on the test set and check the model’s accuracy. You should get around 0.9 (90%) here.
test_predictions = tf.argmax(tf.matmul(data.test_data, W) + b, axis=1,
output_type=tf.int32)
accuracy = tf.reduce_mean(tf.cast(tf.equal(test_predictions, data.test_labels),
tf.float32))
print(accuracy)
The release of Tensorflow 2.0 is supposedly around the corner (at the time of writing, the current version is rc0), and with it comes the promise of a more streamlined and intuitive API through things like full Keras integration and eager execution. However, new ways of doing things also bring new problems. In this post, I want to summarize some common issues along with ways to avoid or fix them. Most of these come from own experience or questions on Stackoverflow. It is intended mostly as a compendium for new users and people taking our classes, but perhaps others can profit as well. Note that I might update this post in the future if more things come up (or there might be a part 2 instead)!
tf.function
promises to make the switch between eager and graph mode easy –
develop, prototype and debug in eager execution and then slap on this decorator
for production-level performance. That’s the idea – in practice, there are so
many quirks with this thing that one could write a whole series of posts on
this alone – and in fact
people have done so already.
That link points to a three-part post discussing this topic at length. I highly
recommend reading it in detail, but I want to include some “highlights” here:
To understand this issue, note what a tf.function
-decorated function actually
does under the hood: The first time it is called, it is compiled into a graph,
and then any other time the function will simply execute the graph instead –
the Python function is basically “ignored”. Consider this simple example:
@tf.function
def fun():
print("hello")
fun()
fun()
>>>hello
As we can see, the print statement is only executed once even though we called the function twice – it doesn’t make it into the compiled function.
This is, however, not the full story: Actually, the function is compiled once for each input signature instead. This has dramatic implications if the function accepts Python numeric types, where each new value actually leads to a new input signature!! If this sounds a bit complicated, just consider the following example:
import time
@tf.function
def fun(x, step):
return 5*x
start = time.time()
for step in range(1000):
dummy_fun(0, step)
stop = time.time()
print(stop-start)
>>>12.037055969238281
The function is compiled anew every time it is called with a new step count (which is every time we call it)! This will slow down execution dramatically. Luckily, the fix is simple: Use tensors for such “changing values” instead, where different values do not count for a new input signature.
start = time.time()
for step in tf.range(1000):
dummy_fun(0, step)
stop = time.time()
print(stop-start)
>>>0.40102267265319824
So, whenever your decorated functions are conspicuously slow (especially if they become slower after adding the decorator), you might want to check your input parameters for Python numbers! Note that this is a fairly common scenario as we may use step counters to control stuff like learning rate decay, saving a model regularly etc.^{1}
Gradient tapes are a new kind of abstraction in TF 2.0. They basically replace
tf.gradient
. The idea is: With eager execution, there is no static
computational graph, meaning no way to trace computations and thus no way to do
backpropagation. Gradient tapes offer a way to temporarily trace computations as
needed so that we can still use TF’s symbolic differentiation capabilities (it
would be pretty useless otherwise!). Once again, however, it’s easy to run into
problems…
In TF 1.x, there is the concept of collections that globally keep track of
things like trainable variables. In TF 2.0, this doesn’t exist anymore: You need
to keep track of your variables yourself. Often, you will do this via
tf.keras.Model
instances, which have a convenient trainable_variables
property. Consider this:
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
var_list = model.trainable_variables
>>>ValueError: Weights for model sequential_1 have not yet been created. Weights
>>>are created when the Model is first called on inputs or `build()` is called
>>>with an `input_shape`.
Whoops! The model was never built, so there are no variables. In the current
version this leads to a crash, but older pre-releases actually executed
perfectly fine – but model.trainable_variables
would be empty! This would
mean that your fancy training loop just went through computing gradients for and
updating no variables at all… Thus, make sure you only ever use variable
stores of your fully-built models:
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.build((None, 784)) # for MNIST ;)
var_list = model.trainable_variables
Alternatively, always using model.trainable_variables
explicitly (instead of a
shortcut assignment like above) can also prevent mistakes, but it can be
cumbersome in some situations.
A common question is this: I want to find gradients with respect to the input to my network, e.g. to find how sensitive the predictions are to certain parts of the input, with the network itself staying fixed. I tried this but it doesn’t work:
input_ = ... # just get a tensor from somewhere
model = ... # same for the model
with tf.GradientTape() as tape:
# let's say we are interested in the class with index 7
logits_seven = model(input_)[:, 7]
grad_for_inp = tape.gradient(logits_seven, input_)
print(grad_for_inp)
>>>None
The issue lies with what kind of computations GradientTape
actually traces: By
default, it will store all computations related to any tf.Variable
it comes
across, and nothing else. In particular, since your network input is usually
just a Tensor
and not stored in a variable, related computations are not
traced and so no gradients can be computed.^{2} Once again, the fix is actually really simple: You
need to tell the tape what to trace (or “watch”)!
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(input_)
logits_seven = model(input_)[:, 7]
grad_for_inp = tape.gradient(logits_seven, input_)
print(grad_for_inp)
>>>tf.Tensor(
>>>[[[[ 0.03662432 -0.0254075 0.06999005]
>>> [ 0.00372662 -0.0231829 0.01369272]
>>> [-0.06823001 -0.0217168 -0.02034823]
>>> ...
Note that I passed an extra parameter to tell the tape not to trace the variables it comes across (i.e. all the model parameters). This isn’t necessary for correctness, but should make the whole thing a little more efficient.
Here’s the example from above again:
with tf.GradientTape() as tape:
tape.watch(input_)
logits = model(input_)
grad_for_eight = tape.gradient(logits[:, 7], input_)
print(grad_for_eight)
>>>None
It broke again! What happened? Note that this time, I do the indexing into the
“interesting” class outside of the gradient tape context. This means that the
tape basically loses track of where this tensor came from and cannot compute
the gradients anymore. As a rule of thumb, anything you put into tape.gradient
should come straight out of the tape context without any modifications!
Another common example would be when you have multiple losses (e.g.
classification and regularization losses) and add them to a “total loss”
outside the tape context. This won’t work!
Keras is “the” high-level interface in TF 2.0, and it is arguably much more
convenient than the cumbersome Estimator
interface or chaining tf.layers
.
But once again, it does not come without pitfalls…
If using batch normalization, you might get strange warnings like this one:
>>>W0924 09:53:23.799460 140659773245184 optimizer_v2.py:979] Gradients d does not
>>>exist for variables ['batch_normalization_1/moving_mean:0',
>>>'batch_normalization_1/moving_variance:0'] when minimizing the loss.
Looking at the code, it may well include something like this:
grads = tape.gradient(xent, model.variables)
optimizer.apply_gradients(zip(grads, model.variables))
This can be subtle: We are using model.variables
instead of
trainable_variables
. These are different! variables
stores anything that
has a “state” that needs to be stored over the course of time. In the case of
batchnorm, this includes the “population statistics” batchnorm uses during
inference (instead of minibatch statistics). These are not used during training
and so no gradients can be computed (and you wouldn’t want this anyway!).
There are of course other cases besides batchnorm, but the root cause is often
the same: You are including variables in your optimization procedure that have
no business of being there. Often, using model.trainable_variables
can fix
this.
Keras models have functionality that allows us to easily execute the full model in a single line. How about this (re-using an example from above)?
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(input_)
logits_seven = model.predict(input_)[:, 7]
grad_for_inp = tape.gradient(logits_seven, input_)
print(grad_for_inp)
>>>AttributeError: 'numpy.dtype' object has no attribute 'is_floating'
Why does this fail? All we wanted to do was the forward pass of the model. It
turns out that model.predict
returns a numpy array, and this quite literally
interrupts the “tensor flow”, meaning no gradients can be computed
either.^{3}
Instead, make sure to always use Keras models as callables as in the examples
above. This still holds if you try to be smart and “just go back to Tensorflow”
again:
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(input_)
logits_seven = model.predict(input_)[:, 7]
logits = tf.convert_to_tensor(logits)
grad_for_inp = tape.gradient(logits_seven, input_)
print(grad_for_inp)
>>>None
Do note that this applies to anything involving numpy arrays – gradients cannot propagate through these operations!! This has always been the case, but is arguably a bigger problem in TF 2.0 where it is so tempting to mix between eager execution and numpy arrays.
In this case, simply removing the decorator and using
Python’s range
actually results in fastest performance (only 0.0006 seconds).
It seems like in this simple case, the overhead from even the first function
compilation as well as handling GPU data transfer with tf.range
is too
much. Sometimes it can pay off to stay eager! ↩
This is likely another case
where raising an error would be preferable, but currently it just returns None
as a gradient…. ↩
Although to be precise, this specific error is due to a different reason. ↩