Commit 3eda7d75 authored by Chris Jewell's avatar Chris Jewell
Browse files

Made small adjustments for basic tensorflow 2 migration. However,...

Made small adjustments for basic tensorflow 2 migration.  However, incompatibilities between Edward2, Tensorflow probability [>=0.80], and TF2 are hindering progress.
parent c868302e
Pipeline #266 failed with stage
in 2 minutes and 17 seconds
......@@ -160,9 +160,12 @@ class DAG(ASTWalker):
def onEnter_IdRef(self, node):
if isinstance(node.symbol, VariableSymbol) and node.symbol.type == SymbolTable._random_variable:
parent = self.__dag[node.value]
child = self.__current_rv[-1]
child.add_parent(parent)
parent.add_child(child)
if len(self.__current_rv) > 0:
child = self.__current_rv[-1]
child.add_parent(parent)
parent.add_child(child)
elif isinstance(node.symbol, STMSymbol):
self._walk(node.symbol.ast_def.children[2])
raise StopWalk()
def dot(self):
......
......@@ -20,7 +20,6 @@
# Parse GEM Language to AST
import os
import inspect
from lark import Lark, Transformer, v_args
......
......@@ -74,8 +74,6 @@ class GEMProgram(Outputter):
self.line("from tensorflow_probability import edward2 as ed") + \
self.line(
"from gem.model import TFContinuousMarkovILM, make_epidemic_process")
if tf.executing_eagerly():
str += self.line("tf.enable_eager_execution()")
return str
def __gen_footer(self):
......
......@@ -64,7 +64,7 @@ class GEM:
:type const_data: dict
:return a module object implementing the model.
"""
if not tf.executing_eagerly():
if not tf.compat.v1.executing_eagerly():
self.tf_session = tf.compat.v1.Session() if tf_session is None else tf_session
self.gem_external = const_data
......@@ -131,10 +131,7 @@ class GEM:
"""
condition_vars = as_numpy_dict(condition_vars)
with ed.interception(make_value_setter(**condition_vars)):
if tf.executing_eagerly():
res_dict = self.model_impl()
else:
res_dict = self.tf_session.run(self.model_impl())
res_dict = self.model_impl()
return res_dict
def fit(self, observed, n_samples, init=None, tune=None, burnin=0,
......
......@@ -26,8 +26,8 @@ import inspect
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.client import session as tfc
from tensorflow_probability.python.edward2.interceptor import interceptable, \
interception
from edward2 import traceable, trace
from gem.model.data_structures import EnumeratedEventList
......@@ -133,7 +133,7 @@ class EpidemicProcess:
def make_epidemic_process(process_cls):
"""Factory function to make random variable given distribution class."""
@interceptable
@traceable
@functools.wraps(process_cls, assigned=("__module__", "__name__"))
def func(*args, **kwargs):
sample_shape = kwargs.pop("sample_shape", ())
......@@ -153,7 +153,7 @@ def make_value_setter(**model_kwargs):
kwargs["value"] = model_kwargs[name]
kwargs["sample_shape"] = model_kwargs[name].shape
try:
return interceptable(f)(*args, **kwargs)
return traceable(f)(*args, **kwargs)
except ValueError as e:
raise ValueError(f"ValueError in variable '{name}': {e}")
......@@ -229,7 +229,7 @@ def make_log_joint_fn(model):
"""
log_probs = []
def interceptor(rv_constructor, *rv_args, **rv_kwargs):
def tracer(rv_constructor, *rv_args, **rv_kwargs):
"""Overrides a random variable's `value` and accumulates its log-prob."""
# Set value to keyword argument indexed by `name` (an input tensor).
rv_name = rv_kwargs.get("name")
......@@ -256,7 +256,7 @@ def make_log_joint_fn(model):
return rv
model_kwargs = _get_function_inputs(model, kwargs)
with interception(interceptor):
with trace(tracer):
model(*args, **model_kwargs)
log_prob = sum(log_probs)
return log_prob
......
......@@ -213,5 +213,6 @@ class TestGEMLang(unittest.TestCase):
epsilon ~ Gamma(1.0, 1.0)
"""
model = GEM(program)
print(model.pyprog)
res = model.sample(condition_vars={'epsilon': [1.0]})
self.assertEqual(1.0, res['epsilon'])
self.assertEqual(1.0, res['epsilon'].numpy())
......@@ -73,9 +73,7 @@ class TestLogProb(unittest.TestCase):
dtype=np.float32)
model = GEM(prog, const_data={'x': x})
print(model.pyprog)
with tf.Session() as s:
res = s.run(
model.log_prob_fn(alpha=1.4, beta=0.5, sigma=0.01, y=y))
res = model.log_prob_fn(alpha=1.4, beta=0.5, sigma=0.01, y=y).numpy()
self.assertAlmostEqual(res, 15.916683, places=5)
def test_epidemic_simple(self):
......
......@@ -48,7 +48,35 @@ class TestSimpleDAGGenerator(unittest.TestCase):
self.assertEqual(dag.dot().source, self.expected_graph)
@unittest.skip("Work in progress")
class TestHierarchyDAGGenerator(unittest.TestCase):
def setUp(self):
self.expected_graph = "digraph {\n\tbeta\n\tgamma\n\tu\n\tepi\n\tbeta -> u\n\tgamma -> u\n\tgamma -> epi\n\tu -> epi\n}"
prog = """
beta ~ Gamma(1, 1)
gamma ~ Gamma(1, 1)
sigma ~ Gamma(1, 1)
phi ~ Gamma(1, 1)
s ~ Normal(0., sigma*phi)
Epidemic SIR() {
S = State(init=999)
I = State(init=1)
R = State(init=0)
[S -> I] = beta * exp(s) * I
[I -> R] = gamma
}
epi ~ SIR()
"""
self.model = GEM(prog)
def test_dag_generator(self):
from gem.gemlang.dag import DAG
dag = DAG(self.model.ast)
dot = dag.dot()
#self.assertEqual(dot.source, self.expected_graph)
dot.render()
class TestEpiDAGGenerator(unittest.TestCase):
def setUp(self):
self.expected_graph = "digraph {\n\tbeta\n\tgamma\n\tepi\n\tbeta -> epi\n\tgamma -> epi\n}"
......@@ -65,13 +93,10 @@ class TestEpiDAGGenerator(unittest.TestCase):
}
epi ~ SIR()
"""
ast = gemparse(prog)
symtab = SymbolDeclarer().visit(ast)
SymbolResolver().visit(ast, symtab)
self.ast = ast
self.symtab = symtab
self.model = GEM(prog)
def test_dag_generator(self):
dag = generate_dag(self.symtab)
dot = dag2dot(dag)
from gem.gemlang.dag import DAG
dag = DAG(self.model.ast)
dot = dag.dot()
self.assertEqual(dot.source, self.expected_graph)
......@@ -72,8 +72,8 @@ class TestEnumeratedEventList(unittest.TestCase):
ss = eventlist2states(self.eventList, model.random_variables['epi'].process.initial,
model.random_variables['epi'].process.time_origin,
model.random_variables['epi'].process.stoichiometry)
with tf.Session() as sess:
ss_np = sess.run(ss)
ss_np = [x.numpy() for x in ss]
npt.assert_array_equal(ss_np[1],
np.array([[[5], [4], [3], [2], [2], [2]],
[[1], [2], [3], [4], [3], [2]],
......
......@@ -33,8 +33,7 @@ class TestEdward2(unittest.TestCase):
rv = ed.Gamma(concentration=1.0, rate=1.0, name='rv')
return rv
with tf.Session() as sess:
res = sess.run(model())
res = model().numpy()
self.assertIs(type(res), np.float32)
def test_rv_log_prob(self):
......@@ -43,6 +42,6 @@ class TestEdward2(unittest.TestCase):
return rv
log_prob_fn = ed.make_log_joint_fn(model)
with tf.Session() as sess:
lp = sess.run(log_prob_fn(rv=1.0))
lp = log_prob_fn(rv=1.0).numpy()
self.assertEqual(lp, -1.0)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment