my take on variational inference with an interactive demo
The goal of this post is to show that variational inference is a natural way of thinking about Bayesian inference and not some shady approximate method.
Let’s start with the textbook definitions. We have a target distribution
\[p^\star(\theta) = \frac{\widetilde{p}(\theta)}{\mathcal{Z}},\]which we know up to its normalization constant \(\mathcal{Z}\). At the core, variational inference is a way to approximate \(p^\star(\theta)\) having only the ability to evaluate the unnormalized target \(\widetilde{p}(\theta)\). The target can be continuous or discrete (or mixed), there are no restrictions!
If we go on reading a textbook, it will tell us that variational inference “approximates” the target with a (simpler) distribution \(q_\psi(\theta)\) parameterized by \(\psi\).
For example, if \(q\) is a multivariate normal, \(\psi\) could be the mean and covariance matrix of the distribution, \(\psi = (\mu, \Sigma)\). Please note that while normal distributions are a common choice in variational inference, they are not the only one – you could choose \(q\) to be any distribution family of your choice!
That is a great question. Why can’t we just use the target as is? Because we can’t.
Yes, we may be able to evaluate \(\widetilde{p}(\theta)\) for any chosen value of \(\theta\), but that alone does not tell us much.
One way to compute some of these values (and not even all of them) might be to get samples from the target… but how do we get those? How do we draw samples from the target if we only know an unnormalized \(\widetilde{p}(\theta)\)?
In short, we have our largely-unusable target and we would like to replace it with something that is easy to use and compute with for all the quantities we care about. There is an imponderable word for that: we want a distribution which is tractable.
This is the magic of what variational inference does: it takes an intractable target distribution and it gives back a tractable approximation \(q\), belonging to a class of our choice. We are using here tractable in a loose sense, meaning that we expect these minimal properties of a respectable probability distribution:
There are more precise and nuanced definitions of tractability based on the specific type of probabilistic queries we can compute in polynomial time (e.g., marginals, conditionals, expectations, etc.), and you are encouraged to read Choi et al. (2020)
So, how does \(q\) approximate the target? Intuitively, we want \(q\) to be as similar as possible to the normalized target \(p^\star\).
So we can take a measure of discrepancy between two distributions, and say that we want that discrepancy to be as small as possible. Traditionally, variational inference chooses the reverse Kullback-Leibler (KL) divergence as its discrepancy function:
This measures how the approximation \(q_\psi(\theta)\) diverges (differs) from the normalized target distribution \(p^\star(\theta)\). It is reverse because we put the approximation \(q\) first (the KL is not symmetric). The direct KL divergence would have the “real” target distribution \(p^\star\) first.
So for a given family of approximating distributions \(q_\psi(\theta)\), variational inference chooses the best value of the parameters \(\psi\) that make \(q_\psi\) “as close as possible” to \(p^\star\) by minimizing the KL divergence between \(q_\psi\) and \(p^\star\).
Done? Not quite yet.
There is a caveat to the logic above: remember that we only have the unnormalized \(\widetilde{p}\), we do not have \(p^\star\)! However, it turns out that this is no problem at all. First, we present the main results, and we will provide a full derivation after, for the interested readers.
Minimizing the KL divergence between \(q_\psi\) and \(p^\star\) can be achieved by maximizing the so-called ELBO, or Evidence Lower BOund, defined as:
\[\text{ELBO}(q_\psi) = \underbrace{\int q_\psi(\theta) \log \widetilde{p}(\theta) \, d\theta}_{\text{Negative cross-entropy}} \; \underbrace{- \int q_\psi(\theta) \log q_\psi(\theta) \, d\theta}_{\text{Entropy}}.\]First, note that the ELBO only depends on \(q_\psi\) and \(\widetilde{p}\). The ELBO takes its name because it is indeed a lower bound to the log normalization constant, that is \(\log \mathcal{Z} \ge \text{ELBO}(\psi)\).
The ELBO is composed of two terms, a cross-entropy term between \(q\) and \(\widetilde{p}\) and the entropy of \(q\). The two terms represent opposing forces:
In conclusion, in variational inference we want to tweak the parameters \(\psi\) of \(q\) such that that the approximation \(q_\psi\) is as close as possible to \(p^\star\), according to the ELBO and, equivalently, to the KL divergence.
This is the full derivation of the ELBO, courtesy of o1-mini
and gpt-4o
, with just a sprinkle of human magic.
The reverse Kullback-Leibler (KL) divergence between \(q_\psi(\theta)\) and the normalized target \(p^\star(\theta)\) is:
\[\text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) = \int q_\psi(\theta) \log \frac{q_\psi(\theta)}{p^\star(\theta)} \, d\theta\]The normalized target \(p^\star(\theta)\) is related to the unnormalized target \(\widetilde{p}(\theta)\) through the normalization constant \(\mathcal{Z}\):
\[p^\star(\theta) = \frac{\widetilde{p}(\theta)}{\mathcal{Z}}, \quad \text{where} \quad \mathcal{Z} = \int \widetilde{p}(\theta) \, d\theta.\]Substitute this expression for \(p^\star(\theta)\) into the KL divergence:
\[\text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) = \int q_\psi(\theta) \log \left( q_\psi(\theta) \cdot \frac{\mathcal{Z}}{\widetilde{p}(\theta)} \right) \, d\theta\]Using the property of logarithms \(\log(ab) = \log(a) + \log(b)\), split the term inside the integral:
\[\text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) = \int q_\psi(\theta) \big( \log q_\psi(\theta) + \log \mathcal{Z} - \log \widetilde{p}(\theta) \big) \, d\theta\]Distribute \(q_\psi(\theta)\) over the sum:
\[\text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) = \int q_\psi(\theta) \log q_\psi(\theta) \, d\theta + \int q_\psi(\theta) \log \mathcal{Z} \, d\theta - \int q_\psi(\theta) \log \widetilde{p}(\theta) \, d\theta\]Since \(\mathcal{Z}\) is a constant, \(\log \mathcal{Z}\) is also constant and can be factored out of the integral:
\[\int q_\psi(\theta) \log \mathcal{Z} \, d\theta = \log \mathcal{Z} \int q_\psi(\theta) \, d\theta\]Because \(q_\psi(\theta)\) is a valid, normalized probability distribution, \(\int q_\psi(\theta) \, d\theta = 1\). Therefore:
\[\int q_\psi(\theta) \log \mathcal{Z} \, d\theta = \log \mathcal{Z}\]Substitute this simplification back into the KL divergence:
\[\text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) = \int q_\psi(\theta) \log q_\psi(\theta) \, d\theta + \log \mathcal{Z} - \int q_\psi(\theta) \log \widetilde{p}(\theta) \, d\theta\]Rearrange the equation to isolate \(\log \mathcal{Z}\), grouping terms related to \(q_\psi(\theta)\):
\[\log \mathcal{Z} = \text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) + \left( \int q_\psi(\theta) \log \widetilde{p}(\theta) \, d\theta - \int q_\psi(\theta) \log q_\psi(\theta) \, d\theta \right)\]The ELBO is defined as:
\[\text{ELBO}(q_\psi) = \int q_\psi(\theta) \log \widetilde{p}(\theta) \, d\theta - \int q_\psi(\theta) \log q_\psi(\theta) \, d\theta\]Substitute this into the equation for \(\log \mathcal{Z}\):
\[\log \mathcal{Z} = \text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta)) + \text{ELBO}(q_\psi)\]Rearranging to isolate \(\text{ELBO}(q_\psi)\):
\[\text{ELBO}(q_\psi) = \log \mathcal{Z} - \text{KL}(q_\psi(\theta) \,\mid\mid\, p^\star(\theta))\]Thus, minimizing the KL divergence is equivalent to maximizing the ELBO.
Moreover, since the $\text{KL}$ divergence is non-negative and zero if $p = q$:
Ignoring the fact that $\widetilde{p}(\theta)$ is not normalized, we can obtain the ELBO purely algebraically.
First, let’s (improperly) write the ELBO in terms of the KL divergence between $q$ and the unnormalized $\widetilde{p}$:
\[\text{ELBO}(q) = -\text{KL}(q \,\mid\mid\, \widetilde{p})\]Then we have four steps:
The much longer “full derivation” in the tab above is to avoid using the KL divergence for an unnormalized pdf, which is improper; but it is the same thing.
We can famously derive the ELBO using Jensen’s inequality, but it adds an unnecessary and potentially misleading “approximate” step, when we apply the inequality. I prefer the almost trivial derivation above, which shows the relationship between the ELBO and the KL divergence purely algebraically.
(You still need Jensen’s to show that the KL divergence is non-negative; but subjectively that feels just a property of the KL instead of being the ELBO doing something shady.)
While variational inference can be performed for any generic target density \(\widetilde{p}(\theta)\), the common scenario is that our target density is a posterior distribution:
\[{p^\star}(\theta) \equiv p(\theta \mid \mathcal{D}) = \frac{p(\mathcal{D} \mid \theta) \pi(\theta)}{p(\mathcal{D})}\]where you should recognize on the right-hand side good old Bayes’ theorem, with \(p(\mathcal{D} \mid \theta)\) the likelihood and \(\pi(\theta)\) the prior.
In essentially all practical scenarios we never know the normalization constant, but we can instead compute the unnormalized posterior:
\[\widetilde{p}(\theta) = p(\mathcal{D} \mid \theta) \pi(\theta).\]In this typical usage-case scenario for variational inference, the ELBO reads:
\[\text{ELBO}(q_\psi) = \mathbb{E}_{q_\psi(\theta)}\left[ \log p(\mathcal{D} \mid \theta) \pi(\theta) \right] - \mathbb{E}_{q_\psi(\theta)}\left[\log q_\psi(\theta)\right]\]where we simply replaced \(\widetilde{p}\) with the unnormalized posterior, and we switched here to the expectation notation, instead of integrals, just to show you how that would look like.
In conclusion, variational inference reduces Bayesian inference to an optimization problem. You have a candidate solution $q$, and you shake it and twist it and spread it around until you maximize the ELBO. Variational inference per se is nothing more than this.
You may have seen other introductions to or formulations of variational inference that may seem way more complicated. The point is that most variational inference algorithms focus on:
But don’t get confused: these are all implementation details. To reiterate, in principle you can just compute the expectation in the ELBO however you want and however it is convenient for you (e.g., by numerical integration, as we will do below), and move things around such that you maximize the ELBO. There is nothing more to it.
Of course, there are many clever things that can be done in various special cases (including exploiting variational calculus, hence the name), but none of those are necessary to understand variational inference. See Blei et al. (2017)
For the reasons mentioned above, I believe that variational inference is possibly the most natural way of thinking about Bayesian inference: computing the posterior is not some esoteric procedure, but we are just trying to find the distribution that best matches the true target posterior, which we know up to a constant.
Variational inference is often seen as “just an approximation method” – as opposed to a true technique for performing Bayesian inference – because historically we were forced to use very simple approximation families (factorized, simple Gaussians, etc.). However, it has been a while since we can use very flexible distributions, starting for example from the advent of normalizing flows in the 2010s. See the poignant review paper by Papamakarios et al. (2021).
But even old-school distributions such as mixtures of Gaussians can go a long way, as long as you use enough components; the difficulty there is to fit them accurately. For example,
In the widget below (app page) you can see variational inference at work for yourself. This works best on a computer, some aspects are not ideal on mobile.
You can select the target density as well as the family of variational posterior, from a single Gaussian with different constraints (isotropic, diagonal covariance, full covariance) to various mixtures of Gaussians.
Your job: click and drag around the distributions and change their parameters – or just lazily press Optimize – and see the ELBO value go up, getting closer and closer to the true \(\log \mathcal{Z}\), as far as the chosen posterior family allows.
It is very satisfying.
In the widget above, the ELBO is calculated via numerical integration on a grid centered around each Gaussian component, and the gradient used for optimization is calculated via finite differences. That’s it, nothing fancy.
Incidentally, I spent way too much time coding up this widget, even with the help of Claude. Still, I am pretty happy with the result given that I knew zero JavaScript when I started; and I would not have done it if I had to learn JavaScript just for this. I will probably write a blog post about the process at some point.
I will be hiring postdocs in early 2025 to work on extending Variational Bayesian Monte Carlo
and related topics. If interested, please get in touch – we can also meet at NeurIPS in Vancouver!