Fun with ML on Splatoon 2     About     Archive

VAE-GAN Part5: VAE-GAN with adaptive β

In the previous post, I could successfully map the latent distribution around the origin, while reducing the reconstruction error, and as the result, I could create fake images from randomly sampled latent vectors. However the adoptation of Batch Normalization still felt cheating, so I looked for a way to archive a similar performance without Batch Normalization. Along the way I found this paper by Peng et al via this medium article. The model being discussed in the paper is Variational GAN, which is different from VAE-GAN, but it has the exact same requirement I am handling, which is enfording the distribution of intermediate output to be Gaussian, and the paper proposed adjusting $\beta$ in a way that KL Divergence is always close to target value.

... enforcing a specific mutual information budget between $\mathbf{x}$ and $\mathbf{z}$ is critical for good performance. We therefore adaptively update $\beta$ via dual gradient descent to enforce a specific constraint $I_{c}$ on the mutual information. $$ \mathcal{\beta} \leftarrow max(0, \beta + \alpha_{\beta} (\mathbb{E}_{\mathbf{x} \sim \tilde{p}(\mathbf{x})}[KL[E(\mathbf{z}|\mathbf{x})||r(\mathbf{z})]] - I_c )) $$

which is implemented as follow in the official code.

new_beta = beta - self.beta_step * (self.target_kl - avg_kl)
new_beta = max(new_beta, 0)

I adopted this method to my VAE-GAN and ran couple of experiments. Also, I modified KL Divergence loss computation so that KLD is computed over batch samples for each latent dimension, like Batch Normalization. The typical implementation of KLD looks somehwat weird to me. The following code is from PyTorch's official implementation.

# mu and logvar are decoder output
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

In this implementation, KLD is 0 when $mu$ is 0 and $var$ is 1. But $mu$ and $var$ are decoder output, and together they represent probability of sample in latent vector of specific input image. This implementation is trying to map all the samples to have normal distribution, instead of having the distribution of the latent samples to follow normal distribution. I do not know how this implementation is justified, but so far, I did not find any alternative implemntation. In my new implementation, I modified this to the following.

def reparameterize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std


sample = reparameterize(mu, logvar)


def kld(samples):
    # 0 is batch dimension
    mean = samples.mean(dim=0)
    var = samples.var(dim=0)
    logvar = torch.log(var.clamp_(min=1e-12))
    return - 0.5 * (1 + logvar - mean.pow(2) - var)

Result

The KL-Divergence is distributted around the target value, and the value of $\beta$ grows as the training progresses, which matches the observations we have seen so far. When running experiments, I observed that large initial $\beta$ value excels the training, which is also stated in the paper (Which is interesting but I have no idea why).

The reconstruction loss on test set goes bellow 1.0. It could go down farther but the fake image genaration starts to collapse.

The latent samples encoded from input images are distributed around the origin.

The features matching errors grow as training proceeds.

Observations

Reconstruction Quality

No noticible difference between different target KL-Divergence 0.1 and 0.2.

Target KLD: 0.1
Target KLD: 0.2

Mode Collapse

Similarly to the previous experiment, fake images collupse to black screen.

Model & code

Code and model is available here