discrete_markov.py 4.1 KB
 Chris Jewell committed Mar 01, 2020 1 ``````"""Functions for chain binomial simulation.""" `````` Chris Jewell committed Mar 26, 2020 2 ``````import numpy as np `````` Chris Jewell committed Mar 01, 2020 3 4 5 6 ``````import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions `````` Chris Jewell committed Apr 08, 2020 7 ``````from covid.impl.util import make_transition_rate_matrix `````` Chris Jewell committed Mar 01, 2020 8 `````` `````` Chris Jewell committed Apr 08, 2020 9 10 `````` def chain_binomial_propagate(h, time_step, seed=None): `````` Chris Jewell committed Mar 26, 2020 11 `````` """Propagates the state of a population according to discrete time dynamics. `````` Chris Jewell committed Mar 01, 2020 12 `````` `````` Chris Jewell committed Mar 26, 2020 13 `````` :param h: a hazard rate function returning the non-row-normalised Markov transition rate matrix `````` Chris Jewell committed Mar 30, 2020 14 15 `````` This function should return a tensor of dimension [ns, ns, nc] where ns is the number of states, and nc is the number of strata within the population. `````` Chris Jewell committed Mar 26, 2020 16 17 18 `````` :param time_step: the time step :returns : a function that propagate `state[t]` -> `state[t+time_step]` """ `````` Chris Jewell committed Mar 26, 2020 19 20 `````` def propagate_fn(t, state): rate_matrix = h(t, state) `````` Chris Jewell committed Mar 26, 2020 21 `````` # Set diagonal to be the negative of the sum of other elements in each row `````` Christopher Suter committed Mar 27, 2020 22 23 `````` rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=-1)) `````` Chris Jewell committed Mar 26, 2020 24 25 `````` # Calculate Markov transition probability matrix markov_transition = tf.linalg.expm(rate_matrix*time_step) `````` Christopher Suter committed Mar 27, 2020 26 `````` num_states = markov_transition.shape[-1] `````` Christopher Suter committed Mar 27, 2020 27 `````` prev_probs = tf.zeros_like(markov_transition[..., :, 0]) `````` Christopher Suter committed Mar 27, 2020 28 29 30 31 32 33 34 35 `````` counts = tf.zeros(markov_transition.shape[:-1].as_list() + [0], dtype=markov_transition.dtype) total_count = state # This for loop is ok because there are (currently) only 4 states (SEIR) # and we're only actually creating work for 3 of them. Even for as many # as a ~10 states it should probably be fine, just increasing the size # of the graph a bit. for i in range(num_states - 1): `````` Christopher Suter committed Mar 27, 2020 36 `````` probs = markov_transition[..., :, i] `````` Christopher Suter committed Mar 27, 2020 37 38 `````` binom = tfd.Binomial( total_count=total_count, `````` Chris Jewell committed Mar 28, 2020 39 `````` probs=tf.clip_by_value(probs / (1. - prev_probs), 0., 1.)) `````` Chris Jewell committed Apr 08, 2020 40 `````` sample = binom.sample(seed=seed) `````` Christopher Suter committed Mar 27, 2020 41 42 `````` counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1) total_count -= sample `````` Christopher Suter committed Mar 27, 2020 43 44 `````` prev_probs += probs `````` Christopher Suter committed Mar 27, 2020 45 46 `````` counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1) new_state = tf.reduce_sum(counts, axis=-2) `````` Chris Jewell committed Apr 08, 2020 47 `````` return counts, new_state `````` Chris Jewell committed Mar 01, 2020 48 49 50 `````` return propagate_fn `````` Chris Jewell committed Apr 08, 2020 51 ``````def discrete_markov_simulation(hazard_fn, state, start, end, time_step, seed=None): `````` Chris Jewell committed Mar 30, 2020 52 53 `````` """Simulates from a discrete time Markov state transition model using multinomial sampling across rows of the """ `````` Chris Jewell committed Apr 08, 2020 54 `````` propagate = chain_binomial_propagate(hazard_fn, time_step, seed=seed) `````` Chris Jewell committed Mar 01, 2020 55 `````` times = tf.range(start, end, time_step) `````` Chris Jewell committed Apr 08, 2020 56 `````` state = tf.convert_to_tensor(state) `````` Chris Jewell committed Mar 01, 2020 57 `````` `````` Christopher Suter committed Mar 27, 2020 58 `````` output = tf.TensorArray(state.dtype, size=times.shape[0]) `````` Chris Jewell committed Mar 01, 2020 59 `````` `````` Christopher Suter committed Mar 27, 2020 60 61 `````` cond = lambda i, *_: i < times.shape[0] def body(i, state, output): `````` Chris Jewell committed Apr 08, 2020 62 63 `````` update, state = propagate(i, state) output = output.write(i, update) `````` Christopher Suter committed Mar 27, 2020 64 65 66 `````` return i + 1, state, output _, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output)) return times, output.stack() `````` Chris Jewell committed Apr 08, 2020 67 68 69 70 71 72 73 74 75 76 77 78 79 `````` def discrete_markov_log_prob(events, init_state, hazard_fn, time_step): """Calculates an unnormalised log_prob function for a discrete time epidemic model. :param events: a [n_t, n_c, n_s, n_s] batch of transition events for all times t, metapopulations c, and states s :param init_state: a vector of shape [n_c, n_s] the initial state of the epidemic for s states and c metapopulations :param hazard_fn: a function that takes a state and returns a matrix of transition rates """ states = tf.concat([[init_state], tf.reduce_sum(events, axis=-2)], axis=-3)[:-1] t = tf.range(states.shape[-3]) `````` Chris Jewell committed Apr 09, 2020 80 81 82 83 84 85 86 87 88 89 `````` def log_prob_t(a, elems): t, event, state = elems rate_matrix = hazard_fn(t, state) rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=-1)) markov_transition = tf.linalg.expm(rate_matrix*time_step) logp = tfd.Multinomial(state, probs=markov_transition).log_prob(event) return a + tf.reduce_sum(logp) return tf.foldl(log_prob_t, (t, events, states), initializer=tf.constant(0., events.dtype))``````