Unverified Commit 3cf012e6 authored by Chris Jewell's avatar Chris Jewell Committed by GitHub
Browse files

Merge pull request #2 from csuter/stochastic

Rewrite Multinomial as hand rolled iterated binomial.
parents 79f51e8d a42675e7
......@@ -19,18 +19,33 @@ def chain_binomial_propagate(h, time_step):
-tf.reduce_sum(rate_matrix, axis=-1))
# Calculate Markov transition probability matrix
markov_transition = tf.linalg.expm(rate_matrix*time_step)
# Sample new state
new_state = tfd.Multinomial(total_count=state,
new_state = tf.reduce_sum(new_state, axis=-1)
num_states = markov_transition.shape[-1]
prev_prob = tf.zeros_like(markov_transition[..., :, 0])
counts = tf.zeros(markov_transition.shape[:-1].as_list() + [0],
total_count = state
# This for loop is ok because there are (currently) only 4 states (SEIR)
# and we're only actually creating work for 3 of them. Even for as many
# as a ~10 states it should probably be fine, just increasing the size
# of the graph a bit.
for i in range(num_states - 1):
binom = tfd.Binomial(
probs=markov_transition[..., :, i] / (1. - prev_prob))
sample = binom.sample()
counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1)
total_count -= sample
prev_prob = binom.probs
counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1)
new_state = tf.reduce_sum(counts, axis=-2)
return new_state
return propagate_fn
@tf.function(autograph=False) # Algorithm runs super slow if uncommented. Weird!
def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
propagate = chain_binomial_propagate(hazard_fn, time_step)
times = tf.range(start, end, time_step)
output = tf.TensorArray(state.dtype, size=times.shape[0])
output = output.write(0, state)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment