Comparison: Logistic Regression¶
Here, the InferPy code is compared with other similar frameworks. A logistic regression will be considered.
Setting up¶
First the required packages are imported. Variable d
is the number of predictive
attributes while N
is the number of observations.
import inferpy as inf
import numpy as np
import tensorflow as tf
d = 2
N = 10000
from tensorflow_probability import edward2 as ed
import tensorflow as tf
d = 2
N = 50000
import torch
import pyro
from pyro.distributions import Normal, Binomial
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.contrib.autoguide import AutoDiagonalNormal
d = 2
N = 1000
Model definition¶
Models are defined as functions. In case of InferPy these must be decoraed
with @inf.probmodel
. Inspired in Pyro, InferPy uses construct inf.datamodel
for simplifying the definition of the variables dimension. In the following
code fragments, P and Q models are defined.
@inf.probmodel
def log_reg(d):
w0 = inf.Normal(0., 1., name="w0")
w = inf.Normal(np.zeros([d, 1]), np.ones([d, 1]), name="w")
with inf.datamodel():
x = inf.Normal(np.zeros(d), 2., name="x") # the scale is broadcasted to shape [d] because of loc
y = inf.Bernoulli(logits=w0 + x @ w, name="y")
@inf.probmodel
def qmodel(d):
qw0_loc = inf.Parameter(0., name="qw0_loc")
qw0_scale = tf.math.softplus(inf.Parameter(1., name="qw0_scale"))
qw0 = inf.Normal(qw0_loc, qw0_scale, name="w0")
qw_loc = inf.Parameter(tf.zeros([d, 1]), name="qw_loc")
qw_scale = tf.math.softplus(inf.Parameter(tf.ones([d, 1]), name="qw_scale"))
qw = inf.Normal(qw_loc, qw_scale, name="w")
def log_reg(d, N, w_init=(1, 1), x_init=(0, 1)):
w = ed.Normal(loc=tf.ones([d], dtype="float32") * w_init[0], scale=1. * w_init[1], name="w")
w0 = ed.Normal(loc=1. * w_init[0], scale=1. * w_init[1], name="w0")
x = ed.Normal(loc=tf.ones([N, d], dtype="float32") * x_init[0], scale=1. * x_init[1], name="x")
y = ed.Bernoulli(logits=tf.tensordot(x, w, axes=[[1], [0]]) + w0, name="y")
return x, y, (w, w0)
def qmodel(d, N):
qw_loc = tf.Variable(tf.ones([d]))
qw_scale = tf.math.softplus(tf.Variable(tf.ones([d])))
qw0_loc = tf.Variable(1.)
qw0_scale = tf.math.softplus(tf.Variable(1.))
qw = ed.Normal(loc=qw_loc, scale=qw_scale, name="qw")
qw0 = ed.Normal(loc=qw0_loc, scale=qw0_scale, name="qw0")
return qw, qw0
def log_reg(x_data=None, y_data=None):
w = pyro.sample("w", Normal(torch.zeros(d), torch.ones(d)))
w0 = pyro.sample("w0", Normal(0., 1.))
with pyro.plate("map", N):
x = pyro.sample("x", Normal(torch.zeros(d), 2).to_event(1), obs=x_data)
logits = (w0 + x @ torch.FloatTensor(w)).squeeze(-1)
y = pyro.sample("pred", Binomial(logits = logits), obs=y_data)
return x,y
qmodel = AutoDiagonalNormal(log_reg)
Sample form the pior model¶
Now we can sample from the P-model in which the global parameters are fixed. As it can be observed below, this is more complex in TFP.
# instance of the model
m = log_reg(d)
# create toy data
data = m.prior(["x", "y"], data={"w0": 0, "w": [[2], [1]]}).sample(N)
x_train = data["x"]
y_train = data["y"]
def set_values(**model_kwargs):
"""Creates a value-setting interceptor."""
def interceptor(f, *args, **kwargs):
"""Sets random variable values to its aligned value."""
name = kwargs.get("name")
if name in model_kwargs:
kwargs["value"] = model_kwargs[name]
else:
print(f"set_values not interested in {name}.")
return ed.interceptable(f)(*args, **kwargs)
return interceptor
with ed.interception(set_values(w=[2, 1], w0=0)):
generate = log_reg(d, N, x_init=(2, 10))
with tf.Session() as sess:
x_train, y_train, _ = sess.run(generate)
sampler = pyro.condition(log_reg, data={"w0": 0, "w": [2,1]})
x_train, y_train = sampler()
Inference¶
Using the data generated, variational inference can be done as follows. This is quite simple with our package, while TFP and Pyro require the user to implement optimization loop.
VI = inf.inference.VI(qmodel(d), epochs=10000)
m.fit({"x": x_train, "y": y_train}, VI)
qw, qw0 = qmodel(d, N)
with ed.interception(set_values(w=qw, w0=qw0, x=x_train, y=y_train)):
post_x, post_y, (post_w, post_w0) = log_reg(d, N)
energy = tf.reduce_sum(post_x.distribution.log_prob(post_x.value)) + \
tf.reduce_sum(post_y.distribution.log_prob(y_train)) + \
tf.reduce_sum(post_w.distribution.log_prob(qw.value)) + \
tf.reduce_sum(post_w0.distribution.log_prob(qw0.value))
entropy = -(tf.reduce_sum(qw.distribution.log_prob(qw.value)) + \
tf.reduce_sum(qw0.distribution.log_prob(qw0.value)))
# ELBO definition
elbo = energy + entropy
# Optimization loop
optimizer = tf.train.AdamOptimizer(learning_rate=0.05)
train = optimizer.minimize(-elbo)
init = tf.global_variables_initializer()
t = []
num_epochs = 10000
with tf.Session() as sess:
sess.run(init)
for i in range(num_epochs):
sess.run(train)
if i % 5 == 0:
t.append(sess.run([elbo]))
if i % 50 == 0:
print(sess.run(elbo))
w_post = sess.run(qw.distribution.loc)
w0_post = sess.run(qw0.distribution.loc)
optim = Adam({"lr": 0.1})
svi = SVI(log_reg, qmodel, optim, loss=Trace_ELBO(), num_samples=10)
num_iterations = 10000
pyro.clear_param_store()
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss = svi.step(x_train, y_train)
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(x_train)))
Usage of the inferred model¶
Finally, the posterior distributions of the global parameters w
can be shown w0
. From the posterior predictive distribution,
samples can be generated as follows.
# Print the parameters
w_post = m.posterior("w").parameters()["loc"]
w0_post = m.posterior("w0").parameters()["loc"]
print(w_post, w0_post)
# Sample from the posterior
post_sample = m.posterior_predictive(["x","y"], data={"w":w_post, "w":w0_post}).sample()
x_gen = post_sample["x"]
y_gen = post_sample["y"]
print(x_gen, y_gen)
# Print the parameters
print(w_post, w0_post)
# Sample form the posterior
with ed.interception(set_values(w=w_post, w0=w0_post)):
generate = log_reg(d, N, x_init=(0, 0))
with tf.Session() as sess:
x_gen, y_gen, _ = sess.run(generate)
print(x_gen, y_gen)
# Print the parameters
w_post = qmodel()["w"]
w0_post = qmodel()["w0"]
print(w_post, w0_post)
# Sample from the posterior
sampler_post = pyro.condition(log_reg, data={"w0": w0_post, "w": w_post})
x_gen, y_gen = sampler_post()
print(x_gen, y_gen)