Generative Adversarial Network
Generative Adversarial Network
Implicit Generative Model
TipImplicit Generative Model Basics
- Goal: a sampler \(g(\cdot)\) to generate images
- A simple generator \(g(z; \theta)\)
- \(z \sim \mathcal{N}(0, I)\)
- \(x = g(z; \theta)\)
- Likelihood-free learning
- Goal: \(g(z; \theta) \approx p_{\text{data}}\)
- Idea: minimize \(D(g(z; \theta), p_{\text{data}})\)
- \(D\) is some distance metric
- \(D\) does not involve likelihood
TipAdversarial Distance Metric & Binary Classification
- Choose a distance measure \(D\)
- Goal: choose a differentiable \(D(x; \phi)\) for \(g(z; \theta)\)
- Objective: \(L(\theta) = D(g(z; \theta); \phi)\)
- We can optimize \(\theta\) by gradient descent
- How to get \(\phi\)?
- Goal: \(D(x; \phi)\) measures how likely \(x\) is from \(p_{\text{data}}\)
- A binary classification problem!
- \(D(x; \phi) = 1\): \(x \sim p_{\text{data}}\)
- \(D(x; \phi) = 0\): \(x\) not from \(p_{\text{data}}\)
- Let’s train a neural classifier!
- How to choose the negative samples?
- Random samples are too easy…
- We have a generator! (negative samples)
- How to choose the negative samples?
GAN: The Minimax Game
TipGenerative Adversarial Networks
- Generator \(G(z; \theta)\) (\(z \sim p(z) = \mathcal{N}(0, I)\))
- generate realistic images
- Discriminator \(D(x; \phi)\)
- Classify the data is from \(p_{\text{data}}\) or \(G\)
- Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
- Training procedure
- Collect dataset \(\{(x, 1) \mid x \sim p_{\text{data}}\} \cup \{(\hat{x}, 0) \mid \hat{x} \sim g(z; \theta)\}\)
- Train discriminator \(D\): \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\)
- Train generator \(G\): \(L(\theta) = \mathbb{E}_{z \sim p(z)}[\log D(G(z; \theta))]\)
- Repeat
Optimal Discriminator
NoteFinding the Optimal Discriminator \(D^*\)
- GAN Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
- Let’s analyze the optimal solution \(D^*\) and \(g^*\) under \(L(\theta, \phi)\)
- Optimal \(D(x; \phi^*)\) for \(x\) \[L(D) = p_{\text{data}}(x) \cdot \log D + p_G(x) \log(1 - D)\]
- Consider \(\frac{\partial L}{\partial D} = 0\)
- We have \(\frac{p_{\text{data}}(x)}{D^*} - \frac{p_G(x)}{1 - D^*} = 0\)
- So \((1 - D^*)p_{\text{data}}(x) = p_G(x)D^*\)
- \[D^* = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_G(x)}\]
- Remark: when having a perfect generator, \(D^* = 0.5\)
NoteGlobal Minimum and JS Divergence Reduction
- GAN Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
- Optimal discriminator \(D^*(x) = p_{\text{data}}(x) / (p_{\text{data}}(x) + p_G(x))\)
- Optimal generator \(g(z; \theta)\) with \(\phi^*\)
- \[L(\theta, \phi^*) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_G(x)}\right] + \mathbb{E}_{\hat{x} \sim p_G}\left[\log \frac{p_G(x)}{p_{\text{data}}(x)+p_G(x)}\right]\]
- \[= \mathbb{E}_{x \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(x)}{\frac{p_{\text{data}}(x)+p_G(x)}{2}}\right] + \mathbb{E}_{\hat{x} \sim p_G}\left[\log \frac{p_G(x)}{\frac{p_{\text{data}}(x)+p_G(x)}{2}}\right] - \log 4\]
- \[= KL\left(p_{\text{data}}(x) \;\Big\|\; \frac{1}{2}(p_{\text{data}} + p_G)\right) + KL\left(p_G \;\Big\|\; \frac{1}{2}(p_{\text{data}} + p_G)\right) - \log 4\]
- \(2 \times \text{Jensen-Shannon Divergence (JSD)} - \log 4\)
GAN Mechanics: Forward vs. Reverse KL and JSD Behavior
TipKullback–Leibler Divergence
- \(KL(p \mid\mid q) = \mathbb{E}_{x \sim p}[\log (p(x) / q(x))]\)
- Asymmetric measure
- \(KL(p \mid\mid q)\) forward KL (inclusive)
- \(KL(q \mid\mid p)\) reverse KL (exclusive)
TipJensen–Shannon Divergence (JSD)
- \(JSD(p \mid\mid q) = \frac{1}{2}\left(KL\left(p \mid\mid \frac{1}{2}(p + q)\right) + KL\left(q \mid\mid \frac{1}{2}(p + q)\right)\right)\)
- Properties
- Symmetric: \(JSD(p \mid\mid q) = JSD(q \mid\mid p)\)
- \(JSD(p \mid\mid q) \ge 0\) and \(JSD(p \mid\mid q) = 0 \iff p = q\)
- Jensen-Shannon distance: \(\sqrt{JSD(p \mid\mid q)}\) satisfies triangular inequality
Evaluation of GANs
TipEvaluation of GAN
- Inception Score \[IS = \exp\left(\mathbb{E}_{x \sim G}\left[KL(f(y \mid x) \mid\mid p_f(y))\right]\right)\]
- IS only measures the quality of generated samples
- We also want \(G\) to fully cover \(p_{\text{data}}\)
- Idea: statistics of \(G(z)\) should be similar to \(p_{\text{data}}\)
- Statistics: measure the distribution of extracted features of \(p_{\text{data}}\)
- Fréchet Inception Distance
- Compute \(\mu_p, \Sigma_p\) and \(\mu_G, \Sigma_G\) for \(p_{\text{data}}\) and \(G(z)\) using inception v3 pool3 layer (2048-d)
- Compute Wasserstein distance between two Gaussians (more on this later)
- \[FID = \|\mu_p - \mu_G\|^2 + \text{trace}\left(\Sigma_p + \Sigma_G - 2(\Sigma_p\Sigma_G)^{1/2}\right)\]
- Lower the better
GAN Techniques
Deep Convolutional GAN
TipGAN Techniques Formulation
- GAN Techniques
- Deep Convolutional GAN (DCGAN, Radford et al, ICLR2016)
- The first milestone paper to make GAN really work
- Trick Suggestions
- Use fully convolutional network
- Batch normalization should be used
- Avoid ReLU activation in discriminator
- Small learning rate and momentum
- DCGAN can learn interesting features and patterns
Improved Training Techniques for GAN
TipImproved Training Techniques for GAN
- Improved Training Techniques for GAN (Salimans et al, NIPS 2016)
- The paper primarily about GAN tricks
- also IS, and semi-supervised learning, more to cover later
- Trick Suggestions
- Feature matching
- Minibatch discrimination
- Historical averaging
- One-sided label smoothing
- Virtual batch normalization
- The paper primarily about GAN tricks
Wasserstein GAN
TipWasserstein GAN (WGAN)
- Wasserstein GAN
- WGAN (Arjovsky et al, ICML 2017)
- Objective
- \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[f(x; \phi)] - \mathbb{E}_{x \sim p_G}[f(x; \phi)]\)
- \(L(\theta) = \mathbb{E}_{x \sim G(z; \theta)}[f(x; \phi)]\)
- Subject to \(\|f(x; \phi)\|_L \le 1\)
- Tricks to enforce \(\|f(x; \phi)\|_L \le 1\)
- For each \(\phi_i\), \(\phi_i \leftarrow \text{clip}(\phi_i, -c, c)\) (e.g., with \(c = 0.01\))
- No momentum! RMSProp suggested (\(\alpha = 5e-5\))
- Update critic for \(n_{\text{critic}}\) batches before update \(G\) (\(n_{\text{critic}} = 5\))
- Results
- Objective
Improved Training of Wasserstein GANs
TipWGAN-GP (Gulrajani et al., NIPS 2017) Formulation
- Wasserstein GAN
- Improved Training of Wasserstein GANs (Gulrajani et al, NIPS2017)
- Wasserstein GAN objective \[\min_{G} \max_{f:\|f\|_L \le 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{\hat{x} \sim p_G}[f(\hat{x})]\]
- Corollary: \(f^*\) have \(\|\nabla f\| = 1\) almost everywhere under \(p_{\text{data}}\) and \(p_G\)
- WGAN-GP (Gradient Penalty)
- \(\tilde{x} \leftarrow (1 - \epsilon)x + \epsilon \cdot \hat{x}\) with \(\epsilon \sim \text{Unif}(0,1)\)
- \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[f(x; \phi)] - \mathbb{E}_{\hat{x} \sim p_G}[f(\hat{x}; \phi)] + \lambda \cdot \mathbb{E}_{\tilde{x}}\left[(\|\nabla_{\tilde{x}} f(\tilde{x}; \phi)\|_2 - 1)^2\right]\)
- Practical issue
- No batch norm for \(f(x; \phi)\) (since we compute \(\nabla_{\tilde{x}} f(\tilde{x})\) for each \(\tilde{x}\))
- Layer norm or instance norm recommended as a drop-in replacement
- Remark:
- Stable training with various architecture (even ResNet) and Adam
- But expensive due to gradient computation over gradient
- Also might be unstable due to the heuristic distribution \(\tilde{x}\) when learning rate is high
BigGAN
TipBigGAN (Brock et al., ICLR 2019) Formulation
- GAN Techniques
- BigGAN (DeepMind, ICLR2019)
- Large-Scale Training of GAN!
- \(\sim\) 100M params
- TPU training (48h)
- Large batch size
- Trick Suggestions
- Synced cross-replica class-conditioned batch-norm (linear projected from class id)
- Large batch size (as much as you can) and wider model!
- Sample \(z\) from truncated Gaussian \(\mathcal{N}(0, 1, -c, c)\)
- Orthogonal initialization & spectral normalization for weights
- Hinge loss \(l(z) = \max(0, 1 - t * z)\):
- Ignore samples when \(D\) make a correct output with high confidence (\(z\) is too high)
- Adaptive hinge loss margin \(t\) to include sufficient training data
- Large-Scale Training of GAN!