- Published on
Generative Adversarial Networks
Classifiers
The most common general ML model that many are probably familiar with today is what can be described as a 'classifier'. Say, you train a model using hundreds of thousands of pictures of say, horses, and then it gets good at taking unknown pictures of anything, and saying whether that picture contains a horse or not.
So, somewhere within that training process, the network has made an internal model of what a horse looks like, that it is subsequently using to classify other pictures. But you can only really use this model to classify things. What it cannot do is provide a decent output for 'Draw me a picture of a horse that I haven't seen before'. Hence, given a particular distribution, it isn't particularly good at generating new samples that would belong to that distribution.
But strangely, that is the exact premise of websites like thispersondoesnotexist.com, which uses an ML model to generate 'synthetic' faces of people that, well, probably do not exist in real life.
another example of a GAN that turns line art into realistic footwear
So how does it even do that?
Adversarial Training
Adversarial training is a way of training ML systems which involves focusing on the system's weaknesses, and this forms the basis for generative networks.
Let us take our aforementioned horse classifier. Since its job is to recognize horses vs not-horses, it would be really nice if it could identify every breed of horse. However, not all breeds of horses look that similar to each other. Take for instance the Falabella Pony, with its characteristic stunted frame and non-prominent mane
Now, it would be natural for even a human, let alone a neural network, to not be able to correctly identify whether this is indeed a horse, given that it looks similar to other four-legged creatures such as donkeys and albino zebras.
So the first step of training our classifier would naturally be to get the first 10000 or so results from Google Images for the keyword 'Horse', and provide it to our network as a sample data. However, in the next step, if we are to be a responsible trainer, we ought to also train the model on such 'obscure edge cases', that it must currently be naturally weak it.
So we keep force-feeding the model with these obscure breeds for hours upon hours, hoping to iron out its weaknesses. However, doing the same while teaching a human student, is, well, purely evil because it is bound to demotivate them and in turn, hamper the student's overall learning ability.
Therefore, this kind of inherently 'inhumane' training that focuses on 'repeated crunching where the model is currently failing' is called 'Adversarial training' (or in human terms, JEE coaching ;-;) But luckily for us, ✨ machines don't have feelings ✨ ........ yet. And adversarial training works wonders on refining classifiers and making them robust.
Catch Me If You CGan
In simple terms, Generative Adversarial Networks (GANs) function using two components: A Generator and a Discriminator which are networks that work against each other, but also 'together' in a sense, through training.
The Generator creates new data based on either random noise, or previous training data that it has been supplied.
The Discriminator then takes this newly formed data and tries to identify if this is real data, i.e. actual data that was collected or fake data, i.e. data that was generated by the generator.
Essentially the Generator is trying to fool the Discriminator while the Discriminator is trying not to be fooled by data generated by the Generator. Hence comes the ‘Adversarial’ in GANs, where we try to fool models by giving them deceptive data as inputs. This process of generating and testing generated data continues on and on, and the model keeps on learning and getting better at both generating data that resembles the training data more and discriminating between real and generated data by learning from their mistakes.
Coming back to our example, the generator would take say, random line art, and try to produce a picture of a horse, while the discriminator is our aforementioned horse-or-not-horse classifier.
The Slightly Nerdy Stuff
The mathematical way to look at the dynamics of GANs would be to say that it is two networks fighting over one number. One of them wants the number to be high, the other wants it to be low. And that number is none other than the error rate of the discriminator. The discriminator's job is to look at an image that could have come from a genuine dataset or from the generator, and reliably detect whether it is a horse or not. But then generator wants to generate convincing horse images. And hence, if it produces an image that the generator cannot detect is actually a fake, generated image, then it is 'rewarded' accordingly.
Hence, we can say that the Generator and Discriminator play a type of zero-sum game with the minimax method, wherein one participant’s gain or loss is balanced by the losses and gains of the other participant. A crucial metric here is the loss function, a measure of the distance between the current output of the algorithm and the expected output. Thus, a primitive loss function for a GAN could be
Here, D(x) and G(z) are the discriminator and generator’s outputs. The discriminator gives the probability that the input is a horse, while the discriminator produces a picture of a horse, given random noise. Essentially, the Generator tries to minimize this function while the Discriminator tries to maximize it.
CycleGANs
Unfortunately, the process of properly training a GAN is rather tedious, because plenty of problems can arise owing to the delicate balance between the generator and the discriminator.
One of these is what is called 'Mode Collapse'. Say we have our HorseGAN again. Usually we want our GAN to produce a wide variety of outputs, i.e. a distinct horse image per distinct random input.
However, if the generator produces an especially plausible output, it may just learn to produce only that output. In fact, the generator is always trying to find the one output that seems most plausible to the discriminator. If the generator starts producing the same output (or a small set of outputs) over and over again, the discriminator's best strategy is to learn to always reject that output. But if the discriminator fails to do that, then it's too easy for the next generator iteration to find the most plausible output for the discriminator.
Each iteration of generator over-optimizes for a particular discriminator, and the discriminator never manages to learn its way out of the trap. As a result the generators rotate through a small set of output types. This form of GAN failure is called mode collapse.
However, a cleaver way to deal with this, and to also make our GAN more robust overall, is to add yet another converter at the end of the pipeline that converts the generated horse image back into the initial random noise. A more potent example of this would be the so-called 'style-transfer' GANs, which take in an image, and output some stylized image, such as taking in a picture of a zebra and 'programmatically erasing' its stripes to leave behind a horse. Checking whether or not the generated horse can be converted back to a zebra is a great way to work against mode collapse.
Here are a couple of fun examples of style transfer CycleGANs
AnimeGanV2
Anti-toonification by the Pixel2Style2Pixel CycleGAN
Sources
Common Problems in GANs, Google Developers
Minimax Principle, Encyclopedia of Mathematics, European Mathematical Society
A Gentle Introduction to Generative Adversarial Network Loss Functions, Jason Brownlee
AnimeGANv2, akhaliq, HuggingFace
Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation, Richardson et al.