Jekyll2020-10-24T09:17:00+00:00https://omarelb.github.io/feed.xmlOmar’s BlogA blog with topics on Artificial Intelligence.Omar Elbaghdadiomarelb[at]gmail[dot]comDivergence in Deep Q-Learning: Two Tricks Are Better Than (N)one2020-10-24T00:00:00+00:002020-10-24T00:00:00+00:00https://omarelb.github.io/dqn-investigation<p><em>By <a href="https://github.com/edudev/">Emil Dudev</a>, <a href="https://github.com/AmanDaVinci">Aman Hussain</a>, <a href="https://omarelb.github.io">Omar Elbaghdadi</a>, and <a href="https://www.linkedin.com/in/ivan-bardarov">Ivan Bardarov</a>.</em></p>
<p>Deep Q Networks (DQN) revolutionized the Reinforcement Learning world. It was the first algorithm able to learn a successful strategy in a complex environment immediately from high-dimensional image inputs. In this blog post, we investigate how some of the techniques introduced in the original paper contributed to its success. Specifically, we investigate to what extent <strong>memory replay</strong> and <strong>target networks</strong> help prevent <strong>divergence</strong> in the learning process.</p>
<!--more-->
<p>Reinforcement Learning (RL) has already been around for a while, but it is not even close to being solved yet. While <em>supervised learning</em> can already be quite difficult, RL methods also need to deal with changes in the data distribution, huge state spaces, partial observability, and various other challenges. In 2013, the paper <a href="https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf">Playing Atari with Deep Reinforcement Learning (Mnih et al.)</a> introduces <strong>DQN, the first RL method to successfully learn good policies directly from high-dimensional inputs using neural networks</strong>. The algorithm performs better than human experts in several Atari games, learning directly from image input.</p>
<!--- ![pic alt](/assets/images/posts/2020/dqn-investigation/space_invaders_games_2.png "opt title") -->
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2020/dqn-investigation/space_invaders_games_2.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
Screenshots from three Atari 2600 Games: (Left-to-right) Pong, Breakout, and Space Invaders.
</figcaption>
</figure>
</div>
<p>The DQN authors improve on DQN in their <a href="https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf">2015 paper</a>, introducing additional techniques to stabilize the learning process. In this post, we take a look at the two key innovations of DQN, <strong>memory replay</strong> and <strong>target networks</strong>. We run our own experiments, investigating to what degree each of these techniques helps avoid <strong>divergence</strong> in the learning process. When divergence occurs, the quality of the learned strategy has a high chance of being destroyed, which we want to avoid. Studying the conditions of divergence also allows us to get a better insight into the learning dynamics of \(Q\)-learning with neural network function approximation.</p>
<p>The rest of this post is outlined as follows:</p>
<ul>
<li>We first develop a little bit of the <strong>background</strong>, briefly going into RL, \(Q\)-learning, function approximation with neural networks, and the DQN algorithm.</li>
<li>We then give a definition of <strong>divergence</strong>, which we use in our experiments.</li>
<li>We describe the <strong>experimental setup</strong>,</li>
<li>after which we <strong>discuss</strong> the results.</li>
</ul>
<!--- - While RL has been around for a while, first time shown to work well with high-dimensional sensory input in 2013 (or 2015) by DQN paper.
- they did this by successfully playing multiple Atari games using the same learning framework, even beating human expert players in some of them.
- function approximation with neural networks had been around for a while, but never succeeded
- DQN introduced some tricks that helped: Experience replay memory, and target networks
- In this blog post, we explore to what extent each of the techniques introduced by DQN contributed to its success
- Specifically, we investigate to what extent each of the techniques avoid divergence in the learning process.
- This in turn gives us insight into the learning dynamics of Q-learning with neural network function approximation. A better understanding of these learning dynamics allows us to focus research on the most promising methods and give us insight into the more important aspects of learning. -->
<h2 id="background">Background</h2>
<p>In this post, we will just give a brief overview of the main techniques, and not go too deep into all the background theory. If you want to dig deeper, we suggest checking out <a href="https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf">the original paper</a>.</p>
<h3 id="reinforcement-learning">Reinforcement Learning</h3>
<!--- - 2 or 3 sentences about RL in general
In Reinforcement Learning (RL), an agent learns to take good actions by optimizing a scalar reward given by its environment. The agent learns to map the current state of the world to a probability distribution over its actions, which we call a policy. -->
<p>In RL, we study an <strong>agent</strong> interacting with some <strong>environment</strong>. The agent learns to take good actions by optimizing a <strong>scalar reward</strong> given by its environment. The agent learns to map the current state of the world, \(s\), to a probability distribution over its actions \(\pi(a \mid s)\), which we call a <strong>policy</strong>. In an Atari game, the game is the environment, and the player is the agent who is trying to maximize their score by learning a good policy.</p>
<p>The environment provides us with a reward signal at every point in time. We care about getting the maximum cumulative reward over time, the <strong>return</strong>. At any timestep \(t\), we can define the future return as:
\(\begin{align}
G_{t} :=\ &r_t + \gamma r_{t+1} + \ldots + \gamma^{T - t}r_T = \sum_{t'=t}^T \gamma^{t'-t}r_{t'}\\
=\ &r_t + \gamma G_{t + 1}, & (1)
\end{align}\)</p>
<p>where \(r_t\) is the reward at time \(t\), \(T\) is the time-step where the <strong>episode</strong> terminates, and \(0 \leq \gamma \leq 1\) is the <strong>discount rate</strong>. The discount rate is used to control how much we care about future rewards, with higher values looking farther into the future. An episode can be seen as one instance of learning. In the Atari world, an episode is one round of playing before a game over. Equation 1 provides us with a very important identity for learning later on.</p>
<p>Since we don’t know what rewards we are going to get in the future, we have to work with the <strong>expected</strong> future (discounted) return. This leads us to \(Q\)-values, defined as the expected future return, given that we take action \(a\) in state \(s\) and follow policy \(\pi\) afterwards:</p>
\[Q^\pi(s, a) := \mathbb{E}_\pi[G_t \mid S_t = s, A_t = a].\]
<p>The expectation is with respect to \(\pi\), since it determines (along with the environment) which states are visited, and in turn which rewards are obtained.</p>
<h3 id="q-learning">Q-Learning</h3>
<p>If we can learn these \(Q\)-values, we know which actions yield the best returns, allowing us to optimize our policy. One technique based on this principle is <strong>\(Q\)-Learning</strong>. In \(Q\)-learning, we learn the optimal \(Q\)-values directly from experienced environment transitions \((s, a, r, s')\), where \(s'\) is the state following \(s\) after taking action \(a\). The following update rule is used:</p>
\[Q(s, a) \leftarrow Q(s, a) + \alpha (r + \gamma \max_{a'} Q(s', a') - Q(s, a)), (2)\]
<p>where \(\alpha\) is a learning rate parameter controlling learning speed. This update pushes the current \(Q\)-values \(Q(s, a)\) towards their <strong>bootstrap targets</strong> \(r + \gamma \max_{a'}Q(s', a')\). The sample transitions can be generated using <em>any</em> policy, such as an <a href="https://medium.com/analytics-vidhya/the-epsilon-greedy-algorithm-for-reinforcement-learning-5fe6f96dc870">epsilon-greedy policy</a>, making \(Q\)-learning an <a href="https://stats.stackexchange.com/questions/184657/what-is-the-difference-between-off-policy-and-on-policy-learning"><strong>off-policy</strong></a> method.</p>
<!--- In most realistic scenarios such as playing Atari games, we can't store $$Q$$-values for every possible state, as the state space is too large. It is therefore usually necessary to **approximate** the $$Q$$-values. -->
<!--- - q learning
- definitions return, q-learning objective -->
<h3 id="function-approximation">Function Approximation</h3>
<p>In most realistic scenarios, the state space is too large to store \(Q\)-values for. Imagine mapping an Atari game state to a \(Q\)-value directly from image data. Assuming RGB pixel values and an 84x84 pixel screen, we would need to store \((256 \cdot 3)^{84\cdot84}\) values, one for each pixel configuration. Besides this impracticality, we would also not generalize well between different pixel states, as these do not capture latent structure efficiently.</p>
<p>Therefore, <strong>function approximation</strong> is used to predict \(Q\)-values using some learned function, given a state or state-action pair. This allows \(Q\)-values to be represented in a compressed form (the parameters) and generalization over similar states.</p>
<p>In DQN, the \(Q\) update is a little bit different than described in Equation 2, since it uses function approximation with parameters \(\theta\), i.e. \(Q(s,a) = Q(s, a; \theta)\). It is roughly equivalent<sup id="fnref:semi-gradient" role="doc-noteref"><a href="#fn:semi-gradient" class="footnote">1</a></sup> to minimizing the mean squared error between the target \(r + \gamma \max_{a'} Q(s', a')\) and the current \(Q\)-value using <a href="https://towardsdatascience.com/stochastic-gradient-descent-clearly-explained-53d239905d31">stochastic gradient descent</a>:</p>
\[\begin{align*}
\theta^{t+1} &\leftarrow \theta^t +
\\
&\alpha [(r + \gamma \max_{a'} Q(s', a'; \theta^t) - Q(s, a; \theta^t)) \nabla_{\theta^t} Q(s, a; \theta^t)], & (3)
\end{align*}\]
<p>where \(Q\) is implemented as a neural network. While neural networks can learn very complex dynamics, they are also notoriously unstable. This instability prevented neural networks (and other complex function approximators) from being used successfully in RL for quite some time. That is, until DQN proposed several techniques to combat this instability, including <strong>experience replay</strong> and <strong>target networks</strong>.</p>
<!--- For a long time, linear models were the go-to function approximator, since they are theoretically relatively straightforward to study. However, these models are in many cases too simple to accurately capture complex system dynamics. A next obvious option was using **neural networks**. -->
<h3 id="experience-replay">Experience Replay</h3>
<p>We’ve seen that DQN learns \(Q\)-values using neural networks. This can be seen as supervised learning. In this paradigm, a key assumption is that data is independently and identically distributed (i.i.d.). In RL, however, this does not hold. Subsequent states are highly correlated, and the data distribution changes as the agent learns. To deal with this, DQN saves the last \(N\) experienced transitions in memory with some finite capacity \(N\). When performing a \(Q\)-value update, it uses experiences randomly sampled from memory.</p>
<p>The idea of sampling randomly is to <strong>break the correlation</strong> between updated experiences, increasing sample efficiency and reducing variance. The authors also argue that the technique: helps by avoiding unwanted feedback loops; and averages the behavior distribution over many previous states, smoothing out learning and avoiding divergence.</p>
<h3 id="target-networks">Target Networks</h3>
<p>In the parameter update given by Equation 3, the Q network predicts both the current state’s predicted \(Q\)-value, as well as the <strong>target</strong>: \(r + \gamma \max_{a'} Q(s', a'; \theta^t)\). However, after the parameters of the network are updated, the target value changes as well. This is like asking the network to learn to throw a bull’s eye, but then moving the dart board somewhere else. This leads to instability.</p>
<p>To tackle this problem, DQN proposes using a <strong>target network</strong>. The idea is to compute the target using a (target) network that is not updated for some amount of time-steps. That way, the targets don’t “move” during training. Every \(C\) time-steps, the target network is synchronized with the current \(Q\) network.</p>
<!--- - function approximation
- neural networks
- dqn's
- what is different about dqn's
- way of modelling
- stack last 4 frames
- most importantly: 2 techniques
- experience replay
- store samples in memory and sample.
- Why? Break correlations. ML methods require iid data.
- target networks
- one network stays fixed for some period of time, this is the target network
- why? this stabilizes learning -->
<h2 id="divergence">Divergence</h2>
<p>Our goal is to find out to what extent the two techniques mentioned above help dealing with divergence in the learning process. Divergence occurs when the \(Q\)-function approximator learns unrealistically high values for state-action pairs, in turn destroying the quality of the greedy control policy derived from \(Q\) <a href="http://arxiv.org/abs/1812.02648">(Van Hasselt et al.)</a>.</p>
<p>For most environments, we don’t know the true Q-values. How do we know when divergence occurs then? <a href="http://arxiv.org/abs/1812.02648">Van Hasselt et al.</a> use a clever trick to define <strong>soft divergence</strong>, a proxy for divergence. To avoid instability, DQN clips all rewards to the range \([-1, 1]\). Thus, the future return at some state is bounded by:</p>
\[\sum_{t'=t}^T \gamma^{t'-t}|r_{t'}| \leq \sum_{t'=t}^\infty \gamma^{t'-t}|r_{t'}| \leq \sum_{t'=t}^\infty \gamma^{t'-t} = \frac{1}{1-\gamma}, (4)\]
<p>where the last equality is a general result for geometric series. This means that any \(Q\)-value is theoretically bounded by (4). <strong>If the maximum absolute \(Q\)-value exceeds this bound, we say that soft divergence occurs.</strong></p>
<!--- - goal of this blog post is to find out to what extent each of these techniques help to deal with divergence
- if the networks diverge, we are most likely not learning anything meaningful.
- Every state-action value is assumed to exist and be finite. If the algo doesn't converge, it means we are not in
a local or global optimum.
- Defining divergence
- For most environments, we don't know the true Q-values. How do we know when divergence occurs then?
- intuition: if some state-action pairs get assigned unrealistically high values, we say there is **soft divergence**.
- when are values too high? Show discount_factor / max q value calculation.
- reward clipping -->
<h2 id="experimental-setup">Experimental setup</h2>
<p>We try to follow the experimental setup of the <a href="https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf">DQN paper</a> wherever possible. Even though the authors use a convolutional neural network to play Atari games, we limit ourselves to a simpler architecture given our computation and time constraints. We use a <strong>fully-connected</strong> neural network with <strong>a single hidden layer (excluding input and output layers) of size 128</strong>, mapping from input states to a discrete set of actions. We use <strong>ReLU</strong> activation functions at each layer before the output layer.</p>
<p>Given our constraints, we only look at relatively simple, computationally inexpensive environments. <strong>We explore three environments</strong>, so that our results are not specific to any one of them. Each environment is chosen to be of a different (estimated) difficulty, as we consider this an important distinction in our context. We consider classical control theory environments made available by <a href="https://gym.openai.com/envs/#classic_control">OpenAI gym</a>:</p>
<h3 id="environment-1-cart-pole">Environment 1: <a href="https://gym.openai.com/envs/CartPole-v1/">Cart Pole</a></h3>
<p><img src="/assets/images/posts/2020/dqn-investigation/cartpole.gif" alt="cartpole" /></p>
<p>In the Cart Pole environment, the agent tries to balance a pole on a cart by applying a rightward or a leftward force. For every time step the pole remains upright (less than 15 degrees from vertical), the agent receives a reward of +1. Since his problem is considered relatively easy to solve. we chose it as a representative of problems with low difficulty.</p>
<h3 id="environment-2-acrobot">Environment 2: <a href="https://gym.openai.com/envs/Acrobot-v1/">Acrobot</a></h3>
<p><img src="/assets/images/posts/2020/dqn-investigation/acrobot.gif" alt="acrobot" /></p>
<p>In the Acrobot environment, the agent tries to swing up a two-link robot arm above the base by applying a clockwise or anti-clockwise torque. This problem is considered more difficult than the previous one, so we select it as a representative of problems with mid-level difficulty.</p>
<h3 id="environment-3-mountain-car">Environment 3: <a href="https://gym.openai.com/envs/MountainCar-v0/">Mountain Car</a></h3>
<p><img src="/assets/images/posts/2020/dqn-investigation/mountaincar.gif" alt="mountaincar" /></p>
<p>In the Mountain Car environment, the agent starts a car at the bottom of a valley and tries to drive it up the right hill. However, the car’s engine is not strong enough to do so in a single pass. Instead it has to go back and forth between the left and right hill to build momentum. This problem is quite challenging, so we choose it as a representative of problems with high-level difficulty.</p>
<h3 id="experiments-and-hyperparameters">Experiments and hyperparameters</h3>
<p>Since divergence can now be quantified, we use it as a metric to compare which algorithms exhibit more divergence than others. <strong>We say an algorithm exhibits more divergence if the fraction of runs in which soft divergence occurs is higher.</strong> We refer to Memory Replay and Target Networks as DQN’s “tricks”. The improvement that each of the tricks brings to DQN is measured against the <strong>baseline</strong> model, DQN without tricks, or <em>vanilla</em> DQN.
<em>We thus compare 4 different setups for each environment</em>: without tricks (vanilla agent), with Memory Replay (memory agent), with target networks (target agent), and with both tricks (DQN / memory+target agent).</p>
<p>We run each experiment with <strong>random seeds from 1 to 25</strong> to achieve more statistically sound results, while taking into account our computational budget. If the maximal absolute Q-value in any of the last 20 training episodes is above the threshold \(\frac{1}{1-\gamma}\), we say soft divergence occurs.
<!--- At the end, we compare the configurations by counting how many times each of them has diverged. --></p>
<p>All the agents are <strong>trained for 700 episodes</strong> which we found to be enough for them to learn to win the games. For better exploration, we use an <strong>\(\epsilon\)-greedy</strong> strategy which is <strong>linearly annealed from 1 to 0.1 during the first 400 episodes</strong>, and kept fixed after. The discount factor is <strong>\(\gamma = 0.99\)</strong> for all the environments.</p>
<p>Another hyperparameter is the <strong>frequency of target network updates</strong> (whenever the technique is used). We empirically find <strong>400, 2000, 2000</strong> to work well for <strong>Mountain Car, Cart Pole and Acrobot respectively</strong>. No extensive hyperparameter search has been done since the focus of our work is not state-of-the-art performance, but to compare the importance of the tricks. The values of <strong>the parameters are selected manually for the configuration with no tricks</strong> and kept fixed for all other configurations of the respective environment.</p>
<p>Similar to the original paper, we use the <strong>mean-squared error (MSE)</strong> loss between the predicted and bootstrap \(Q\)-values. Clipping the loss between \([-1, 1]\) has been reported to improve the training stability of DQN. We do this for all environments except Cart Pole, which achieves better results without the clipping. The error is optimized by <a href="https://arxiv.org/pdf/1412.6980.pdf">Adam</a> with a <strong>learning rate \(\alpha = 0.001\)</strong>. The choice of optimizer deviates from the original paper but has shown great success in deep learning recently. Additional experiments with different values of the learning rate and the contribution of error clipping are left for future work.</p>
<p><em>The code used in all our experiments <a href="https://github.com/VaniOFX/DQN-divergence">can be found on Github</a>.</em></p>
<!--- - Evaluating the different techniques
- how we evaluate the techniques
- run each setup for X runs
- measure the fraction of times that soft divergence occurs
- we do this by tracking the max absolute q value. If this is larger than X, we say divergence occurs
- we do this for ? runs, because ??
- the less divergence occurs, the more we say a technique helps avoiding divergence
- explanation on environments
- we need to do many runs to get some statistically significant results
- we don't have the time and resources to investigate computationally expensive atari games
- therefore, we investigate environments that are relatively simple and computationally inexpensive
- we want enough environments such that we have divergence and convergence on each setup
- the following hyperparameters are important
- optimizer type
- learning rate
- discount factor
- reward clipping
- gradient clipping
- …
We run each setup X times, and report the fraction of runs at which soft divergence occurs. We set the amount of runs to ?? to ensure statistically significant results, while taking our computational budget into account.
We try to make sure our experimental setup coincides with the DQN implementation as much as possible. Due to computational constraints, we unfortunately can't run any experiments on Atari games. Instead, we investigate the following simpler environments: Cart-Pole, Mountain Car, Inverse Pendulum, ... . We want enough environments such that we have divergence and convergence on each setup.
We use the following hyperparameter settings in all our experiments:
- We use an epsilon-greedy exploration strategy, where epsilon is linearly annealed over ?? steps to 0.05, after which it stays at that level.
- Learning rate $$\alpha = x$$
- Adam optimizer
- reward clipping to range [-1, 1]
- gradient clipping to x
- discount factor x
- we try to stick to the original paper as much as possible -->
<h2 id="results">Results</h2>
<p>Our main results can be summarized by the figures below.
Each figure displays a scatter plot for one environment, where <strong>each point represents one training run</strong>. Each point’s x-coordinate is given by its max |\(Q\)|, which can be used to identify soft divergence. The y-coordinate shows its average <em>return</em> in the last 20 episodes, indicating the performance achieved in that run.
This allows us to analyze the effect of the tricks on divergence and overall performance, as well as how these interact, at the same time.
We first discuss the obtained results for each environment separately, from which we draw more general conclusions.</p>
<h3 id="mountain-car">Mountain Car</h3>
<!--- <div style="display: flex; width: 100%; height: 100%;">
<img style="width:33%;" src="./img/MountainCar-v0_rewards_q_values.png">
<img style="width:33%;" src="./img/Acrobot-v1_rewards_q_values.png">
<img style="width:33%;" src="./img/CartPole-v1_rewards_q_values.png">
</div> -->
<!--- TODO: add titles to the plots, identifying the experiment -->
<p>To begin with, let’s look at the Mountain Car results below.</p>
<p><img src="/assets/images/posts/2020/dqn-investigation/MountainCar-v0_rewards_q_values.png" alt="image" title="Mountain Car results" /></p>
<p>The vanilla agent diverges and fails miserably at learning a good policy.
The memory agent also performs badly for most runs, but does learn a good policy for a small amount of runs.
Specifically <strong>for the runs where the memory agent does not diverge, it actually obtains a good overall return.</strong>
This is an interesting observation, as it suggests that our measure of divergence is indeed predictive of final performance for this environment.</p>
<p>The target agent has managed to eliminate divergence completely, but the policy it learns is poor. <strong>Not diverging is clearly not a guarantee for good performance.</strong>
As one would expect, the network with both tricks enabled performs best.
It does not diverge and consistently achieves high rewards.
However, even the DQN agent has runs on which it doesn’t learn anything.
This goes to show that out of the tasks we explore, Mountain Car is relatively difficult.</p>
<h3 id="acrobot">Acrobot</h3>
<p>We now go over the results for the Acrobot environment. For clarity, we use a log scale for the Q values here.</p>
<p><img src="/assets/images/posts/2020/dqn-investigation/Acrobot-v1_rewards_q_values.png" alt="image" title="Acrobot results" /></p>
<p>As with Mountain Car, the vanilla network is the worst out of all configurations here.
Again, it diverges heavily and doesn’t learn any meaningful policy.
On the other hand, <strong>we observe that the memory agent manages to find good policies, despite exhibiting soft divergence</strong>. The variance of its return is higher than that of the other methods, indicating that the learning process is not that stable.
<strong>This suggests that the amount of soft divergence, our proxy for divergence, is not fully indicative of how well an algorithm learns.</strong></p>
<p>We see again that using both tricks alleviates divergence and leads to high returns. If just the target network is used, divergence is again controlled, but the learned policy is still worse than that of using both tricks.</p>
<h3 id="cart-pole">Cart Pole</h3>
<p>The last environment we look at is the Cart Pole environment.</p>
<p><img src="/assets/images/posts/2020/dqn-investigation/CartPole-v1_rewards_q_values.png" alt="image" title="Cart Pole results" /></p>
<p>Despite both the vanilla and memory agents exhibiting soft divergence, they still manage to learn good policies. Interestingly, although the memory agent shows the most divergence, it achieves a higher average return than the other settings.</p>
<p>In line with the previous results, having a target network greatly reduces soft divergence. However, its average return is now even lower than that of the vanilla agent.
Once more, using both tricks controls soft divergence and allows learning good policies, but the memory agent does perform better in this case.</p>
<h3 id="putting-things-into-perspective">Putting Things into Perspective</h3>
<p>So what did we learn from our experiments?
In each of the three environments we explore, the <strong>vanilla agent (soft) diverges every single time.
The target network trick significantly helps in reducing this divergence</strong> as well as the variance of the max |\(Q\)|.
In fact, not a single run diverged when making use of a target network.
<strong>Without the target network, divergence seems almost inevitable.</strong>
This is made especially clear by the below figure, which zooms in on the distributions of the max |\(Q\)| (in log scale). The dotted line indicates the soft divergence threshold.</p>
<div style="display: flex; width: 100%; margin-bottom: 1cm;">
<img style="width:33%;" src="/assets/images/posts/2020/dqn-investigation/violinplot_q_divergence_MountainCar-v0_0.99.png" />
<img style="width:33%;" src="/assets/images/posts/2020/dqn-investigation/violinplot_q_divergence_Acrobot-v1_0.99.png" />
<img style="width:33%;" src="/assets/images/posts/2020/dqn-investigation/violinplot_q_divergence_CartPole-v1_0.99_no_error_clamping.png" />
</div>
<p><strong>For the Acrobot environment, the memory agent is able to learn good policies even when it shows divergence. The same holds for the memory and vanilla agents in the Cart Pole environment.</strong> This contrasts the findings in the Mountain Car environment, where the memory agent only learns a good policy when it doesn’t diverge. It appears that divergence has a larger impact on performance for some environments than for others. There are many possible explanations for this, among which:</p>
<ul>
<li>We hypothesize that the <strong>difficulty of a task</strong> is an important factor in this process. In the simplest environment, Cart Pole, divergence doesn’t seem to be an issue in terms of performance. In the harder environments, however, divergence does seem to affect the quality of the policies. In Acrobot, the variance of the memory agent is very high, and its performance is lower compared to the DQN agent as well. <strong>In the Mountain Car environment, the agent didn’t manage to learn anything for every single run that diverged.</strong> It might be that as the task grows more difficult, having accurate Q value estimates becomes more important.</li>
<li>Another possibility is that our proxy metric for measuring divergence, max |\(Q\)|, is too noisy. It is calculated by keeping track over this quantity for each update transition encountered during the last 20 episodes. <strong>Taking the maximum is not robust to outliers</strong>. If a single high value is encountered in one state, while most of the states are well behaved, this may give a very skewed picture of the divergence in training run.</li>
</ul>
<p>Another important insight is that <strong>adding memory replay improves performance in all our experiments</strong>. The target agent is always improved by adding the memory replay mechanism (resulting in the DQN agent). This corroborates the findings of the original DQN paper, which say that memory replay leads to a better realization of the i.i.d. data assumption, subsequently allowing gradient descent to find a better optimum.</p>
<p><strong>In short, target networks prevent divergence in the learning process. While memory replay does not prevent divergence, it is an important technique that guides the search towards good policies. Combining both tricks gives us the best of both worlds — a controlled divergence setup with good Q-value estimates.</strong></p>
<h2 id="some-final-remarks">Some Final Remarks</h2>
<p>It is always good practice to look critically at obtained results. In this final section, we highlight some <strong>limitations</strong> of our approach:</p>
<ul>
<li>Given our constraints on computation and time, we do not do an exhaustive <strong>hyperparameter search</strong> over our 3 chosen environments. We focused on varying the discount factor and target network update frequency, yet even for those we considered only a few values. This means that the observed behavior might be different had we chosen different sets of hyperparameters. Ideally, we would want to average results over more hyperparameter settings.</li>
<li>Relating to the previous point, we <strong>only use a pretty shallow neural network of 2 layers in all our experiments</strong>. This might cause all methods to have an even harder time learning a difficult task such as the Mountain Car task.</li>
<li>We evaluate 25 seeds per setup. While this is better than 1, we would ideally want to have more seeds to base conclusions on, given the high variance of reinforcement learning methods.</li>
<li>We choose to use a proxy for divergence, soft divergence. <strong>Despite this proxy being theoretically well-motivated, it is still a proxy.</strong> We don’t know how it relates exactly to “actual” divergence.</li>
<li>As mentioned in the previous section, our method of metric for soft divergence might no be very robust to outliers. Future studies could look at more robust versions of the metric.</li>
</ul>
<p>The conclusion that we come to above is not completely unexpected, but the fact that memory replay doesn’t prevent divergence is definitely an interesting insight. Thank you for reading!</p>
<!--- We also looked at the effects of clipping the error term during training, and
understood that netither of tricks is useful without the clipping in the simple environments we tested. -->
<p><strong>Footnotes</strong></p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:semi-gradient" role="doc-endnote">
<p>The true gradient contains an extra term, and usually does not work very well. Instead, semi-gradient methods, which don’t backpropagate through the target Q function \(Q(s', \cdot)\), are usually found to work better. <a href="#fnref:semi-gradient" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Omar Elbaghdadiomarelb[at]gmail[dot]comBy Emil Dudev, Aman Hussain, Omar Elbaghdadi, and Ivan Bardarov. Deep Q Networks (DQN) revolutionized the Reinforcement Learning world. It was the first algorithm able to learn a successful strategy in a complex environment immediately from high-dimensional image inputs. In this blog post, we investigate how some of the techniques introduced in the original paper contributed to its success. Specifically, we investigate to what extent memory replay and target networks help prevent divergence in the learning process.Self-Explaining Neural Networks: A Review2020-03-25T00:00:00+00:002020-03-25T00:00:00+00:00https://omarelb.github.io/self-explaining-neural-networks<p>For many applications, understanding <em>why</em> a predictive model makes a certain prediction can be of crucial importance. In the paper <a href="http://papers.nips.cc/paper/8003-towards-robust-interpretability-with-self-explaining-neural-networks">“Towards Robust Interpretability with Self-Explaining Neural Networks”</a>, David Alvarez-Melis and Tommi Jaakkola propose a neural network model that takes interpretability of predictions into account <em>by design</em>. In this post, we will look at how this model works, how reproducible the paper’s results are, and how the framework can be extended.</p>
<!--more-->
<p>First, a bit of context. This blog post is a by-product of a student project that was
done with a group of 4 AI grad students: <a href="https://github.com/AmanDaVinci">Aman Hussain</a>, <a href="https://github.com/ChristophHoenes">Chris Hoenes</a>, <a href="https://www.linkedin.com/in/ivan-bardarov/">Ivan Bardarov</a>, and me. The goal of
the project was to find out how reproducible the results of the aforementioned paper are, and what ideas could be extended. To
do this, we re-implemented the framework from scratch. We have released this
implementation as a package <a href="https://github.com/AmanDaVinci/SENN">here</a>. In
general, we find that the explanations generated by the framework are not very
interpretable. However, the authors provide several valuable ideas, which we use to propose and evaluate several improvements.</p>
<p>Before we dive into the nitty gritty details of the model, let’s talk about why
we care about explaining our predictions in the first place.</p>
<h2 id="transparency-in-ai">Transparency in AI</h2>
<p>As ML systems become more omnipresent in societal applications such as banking and healthcare, it’s crucial that they should satisfy some important criteria such as safety, not discriminating against certain groups, and being able to provide the right to explanation of algorithmic decisions. The last criterion is enforced by data policies such as the <a href="https://en.wikipedia.org/wiki/General_Data_Protection_Regulation">GDPR</a>, which say that people subject to the decisions of algorithms have a right to explanations of these algorithms.</p>
<p>These criteria are often hard to quantify. Instead, a proxy notion is regularly made use of: <em>interpretability</em>. The idea is that if we can explain <em>why</em> a model is making the predictions it does, we can check whether that reasoning is <em>reliable</em>. Currently, there is not much agreement on what the definition of interpretability should be or how to evaluate it.</p>
<p>Much recent work has focused on interpretability methods that try to understand a model’s inner workings <strong>after</strong> it has been trained. Well known examples are <a href="https://github.com/marcotcr/lime">LIME</a> and <a href="https://github.com/slundberg/shap">SHAP</a>. Most of these methods make no assumptions about the model to be explained, and instead treat them like a black box. This means they can be used on <em>any</em> predictive model.</p>
<p>Another approach is to design models that are transparent <strong>by design</strong>. This approach is also taken by Alvarez-Melis and Jaakkola, who I will be referring to as ‘‘the authors’’ from now on. They propose a self-explaining neural network (SENN) that optimises for transparency <strong>during the learning process</strong>, (hopefully) without sacrificing too much modeling power.</p>
<h2 id="self-explaining-neural-networks-senn">Self Explaining Neural Networks (SENN)</h2>
<p>Before we design a model that generates explanations, we must first think about
what properties we want these explanations to have. The authors suggest that
explanations should have the following properties:</p>
<ul>
<li><strong>Explicitness</strong>: The explanations should be <em>immediate</em> and <em>understandable</em>.</li>
<li><strong>Faithfulness</strong>: Importance scores assigned to features should be <em>indicative of “true” importance</em>.</li>
<li><strong>Stability</strong>: Similar examples should yield <em>similar explanations</em>.</li>
</ul>
<h3 id="linearity">Linearity</h3>
<p>The authors start out with a linear regression model and generalize from there.
The plan is to generalize the linear model by allowing it to be more complex, while
retaining the interpretable properties of a linear model.
This approach is motivated by arguing that a linear model is inherently interpretable.
Despite this claim’s enduring popularity, it is not always entirely valid<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote">1</a></sup>.
Regardless, we continue with the linear model.</p>
<p>Say we have input features \(x_1, \dots, x_n\)
and parameters \(\theta_1, \dots, \theta_n\). The linear model (omitting bias for clarity) returns the following
prediction:
\(f(x) = \sum_{i}^{k} \theta_i x_{i}.\)</p>
<h3 id="basis-concepts">Basis Concepts</h3>
<p>The first step towards interpretability is taken by first computing
<em>interpretable feature representations</em> \(h(x)\) of the input \(x\), which are called
<strong>basis concepts</strong> (or concepts). Instead of acting on the input directly, the model acts on these basis concepts:
\(f(x) = \sum_{i}^{k} \theta_i h(x)_{i}.\)</p>
<!-- reasoning --------------------------------- -->
<!-- Raw input features are the natural basis for interpretability when the input is low-dimensional and -->
<!-- individual features are meaningful. For high-dimensional inputs, raw features (such as individual -->
<!-- pixels in images) often lead to noisy explanations that are sensitive to imperceptible artifacts in the -->
<!-- data, tend to be hard to analyze coherently and not robust to simple transformations such as constant -->
<!-- shifts [9]. Furthermore, the lack of robustness of methods that relies on raw inputs is amplified for -->
<!-- high-dimensional inputs, as shown in the next section. To avoid some of these shortcomings, we can -->
<!-- instead operate on higher level features. In the context of images, we might be interested in the effect -->
<!-- of textures or shapes—rather than single pixels—on predictions. -->
<p>We see that the final prediction is some linear combination of interpretable
concepts. The \(\theta_i\) can now be interpreted as importance or <strong>relevance
scores</strong> for a certain concept \(h(x)_i\).</p>
<p>Say we are given an image \(x\) of a digit and we want to detect which digit it
is. Then each concept \(h(x)_i\) might for instance encode stroke width,
global orientation, position, roundness, and so on.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/concepts.png" alt="concepts" title="opt title" /></p>
<p>Concepts like this could be generated by domain experts, but this is expensive and in many cases infeasible. An alternative approach is to <a href="https://arxiv.org/abs/1711.11279">learn the concepts directly</a>. It cannot be stressed enough that <em>the interpretability of the whole framework depends on how interpretable the concepts are</em>. Therefore, the authors propose <strong>three properties concepts should have</strong> and how to enforce them:</p>
<ol>
<li>
<p><em>Fidelity</em>: The representation of \(x\) in terms of concepts should <strong>preserve relevant information</strong>.</p>
<p>This is enforced by learning the concepts \(h(x)\) as the latent encoding of an <a href="https://www.jeremyjordan.me/autoencoders/">autoencoder</a>. An autoencoder is a neural network that learns to map an input \(x\) to itself by first encoding it into a <strong>lower dimensional representation</strong> with an encoder network \(h\) and then creating a reconstruction \(\hat{x}\) with a decoder network \(h_\mathrm{dec}\), i.e. \(\hat{x} = h_\mathrm{dec}(h(x))\). The lower dimensional representation \(h(x)\), which we call its <em>latent representation</em>, is a vector in a space we call the latent space. <strong>It therefore needs to capture the most important information contained in \(x\)</strong>. This can be thought of as a nonlinear version of <a href="http://setosa.io/ev/principal-component-analysis/">PCA</a>.</p>
</li>
<li>
<p><em>Diversity</em>: Inputs should be representable with <strong>few non-overlapping concepts</strong>.</p>
<p>This is enforced by making the autoencoder mentioned above <strong>sparse</strong>. A <a href="https://medium.com/@syoya/what-happens-in-sparse-autencoder-b9a5a69da5c6">sparse autoencoder</a> is one in which <strong>only a relatively small subset of the latent dimensions activate for any given input</strong>. While this indeed forces an input to be representable with <em>few</em> concepts, it does not really guarantee that they should be <em>non-overlapping</em>.</p>
</li>
<li>
<p><em>Grounding</em>: Concepts should have an <strong>immediate human-understandable interpretation</strong>.</p>
<p>This is a more subjective criterion. The authors aim to do this by providing interpretations for concepts by <a href="https://arxiv.org/abs/1710.04806"><strong>prototyping</strong></a>. For image data, they find a set of observations that <strong>maximally activate</strong> a certain concept and use those as the <em>representation</em> for that concept. While this may seem reasonable at first, we will see later that this approach is quite problematic.</p>
</li>
</ol>
<p>As mentioned above, the approaches taken by the authors to achieve diversity and grounding have
quite some problems. The extension that we introduce aims to mitigate these.</p>
<h3 id="keeping-it-linear">Keeping it Linear</h3>
<p>In the last section, we introduced basis concepts. However, the model is still
too simple, since the relevance scores \(\theta_i\) are fixed. We therefore
make the next generalization: the relevance scores are now also a function of
the input. This leads us to the final model<sup id="fnref:most_general" role="doc-noteref"><a href="#fn:most_general" class="footnote">2</a></sup>:</p>
\[f(x) = \sum_{i}^{k} \theta(x)_i h(x)_{i}.\]
<p>To make the model sufficiently complex, the function computing relevance scores
\(\theta\) is actualized by a neural network. A problem with neural networks,
however, is that they are not very stable. A small change to the input may lead
to a large change in relevance scores. This goes against the <em>stability</em>
criterion for explanations introduced earlier. To combat this, the authors propose
to <strong>make \(\theta\) behave linearly in local regions, while still being sufficiently
complex globally</strong>. This means that a small change in \(h\) should lead to only a small
change in \(\theta\). To do this, they add a regularization term, which we call the
robustness loss<sup id="fnref:robustness_loss" role="doc-noteref"><a href="#fn:robustness_loss" class="footnote">3</a></sup>, to the loss function.</p>
<p>In a classification setting with multiple classes, relevance
scores \(\theta\) are estimated for each class separately. An <strong>explanation</strong> is
then given by the concepts and their corresponding relevance scores<sup id="fnref:product" role="doc-noteref"><a href="#fn:product" class="footnote">4</a></sup>.
The figure below shows an example of such an explanation.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/senn_explanation.jpg" alt="pic alt" title="opt title" /></p>
<p>The authors interpret such an explanation in the following way. Looking at the
first input, a “9”, we see that concept 3 has a positive contribution to the
prediction that it is a 9. Concept 3 seems to represent a horizontal dash
together with a diagonal stroke, which seems to be present in the image of a 9.
As you may have noticed, these explanations also require some imagination.</p>
<h2 id="implementation">Implementation</h2>
<p>Although a public implementation of SENN <em>is</em> <a href="https://github.com/dmelis/senn">available</a>, the authors have not officially released that code with the paper. There also seems to be a major bug in this code. Therefore, we re-implement the framework with the original paper as ground truth.</p>
<p>On a high level, the SENN model consists of three main building blocks: a <em>parameterizer</em> \(\theta\), a <em>conceptizer</em> \(h\), and an aggregator \(g\), which is just the sum function in this case. The parameterizer is actualized by a neural network, and the conceptizer is actualized by an autoencoder. The specific implementations of these networks may vary. For tabular data, we may use fully connected networks and for image data, we may use convolutional networks. Here is an overview of the model:
<img src="/assets/images/posts/2020/self-explaining-neural-networks/teaser.png" alt="pic alt" title="opt title" /></p>
<p>To train the model, the following loss function is minimized:</p>
\[\begin{align}
\mathcal{L} := \mathcal{L}_y(f(x), y) + \lambda \mathcal{L}_\theta(f(x)) + \xi \mathcal{L}_h(x), \label{eq:total_loss}
\end{align}\]
<p>where</p>
<ul>
<li>\(\mathcal{L}_y(f(x), y)\) is the <em>classification loss</em>, i.e. how well the model predicts the ground truth label.</li>
<li>\(\mathcal{L}_\theta(f(x))\) is the <em>robustness loss</em>. \(\lambda\) is a regularization parameter controlling how heavily robustness is enforced.</li>
<li>\(\mathcal{L}_h(x)\) is the <em>concept loss</em>. The concept loss is a sum of 2 different losses: <em>reconstruction loss</em> and <em>sparsity loss</em>. \(\xi\) is a regularization parameter on the concept loss.</li>
</ul>
<h2 id="reproducibility">Reproducibility</h2>
<p>If it is impossible to reproduce a paper’s results, it is hard to
verify the validity of the obtained results. Even if the results are valid, it
will be hard to build upon that work. Reproducibility has therefore recently been an
important topic in many scientific disciplines. In AI research, there
have been initiatives to move towards more reproducible research. NeurIPS, one of the
biggest conferences in AI, introduced a <a href="https://www.cs.mcgill.ca/~jpineau/ReproducibilityChecklist.pdf">reproducibility
checklist</a> last
year. Although it is not mandatory, it is a step in the right direction.</p>
<h3 id="results">Results</h3>
<p>The main goal of our project was to reproduce a subset of the results
presented by the authors. We find that while we are able to reproduce some results,
<em>we are not able to reproduce all of them</em>.</p>
<p>In particular, we are able to <strong>achieve similar test accuracies</strong> on the Compas and MNIST
datasets, a tabular and image dataset respectively, which seems to validate the authors’
claim that SENN models have high modelling capacity. Even though this is the case,
the MNIST and Compas datasets are of low complexity. The relatively high performance
on these datasets is therefore <em>not sufficient to show that SENN models are on par
with non-explaining state-of-the-art models</em>.</p>
<p>We also look at the <strong>trade off between robustness of explanations and model
performance</strong>. We see that enforcing robustness more decreases classification accuracy.
This behavior matches up with the authors’ findings. However, we find
<strong>the accuracy drop for increasing regularization to be significantly larger
than reported by the authors</strong>.</p>
<p>Assessing the <strong>quality of explanations</strong> is inherently subjective. It is therefore difficult
to link the quality of explanations to reproducibility. However, we can still
partially judge reproducibility by qualitative analysis. In general, we see that <strong>finding
an example whose explanation “makes sense” is difficult</strong> and that such an example is
<em>not representative of the generated examples</em>. Therefore, we conclude that obtaining
good explanations is, in that sense, not reproducible.</p>
<h3 id="interpretation">Interpretation</h3>
<p>How do we interpret this failure to reproduce? Although it is tempting to
say that the initial results are invalid, this might be too harsh and
unfair. There are many factors that influence an experiment’s results.
Although our results are discouraging, more experiments
need to be done to have a better estimate of the validity of the framework.
Even if this is the case, the authors provide
a useful starting point for further research. We take a first step, which
involves improving the learned concepts.</p>
<h2 id="disentangled-senn-disenn">Disentangled SENN (DiSENN)</h2>
<p>As we have seen, the current SENN framework hardly generates interpretable explanations.
One important part of this is <strong>the way concepts are represented and learned</strong>.
A single concept is represented by a set of data samples that maximizes that concept’s activation.
The authors reason that we can read off what the concept represents by examining these samples. However, this approach is
problematic.</p>
<p>Firstly, only showcasing the samples that maximize an encoder’s activation is quite
arbitrary. One could just as well showcase samples that <em>minimize</em> the activation
instead, or use any other method. These different approaches lead to <em>different interpretations
even if the learned concepts are the same</em>. The authors also hypothesize that a sparse
autoencoder will lead to <strong>diverse</strong> concepts. However, enforcing only a subset
of dimensions to activate for an input <strong>does not explicitly enforce that these concepts
should be non-overlapping</strong>. For digit recognition, one concept
may for example represent both thickness and roundness at the same time. That is, each
single concept may represent a mix of features that are <em>entangled</em> in some complex
way. This greatly increases the complexity of interpreting any concept on its own. An image, for example, is formed from the complex interaction of light sources, object shapes, and their material properties. We call each of these component features a <em>generative factor</em>.</p>
<p>To enhance concept interpretability, we therefore propose to explicitly enforce <strong>disentangling
the factors of variation in the data (generative factors)</strong> and using these as concepts instead.
Different explanatory factors of the data tend to change independently
of each other in the input distribution, and only a few at a time
tend to change when one considers a sequence of consecutive
real-world inputs. Matching a single generative factor to a single latent dimension
allows for easier human interpretation<sup id="fnref:representation_learning" role="doc-noteref"><a href="#fn:representation_learning" class="footnote">5</a></sup>.</p>
<p>We enforce disentanglement by using
a <a href="https://openreview.net/forum?id=Sy2fzU9gl">\(\beta\)-VAE</a>, a variant
of the <a href="https://arxiv.org/abs/1312.6114">Variational Autoencoder (VAE)</a>,
to learn the concepts. If you are unfamiliar with VAEs, I suggest reading <a href="https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html">this blog post</a>.
Whereas a normal autoencoder maps an input to a single point in the latent
space, a VAE maps an input to a <em>distribution</em> in the latent space. We want,
and enforce, this latent distribution to look “nice” in some sense. A unit Gaussian
distribution is generally the go-to choice for this. This is done by minimizing the <a href="https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained">KL-divergence</a> between
a Gaussian and the learned latent distribution.</p>
<p>\(\beta\)-VAE introduces a hyperparameter \(\beta\) that enables a heavier regularization on the
latent distribution (i.e. higher KL-divergence penalty). The higher \(\beta\), the more the latent space will
be encouraged to look like a unit Gaussian. Since the dimensions
of a unit Gaussian are independent, these latent factors will be
encouraged to be independent as well.</p>
<p>Each latent space dimension corresponds to a concept. Since interpolation in the latent
space of a VAE can be done meaningfully, a concept can be represented by
varying its corresponding dimension’s values and plugging it into the decoder<sup id="fnref:disenn" role="doc-noteref"><a href="#fn:disenn" class="footnote">6</a></sup>.
It is easier to understand visually:</p>
<p align="center">
<img width="460" height="300" src="/assets/images/posts/2020/self-explaining-neural-networks/numbers_small.gif" />
</p>
<p>We see here that the first row represents the width of the number, while the
second row represents a number’s angle. <a href="https://github.com/YannDubs/disentangling-vae">This
figure</a> is created by mapping
an image to the latent space, changing one dimension in this latent space and
then seeing what happens to the reconstruction.</p>
<p>We thus have a principled way to generate prototypes, since meaningful latent space
traversal is an inherent property of VAEs. Another advantage is that prototypes are
not constrained to the input domain. The prototypes
generated by the DiSENN are more complete than highest activation prototypes,
since they showcase a much larger portion of a concept dimension’s latent space. Seeing
the transitions in concept space provides a more intuitive idea of what the concept
means.</p>
<p>We call a disentangled SENN model with \(\beta\)-VAE as the conceptizer a <strong>DiSENN</strong>
model. The following figure shows an overview of the DiSENN model. Only the
conceptizer subnetwork has really changed.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/disenn.png" alt="pic alt" title="opt title" /></p>
<p>We now examine the DiSENN explanations by analyzing a generated DiSENN
explanation for the digit 7.</p>
<p><img src="/assets/images/posts/2020/self-explaining-neural-networks/DiSENN_explanation.png" alt="pic alt" title="opt title" /></p>
<p>The contribution of concept \(i\) to the prediction of a class \(c\) is given by the product
of the corresponding relevance and concept activation \(\theta_{ic} \cdot h_i\). First, we look
at how the concept prototypes are interpreted. To see what a concept encodes, we observe
the changes in the prototypes in the same row. Taking the second row as an example,
we see a circular blob slowly disconnect at the left corner to form a 7, and then
morph into a diagonal stroke. This explains the characteristic diagonal stroke of
a 7 connected with the horizontal stroke at the right top corner but disconnected
otherwise. As expected, this concept has a positive contribution to the prediction
for the real class, digit 7, and a negative contribution to that of another incorrect
class, digit 5.</p>
<p>However, despite the hope that disentanglement encourages diversity, we observe that concepts still demonstrate overlap. This can be seen from concepts 1 and 2 in the previous figure. This means that the concepts are still not disentangled enough, and the problem of interpretability, although alleviated, remains. The progress of good explanations using the DiSENN framework therefore depends on the progress of research in disentanglement.</p>
<h2 id="what-next">What next?</h2>
<p>In this post, we’ve talked about the need for transparency in AI. Self Explaining Neural Networks
have been proposed as one way to achieve transparency even with highly complex
models. We’ve reviewed this framework and the reproducibility of its results,
and find that there is still a lot of work to be done. One promising idea for
extension is to use disentanglement to create more interpretable feature
representations. Interesting further work would be to test how well this
approach works on datasets for which we <em>know</em> the ground truth latent
generative factors, such as the
<a href="https://github.com/deepmind/dsprites-dataset">dSprites dataset</a>.</p>
<p>This blog post is largely based on the project report for this course, which goes into more technical detail. The interested reader can find it <a href="https://github.com/uva-fact-ai-course/uva-fact-ai-course/blob/master/SelfExplainingNNs/report.pdf">here</a>.</p>
<p>I thank Chris, Ivan, and Aman for being incredible project teammates and Simon Passenheim for his guidance throughout the project. Thanks for reading!</p>
<p><strong>Footnotes</strong></p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:2" role="doc-endnote">
<p>According to <a href="https://arxiv.org/abs/1606.03490">Zachary Lipton</a>: “When choosing between linear and deep models, we must often make a trade-off between <em>algorithmic transparency</em> and <em>decomposability</em>. This is because deep neural networks tend to operate on raw or lightly processed features. So if nothing else, the features are intuitively meaningful, and post-hoc reasoning is sensible. However, in order to get comparable performance, linear models often must operate on heavily hand-engineered features (which may not be very interpretable).” <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:most_general" role="doc-endnote">
<p>The authors actually generalize the model one step further by introducing an <em>aggregation function</em> \(g\) such that the final model is given by</p>
\[f(x) = g(\theta(x)_1h(x)_1, \ldots, \theta(x)_kh(x)_k),\]
<p>where \(g\) has properties such that interpretability is maintained. However, for all intents and purposes, it is most reasonable to use the sum function, which is what the authors do in all their experiments as well. <a href="#fnref:most_general" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:robustness_loss" role="doc-endnote">
<p>The robustness loss is given by</p>
\[\begin{equation}
\mathcal{L}_\theta := ||\nabla_x f(x) - \theta(x)^{\mathrm{T}} J_x^{h}(x)||,
\end{equation}\]
<p>where \(J_x^h(x)\) is the Jacobian of \(h\) with respect to \(x\). The idea
is that we want \(\theta(x_0)\) to behave as the derivative of \(f\) with respect to \(h(x)\)
around \(x_0\) , i.e., we seek \(\theta(x_0) \approx \nabla_z f\). For more
detailed reasoning, see the paper. <a href="#fnref:robustness_loss" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:product" role="doc-endnote">
<p>It actually does not make sense to look only at relevance scores. We have to take into account the product \(\theta_i\cdot h_i\), since it’s this product that determines the contribution to the class prediction. If an \(h_i\) has a negative activation, then a positive relevance leads to a negative overall contribution. <a href="#fnref:product" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:representation_learning" role="doc-endnote">
<p>A disentangled representation may be viewed as a concise representation of the variation in data we care about most – the generative factors. See <a href="https://arxiv.org/abs/1206.5538">a review of representation learning</a> for more info. <a href="#fnref:representation_learning" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:disenn" role="doc-endnote">
<p>Let an input \(x\) produce the Gaussian encoding distribution for a single concept \(h(x)_i = \mathcal{N}(\mu_i, \sigma_i)\). The concept’s activation for this input is then given by \(\mu_i\). We then vary a single latent dimension’s values around \(\mu_i\) while keeping the others fixed, call it \(\mu_c\). If the concepts are disentangled, a single concept should encode only a single generative factor of the data. The changes in the reconstructions \(\mathrm{decoder}(\mu_c)\) will show which generative factor that latent dimension represents. We plot these changes in the reconstructed input space to visualize this. \(\mu_c\) is sampled linearly in the interval \([\mu_i - q, \mu_i + q]\), where \(q\) is some quantile of \(h(x)_i\). <a href="#fnref:disenn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Omar Elbaghdadiomarelb[at]gmail[dot]comFor many applications, understanding why a predictive model makes a certain prediction can be of crucial importance. In the paper “Towards Robust Interpretability with Self-Explaining Neural Networks”, David Alvarez-Melis and Tommi Jaakkola propose a neural network model that takes interpretability of predictions into account by design. In this post, we will look at how this model works, how reproducible the paper’s results are, and how the framework can be extended.Variational Bayesian Inference: A Fast Bayesian Take on Big Data.2019-08-22T00:00:00+00:002019-08-22T00:00:00+00:00https://omarelb.github.io/variational-bayes<p>Compared to the frequentist paradigm, <a href="https://en.wikipedia.org/wiki/Bayesian_inference">Bayesian inference</a> allows more readily for dealing with and interpreting uncertainty, and for easier incorporation of prior beliefs.</p>
<p>A big problem for traditional Bayesian inference methods, however, is that they are <strong>computationally expensive</strong>. In many cases, computation takes too much time to be used reasonably in research and application. This problem gets increasingly apparent in today’s world, where we would like to make good use of the <strong>large amounts of data</strong> that may be available to us.</p>
<!--more-->
<p>Enter our savior: <strong>variational inference (VI)</strong>—a much faster method than those used traditionally. This is great, but as usual, there is no such thing as free lunch, and the method has some caveats. But all in due time.</p>
<p>This write-up is mostly based on the first part of the <a href="https://www.youtube.com/watch?v=DYRK0-_K2UU">fantastic 2018 ICML tutorial session on the topic</a> by professor <a href="https://people.csail.mit.edu/tbroderick/">Tamara Broderick</a>. If you like video format, I would recommend checking it out.</p>
<h1 id="overview">Overview</h1>
<p>The post is outlined as follows:</p>
<ul>
<li>What is Bayesian inference and why do we use it in the first place</li>
<li>How Bayesian inference works—a quick overview</li>
<li>The problem, a solution, and a faster solution</li>
<li>Variational Inference and the Mean Field Variational Bayes (MFVB) framework</li>
<li>When can we trust our method</li>
<li>Conclusion</li>
</ul>
<h3 id="a-birds-eye-view">A bird’s-eye view:</h3>
<p>We need Bayesian inference whenever we want to know the <strong>uncertainty of our estimates</strong>. Bayesian inference works by specifying some <strong>prior belief distribution</strong>, and <strong>updating our beliefs</strong> about that distribution with data, based on the <strong>likelihood</strong> of observing that data. We need <strong>approximate algorithms</strong> because standard algorithms need too much time to give usable estimates. <strong>Variational inference</strong> uses <strong>optimization</strong> instead of estimation to <strong>approximate</strong> the true distribution. We get results <strong>much more quickly</strong>, but they are <strong>not always correct</strong>. We have to find out <strong>when we can trust the obtained results</strong>.</p>
<p class="notice--info">A note on notation: \(P(\cdot)\) is used to describe both probabilities and probability distributions.</p>
<h1 id="what-is-bayesian-inference-and-why-do-we-use-it-in-the-first-place">What is Bayesian inference and why do we use it in the first place</h1>
<p>Probability theory is a mathematical framework for reasoning about <strong>uncertainty</strong>. Within the subject exist two major schools of thought, or paradigms: the <strong>frequentist</strong> and the <strong>Bayesian</strong> paradigms. In the frequentist paradigm, probabilities are interpreted as average outcomes of random repeatable events, while the Bayesian paradigm provides a way to reason about probability as <strong>a measure of uncertainty</strong>.</p>
<p><strong>Inference</strong> is the process of finding properties of a population or probability distribution from data. Most of the time, these properties are encoded by <strong>parameters</strong> that govern our model of the world.</p>
<p>In the frequentist paradigm, a parameter is assumed to be a fixed quantity unknown to us. Then, a method such as <strong>maximum likelihood (ML)</strong> is used to obtain a <strong>point estimate</strong> (a single number) of the parameter. In the Bayesian paradigm, parameters are not seen as fixed quantities, but as random variables themselves. The <strong>uncertainty in the parameters</strong> is then specified by a <strong>probability distribution</strong> over its values. Our job is to find this probability distribution over parameters <strong>given our data and prior beliefs.</strong></p>
<p>The frequentist and Bayesian paradigms both have their pros and cons, but there are multiple reasons why we might want to use a Bayesian approach. The following reasons are given in a <a href="https://www.quora.com/What-are-some-good-resources-to-learn-about-Bayesian-probability-for-machine-learning-and-how-should-I-structure-my-learning">Quora answer by Peadar Coyle</a>. To summarize:</p>
<ul>
<li>
<p><strong>Explicitly modelling your data generating process</strong>: You are forced to <strong>think carefully about your assumptions</strong>, which are often implicit in other methods.</p>
</li>
<li>
<p><strong>No need to derive estimators</strong>: Being able to treat model fitting as an abstraction is great for <strong>analytical productivity</strong>.</p>
</li>
<li>
<p><strong>Estimating a distribution</strong>: You <strong>deeply understand uncertainty</strong> and get a full-featured input into any downstream decision you need to make.</p>
</li>
<li>
<p><strong>Borrowing strength / sharing information</strong>: A common feature of Bayesian analysis is <strong>leveraging multiple sources of data</strong> (from different groups, times, or geographies) to share related parameters through a prior. This can help enormously with precision.</p>
</li>
<li>
<p><strong>Model checking as a core activity</strong>: There are <strong>principled, practical procedures</strong> for considering a wide range of <strong>models that vary in assumptions and flexibility</strong>.</p>
</li>
<li>
<p><strong>Interpretability of posteriors</strong>: What a posterior means <strong>makes more intuitive sense</strong> to people than most statistical tests.</p>
</li>
</ul>
<p>The debate between frequentists and Bayesians about which is better can get quite intense. I personally believe that no single point of view is better in any situation. We need to think carefully and apply the method that is most appropriate for a given situation, be it frequentist or Bayesian. One infamous <a href="https://www.xkcd.com/">xkcd</a> comic, given below, addresses this debate.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/frequentists-vs-bayesians.png" alt="image" style="width:400px; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
xkcd comic on frequentist vs Bayesian views. The comic was quite controversial itself. Many thought that the frequentist was treated unfairly. The artist himself <a href="http://web.archive.org/web/20130117080920/http://andrewgelman.com/2012/11/16808/#comment-109366">later commented</a>:
<blockquote>
I meant this as a jab at the kind of shoddy misapplications of statistics I keep running into in things like cancer screening (which is an emotionally wrenching subject full of poorly-applied probability) and political forecasting. I wasn’t intending to characterize the merits of the two sides of what turns out to be a much more involved and ongoing academic debate than I realized.
A sincere thank you for the gentle corrections; I’ve taken them to heart, and you can be confident I will avoid such mischaracterizations in the future!
</blockquote>
Another discussion can be found <a href="https://www.lesswrong.com/posts/mpTEEffWYE6ZAs7id/xkcd-frequentist-vs-bayesians">here</a>.
</figcaption>
</figure>
</div>
<h2 id="what-problems-is-it-used-for">What problems is it used for?</h2>
<p>There are many cases in which we care not only about our estimates, but also how confident we are in those estimates. Some examples:</p>
<ul>
<li>
<p><strong>Finding a wreckage</strong>: In 2009, a passenger plane crashed over the atlantic ocean. For two years, investigators had not been able to find the wreckage of the plane. In the third year, after bringing in Bayesian analysis, the wreckage was found after one week of undersea search <a href="https://arxiv.org/pdf/1405.4720.pdf">(Stone et al 2014)</a>!</p>
</li>
<li>
<p><strong>Routing</strong>: Understanding the time it takes for vehicles to get from point A to point B. This could be ambulance routing for instance. Knowing the uncertainty in the estimates is important when planning <a href="https://people.orie.cornell.edu/woodard/WoodNogiKoch17.pdf">(Woodard et al 2017)</a>.</p>
</li>
<li>
<p><strong>Microcredit</strong>: Is microcredit actually helping? Knowing the extent of the microcredit effect and our certainty about it may be used to make decisions such as making an investment <a href="https://economics.mit.edu/files/11443">(Meager et al 2019)</a>.</p>
</li>
</ul>
<p>These are just some of the applications. There are many more. Hopefully you are convinced that we are doing something useful. Let’s move on to the actual techniques.</p>
<h1 id="how-bayesian-inference-worksa-quick-overview">How Bayesian inference works—a quick overview</h1>
<p>The <strong>first step</strong> in any inference job is defining a <strong>model</strong>. We might for example model heights in a population of, say, penguins, as being generated by a Gaussian distribution with mean \(\mu\) and variance \(\sigma^2\).</p>
<p>Our goal is then to find a probability distribution over the parameters in our model, \(\mu\) and \(\sigma^2\), given the data that we have collected. This distribution, also called the <strong>posterior</strong>, is given by</p>
\[P(\theta \vert y_{1:n}) = \frac{P(y_{1:n} \vert \theta) P(\theta)}{P(y_{1:n})},\]
<p>where \(y_{1:n}\) represents the dataset containing \(n\) observations and \(\theta\) represents the parameters. This identity is known as <strong>Bayes’ Theorem</strong>.</p>
<p>In words, the posterior is given by a product of the <strong>likelihood</strong> \(P(y_{1:n} \vert\theta)\) and the <strong>prior</strong> \(P(\theta)\), <strong>normalized by the evidence</strong> \(P(y_{1:n})\).</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/teaser.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center"> Sketch of the Bayesian update. The posterior (blue) is obtained after multiplying the prior (red) with the likelihood (black). In a sequential estimation procedure, the new prior is the posterior obtained in the previous step. </figcaption>
</figure>
</div>
<p>The likelihood \(P(y_{1:n} \vert\theta)\) is often seen as a function of \(\theta\), and tells us how likely it is to have observed our data given a specific setting of the parameters. <strong>The prior \(P(\theta)\) encapsulates beliefs</strong> we have about the parameters before observing any data. <strong>We update our beliefs about the parameter distribution after observing data</strong> according to Bayes’ rule.</p>
<p>After obtaining the posterior distribution, <strong>we would like to report a summary</strong> of it. We usually do this by providing a <strong>point estimate and an uncertainty</strong> surrounding the estimate. A point estimate may for example be given by the <strong>posterior mean or mode</strong>, and uncertainty by the <strong>posterior (co)variances</strong>.</p>
<p>In summary, there are <strong>three major steps</strong> involved in doing Bayesian inference:</p>
<ol>
<li><strong>Choose a model</strong> i.e. choose a prior and likelihood.</li>
<li><strong>Compute the posterior</strong>.</li>
<li><strong>Report a summary</strong>: point estimate and uncertainty, e.g. posterior means and (co)variances.</li>
</ol>
<p>In this post, we will not be concerned with the first step, and instead focus on the last two steps: computing the posterior and reporting a summary. This means that we assume someone has already done the modeling for us and has asked us to report back something useful.</p>
<h1 id="the-problem-a-solution-and-a-faster-solution">The problem, a solution, and a faster solution</h1>
<p>Computing the posterior and reporting summaries is generally not a simple task. To see why, consider again the equation we use to compute the posterior:</p>
\[P(\theta \vert y_{1:n}) = \frac{P(y_{1:n} \vert \theta) P(\theta)}{P(y_{1:n})}.\]
<p>One issue is that typically, <strong>there is no nice closed-form solution for this expression</strong>. Usually, we can’t solve it analytically or find a standard easy-to-use distribution. Another issue is that to find a point estimate such as the mean, or to compute the normalizing constant \(P(y_{1:n})\), <strong>we need to integrate over a high-dimensional space</strong>. This is a fundamentally difficult problem, and is also known as the <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a>. The problem is exacerbated by the large and high-dimensional datasets we have these days. (Why does calculating \(P(y_{1:n})\) involve integration? Because it is a marginal obtained by integrating out the parameters: \(P(y_{1:n}) = \int P(y_{1:n}, \theta) d\theta\).)</p>
<p>This is where <strong>approximate Bayesian inference</strong> comes in. The gold standard in this area has been <strong>Markov Chain Monte Carlo (MCMC)</strong>. It is extremely widely used and has been called one of the top 10 most influential algorithms of the 20th century. <strong>MCMC is eventually accurate, but it is slow</strong>. In many applications, we simply don’t have the time to wait until the computation is finished.</p>
<p>A faster method is called <strong>Variational Inference (VI)</strong>. In this post, we’ll take a deeper dive into how this works.</p>
<h1 id="variational-inference-and-the-mean-field-variational-bayes-mfvb-framework">Variational Inference and the Mean Field Variational Bayes (MFVB) framework</h1>
<p>The main idea in VI is that <strong>instead of trying to find the real posterior distribution \(p(\cdot \vert y)\), we approximate it with a distribution \(q\)</strong>. Of course, not every distribution would be useful as an approximation. For some distributions, measures that we’d like to report, like the mean and variance, can’t be found. We therefore restrict our search to \(Q\), the space of distributions that have certain “nice” properties. We’ll discuss later what “nice” means exactly. <strong>In this space, we search for a distribution \(q^*\) that minimizes a certain measure of dissimilarity to \(p\)</strong>. Mathematically:</p>
\[q^* = argmin_{q\in Q} f(q(\cdot), p(\cdot \vert y)),\]
<p>where \(f\) is some measure of dissimilarity.</p>
<h2 id="kl-divergence">KL-divergence</h2>
<p>There are many measures of dissimilarity to choose from, but one is particularly useful: the <strong>Kullback-Leibler divergence, or <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a></strong>.</p>
<p>KL-divergence is a concept from <strong>Information Theory</strong>. For distributions \(p\) and \(q\), the KL-divergence is given by</p>
\[KL(p\ \vert\vert\ q) = \int p(x)\ln \frac{p(x)}{q(x)}dx.\]
<p><a href="https://www.cs.cmu.edu/~epxing/Class/10708-17/notes-17/10708-scribe-lecture13.pdf">Intuitively, there are three cases of importance</a>:</p>
<ul>
<li>If <strong>\(p\)</strong> is <strong>high</strong> and <strong>\(q\)</strong> is <strong>high</strong>, then we are <strong>happy</strong> i.e. low KL-divergence.</li>
<li>If <strong>\(p\)</strong> is <strong>high</strong> and <strong>\(q\)</strong> is <strong>low</strong> then we <strong>pay a price</strong> i.e. high KL-divergence.</li>
<li>If <strong>\(p\)</strong> is <strong>low</strong> then <strong>we don’t care</strong> i.e. also low KL-divergence, <strong>regardless of \(q\)</strong>.</li>
</ul>
<p>The following figure illustrates KL-divergence for two normal distributions \(\pi_1\) and \(\pi_2\). A couple of things to note: divergence is indeed high when \(p\) is high and \(q\) is low; divergence is 0 when \(p = q\); and the complete KL-divergence is given by the area under the green curve.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/KL-example.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
KL divergence between two normal distributions. In this example \(\pi_1\) is a standard normal distribution and \(\pi_2\) is a normal distribution with a mean of 1 and a variance of 1. The value of the KL divergence is equal to the area under the curve of the function. <a href="https://www.researchgate.net/publication/319662351_Using_the_Data_Agreement_Criterion_to_Rank_Experts'_Beliefs">(Source)</a> </figcaption>
</figure>
</div>
<p>Loosely speaking, KL-divergence can be interpreted as <strong>the amount of information that is lost when \(q\) is used to approximate \(p\)</strong>. I won’t be going much deeper into it, but it has a couple of properties that are interesting for our purposes. The KL-divergence is:</p>
<ul>
<li>
<p><strong>Not symmetric</strong>: \(KL(p\ \vert\vert\ q) \neq KL(q\ \vert\vert\ p)\) in general. It can therefore not be interpreted as a distance measure, which is required to be symmetric.</p>
<p>We will be using the KL-divergence \(KL(q\ \vert\vert\ p)\). It is possible to use the <strong>reverse KL-divergence</strong> \(KL(p\ \vert\vert\ q)\) as well. Let’s examine the differences.</p>
<p>In practical applications, the true posterior will often be a multimodal distribution. Minimizing KL-divergence leads to <strong>mode-seeking</strong> behavior, which means that most probability mass of the approximating distribution \(q\) <strong>is centered around a mode of \(p\)</strong>. Minimizing reverse KL-divergence leads to <strong>mean-seeking</strong> behavior, which means that \(q\) would <strong>average across all of the modes</strong>. This would typically lead to poor predictive performance, since the average of two good parameter values is usually not a good parameter value itself. This is illustrated in the following figure.</p>
</li>
</ul>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/KL-inclusive-exclusive.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">
Minimizing \(KL(q\ \vert\vert\ p)\) versus \(KL(p\ \vert\vert\ q)\). The first (exclusive) leads to mode-seeking behavior, while the latter (inclusive) leads to mean-seeking behavior. (Source <a href="https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function">Tim Vieira's blog</a>, figure by <a href="http://www.johnwinn.org/">John Winn</a>.)
</figcaption>
</figure>
</div>
<p>For a more in-depth discussion, see Tim Vieira’s blog post <a href="https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function">KL-divergence as an objective function</a>.</p>
<ul>
<li><strong>Always \(\geq 0\)</strong>, with equality only when \(p = q\). Lower KL-divergence thus implies higher similarity.</li>
</ul>
<p>Most useful for us though is the following. We are optimizing \(q\) to be as close as possible to the real distribution \(p\), but we don’t actually know \(p\). <strong>How do we find a distribution close to \(p\) if we don’t even know what \(p\) itself is?</strong> It turns out that we can solve this problem by doing some algebraic manipulation. This is huge. Let’s derive the necessary expression:</p>
\[\begin{align}
KL(q\ \vert\vert\ p(\cdot \vert y)) :=& \int q(\theta)\log \frac{q(\theta)}{p(\theta \vert y)}d\theta \\
=& \int q(\theta)\log \frac{q(\theta)p(y)}{p(\theta , y)} d\theta \\
=&\ \log p(y)\int q(\theta)d\theta - \int q(\theta)\log \frac{p(y, \theta)}{q(\theta)} d\theta.\\
=&\ \log p(y) - \int q(\theta)\log \frac{p(y \vert \theta) p(\theta)}{q(\theta)} d\theta.
\end{align}\]
<p>Here we use Bayes’ theorem to substite out \(p(\theta\vert y)\) in the second line. Then, we use the property of logarithms \(\log(ab) = \log(a) + \log(b)\), together with the fact that \(p(y)\) doesn’t depend on \(\theta\), and that \(\int q(\theta) d\theta = 1\) since \(q(\theta)\) is a probability distribution over \(\theta\), to arrive at the result. Phew, that was a whole mouthful.</p>
<p>Since \(p(y)\) is fixed, we only need to consider the second term, which has a name: the <strong>Evidence Lower Bound (ELBO)</strong>. We can see from the last equation why it is called this way. \(KL(q\ \vert\vert\ p) \geq 0\) implies \(\log p(y) \geq \text{ ELBO}\). It is thus a lower bound on the log evidence \(\log p(y)\).</p>
<p>To minimize KL-divergence, we thus need to maximize the ELBO. <strong>The ELBO depends on \(p\) only through the likelihood and prior, which we already know!</strong> This is something we can actually compute without having to know the real distribution!</p>
<h2 id="mean-field-variational-bayes-mvfb">Mean Field Variational Bayes (MVFB)</h2>
<p>I promised to tell you what kinds of distributions we think of as “nice”. Firstly, we want to be able to report a mean and a variance, so these must exist. We then make <strong>the MFVB assumption</strong>, also known as <strong>Mean-Field Approximation</strong>. The approximation is a simplifying assumption for our distribution \(q\), which <strong>factorizes the distribution into independent parts</strong>:</p>
\[q(\theta) = \prod_i q_i(\theta_i).\]
<p>From a statistical physics point of view, “mean-field” refers to the relaxation of a difficult optimization problem to a simpler one which ignores second-order effects. The optimization problem becomes easier to solve.</p>
<p>Note that this is <strong>not a modeling assumption</strong>. We are <strong>not</strong> saying that the parameters in our model are <strong>independent</strong>, which would limit us only to uninteresting models. We are only saying that the parameters are independent in our <strong>approximation</strong> of the posterior.</p>
<p>We often also assume a distribution from the <strong>exponential family</strong>, since these have nice properties that make life easier.</p>
<p>Now that we have defined a space and metric to optimize over, <strong>we have a clearly defined optimization problem</strong>. At this point, we can use any optimization technique we’d like to find \(q^*\). Typically, <a href="https://en.wikipedia.org/wiki/Coordinate_descent">coordinate gradient descent</a> is used.</p>
<h1 id="when-can-we-trust-our-method">When can we trust our method</h1>
<p>Since Variational Inference is an approximate method, we’d like to know <strong>how accurate</strong> the approximation actually is. In other words, when we can trust it. If we schedule an ambulance based on the prediction that it will take 10 minutes to arrive, we have to be damn sure that our <strong>confidence in the prediction is justified</strong>.</p>
<p>One way to check whether the method works is to consider a simple example that we know the correct answer to. We can then see how well MFVB approximates that answer. To do this, we consider the (rather randomly chosen) problem of estimating midge wing length.</p>
<h2 id="estimating-midge-wing-length">Estimating midge wing length</h2>
<p><em>Note that understanding all the details in this example is not required for an understanding of the big picture.</em></p>
<p>Midge is a term used to refer to many species of small flies. Before we can compute a posterior, we need a model. As before, we assume that we are given a model, which is determined by a likelihood and a prior. We only need to worry about computing a posterior and reporting back a summary. The model given to us is the following.</p>
<p>Assume midge wing length is <strong>normally distributed</strong> with unknown mean \(\mu\) and unknown precision \(\tau\). (We use precision, the inverse of variance, because it’s mathematically more convenient.) Let \(y\) be the midge wing length. We care about finding the <strong>posterior</strong></p>
\[p(\mu, \tau \vert y_{1:N}) \propto p(y_{1:N} \vert \mu, \tau) p(\mu, \tau).\]
<p>The <strong>likelihood</strong> is then given by</p>
\[p(y_{1:N} \vert \mu, \tau) = \prod_i \mathcal{N}(y_i \vert \mu, \tau^{-1}),\]
<p>where \(\mathcal{N}(\cdot)\) denotes the normal distribution.</p>
<p>The <a href="https://en.wikipedia.org/wiki/Conjugate_prior"><strong>conjugate prior</strong></a> for a Gaussian with unknown mean and variance is a <a href="https://en.wikipedia.org/wiki/Normal-gamma_distribution">Gaussian-Gamma distribution</a> given by</p>
\[p(\mu, \tau) = \mathcal{N}(\mu \vert \mu_0, (\beta\tau)^{-1})Gamma(\tau \vert a, b),\]
<p>where \(Gamma\) is the <a href="https://en.wikipedia.org/wiki/Gamma_distribution">Gamma distribution</a> and \(\mu_0, \beta, a, b\) are <strong>hyperparameters</strong>.</p>
<p>To start solving the problem, we first use the <strong>mean-field assumption</strong> resulting in the factorization:</p>
\[q^*(\mu, \tau) = q_\mu^*(\mu)q_\tau^*(\tau) = argmin_{q\in Q_{MFVB}} KL(q(\cdot)\ \vert\vert\ p(\cdot \vert y)).\]
<p>The factors \(q_\mu^*(\mu)\) and \(q_\tau^*(\tau)\) can be derived [Bishop 2006, Sec. 10.1.3]:</p>
\[\begin{align}
q_\mu^*(\mu) &= \mathcal{N}(\mu \vert m_\mu, \rho_\mu^2)\\
q_\tau^*(\tau) &= Gamma(\tau \vert a_\tau, b_\tau),
\end{align}\]
<p>where “variational parameters” \(m_\mu, \rho_\mu^2, a_\tau, \text{ and } b_\tau\) determine the approximating distribution.</p>
<p>We then <strong>iterate</strong>. First, we make an initial guess for the variational parameters. Then, we cycle through each factor. We find the approximating distribution of \(\mu\) given the distribution of \(\tau\) in one step and the approximating distribution of \(\tau\) given the distribution of \(\mu\) in another step:</p>
\[\begin{align}
(m_\mu, \rho_\mu^2) &= f(a_\tau, b_\tau)\\
(a_\tau, b_\tau) &= g(m_\mu, \rho_\mu^2).
\end{align}\]
<p>We repeat this procedure until convergence.</p>
<p>The following figure shows how our approximation (blue) of the real posterior (green) gets more and more accurate by applying coordinate descent, resulting in quite a good approximation (red).</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig1.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">The process of variational approximation to the Gaussian-gamma distribution. Our approximation (blue) of the real posterior (green) gets more and more accurate by applying coordinate descent, resulting in quite a good approximation (red) (Source: PRML, Bishop 2006)</figcaption>
</figure>
</div>
<h2 id="variance-underestimation">Variance underestimation</h2>
<p>One of the major problems that shows up is that <strong>the variational distribution often underestimates the variance of the real posterior</strong>. This is a result of minimizing the <strong>KL-divergence</strong>, which encourages a small value of the approximating distribution when the true distribution has a small value. This is showcased in the next figure, where MFVB is used to fit a multivariate Gaussian. The mean is correctly captured by the approximation, but the variance is severely underestimated. This gets progressively worse as the correlation between the two variables increases.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig2.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">The MFVB approximation of the true distribution, a multivariate Gaussian, severely underestimates the true variance. This gets worse as the correlation increases (Source: <a href="http://www.gatsby.ucl.ac.uk/~maneesh/papers/turner-sahani-2010-ildn.pdf">Turner, Sahani 2010</a>).</figcaption>
</figure>
</div>
<p>Another way to test the validity of our approximations is to compare them to the answers of <strong>a method that we know works: MCMC</strong>. We can use this for more complex problems for which we cannot find an analytical solution. One real-life application deals with <strong>microcredit</strong>.</p>
<p>Microcredit is an initiative that helps impoverished people become self-employed or start a business by giving them extremely small loans. Of course, we’d like to know if this approach actually has a positive effect, how large this effect is, and <strong>how certain</strong> we are of that.</p>
<p>The next figure shows once more that MFVB estimates of microcredit effect variance indeed underestimate the true variance.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig4.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">Microcredit effect variance estimates given by MCMC versus MFVB. MFVB underestimates variance (Source: Giordano, Broderick, Jordan 2016).</figcaption>
</figure>
</div>
<h2 id="mean-estimates">Mean estimates</h2>
<p>MFVB not always getting the variance right begs another question: Can estimates of the mean be incorrect too? The answer is yes, as demonstrated by the following two figures. In the first figure, MFVB estimates for the mean of a parameter \(\nu\) disagree with MCMC estimates. The same happens in the second figure, where some MFVB estimates are even outside the 95% credible interval obtained by MCMC.</p>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig3.png" alt="image" style="width:60%; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">MFVB estimates for the mean of a parameter disagree with MCMC estimate (Source: Giordano, Broderick, Jordan 2015).</figcaption>
</figure>
</div>
<div class="figure">
<figure>
<img class="align-center" src="/assets/images/posts/2019/variational-bayes/variational_bayes_fig5.png" alt="image" style="width:; align:center;margin-bottom: 0.4em;" />
<figcaption class="align-center">MFVB estimates disagree with MCMC estimates (Source: <a href="https://digital.lib.washington.edu/researchworks/bitstream/handle/1773/24305/Fosdick_washington_0250E_12238.pdf?sequence=1&isAllowed=y">Fosdick 2013)</a>.</figcaption>
</figure>
</div>
<h2 id="what-can-we-do">What can we do?</h2>
<p>We’ve seen that MFVB doesn’t always produce accurate approximations. What do we do then? Some major lines of research to alleviate this problem are:</p>
<ul>
<li>
<p><strong>Reliable diagnostics</strong>: <strong>Fast procedures</strong> that tell us <strong>after the fact</strong> if the approximation is good. One way to achieve this might be to find a fast way to find the KL-divergence of our approximation, since we know it is bounded below by 0. Usually, we only have access to the ELBO, of which we don’t have such a bound.</p>
</li>
<li>
<p><strong>Richer “nice” set</strong>: In the MFVB framework, we only consider optimizing over a set of functions that factorize. Considering a richer set of functions might help. It turns out though that having a richer nice set doesn’t necessarily yield better approximations. We’d have to make other assumptions that complicate the problem as well.</p>
</li>
<li>
<p><strong>Alternative divergences</strong>: Minimizing other divergences than the KL-divergence might help, but has similar difficulties as the above point.</p>
</li>
<li>
<p><strong>Data compression</strong>: Before using an inference algorithm, we can consider doing a <strong>preprocessing step in which the data is compressed</strong>. We would like to have theoretical guarantees on the quality of our inference methods on this compressed dataset.</p>
</li>
</ul>
<p>Until we have a foolproof way to test for the reliability of the estimates obtained by MFVB, it is important to be wary of the results obtained, as they may not always be correct.</p>
<h1 id="summary">Summary</h1>
<p>We’ve discussed how Bayesian Variational Inference works. By framing the problem as an optimization problem, we can find results much faster compared to the classic MCMC algorithm. This comes at a price: we don’t always know when our approximation is accurate. This is still very much an open problem that researchers are working on.</p>
<h2 id="related-publications">Related Publications</h2>
<p>If you would like to read further, here are some <strong>related publications</strong>:</p>
<ul>
<li>Bishop. Pattern Recognition and Machine Learning, Ch 10. 2006.</li>
<li>Blei, Kucukelbir, McAuliffe. Variational inference: A review for statisticians, JASA2016.</li>
<li>MacKay. Information Theory, Inference, and Learning Algorithms, Ch 33. 2003.</li>
<li>Murphy. Machine Learning: A Probabilistic Perspective, Ch 21. 2012.</li>
<li>Ormerod, Wand. Explaining Variational Approximations. Amer Stat 2010.</li>
<li>Turner, Sahani. Two problems with variational expectation maximisation for time-series models. In Bayesian Time Series Models, 2011.</li>
<li>Wainwright, Jordan. Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 2008.</li>
</ul>
<p><strong>More Experiments</strong>:</p>
<ul>
<li>RJ Giordano, T Broderick, and MI Jordan. Linear response methods for accurate covariance estimates from mean field variational Bayes. NIPS 2015.</li>
<li>RJ Giordano, T Broderick, R Meager, J Huggins, and MI Jordan. Fast robustness quantification with variational Bayes. ICML Data4Good Workshop 2016.</li>
<li>RJ Giordano, T Broderick, and MI Jordan. Covariances, robustness, and variational Bayes, 2017. Under review. ArXiv:1709.02536.</li>
</ul>
<p><em>Thanks for reading! Any thoughts? Leave them in the comments below!</em></p>Omar Elbaghdadiomarelb[at]gmail[dot]comCompared to the frequentist paradigm, Bayesian inference allows more readily for dealing with and interpreting uncertainty, and for easier incorporation of prior beliefs. A big problem for traditional Bayesian inference methods, however, is that they are computationally expensive. In many cases, computation takes too much time to be used reasonably in research and application. This problem gets increasingly apparent in today’s world, where we would like to make good use of the large amounts of data that may be available to us.Hello World!2018-10-28T00:00:00+00:002018-10-28T00:00:00+00:00https://omarelb.github.io/hello-world<p>Hi there! This is my first blog post ever! I want to start writing about topics in artificial intelligence and machine learning.</p>
<p>I got inspired to start blogging by several other bloggers, including <a href="https://medium.com/@racheltho/why-you-yes-you-should-blog-7d2544ac1045">Rachel Thomas</a>, <a href="https://machinelearningmastery.com/about/">Jason Brownlee</a>, and <a href="https://ayearofai.com/0-2016-is-the-year-i-venture-into-artificial-intelligence-d702d65eb919">Rohan and Lenny</a>. The key reasons for doing this are:</p>
<ul>
<li>Building up a portfolio,</li>
<li>helping understand stuff.</li>
</ul>
<p>As of around a year ago, I began to grow increasingly interested in Artificial Intelligence and Computer Science. So much so, that I decided I want to apply for a master’s degree in the subject. My background in Econometrics and Operations Research has given me a solid base in statistical modelling and optimization, which I believe is of core importance to many applications in AI.</p>
<p>There are still way too many subjects, however, that I don’t know too much about. By writing about them and explaining concepts as simply as I possibly can, I hope to really deepen my knowledge on subjects, improve my technical writing and help make things clear for others at the same time.</p>
<p>I want to make learning about Artificial Intelligence a project, a journey. There are so many exciting advances in the field, so there is a lot to catch up with. Real world applications are becoming increasingly widespread as well, even hitting headlines regularly, making it a science that’s interesting even at parties. I guess it’s the perfect subject for the phrase: <em>“Let’s make nerds cool again.”</em></p>
<p>This post is the start of a journey for me, let’s make it an interesting ride! Also, I love meeting new people, so don’t hesitate to contact me!</p>Omar Elbaghdadiomarelb[at]gmail[dot]comHi there! This is my first blog post ever!