Source code for inferpy.util.session


import tensorflow as tf
import warnings


"""
Module to manage global sessions and graphs executed in sessions
"""

__session = None


[docs]def new_session(gpu_memory_fraction=0.0): # Create a new session. By default do not use GPU. Use gpu_memory_fraction > 0 (and <= 1) to use GPU. if gpu_memory_fraction <= 0.0: set_session(tf.Session()) else: config = tf.ConfigProto(log_device_placement=True) config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction # memory_fraction [0-1] set_session(tf.Session(config=config))
[docs]def get_session(): global __session if not __session: __session = tf.Session() return __session
[docs]def set_session(session): global __session if __session: warnings.warn("Running session closed to use the provided session instead") __session.close() __session = session
[docs]def swap_session(new_session): global __session old_session = __session __session = new_session return old_session
[docs]def clear_session(): global __session if __session: __session.close() __session = tf.Session()
[docs]def init_uninit_vars(): uninit_vars = set(get_session().run(tf.report_uninitialized_variables())) if len(uninit_vars) > 0: get_session().run(tf.variables_initializer( [v for v in tf.global_variables() if v.name.split(':')[0].encode('UTF-8') in uninit_vars] ))