Jekyll2020-03-31T21:03:52+00:00https://omarelb.github.io/feed.xmlOmar’s BlogA blog with topics on Artificial Intelligence.Omar Elbaghdadiomarelblog@gmail.comSelf-Explaining Neural Networks: A Review2020-03-25T00:00:00+00:002020-03-25T00:00:00+00:00https://omarelb.github.io/self-explaining-neural-networks<p>For many applications, understanding <em>why</em> a predictive model makes a certain prediction can be of crucial importance. In the paper <a href="http://papers.nips.cc/paper/8003-towards-robust-interpretability-with-self-explaining-neural-networks">“Towards Robust Interpretability with Self-Explaining Neural Networks”</a>, David Alvarez-Melis and Tommi Jaakkola propose a neural network model that takes interpretability of predictions into account <em>by design</em>. In this post, we will look at how this model works, how reproducible the paper’s results are, and how the framework can be extended.</p>
<!--more-->
<p>First, a bit of context. This blog post is a by-product of a student project that was
done with a group of 4 AI grad students: <a href="https://github.com/AmanDaVinci">Aman Hussain</a>, <a href="https://github.com/ChristophHoenes">Chris Hoenes</a>, <a href="https://www.linkedin.com/in/ivan-bardarov/">Ivan Bardarov</a>, and me. The goal of
the project was to find out how reproducible the results of the aforementioned paper are, and what ideas could be extended. To
do this, we re-implemented the framework from scratch. We have released this
implementation as a package <a href="https://github.com/AmanDaVinci/SENN">here</a>. In
general, we find that the explanations generated by the framework are not very
interpretable. However, the authors provide several valuable ideas, which we use to propose and evaluate several improvements.</p>
<p>Before we dive into the nitty gritty details of the model, let’s talk about why
we care about explaining our predictions in the first place.</p>
<h2 id="transparency-in-ai">Transparency in AI</h2>
<p>As ML systems become more omnipresent in societal applications such as banking and healthcare, it’s crucial that they should satisfy some important criteria such as safety, not discriminating against certain groups, and being able to provide the right to explanation of algorithmic decisions. The last criterion is enforced by data policies such as the <a href="https://en.wikipedia.org/wiki/General_Data_Protection_Regulation">GDPR</a>, which say that people subject to the decisions of algorithms have a right to explanations of these algorithms.</p>
<p>These criteria are often hard to quantify. Instead, a proxy notion is regularly made use of: <em>interpretability</em>. The idea is that if we can explain <em>why</em> a model is making the predictions it does, we can check whether that reasoning is <em>reliable</em>. Currently, there is not much agreement on what the definition of interpretability should be or how to evaluate it.</p>
<p>Much recent work has focused on interpretability methods that try to understand a model’s inner workings <strong>after</strong> it has been trained. Well known examples are <a href="https://github.com/marcotcr/lime">LIME</a> and <a href="https://github.com/slundberg/shap">SHAP</a>. Most of these methods make no assumptions about the model to be explained, and instead treat them like a black box. This means they can be used on <em>any</em> predictive model.</p>
<p>Another approach is to design models that are transparent <strong>by design</strong>. This approach is also taken by Alvarez-Melis and Jaakkola, who I will be referring to as ‘‘the authors’’ from now on. They propose a self-explaining neural network (SENN) that optimises for transparency <strong>during the learning process</strong>, (hopefully) without sacrificing too much modeling power.</p>
<h2 id="self-explaining-neural-networks-senn">Self Explaining Neural Networks (SENN)</h2>
<p>Before we design a model that generates explanations, we must first think about
what properties we want these explanations to have. The authors suggest that
explanations should have the following properties:</p>
<ul>
<li><strong>Explicitness</strong>: The explanations should be <em>immediate</em> and <em>understandable</em>.</li>
<li><strong>Faithfulness</strong>: Importance scores assigned to features should be <em>indicative of “true” importance</em>.</li>
<li><strong>Stability</strong>: Similar examples should yield <em>similar explanations</em>.</li>
</ul>
<h3 id="linearity">Linearity</h3>
<p>The authors start out with a linear regression model and generalize from there.
The plan is to generalize the linear model by allowing it to be more complex, while
retaining the interpretable properties of a linear model.
This approach is motivated by arguing that a linear model is inherently interpretable.
Despite this claim’s enduring popularity, it is not always entirely valid<sup id="fnref:2"><a href="#fn:2" class="footnote">1</a></sup>.
Regardless, we continue with the linear model.</p>
<p>Say we have input features <script type="math/tex">x_1, \dots, x_n</script>
and parameters <script type="math/tex">\theta_1, \dots, \theta_n</script>. The linear model (omitting bias for clarity) returns the following
prediction:
<script type="math/tex">f(x) = \sum_{i}^{k} \theta_i x_{i}.</script></p>
<h3 id="basis-concepts">Basis Concepts</h3>
<p>The first step towards interpretability is taken by first computing
<em>interpretable feature representations</em> <script type="math/tex">h(x)</script> of the input <script type="math/tex">x</script>, which are called
<strong>basis concepts</strong> (or concepts). Instead of acting on the input directly, the model acts on these basis concepts:
<script type="math/tex">f(x) = \sum_{i}^{k} \theta_i h(x)_{i}.</script></p>
<!-- reasoning --------------------------------- -->
<!-- Raw input features are the natural basis for interpretability when the input is low-dimensional and -->
<!-- individual features are meaningful. For high-dimensional inputs, raw features (such as individual -->
<!-- pixels in images) often lead to noisy explanations that are sensitive to imperceptible artifacts in the -->
<!-- data, tend to be hard to analyze coherently and not robust to simple transformations such as constant -->
<!-- shifts [9]. Furthermore, the lack of robustness of methods that relies on raw inputs is amplified for -->
<!-- high-dimensional inputs, as shown in the next section. To avoid some of these shortcomings, we can -->
<!-- instead operate on higher level features. In the context of images, we might be interested in the effect -->
<!-- of textures or shapes—rather than single pixels—on predictions. -->
<p>We see that the final prediction is some linear combination of interpretable
concepts. The <script type="math/tex">\theta_i</script> can now be interpreted as importance or <strong>relevance
scores</strong> for a certain concept <script type="math/tex">h(x)_i</script>.</p>
<p>Say we are given an image <script type="math/tex">x</script> of a digit and we want to detect which digit it
is. Then each concept <script type="math/tex">h(x)_i</script> might for instance encode stroke width,
global orientation, position, roundness, and so on.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/concepts.png" alt="concepts" title="opt title" /></p>
<p>Concepts like this could be generated by domain experts, but this is expensive and in many cases infeasible. An alternative approach is to <a href="https://arxiv.org/abs/1711.11279">learn the concepts directly</a>. It cannot be stressed enough that <em>the interpretability of the whole framework depends on how interpretable the concepts are</em>. Therefore, the authors propose <strong>three properties concepts should have</strong> and how to enforce them:</p>
<ol>
<li>
<p><em>Fidelity</em>: The representation of <script type="math/tex">x</script> in terms of concepts should <strong>preserve relevant information</strong>.</p>
<p>This is enforced by learning the concepts <script type="math/tex">h(x)</script> as the latent encoding of an <a href="https://www.jeremyjordan.me/autoencoders/">autoencoder</a>. An autoencoder is a neural network that learns to map an input <script type="math/tex">x</script> to itself by first encoding it into a <strong>lower dimensional representation</strong> with an encoder network <script type="math/tex">h</script> and then creating a reconstruction <script type="math/tex">\hat{x}</script> with a decoder network <script type="math/tex">h_\mathrm{dec}</script>, i.e. <script type="math/tex">\hat{x} = h_\mathrm{dec}(h(x))</script>. The lower dimensional representation <script type="math/tex">h(x)</script>, which we call its <em>latent representation</em>, is a vector in a space we call the latent space. <strong>It therefore needs to capture the most important information contained in <script type="math/tex">x</script></strong>. This can be thought of as a nonlinear version of <a href="http://setosa.io/ev/principal-component-analysis/">PCA</a>.</p>
</li>
<li>
<p><em>Diversity</em>: Inputs should be representable with <strong>few non-overlapping concepts</strong>.</p>
<p>This is enforced by making the autoencoder mentioned above <strong>sparse</strong>. A <a href="https://medium.com/@syoya/what-happens-in-sparse-autencoder-b9a5a69da5c6">sparse autoencoder</a> is one in which <strong>only a relatively small subset of the latent dimensions activate for any given input</strong>. While this indeed forces an input to be representable with <em>few</em> concepts, it does not really guarantee that they should be <em>non-overlapping</em>.</p>
</li>
<li>
<p><em>Grounding</em>: Concepts should have an <strong>immediate human-understandable interpretation</strong>.</p>
<p>This is a more subjective criterion. The authors aim to do this by providing interpretations for concepts by <a href="https://arxiv.org/abs/1710.04806"><strong>prototyping</strong></a>. For image data, they find a set of observations that <strong>maximally activate</strong> a certain concept and use those as the <em>representation</em> for that concept. While this may seem reasonable at first, we will see later that this approach is quite problematic.</p>
</li>
</ol>
<p>As mentioned above, the approaches taken by the authors to achieve diversity and grounding have
quite some problems. The extension that we introduce aims to mitigate these.</p>
<h3 id="keeping-it-linear">Keeping it Linear</h3>
<p>In the last section, we introduced basis concepts. However, the model is still
too simple, since the relevance scores <script type="math/tex">\theta_i</script> are fixed. We therefore
make the next generalization: the relevance scores are now also a function of
the input. This leads us to the final model<sup id="fnref:most_general"><a href="#fn:most_general" class="footnote">2</a></sup>:</p>
<script type="math/tex; mode=display">f(x) = \sum_{i}^{k} \theta(x)_i h(x)_{i}.</script>
<p>To make the model sufficiently complex, the function computing relevance scores
<script type="math/tex">\theta</script> is actualized by a neural network. A problem with neural networks,
however, is that they are not very stable. A small change to the input may lead
to a large change in relevance scores. This goes against the <em>stability</em>
criterion for explanations introduced earlier. To combat this, the authors propose
to <strong>make <script type="math/tex">\theta</script> behave linearly in local regions, while still being sufficiently
complex globally</strong>. This means that a small change in <script type="math/tex">h</script> should lead to only a small
change in <script type="math/tex">\theta</script>. To do this, they add a regularization term, which we call the
robustness loss<sup id="fnref:robustness_loss"><a href="#fn:robustness_loss" class="footnote">3</a></sup>, to the loss function.</p>
<p>In a classification setting with multiple classes, relevance
scores <script type="math/tex">\theta</script> are estimated for each class separately. An <strong>explanation</strong> is
then given by the concepts and their corresponding relevance scores<sup id="fnref:product"><a href="#fn:product" class="footnote">4</a></sup>.
The figure below shows an example of such an explanation.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/senn_explanation.jpg" alt="pic alt" title="opt title" /></p>
<p>The authors interpret such an explanation in the following way. Looking at the
first input, a “9”, we see that concept 3 has a positive contribution to the
prediction that it is a 9. Concept 3 seems to represent a horizontal dash
together with a diagonal stroke, which seems to be present in the image of a 9.
As you may have noticed, these explanations also require some imagination.</p>
<h2 id="implementation">Implementation</h2>
<p>Although a public implementation of SENN <em>is</em> <a href="https://github.com/dmelis/senn">available</a>, the authors have not officially released that code with the paper. There also seems to be a major bug in this code. Therefore, we re-implement the framework with the original paper as ground truth.</p>
<p>On a high level, the SENN model consists of three main building blocks: a <em>parameterizer</em> <script type="math/tex">\theta</script>, a <em>conceptizer</em> <script type="math/tex">h</script>, and an aggregator <script type="math/tex">g</script>, which is just the sum function in this case. The parameterizer is actualized by a neural network, and the conceptizer is actualized by an autoencoder. The specific implementations of these networks may vary. For tabular data, we may use fully connected networks and for image data, we may use convolutional networks. Here is an overview of the model:
<img src="/assets/images/posts/2020/self-explaining-neural-networks/teaser.png" alt="pic alt" title="opt title" /></p>
<p>To train the model, the following loss function is minimized:</p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L} := \mathcal{L}_y(f(x), y) + \lambda \mathcal{L}_\theta(f(x)) + \xi \mathcal{L}_h(x), \label{eq:total_loss}
\end{align}</script>
<p>where</p>
<ul>
<li><script type="math/tex">\mathcal{L}_y(f(x), y)</script> is the <em>classification loss</em>, i.e. how well the model predicts the ground truth label.</li>
<li><script type="math/tex">\mathcal{L}_\theta(f(x))</script> is the <em>robustness loss</em>. <script type="math/tex">\lambda</script> is a regularization parameter controlling how heavily robustness is enforced.</li>
<li><script type="math/tex">\mathcal{L}_h(x)</script> is the <em>concept loss</em>. The concept loss is a sum of 2 different losses: <em>reconstruction loss</em> and <em>sparsity loss</em>. <script type="math/tex">\xi</script> is a regularization parameter on the concept loss.</li>
</ul>
<h2 id="reproducibility">Reproducibility</h2>
<p>If it is impossible to reproduce a paper’s results, it is hard to
verify the validity of the obtained results. Even if the results are valid, it
will be hard to build upon that work. Reproducibility has therefore recently been an
important topic in many scientific disciplines. In AI research, there
have been initiatives to move towards more reproducible research. NeurIPS, one of the
biggest conferences in AI, introduced a <a href="https://www.cs.mcgill.ca/~jpineau/ReproducibilityChecklist.pdf">reproducibility
checklist</a> last
year. Although it is not mandatory, it is a step in the right direction.</p>
<h3 id="results">Results</h3>
<p>The main goal of our project was to reproduce a subset of the results
presented by the authors. We find that while we are able to reproduce some results,
<em>we are not able to reproduce all of them</em>.</p>
<p>In particular, we are able to <strong>achieve similar test accuracies</strong> on the Compas and MNIST
datasets, a tabular and image dataset respectively, which seems to validate the authors’
claim that SENN models have high modelling capacity. Even though this is the case,
the MNIST and Compas datasets are of low complexity. The relatively high performance
on these datasets is therefore <em>not sufficient to show that SENN models are on par
with non-explaining state-of-the-art models</em>.</p>
<p>We also look at the <strong>trade off between robustness of explanations and model
performance</strong>. We see that enforcing robustness more decreases classification accuracy.
This behavior matches up with the authors’ findings. However, we find
<strong>the accuracy drop for increasing regularization to be significantly larger
than reported by the authors</strong>.</p>
<p>Assessing the <strong>quality of explanations</strong> is inherently subjective. It is therefore difficult
to link the quality of explanations to reproducibility. However, we can still
partially judge reproducibility by qualitative analysis. In general, we see that <strong>finding
an example whose explanation “makes sense” is difficult</strong> and that such an example is
<em>not representative of the generated examples</em>. Therefore, we conclude that obtaining
good explanations is, in that sense, not reproducible.</p>
<h3 id="interpretation">Interpretation</h3>
<p>How do we interpret this failure to reproduce? Although it is tempting to
say that the initial results are invalid, this might be too harsh and
unfair. There are many factors that influence an experiment’s results.
Although our results are discouraging, more experiments
need to be done to have a better estimate of the validity of the framework.
Even if this is the case, the authors provide
a useful starting point for further research. We take a first step, which
involves improving the learned concepts.</p>
<h2 id="disentangled-senn-disenn">Disentangled SENN (DiSENN)</h2>
<p>As we have seen, the current SENN framework hardly generates interpretable explanations.
One important part of this is <strong>the way concepts are represented and learned</strong>.
A single concept is represented by a set of data samples that maximizes that concept’s activation.
The authors reason that we can read off what the concept represents by examining these samples. However, this approach is
problematic.</p>
<p>Firstly, only showcasing the samples that maximize an encoder’s activation is quite
arbitrary. One could just as well showcase samples that <em>minimize</em> the activation
instead, or use any other method. These different approaches lead to <em>different interpretations
even if the learned concepts are the same</em>. The authors also hypothesize that a sparse
autoencoder will lead to <strong>diverse</strong> concepts. However, enforcing only a subset
of dimensions to activate for an input <strong>does not explicitly enforce that these concepts
should be non-overlapping</strong>. For digit recognition, one concept
may for example represent both thickness and roundness at the same time. That is, each
single concept may represent a mix of features that are <em>entangled</em> in some complex
way. This greatly increases the complexity of interpreting any concept on its own. An image, for example, is formed from the complex interaction of light sources, object shapes, and their material properties. We call each of these component features a <em>generative factor</em>.</p>
<p>To enhance concept interpretability, we therefore propose to explicitly enforce <strong>disentangling
the factors of variation in the data (generative factors)</strong> and using these as concepts instead.
Different explanatory factors of the data tend to change independently
of each other in the input distribution, and only a few at a time
tend to change when one considers a sequence of consecutive
real-world inputs. Matching a single generative factor to a single latent dimension
allows for easier human interpretation<sup id="fnref:representation_learning"><a href="#fn:representation_learning" class="footnote">5</a></sup>.</p>
<p>We enforce disentanglement by using
a <a href="https://openreview.net/forum?id=Sy2fzU9gl"><script type="math/tex">\beta</script>-VAE</a>, a variant
of the <a href="https://arxiv.org/abs/1312.6114">Variational Autoencoder (VAE)</a>,
to learn the concepts. If you are unfamiliar with VAEs, I suggest reading <a href="https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">this blog post</a>.
Whereas a normal autoencoder maps an input to a single point in the latent
space, a VAE maps an input to a <em>distribution</em> in the latent space. We want,
and enforce, this latent distribution to look “nice” in some sense. A unit Gaussian
distribution is generally the go-to choice for this. This is done by minimizing the <a href="https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained">KL-divergence</a> between
a Gaussian and the learned latent distribution.</p>
<p><script type="math/tex">\beta</script>-VAE introduces a hyperparameter <script type="math/tex">\beta</script> that enables a heavier regularization on the
latent distribution (i.e. higher KL-divergence penalty). The higher <script type="math/tex">\beta</script>, the more the latent space will
be encouraged to look like a unit Gaussian. Since the dimensions
of a unit Gaussian are independent, these latent factors will be
encouraged to be independent as well.</p>
<p>Each latent space dimension corresponds to a concept. Since interpolation in the latent
space of a VAE can be done meaningfully, a concept can be represented by
varying its corresponding dimension’s values and plugging it into the decoder<sup id="fnref:disenn"><a href="#fn:disenn" class="footnote">6</a></sup>.
It is easier to understand visually:</p>
<p align="center">
<img width="460" height="300" src="/assets/images/posts/2020/self-explaining-neural-networks/numbers_small.gif" />
</p>
<p>We see here that the first row represents the width of the number, while the
second row represents a number’s angle. <a href="https://github.com/YannDubs/disentangling-vae">This
figure</a> is created by mapping
an image to the latent space, changing one dimension in this latent space and
then seeing what happens to the reconstruction.</p>
<p>We thus have a principled way to generate prototypes, since meaningful latent space
traversal is an inherent property of VAEs. Another advantage is that prototypes are
not constrained to the input domain. The prototypes
generated by the DiSENN are more complete than highest activation prototypes,
since they showcase a much larger portion of a concept dimension’s latent space. Seeing
the transitions in concept space provides a more intuitive idea of what the concept
means.</p>
<p>We call a disentangled SENN model with <script type="math/tex">\beta</script>-VAE as the conceptizer a <strong>DiSENN</strong>
model. The following figure shows an overview of the DiSENN model. Only the
conceptizer subnetwork has really changed.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/disenn.png" alt="pic alt" title="opt title" /></p>
<p>We now examine the DiSENN explanations by analyzing a generated DiSENN
explanation for the digit 7.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/DiSENN_explanation.png" alt="pic alt" title="opt title" /></p>
<p>The contribution of concept <script type="math/tex">i</script> to the prediction of a class <script type="math/tex">c</script> is given by the product
of the corresponding relevance and concept activation <script type="math/tex">\theta_{ic} \cdot h_i</script>. First, we look
at how the concept prototypes are interpreted. To see what a concept encodes, we observe
the changes in the prototypes in the same row. Taking the second row as an example,
we see a circular blob slowly disconnect at the left corner to form a 7, and then
morph into a diagonal stroke. This explains the characteristic diagonal stroke of
a 7 connected with the horizontal stroke at the right top corner but disconnected
otherwise. As expected, this concept has a positive contribution to the prediction
for the real class, digit 7, and a negative contribution to that of another incorrect
class, digit 5.</p>
<p>However, despite the hope that disentanglement encourages diversity, we observe that concepts still demonstrate overlap. This can be seen from concepts 1 and 2 in the previous figure. This means that the concepts are still not disentangled enough, and the problem of interpretability, although alleviated, remains. The progress of good explanations using the DiSENN framework therefore depends on the progress of research in disentanglement.</p>
<h2 id="what-next">What next?</h2>
<p>In this post, we’ve talked about the need for transparency in AI. Self Explaining Neural Networks
have been proposed as one way to achieve transparency even with highly complex
models. We’ve reviewed this framework and the reproducibility of its results,
and find that there is still a lot of work to be done. One promising idea for
extension is to use disentanglement to create more interpretable feature
representations. Interesting further work would be to test how well this
approach works on datasets for which we <em>know</em> the ground truth latent
generative factors, such as the
<a href="https://github.com/deepmind/dsprites-dataset">dSprites dataset</a>.</p>
<p>This blog post is largely based on the project report for this course, which goes into more technical detail. The interested reader can find it <a href="https://github.com/uva-fact-ai-course/uva-fact-ai-course/blob/master/SelfExplainingNNs/report.pdf">here</a>.</p>
<p>I thank Chris, Ivan, and Aman for being incredible project teammates and Simon Passenheim for his guidance throughout the project. Thanks for reading!</p>
<p><strong>Footnotes</strong></p>
<div class="footnotes">
<ol>
<li id="fn:2">
<p>According to <a href="https://arxiv.org/abs/1606.03490">Zachary Lipton</a>: “When choosing between linear and deep models, we must often make a trade-off between <em>algorithmic transparency</em> and <em>decomposability</em>. This is because deep neural networks tend to operate on raw or lightly processed features. So if nothing else, the features are intuitively meaningful, and post-hoc reasoning is sensible. However, in order to get comparable performance, linear models often must operate on heavily hand-engineered features (which may not be very interpretable).” <a href="#fnref:2" class="reversefootnote">↩</a></p>
</li>
<li id="fn:most_general">
<p>The authors actually generalize the model one step further by introducing an <em>aggregation function</em> <script type="math/tex">g</script> such that the final model is given by</p>
<script type="math/tex; mode=display">f(x) = g(\theta(x)_1h(x)_1, \ldots, \theta(x)_kh(x)_k),</script>
<p>where <script type="math/tex">g</script> has properties such that interpretability is maintained. However, for all intents and purposes, it is most reasonable to use the sum function, which is what the authors do in all their experiments as well. <a href="#fnref:most_general" class="reversefootnote">↩</a></p>
</li>
<li id="fn:robustness_loss">
<p>The robustness loss is given by</p>
<script type="math/tex; mode=display">\begin{equation}
\mathcal{L}_\theta := ||\nabla_x f(x) - \theta(x)^{\mathrm{T}} J_x^{h}(x)||,
\end{equation}</script>
<p>where <script type="math/tex">J_x^h(x)</script> is the Jacobian of <script type="math/tex">h</script> with respect to <script type="math/tex">x</script>. The idea
is that we want <script type="math/tex">\theta(x_0)</script> to behave as the derivative of <script type="math/tex">f</script> with respect to <script type="math/tex">h(x)</script>
around <script type="math/tex">x_0</script> , i.e., we seek <script type="math/tex">\theta(x_0) \approx \nabla_z f</script>. For more
detailed reasoning, see the paper. <a href="#fnref:robustness_loss" class="reversefootnote">↩</a></p>
</li>
<li id="fn:product">
<p>It actually does not make sense to look only at relevance scores. We have to take into account the product <script type="math/tex">\theta_i\cdot h_i</script>, since it’s this product that determines the contribution to the class prediction. If an <script type="math/tex">h_i</script> has a negative activation, then a positive relevance leads to a negative overall contribution. <a href="#fnref:product" class="reversefootnote">↩</a></p>
</li>
<li id="fn:representation_learning">
<p>A disentangled representation may be viewed as a concise representation of the variation in data we care about most – the generative factors. See <a href="https://arxiv.org/abs/1206.5538">a review of representation learning</a> for more info. <a href="#fnref:representation_learning" class="reversefootnote">↩</a></p>
</li>
<li id="fn:disenn">
<p>Let an input <script type="math/tex">x</script> produce the Gaussian encoding distribution for a single concept <script type="math/tex">h(x)_i = \mathcal{N}(\mu_i, \sigma_i)</script>. The concept’s activation for this input is then given by <script type="math/tex">\mu_i</script>. We then vary a single latent dimension’s values around <script type="math/tex">\mu_i</script> while keeping the others fixed, call it <script type="math/tex">\mu_c</script>. If the concepts are disentangled, a single concept should encode only a single generative factor of the data. The changes in the reconstructions <script type="math/tex">\mathrm{decoder}(\mu_c)</script> will show which generative factor that latent dimension represents. We plot these changes in the reconstructed input space to visualize this. <script type="math/tex">\mu_c</script> is sampled linearly in the interval <script type="math/tex">[\mu_i - q, \mu_i + q]</script>, where <script type="math/tex">q</script> is some quantile of <script type="math/tex">h(x)_i</script>. <a href="#fnref:disenn" class="reversefootnote">↩</a></p>
</li>
</ol>
</div>Omar Elbaghdadiomarelblog@gmail.comFor many applications, understanding why a predictive model makes a certain prediction can be of crucial importance. In the paper “Towards Robust Interpretability with Self-Explaining Neural Networks”, David Alvarez-Melis and Tommi Jaakkola propose a neural network model that takes interpretability of predictions into account by design. In this post, we will look at how this model works, how reproducible the paper’s results are, and how the framework can be extended.Variational Bayesian Inference: A Fast Bayesian Take on Big Data.2019-08-22T00:00:00+00:002019-08-22T00:00:00+00:00https://omarelb.github.io/variational-bayes<p>Compared to the frequentist paradigm, <a href="https://en.wikipedia.org/wiki/Bayesian_inference">Bayesian inference</a> allows more readily for dealing with and interpreting uncertainty, and for easier incorporation of prior beliefs.</p>
<p>A big problem for traditional Bayesian inference methods, however, is that they are <strong>computationally expensive</strong>. In many cases, computation takes too much time to be used reasonably in research and application. This problem gets increasingly apparent in today’s world, where we would like to make good use of the <strong>large amounts of data</strong> that may be available to us.</p>
<!--more-->
<p>Enter our savior: <strong>variational inference (VI)</strong>—a much faster method than those used traditionally. This is great, but as usual, there is no such thing as free lunch, and the method has some caveats. But all in due time.</p>
<p>This write-up is mostly based on the first part of the <a href="https://www.youtube.com/watch?v=DYRK0-_K2UU">fantastic 2018 ICML tutorial session on the topic</a> by professor <a href="https://people.csail.mit.edu/tbroderick/">Tamara Broderick</a>. If you like video format, I would recommend checking it out.</p>
<h1 id="overview">Overview</h1>
<p>The post is outlined as follows:</p>
<ul>
<li>What is Bayesian inference and why do we use it in the first place</li>
<li>How Bayesian inference works—a quick overview</li>
<li>The problem, a solution, and a faster solution</li>
<li>Variational Inference and the Mean Field Variational Bayes (MFVB) framework</li>
<li>When can we trust our method</li>
<li>Conclusion</li>
</ul>
<h3 id="a-birds-eye-view">A bird’s-eye view:</h3>
<p>We need Bayesian inference whenever we want to know the <strong>uncertainty of our estimates</strong>. Bayesian inference works by specifying some <strong>prior belief distribution</strong>, and <strong>updating our beliefs</strong> about that distribution with data, based on the <strong>likelihood</strong> of observing that data. We need <strong>approximate algorithms</strong> because standard algorithms need too much time to give usable estimates. <strong>Variational inference</strong> uses <strong>optimization</strong> instead of estimation to <strong>approximate</strong> the true distribution. We get results <strong>much more quickly</strong>, but they are <strong>not always correct</strong>. We have to find out <strong>when we can trust the obtained results</strong>.</p>
<p class="notice--info">A note on notation: <script type="math/tex">P(\cdot)</script> is used to describe both probabilities and probability distributions.</p>
<h1 id="what-is-bayesian-inference-and-why-do-we-use-it-in-the-first-place">What is Bayesian inference and why do we use it in the first place</h1>
<p>Probability theory is a mathematical framework for reasoning about <strong>uncertainty</strong>. Within the subject exist two major schools of thought, or paradigms: the <strong>frequentist</strong> and the <strong>Bayesian</strong> paradigms. In the frequentist paradigm, probabilities are interpreted as average outcomes of random repeatable events, while the Bayesian paradigm provides a way to reason about probability as <strong>a measure of uncertainty</strong>.</p>
<p><strong>Inference</strong> is the process of finding properties of a population or probability distribution from data. Most of the time, these properties are encoded by <strong>parameters</strong> that govern our model of the world.</p>
<p>In the frequentist paradigm, a parameter is assumed to be a fixed quantity unknown to us. Then, a method such as <strong>maximum likelihood (ML)</strong> is used to obtain a <strong>point estimate</strong> (a single number) of the parameter. In the Bayesian paradigm, parameters are not seen as fixed quantities, but as random variables themselves. The <strong>uncertainty in the parameters</strong> is then specified by a <strong>probability distribution</strong> over its values. Our job is to find this probability distribution over parameters <strong>given our data and prior beliefs.</strong></p>
<p>The frequentist and Bayesian paradigms both have their pros and cons, but there are multiple reasons why we might want to use a Bayesian approach. The following reasons are given in a <a href="https://www.quora.com/What-are-some-good-resources-to-learn-about-Bayesian-probability-for-machine-learning-and-how-should-I-structure-my-learning">Quora answer by Peadar Coyle</a>. To summarize:</p>
<ul>
<li>
<p><strong>Explicitly modelling your data generating process</strong>: You are forced to <strong>think carefully about your assumptions</strong>, which are often implicit in other methods.</p>
</li>
<li>
<p><strong>No need to derive estimators</strong>: Being able to treat model fitting as an abstraction is great for <strong>analytical productivity</strong>.</p>
</li>
<li>
<p><strong>Estimating a distribution</strong>: You <strong>deeply understand uncertainty</strong> and get a full-featured input into any downstream decision you need to make.</p>
</li>
<li>
<p><strong>Borrowing strength / sharing information</strong>: A common feature of Bayesian analysis is <strong>leveraging multiple sources of data</strong> (from different groups, times, or geographies) to share related parameters through a prior. This can help enormously with precision.</p>
</li>
<li>
<p><strong>Model checking as a core activity</strong>: There are <strong>principled, practical procedures</strong> for considering a wide range of <strong>models that vary in assumptions and flexibility</strong>.</p>
</li>
<li>
<p><strong>Interpretability of posteriors</strong>: What a posterior means <strong>makes more intuitive sense</strong> to people than most statistical tests.</p>
</li>
</ul>
<p>The debate between frequentists and Bayesians about which is better can get quite intense. I personally believe that no single point of view is better in any situation. We need to think carefully and apply the method that is most appropriate for a given situation, be it frequentist or Bayesian. One infamous <a href="https://www.xkcd.com/">xkcd</a> comic, given below, addresses this debate.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/frequentists-vs-bayesians.png" alt="image" style="width:400px; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
xkcd comic on frequentist vs Bayesian views. The comic was quite controversial itself. Many thought that the frequentist was treated unfairly. The artist himself <a href="http://web.archive.org/web/20130117080920/http://andrewgelman.com/2012/11/16808/#comment-109366">later commented</a>:
<blockquote>
I meant this as a jab at the kind of shoddy misapplications of statistics I keep running into in things like cancer screening (which is an emotionally wrenching subject full of poorly-applied probability) and political forecasting. I wasn’t intending to characterize the merits of the two sides of what turns out to be a much more involved and ongoing academic debate than I realized.
A sincere thank you for the gentle corrections; I’ve taken them to heart, and you can be confident I will avoid such mischaracterizations in the future!
</blockquote>
Another discussion can be found <a href="https://www.lesswrong.com/posts/mpTEEffWYE6ZAs7id/xkcd-frequentist-vs-bayesians">here</a>.
</figcaption>
</figure>
</div>
<h2 id="what-problems-is-it-used-for">What problems is it used for?</h2>
<p>There are many cases in which we care not only about our estimates, but also how confident we are in those estimates. Some examples:</p>
<ul>
<li>
<p><strong>Finding a wreckage</strong>: In 2009, a passenger plane crashed over the atlantic ocean. For two years, investigators had not been able to find the wreckage of the plane. In the third year, after bringing in Bayesian analysis, the wreckage was found after one week of undersea search <a href="https://arxiv.org/pdf/1405.4720.pdf">(Stone et al 2014)</a>!</p>
</li>
<li>
<p><strong>Routing</strong>: Understanding the time it takes for vehicles to get from point A to point B. This could be ambulance routing for instance. Knowing the uncertainty in the estimates is important when planning <a href="https://people.orie.cornell.edu/woodard/WoodNogiKoch17.pdf">(Woodard et al 2017)</a>.</p>
</li>
<li>
<p><strong>Microcredit</strong>: Is microcredit actually helping? Knowing the extent of the microcredit effect and our certainty about it may be used to make decisions such as making an investment <a href="https://economics.mit.edu/files/11443">(Meager et al 2019)</a>.</p>
</li>
</ul>
<p>These are just some of the applications. There are many more. Hopefully you are convinced that we are doing something useful. Let’s move on to the actual techniques.</p>
<h1 id="how-bayesian-inference-worksa-quick-overview">How Bayesian inference works—a quick overview</h1>
<p>The <strong>first step</strong> in any inference job is defining a <strong>model</strong>. We might for example model heights in a population of, say, penguins, as being generated by a Gaussian distribution with mean <script type="math/tex">\mu</script> and variance <script type="math/tex">\sigma^2</script>.</p>
<p>Our goal is then to find a probability distribution over the parameters in our model, <script type="math/tex">\mu</script> and <script type="math/tex">\sigma^2</script>, given the data that we have collected. This distribution, also called the <strong>posterior</strong>, is given by</p>
<script type="math/tex; mode=display">P(\theta \vert y_{1:n}) = \frac{P(y_{1:n} \vert \theta) P(\theta)}{P(y_{1:n})},</script>
<p>where <script type="math/tex">y_{1:n}</script> represents the dataset containing <script type="math/tex">n</script> observations and <script type="math/tex">\theta</script> represents the parameters. This identity is known as <strong>Bayes’ Theorem</strong>.</p>
<p>In words, the posterior is given by a product of the <strong>likelihood</strong> <script type="math/tex">P(y_{1:n} \vert\theta)</script> and the <strong>prior</strong> <script type="math/tex">P(\theta)</script>, <strong>normalized by the evidence</strong> <script type="math/tex">P(y_{1:n})</script>.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/teaser.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center"> Sketch of the Bayesian update. The posterior (blue) is obtained after multiplying the prior (red) with the likelihood (black). In a sequential estimation procedure, the new prior is the posterior obtained in the previous step. </figcaption>
</figure>
</div>
<p>The likelihood <script type="math/tex">P(y_{1:n} \vert\theta)</script> is often seen as a function of <script type="math/tex">\theta</script>, and tells us how likely it is to have observed our data given a specific setting of the parameters. <strong>The prior <script type="math/tex">P(\theta)</script> encapsulates beliefs</strong> we have about the parameters before observing any data. <strong>We update our beliefs about the parameter distribution after observing data</strong> according to Bayes’ rule.</p>
<p>After obtaining the posterior distribution, <strong>we would like to report a summary</strong> of it. We usually do this by providing a <strong>point estimate and an uncertainty</strong> surrounding the estimate. A point estimate may for example be given by the <strong>posterior mean or mode</strong>, and uncertainty by the <strong>posterior (co)variances</strong>.</p>
<p>In summary, there are <strong>three major steps</strong> involved in doing Bayesian inference:</p>
<ol>
<li><strong>Choose a model</strong> i.e. choose a prior and likelihood.</li>
<li><strong>Compute the posterior</strong>.</li>
<li><strong>Report a summary</strong>: point estimate and uncertainty, e.g. posterior means and (co)variances.</li>
</ol>
<p>In this post, we will not be concerned with the first step, and instead focus on the last two steps: computing the posterior and reporting a summary. This means that we assume someone has already done the modeling for us and has asked us to report back something useful.</p>
<h1 id="the-problem-a-solution-and-a-faster-solution">The problem, a solution, and a faster solution</h1>
<p>Computing the posterior and reporting summaries is generally not a simple task. To see why, consider again the equation we use to compute the posterior:</p>
<script type="math/tex; mode=display">P(\theta \vert y_{1:n}) = \frac{P(y_{1:n} \vert \theta) P(\theta)}{P(y_{1:n})}.</script>
<p>One issue is that typically, <strong>there is no nice closed-form solution for this expression</strong>. Usually, we can’t solve it analytically or find a standard easy-to-use distribution. Another issue is that to find a point estimate such as the mean, or to compute the normalizing constant <script type="math/tex">P(y_{1:n})</script>, <strong>we need to integrate over a high-dimensional space</strong>. This is a fundamentally difficult problem, and is also known as the <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a>. The problem is exacerbated by the large and high-dimensional datasets we have these days. (Why does calculating <script type="math/tex">P(y_{1:n})</script> involve integration? Because it is a marginal obtained by integrating out the parameters: <script type="math/tex">P(y_{1:n}) = \int P(y_{1:n}, \theta) d\theta</script>.)</p>
<p>This is where <strong>approximate Bayesian inference</strong> comes in. The gold standard in this area has been <strong>Markov Chain Monte Carlo (MCMC)</strong>. It is extremely widely used and has been called one of the top 10 most influential algorithms of the 20th century. <strong>MCMC is eventually accurate, but it is slow</strong>. In many applications, we simply don’t have the time to wait until the computation is finished.</p>
<p>A faster method is called <strong>Variational Inference (VI)</strong>. In this post, we’ll take a deeper dive into how this works.</p>
<h1 id="variational-inference-and-the-mean-field-variational-bayes-mfvb-framework">Variational Inference and the Mean Field Variational Bayes (MFVB) framework</h1>
<p>The main idea in VI is that <strong>instead of trying to find the real posterior distribution <script type="math/tex">p(\cdot \vert y)</script>, we approximate it with a distribution <script type="math/tex">q</script></strong>. Of course, not every distribution would be useful as an approximation. For some distributions, measures that we’d like to report, like the mean and variance, can’t be found. We therefore restrict our search to <script type="math/tex">Q</script>, the space of distributions that have certain “nice” properties. We’ll discuss later what “nice” means exactly. <strong>In this space, we search for a distribution <script type="math/tex">q^*</script> that minimizes a certain measure of dissimilarity to <script type="math/tex">p</script></strong>. Mathematically:</p>
<script type="math/tex; mode=display">q^* = argmin_{q\in Q} f(q(\cdot), p(\cdot \vert y)),</script>
<p>where <script type="math/tex">f</script> is some measure of dissimilarity.</p>
<h2 id="kl-divergence">KL-divergence</h2>
<p>There are many measures of dissimilarity to choose from, but one is particularly useful: the <strong>Kullback-Leibler divergence, or <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a></strong>.</p>
<p>KL-divergence is a concept from <strong>Information Theory</strong>. For distributions <script type="math/tex">p</script> and <script type="math/tex">q</script>, the KL-divergence is given by</p>
<script type="math/tex; mode=display">KL(p\ \vert\vert\ q) = \int p(x)\ln \frac{p(x)}{q(x)}dx.</script>
<p><a href="https://www.cs.cmu.edu/~epxing/Class/10708-17/notes-17/10708-scribe-lecture13.pdf">Intuitively, there are three cases of importance</a>:</p>
<ul>
<li>If <strong><script type="math/tex">p</script></strong> is <strong>high</strong> and <strong><script type="math/tex">q</script></strong> is <strong>high</strong>, then we are <strong>happy</strong> i.e. low KL-divergence.</li>
<li>If <strong><script type="math/tex">p</script></strong> is <strong>high</strong> and <strong><script type="math/tex">q</script></strong> is <strong>low</strong> then we <strong>pay a price</strong> i.e. high KL-divergence.</li>
<li>If <strong><script type="math/tex">p</script></strong> is <strong>low</strong> then <strong>we don’t care</strong> i.e. also low KL-divergence, <strong>regardless of <script type="math/tex">q</script></strong>.</li>
</ul>
<p>The following figure illustrates KL-divergence for two normal distributions <script type="math/tex">\pi_1</script> and <script type="math/tex">\pi_2</script>. A couple of things to note: divergence is indeed high when <script type="math/tex">p</script> is high and <script type="math/tex">q</script> is low; divergence is 0 when <script type="math/tex">p = q</script>; and the complete KL-divergence is given by the area under the green curve.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/KL-example.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
KL divergence between two normal distributions. In this example \(\pi_1\) is a standard normal distribution and \(\pi_2\) is a normal distribution with a mean of 1 and a variance of 1. The value of the KL divergence is equal to the area under the curve of the function. <a href="https://www.researchgate.net/publication/319662351_Using_the_Data_Agreement_Criterion_to_Rank_Experts'_Beliefs">(Source)</a> </figcaption>
</figure>
</div>
<p>Loosely speaking, KL-divergence can be interpreted as <strong>the amount of information that is lost when <script type="math/tex">q</script> is used to approximate <script type="math/tex">p</script></strong>. I won’t be going much deeper into it, but it has a couple of properties that are interesting for our purposes. The KL-divergence is:</p>
<ul>
<li>
<p><strong>Not symmetric</strong>: <script type="math/tex">KL(p\ \vert\vert\ q) \neq KL(q\ \vert\vert\ p)</script> in general. It can therefore not be interpreted as a distance measure, which is required to be symmetric.</p>
<p>We will be using the KL-divergence <script type="math/tex">KL(q\ \vert\vert\ p)</script>. It is possible to use the <strong>reverse KL-divergence</strong> <script type="math/tex">KL(p\ \vert\vert\ q)</script> as well. Let’s examine the differences.</p>
<p>In practical applications, the true posterior will often be a multimodal distribution. Minimizing KL-divergence leads to <strong>mode-seeking</strong> behavior, which means that most probability mass of the approximating distribution <script type="math/tex">q</script> <strong>is centered around a mode of <script type="math/tex">p</script></strong>. Minimizing reverse KL-divergence leads to <strong>mean-seeking</strong> behavior, which means that <script type="math/tex">q</script> would <strong>average across all of the modes</strong>. This would typically lead to poor predictive performance, since the average of two good parameter values is usually not a good parameter value itself. This is illustrated in the following figure.</p>
</li>
</ul>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/KL-inclusive-exclusive.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
Minimizing \(KL(q\ \vert\vert\ p)\) versus \(KL(p\ \vert\vert\ q)\). The first (exclusive) leads to mode-seeking behavior, while the latter (inclusive) leads to mean-seeking behavior. (Source <a href="https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function">Tim Vieira's blog</a>, figure by <a href="http://www.johnwinn.org/">John Winn</a>.)
</figcaption>
</figure>
</div>
<p>For a more in-depth discussion, see Tim Vieira’s blog post <a href="https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function">KL-divergence as an objective function</a>.</p>
<ul>
<li><strong>Always <script type="math/tex">\geq 0</script></strong>, with equality only when <script type="math/tex">p = q</script>. Lower KL-divergence thus implies higher similarity.</li>
</ul>
<p>Most useful for us though is the following. We are optimizing <script type="math/tex">q</script> to be as close as possible to the real distribution <script type="math/tex">p</script>, but we don’t actually know <script type="math/tex">p</script>. <strong>How do we find a distribution close to <script type="math/tex">p</script> if we don’t even know what <script type="math/tex">p</script> itself is?</strong> It turns out that we can solve this problem by doing some algebraic manipulation. This is huge. Let’s derive the necessary expression:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
KL(q\ \vert\vert\ p(\cdot \vert y)) :=& \int q(\theta)\log \frac{q(\theta)}{p(\theta \vert y)}d\theta \\
=& \int q(\theta)\log \frac{q(\theta)p(y)}{p(\theta , y)} d\theta \\
=&\ \log p(y)\int q(\theta)d\theta - \int q(\theta)\log \frac{p(y, \theta)}{q(\theta)} d\theta.\\
=&\ \log p(y) - \int q(\theta)\log \frac{p(y \vert \theta) p(\theta)}{q(\theta)} d\theta.
\end{align} %]]></script>
<p>Here we use Bayes’ theorem to substite out <script type="math/tex">p(\theta\vert y)</script> in the second line. Then, we use the property of logarithms <script type="math/tex">\log(ab) = \log(a) + \log(b)</script>, together with the fact that <script type="math/tex">p(y)</script> doesn’t depend on <script type="math/tex">\theta</script>, and that <script type="math/tex">\int q(\theta) d\theta = 1</script> since <script type="math/tex">q(\theta)</script> is a probability distribution over <script type="math/tex">\theta</script>, to arrive at the result. Phew, that was a whole mouthful.</p>
<p>Since <script type="math/tex">p(y)</script> is fixed, we only need to consider the second term, which has a name: the <strong>Evidence Lower Bound (ELBO)</strong>. We can see from the last equation why it is called this way. <script type="math/tex">KL(q\ \vert\vert\ p) \geq 0</script> implies <script type="math/tex">\log p(y) \geq \text{ ELBO}</script>. It is thus a lower bound on the log evidence <script type="math/tex">\log p(y)</script>.</p>
<p>To minimize KL-divergence, we thus need to maximize the ELBO. <strong>The ELBO depends on <script type="math/tex">p</script> only through the likelihood and prior, which we already know!</strong> This is something we can actually compute without having to know the real distribution!</p>
<h2 id="mean-field-variational-bayes-mvfb">Mean Field Variational Bayes (MVFB)</h2>
<p>I promised to tell you what kinds of distributions we think of as “nice”. Firstly, we want to be able to report a mean and a variance, so these must exist. We then make <strong>the MFVB assumption</strong>, also known as <strong>Mean-Field Approximation</strong>. The approximation is a simplifying assumption for our distribution <script type="math/tex">q</script>, which <strong>factorizes the distribution into independent parts</strong>:</p>
<script type="math/tex; mode=display">q(\theta) = \prod_i q_i(\theta_i).</script>
<p>From a statistical physics point of view, “mean-field” refers to the relaxation of a difficult optimization problem to a simpler one which ignores second-order effects. The optimization problem becomes easier to solve.</p>
<p>Note that this is <strong>not a modeling assumption</strong>. We are <strong>not</strong> saying that the parameters in our model are <strong>independent</strong>, which would limit us only to uninteresting models. We are only saying that the parameters are independent in our <strong>approximation</strong> of the posterior.</p>
<p>We often also assume a distribution from the <strong>exponential family</strong>, since these have nice properties that make life easier.</p>
<p>Now that we have defined a space and metric to optimize over, <strong>we have a clearly defined optimization problem</strong>. At this point, we can use any optimization technique we’d like to find <script type="math/tex">q^*</script>. Typically, <a href="https://en.wikipedia.org/wiki/Coordinate_descent">coordinate gradient descent</a> is used.</p>
<h1 id="when-can-we-trust-our-method">When can we trust our method</h1>
<p>Since Variational Inference is an approximate method, we’d like to know <strong>how accurate</strong> the approximation actually is. In other words, when we can trust it. If we schedule an ambulance based on the prediction that it will take 10 minutes to arrive, we have to be damn sure that our <strong>confidence in the prediction is justified</strong>.</p>
<p>One way to check whether the method works is to consider a simple example that we know the correct answer to. We can then see how well MFVB approximates that answer. To do this, we consider the (rather randomly chosen) problem of estimating midge wing length.</p>
<h2 id="estimating-midge-wing-length">Estimating midge wing length</h2>
<p><em>Note that understanding all the details in this example is not required for an understanding of the big picture.</em></p>
<p>Midge is a term used to refer to many species of small flies. Before we can compute a posterior, we need a model. As before, we assume that we are given a model, which is determined by a likelihood and a prior. We only need to worry about computing a posterior and reporting back a summary. The model given to us is the following.</p>
<p>Assume midge wing length is <strong>normally distributed</strong> with unknown mean <script type="math/tex">\mu</script> and unknown precision <script type="math/tex">\tau</script>. (We use precision, the inverse of variance, because it’s mathematically more convenient.) Let <script type="math/tex">y</script> be the midge wing length. We care about finding the <strong>posterior</strong></p>
<script type="math/tex; mode=display">p(\mu, \tau \vert y_{1:N}) \propto p(y_{1:N} \vert \mu, \tau) p(\mu, \tau).</script>
<p>The <strong>likelihood</strong> is then given by</p>
<script type="math/tex; mode=display">p(y_{1:N} \vert \mu, \tau) = \prod_i \mathcal{N}(y_i \vert \mu, \tau^{-1}),</script>
<p>where <script type="math/tex">\mathcal{N}(\cdot)</script> denotes the normal distribution.</p>
<p>The <a href="https://en.wikipedia.org/wiki/Conjugate_prior"><strong>conjugate prior</strong></a> for a Gaussian with unknown mean and variance is a <a href="https://en.wikipedia.org/wiki/Normal-gamma_distribution">Gaussian-Gamma distribution</a> given by</p>
<script type="math/tex; mode=display">p(\mu, \tau) = \mathcal{N}(\mu \vert \mu_0, (\beta\tau)^{-1})Gamma(\tau \vert a, b),</script>
<p>where <script type="math/tex">Gamma</script> is the <a href="https://en.wikipedia.org/wiki/Gamma_distribution">Gamma distribution</a> and <script type="math/tex">\mu_0, \beta, a, b</script> are <strong>hyperparameters</strong>.</p>
<p>To start solving the problem, we first use the <strong>mean-field assumption</strong> resulting in the factorization:</p>
<script type="math/tex; mode=display">q^*(\mu, \tau) = q_\mu^*(\mu)q_\tau^*(\tau) = argmin_{q\in Q_{MFVB}} KL(q(\cdot)\ \vert\vert\ p(\cdot \vert y)).</script>
<p>The factors <script type="math/tex">q_\mu^*(\mu)</script> and <script type="math/tex">q_\tau^*(\tau)</script> can be derived [Bishop 2006, Sec. 10.1.3]:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
q_\mu^*(\mu) &= \mathcal{N}(\mu \vert m_\mu, \rho_\mu^2)\\
q_\tau^*(\tau) &= Gamma(\tau \vert a_\tau, b_\tau),
\end{align} %]]></script>
<p>where “variational parameters” <script type="math/tex">m_\mu, \rho_\mu^2, a_\tau, \text{ and } b_\tau</script> determine the approximating distribution.</p>
<p>We then <strong>iterate</strong>. First, we make an initial guess for the variational parameters. Then, we cycle through each factor. We find the approximating distribution of <script type="math/tex">\mu</script> given the distribution of <script type="math/tex">\tau</script> in one step and the approximating distribution of <script type="math/tex">\tau</script> given the distribution of <script type="math/tex">\mu</script> in another step:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{align}
(m_\mu, \rho_\mu^2) &= f(a_\tau, b_\tau)\\
(a_\tau, b_\tau) &= g(m_\mu, \rho_\mu^2).
\end{align} %]]></script>
<p>We repeat this procedure until convergence.</p>
<p>The following figure shows how our approximation (blue) of the real posterior (green) gets more and more accurate by applying coordinate descent, resulting in quite a good approximation (red).</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig1.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">The process of variational approximation to the Gaussian-gamma distribution. Our approximation (blue) of the real posterior (green) gets more and more accurate by applying coordinate descent, resulting in quite a good approximation (red) (Source: PRML, Bishop 2006)</figcaption>
</figure>
</div>
<h2 id="variance-underestimation">Variance underestimation</h2>
<p>One of the major problems that shows up is that <strong>the variational distribution often underestimates the variance of the real posterior</strong>. This is a result of minimizing the <strong>KL-divergence</strong>, which encourages a small value of the approximating distribution when the true distribution has a small value. This is showcased in the next figure, where MFVB is used to fit a multivariate Gaussian. The mean is correctly captured by the approximation, but the variance is severely underestimated. This gets progressively worse as the correlation between the two variables increases.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig2.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">The MFVB approximation of the true distribution, a multivariate Gaussian, severely underestimates the true variance. This gets worse as the correlation increases (Source: <a href="http://www.gatsby.ucl.ac.uk/~maneesh/papers/turner-sahani-2010-ildn.pdf">Turner, Sahani 2010</a>).</figcaption>
</figure>
</div>
<p>Another way to test the validity of our approximations is to compare them to the answers of <strong>a method that we know works: MCMC</strong>. We can use this for more complex problems for which we cannot find an analytical solution. One real-life application deals with <strong>microcredit</strong>.</p>
<p>Microcredit is an initiative that helps impoverished people become self-employed or start a business by giving them extremely small loans. Of course, we’d like to know if this approach actually has a positive effect, how large this effect is, and <strong>how certain</strong> we are of that.</p>
<p>The next figure shows once more that MFVB estimates of microcredit effect variance indeed underestimate the true variance.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig4.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">Microcredit effect variance estimates given by MCMC versus MFVB. MFVB underestimates variance (Source: Giordano, Broderick, Jordan 2016).</figcaption>
</figure>
</div>
<h2 id="mean-estimates">Mean estimates</h2>
<p>MFVB not always getting the variance right begs another question: Can estimates of the mean be incorrect too? The answer is yes, as demonstrated by the following two figures. In the first figure, MFVB estimates for the mean of a parameter <script type="math/tex">\nu</script> disagree with MCMC estimates. The same happens in the second figure, where some MFVB estimates are even outside the 95% credible interval obtained by MCMC.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig3.png" alt="image" style="width:60%; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">MFVB estimates for the mean of a parameter disagree with MCMC estimate (Source: Giordano, Broderick, Jordan 2015).</figcaption>
</figure>
</div>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig5.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">MFVB estimates disagree with MCMC estimates (Source: <a href="https://digital.lib.washington.edu/researchworks/bitstream/handle/1773/24305/Fosdick_washington_0250E_12238.pdf?sequence=1&isAllowed=y">Fosdick 2013)</a>.</figcaption>
</figure>
</div>
<h2 id="what-can-we-do">What can we do?</h2>
<p>We’ve seen that MFVB doesn’t always produce accurate approximations. What do we do then? Some major lines of research to alleviate this problem are:</p>
<ul>
<li>
<p><strong>Reliable diagnostics</strong>: <strong>Fast procedures</strong> that tell us <strong>after the fact</strong> if the approximation is good. One way to achieve this might be to find a fast way to find the KL-divergence of our approximation, since we know it is bounded below by 0. Usually, we only have access to the ELBO, of which we don’t have such a bound.</p>
</li>
<li>
<p><strong>Richer “nice” set</strong>: In the MFVB framework, we only consider optimizing over a set of functions that factorize. Considering a richer set of functions might help. It turns out though that having a richer nice set doesn’t necessarily yield better approximations. We’d have to make other assumptions that complicate the problem as well.</p>
</li>
<li>
<p><strong>Alternative divergences</strong>: Minimizing other divergences than the KL-divergence might help, but has similar difficulties as the above point.</p>
</li>
<li>
<p><strong>Data compression</strong>: Before using an inference algorithm, we can consider doing a <strong>preprocessing step in which the data is compressed</strong>. We would like to have theoretical guarantees on the quality of our inference methods on this compressed dataset.</p>
</li>
</ul>
<p>Until we have a foolproof way to test for the reliability of the estimates obtained by MFVB, it is important to be wary of the results obtained, as they may not always be correct.</p>
<h1 id="summary">Summary</h1>
<p>We’ve discussed how Bayesian Variational Inference works. By framing the problem as an optimization problem, we can find results much faster compared to the classic MCMC algorithm. This comes at a price: we don’t always know when our approximation is accurate. This is still very much an open problem that researchers are working on.</p>
<h2 id="related-publications">Related Publications</h2>
<p>If you would like to read further, here are some <strong>related publications</strong>:</p>
<ul>
<li>Bishop. Pattern Recognition and Machine Learning, Ch 10. 2006.</li>
<li>Blei, Kucukelbir, McAuliffe. Variational inference: A review for statisticians, JASA2016.</li>
<li>MacKay. Information Theory, Inference, and Learning Algorithms, Ch 33. 2003.</li>
<li>Murphy. Machine Learning: A Probabilistic Perspective, Ch 21. 2012.</li>
<li>Ormerod, Wand. Explaining Variational Approximations. Amer Stat 2010.</li>
<li>Turner, Sahani. Two problems with variational expectation maximisation for time-series models. In Bayesian Time Series Models, 2011.</li>
<li>Wainwright, Jordan. Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 2008.</li>
</ul>
<p><strong>More Experiments</strong>:</p>
<ul>
<li>RJ Giordano, T Broderick, and MI Jordan. Linear response methods for accurate covariance estimates from mean field variational Bayes. NIPS 2015.</li>
<li>RJ Giordano, T Broderick, R Meager, J Huggins, and MI Jordan. Fast robustness quantification with variational Bayes. ICML Data4Good Workshop 2016.</li>
<li>RJ Giordano, T Broderick, and MI Jordan. Covariances, robustness, and variational Bayes, 2017. Under review. ArXiv:1709.02536.</li>
</ul>
<p><em>Thanks for reading! Any thoughts? Leave them in the comments below!</em></p>Omar Elbaghdadiomarelblog@gmail.comCompared to the frequentist paradigm, Bayesian inference allows more readily for dealing with and interpreting uncertainty, and for easier incorporation of prior beliefs. A big problem for traditional Bayesian inference methods, however, is that they are computationally expensive. In many cases, computation takes too much time to be used reasonably in research and application. This problem gets increasingly apparent in today’s world, where we would like to make good use of the large amounts of data that may be available to us.Hello World!2018-10-28T00:00:00+00:002018-10-28T00:00:00+00:00https://omarelb.github.io/hello-world<p>Hi there! This is my first blog post ever! I want to start writing about topics in artificial intelligence and machine learning.</p>
<p>I got inspired to start blogging by several other bloggers, including <a href="https://medium.com/@racheltho/why-you-yes-you-should-blog-7d2544ac1045">Rachel Thomas</a>, <a href="https://machinelearningmastery.com/about/">Jason Brownlee</a>, and <a href="https://ayearofai.com/0-2016-is-the-year-i-venture-into-artificial-intelligence-d702d65eb919">Rohan and Lenny</a>. The key reasons for doing this are:</p>
<ul>
<li>Building up a portfolio,</li>
<li>helping understand stuff.</li>
</ul>
<p>As of around a year ago, I began to grow increasingly interested in Artificial Intelligence and Computer Science. So much so, that I decided I want to apply for a master’s degree in the subject. My background in Econometrics and Operations Research has given me a solid base in statistical modelling and optimization, which I believe is of core importance to many applications in AI.</p>
<p>There are still way too many subjects, however, that I don’t know too much about. By writing about them and explaining concepts as simply as I possibly can, I hope to really deepen my knowledge on subjects, improve my technical writing and help make things clear for others at the same time.</p>
<p>I want to make learning about Artificial Intelligence a project, a journey. There are so many exciting advances in the field, so there is a lot to catch up with. Real world applications are becoming increasingly widespread as well, even hitting headlines regularly, making it a science that’s interesting even at parties. I guess it’s the perfect subject for the phrase: <em>“Let’s make nerds cool again.”</em></p>
<p>This post is the start of a journey for me, let’s make it an interesting ride! Also, I love meeting new people, so don’t hesitate to contact me!</p>Omar Elbaghdadiomarelblog@gmail.comHi there! This is my first blog post ever!