Commit 40ee1ced authored by Chris Jewell's avatar Chris Jewell
Browse files

Switched storage order in "States" representation of an epidemic. We now have...

Switched storage order in "States" representation of an epidemic.  We now have TxPxS for Time, Population (size), States.
parent bd3eaae2
Pipeline #276 passed with stage
in 4 minutes and 32 seconds
......@@ -110,7 +110,7 @@ class CodeGenerator(ASTWalker):
return arglist.children
def onExit_KwArg(self, kwarg):
return str(kwarg.children[1])
return kwarg.children[1]
def onExit_TruedivExpr(self, op):
return out.Truediv(op.children[0], op.children[1])
......
......@@ -128,11 +128,10 @@ class STMDef(Outputter):
"""Writes the initial values matrix"""
s = "["
for init in self.inits:
init = np.atleast_1d(init)
#s += "["
for i in init:
if not isinstance(init, Array):
init = Array([init])
for i in [init]:
s += str(i) + ","
#s += "], "
s += "]"
return s
......@@ -144,8 +143,8 @@ class STMDef(Outputter):
f"with tf.compat.v1.name_scope('{self.name}_h') as scope:")
self.indent()
f_repr += self.line("state = tf.cast(state, tf.float32)")
for i, state in enumerate(self.states):
f_repr += self.line(f"{state} = state[{i}]")
f_repr += self.line("state = tf.unstack(state, axis=-1)")
f_repr += self.line(f"{','.join(self.states)} = state")
f_repr += self.line("return tf.stack([")
self.indent()
for tx in self.transitions:
......@@ -174,7 +173,7 @@ class STMDef(Outputter):
s += self.line(
f"self.stoichiometry = tf.constant({self.stoichiometry()}, dtype=tf.int64)")
s += self.line(
f"self.initial = tf.convert_to_tensor({self.initial()}, dtype=tf.int64)")
f"self.initial = tf.transpose(tf.convert_to_tensor({self.initial()}, dtype=tf.int64))")
self.dedent()
s += self.line("")
s += self.hazard_rates()
......
......@@ -38,7 +38,7 @@ def cont_markov_ilm_logp(h, states, events):
states = tf.cast(states[1], dtype=tf.float32)
delta_t = t[1:] - t[:-1]
hazards = h(states[:, :-1, :]) # Hx(T-1)xN
hazards = h(states[:-1, :, :]) # Hx(T-1)xN
to_gather = tf.stack(
(event_trans_id, tf.range(n_events - 1), event_uid), -1)
......
......@@ -36,6 +36,9 @@ def stop_cond(end):
def cont_markov_ilm_propagate(h, stoichiometry):
update_fn = update_state(stoichiometry)
def propagate(old_state, accum_time, accum_event, i):
"""Propagates a state vector.
......@@ -54,18 +57,16 @@ def cont_markov_ilm_propagate(h, stoichiometry):
tf.random.gamma([1], alpha=1., beta=total_rate, name='drawTTNE'))
# Draw event type
event = tf.squeeze(
tf.random.categorical([tf.math.log(rates)], 1, name='drawEvent'))
if state.shape[
1] > 1: # This is a workaround for a suspected bug in TF when state.shape[1] = 1
uid = tf.squeeze(event % state.shape[1], name='packId')
event_enum = tf.squeeze(event // state.shape[1],
event = tf.random.categorical([tf.math.log(rates)], 1, name='drawEvent')
# if state.shape[
# 1] > 1: # This is a workaround for a suspected bug in TF when state.shape[1] = 1
uid = tf.squeeze(event % state.shape[0], name='packId')
event_enum = tf.squeeze(event // state.shape[0],
name='packEvent_type')
else:
uid = 0
event_enum = event
update_fn = update_state(stoichiometry)
state = update_fn(state, (uid, event_enum))
# else:
# uid = 0
# event_enum = event
state = update_fn(state, ([[uid]], event_enum))
accum_time = accum_time.write(i, t + ttne)
accum_event = accum_event.write(i, [uid, event_enum])
new_state = t + ttne, state, total_rate
......
......@@ -26,18 +26,11 @@ import tensorflow as tf
def update_state(stoichiometry):
def fn(state, event):
event_uid, event_trans_id = event
event_uid = tf.ensure_shape(event_uid, shape=(None, 1))
with tf.compat.v1.name_scope('update_state') as scope:
if state.shape[1] > 1:
indices = tf.stack((tf.range(state.shape[0], dtype=tf.int64),
tf.broadcast_to(event_uid,
[state.shape[0]])), axis=1,
name='indices')
state = tf.tensor_scatter_add(state, indices,
stoichiometry[event_trans_id, :],
name='updateStates')
else:
state = state + tf.reshape(stoichiometry[event_trans_id, :],
state.shape)
state = tf.tensor_scatter_nd_add(state, event_uid,
[stoichiometry[event_trans_id]],
name='updateStates')
return state
return fn
......@@ -47,12 +40,13 @@ def eventlist2states(events, state0, time_origin, stoichiometry):
"""Converts a list of events into a list of states immediately after each event.
:param event_list: a tuple of (event_time, event_uid, event_trans_id) sorted by time
:param state0: initial state from which to propagate
:param state0: an N x S initial state from which to propagate where N is the number of
individuals and S is the number of states
:param stoichiometry: the stoichiometry of the state transition model
:param time_origin: the initial time
:returns: a tuple of tensors containing time and state after each event, including the initial state. Shape of state
is (S, T, N) where S is the number of states respecting the enumerated event list, T is the number of times, and N
is the number of individuals.
:returns: a tuple of tensors containing time and state after each event, including the initial state.
Shape of state is (T, N, S) where T is the number of times, N is the number of individuals, and
S is the number of states respecting the enumerated event list.
"""
with tf.compat.v1.name_scope('eventlist2states'):
state0 = tf.convert_to_tensor(state0)
......@@ -63,16 +57,16 @@ def eventlist2states(events, state0, time_origin, stoichiometry):
transition = tf.gather(stoichiometry, indices=events.trans_id[0])
n_states = state0.shape[0]
n_times = events.t.shape[0]
n_indiv = state0.shape[1]
n_indiv = state0.shape[0]
n_states = state0.shape[1]
n_times = events.t[0].shape[0]
increments = tf.scatter_nd(indices, transition, shape=[n_states, n_times, n_indiv])
indices = tf.stack((tf.range(n_times), events.uid[0]), axis=1) # Strided access here
increments = tf.scatter_nd(indices, transition, shape=[n_times, n_indiv, n_states])
incr = tf.math.cumsum(transition, axis=0)
incr = tf.expand_dims(incr, axis=2)
incr = tf.math.cumsum(increments, axis=0)
state = state0 + incr
state = tf.concat([[state0], state], axis=0)
state = tf.transpose(state, [1, 0, 2])
state = tf.concat(([state0], state), axis=0)
#state = tf.transpose(state, [1, 0, 2])
t = tf.concat((time_origin, events.t[0]), -1)
return t, state
......@@ -202,6 +202,7 @@ class TestInference(unittest.TestCase):
uid=[[0, 0, 0, 0, 0, 0, 0, 0, 0]],
trans_id=[[0, 1, 0, 1, 0, 0, 1, 1, 1]])
model = GEM(prog)
print(model.pyprog)
self.assertAlmostEqual(
model.log_prob(beta=0.00014, gamma=0.14, epi=epidata), -26.802853,
places=5)
......
......@@ -19,6 +19,7 @@
#
import unittest
import numpy as np
import numpy.testing as npt
......@@ -29,13 +30,13 @@ class TestEnumeratedEventList(unittest.TestCase):
def setUp(self):
self.t = [0.75577045, 0.08644191, 0.60005492, 0.94603795, 0.35167295]
self.pid = range(5)
self.uid = range(5)
self.transition = [0, 0, 0, 1, 1]
self.eventList = EnumeratedEventList(self.t, self.pid, self.transition, sort=False)
self.eventList = EnumeratedEventList(self.t, self.uid, self.transition, sort=False)
def test_constructor(self):
npt.assert_array_equal(self.eventList.t[0], np.array(self.t, dtype=np.float32))
npt.assert_array_equal(self.eventList.uid[0], np.array(self.pid, dtype=np.int32))
npt.assert_array_equal(self.eventList.uid[0], np.array(self.uid, dtype=np.int32))
npt.assert_array_equal(self.eventList.trans_id[0], np.array(self.transition, dtype=np.int32))
def test_sort(self):
......@@ -55,30 +56,40 @@ class TestEnumeratedEventList(unittest.TestCase):
npt.assert_array_equal(self.eventList.trans_id[0], np.array([0, 1, 0, 0, 1], dtype=np.int32))
def test_eventlist2states(self):
model_prog = """
Epidemic SIR() {
S = State(init=5)
I = State(init=1)
R = State(init=0)
[S -> I] = 1.
[I -> R] = 1.
}
epi ~ SIR()
"""
import tensorflow as tf
from gem import GEM
from gem.model.impl.tf_utils import eventlist2states
model = GEM(model_prog)
ss = eventlist2states(self.eventList, model.random_variables['epi'].process.initial,
model.random_variables['epi'].process.time_origin,
model.random_variables['epi'].process.stoichiometry)
state0 = np.array([[999, 1, 0],
[999, 1, 0]])
stoichiometry = np.array([[-1, 1, 0],
[0, -1, 1]])
event_list = EnumeratedEventList(t=[0.1, 0.2, 0.3, 0.4, 0.5],
uid=[0, 0, 0, 0, 0],
trans_id=[0, 0, 0, 1, 1])
ss = eventlist2states(event_list, state0, 0., stoichiometry)
with tf.Session() as sess:
ss_np = sess.run(ss)
print(ss_np)
npt.assert_array_equal(ss_np[1],
np.array([[[5], [4], [3], [2], [2], [2]],
[[1], [2], [3], [4], [3], [2]],
[[0], [0], [0], [0], [1], [2]]]))
np.array([[[999, 1, 0],
[999, 1, 0]],
[[998, 2, 0],
[999, 1, 0]],
[[997, 3, 0],
[999, 1, 0]],
[[996, 4, 0],
[999, 1, 0]],
[[996, 3, 1],
[999, 1, 0]],
[[996, 2, 2],
[999, 1, 0]]]))
class TestEventList(unittest.TestCase):
......
......@@ -27,7 +27,7 @@ import tensorflow as tf
class TestContMarkovILMLogProb(unittest.TestCase):
def setUp(self):
self.initial = np.array([[999], [1], [0]], dtype=np.int32)
self.initial = np.array([[999, 1, 0]], dtype=np.int32)
self.stoichiometry = np.array([[-1, 1, 0],
[0, -1, 1],
[0, 0, 0]], dtype=np.int32)
......@@ -35,17 +35,16 @@ class TestContMarkovILMLogProb(unittest.TestCase):
def h(self, state):
state = tf.cast(state, tf.float32)
S = state[0]
I = state[1]
R = state[2]
state = tf.unstack(state, axis=-1)
S, I, R = state
return tf.stack([
tf.multiply(state[0], tf.multiply(self.params['beta'], I)),
tf.multiply(state[1], self.params['gamma']),
], name='MyEpidemic_h')
def numpy_ll_t(self, delta_t, event_id, state):
h = np.array([state[0] * self.params['beta'] * state[1],
state[1] * self.params['gamma']])
h = np.array([state[:, 0] * self.params['beta'] * state[:, 1],
state[:, 1] * self.params['gamma']])
log_survivor = -h * delta_t
ll_t_numpy = log_survivor.sum() + np.log(h[event_id])[0]
return ll_t_numpy
......@@ -63,7 +62,7 @@ class TestContMarkovILMLogProb(unittest.TestCase):
from gem.model.impl.tf_utils import eventlist2states
from gem.model.impl.tf_cont_ilm_logp import cont_markov_ilm_logp
events = EnumeratedEventList(t=[0.95, 1.5], uid=[0, 0], trans_id=[0, 1])
states = eventlist2states(events, tf.convert_to_tensor(self.initial), [0.0], self.stoichiometry)
states = eventlist2states(events, self.initial, 0.0, self.stoichiometry)
ll_t = cont_markov_ilm_logp(self.h, states, events)
if tf.executing_eagerly():
res = ll_t.numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment