Source code for inferpy.util.tf_graph

from collections import defaultdict
import networkx as nx
import tensorflow as tf


"""
These set of private functions obtain the dependencies between Random Variables
by analyzing the tensors in Random Varibles and its relations. Finally, the
_get_graph function computes a graph in networkx that represents these dependencies.
"""


def _get_varname(op):
    op_name = op.name
    idx = op_name.find('/')  # Use the first part of the operation name (until slash) as name
    if idx != -1 and '/Assign' not in op_name:  # Special case for tf.Variables
        return op_name[:idx]
    else:
        return op_name


def _children(op):
    # get the consumers of the operation as its children (set of names, using _get_varname)
    return set(_get_varname(opc) for out in op.outputs for opc in out.consumers())


def _clean_graph(G, varnames):
    # G is a networkx graph. Clean nodes from it whose names are not in varnames set or dict.
    # Before removing such nodes, create an edge between their parents and their children.
    g_nodes = list(G.nodes)
    for n in g_nodes:
        if n not in varnames:
            n_name = n[:n.rfind('/')]  # real name of the tf.variable
            if n_name in varnames and '/Assign' in n:  # Special case for tf.Variables used by inf.Parameters
                predecesors = list(G.predecessors(n))

                assert len(predecesors) <= 2  # At most, it should have two predecessors

                if len(predecesors) == 2:
                    if predecesors[0] == n_name:
                        # create relation from predecesors[1] to predecesors[0]
                        G.add_edge(predecesors[1], predecesors[0])
                    else:
                        # create relation from predecesors[0] to predecesors[1]
                        G.add_edge(predecesors[0], predecesors[1])
            else:
                # remove and create edge between parent and child if exist
                for p in G.predecessors(n):
                    for s in G.successors(n):
                        G.add_edge(p, s)
            G.remove_node(n)
    return G


[docs]def get_graph(varnames): # varnames is a set or dict where keys are the var names of the Random Variables if not (isinstance(varnames, dict) or isinstance(varnames, set)): raise TypeError("The type of varnames must be dict or set, not {}".format(type(varnames))) # Creates dictionary {node: {child1, child2, ..},..} for current # TensorFlow graph. Result is compatible with networkx/toposort # Uses the default_graph ops = tf.get_default_graph().get_operations() dependencies = defaultdict(set) for op in ops: # in tensorflow_probability, the tensor named *sample_shape* is a op in child-parent order. # as we want to capture only the parent-child relations, skip these op names if 'sample_shape' not in op.name: c = _children(op) op_name = _get_varname(op) c.discard(op_name) # avoid name references to itself dependencies[op_name].update(c) # create networkx graph G = nx.DiGraph(dependencies) # clean names, to get just ed2 RV var names _clean_graph(G, varnames) # inplace modification return G
[docs]def get_empty_graph(): return nx.DiGraph()