Commit a42675e7 by Christopher Suter

### Rewrite Multinomial as hand rolled iterated binomial.

parent 864fcf65
 ... ... @@ -19,18 +19,33 @@ def chain_binomial_propagate(h, time_step): -tf.reduce_sum(rate_matrix, axis=-1)) # Calculate Markov transition probability matrix markov_transition = tf.linalg.expm(rate_matrix*time_step) # Sample new state new_state = tfd.Multinomial(total_count=state, probs=markov_transition).sample() new_state = tf.reduce_sum(new_state, axis=-1) num_states = markov_transition.shape[-1] prev_prob = tf.zeros_like(markov_transition[..., :, 0]) counts = tf.zeros(markov_transition.shape[:-1].as_list() + [0], dtype=markov_transition.dtype) total_count = state # This for loop is ok because there are (currently) only 4 states (SEIR) # and we're only actually creating work for 3 of them. Even for as many # as a ~10 states it should probably be fine, just increasing the size # of the graph a bit. for i in range(num_states - 1): binom = tfd.Binomial( total_count=total_count, probs=markov_transition[..., :, i] / (1. - prev_prob)) sample = binom.sample() counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1) total_count -= sample prev_prob = binom.probs counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1) new_state = tf.reduce_sum(counts, axis=-2) return new_state return propagate_fn @tf.function(autograph=False) # Algorithm runs super slow if uncommented. Weird! def chain_binomial_simulate(hazard_fn, state, start, end, time_step): propagate = chain_binomial_propagate(hazard_fn, time_step) times = tf.range(start, end, time_step) print(times.shape[0]) output = tf.TensorArray(state.dtype, size=times.shape[0]) output = output.write(0, state) ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!