Commit bd3eaae2 authored by Chris Jewell's avatar Chris Jewell
Browse files

Array logic implemented. Does not yet work with events2states

parent 10895b39
......@@ -20,7 +20,7 @@
"""AST Expression classes"""
__all__ = ['Expr', 'UnaryExpr', 'BinaryExpr', 'Call', 'KwArg',
'IdRef', 'ArrayLiteral', 'Number', 'Integer', 'Float', 'String', 'TransitionDef', 'ArgList']
'IdRef', 'Array', 'Number', 'Integer', 'Float', 'String', 'TransitionDef', 'ArgList']
import operator
......@@ -169,20 +169,23 @@ class IdRef(Expr):
# Data literals
class ArrayLiteral(Expr):
class Array(Expr):
"""Represents an n-dimensional array literal.
:param elements: a list of array elements
"""
def __init__(self, elements, *args, **kwargs):
super().__init__(*args, **kwargs)
self.value = elements
for element in elements:
self.add_child(element)
def __str__(self):
return str(self.value)
str_children = [str(child) for child in self.children]
return f"[{', '.join(str_children)}]"
def __repr__(self):
return f"{self.__class__.__name__} '{str(self.value)}'"
repr_children = [repr(child) for child in self.children]
return f"{self.__class__.__name__} '{repr_children}'"
class Number(Expr):
......
......@@ -92,9 +92,9 @@ class CodeGenerator(ASTWalker):
name = epidecl.name
config = epidecl.children[0]
param = epidecl.children[1]
stm = out.STM(name, config, param, self.current_workspace['states'],
self.current_workspace['transitions'],
self.current_workspace['initial'])
stm = out.STMDef(name, config, param, self.current_workspace['states'],
self.current_workspace['transitions'],
self.current_workspace['initial'])
self.pop_ws()
return stm
......@@ -199,6 +199,9 @@ class CodeGenerator(ASTWalker):
self.current_workspace['transitions'].append(tx)
return tx
def onExit_Array(self, array):
return out.Array(array.children)
def onExit_Number(self, n):
return out.Number(n.value)
......
......@@ -145,7 +145,7 @@ class ParseTree2AST(Transformer):
return IdRef(str(args[0]), meta=meta)
def array(self, args, meta):
return ArrayLiteral(args, meta=meta)
return Array(args, meta=meta)
def number(self, args, meta):
if args[0].type == 'INT':
......
......@@ -97,7 +97,7 @@ class GEMProgram(Outputter):
return s
class STM(Outputter):
class STMDef(Outputter):
def __init__(self, name, config, param, states, transitions, initial_states,
**kwargs):
super().__init__(**kwargs)
......@@ -129,10 +129,10 @@ class STM(Outputter):
s = "["
for init in self.inits:
init = np.atleast_1d(init)
s += "["
#s += "["
for i in init:
s += str(i) + ","
s += "], "
#s += "], "
s += "]"
return s
......@@ -350,7 +350,11 @@ class Assign(Outputter):
self.rhs = rhs
def __str__(self):
return f"{self.lhs} = {self.rhs}"
if isinstance(self.rhs, (Number, Array)):
rhs_expr = f"tf.convert_to_tensor({self.rhs})"
else:
rhs_expr = f"{self.rhs}"
return f"{self.lhs} = {rhs_expr}"
class STMInstance(Function):
......@@ -369,13 +373,27 @@ class STMInstance(Function):
return s
class Array(Outputter):
def __init__(self, elems):
self.value = elems
def __str__(self):
out_elems = []
for elem in self.value:
if isinstance(elem, Number):
out_elems.append(str(elem.value))
else:
out_elems.append(str(elem))
return f"[{', '.join(out_elems)}]"
class Number(Outputter):
def __init__(self, value):
self.value = value
def __str__(self):
return f"tf.constant({self.value})"
#return f"tf.constant({self.value})"
return f"{self.value}"
# Builtin gemlang funcs
class BuiltinGamma(Function):
......
......@@ -88,7 +88,7 @@ class EnumeratedEventList:
return tf.tuple(self.t, self.uid, self.trans_id)
def __repr__(self):
return f"{self.__class__}(t={self.t}, id={self.uid}, event={self.trans_id})"
return f"{self.__class__}(t={self.t}, uid={self.uid}, event={self.trans_id})"
def enumerated_event_list_to_tensor(value, dtype=None, name=None,
......
......@@ -59,7 +59,7 @@ def cont_markov_ilm_propagate(h, stoichiometry):
if state.shape[
1] > 1: # This is a workaround for a suspected bug in TF when state.shape[1] = 1
uid = tf.squeeze(event % state.shape[1], name='packId')
event_type = tf.squeeze(event // state.shape[1],
event_enum = tf.squeeze(event // state.shape[1],
name='packEvent_type')
else:
uid = 0
......
......@@ -50,7 +50,7 @@ def eventlist2states(events, state0, time_origin, stoichiometry):
:param state0: initial state from which to propagate
:param stoichiometry: the stoichiometry of the state transition model
:param time_origin: the initial time
:return a tuple of tensors containing time and state after each event, including the initial state. Shape of state
:returns: a tuple of tensors containing time and state after each event, including the initial state. Shape of state
is (S, T, N) where S is the number of states respecting the enumerated event list, T is the number of times, and N
is the number of individuals.
"""
......@@ -62,6 +62,13 @@ def eventlist2states(events, state0, time_origin, stoichiometry):
stoichiometry = tf.convert_to_tensor(stoichiometry)
transition = tf.gather(stoichiometry, indices=events.trans_id[0])
n_states = state0.shape[0]
n_times = events.t.shape[0]
n_indiv = state0.shape[1]
increments = tf.scatter_nd(indices, transition, shape=[n_states, n_times, n_indiv])
incr = tf.math.cumsum(transition, axis=0)
incr = tf.expand_dims(incr, axis=2)
state = state0 + incr
......
......@@ -40,6 +40,24 @@ class TestLogProb(unittest.TestCase):
res = model.log_prob(epsilon=1.0)
self.assertEqual(res, -1.0)
def test_bivariate_free(self):
prog = """
epsilon ~ Gamma([1.0, 2.0], 1.0)
"""
model = GEM(prog)
print(model.pyprog)
res = model.sample()
self.assertTrue(len(res['epsilon']) == 2)
def test_nvariate_free(self):
prog = """
epsilon ~ Gamma([[1.0, 2.0], [3.0, 4.0]], 1.0)
"""
model = GEM(prog)
print(model.pyprog)
res = model.sample()
self.assertTrue(res['epsilon'].shape == (2, 2))
def test_vector(self):
"""Tests uninitialized distribution with a vector -- i.e. tests sum of log probs."""
prog = """
......@@ -282,3 +300,30 @@ class TestInference(unittest.TestCase):
plt.figure()
plt.plot(post['beta_1'], post['beta_2'])
plt.show()
class TestMetaPopulation(unittest.TestCase):
def test_epidemic_meta(self):
"""Test an epidemic metapopulation model."""
prog = """
beta = 0.002
gamma = 0.14
Epidemic MyEpidemic() {
S = State(init=999)
I = State(init=1)
R = State(init=0)
[S -> I] = beta * I
[I -> R] = gamma
}
epi ~ MyEpidemic()
"""
events = EnumeratedEventList(
t=np.array([0.048, 0.69, 1.10], dtype=np.float32),
uid=np.array([0] * 3, dtype=np.int32),
trans_id=np.array([0, 0, 1], dtype=np.int32))
model = GEM(prog)
res = model.log_prob(epi=events)
self.assertAlmostEqual(-4.2611294, res, places=5)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment