The Long Walk Home: Statistical Distances, Part 2

To recap, in the setting of implicit generative modeling (IGM), we wish to train a model to produce data that appears to be drawn from a target distribution $p$, which might be images of cats or whatever else we’re interested in. The distribution $p$ is unknown, and we only have indirect access to it through finite data samples. The model that we are training produces data with distribution $q_t$ at time $t$, also unknown, which early on will not look anything like the target data. Our goal is to move $q_t$ closer and closer to $p$ until the distributions overlap and the data they generate is indistinguishable in terms of its content and style. We don’t want to simply reproduce the training data, though. So if the target data is cat pictures, we want our model to start churning out new cat pictures that match the overall look of our training data.

In Part 1, I introduced the concept of a statistical distance $d(p, q_t)$ to measure the separation of two distributions, which tells us how far we need to go to achieve our goal. Our training objective, then, is for $d(p, q_{t’}) \leq d(p, q_t)$ when $t’ > t$. That is, we want to get progressively closer to our target as time goes on. I provided examples of statistical distances based on the moment generating function, $M_p(\mb{t})$, and the characteristic function, $\varphi_p(\mb{t})$. Since these functions are equal at every argument $\mb{t} \in \mathbb{R}^d$ for two distributions $p$ and $q_t$ if and only if $p = q_t$, they provide a straightforward way to measure how far apart two distributions are: Simply compare their values over a range of arguments $\mb{t}$, and add up or average the amount by which they differ. This quantity should get closer to zero as our model gets better at capturing the essence of distribution $p$.

Classifier-Based Tests and GANs

The distances I have described already are just two of very many possibilities, and they’re not even the most commonly used for the IGM problem. As with statistical distances, we can also test the equality of two distributions by training a classifier on examples from the distributions. That is, we want to train the classifier to be able to identify when data has come from distribution $p$ and when it has come from distribution $q_t$. Intuitively, the more dissimilar the distributions are, the easier the job the classifier will have and the better its performance. However, if the distributions are identical, then the classifier should fail to distinguish them in any meaningful way, even if slightly overfitting to the training data. After all, how could the classifier succeed if the distributions are the same?

An idea due to David Lopez-Paz and Maxime Oquab is to separate the data into training and testing buckets. A classifier $f_{\psi}: \mathbb{R}^d \to (0,1)$ is then trained in the usual way, such that $$f_{\psi}(\mb{x}) \approx \begin{cases} 1 & \mb{x} \sim p \\ 0 & \mb{x} \sim q_t \end{cases}.$$ We then apply the trained classifier to the test data of size $M$ and compute the statistic $$T = \frac{1}{M} \sum_{i=1}^M \mathbb{I} \left[ \mathbb{I} \left( f_{\psi}(\mb{x}) > \frac{1}{2} \right) = \mathbb{I} \left( \mb{x} \sim p \right) \right],$$ where $\mathbb{I}$ is an indicator function that returns 1 if the condition in its argument is true and 0 otherwise.

Translating the above equation into words, we’re looking to see the rate at which the classifier gets it right. First, inside the brackets, we check if we have classified a point as coming from $p$, which we do if $f_{\psi}(\mb{x}) > \frac{1}{2}$ (true or false; 1 or 0). Then we see if it really did come from $p$ (true or false; 1 or 0). And then we count up all the times these truth values agreed ($1 = 1$ or $0 = 0$) and compute the average. This is the classification rate. Under the null hypothesis, $p = q_t$, the classifier will have chance performance, meaning it will hover around 50 percent accuracy, assuming that we have an equal number of samples from $p$ and $q_t$. The exact neighborhood within which the performance hovers is the sampling distribution under the null hypothesis, which is a binomial distribution, $\mbox{BIN}(\frac{1}{2}, M)$, which for large $M$ is approximately $\mathcal{N}(\frac{1}{2}, \frac{1}{4M})$.

In the IGM setting, though, classifiers are perhaps best known for their role in generative adversarial networks (GANs), where they are more commonly referred to as discriminators. Here they provide a training signal to a generator, $g_{\theta}$, under the so-called non-saturating GAN loss, $$\mathcal{L}(\theta) = -\log f_{\psi}(g_{\theta}(\mb{z})),$$ which is equivalent to maximizing $f_{\psi}(g_{\theta}(\mb{z}))$ for generated data $g_{\theta}(\mb{z}) \sim q_t$, where $\mb{z}$ is a random input vector from a convenient, easy-to-sample-from prior distribution, often the standard Gaussian. This is consistent with the common lore about GANs, namely that the generator is trained in a way to “fool” the discriminator into classifying generated (fake) images as real.

Under certain conditions, optimizing this GAN objective can be shown to minimize the Jensen-Shannon divergence (JSD), another type of statistical distance,1 which is defined by $$\mbox{JSD}(p \| q_t) = \frac{1}{2} D(p \| m) + \frac{1}{2} D(q_t \| m),$$ where $$D(p \| q) = \E_p \left[ \log \frac{p(\mb{x})}{q(\mb{x})} \right]$$ is the Kullback-Leibler (KL) divergence, and $m = \frac{1}{2}(p + q_t)$ is a 50-50 mixture of the distributions $p$ and $q_t$. While the KL divergence will equal zero if and only if $p = q_t$, it is not symmetric (meaning that it is not a true distance based on the conditions a distance has to obey). But the JSD is symmetric, which makes it behave as a distance.2 Naturally, for distance-like measures $d(p, q_t)$ that equal zero if and only if $p = q_t$, any method that makes one of them vanish will make all of them vanish. In other words, we need not get too hung up on the Jensen-Shannon formulation, especially since it relies on assumptions that do not typically hold in practice.

Something to keep in mind is that if our model training works, eventually the data we generate and the real training data will be indistinguishable. That means that running Lopez-Paz and Oquab’s test will reveal chance classification performance, and we’ll conclude that the generated data distribution is equal to the data distribution. But that also means that in the process of training a GAN with a discriminator, we ultimately have to destroy the discriminator’s ability to discriminate.

An Alternative View of GANs

Although GAN generator losses are defined with respect to the generator parameters, recent work of mine shows that it may be more informative to break the step of training using these losses into two mathematically equivalent sub-problems, the first of which focuses on making adjustments to the generator output (in $\mathbb{R}^d$), which lives in “data space,” and then, in the second sub-problem, adjusting the generator parameters by regressing this adjusted output on the generator input, $\mb{z}$. In short, GAN generator training can be thought of as a target-generation step followed by a regression step.

It is fairly straightforward to show that this is the case. Let $\tilde{\mb{x}}(\theta) = g_{\theta}(\mb{z})$ be an output of our generator.3 Via the chain rule, we can decompose the GAN loss as $$\frac{\partial \mathcal{L}(\theta)}{\partial \theta} = \frac{\partial \mathcal{L}(\theta)}{\partial \tilde{\mb{x}}(\theta)} \frac{\partial \tilde{\mb{x}}(\theta)}{\partial \theta}.$$ From here on, unless we need to remind ourselves of a variable’s dependency on $\theta$, we will suppress this argument to make the notation easier.

Assume that we first adjust $\tilde{\mb{x}}$ by taking a step in the negative direction of the gradient $\nabla_{\tilde{\mb{x}}} \mathcal{L} = \left( \frac{\partial \mathcal{L}}{\partial \tilde{\mb{x}}} \right)^\top$ weighted by the step size $\lambda_1$, getting $$\tilde{\mb{x}}’ = \tilde{\mb{x}} – \lambda_1 \left( \frac{\partial \mathcal{L}}{\partial \tilde{\mb{x}}} \right)^\top.$$ Reminding ourselves that $\tilde{\mb{x}}$ is the output of a function that depends on $\theta$, we now define the regression loss $$\mathcal{J}(\theta) = \frac{1}{2} \| \tilde{\mb{x}}(\theta) – \tilde{\mb{x}}’ \|^2,$$ with derivative $$\frac{\partial \mathcal{J}}{\partial \theta} = (\tilde{\mb{x}}(\theta) – \tilde{\mb{x}}’)^\top \frac{\partial \tilde{\mb{x}}(\theta)}{\partial \theta} = \lambda_1 \frac{\partial \mathcal{L}}{\partial \tilde{\mb{x}}(\theta)} \frac{\partial \tilde{\mb{x}}(\theta)}{\partial \theta} = \lambda_1 \frac{\partial \mathcal{L}}{\partial \theta},$$ where the latter steps follow from rearranging, transposing, and substituting in the update equation for $\tilde{\mb{x}}$. This shows that the sub-problem formulation is equivalent to the standard version when $\lambda_1 = 1$. Beyond this, the sub-problem formulation admits additional control for the regression problem, such as the introduction of a separate learning rate, $\lambda_2$. This result is also true for any GAN loss, not just the non-saturating loss.

The main takeaway from the theory developed above is that it is often helpful to think of what a loss function is doing (at least virtually) to the data being generated by a model in $\mathbb{R}^d$, since this represents new targets for the generator to hit that will push its distribution $q_t$ closer to $p$. That is, during training, the generator generates something, and we use the signal from the discriminator to find something better that it could have generated instead. We then adjust the generator by regressing that “something better” on its original input, $\mb{z}$.

For those familiar with them, we are essentially regressing the generator on adversarial examples until there are no more of them to be found, which is what happens when the discriminator’s discrimination ability is ultimately destroyed. There is no signal left from the discriminator and therefore nowhere else to go in data space to find a better target for the generator.

The Ubiquitous “Score Difference”

This viewpoint of training in “data space” will be helpful when we consider state-of-the-art generative methods such as diffusion models, which operate directly in $\mathbb{R}^d$.4 With this in mind, let us take another look at our discriminator, $f_{\psi}$. Regardless of the architecture used, it will almost always be the case that the last layer of this model will be a logistic sigmoid activation that keeps the output in the range $(0,1)$. The definition of the logistic sigmoid is $$f_{\psi}(\mb{x}) = \frac{1}{1 + \exp (-h_{\psi}(\mb{x}))},$$ where $h_{\psi}(\mb{x})$ is the pre-activation output of the model, namely the argument to the logistic sigmoid function. If the discriminator is optimal,4 then it is not hard to show using Bayes’s theorem that $$f_{\psi}(\mb{x}) = \frac{p(\mb{x})}{p(\mb{x}) + q_t(\mb{x})}.$$ Equating these two expressions, after a little algebra we see that $$h_{\psi}(\mb{x}) = \log \left( \frac{p(\mb{x})}{q_t(\mb{x})} \right) = \log p(\mb{x}) – \log q_t(\mb{x}).$$

Leveraging the sub-problem interpretation described above, we will examine what the non-saturating loss’s effect on generated data $\tilde{\mb{x}} = g_{\theta}(\mb{z})$ is: $$-\nabla_{\tilde{\mb{x}}} \log f_{\psi}(\tilde{\mb{x}}) = -\frac{\nabla_{\tilde{\mb{x}}} f_{\psi}(\tilde{\mb{x}})}{f_{\psi}(\tilde{\mb{x}})} = -(1 – f_{\psi}(\tilde{\mb{x}})) \nabla_{\tilde{\mb{x}}} h_{\psi}(\tilde{\mb{x}}),$$ where the last equality follows from the properties of the sigmoid function’s derivative.

Since in minimizing the loss we are moving opposite this gradient, and since $f_{\psi}(\mb{x}) \leq 1$, we can see that we are moving in the direction $$\nabla_{\tilde{\mb{x}}} h_{\psi}(\tilde{\mb{x}}) = \nabla_{\tilde{\mb{x}}} \log p(\tilde{\mb{x}}) – \nabla_{\tilde{\mb{x}}} \log q_t(\tilde{\mb{x}})$$ when we update the generator output $\tilde{\mb{x}}$ to form $\tilde{\mb{x}}’$, creating a new regression target for the generator. That regression problem then updates the generator parameters, $\theta$.

This generator output update direction is the first appearance of a quantity that will be coming up quite a bit in future posts, so much so that we will give it a name: the score difference. Soon we will discuss the key role that this quantity plays in score matching (minimizing the Fisher divergence, yet another statistical distance), diffusion models (the machinery behind DALL·E-2 and its cousins), and various flavors of probability flow dynamics found in stochastic and ordinary differential equations.

Notes and References
  1. Technically, the square root of the JSD is a metric, which is known as the Jensen-Shannon distance.
  2. An arguable exception is latent diffusion, which operates on the latent space of a pre-trained autoencoder. But that space winds up being the “data space” of the diffusion problem, since the autoencoder is pre-trained. It’s a subtle point that we need not occupy ourselves with now.
  3. The generator output $\tilde{\mb{x}}(\theta)$ is of course also a function of $\mb{z}$, but since $\mb{z}$ is considered fixed, we will not list it as an argument in this analysis.
  4. This is a heavy assumption, but it’s one frequently made in the analysis of GANs.


Leave a Reply