Computational Bayes (R)

Author

Zach Smith

smthzch.github.io

Repo for this notebook.

options(tidyverse.quiet = TRUE)
library(tidyverse)
library(gganimate)

set.seed(2021)

This notebook was developed to provide an intuitive understanding of different bayesian inference techniques and the actual computations behind them. The focus is on the what and how and not on the why. Mathematical derivations for why a technique works are left for further research, instead I hope to demonstrate what each inference technique aims to accomplish and show how to accomplish this by implementing simple solutions with minimal tooling to make the mechanics visible and easily understood. I also hope to show how each technique relates to the other to help provide a unified picture of what we are trying to accomplish in bayesian inference.

This notebook is composed of two main sections representing two common problems faced in bayesian inference.

  1. Parameter inference (global parameters)
  2. Latent variable inference (local parameters)

Suppose we have a random variable \(X\) that is distributed

\[ \begin{aligned} & \mu = 2 \\ & \sigma = 1 \\ & X \sim Normal(\mu, \sigma ) \end{aligned} \]

N <- 60
mu <- 2
sigma <- 1
x <- rnorm(N, mu, sigma)

hist(x)

If \(\mu\) is unknown how do we perform inference? In bayesian inference we hope to find the posterior of a parameter given some data \(p(\mu|X=x)\). We can accomplis this using bayes rule:

\[ p(\mu|x) = \frac{p(x|\mu)p(\mu)}{\int_{\mu} p(x,\mu)d\mu} = \frac{p(x|\mu)p(\mu)}{p(x)} \]

By specifying a prior over \(\mu\) we can then update this distribution given some data to produce a posterior distribution. Three common methods for solving this problem follow.

  1. Conjugate Prior
  2. MCMC
  3. Variational Inference

Once we get to variational inference we will also look at two ways of solving latent variable models.

  1. Latent Variable Variational Inference
  2. Amortized Variational Inference

Conjugate Prior

By matching the likelihood distribution with a conjugate prior we can solve for \(p(x)\) analytically and recover a posterior distribution whose parameters can be estimated analytically as well.

A normal distribution is the conjugate prior for \(\mu\) when the data likelihood is also a normal distribution and the standard deviation is known.

\[ \mu \sim Normal(\mu_\mu, \sigma_\mu) \]

The parameter estimates for the posterior distribution of \(\mu\) are:

\[ \begin{aligned} & \mu_{\mu|x} = \frac{\frac{\mu_\mu}{\sigma_\mu^2} + \frac{\sum X}{\sigma^2}}{1/\sigma_\mu^2 + N/\sigma^2} \\ & \sigma_{\mu|x} = \sqrt{\frac{1}{\frac{1}{\sigma_mu^2} + \frac{N}{\sigma^2}}} \end{aligned} \]

In this case lets set the hyperparameters for \(\mu\) as:

\[ \begin{aligned} & \mu_\mu = 0 \\ & \sigma_\mu = 10 \\ & \mu \sim N(\mu_\mu, \sigma_\mu) \\ & \sigma = 1 \\ & X \sim N(\mu, \sigma ) \end{aligned} \]

prior_mu <- 0
prior_sigma <- 10

mu_mu_conj <- (prior_mu / prior_sigma^2 + sum(x) / sigma^2) / (1 / prior_sigma^2 + N / sigma^2) # posterior mean for mu
mu_sigma_conj <- sqrt(1 / (1 / prior_sigma^2 + N / sigma^2)) # posterior sd for mu

estimates <- data.frame(method = "Conjugate", mu_mu = mu_mu_conj, mu_sigma = mu_sigma_conj)
knitr::kable(estimates)
method mu_mu mu_sigma
Conjugate 1.900653 0.1290887

This gives us an estimate of the mean of \(\mu\) of about 1.90 with a standard deviation of about 0.13, so our posterior \(p(\mu|x)\) is distributed:

\[ \begin{aligned} & \mu_{\mu|x} = 1.90 \\ & \sigma_{\mu|x} = 0.13 \\ & \mu \sim Normal(\mu_{\mu|x}, \sigma_{\mu|x}) \end{aligned} \]

MCMC

In some cases we may choose a prior for \(/mu\) which we are unable or unwilling to solve for \(p(x)\) analytically. In that case we cannot analytically recover the posterior \(p(\mu|x)\), however, using Markov Chain Monte Carlo we can sample from the posterior.

If we use the same prior distribution and parameters as before \(\mu_\mu=0,\sigma_\mu=10\), that means we wish to draw samples that are normally distributed with mean 1.90 and standard deviation 0.13 without actually knowing those parameters.

There are different methods for performing this sampling (Gibbs, Hamiltonian, Metropolis Hastings), but the basic idea behind all of them is that we generate a series of random samples (Monte Carlo), with each sample being dependent on the previous one (Markov Chain). By being particular about which of the samples in this series we keep we can shape an arbitrary proposal distribution into producing samples that match our posterior distribution.

For this we will use one of the simplest MCMC sampling methods, Metropolis Hastings. For this there are two things we need.

  1. A proposal distribution. This is an arbitrary distribution which we use to draw samples from.

For this we will use a normal distribution centered at the previous sample location (Markov Chain) with a standard deviation of 1.

\[ \mu_1 \sim Normal(\mu_0, 1) \]

# we need a proposal distribution 
# very tiny jumps == high acceptance rate but slow mixing
jump_sigma <- 1
rproposal <- function(mu_) rnorm(1, mu_, jump_sigma) 

The proposal distribution can be any arbitrary distribution we want. The choice of a symmetric distribution where \(p(\mu_1|\mu_0)=p(\mu_0|\mu_1)\) simplifies the acceptance rule for us. The choice the standard deviation of 1 is also fairly arbitrary, in theory it will not affect the results, but in practice it influences how quickly the markov chain explores the posterior (larger sd will jump to farther away locations in the posterior) as well as our overall acceptance probabilities (larger sd will result in lower acceptance probabilities) which both influence how long we need to run our chain.

  1. An acceptance rule. This is a rule dependent on the previous sample and current sample to determine whether or not to keep the current sample.

For Metropolis Hastings with a symettric proposal distribution this rule is the ratio of the posterior likelihood given the proposed \(\mu\) over the posterior likelihood given the previous \(\mu\)

\[ p(Accept) = \frac{p(\mu_1|x)}{p(\mu_0|x)} = \frac{p(x|\mu_1)p(\mu_1|\mu_\mu,\sigma_\mu)}{p(x|\mu_0)p(\mu_0|\mu_\mu,\sigma_\mu)} \]

pacceptance <- function(mu1, mu0) {
  # calculations done in log for numerical stability (prevent underflows)
  log_likelihood1 <- dnorm(mu1, prior_mu, prior_sigma, log = TRUE) + sum(dnorm(x, mu1, sigma, log = TRUE))
  log_likelihood0 <- dnorm(mu0, prior_mu, prior_sigma, log = TRUE) + sum(dnorm(x, mu0, sigma, log = TRUE))
  exp(log_likelihood1 - log_likelihood0)
}

Now that we have a proposal distribution and an acceptance rule we can run our MCMC algorithm.

mu0 <- rnorm(1, prior_mu, prior_sigma) # select initial value from the prior (can be done other ways)
mus <- c() # collect the markov chain values here
# run Metropolis Hastings MCMC
iters <- 10000
for(i in 1:iters){
  mu1 <- rproposal(mu0) # propose new mu based on current mu
  acceptance_ratio <- pacceptance(mu1, mu0) # calculate acceptance probability
  # accept or reject with probability equal to acceptance_ratio
  u <- runif(1)
  mu0 <- ifelse(u <= acceptance_ratio, mu1, mu0) # if accept current mu it becomes previous mu for next step
  mus <- c(mus, mu0) # add current sample value to chain
}

# drop warmup samples
mus <- mus[100:length(mus)]

By calculating summary statistics on the sampled values we can see if they match the analytical solution.

mu_mu_mcmc <- mean(mus)
mu_sigma_mcmc <- sd(mus)

estimates <- rbind(estimates, data.frame(method = "MCMC", mu_mu = mu_mu_mcmc, mu_sigma = mu_sigma_mcmc))
knitr::kable(estimates)
method mu_mu mu_sigma
Conjugate 1.900653 0.1290887
MCMC 1.890423 0.1298439

We see they are fairly close, showing that we were able to sample from the posterior distribution. Lets plot the sampled distribution against the true posterior.

x1 <- seq(1, 3, by = 0.01)
hist(mus, freq = FALSE, main = "Histogram of MCMC Draws")
lines(y = dnorm(x1, mu_mu_conj, mu_sigma_conj ), x = x1)
legend("topleft", legend=c("Conjugate Posterior"), lty=c(1))

Variational Inference

Perhaps we are impatient with MCMC so we would like to turn this into an optimization problem. We can do this using variational inference. In this scenario we will specify an approximating posterior distribution \(q(\mu|x)\) along with the prior distribution. This posterior distribution can be any distribution we like but our goal is the select a distribution that can match the true posterior as closely as possible. We then optimize the parameters of the approximate posterior to do this.

Because we are using conjugate priors we actually know that the true posterior is normally distributed, so first let’s pick the normal as our approximate distribution and see if we can properly recover the parameters. Next, we will try using a distribution that cannot so closely approximate the true normal posterior and see what happens.

With a normal distribution as our posterior we have two parameters to optimize for it, \(\mu_q\) and \(\sigma_q\).

\[ q(\mu|x) \sim Normal(\mu_q, \sigma_q) \]

q_z <- rnorm

In order for us to optimize these parameters we need to define our loss function. This is normally taken to be the KL divergence between the approximate posterior and the true posterior. With some work the KL divergence can be simplified to a form which we can actually calculate, in practice this is generally converted to the ELBO to be maximized. The details of this is left to the reader to explore further.

\[ \begin{aligned} & KL(q(\mu|x)||p(\mu|x)) = \int_\mu q(\mu|x) \log\frac{q(\mu|x)}{p(\mu|x)}d\mu = \\ & E[\log\frac{q(\mu|x)}{p(\mu|x)}] = \\ & E[\log q(\mu|x) - \log p(\mu, x) + \log p(x)] = \\ & E[\log q(\mu|x) - \log p(x|\mu) - \log p(\mu) + \log p(x)] \end{aligned} \]

The last term \(\log p(x)\) does not rely on \(q(\mu|x)\) and is therefore constant within the optimization problem reducing the KL divergence to

\[ E[\log q(\mu|x) - \log p(x|\mu) - \log p(\mu)] \]

Because the result is an expectation, we can calculate this by Monte Carlo sampling the approximate posterior and take the average of the resulting calculations.

n_draws = 100 #number of draws for the expectation in KL divergence

kl_divergence <- function(posterior_mu, posterior_sigma){
  #take samples from approximate posterior distribution
  z <- q_z(n_draws, posterior_mu, posterior_sigma)
  #calculate elbo
  kl <- sapply(z, function(z_i){
    q_z <- dnorm(z_i, posterior_mu, posterior_sigma, log = TRUE)
    p_x_z <- sum(dnorm(x, z_i, sigma, log = TRUE)) 
    p_z <- dnorm(z_i, prior_mu, prior_sigma, log = TRUE)
    q_z - p_x_z - p_z
  })
  mean(kl) #take expectation
}

All that is left is to optimize the posterior_mu and posterior_sigma. Some sort of gradient descent is generally used, but lets to a grid search over the parameter space so we can look at how kl divergence changes across this space. We will calculate the KL divergence at each combination of proposed posterior_mu and posterior_sigma, select the combination that produces the lowest value, and plot it on the space of values.

# grid search over all combinations of mu_vals and sigma_vals
mu_vals <- seq(-2, 8, by = 0.1)
sigma_vals <- seq(0.1, 10, by = 0.1)
to_search <- expand.grid(mu_vals, sigma_vals)
to_search$kl <- 0
for(i in 1:nrow(to_search)){
  to_search[i, "kl"] <- kl_divergence(to_search[i,1], to_search[i,2]) # calculate value of kl divergence at this combination and save
}
to_search <- setNames(to_search, c("mu_mu", "mu_sigma", "kl"))

# find best value in grid and report
best_val <- to_search[which.min(to_search$kl),]
estimates <- rbind(estimates, data.frame(method = "VI-Grid Search", mu_mu = best_val$mu_mu, mu_sigma = best_val$mu_sigma))
knitr::kable(estimates)
method mu_mu mu_sigma
Conjugate 1.900653 0.1290887
MCMC 1.890423 0.1298439
VI-Grid Search 1.900000 0.1000000

From our grid search we see that we have once again come very close to the true posterior values.

to_search %>% 
  ggplot(aes(x = mu_mu, y = mu_sigma, fill =kl)) +
  geom_tile() +
  geom_point(data = best_val, aes(x = mu_mu, y = mu_sigma), col="red")

While performing grid search is a useful way for visualizing the shape of the parameter space, in practice it is generally too computationally intensive and requires very tight grid spacing to recover the optimal paramters.

Now lets perform a very simplified form of gradient descent. We will be performing coordinate descent, where we optimize each parameter independently. This makes the updates simple to perform, though in more complex parameter spaces it will not perform well.

Also for simplicity we will compute the gradient via finite differences. Again this is not recommended in practice but is good enough for our purposes here.

# initialize the parameters
mu_mu_kl <- 0
mu_sigma_kl <- 7

# optimization settings
iters <- 1000
h <- 1e-1 # perturb the parameter by this much to estimate the gradient 
step_size <- 1e-4
# track how the posterior parameters change over the iterations
posterior_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 2)) %>%  setNames(c("mu", "sigma"))
kls <- c() # collect the kl divergence values here

for(i in 1:iters){
  # step mu
  gradient1 <- (kl_divergence(mu_mu_kl + h, mu_sigma_kl) - kl_divergence(mu_mu_kl, mu_sigma_kl)) / h # finite difference for gradient
  mu_mu_kl <- mu_mu_kl - gradient1 * step_size
  
  # step sigma
  gradient2 <- (kl_divergence(mu_mu_kl, mu_sigma_kl + h) - kl_divergence(mu_mu_kl, mu_sigma_kl)) / h
  mu_sigma_kl <- mu_sigma_kl - gradient2 * step_size
  
  
  posterior_steps[i,] <- c(mu_mu_kl, mu_sigma_kl)
  kls <- c(kls, kl_divergence(mu_mu_kl, mu_sigma_kl))
}

First we can inspect to make sure the KL divergence between the approximate and true posterior was reduced over iterations.

plot(kls, type = "l")

And lets see how our approximate distribution compares to the previous estimates.

estimates <- rbind(estimates, data.frame(method = "VI-Gradient Descent", mu_mu = mu_mu_kl, mu_sigma = mu_sigma_kl))
knitr::kable(estimates)
method mu_mu mu_sigma
Conjugate 1.900653 0.1290887
MCMC 1.890423 0.1298439
VI-Grid Search 1.900000 0.1000000
VI-Gradient Descent 1.850746 0.0842574

Because we traced the posterior parameters as the evolved we can plot the optimization trajectory over the parameter space.

to_search %>% 
  ggplot() +
  geom_tile(aes(x = mu_mu, y = mu_sigma, fill = kl)) +
  geom_path(data = posterior_steps, aes(x = mu, y = sigma), col="red")

Variational Inference Posterior Mismatch

In the previous section we had the fortune of working with a model that had a known posterior distribution so we could match the family of the approximate distribution to this known family providing us with accurate posterior parameter predictions. But we are not restricted to using the same family of distribution for the approximate distribution. In fact when using varitional methods we are often working on models in which we don’t know the true posterior distribution. To see what happens when the approximate distribution family does not match the true posterior, we will use the same generative model as before, but use a exponential distribution as the approximate distribution.

You will notice that this approximate distribution is a poor fit immediately due to it being bounded on the low end by 0.

q_z <- rexp
n_draws = 100 #number of draws for the expectation in KL divergence

kl_divergence <- function(posterior_mu){
  #take samples from approximate posterior distribution
  z <- q_z(n_draws, posterior_mu)
  #calculate elbo
  kl <- sapply(z, function(z_i){
    q_z <- dexp(z_i, posterior_mu, log = TRUE)
    p_x_z <- sum(dnorm(x, z_i, sigma, log = TRUE)) 
    p_z <- dnorm(z_i, prior_mu, prior_sigma, log = TRUE)
    q_z - p_x_z - p_z
  })
  mean(kl) #take expectation
}

The exponential distribution only has one parameter \(\beta\) that needs to be optimized. Because \(\beta>0\) we will optimize it in unconstrained space by taking the \(\log\) of it.

# initialize the parameters
mu_beta_kl <- -1.5 #this is in log domain because beta is lower bounded at 0, need to exponentiate

# optimization settings
iters <- 2000
h <- 1e-2 # perturb the parameter by this much to estimate the gradient 
step_size <- 1e-6
# track how the posterior parameters change over the iterations
posterior_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 1)) %>%  setNames(c("mu"))
kls <- c() # collect the kl divergence values here

for(i in 1:iters){
  # step mu
  gradient1 <- (kl_divergence(exp(mu_beta_kl + h)) - kl_divergence(exp(mu_beta_kl))) / h # finite difference for gradient
  mu_beta_kl <- mu_beta_kl - gradient1 * step_size
    
  
  posterior_steps[i,] <- exp(mu_beta_kl)
  kls <- c(kls, kl_divergence(exp(mu_beta_kl)))
}
plot(kls, type = "l")

And lets see how our approximate distribution compares to the previous estimates.

estimates <- rbind(estimates, data.frame(method = "VI-Mismatch", mu_mu = 1 / exp(mu_beta_kl), mu_sigma = NA))
knitr::kable(estimates)
method mu_mu mu_sigma
Conjugate 1.900653 0.1290887
MCMC 1.890423 0.1298439
VI-Grid Search 1.900000 0.1000000
VI-Gradient Descent 1.850746 0.0842574
VI-Mismatch 1.580656 NA
x1 <- seq(-5, 10, by = 0.1)
plot(y = dexp(x1, 1 / exp(mu_beta_kl)), x = x1, type = "l", lty = 1)
lines(y = dnorm(x1, mu_mu_conj, mu_sigma_conj ), x = x1, lty = 2)
legend("topright", legend=c("Approximate Posterior", "Conjugate Posterior"), lty=c(1, 2))

Latent Variable Variational Inference

In the previous section we used various methods to calculate or estimate the posterior distribution of a single, global parameter. In this section we we continue using variational inference to learn latent parameters.

In this context a latent parameter is an unobserved variable related to each observation, so rather than estimating one parameter we need to estimate at least as any parameters as there are data points.

For this we will generate data from mixture of normal distributions.

\[ \begin{aligned} & Z \sim Categorical(0.5, 0.5) \\ & \mu_1 = -1 \\ & \mu_2 = 1 \\ & \sigma = 1 \\ & Y \sim Normal(\mu_z, \sigma | Z=z) \end{aligned} \]

N_ <- 60
N <- N_ + N_

mus <- c(-1, 1)
sigma <- 1

y_1 <- rnorm(N_, mus[1], sigma)
y_2 <- rnorm(N_, mus[2], sigma)
y <- c(y_1, y_2)

y_df <- data.frame(class = as.factor(c(rep(1, N_), rep(2, N_))), y = y)
y_df %>%
  ggplot(aes(x = y, fill = class, group = class))+
  geom_histogram(position = "identity", alpha = 0.5)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

For our model let us again assume that \(\sigma\) is known, so we must find the global \(\mu\) parameters for each class, as well as the latent class parameter for each data point.

Our joint model is then

\[ p(y,z) = \prod_N{p(y|\mu_z,\sigma)p(z)} \]

Now that we have our data let us specify our posterior distribution and our KL divergence loss metric.

prior_prob <- c(0.5, 0.5)
prior_mu <- 0
prior_sigma <- 10
ndraws <- 100

kl_divergence <- function(y, posterior_vals, mu_mu_params, mu_sigma_params, sigma_params){
  #calculate elbo
  kl <- sapply(1:ndraws, function(i){
    #take samples from approximate distributions
    z <- ifelse(runif(length(y)) <= posterior_vals[,1], 1, 2)
    mu_params <- rnorm(2, mu_mu_params, mu_sigma_params)
    # calculate KL
    q_z <- sum(log(posterior_vals[cbind(1:length(y), z)])) # approximate posterior
    p_y_z <- sum(dnorm(y, mu_params[z], sigma_params[z], log = TRUE)) # likelihood
    p_z <- sum(log(prior_prob[z])) # class prior
    p_mu <- sum(dnorm(mu_params, prior_mu, prior_sigma, log = TRUE)) # mu prior
    q_z - p_y_z - p_z - p_mu
  })
  mean(kl)
}
# initialize parameters
posterior_vals <- matrix(0.5, nrow = length(y), ncol = 2)
mu_mu_params <- c(-3, 3)
mu_sigma_params <- log(c(1, 1)) # take log to put parameters in unconstrained space
sigma_params <- c(1, 1)

iters <- 300
# perturbation for posterior z
h <- 1
# perturbation for mu values
h_ <- 1e-1
h1 <- c(1, 0)
h2 <- c(0, 1)

step_size <- 1e-1
mu_step_size <- 1e-4
sigma_step_size <- 1e-4

posterior_steps <- as.data.frame(matrix(0, nrow = iters, ncol = length(y)))
mu_mu_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 2)) %>% setNames(c("mu1", "mu2"))
mu_sigma_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 2)) %>% setNames(c("sigma1", "sigma2"))
kls <- c()

for(i in 1:iters){  
  # expectation
  # convert p to log odds (unconstrained space) before perturbing to stay in (0-1)
  offset_vals <- 1 / (1 + exp(-log(posterior_vals[,1] / posterior_vals[,2]) + h))
  offset_vals <- cbind(offset_vals, 1 - offset_vals)
  for(j in 1:N){
    # get gradient
    gradient <- (kl_divergence(y[j], offset_vals[j,,drop=FALSE], mu_mu_params, exp(mu_sigma_params), sigma_params) - 
                  kl_divergence(y[j], posterior_vals[j,,drop=FALSE], mu_mu_params, exp(mu_sigma_params), sigma_params)) / h
    posterior_vals[j, 1] <- 1 / (1 + exp(-log(posterior_vals[j,1] / posterior_vals[j,2]) - gradient * step_size))
    # step
    posterior_vals[j, 2] <- 1 - posterior_vals[j, 1]
  }

  # maximization
  # step mu1 params
  gradient11 <- (kl_divergence(y, posterior_vals, mu_mu_params + h_ * h1, exp(mu_sigma_params), sigma_params) - 
                kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params), sigma_params)) / h_
  mu_mu_params <- mu_mu_params - h1 * gradient11 * mu_step_size
  gradient12 <- (kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params + h_ * h1), sigma_params) - 
                kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params), sigma_params)) / h_
  mu_sigma_params <- mu_sigma_params - h1 * gradient12 * sigma_step_size
  
  # step mu2 params
  gradient21 <- (kl_divergence(y, posterior_vals, mu_mu_params + h_ * h2, exp(mu_sigma_params), sigma_params) - 
                kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params), sigma_params)) / h_
  mu_mu_params <- mu_mu_params - h2 * gradient21 * mu_step_size
  gradient22 <- (kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params + h_ * h2), sigma_params) - 
                kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params), sigma_params)) / h_
  mu_sigma_params <- mu_sigma_params - h2 * gradient22 * sigma_step_size

  # save intermediate values for later
  posterior_steps[i,] <- posterior_vals[,2]
  mu_mu_steps[i,] <- mu_mu_params
  mu_sigma_steps[i,] <- exp(mu_sigma_params)
  kl <- kl_divergence(y, posterior_vals, mu_mu_params, exp(mu_sigma_params), sigma_params)
  kls <- c(kls, kl)
}

Inspect KL divergence over training iterations.

plot(kls, type = "l")

steps_df <- posterior_steps %>%
  mutate(ix = row_number()) %>%
  pivot_longer(-ix, names_to = "dim", values_to = "value") %>%
  mutate(
    class = as.factor(rep(c(rep(1,N_), rep(2,N_)), iters)),
    y = rep(y, iters)
  )

# these steps convert the posterior parameters for mu into a density curve for the plots
mu_df <- mu_mu_steps %>%
  mutate(ix = row_number())
sigma_df <- mu_sigma_steps %>%
  mutate(ix = row_number())
mu_sigma_df <- mu_df %>%
  inner_join(sigma_df, by="ix") %>%
  group_by(ix) %>%
  nest() %>%
  mutate(
    density = map(data, ~{
      x = seq(min(y), max(y), by=0.01)
      data.frame(
        x = x, 
        mu1 = dnorm(x, .x$mu1, .x$sigma1),
        mu2 = dnorm(x, .x$mu2, .x$sigma2)
      )
    })
  ) %>%
  unnest(density) %>%
  select(-data) %>%
  pivot_longer(starts_with("mu")) %>%
  mutate(value = value / max(value))

learning_plot <- ggplot(steps_df, aes(y = value, x = y, color = class)) +
  geom_point() +
  geom_line(data = mu_sigma_df, aes(x = x, color = name, y = value)) +
  ggtitle("Posterior Probability of Class 2") +
  ylab("p(c=2)") +
  transition_manual(ix)

anim_save("latent_learning.gif", learning_plot)
knitr::include_graphics("latent_learning.gif")

Amortized Variational Inference

Amortized variational inference is really where VI is set free and we can start utilizing deep learning methods within bayesian models. This is the method that allows for variational auto-encoders.

Amortized inference is motivated by the fact that in complex latent models the number of parameters that needs to be estimated, grows with the size of the dataset. For very large datasets even variational methods will be slow.

# encoder network, this gives us out posterior distributio for z
categorical_network <- function(x, params){
  z <- x %*% params
  z <- exp(z) / rowSums(exp(z))
  z
}

# decoder network, this gives us our likelihood parametetrs given z
normal_network <- function(z, params){
  x <- z %*% params
  x[,2] <- exp(x[,2])
  x
}

prior_prob <- c(0.5, 0.5)
ndraws <- 100

kl_divergence <- function(encoder_params, decoder_params){
  posterior_vals <- categorical_network(y, encoder_params)
  #calculate elbo
  kl <- sapply(1:ndraws, function(i){
    #take samples from approximate distribution
    z <- ifelse(runif(length(y)) <= posterior_vals[,1], 1, 2)
    # convert to one hot encoding
    z_hot <- matrix(0, nrow = length(y), ncol = 2)
    z_hot[cbind(1:length(y), z)] <- 1
    
    # get likelihood parameters given z
    y_params <- normal_network(z_hot, decoder_params)

    # calculate KL
    q_z <- sum(log(posterior_vals[cbind(1:length(y), z)]))
    p_y_z <- sum(dnorm(y, y_params[,1], y_params[,2], log = TRUE))
    p_z <- sum(log(prior_prob[z]))
    q_z - p_y_z - p_z
  })

  mean(kl)
}
# randomly initialize parameters
encoder_params <- matrix(rnorm(2, 0, 0.01), nrow = 1)
decoder_params <- matrix(rnorm(4, 0, 0.01), nrow = 2)

iters <- 100
# perturbation parameters
h <- 1e-1
h11 <- c(1, 0)
h12 <- c(0, 1)
h21 <- matrix(c(1, 0, 0, 0), nrow = 2, byrow = TRUE)
h22 <- matrix(c(0, 1, 0, 0), nrow = 2, byrow = TRUE)
h23 <- matrix(c(0, 0, 1, 0), nrow = 2, byrow = TRUE)
h24 <- matrix(c(0, 0, 0, 1), nrow = 2, byrow = TRUE)
step_size <- 1e-2

# data structures for saving intermediate valuess
posterior_steps <- as.data.frame(matrix(0, nrow = iters, ncol = N))
mu_mu_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 2)) %>% setNames(c("mu1", "mu2"))
mu_sigma_steps <- as.data.frame(matrix(0, nrow = iters, ncol = 2)) %>% setNames(c("sigma1", "sigma2"))
kls <- c()

for(i in 1:iters){
  # expectation
  # optimize encoder
  # get gradients and step
  gradient1 <- (kl_divergence(encoder_params + h * h11, decoder_params) - 
    kl_divergence(encoder_params, decoder_params)) / h
  encoder_params <- encoder_params - h11 * gradient1 * step_size
  
  gradient2 <- (kl_divergence(encoder_params + h * h12, decoder_params) - 
    kl_divergence(encoder_params, decoder_params)) / h
  encoder_params <- encoder_params - h12 * gradient2 * step_size 
  
  # maximization
  # optimize decoder
  # get gradients and step
  gradient3 <- (kl_divergence(encoder_params, decoder_params + h * h21) - 
    kl_divergence(encoder_params, decoder_params)) / h
  decoder_params <- decoder_params - h21 * gradient3 * step_size
  
  gradient4 <- (kl_divergence(encoder_params, decoder_params + h * h22) - 
    kl_divergence(encoder_params, decoder_params)) / h
  decoder_params <- decoder_params - h22 * gradient4 * step_size
  
  gradient5 <- (kl_divergence(encoder_params, decoder_params + h * h23) - 
    kl_divergence(encoder_params, decoder_params)) / h
  decoder_params <- decoder_params - h23 * gradient5 * step_size
  
  gradient6 <- (kl_divergence(encoder_params, decoder_params + h * h24) - 
    kl_divergence(encoder_params, decoder_params)) / h
  decoder_params <- decoder_params - h24 * gradient6 * step_size
  
  # save intermediate values for later
  posterior_steps[i,] <- categorical_network(y, encoder_params)[,1]
  mu_mu_steps[i,] <- normal_network(matrix(c(1,0,0,1), nrow=2), decoder_params)[,1]
  mu_sigma_steps[i,] <- normal_network(matrix(c(1,0,0,1), nrow=2), decoder_params)[,2]
  kls <- c(kls, kl_divergence(encoder_params, decoder_params))
}
plot(kls, type = "l")

steps_df <- posterior_steps %>%
  mutate(ix = row_number()) %>%
  pivot_longer(-ix, names_to = "dim", values_to = "value") %>%
  mutate(
    class = as.factor(rep(c(rep(1,N_), rep(2,N_)), iters)),
    y = rep(y, iters)
  )

# these steps convert the posterior parameters for mu into a density curve for the plots
mu_df <- mu_mu_steps %>%
  mutate(ix = row_number())
sigma_df <- mu_sigma_steps %>%
  mutate(ix = row_number())
mu_sigma_df <- mu_df %>%
  inner_join(sigma_df, by="ix") %>%
  group_by(ix) %>%
  nest() %>%
  mutate(
    density = map(data, ~{
      x = seq(min(y), max(y), by=0.01)
      data.frame(
        x = x, 
        mu1 = dnorm(x, .x$mu1, .x$sigma1),
        mu2 = dnorm(x, .x$mu2, .x$sigma2)
      )
    })
  ) %>%
  unnest(density) %>%
  select(-data) %>%
  pivot_longer(starts_with("mu")) %>%
  mutate(value = value / max(value))

learning_plot <- ggplot(steps_df, aes(y = value, x = y, color = class)) +
  geom_point() +
  geom_line(data = mu_sigma_df, aes(x = x, color = name, y = value)) +
  ggtitle("Posterior Probability of Class 2") +
  ylab("p(c=2)") +
  transition_manual(ix)

anim_save("amortized_latent_learning.gif", learning_plot)
knitr::include_graphics("amortized_latent_learning.gif")

References

Bayesian Data Analysis

Variational Inference: A Review for Statisticians