Comparison: Variational auto-encoder

Here we make a comparison between tensorflow-probability/Edward 2, Pyro and InferPy. As a running example, we will consider a variational auto-encoder (VAE) trained with the MNIST dataset containing handwritten digits. For the inference, SVI method will be used.

Setting up

First, we import the required packages and set the global variables. This code is common for the 3 different frameworks:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import inferpy as inf
import pyro
import torch
import tensorflow_probability.python.edward2 as ed

# number of components
k = 2
# size of the hidden layer in the NN
d0 = 100
# dimensionality of the data
dx = 28 * 28
# number of observations (dataset size)
N = 1000
# batch size
M = 100
# digits considered
DIG = [0, 1, 2]
# minimum scale
scale_epsilon = 0.01
# inference parameters
num_epochs = 1000
learning_rate = 0.01

tf.reset_default_graph()
tf.set_random_seed(1234)

Then, we can load and plot the MNIST dataset using the functionality provided by inferpy.data.mnist.

from inferpy.data import mnist

# load the data
(x_train, y_train), _ = mnist.load_data(num_instances=N, digits=DIG)

mnist.plot_digits(x_train, grid=[5,5])

The generated plot is shown in the figure below.

MNIST training data.

Model definition

P and Q models are defined as functions creating random variables. In the case of the VAE model, we must also define the neural networks for encoding and decoding. For simplicity, they are also defined as functions. The model definitions using InferPy, Edward and Pyro are shown below.

# P model and the  decoder NN
@inf.probmodel
def vae(k, d0, dx):
    with inf.datamodel():
        z = inf.Normal(tf.ones(k), 1,name="z")

        decoder = inf.layers.Sequential([
            tf.keras.layers.Dense(d0, activation=tf.nn.relu),
            tf.keras.layers.Dense(dx)])

        x = inf.Normal(decoder(z), 1, name="x")

# Q model for making inference
@inf.probmodel
def qmodel(k, d0, dx):
    with inf.datamodel():
        x = inf.Normal(tf.ones(dx), 1, name="x")

        encoder = tf.keras.Sequential([
            tf.keras.layers.Dense(d0, activation=tf.nn.relu),
            tf.keras.layers.Dense(2 * k)])

        output = encoder(x)
        qz_loc = output[:, :k]
        qz_scale = tf.nn.softplus(output[:, k:]) + scale_epsilon
        qz = inf.Normal(qz_loc, qz_scale, name="z")
def vae(k, d0, dx, N):
    z = ed.Normal(loc=tf.ones(k), scale=1., sample_shape=N, name="z")

    decoder = inf.layers.Sequential([
        tf.keras.layers.Dense(d0, activation=tf.nn.relu, name="h0"),
        tf.keras.layers.Dense(dx, name="h1")],
    name="decoder")
    x = ed.Normal(loc=decoder(z, d0, dx), scale=1., name="x")
    return z, x



# Q model for making inference which is parametrized by the data x.
def qmodel(k, d0, x):
    encoder = tf.keras.Sequential([
        tf.keras.layers.Dense(d0, activation=tf.nn.relu, name="h0"),
        tf.keras.layers.Dense(2 * k, name="h1")],
    name = "encoder")
    output = encoder(x)
    qz_loc = output[:, :k]
    qz_scale = tf.nn.softplus(output[:, k:]) + scale_epsilon
    qz = ed.Normal(loc=qz_loc, scale=qz_scale, name="qz")
    return qz
class Decoder(torch.nn.Module):
    def __init__(self, k, d0, dx):
        super(Decoder, self).__init__()
        # setup the two linear transformations used
        self.fc1 = torch.nn.Linear(k, d0)
        self.fc21 = torch.nn.Linear(d0, dx)
        # setup the non-linearities
        self.softplus = torch.nn.Softplus()
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.relu(self.fc1(z))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        #loc_img = self.sigmoid(self.fc21(hidden))
        loc_img = self.fc21(hidden)
        return loc_img


class Encoder(torch.nn.Module):
    def __init__(self, k, d0, dx):
        super(Encoder, self).__init__()
        # setup the three linear transformations used
        self.fc1 = torch.nn.Linear(dx, d0)
        self.fc21 = torch.nn.Linear(d0, k)
        self.fc22 = torch.nn.Linear(d0, k)
        # setup the non-linearities
        self.softplus = torch.nn.Softplus()

    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        # then compute the hidden units
        hidden = self.softplus(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x k
        z_loc = self.fc21(hidden)
        z_scale = self.softplus(self.fc22(hidden))
        return z_loc, z_scale + scale_epsilon


class VAE(torch.nn.Module):
    def __init__(self, k=2, d0=100, dx=784):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(k, d0, dx)
        self.decoder = Decoder(k, d0, dx)
        self.k = k

    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.k)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.k)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", pyro.distributions.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images
            pyro.sample("obs", pyro.distributions.Normal(loc_img, 1).to_event(1), obs=x)

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", pyro.distributions.Normal(z_loc, z_scale).to_event(1))

With InferPy we do not need to specify which is the size of the data (i.e., plateau or datamodel construct). Instead, this will be automatically obtained at inference time.

With InferPy and Edward 2, models are defined as functions, though InferPy requires to use the decorator @inf.probmodel. On the other hand, even though neural networks can be the same, in the Edward 2’s code these are defined with a name as this will be later used for access to the learned weights. The code in Pyro (adapted from the one in the official documentation) is quite different as a class structure is used.

Inference

Setting up the inference and batched data

In Edward 2, before optimizing the variational parameters, we must: split the data into batches; create the instances of the P and Q models; and finally build tensor for computing the variational ELBO, which represents the function that will be optimized. The equivalent code using InferPy is much more simple, because most of the functionality is done transparently to the user: we simply instantiate the P and Q models and the corresponding inference algorithm. Pyro’s code also remains quite simple because most of the inference details are also encapsulated. Yet the user is required to split the data into batches using the functionality in torch.utils.data.DataLoader.

m = vae(k, d0, dx)
q = qmodel(k, d0, dx)

# set the inference algorithm
SVI = inf.inference.SVI(q, epochs=1000, batch_size=M)
batch = tf.data.Dataset.from_tensor_slices(x_train)\
        .shuffle(M)\
        .batch(M)\
        .repeat()\
        .make_one_shot_iterator().get_next()

qz = qmodel(k, d0, batch)

with ed.interception(ed.make_value_setter(z=qz, x=batch)):
    pz, px = vae(k, d0, dx, M)

energy = N/M*tf.reduce_sum(pz.distribution.log_prob(pz.value)) + \
         N/M*tf.reduce_sum(px.distribution.log_prob(px.value))
entropy = N/M*tf.reduce_sum(qz.distribution.log_prob(qz.value))

elbo = energy - entropy
## setting up batched data
vae = VAE(k, d0, dx)

# Load data and set batch_size
train_loader = torch.utils.data.DataLoader(torch.tensor(x_train), batch_size=M, shuffle=False)

# setup the optimizer
adam_args = {"lr": learning_rate}
optimizer = pyro.optim.Adam(adam_args)

# setup the inference algorithm
svi = pyro.infer.SVI(vae.model, vae.guide, optimizer, loss=pyro.infer.Trace_ELBO())

Optimization loop

In variational inference, parameters are iteratively optimized. When using Edward 2, we must first specify TensorFlow optimizers and training objects. Then the loop is explicitly coded as shown below. With Pyro, the optimization loop must coded by calling svi.step() at each iteration. By contrast, with InferPy, we simply invoke the method probmodel.fit() which takes as input parameters the data and the inference algorithm object previously defined.

# learn the parameters
m.fit({"x": x_train}, SVI)
sess = tf.Session()
optimizer = tf.train.AdamOptimizer(learning_rate)
train = optimizer.minimize(-elbo)
init = tf.global_variables_initializer()
sess.run(init)

t = []
for i in range(num_epochs + 1):
    for j in range(N // M):
        elbo_ij, _ = sess.run([elbo, train])

        t.append(elbo_ij)
        if j == 0 and i % 200 == 0:
            print("\n {} epochs\t {}".format(i, t[-1]), end="", flush=True)
        if j == 0 and i % 20 == 0:
            print(".", end="", flush=True)
train_elbo = []
pyro.clear_param_store()

# training loop
for epoch in range(num_epochs):
    epoch_loss = 0.
    for x in train_loader:
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train

    train_elbo.append(-total_epoch_loss_train)


    if (epoch % 10) == 0:
        print(total_epoch_loss_train)

Usage of the inferred model

Once optimization is finished, we can use the model with the inferred parameters. For example, we might obtain the hidden representation of the original data, which is done by passing such data through the decoder. Edward does not provide any functionality for this purpose, so we will use TensorFlow code. With InferPy, this is done by simply using the method probmodel.posterior() as follows. For this, Pyro requires to invoke the class method vae.encoder.forward which performs the forward propagation in the encoder NN.

# extract the posterior and generate new digits
postz = np.concatenate([
    m.posterior("z", data={"x": x_train[i:i+M,:]}).sample()
    for i in range(0,N,M)])
def get_tfvar(name):
    for v in tf.trainable_variables():
        if str.startswith(v.name, name):
            return v

def predictive_nn(x, beta0, alpha0, beta1, alpha1):
    h0 = tf.nn.relu(x @ beta0 + alpha0)
    output = h0 @ beta1 + alpha1

    return output

weights_encoder = [sess.run(get_tfvar("encoder/h" + name)) for name in ["0/kernel", "0/bias", "1/kernel", "1/bias",]]
postz = sess.run(predictive_nn(x_train, *weights_encoder)[:, :k])
# extract the posterior of z
postz = np.concatenate([
    vae.encoder.forward(x)[0].detach().numpy()
    for x in train_loader])

The result of plotting the hidden representation is:

Plot of the hidden encoding.

We might be also interested in generating new digits, which implies passing some data in the hidden space through the decoder. With InferPy we must just invoke the method probmodel.posterior_predictive(). In Pyro this is done by invoking the class method vae.encoder.forward which performs the forward propagation in the decoder NN.

x_gen = m.posterior_predictive('x', data={"z": postz[:M,:]}).sample()
mnist.plot_digits(x_gen, grid=[5,5])
weights_decoder = [sess.run(get_tfvar("decoder/h" + name)) for name in ["0/kernel", "0/bias", "1/kernel", "1/bias",]]
x_gen = sess.run(predictive_nn(postz, *weights_decoder)[:, :dx])

nx, ny = (3,3)
fig, ax = plt.subplots(nx, ny, figsize=(12, 12))
fig.tight_layout(pad=0.3, rect=[0, 0, 0.9, 0.9])
for x, y in [(i, j) for i in list(range(nx)) for j in list(range(ny))]:
    img_i = x_gen[x + y * nx].reshape((28, 28))
    i = (y, x) if nx > 1 else y
    ax[i].imshow(img_i, cmap='gray')
plt.show()
## generate new digits
x_gen = vae.decoder.forward(torch.Tensor(postz[:M,:]))
mnist.plot_digits(x_gen.detach().numpy(), grid=[5,5])

Some of the resulting images are shown below.

MNIST generated data.