Generative Adversarial Network

Generative Adversarial Network

Implicit Generative Model

  • Goal: a sampler \(g(\cdot)\) to generate images
  • A simple generator \(g(z; \theta)\)
    • \(z \sim \mathcal{N}(0, I)\)
    • \(x = g(z; \theta)\)
  • Likelihood-free learning
    • Goal: \(g(z; \theta) \approx p_{\text{data}}\)
    • Idea: minimize \(D(g(z; \theta), p_{\text{data}})\)
      • \(D\) is some distance metric
      • \(D\) does not involve likelihood
  • Choose a distance measure \(D\)
    • Goal: choose a differentiable \(D(x; \phi)\) for \(g(z; \theta)\)
    • Objective: \(L(\theta) = D(g(z; \theta); \phi)\)
      • We can optimize \(\theta\) by gradient descent
    • How to get \(\phi\)?
      • Goal: \(D(x; \phi)\) measures how likely \(x\) is from \(p_{\text{data}}\)
      • A binary classification problem!
        • \(D(x; \phi) = 1\): \(x \sim p_{\text{data}}\)
        • \(D(x; \phi) = 0\): \(x\) not from \(p_{\text{data}}\)
      • Let’s train a neural classifier!
        • How to choose the negative samples?
          • Random samples are too easy…
          • We have a generator! (negative samples)

GAN: The Minimax Game

  • Generator \(G(z; \theta)\) (\(z \sim p(z) = \mathcal{N}(0, I)\))
    • generate realistic images
  • Discriminator \(D(x; \phi)\)
    • Classify the data is from \(p_{\text{data}}\) or \(G\)
  • Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
  • Training procedure
    • Collect dataset \(\{(x, 1) \mid x \sim p_{\text{data}}\} \cup \{(\hat{x}, 0) \mid \hat{x} \sim g(z; \theta)\}\)
    • Train discriminator \(D\): \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\)
    • Train generator \(G\): \(L(\theta) = \mathbb{E}_{z \sim p(z)}[\log D(G(z; \theta))]\)
    • Repeat

Optimal Discriminator

  • GAN Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
  • Let’s analyze the optimal solution \(D^*\) and \(g^*\) under \(L(\theta, \phi)\)
  • Optimal \(D(x; \phi^*)\) for \(x\) \[L(D) = p_{\text{data}}(x) \cdot \log D + p_G(x) \log(1 - D)\]
    • Consider \(\frac{\partial L}{\partial D} = 0\)
    • We have \(\frac{p_{\text{data}}(x)}{D^*} - \frac{p_G(x)}{1 - D^*} = 0\)
    • So \((1 - D^*)p_{\text{data}}(x) = p_G(x)D^*\)
    • \[D^* = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_G(x)}\]
  • Remark: when having a perfect generator, \(D^* = 0.5\)
  • GAN Objective \[L(\theta, \phi) = \min_{\theta} \max_{\phi} \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x; \phi)] + \mathbb{E}_{\hat{x} \sim G}[\log(1 - D(\hat{x}; \phi))]\]
  • Optimal discriminator \(D^*(x) = p_{\text{data}}(x) / (p_{\text{data}}(x) + p_G(x))\)
  • Optimal generator \(g(z; \theta)\) with \(\phi^*\)
    • \[L(\theta, \phi^*) = \mathbb{E}_{x \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_G(x)}\right] + \mathbb{E}_{\hat{x} \sim p_G}\left[\log \frac{p_G(x)}{p_{\text{data}}(x)+p_G(x)}\right]\]
    • \[= \mathbb{E}_{x \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(x)}{\frac{p_{\text{data}}(x)+p_G(x)}{2}}\right] + \mathbb{E}_{\hat{x} \sim p_G}\left[\log \frac{p_G(x)}{\frac{p_{\text{data}}(x)+p_G(x)}{2}}\right] - \log 4\]
    • \[= KL\left(p_{\text{data}}(x) \;\Big\|\; \frac{1}{2}(p_{\text{data}} + p_G)\right) + KL\left(p_G \;\Big\|\; \frac{1}{2}(p_{\text{data}} + p_G)\right) - \log 4\]
  • \(2 \times \text{Jensen-Shannon Divergence (JSD)} - \log 4\)

GAN Mechanics: Forward vs. Reverse KL and JSD Behavior

  • \(KL(p \mid\mid q) = \mathbb{E}_{x \sim p}[\log (p(x) / q(x))]\)
  • Asymmetric measure
    • \(KL(p \mid\mid q)\) forward KL (inclusive)
    • \(KL(q \mid\mid p)\) reverse KL (exclusive)
  • \(JSD(p \mid\mid q) = \frac{1}{2}\left(KL\left(p \mid\mid \frac{1}{2}(p + q)\right) + KL\left(q \mid\mid \frac{1}{2}(p + q)\right)\right)\)
  • Properties
    • Symmetric: \(JSD(p \mid\mid q) = JSD(q \mid\mid p)\)
    • \(JSD(p \mid\mid q) \ge 0\) and \(JSD(p \mid\mid q) = 0 \iff p = q\)
    • Jensen-Shannon distance: \(\sqrt{JSD(p \mid\mid q)}\) satisfies triangular inequality

Evaluation of GANs

  • Inception Score \[IS = \exp\left(\mathbb{E}_{x \sim G}\left[KL(f(y \mid x) \mid\mid p_f(y))\right]\right)\]
  • IS only measures the quality of generated samples
    • We also want \(G\) to fully cover \(p_{\text{data}}\)
    • Idea: statistics of \(G(z)\) should be similar to \(p_{\text{data}}\)
      • Statistics: measure the distribution of extracted features of \(p_{\text{data}}\)
  • Fréchet Inception Distance
    • Compute \(\mu_p, \Sigma_p\) and \(\mu_G, \Sigma_G\) for \(p_{\text{data}}\) and \(G(z)\) using inception v3 pool3 layer (2048-d)
    • Compute Wasserstein distance between two Gaussians (more on this later)
    • \[FID = \|\mu_p - \mu_G\|^2 + \text{trace}\left(\Sigma_p + \Sigma_G - 2(\Sigma_p\Sigma_G)^{1/2}\right)\]
    • Lower the better

GAN Techniques

Deep Convolutional GAN

  • GAN Techniques
  • Deep Convolutional GAN (DCGAN, Radford et al, ICLR2016)
    • The first milestone paper to make GAN really work
    • Trick Suggestions
      • Use fully convolutional network
      • Batch normalization should be used
      • Avoid ReLU activation in discriminator
      • Small learning rate and momentum
    • DCGAN can learn interesting features and patterns

Improved Training Techniques for GAN

  • Improved Training Techniques for GAN (Salimans et al, NIPS 2016)
    • The paper primarily about GAN tricks
      • also IS, and semi-supervised learning, more to cover later
    • Trick Suggestions
      • Feature matching
      • Minibatch discrimination
      • Historical averaging
      • One-sided label smoothing
      • Virtual batch normalization

Wasserstein GAN

  • Wasserstein GAN
  • WGAN (Arjovsky et al, ICML 2017)
    • Objective
      • \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[f(x; \phi)] - \mathbb{E}_{x \sim p_G}[f(x; \phi)]\)
      • \(L(\theta) = \mathbb{E}_{x \sim G(z; \theta)}[f(x; \phi)]\)
      • Subject to \(\|f(x; \phi)\|_L \le 1\)
    • Tricks to enforce \(\|f(x; \phi)\|_L \le 1\)
      • For each \(\phi_i\), \(\phi_i \leftarrow \text{clip}(\phi_i, -c, c)\) (e.g., with \(c = 0.01\))
      • No momentum! RMSProp suggested (\(\alpha = 5e-5\))
      • Update critic for \(n_{\text{critic}}\) batches before update \(G\) (\(n_{\text{critic}} = 5\))
    • Results

Improved Training of Wasserstein GANs

  • Wasserstein GAN
  • Improved Training of Wasserstein GANs (Gulrajani et al, NIPS2017)
    • Wasserstein GAN objective \[\min_{G} \max_{f:\|f\|_L \le 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{\hat{x} \sim p_G}[f(\hat{x})]\]
    • Corollary: \(f^*\) have \(\|\nabla f\| = 1\) almost everywhere under \(p_{\text{data}}\) and \(p_G\)
    • WGAN-GP (Gradient Penalty)
      • \(\tilde{x} \leftarrow (1 - \epsilon)x + \epsilon \cdot \hat{x}\) with \(\epsilon \sim \text{Unif}(0,1)\)
      • \(L(\phi) = \mathbb{E}_{x \sim p_{\text{data}}}[f(x; \phi)] - \mathbb{E}_{\hat{x} \sim p_G}[f(\hat{x}; \phi)] + \lambda \cdot \mathbb{E}_{\tilde{x}}\left[(\|\nabla_{\tilde{x}} f(\tilde{x}; \phi)\|_2 - 1)^2\right]\)
    • Practical issue
      • No batch norm for \(f(x; \phi)\) (since we compute \(\nabla_{\tilde{x}} f(\tilde{x})\) for each \(\tilde{x}\))
      • Layer norm or instance norm recommended as a drop-in replacement
    • Remark:
      • Stable training with various architecture (even ResNet) and Adam
      • But expensive due to gradient computation over gradient
      • Also might be unstable due to the heuristic distribution \(\tilde{x}\) when learning rate is high

BigGAN

  • GAN Techniques
  • BigGAN (DeepMind, ICLR2019)
    • Large-Scale Training of GAN!
      • \(\sim\) 100M params
      • TPU training (48h)
      • Large batch size
    • Trick Suggestions
      • Synced cross-replica class-conditioned batch-norm (linear projected from class id)
      • Large batch size (as much as you can) and wider model!
      • Sample \(z\) from truncated Gaussian \(\mathcal{N}(0, 1, -c, c)\)
      • Orthogonal initialization & spectral normalization for weights
      • Hinge loss \(l(z) = \max(0, 1 - t * z)\):
        • Ignore samples when \(D\) make a correct output with high confidence (\(z\) is too high)
        • Adaptive hinge loss margin \(t\) to include sufficient training data