VAE-GAN Part7: Batch KLD
In the previous post, I hypothesised that the explosion of KL-Divergence occurs when optimizing feature matching because the output from feature extraction part of discriminator is not regulated and the values of gradient can become large. Then I thought simply adding Batch Normalization at the end of feature extraction would do. It, however, did not. With Batch Normalization, feature matching error is regulated, but is no longer giving a good gradient back to encoder/decoder (it seems).
While I was looking for a way to surpress the explostion of feature matching error, so that KL-Divergence behaves better, I also came to wonder if KLD on batch can help. To investigate this, I ran couple of experiments.
Batch KLD
Firstly, in addition to the previous definition of Batch KLD, I incorperated moving average.
# Get mu and logvar from encoder
z_mu, z_logvar = encoder(input_image)
# Generate latent sample
z_std = torch.exp(0.5 * z_logvar)
sample = z_mu + z_std *torch.randn_like(z_std)
# Compuet mean and variance of samples over batch
sample_mean = torch.mean(sample, dim=0)
sample_var = torch.var(sample, dim=0)
# Apply moving average
mean = momentum * current_mean + (1 - momentum) * sample_mean
var = momentum * current_var + (1 - momentum) * sample_var
# Cache moving stats
current_mean = mean.detach()
current_var = var.detach()
# Compute KLD
logvar = torch.log(var.clamp(min=1e-12))
kld = - 0.5 * (1 + logvar - mean.pow(2) - var)
In cotrast the typical KLD computation (referred as Single Point KLD hereafter) is as follow;
mu, logvar = encoder(input_image)
var = logvar.exp()
kld = - 0.5 * (1 + logvar - mu.pow(2) - var)
Experiment 1
I ran two sets of experiments with different parameter set. In the first experiment, the parameters for adjustment of $\beta$ are as following, which is same as the previous post.
beta_step = 0.1
initial_beta = 10.0
in this experiment, I changed how KLD is computed. Single Point KLD, Batch KLD with momentum 0.9, momentum 0.1 and momentum 0.0.
Observations
- The reconstruction error improved with batch KLD constraint.
- When momentum value for Batch KLD is small, the value of KLD for test cases deviate from target KLD.
- The feature matching error values have become lower with batch KLD.
- $\beta$ grows larger in Single Point KLD and batch KLD with momentum=0.9.
In addition to the above, I recorded some statistics of decoder output, Z_MEAN and Z_STDDEV. For Z_MEAN, Batch KLD has a broader skirt than Single Point KLD. For Single Point KLD, the value of standard deviation of latent points (Z_STDDEV) is distributed near 1.0, but not that does not happend for Batch KLD. This is expected as, in Batch KLD, the whole distribution of latent samples are optimized towards normal distribution. However, the value of Z_STDDEVs are too small and it is virtually making no difference when sampling from the latent distribution.
Experiment 2
For the second set of experiments, the parameters for adjustment of $\beta$ are as following;
beta_step = 0.01
initial_beta = 1.0
I changed KLD computation in the same manner as exp 1; Single Point KLD, Batch KLD with momentum 0.9, momentum 0.1 and momentum 0.0.
In this experiments, the observations 1 - 4 from the experiment 1 are also observed.
- The reconstruction error improved with batch KLD constraint.
- When momentum value for Batch KLD is small, the value of KLD for test cases deviate from target KLD.
- The feature matching error values have become lower with batch KLD.
- $\beta$ grows larger in Single Point KLD and batch KLD with momentum=0.9.
For this experiment, the distribution of latent parameters show somewhat different trend. Z_STDDEV at the beginning of training (before fake samples start collapse) is more distributed.