Source code for inferpy.util.runtime

# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


"""
Module focused on evaluating tensors to makes the usage easier, forgetting about tensors and sessions
"""

from functools import wraps
from contextlib import contextmanager

from inferpy import util


# default value for tf_run in decorated tf_run_allowed functions
__tf_run_default = True

# configuration environment for runner_scopes. It counts the number of nested contexts
runner_context = dict(
    runner_recursive_depth=0
)


[docs]@contextmanager def runner_scope(): # Update the runner recursive depth, because decorated functions might call other decorated functions too. # This way, we can control that only first level decorated functions will be evaluated in a tf Session. runner_context['runner_recursive_depth'] += 1 try: yield finally: runner_context['runner_recursive_depth'] -= 1
[docs]def tf_run_allowed(f): """ A function might return a tensor or not. In order to decide if the result of this function needs to be evaluated in a tf session or not, use the tf_run extra parameter or the tf_run_default value. If True, and this function is in the first level of execution depth, use a tf Session to evaluate the tensor or other evaluable object (like dicts) """ @wraps(f) def wrapper(*args, **kwargs): # first obtains the tf_run, a bool which tells if we need to eval the output in a session or not if "tf_run" in kwargs: tf_run = kwargs.pop("tf_run") else: tf_run = __tf_run_default # use this context to keep track of the decorated functions calls (recursive depth level) with runner_scope(): # now execute the function obj = f(*args, **kwargs) if tf_run and runner_context['runner_recursive_depth'] == 1: # first recursive depth, and tf_run is True: we can eval the function return try_run(obj) else: # tf_run is False or we are in a deeper runner levels than 1 (do not eval the result yet) return obj return wrapper
[docs]def tf_run_ignored(f): """ A function might call other functions decorated with tf_run_allowed. This decorator is used to avoid that such functions are evaluated. """ @wraps(f) def wrapper(*args, **kwargs): # use this context to keep track of the decorated functions calls (recursive depth level). # this way, deeper execution level will not be evaluated. # finally, in first execution level we do not evaluate the result neither with runner_scope(): # now execute the function and return the result return f(*args, **kwargs) return wrapper
[docs]def set_tf_run(enable): # this function is used to modify the default state of tf run (eval tensors or not) global __tf_run_default __tf_run_default = enable
[docs]def try_run(obj): try: ev_obj = util.get_session().run(obj) return ev_obj except (RuntimeError, TypeError, ValueError): # cannot evaluate the result, return the obj return obj