Normalizing Flow & Variational Autoencoder

Normalizing Flow

  • Generative Models
    • Generator: transform a simple distribution to the data distribution
    • Generation: sample \(z_0 \sim \pi(z)\), \(x = G(z_0)\)
  • Change Variables
    • For invertible neural networks, the change of variable formula: \[p(x) = \pi(z) \left| \det(J_{G^{-1}}(x)) \right|, \quad z = G^{-1}(x)\]
    • \(J_{G^{-1}}\): Jacobian matrix of \(G^{-1}\)
    • Intuition: \[p(x')\Delta x = \pi(z')\Delta z \implies p(x') = \pi(z')\frac{\Delta z}{\Delta x} \implies p(x') = \pi(z')\left|\frac{dz}{dx}\right|\]
  • Normalizing Flows
    • Loss function: \[\log p(x) = \log \pi(G^{-1}(x)) + \log \left| \det(J_{G^{-1}}(x)) \right|\]
    • Decompose \(G\) into the decomposition of many sub-net: \[z_K = f_{\theta^K} \circ f_{\theta^{K-1}} \circ \dots \circ f_{\theta^1}(z_0)\]
    • Chain rule (given \(\left|\det(J_{G^{-1}})\right| = 1/\left|\det(J_G)\right|\)): \[\log p(z_K) = \log \pi(z_0) - \sum_{k=1}^K \log \left| \det\left(\frac{\partial f_{\theta^k}}{\partial z_k}\right) \right|\]
  • Coupling Layers
    • \(f(x) = \begin{bmatrix} x^A \\ \hat{f}(x^B \mid \theta(x^A)) \end{bmatrix}\)
    • Jacobian \[J_f = \begin{bmatrix} I & 0 \\ \partial\hat{f}(x^B \mid \theta(x^A)) / \partial x^A & J_{\hat{f}} \end{bmatrix}\]
    • Determinant \[\det(J_f) = \det(J_{\hat{f}})\]
  • Coupling Layers
    • \(\theta(x^A)\) can be arbitrary network
    • Splitting is important for the expressivity of the invertible net.
    • Coupling Transform
      • Linear [NICE, Dinh et al. 2014] \[\hat{f}(x \mid t) = x + t\]
      • Affine [RealNVP, Dinh et al. 2016] \[\hat{f}(x \mid s, t) = s \odot x + t\]
  • Glow
    • Input \(h: c \times h \times w\), \(W: c \times c\), Output: \(c \times h \times w\) \[\log \left| \det \left( \frac{d \text{ conv2D}(h; W)}{d h} \right) \right| = h \cdot w \cdot \log |\det(W)|\]
    • Compute \(|\det(W)|\): \(\mathcal{O}(c^3)\) (Vs. complexity of conv \(\mathcal{O}(hwc^2)\))
    • Reparametrize \(W\) in LU decomposition \[W = P L (U + \text{diag}(s))\] \[\log |\det(W)| = \text{sum}(\log |s|)\]
    • Components of LU decomposition:
      • \(P\) (fixed): Permutation matrix
      • \(L\): A lower triangular matrix with diagonal elements equal to 1.
      • \(U\): An upper triangular matrix with diagonal elements equal to 0.
    • complexity: \(\mathcal{O}(c)\)
  • Continuous-time Normalizing Flows
    • Neural ODE \[\frac{dz(t)}{dt} = f(z(t), t, \theta)\] \[\text{forward: } z(t_1) = z(t_0) + \int_{t_0}^{t_1} f(z(t), t, \theta) \quad (\text{discretization} = \text{ResNet})\] \[\text{reverse: } z(t_0) = z(t_1) + \int_{t_1}^{t_0} f(z(t), t, \theta)\]
    • Flow map of an ODE is always invertible.
    • Instantaneous Change of Variables \[\frac{d\log p(z(t))}{dt} = -\text{Tr}\left(\frac{df}{dz(t)}\right) = -\text{div}(f)\]
  • Continuous-time Normalizing Flows
    • Loss \[L(z(t)) = -\log p(z(t)) = -\log p(z(0)) + \int_{0}^{t} \text{Tr}\left(\frac{df}{dz(t)}\right)\] \[\underbrace{\begin{bmatrix} z_0 \\ \log p(\mathbf{x}) - \log p_{z_0}(\mathbf{z}_0) \end{bmatrix}}_{\text{solutions}} = \int_{t_1}^{t_0} \underbrace{\begin{bmatrix} f(\mathbf{z}(t), t; \theta) \\ -\text{Tr}\left(\frac{\partial f}{\partial\mathbf{z}(t)}\right) \end{bmatrix}}_{\text{dynamics}} dt, \quad \underbrace{\begin{bmatrix} \mathbf{z}(t_1) \\ \log p(\mathbf{x}) - \log p(\mathbf{z}(t_1)) \end{bmatrix} = \begin{bmatrix} \mathbf{x} \\ 0 \end{bmatrix}}_{\text{initial values}}\]
    • Dynamics of adjoint \(\mathbf{a}(t) = \partial L / \partial\mathbf{z}(t)\) \[\frac{d\mathbf{a}(t)}{dt} = -\mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t), t, \theta)}{\partial\mathbf{z}}\] \[\frac{dL}{d\theta} = -\int_{t_1}^{t_0} \left(\frac{\partial L}{\partial\mathbf{z}(t)}\right)^T \frac{\partial f(\mathbf{z}(t), t; \theta)}{\partial\theta} dt\]

Variational Autoencoders

  • Gaussian Mixture Model
    • \(z\): latent variable
    • If \(z\) can only take finite number of values, e.g. \[p(z = i) = \pi_i, \quad i = 1, 2, \dots, N.\]
    • and \(p(x \mid z)\) is Gaussian, then \[p(x) = \sum_{i=1}^N \pi_i \mathcal{N}(\mu_i, \Sigma_i)\]
    • Gaussian Mixture Model (Generalized K-means with soft assignments)
  • Recaps: K-means
    • K-means solves the following problem alternatively \[\min_{\gamma_{ij} \in S, c_j} \frac{1}{2N} \sum_{ij} \gamma_{ij} \|x_i - c_j\|_2^2\]
    • K-means Clustering Algorithm
      • Input: \(\mathcal{D} = \{x_i\}_{i=1}^N\), \(x_i \in \mathbb{R}^d\), K (number of clusters)
      • Initialize: \(C = (c_1, c_2, \dots, c_K)\).
      • Iterate Until Converge:
        • Update \(\gamma\): \[\gamma_{ij} = \begin{cases} 1 & \text{if } j = \text{argmin}_k \|x_i - c_k\|_2^2 \\ 0 & \text{else} \end{cases}\]
        • Update \(C\): \[c_j = \left(\sum_i \gamma_{ij} x_i\right) \Big/ \sum_i \gamma_{ij}\]
      • Output: Cluster centers C and cluster assignment \(\gamma\).
  • Gaussian Mixture Model
    • GMM optimization
    • update assignment \(\gamma_{ij} = p(z_i = j \mid x_i)\) (posterior) \[\gamma_{ij} := p(z_i = j \mid x_i) = \frac{p(z_i = j) p(x_i \mid z_i = j)}{p(x_i)} = \frac{p(z_i = j) p(x_i \mid z_i = j)}{\sum_j p(z_i = j) p(x_i \mid z_i = j)} = \frac{\pi_j^{(k)} \mathcal{N}(x_i \mid \mu_j^{(k)}, \Sigma_j^{(k)})}{\sum_j \pi_j^{(k)} \mathcal{N}(x_i \mid \mu_j^{(k)}, \Sigma_j^{(k)})}\]
    • update center \((\mu, \Sigma)\) and \(\pi_i = p(z = i)\) (prior) \[\pi_j = p(z = j) = \int p(z = j \mid x) p(x) dx \approx \frac{\sum_{i=1}^N p(z = j \mid x_i)}{N} = \frac{\sum_i \gamma_{ij}}{N}\]
  • Expectation Maximization (EM) Algorithm (Variational Inference) \[\sum_{i} \log \left( \sum_{j} \gamma_{ij} \pi_j \mathcal{N}(x_i \mid \mu_j, \Sigma_j) \big/ \gamma_{ij} \right) \ge \sum_{i} \left( \sum_{j} \gamma_{ij} \log \frac{\pi_j \mathcal{N}(x_i \mid \mu_j, \Sigma_j)}{\gamma_{ij}} \right) \quad \text{Jensen's Inequality}\]
  • Expectation Step: \[\gamma_{ij}^{(k+1)} = \arg\max_{\{\gamma_{ij} : \sum_j \gamma_{ij} = 1, \gamma_{ij} \in [0, 1]\}} \sum_{i} \sum_{j} \gamma_{ij} \log \pi_i^{(k)} \mathcal{N}\left(x_i \mid \mu_j^{(k)}, \Sigma_j^{(k)}\right) - \gamma_{ij} \log \gamma_{ij}\]
  • Maximization Step: \[\left(\pi_j^{(k+1)}, \mu_j^{(k+1)}, \Sigma_j^{(k+1)}\right) = \arg\max_{\theta} \sum_{i} \sum_{j} \gamma_{ij}^{(k+1)} \log \pi_j \mathcal{N}(x_i \mid \mu_j, \Sigma_j)\]
  • Jensen’s Inequality \[\log \int p(z)f(z)dz \ge \int p(z)\log f(z)dz\]

  • Evidence Lower Bound (ELBO) \[\log p(x) = \log \int p(z)p(x \mid z)dz = \log \int q_x(z) \frac{p(z)p(x \mid z)}{q_x(z)} dz\] \[\ge \int q_x(z) \log \frac{p(z)p(x \mid z)}{q_x(z)} dz\] \[= \int q_x(z) \log p(x \mid z)dz - KL(q_x(z) \parallel p(z))\]

  • Gap between \(\log p(x)\) and ELBO \[\log p(x) - \left( \int q_x(z)\log p(x \mid z)dz - KL(q_x(z) \parallel p(z)) \right)\] \[= \log p(x) - \int q_x(z)\log p(x \mid z)dz + KL(q_x(z) \parallel p(z))\] \[= \int q_x(z)\log p(x)dz - \int q_x(z)\log p(x \mid z)dz + KL(q_x(z) \parallel p(z))\] \[= \int q_x(z) \left[ \log p(x) - \log p(x \mid z) - \log p(z) \right] dz + \int q_x(z)\log q_x(z)dz\] \[= -\int q_x(z) \log \frac{p(x \mid z)p(z)}{p(x)} dz + \int q_x(z)\log q_x(z)dz\] \[= -\int q_x(z) \log p(z \mid x)dz + \int q_x(z)\log q_x(z)dz = KL(q_x(z) \parallel p(z \mid x))\]

  • When \(q_x(z) = p(z \mid x)\), the gap is closed.

  • (In GMM, \(q_{x_i}(z = j) = \gamma_{ij} = p(z = j \mid x_i)\))

  • EM Algorithm: Alternately Maximize ELBO

    \[ \max_{\phi,\psi} \int q_{\phi(x)}(z)\log p_\psi(x \mid z)dz - KL(q_{\phi(x)}(z) \parallel p(z)) \]

    • Expectation: \[ \max_\phi ELBO \Rightarrow q_{\phi(x)}(z) = p(z \mid x) \]

    However, we can’t obtain \(p(z \mid x)\) in general; propose a parameterized distribution \(q_{\phi(x)}(z)\) we know we can work with easily to approximate \(p(z \mid x)\).

    • Maximization: \[ \max_\psi ELBO \]
  • For Gaussian Mixture Models, the negative log-likelihood (NLL) function is non-convex, and directly minimizing it can lead to poor performance due to local optima. In contrast, the EM algorithm guarantees a monotonic decrease in the NLL at each iteration.

  • Parametrize \[ q_x(z)=\mathcal N(\mu_\theta(x), \sigma_\theta(x)) \] and \[ p_\psi(x \mid z)=\mathcal N(\psi(z), \beta I) \]

  • ELBO: \[ ELBO = \int q_x(z)\log p(x \mid z)dz - KL(q_x(z)\parallel p(z)) \]

    \[ = -\frac{1}{2\beta} \mathbb E_{q_x(z)} \left\| \psi(z)-x \right\|_2^2 - KL(q_x(z)\parallel p(z)) \]

    \[ = -\frac{1}{2\beta} \mathbb E_{\epsilon\sim\mathcal N(0,I)} \left\| \psi(z)-x \right\|_2^2 - KL(q_x(z)\parallel p(z)) \]

    where \[ z=\mu_\theta(x)+\sigma_\theta(x)\epsilon \]

  • If \[ p(z)=\mathcal N(0,I), \]

    then \[ KL(q_x(z)\parallel p(z)) = \frac{1}{2}\|\mu_\theta(x)\|_2^2 + \frac{1}{2}\sigma_\theta^2(x) - \log \sigma_\theta(x) + c \]

  • Gaussian KL formula:

    Let \[ p(x)=\mathcal N(\mu_1,\sigma_1) \] and \[ q(x)=\mathcal N(\mu_2,\sigma_2). \]

    \[ KL(p,q) = \log \frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \]

Hierarchical latent space decomposition

  • Ordinary VAE uses one latent variable: \[ z \sim p(z), \qquad x \sim p(x \mid z) \]

  • Hierarchical PixelVAE uses multiple latent variables: \[ z_L \rightarrow z_{L-1} \rightarrow \cdots \rightarrow z_1 \rightarrow x \]

  • Generative model: \[ p(x,z_1,\dots,z_L) = p(z_L) \prod_{l=1}^{L-1}p(z_l\mid z_{l+1}) p(x\mid z_1) \]

  • Encoder / inference model: \[ x \rightarrow h_1 \rightarrow h_2 \rightarrow \cdots \rightarrow h_L \]

    The encoder produces approximate posteriors: \[ q(z_l\mid x) \]

  • Reconstruction term: \[ -\mathbb E_{z_1\sim q(z_1\mid x)} \log p(x\mid z_1) \]

  • KL terms at different latent levels: \[ D_{KL}(q(z_L\mid x)\parallel p(z_L)) \] \[ D_{KL}(q(z_{L-1}\mid x)\parallel p(z_{L-1}\mid z_L)) \] \[ \cdots \] \[ D_{KL}(q(z_1\mid x)\parallel p(z_1\mid z_2)) \]

  • Training loss: \[ \mathcal L = -\mathbb E_q[\log p(x\mid z_1)] + D_{KL}(q(z_L\mid x)\parallel p(z_L)) + \sum_{l=1}^{L-1} D_{KL}(q(z_l\mid x)\parallel p(z_l\mid z_{l+1})) \]

  • ELBO form: \[ ELBO = \mathbb E_q[\log p(x\mid z_1)] - D_{KL}(q(z_L\mid x)\parallel p(z_L)) - \sum_{l=1}^{L-1} D_{KL}(q(z_l\mid x)\parallel p(z_l\mid z_{l+1})) \]

  • Intuition: high-level latent variables model global structure, while low-level latent variables model local details.

  • PixelVAE combines hierarchical latent variables with an autoregressive PixelCNN-style decoder.

Hierarchical Architecture

  • Hierarchical latent variables: \[ z=(z_1,z_2,\dots,z_L) \]

  • Exact chain-rule decomposition: \[ q(z)=\prod_{l=1}^{L} q(z_l\mid z_{<l}) \] \[ p(z\mid x)=\prod_{l=1}^{L} p(z_l\mid z_{<l},x) \]

    This is an exact decomposition and does not introduce additional assumptions.

  • Conditional prior at layer \(l\): \[ q(z_l\mid z_{<l}) = \mathcal N \left( z_l; \mu(z_{<l}), \sigma^2(z_{<l}) \right) \]

  • Approximate posterior at layer \(l\): \[ p(z_l\mid z_{<l},x) = \mathcal N \left( z_l; \mu(z_{<l})+\Delta\mu(z_{<l},x), \sigma^2(z_{<l})\otimes\Delta\sigma^2(z_{<l},x) \right) \]

  • The posterior is parameterized as a residual correction to the prior: \[ \mu_{\text{post}} = \mu_{\text{prior}} + \Delta \mu \] \[ \sigma^2_{\text{post}} = \sigma^2_{\text{prior}} \otimes \Delta\sigma^2 \]

  • Layer-wise KL decomposition: \[ KL(p(z\mid x)\parallel q(z)) = KL(p(z_1\mid x)\parallel q(z_1)) + \sum_{l=2}^{L} \mathbb E_{p(z_{<l}\mid x)} \left[ KL( p(z_l\mid z_{<l},x) \parallel q(z_l\mid z_{<l}) ) \right] \]

  • For diagonal Gaussian distributions, each layer has closed-form KL: \[ KL( p(z_l\mid z_{<l},x) \parallel q(z_l\mid z_{<l}) ) = \frac{1}{2} \sum_{i=1}^{|z_l|} \left( \frac{\Delta\mu_i^2}{\sigma_i^2} + \Delta\sigma_i^2 - \log \Delta\sigma_i^2 - 1 \right) \]