Unverified Commit 553c79da authored by Chris Jewell's avatar Chris Jewell Committed by GitHub
Browse files

Merge pull request #12 from csuter/stochastic-da-dev

Pair programming changes from 2020-06-12
parents 843ad749 15b44690
......@@ -11,7 +11,7 @@ TransitionTopology = collections.namedtuple('TransitionTopology',
def _abscumdiff(events, initial_state,
target_t, target_id,
bound_t, bound_id):
bound_t, bound_id, int_dtype=tf.int32):
"""Returns the number of free events to move in target_events
bounded by max([N_{target_id}(t)-N_{bound_id}(t)]_{bound_t}).
:param events: a [(M), T, X] tensor of transition events
......@@ -41,12 +41,13 @@ def _abscumdiff(events, initial_state,
cumdiff = tf.abs(tf.cumsum(diff, axis=-1))
# TODO: check the validity of bound_id+1
bound_init_state = tf.gather(initial_state,
target_id + tf.cast(bound_id > target_id, dtype=tf.int32),
bound_init_state = tf.gather(
target_id + tf.cast(bound_id > target_id, dtype=int_dtype),
indices = tf.stack([
tf.repeat(tf.range(bound_t.shape[0], dtype=tf.int64), [bound_t.shape[1]]),
tf.repeat(tf.range(bound_t.shape[0], dtype=tf.int32), [bound_t.shape[1]]),
tf.reshape(bound_t, [-1])
], axis=-1)
indices = tf.reshape(indices, bound_t.shape + [2])
......@@ -57,58 +58,47 @@ def _abscumdiff(events, initial_state,
# compiler can guarantee that the output shapes of true_fn() and
# false_fn() are equal.
def false_fn():
return tf.zeros([events.shape[0]] + [bound_t.shape[0]])
return int_dtype.max * tf.ones([events.shape[0]] + [bound_t.shape[1]],
ret_val = tf.cond(bound_id != -1, true_fn, false_fn)
return ret_val
def TimeDelta(dmax, p_pos=0.5, name=None):
def u():
return tfd.Bernoulli(probs=p_pos)
outcomes = tf.concat([-tf.range(1, dmax + 1), tf.range(1, dmax + 1)], axis=0)
logits = tf.ones_like(outcomes, dtype=tf.float64)
return tfd.FiniteDiscrete(outcomes=outcomes, logits=logits, name=name)
def magnitude():
return tfd.Categorical(logits=tf.ones(dmax)) # Issue if dmax unknown at compile time
def delta_t(u, magnitude):
return tfd.Deterministic((2 * u - 1) * (magnitude+1), name='delta_t')
return tfd.JointDistributionNamed(dict(u=u,
delta_t=delta_t), name=name)
def EventTimeProposal(events, initial_state, topology, d_max, n_max, name=None):
def EventTimeProposal(events, initial_state, topology, d_max, n_max, dtype=tf.int32, name=None):
"""Draws an event time move proposal.
:param events: a [M, T, K] tensor of event times (M number of metapopulations,
T number of timepoints, K number of transitions)
:param initial_state: a [M, S] tensor of initial metapopulation x state counts
:param topology: a 3-element tuple of (previous_transition, target_transition,
:param d_max: the maximum distance over which to move
next_transition), eg "(s->e, e->i, i->r)"
(assuming we are interested presently in e->i, `None` for boundaries)
:param d_max: the maximum distance over which to move (in time)
:param n_max: the maximum number of events to move
target_events = tf.gather(events, topology.target, axis=-1)
time_interval = tf.range(d_max, dtype=tf.int64)
def one_hot_rows():
# OneHotCategorical on each batch
x = tf.cast(target_events > 0, dtype=tf.float32)
logits = tf.math.log(x)
return tfd.OneHotCategorical(logits=logits, name='one_hot_rows')
time_interval = tf.range(d_max, dtype=dtype)
def t(one_hot_rows):
col_idx = tf.argmax(one_hot_rows, axis=-1)
return tfd.Deterministic(col_idx[..., None], name='event_coords')
def t():
x = tf.cast(target_events > 0, dtype=tf.float64) # [M, T]
return tfd.Categorical(logits=tf.math.log(x), name='event_coords')
def time_delta():
return TimeDelta(d_max, name='TimeDelta')
def x_star(t, time_delta):
delta_t = time_delta['delta_t']
delta_t = time_delta
# Compute bounds
# The limitations of XLA mean that we must calculate bounds for
# intervals [t, t+delta_t) if delta_t > 0, and [t+delta_t, t) if
# delta_t is < 0.
t = t[..., tf.newaxis]
bound_interval = tf.where(delta_t < 0,
t - time_interval - 1, # [t+delta_t, t)
t + time_interval) # [t, t+delta_t)
......@@ -117,29 +107,32 @@ def EventTimeProposal(events, initial_state, topology, d_max, n_max, name=None):
topology.prev or -1,
topology.next or -1)
free_events = _abscumdiff(events=events, initial_state=initial_state,
free_events = _abscumdiff(events=events,
# Mask out bits of the interval we don't need for our delta_t
inf_mask = tf.cumsum(tf.one_hot(tf.math.abs(delta_t),
d_max, dtype=tf.int32)) * tf.int32.max
free_events = tf.reduce_min(tf.cast(inf_mask,
free_events.dtype) + free_events, axis=-1)
free_events = tf.reduce_min(inf_mask + free_events, axis=-1)
indices = tf.stack([
tf.range(events.shape[0], dtype=tf.int64),
tf.range(events.shape[0], dtype=dtype),
], axis=-1)
available_events = tf.gather_nd(events[..., topology.target], indices)
available_events = tf.gather_nd(target_events, indices)
max_events = tf.minimum(free_events, available_events)
max_events = tf.clip_by_value(max_events, clip_value_min=0,
# Draw x_star
# TODO: needs correct sample/log_prob (mass)
return tfd.Uniform(low=0, high=max_events, name='x_star')
return tfd.JointDistributionNamed(dict(one_hot_rows=one_hot_rows,
return tfd.JointDistributionNamed(dict(t=t,
x_star=x_star), name=name)
from pprint import pprint
import unittest
import pickle as pkl
import numpy as np
......@@ -6,11 +7,10 @@ import tensorflow as tf
from covid.impl.event_time_proposal import _abscumdiff, EventTimeProposal, TransitionTopology
class TestAbsCumDiff(unittest.TestCase):
def setUp(self):
with open('../stochastic_sim_small.pkl','rb') as f:
with open('./stochastic_sim_small.pkl','rb') as f:
sim = pkl.load(f)
self.events = np.stack([sim['events'][..., 0, 1],
sim['events'][..., 1, 2],
......@@ -61,10 +61,9 @@ class TestAbsCumDiff(unittest.TestCase):
bound_t=t, bound_id=-1).numpy()
np.testing.assert_array_equal(n_max, np.zeros([self.events.shape[0], 1]))
class TestEventTimeProposal(unittest.TestCase):
def setUp(self):
with open('../stochastic_sim_small.pkl', 'rb') as f:
with open('./stochastic_sim_small.pkl', 'rb') as f:
sim = pkl.load(f)
self.events = np.stack([sim['events'][..., 0, 1], # S->E
sim['events'][..., 1, 2], # E->I
......@@ -81,7 +80,7 @@ class TestEventTimeProposal(unittest.TestCase):
def test_event_time_proposal_sample(self):
q = self.Q.sample()
if __name__ == '__main__':
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment