Computational Bayes: Variational Inference

Author

Zach Smith

smthzch.github.io

Repo for this notebook.

import numpy as np
import scipy.stats as ss
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from scipy.special import logsumexp

np.random.seed(2023)

This post continues from the previous post exploring different Bayesian inference techniques and the computations behind them. The focus is on the what and how and not on the why. Mathematical derivations for why a technique works are left for further research, instead I hope to demonstrate what each inference technique aims to accomplish and show how to accomplish this by implementing simple solutions with minimal tooling to make the mechanics visible and easily understood. I also hope to show how each technique relates to the other to help provide a unified picture of what we are trying to accomplish in Bayesian inference.

This post focuses on variational inference (VI). While MCMC aims to produce samples from a potentially intractable posterior, in VI we specify an approximate distribution for the posterior and optimize the parameters to most closely fit the true posterior.

In this post we will look at two main problems in Bayesian inference

  1. Parameter inference (global parameters)
  2. Latent variable inference (local parameters)

The parameter inference section will follow along the same problem as in the previous post, trying to estimate the unknown \(\mu\) parameter for some distribution. The latent variable inference will look at a more interesting problem of inferring latent local parameters, and use VI to solve it in two different ways.

From the previous post we were trying to estimate the unknown \(\mu\) parameter for data drawn from a normal distribution with known variance.

\[ \begin{aligned} & \mu = 2 \\ & \sigma = 1 \\ & X \sim \text{Normal}(\mu, \sigma ) \end{aligned} \]

We placed a normal distribution prior on \(\mu\)

\[ \begin{aligned} & \mu_\mu = 0 \\ & \sigma_\mu = 10 \\ & \mu \sim \text{Normal}(\mu_\mu, \sigma_\mu) \\ \end{aligned} \]

Giving us the conjugate posterior distribution for \(\mu\) to be

\[ \begin{aligned} & \mu_{\mu|x} = 1.78 \\ & \sigma_{\mu|x} = 0.13 \\ & \mu \sim \text{Normal}(\mu_{\mu|x}, \sigma_{\mu|x}) \end{aligned} \]

prior_mu = 0
prior_sigma = 10

N = 60
mu = 2
sigma = 1
x = np.random.normal(mu, sigma, N)

estimates = pd.DataFrame(
    {
        "method": "Conjugate", 
        "mu_mu": 1.78, 
        "mu_sigma": 0.13
    },
    index=[0]
)

Let’s see if we can recover this distribution using variational inference.

Variational Inference

Rather than sampling from the posterior like with MCMC, variational inference turns this into an optimization problem. To do so we will specify an approximating distribution for the posterior \(q(\mu|x)\) along with the prior distribution. This posterior distribution can be any distribution we like, but our goal is the select a distribution that can match the true posterior as closely as possible. We then optimize the parameters of the approximate posterior to minimize KL divergence between the approximate posterior and the true posterior.

Because we are using conjugate priors we actually know that the true posterior is normally distributed, so first let’s pick the normal as our approximate distribution and see if we can properly recover the parameters. Next, we will try using a distribution that cannot so closely approximate the true normal posterior and see what happens.

With a normal distribution as our posterior we have two parameters to optimize for it, \(\mu_q\) and \(\sigma_q\).

\[ q(\mu|x) \sim \text{Normal}(\mu_q, \sigma_q) \]

q_z = np.random.normal

In order for us to optimize these parameters we need to define our loss function. This is normally taken to be the KL divergence between the approximate posterior and the true posterior. With some work the KL divergence can be simplified to a form which we can actually calculate, in practice this is generally converted to the ELBO to be maximized. The details of this is left to the reader to explore further.

\[ \begin{aligned} & KL(q(\mu|x)||p(\mu|x)) = \int_\mu q(\mu|x) \log\frac{q(\mu|x)}{p(\mu|x)}d\mu = \\ & E[\log\frac{q(\mu|x)}{p(\mu|x)}] = \\ & E[\log q(\mu|x) - \log p(\mu, x) + \log p(x)] = \\ & E[\log q(\mu|x) - \log p(x|\mu) - \log p(\mu) + \log p(x)] \end{aligned} \]

The last term \(\log p(x)\) does not rely on \(q(\mu|x)\) and is therefore constant within the optimization problem reducing the KL divergence to

\[ E[\log q(\mu|x) - \log p(x|\mu) - \log p(\mu)] \]

Because the result is an expectation, we can calculate this by Monte Carlo sampling the approximate posterior and take the average of the resulting calculations.

n_draws = 100 #number of draws for the expectation in KL divergence

def elbo(z_i, posterior_mu, posterior_sigma):
    q_z = ss.norm.logpdf(z_i, posterior_mu, posterior_sigma)
    p_x_z = ss.norm.logpdf(x, z_i, sigma).sum()
    p_z = ss.norm.logpdf(z_i, prior_mu, prior_sigma)
    return q_z - p_x_z - p_z

def kl_divergence(posterior_mu, posterior_sigma):
  #take samples from approximate posterior distribution
  z = q_z(posterior_mu, posterior_sigma, n_draws)
  #calculate elbo(s)
  kl = [elbo(z_i, posterior_mu, posterior_sigma) for z_i in z]
  return np.mean(kl) #take expectation

All that is left is to optimize the posterior_mu and posterior_sigma. Some sort of gradient descent is generally used, but lets do a grid search over the parameter space so we can look at how kl divergence changes across this space. We will calculate the KL divergence at each combination of proposed posterior_mu and posterior_sigma, select the combination that produces the lowest value, and plot it on the space of values.

# grid search over all combinations of mu_vals and sigma_vals
mu_vals = np.arange(0, 4, 0.1)
sigma_vals = np.arange(0.1, 2, 0.1)
to_search = np.array(np.meshgrid(mu_vals, sigma_vals)).reshape((2,-1)).T
to_search = pd.DataFrame(to_search, columns=["mu_mu", "mu_sigma"])
to_search["kl"] = 0.0

for row in to_search.itertuples():
    # calculate value of kl divergence at this combination and save
    to_search.loc[row.Index, "kl"] = kl_divergence(row.mu_mu, row.mu_sigma) 

# find best value in grid and report
best_val = to_search.sort_values("kl").iloc[0]
estimates = pd.concat(
    [
        estimates, 
        pd.DataFrame(
            {
                "method": "VI-Grid Search", 
                "mu_mu": best_val["mu_mu"], 
                "mu_sigma": best_val["mu_sigma"],
            },
            index=[0]
        )
    ],
    ignore_index=True
)
estimates
method mu_mu mu_sigma
0 Conjugate 1.78 0.13
1 VI-Grid Search 1.80 0.10

From our grid search we see that we have once again come very close to the true posterior values.

plt.imshow(
    to_search["kl"].values.reshape((19, 40)), 
    origin="lower",
    extent=(
        to_search["mu_mu"].min(),
        to_search["mu_mu"].max(),
        to_search["mu_sigma"].min(),
        to_search["mu_sigma"].max()
    )
);
plt.scatter(best_val["mu_mu"], best_val["mu_sigma"], c="red");
plt.xlabel("mu_mu");
plt.ylabel("mu_sigma");

While performing grid search is a useful way for visualizing the shape of the parameter space, in practice it is generally too computationally intensive and requires very tight grid spacing to recover the optimal parameters.

Now lets perform a very simplified form of gradient descent. We will be performing coordinate descent, where we optimize each parameter independently. This makes the updates simple to perform, though in more complex parameter spaces it will not perform well.

Also for simplicity we will compute the gradient via finite differences. Again this is not recommended in practice but is good enough for our purposes here.

# initialize the parameters
mu_mu_kl = 0.5
mu_sigma_kl = 1.5

# optimization settings
iters = 100
h = 1e-1 # perturb the parameter by this much to estimate the gradient 
step_size = 1e-3
# track how the posterior parameters change over the iterations
posterior_steps = pd.DataFrame(np.zeros((iters, 2)), columns=["mu", "sigma"])
kls = [] # collect the kl divergence values here

for i in range(iters):
  # step mu
  gradient1 = (kl_divergence(mu_mu_kl + h, mu_sigma_kl) - kl_divergence(mu_mu_kl, mu_sigma_kl)) / h # finite difference for gradient
  mu_mu_kl = mu_mu_kl - gradient1 * step_size
  
  # step sigma
  # technically we should recompute current_kl here but we will skip for speed
  gradient2 = (kl_divergence(mu_mu_kl, mu_sigma_kl + h) - kl_divergence(mu_mu_kl, mu_sigma_kl)) / h
  mu_sigma_kl = mu_sigma_kl - gradient2 * step_size
  
  
  posterior_steps.iloc[i,:] = [mu_mu_kl, mu_sigma_kl]
  kls += [kl_divergence(mu_mu_kl, mu_sigma_kl)]

First we can inspect to make sure the KL divergence between the approximate and true posterior was reduced over iterations.

plt.plot(kls)

And lets see how our approximate distribution compares to the previous estimates.

estimates = pd.concat(
    [
        estimates, 
        pd.DataFrame(
            {
                "method": "VI-Gradient Descent", 
                "mu_mu": mu_mu_kl, 
                "mu_sigma": mu_sigma_kl,
            },
            index=[0]
        )
    ],
    ignore_index=True
)
estimates
method mu_mu mu_sigma
0 Conjugate 1.780000 0.130000
1 VI-Grid Search 1.800000 0.100000
2 VI-Gradient Descent 1.727722 0.083892

Because we traced the posterior parameters as the evolved we can plot the optimization trajectory over the parameter space.

plt.imshow(
    to_search["kl"].values.reshape((19, 40)), 
    origin="lower",
    extent=(
        to_search["mu_mu"].min(),
        to_search["mu_mu"].max(),
        to_search["mu_sigma"].min(),
        to_search["mu_sigma"].max()
    )
);
plt.plot(posterior_steps["mu"], posterior_steps["sigma"], c="red");
plt.xlabel("mu_mu");
plt.ylabel("mu_sigma");

Variational Inference Posterior Mismatch

In the previous section we had the fortune of working with a model that had a known posterior distribution so we could match the family of the approximate distribution to this known family providing us with accurate posterior parameter predictions. But we are not restricted to using the same family of distribution for the approximate distribution. In fact when using variational methods we are often working on models in which we don’t know the true posterior distribution. To see what happens when the approximate distribution family does not match the true posterior, we will use the same generative model as before, but use a exponential distribution as the approximate distribution.

You will notice that this approximate distribution is a poor fit immediately due to it being bounded on the low end by 0.

q_z = np.random.exponential
n_draws = 100 #number of draws for the expectation in KL divergence

def elbo(z_i, posterior_mu):
    q_z = ss.expon.logpdf(z_i, scale=1 / posterior_mu)
    p_x_z = ss.norm.logpdf(x, z_i, sigma).sum()
    p_z = ss.norm.logpdf(z_i, prior_mu, prior_sigma)
    return q_z - p_x_z - p_z

def kl_divergence(posterior_mu):
    #take samples from approximate posterior distribution
    z = q_z(1 / posterior_mu, n_draws)
    #calculate elbo(s)
    kl = [elbo(z_i, posterior_mu) for z_i in z]
    return np.mean(kl) #take expectation

The exponential distribution only has one parameter \(\beta\) that needs to be optimized. Because \(\beta>0\) we will optimize it in unconstrained space by taking the \(\log\) of it.

# initialize the parameters
mu_beta_kl = -1 #this is in log domain because beta is lower bounded at 0, need to exponentiate

# optimization settings
iters = 1000
h = 1e-2 # perturb the parameter by this much to estimate the gradient 
step_size = 1e-6
# track how the posterior parameters change over the iterations
posterior_steps = pd.DataFrame(np.zeros((iters, 1)), columns=["mu"])
kls = [] # collect the kl divergence values here
current_kl = kl_divergence(np.exp(mu_beta_kl))

for i in range(iters):
    # step mu
    gradient1 = (kl_divergence(np.exp(mu_beta_kl + h)) - current_kl) / h # finite difference for gradient
    mu_beta_kl = mu_beta_kl - gradient1 * step_size


    posterior_steps.loc[i, "mu"] = np.exp(mu_beta_kl)
    current_kl = kl_divergence(np.exp(mu_beta_kl))
    kls += [current_kl]
plt.plot(kls)

And lets see how our approximate distribution compares to the previous estimates.

estimates = pd.concat(
    [
        estimates, 
        pd.DataFrame(
            {
                "method": "VI-Mismatch", 
                "mu_mu": 1 / np.exp(mu_beta_kl), 
                "mu_sigma": np.nan,
            },
            index=[0]
        )
    ],
    ignore_index=True
)
estimates
method mu_mu mu_sigma
0 Conjugate 1.780000 0.130000
1 VI-Grid Search 1.800000 0.100000
2 VI-Gradient Descent 1.727722 0.083892
3 VI-Mismatch 1.661364 NaN
x1 = np.arange(-5, 10, 0.1)
plt.plot(x1,  ss.expon.pdf(x1, scale=1 / np.exp(mu_beta_kl)), label="misfit approx posterior");
plt.plot(x1,  ss.norm.pdf(x1, estimates.loc[0, "mu_mu"], estimates.loc[0, "mu_sigma"]),label="conjugate posterior");
plt.legend();

Latent Variable Variational Inference

In the previous section we used various methods to calculate or estimate the posterior distribution of a single, global parameter. In this section we we continue using variational inference to learn latent parameters.

In this context a latent parameter is an unobserved variable related to each observation, so rather than estimating one parameter we need to estimate at least as any parameters as there are data points.

For this we will generate data from mixture of normal distributions.

\[ \begin{aligned} & Z \sim Categorical(0.5, 0.5) \\ & \mu_1 = -1 \\ & \mu_2 = 1 \\ & \sigma = 1 \\ & Y \sim \text{Normal}(\mu_z, \sigma | Z=z) \end{aligned} \]

N_ = 60
N = N_ + N_

mus = [-1, 1]
sigma = 1

y_1 = np.random.normal(mus[0], sigma, N_)
y_2 = np.random.normal(mus[1], sigma, N_)
y = np.concatenate([y_1, y_2])

y_df = pd.DataFrame(
    {
        "class": np.concatenate([np.ones(N_), 2 * np.ones(N_)]), 
        "y": y
    }
)
plt.hist(y_1, alpha=0.5)
plt.hist(y_2, alpha=0.5)
(array([3., 2., 5., 6., 8., 5., 7., 8., 7., 9.]),
 array([-1.15479319, -0.77779737, -0.40080155, -0.02380573,  0.35319009,
         0.7301859 ,  1.10718172,  1.48417754,  1.86117336,  2.23816918,
         2.615165  ]),
 <BarContainer object of 10 artists>)

For our model let us again assume that \(\sigma\) is known, so we must find the global \(\mu\) parameters for each class, as well as the latent class parameter for each data point.

Our joint model is then

\[ p(y,z) = \prod_N{p(y|\mu_z,\sigma)p(z)} \]

Now that we have our data let us specify our posterior distribution and our KL divergence loss metric.

prior_prob = np.array([0.5, 0.5])
prior_mu = 0
prior_sigma = 10
ndraws = 10

def elbo(y, posterior_vals, mu_mu_params, mu_sigma_params, sigma_params):
    z = np.where(np.random.uniform(0, 1, len(y)) <= posterior_vals[:,0], 0, 1).astype(int)
    mu_params = np.random.normal(mu_mu_params, mu_sigma_params, 2)
    # calculate KL
    q_z = np.log(posterior_vals[np.arange(len(y)), z]).sum() # approximate posterior
    p_y_z = ss.norm.logpdf(y, mu_params[z], sigma_params[z]).sum() # likelihood
    p_z = np.log(prior_prob[z]).sum() # class prior
    p_mu = ss.norm.logpdf(mu_params, prior_mu, prior_sigma).sum() # mu prior #need log?
    return q_z - p_y_z - p_z - p_mu

def kl_divergence(y, posterior_vals, mu_mu_params, mu_sigma_params, sigma_params):
  #calculate elbo
  kl = [elbo(y, posterior_vals, mu_mu_params, mu_sigma_params, sigma_params) for _ in range(ndraws)]
  return np.mean(kl)
# initialize parameters
posterior_vals = 0.5 * np.ones((len(y), 2))
mu_mu_params = np.array([-3, 3])
mu_sigma_params = np.log(np.ones(2)) # take log to put parameters in unconstrained space
sigma_params = np.array([1, 1])

iters = 300
# perturbation for posterior z
h = 1
# perturbation for mu values
h_ = 1e-0
h1 = np.array([1, 0])
h2 = np.array([0, 1])

step_size = 1e-1
mu_step_size = 1e-4
sigma_step_size = 1e-4

posterior_steps = np.zeros((iters, len(y)))
mu_mu_steps = np.zeros((iters, 2)) # %>% setNames(c("mu1", "mu2"))
mu_sigma_steps = np.zeros((iters, 2)) # %>% setNames(c("sigma1", "sigma2"))
kls = []

for i in range(iters):
    # expectation
    # convert p to log odds (unconstrained space) before perturbing to stay in (0-1)
    offset_vals = 1 / (1 + np.exp(-np.log(posterior_vals[:,0] / posterior_vals[:,1]) + h))
    offset_vals = np.column_stack([offset_vals, 1 - offset_vals])
    for j in range(N):
        # get gradient
        gradient = (kl_divergence(y[[j]], offset_vals[[j],:], mu_mu_params, np.exp(mu_sigma_params), sigma_params) - 
                    kl_divergence(y[[j]], posterior_vals[[j],:], mu_mu_params, np.exp(mu_sigma_params), sigma_params)) / h
        posterior_vals[j, 0] = 1 / (1 + np.exp(-np.log(posterior_vals[j,0] / posterior_vals[j,1]) - gradient * step_size))
        # step
        posterior_vals[j, 1] = 1 - posterior_vals[j, 0]
    

    # maximization
    # step mu1 params
    gradient11 = (kl_divergence(y, posterior_vals, mu_mu_params + h_ * h1, np.exp(mu_sigma_params), sigma_params) - 
                    kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params), sigma_params)) / h_
    mu_mu_params = mu_mu_params - h1 * gradient11 * mu_step_size
    gradient12 = (kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params + h_ * h1), sigma_params) - 
                    kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params), sigma_params)) / h_
    mu_sigma_params = mu_sigma_params - h1 * gradient12 * sigma_step_size
    
    # step mu2 params
    gradient21 = (kl_divergence(y, posterior_vals, mu_mu_params + h_ * h2, np.exp(mu_sigma_params), sigma_params) - 
                    kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params), sigma_params)) / h_
    mu_mu_params = mu_mu_params - h2 * gradient21 * mu_step_size
    gradient22 = (kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params + h_ * h2), sigma_params) - 
                    kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params), sigma_params)) / h_
    mu_sigma_params = mu_sigma_params - h2 * gradient22 * sigma_step_size

    # save intermediate values for later
    posterior_steps[i,:] = posterior_vals[:,1]
    mu_mu_steps[i,:] = mu_mu_params
    mu_sigma_steps[i,:] = np.exp(mu_sigma_params)
    kl = kl_divergence(y, posterior_vals, mu_mu_params, np.exp(mu_sigma_params), sigma_params)
    kls += [kl]

Inspect KL divergence over training iterations.

plt.plot(kls)

x_range = np.arange(y.min(), y.max() + 0.1, 0.1)

fig, ax = plt.subplots()
xdata, ydata = [], []
sc = ax.scatter(x=y, y=posterior_steps[0], c=y_df["class"], alpha=0.5)
ln1, = plt.plot(
    x_range, 
    ss.norm.pdf(x_range, mu_mu_steps[0,0], mu_sigma_steps[0,0]), 
    color="purple"
)
ln2, = plt.plot(
    x_range, 
    ss.norm.pdf(x_range, mu_mu_steps[0,1], mu_sigma_steps[0,1]), 
    color="yellow"
)

def init():
    ax.set_xlim(y.min(), y.max())
    ax.set_ylim(0, 1)
    return sc,

def update(i):
    ydata = posterior_steps[i]
    sc.set_offsets(np.column_stack([y, ydata]))

    y1 = ss.norm.pdf(x_range, mu_mu_steps[i,0], mu_sigma_steps[i,0])
    ln1.set_data(x_range, y1)
    y2 = ss.norm.pdf(x_range, mu_mu_steps[i,1], mu_sigma_steps[i,1])
    ln2.set_data(x_range, y2)
    return sc,

ani = anim.FuncAnimation(fig, update, frames=np.arange(iters),
                    init_func=init, blit=True)

ani.save("latent_learning_py.gif", fps=30);

Amortized Variational Inference

Amortized variational inference is really where VI is set free and we can start utilizing deep learning methods within Bayesian models. This is the method that allows for variational auto-encoders.

Amortized inference is motivated by the fact that in complex latent models the number of parameters that needs to be estimated, grows with the size of the dataset. For very large datasets even variational methods will be slow.

Rather than fit an individual parameter for each data point with a latent variable, we fit a function that will spit out the parameters given a data point. If our function has fewer parameters than there are datapoints (latent variables) we save on the number of parameters we need to fit.

For this mixture model task we will use two functions. The first function (categorical_network) will give us the posterior parameters of which group an item is in given its observed value. The second (normal_newtork) will give us the posterior parameters for the group distribution.

\[ \begin{aligned} & Z_i \sim Categorical([p, 1-p]) \\ & p, 1-p = f_1(y_i) \\ & \mu, \sigma = f_2(z_i) \end{aligned} \]

# encoder network, this gives us out posterior distributio for z
def categorical_network(x, params):
    # (N,1) x (1, 2) -> (n,2)
    z = x[:,None] @ params
    return np.exp(z - logsumexp(z, axis=1)[:,None]) # softmax


# decoder network, this gives us our likelihood parametetrs given z
def normal_network(z, params):
    # (N,2) x (2, 2) -> (n,2)
    x = z @ params
    x[:,1] = np.exp(x[:,1]) # unconstrain sigma
    return x

prior_prob = np.array([0.5, 0.5])
ndraws = 100

def elbo(y, posterior_vals, decoder_params):
    #take samples from approximate distribution
    z = np.where(np.random.uniform(0, 1, len(y)) <= posterior_vals[:,0], 0, 1).astype(int)
    # convert to one hot encoding
    z_hot = np.zeros((len(y), 2))
    z_hot[np.arange(len(y)), z] = 1
    
    # get likelihood parameters given z
    y_params = normal_network(z_hot, decoder_params)

    # calculate KL
    q_z = np.log(posterior_vals[np.arange(len(y)), z]).sum()
    p_y_z = ss.norm.logpdf(y, y_params[:,0], y_params[:,1]).sum()
    p_z = np.log(prior_prob[z]).sum()
    return q_z - p_y_z - p_z

def kl_divergence(encoder_params, decoder_params):
    posterior_vals = categorical_network(y, encoder_params)
    #calculate elbo
    kl = [elbo(y, posterior_vals, decoder_params) for i in range(ndraws)]

    return np.mean(kl)

The optimization block here is getting more complex because we optimize each parameter individually.

# randomly initialize parameters
encoder_params = np.random.normal(0, 0.01, (1, 2)) # matrix(rnorm(2, 0, 0.01), nrow = 1)
decoder_params = np.random.normal(0, 0.01, (2, 2)) # matrix(rnorm(4, 0, 0.01), nrow = 2)

iters = 100
# perturbation parameters
h = 1e-1
h11 = np.array([1, 0])
h12 = np.array([0, 1])
h21 = np.array([[1, 0], [0, 0]])
h22 = np.array([[0, 1], [0, 0]])
h23 = np.array([[0, 0], [1, 0]])
h24 = np.array([[0, 0], [0, 1]])
step_size = 1e-2

# data structures for saving intermediate valuess
posterior_steps = np.zeros((iters, N))
mu_mu_steps = np.zeros((iters, 2)) #%>% setNames(c("mu1", "mu2"))
mu_sigma_steps = np.zeros((iters, 2)) #%>% setNames(c("sigma1", "sigma2"))
kls = []

for i in range(iters):
    # expectation
    # optimize encoder
    # get gradients and step
    gradient1 = (kl_divergence(encoder_params + h * h11, decoder_params) - 
    kl_divergence(encoder_params, decoder_params)) / h
    encoder_params = encoder_params - h11 * gradient1 * step_size

    gradient2 = (kl_divergence(encoder_params + h * h12, decoder_params) - 
    kl_divergence(encoder_params, decoder_params)) / h
    encoder_params = encoder_params - h12 * gradient2 * step_size

    # maximization
    # optimize decoder
    # get gradients and step
    gradient3 = (kl_divergence(encoder_params, decoder_params + h * h21) - 
    kl_divergence(encoder_params, decoder_params)) / h
    decoder_params = decoder_params - h21 * gradient3 * step_size

    gradient4 = (kl_divergence(encoder_params, decoder_params + h * h22) - 
    kl_divergence(encoder_params, decoder_params)) / h
    decoder_params = decoder_params - h22 * gradient4 * step_size

    gradient5 = (kl_divergence(encoder_params, decoder_params + h * h23) - 
    kl_divergence(encoder_params, decoder_params)) / h
    decoder_params = decoder_params - h23 * gradient5 * step_size

    gradient6 = (kl_divergence(encoder_params, decoder_params + h * h24) - 
    kl_divergence(encoder_params, decoder_params)) / h
    decoder_params = decoder_params - h24 * gradient6 * step_size

    # save intermediate values for later
    posterior_steps[i,:] = categorical_network(y, encoder_params)[:,0]
    posterior_params = normal_network(np.array([[1,0],[0,1]]), decoder_params)
    mu_mu_steps[i,:] = posterior_params[:,0]
    mu_sigma_steps[i,:] = posterior_params[:,1]
    kls += [kl_divergence(encoder_params, decoder_params)]
plt.plot(kls)

Looking at how the latent values are fit you see a coupling between the points due to the smoothness of the function, while previously we had to fit each data point individually making the process noisier.

fig, ax = plt.subplots()
xdata, ydata = [], []
sc = ax.scatter(x=y, y=posterior_steps[0], c=y_df["class"], alpha=0.5)
ln2, = plt.plot(
    x_range, 
    ss.norm.pdf(x_range, mu_mu_steps[0,1], mu_sigma_steps[0,1]), 
    color="purple"
)
ln1, = plt.plot(
    x_range, 
    ss.norm.pdf(x_range, mu_mu_steps[0,0], mu_sigma_steps[0,0]), 
    color="yellow"
)


def init():
    ax.set_xlim(y.min(), y.max())
    ax.set_ylim(0, 1)
    return sc,

def update(i):
    ydata = posterior_steps[i]
    sc.set_offsets(np.column_stack([y, ydata]))

    y1 = ss.norm.pdf(x_range, mu_mu_steps[i,0], mu_sigma_steps[i,0])
    ln1.set_data(x_range, y1)
    y2 = ss.norm.pdf(x_range, mu_mu_steps[i,1], mu_sigma_steps[i,1])
    ln2.set_data(x_range, y2)
    return sc,

ani = anim.FuncAnimation(fig, update, frames=np.arange(iters),
                    init_func=init, blit=True)

ani.save("amortized_latent_learning_py.gif", fps=30);

References

Bayesian Data Analysis

Variational Inference: A Review for Statisticians