Source code for inferpy.queries.query

import numpy as np
import functools

from inferpy import contextmanager
from inferpy import util


[docs]def flatten_result(f): @functools.wraps(f) def wrapper(*args, **kwargs): simplify_result = kwargs.pop('simplify_result', True) result = f(*args, **kwargs) if simplify_result and len(result) == 1: return result[list(result.keys())[0]] else: return result return wrapper
[docs]class Query: def __init__(self, variables, target_names=None, data={}, enable_interceptor_variables=(None, None)): # enable_interceptor_variables is a tuple to intercept global and local hidden variables independently # if provided a single name, create a list with only one item if isinstance(target_names, str): target_names = [target_names] # raise an error if target_names is not None and contains variable names not in variables if target_names and any((name not in variables for name in target_names)): raise ValueError("Target names must correspond to variable names") self.target_variables = variables if not target_names else \ {k: v for k, v in variables.items() if k in target_names} self.observed_variables = variables self.data = data self.enable_interceptor_variables = enable_interceptor_variables
[docs] @flatten_result @util.tf_run_ignored def log_prob(self): """ Computes the log probabilities of a (set of) sample(s)""" with util.interceptor.enable_interceptor(*self.enable_interceptor_variables): with contextmanager.observe(self.observed_variables, self.data): result = util.runtime.try_run({k: v.log_prob(v.value) for k, v in self.target_variables.items()}) return result
[docs] def sum_log_prob(self): """ Computes the sum of the log probabilities (evaluated) of a (set of) sample(s)""" # The decorator is not needed here because this function returns a single value return np.sum([np.mean(lp) for lp in self.log_prob(simplify_result=False).values()])
[docs] @flatten_result @util.tf_run_ignored def sample(self, size=1): """ Generates a sample for eache variable in the model """ with util.interceptor.enable_interceptor(*self.enable_interceptor_variables): with contextmanager.observe(self.observed_variables, self.data): # each iteration for `size` run the dict in the session, so if there are dependencies among random vars # they are computed in the same graph operations, and reflected in the results samples = [util.runtime.try_run(self.target_variables) for _ in range(size)] if size == 1: result = samples[0] else: # compact all samples in one single dict result = {k: np.array([sample[k] for sample in samples]) for k in self.target_variables.keys()} return result
[docs] @flatten_result @util.tf_run_ignored def parameters(self, names=None): """ Return the parameters of the Random Variables of the model. If `names` is None, then return all the parameters of all the Random Variables. If `names` is a list, then return the parameters specified in the list (if exists) for all the Random Variables. If `names` is a dict, then return all the parameters specified (value) for each Random Variable (key). Note: If `tf_run=True`, but any of the returned parameters is not a Tensor and therefore cannot be evaluated) this returns a not evaluated dict (because the evaluation will raise an Exception) Args: names: A list, a dict or None. Specify the parameters for the Random Variables to be obtained. Returns: A dict, where the keys are the names of the Random Variables and the values a dict of parameters (name-value) """ # argument type checking if not(names is None or isinstance(names, (list, dict))): raise TypeError("The argument 'names' must be None, a list or a dict, not {}.".format(type(names))) # now we can assume that names is None, a list or a dict # function to filter the parameters for each Random Variable def filter_parameters(varname, parameters): parameter_names = list(parameters.keys()) if names is None: # use all the parameters selected_parameters = parameter_names else: # filter by names; if is a dict and key not in, use all the parameters selected_parameters = set(names if isinstance(names, list) else names.get(varname, parameters)) return {k: util.runtime.try_run(v) for k, v in parameters.items() if k in selected_parameters} with contextmanager.observe(self.observed_variables, self.data): result = {k: filter_parameters(k, v.parameters) for k, v in self.target_variables.items()} return result