Source code for inferpy.contextmanager.layer_registry
from contextlib import contextmanager
"""This context must be used by the prob model before calling the builder function.
This is used to store the layers.sequential.Sequential objects...
"""
def _restart_properties():
global _properties
_properties = dict(
_sequentials=[],
enabled=False
)
# call to the function to restart the default context
_restart_properties()
[docs]def add_sequential(sequential):
# only if enabled append sequential object (i.e. when building the graph of dependencies is not necessary)
if _properties["enabled"]:
_properties["_sequentials"].append(sequential)
[docs]def get_losses():
assert _properties["enabled"]
losses = [loss for sequential in _properties["_sequentials"] for loss in sequential.losses]
return sum(losses) if len(losses) > 0 else None
[docs]@contextmanager
def init(graph=None):
global _properties
assert not _properties["enabled"]
assert _properties["_sequentials"] == []
try:
# now the sequentials created can use this list to store the objects
_properties["enabled"] = True
yield
finally:
# reasign the default object
_restart_properties()