Comparison: Variational auto-encoder¶
Here we make a comparison between tensorflow-probability/Edward 2, Pyro and InferPy. As a running example, we will consider a variational auto-encoder (VAE) trained with the MNIST dataset containing handwritten digits. For the inference, SVI method will be used.
Setting up¶
First, we import the required packages and set the global variables. This code is common for the 3 different frameworks:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import inferpy as inf
import pyro
import torch
import tensorflow_probability.python.edward2 as ed
# number of components
k = 2
# size of the hidden layer in the NN
d0 = 100
# dimensionality of the data
dx = 28 * 28
# number of observations (dataset size)
N = 1000
# batch size
M = 100
# digits considered
DIG = [0, 1, 2]
# minimum scale
scale_epsilon = 0.01
# inference parameters
num_epochs = 1000
learning_rate = 0.01
tf.reset_default_graph()
tf.set_random_seed(1234)
Then, we can load and plot the MNIST dataset using the functionality provided by inferpy.data.mnist
.
from inferpy.data import mnist
# load the data
(x_train, y_train), _ = mnist.load_data(num_instances=N, digits=DIG)
mnist.plot_digits(x_train, grid=[5,5])
The generated plot is shown in the figure below.
Model definition¶
P and Q models are defined as functions creating random variables. In the case of the VAE model, we must also define the neural networks for encoding and decoding. For simplicity, they are also defined as functions. The model definitions using InferPy, Edward and Pyro are shown below.
# P model and the decoder NN
@inf.probmodel
def vae(k, d0, dx):
with inf.datamodel():
z = inf.Normal(tf.ones(k), 1,name="z")
decoder = inf.layers.Sequential([
tf.keras.layers.Dense(d0, activation=tf.nn.relu),
tf.keras.layers.Dense(dx)])
x = inf.Normal(decoder(z), 1, name="x")
# Q model for making inference
@inf.probmodel
def qmodel(k, d0, dx):
with inf.datamodel():
x = inf.Normal(tf.ones(dx), 1, name="x")
encoder = tf.keras.Sequential([
tf.keras.layers.Dense(d0, activation=tf.nn.relu),
tf.keras.layers.Dense(2 * k)])
output = encoder(x)
qz_loc = output[:, :k]
qz_scale = tf.nn.softplus(output[:, k:]) + scale_epsilon
qz = inf.Normal(qz_loc, qz_scale, name="z")
def vae(k, d0, dx, N):
z = ed.Normal(loc=tf.ones(k), scale=1., sample_shape=N, name="z")
decoder = inf.layers.Sequential([
tf.keras.layers.Dense(d0, activation=tf.nn.relu, name="h0"),
tf.keras.layers.Dense(dx, name="h1")],
name="decoder")
x = ed.Normal(loc=decoder(z, d0, dx), scale=1., name="x")
return z, x
# Q model for making inference which is parametrized by the data x.
def qmodel(k, d0, x):
encoder = tf.keras.Sequential([
tf.keras.layers.Dense(d0, activation=tf.nn.relu, name="h0"),
tf.keras.layers.Dense(2 * k, name="h1")],
name = "encoder")
output = encoder(x)
qz_loc = output[:, :k]
qz_scale = tf.nn.softplus(output[:, k:]) + scale_epsilon
qz = ed.Normal(loc=qz_loc, scale=qz_scale, name="qz")
return qz
class Decoder(torch.nn.Module):
def __init__(self, k, d0, dx):
super(Decoder, self).__init__()
# setup the two linear transformations used
self.fc1 = torch.nn.Linear(k, d0)
self.fc21 = torch.nn.Linear(d0, dx)
# setup the non-linearities
self.softplus = torch.nn.Softplus()
self.sigmoid = torch.nn.Sigmoid()
self.relu = torch.nn.ReLU()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.relu(self.fc1(z))
# return the parameter for the output Bernoulli
# each is of size batch_size x 784
#loc_img = self.sigmoid(self.fc21(hidden))
loc_img = self.fc21(hidden)
return loc_img
class Encoder(torch.nn.Module):
def __init__(self, k, d0, dx):
super(Encoder, self).__init__()
# setup the three linear transformations used
self.fc1 = torch.nn.Linear(dx, d0)
self.fc21 = torch.nn.Linear(d0, k)
self.fc22 = torch.nn.Linear(d0, k)
# setup the non-linearities
self.softplus = torch.nn.Softplus()
def forward(self, x):
# define the forward computation on the image x
# first shape the mini-batch to have pixels in the rightmost dimension
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x k
z_loc = self.fc21(hidden)
z_scale = self.softplus(self.fc22(hidden))
return z_loc, z_scale + scale_epsilon
class VAE(torch.nn.Module):
def __init__(self, k=2, d0=100, dx=784):
super(VAE, self).__init__()
# create the encoder and decoder networks
self.encoder = Encoder(k, d0, dx)
self.decoder = Decoder(k, d0, dx)
self.k = k
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.k)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.k)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", pyro.distributions.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", pyro.distributions.Normal(loc_img, 1).to_event(1), obs=x)
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", pyro.distributions.Normal(z_loc, z_scale).to_event(1))
With InferPy we do not need to specify which is the size of the data (i.e., plateau or datamodel construct). Instead, this will be automatically obtained at inference time.
With InferPy and Edward 2, models are defined as functions, though InferPy requires to use the decorator @inf.probmodel
. On the
other hand, even though neural networks can be the same, in the Edward 2’s code these are defined with a name as this
will be later used for access to the learned weights. The code in Pyro (adapted from the one in the official documentation)
is quite different as a class structure is used.
Inference¶
Setting up the inference and batched data¶
In Edward 2, before optimizing the variational parameters, we must: split the data into batches; create the instances of the P and Q
models; and finally build tensor for computing the variational ELBO, which represents the function that will be optimized.
The equivalent code using InferPy is much more simple, because most of the functionality is done transparently to the user:
we simply instantiate the P and Q models and the corresponding inference algorithm. Pyro’s code also remains quite simple because
most of the inference details are also encapsulated. Yet the user is required to split the data into batches using the
functionality in torch.utils.data.DataLoader
.
m = vae(k, d0, dx)
q = qmodel(k, d0, dx)
# set the inference algorithm
SVI = inf.inference.SVI(q, epochs=1000, batch_size=M)
batch = tf.data.Dataset.from_tensor_slices(x_train)\
.shuffle(M)\
.batch(M)\
.repeat()\
.make_one_shot_iterator().get_next()
qz = qmodel(k, d0, batch)
with ed.interception(ed.make_value_setter(z=qz, x=batch)):
pz, px = vae(k, d0, dx, M)
energy = N/M*tf.reduce_sum(pz.distribution.log_prob(pz.value)) + \
N/M*tf.reduce_sum(px.distribution.log_prob(px.value))
entropy = N/M*tf.reduce_sum(qz.distribution.log_prob(qz.value))
elbo = energy - entropy
## setting up batched data
vae = VAE(k, d0, dx)
# Load data and set batch_size
train_loader = torch.utils.data.DataLoader(torch.tensor(x_train), batch_size=M, shuffle=False)
# setup the optimizer
adam_args = {"lr": learning_rate}
optimizer = pyro.optim.Adam(adam_args)
# setup the inference algorithm
svi = pyro.infer.SVI(vae.model, vae.guide, optimizer, loss=pyro.infer.Trace_ELBO())
Optimization loop¶
In variational inference, parameters are iteratively optimized. When using Edward 2, we must first specify TensorFlow optimizers and training objects. Then the loop is explicitly coded as shown below. With Pyro, the optimization loop must coded by calling svi.step()
at each iteration. By contrast, with InferPy, we simply invoke the method probmodel.fit()
which takes as input parameters the data and the inference algorithm object previously defined.
# learn the parameters
m.fit({"x": x_train}, SVI)
sess = tf.Session()
optimizer = tf.train.AdamOptimizer(learning_rate)
train = optimizer.minimize(-elbo)
init = tf.global_variables_initializer()
sess.run(init)
t = []
for i in range(num_epochs + 1):
for j in range(N // M):
elbo_ij, _ = sess.run([elbo, train])
t.append(elbo_ij)
if j == 0 and i % 200 == 0:
print("\n {} epochs\t {}".format(i, t[-1]), end="", flush=True)
if j == 0 and i % 20 == 0:
print(".", end="", flush=True)
train_elbo = []
pyro.clear_param_store()
# training loop
for epoch in range(num_epochs):
epoch_loss = 0.
for x in train_loader:
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
train_elbo.append(-total_epoch_loss_train)
if (epoch % 10) == 0:
print(total_epoch_loss_train)
Usage of the inferred model¶
Once optimization is finished, we can use the model with the inferred parameters. For example, we
might obtain the hidden representation of the original data, which is done by passing such data through the decoder.
Edward does not provide any functionality for this purpose, so we will use TensorFlow code. With InferPy, this is done by simply using the method probmodel.posterior()
as follows.
For this, Pyro requires to invoke the class method vae.encoder.forward
which performs the forward propagation in the encoder NN.
# extract the posterior and generate new digits
postz = np.concatenate([
m.posterior("z", data={"x": x_train[i:i+M,:]}).sample()
for i in range(0,N,M)])
def get_tfvar(name):
for v in tf.trainable_variables():
if str.startswith(v.name, name):
return v
def predictive_nn(x, beta0, alpha0, beta1, alpha1):
h0 = tf.nn.relu(x @ beta0 + alpha0)
output = h0 @ beta1 + alpha1
return output
weights_encoder = [sess.run(get_tfvar("encoder/h" + name)) for name in ["0/kernel", "0/bias", "1/kernel", "1/bias",]]
postz = sess.run(predictive_nn(x_train, *weights_encoder)[:, :k])
# extract the posterior of z
postz = np.concatenate([
vae.encoder.forward(x)[0].detach().numpy()
for x in train_loader])
The result of plotting the hidden representation is:
We might be also interested in generating new digits, which implies passing some data in the hidden space
through the decoder. With InferPy we must just invoke the method probmodel.posterior_predictive()
.
In Pyro this is done by invoking the class method vae.encoder.forward
which performs the forward propagation in the decoder NN.
x_gen = m.posterior_predictive('x', data={"z": postz[:M,:]}).sample()
mnist.plot_digits(x_gen, grid=[5,5])
weights_decoder = [sess.run(get_tfvar("decoder/h" + name)) for name in ["0/kernel", "0/bias", "1/kernel", "1/bias",]]
x_gen = sess.run(predictive_nn(postz, *weights_decoder)[:, :dx])
nx, ny = (3,3)
fig, ax = plt.subplots(nx, ny, figsize=(12, 12))
fig.tight_layout(pad=0.3, rect=[0, 0, 0.9, 0.9])
for x, y in [(i, j) for i in list(range(nx)) for j in list(range(ny))]:
img_i = x_gen[x + y * nx].reshape((28, 28))
i = (y, x) if nx > 1 else y
ax[i].imshow(img_i, cmap='gray')
plt.show()
## generate new digits
x_gen = vae.decoder.forward(torch.Tensor(postz[:M,:]))
mnist.plot_digits(x_gen.detach().numpy(), grid=[5,5])
Some of the resulting images are shown below.