Variational Autoencoder (VAE) is a type of Autoencoder and a neural network that trains using an unsupervised technique. VAE are widely used in Image generation models mainly on latent diffusion-based and GANs-based image generation models.
If we have an image as an input with the size of 1024x1024. the total pixels would be approx 1 million and it is very computationally expensive to take all the pixels and feed them directly into a neural network so we needed another neural network model which can understand the image and represent these pixels in smaller sizes of features, these features are called as Latent.
That’s where Autoencoders are used for dimensional reduction and data representation. Autoencoder neural network learns through unsupervised learning where it doesn’t need to label data. Autoencoder has two big blocks, Encoder and Decoder, Encoder encodes the image into latent. The Decoder uses the latents to reconstruct back to the original image. The loss is calculated through the MSE or any distance measurement metric, where it find how close the decoded image with respect to input, the training goal is to minimize this loss which we call reconstruction loss.
But these types of Autoencoders don’t capture any knowledge of variations of images, let’s suppose you want to write the number “7”, it can be written in many different variations. Similarly, when we feed it various types of data it won’t capture those differences especially in when we want to use it for generating images. That’s where Variational Autoencoder (VAE) comes into the picture which solves this problem.
Unlike normal Autoencoders which encode a single point in latent, the VAE encodes probability distribution which helps the model to create many pools within the latent space, where similar latents(images) are aligned closely in a single pool and unlike images are far or in another pool in the space.
This figure shows how the latents are distributed in different pools, where each pool states that the variations of the image belong to the same class. Similarly for the pools, the similar ones are closer and very distinct ones are very far away.
As mentioned above VAEs encode images to probability distribution and learn the probability distribution throughout the learning process, therefore it can capture uncertainty and variability in data.
How does the model learn probability distribution?
Image → Probability Distribution (μ, σ) μ: mean
σ: standard variation
these μ and σ are learnable parameters which help the model to learn the probability distribution through gradient descent. When the model is fully trained then the model can differentiate between variations of images. More technically, when we generate embeddings of images we divide these embeddings into μ and σ learnable embeddings which represents mean and standard deviation. To create final latents, we multiply these parameters, we also add a constant to the mean before multiplication.
z = μ + constant * σ
Then we use this latent and try to reconstruct back to the original image. To calculate the loss function we first calculate the reconstruction loss which helps us to get the difference between the original image and reconstructed image, then we add some constant to it (beta), then we estimate the KL divergence to multiply with it, KL divergence is a metric which helps us understand the difference between probability distributions, as we are dealing with distributions we need to understand how each differs. this is how we can generate the KL divergence:
mean, log_variance = torch.chunk(encoded, 2, dim=1)
kl_div = -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
Here’s how the loss estimation looks like: