Fun with ML on Splatoon 2     About     Archive

VAE-GAN Part4: AE-GAN with BN Cheating

Timelapse of fake images generated from same random samples. Many samples collapse after a while.

In the previous post, I realized that my what I thought VAE-GAN was not VAE-GAN, and that all the strange phenomena were caused by latent samples not taking a form of normal distribution (exploding KL divergence).

After couple of trials and code review, I realized that encoder part is optimized to reduce two objective function, KL-Divergence and feature matching error. Assuming that KL-Divergence implementation is correct, the only explanatin for exploding KL-Divergence was this feature matching optimization. I tweaked the optimization process so that when optimizing feature matching error, only the decoder parameters are updated. Training models couple of time, KL-Divergence wass no longer exploding, but now the reconstruction error after the same amount of training as before was not as good. In fact, GAN part was not contributing to the training anymore as discriminator converges to zero very quickly.

I tried the following things, but after all, I could not successfully train VAE-GAN while keeping KL-Divergence small.

  • Optimizing reconstruction error in place of / alongside feature matching error.
  • Use BCE error for reconstruction error instead of MSE.
  • Try balancing descriminator/generator update based on loss.
  • Use LeakyReLU in generator (decoder) and discriminator, as described in GAN Hacks.
  • Use SNGAN (Spectral Normalization and hinge loss). Need to revisit this.

At some point, I realized that for the GAN part of the model to work as GAN, the input samples must follow normal distribution, and KL-Divergence as seen as regularization term is not powerful enough, so I tried using Batch Normalization as the last layer of encoder part. This way, the distribution of the encoded samples become closer to normal distribution. Looking back the past few weeks where I tried so many different technique to regularize KLD, the adoptation of Batch Normalization like this felt like cheating, so I personally decided to call this technique Batch Normalization Cheating.

With this technique, the pixel error has reduced passed 0.1, which is a kind of milestone.

KL-Divergence of each batch. It of course is bounded.

The following is some fake images generated from random samples before generator collapses.

Code and model is available here.

References