discrete_markov.py 5.95 KB
 Chris Jewell committed Mar 01, 2020 1 2 3 ``````"""Functions for chain binomial simulation.""" import tensorflow as tf import tensorflow_probability as tfp `````` Chris Jewell committed Jun 26, 2020 4 `````` `````` Chris Jewell committed Jun 28, 2020 5 6 ``````from covid.impl.util import compute_state, make_transition_matrix `````` Chris Jewell committed Mar 01, 2020 7 8 ``````tfd = tfp.distributions `````` Chris Jewell committed Apr 08, 2020 9 `````` `````` Chris Jewell committed Apr 10, 2020 10 11 12 13 14 15 ``````def approx_expm(rates): """Approximates a full Markov transition matrix :param rates: un-normalised rate matrix (i.e. diagonal zero) :returns: approximation to Markov transition matrix """ total_rates = tf.reduce_sum(rates, axis=-1, keepdims=True) `````` Chris Jewell committed Jun 26, 2020 16 `````` prob = 1.0 - tf.math.exp(-tf.reduce_sum(rates, axis=-1, keepdims=True)) `````` Chris Jewell committed Apr 10, 2020 17 `````` mt1 = tf.math.multiply_no_nan(rates / total_rates, prob) `````` Chris Jewell committed Jun 26, 2020 18 `````` return tf.linalg.set_diag(mt1, 1.0 - tf.reduce_sum(mt1, axis=-1)) `````` Chris Jewell committed Apr 10, 2020 19 20 `````` `````` Chris Jewell committed Apr 08, 2020 21 ``````def chain_binomial_propagate(h, time_step, seed=None): `````` Chris Jewell committed Mar 26, 2020 22 `````` """Propagates the state of a population according to discrete time dynamics. `````` Chris Jewell committed Mar 01, 2020 23 `````` `````` Chris Jewell committed Jun 28, 2020 24 25 26 27 `````` :param h: a hazard rate function returning the non-row-normalised Markov transition rate matrix. This function should return a tensor of dimension [ns, ns, nc] where ns is the number of states, and nc is the number of strata within the population. `````` Chris Jewell committed Mar 26, 2020 28 29 30 `````` :param time_step: the time step :returns : a function that propagate `state[t]` -> `state[t+time_step]` """ `````` Chris Jewell committed Jun 26, 2020 31 `````` `````` Chris Jewell committed Mar 26, 2020 32 `````` def propagate_fn(t, state): `````` Chris Jewell committed Jul 08, 2020 33 34 35 36 `````` rates = h(t, state) rate_matrix = make_transition_matrix( rates, [[0, 1], [1, 2], [2, 3]], state.shape ) `````` Chris Jewell committed Mar 26, 2020 37 `````` # Set diagonal to be the negative of the sum of other elements in each row `````` Chris Jewell committed Jun 26, 2020 38 `````` markov_transition = approx_expm(rate_matrix * time_step) `````` Christopher Suter committed Mar 27, 2020 39 `````` num_states = markov_transition.shape[-1] `````` Christopher Suter committed Mar 27, 2020 40 `````` prev_probs = tf.zeros_like(markov_transition[..., :, 0]) `````` Chris Jewell committed Jun 26, 2020 41 42 43 `````` counts = tf.zeros( markov_transition.shape[:-1].as_list() + [0], dtype=markov_transition.dtype ) `````` Christopher Suter committed Mar 27, 2020 44 45 46 47 48 49 `````` total_count = state # This for loop is ok because there are (currently) only 4 states (SEIR) # and we're only actually creating work for 3 of them. Even for as many # as a ~10 states it should probably be fine, just increasing the size # of the graph a bit. for i in range(num_states - 1): `````` Chris Jewell committed Jun 26, 2020 50 51 52 53 54 55 56 57 58 `````` probs = markov_transition[..., :, i] binom = tfd.Binomial( total_count=total_count, probs=tf.clip_by_value(probs / (1.0 - prev_probs), 0.0, 1.0), ) sample = binom.sample(seed=seed) counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1) total_count -= sample prev_probs += probs `````` Christopher Suter committed Mar 27, 2020 59 `````` `````` Christopher Suter committed Mar 27, 2020 60 61 `````` counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1) new_state = tf.reduce_sum(counts, axis=-2) `````` Chris Jewell committed Apr 08, 2020 62 `````` return counts, new_state `````` Chris Jewell committed Jun 26, 2020 63 `````` `````` Chris Jewell committed Mar 01, 2020 64 65 66 `````` return propagate_fn `````` Chris Jewell committed Apr 08, 2020 67 ``````def discrete_markov_simulation(hazard_fn, state, start, end, time_step, seed=None): `````` Chris Jewell committed Mar 30, 2020 68 69 `````` """Simulates from a discrete time Markov state transition model using multinomial sampling across rows of the """ `````` Chris Jewell committed Apr 08, 2020 70 `````` propagate = chain_binomial_propagate(hazard_fn, time_step, seed=seed) `````` Chris Jewell committed Sep 04, 2020 71 `````` times = tf.range(start, end, time_step, dtype=state.dtype) `````` Chris Jewell committed Apr 08, 2020 72 `````` state = tf.convert_to_tensor(state) `````` Chris Jewell committed Mar 01, 2020 73 `````` `````` Christopher Suter committed Mar 27, 2020 74 `````` output = tf.TensorArray(state.dtype, size=times.shape[0]) `````` Chris Jewell committed Mar 01, 2020 75 `````` `````` Christopher Suter committed Mar 27, 2020 76 `````` cond = lambda i, *_: i < times.shape[0] `````` Chris Jewell committed Jun 26, 2020 77 `````` `````` Christopher Suter committed Mar 27, 2020 78 `````` def body(i, state, output): `````` Chris Jewell committed Aug 23, 2020 79 `````` update, state = propagate(times[i], state) `````` Chris Jewell committed Jun 26, 2020 80 81 82 `````` output = output.write(i, update) return i + 1, state, output `````` Christopher Suter committed Mar 27, 2020 83 84 `````` _, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output)) return times, output.stack() `````` Chris Jewell committed Apr 08, 2020 85 86 `````` `````` Chris Jewell committed Jun 28, 2020 87 ``````def discrete_markov_log_prob(events, init_state, hazard_fn, time_step, stoichiometry): `````` Chris Jewell committed Apr 08, 2020 88 `````` """Calculates an unnormalised log_prob function for a discrete time epidemic model. `````` Chris Jewell committed Jun 28, 2020 89 90 91 92 93 94 95 96 97 98 `````` :param events: a `[M, T, X]` batch of transition events for metapopulation M, times `T`, and transitions `X`. :param init_state: a vector of shape `[M, S]` the initial state of the epidemic for `M` metapopulations and `S` states :param hazard_fn: a function that takes a state and returns a matrix of transition rates. :param time_step: the size of the time step. :param stoichiometry: a `[X, S]` matrix describing the state update for each transition. :return: a scalar log probability for the epidemic. `````` Chris Jewell committed Apr 08, 2020 99 `````` """ `````` Chris Jewell committed Jun 28, 2020 100 101 `````` num_meta = events.shape[-3] num_times = events.shape[-2] `````` Chris Jewell committed Jun 28, 2020 102 `````` num_events = events.shape[-1] `````` Chris Jewell committed Jun 28, 2020 103 104 `````` num_states = stoichiometry.shape[-1] state_timeseries = compute_state(init_state, events, stoichiometry) # MxTxS `````` Chris Jewell committed Apr 10, 2020 105 `````` `````` Chris Jewell committed Jun 28, 2020 106 `````` tms_timeseries = tf.transpose(state_timeseries, perm=(1, 0, 2)) `````` Chris Jewell committed Jun 28, 2020 107 `````` `````` Chris Jewell committed Jun 28, 2020 108 109 `````` def fn(elems): return hazard_fn(*elems) `````` Chris Jewell committed Jun 28, 2020 110 `````` `````` Chris Jewell committed Jun 28, 2020 111 `````` rates = tf.vectorized_map(fn=fn, elems=[tf.range(num_times), tms_timeseries]) `````` Chris Jewell committed Jun 28, 2020 112 113 `````` rate_matrix = make_transition_matrix( rates, [[0, 1], [1, 2], [2, 3]], tms_timeseries.shape `````` Chris Jewell committed Jun 28, 2020 114 `````` ) `````` Chris Jewell committed Jun 28, 2020 115 `````` probs = approx_expm(rate_matrix * time_step) `````` Chris Jewell committed Jun 28, 2020 116 117 `````` # [T, M, S, S] to [M, T, S, S] `````` Chris Jewell committed Jun 28, 2020 118 `````` probs = tf.transpose(probs, perm=(1, 0, 2, 3)) `````` Chris Jewell committed Jun 28, 2020 119 120 121 122 123 124 `````` event_matrix = make_transition_matrix( events, [[0, 1], [1, 2], [2, 3]], [num_meta, num_times, num_states] ) event_matrix = tf.linalg.set_diag( event_matrix, state_timeseries - tf.reduce_sum(event_matrix, axis=-1) ) `````` Chris Jewell committed Jun 28, 2020 125 126 127 128 129 `````` logp = tfd.Multinomial( tf.cast(state_timeseries, dtype=tf.float32), probs=tf.cast(probs, dtype=tf.float32), name="log_prob", ).log_prob(tf.cast(event_matrix, dtype=tf.float32)) `````` Chris Jewell committed Jun 28, 2020 130 `````` `````` Chris Jewell committed Jun 28, 2020 131 `````` return tf.cast(tf.reduce_sum(logp), dtype=events.dtype) `````` Chris Jewell committed Apr 17, 2020 132 133 134 135 136 137 138 139 140 `````` def events_to_full_transitions(events, initial_state): """Creates a state tensor given matrices of transition events and the initial state :param events: a tensor of shape [t, c, s, s] for t timepoints, c metapopulations and s states. :param initial_state: the initial state matrix of shape [c, s] """ `````` Chris Jewell committed Jun 26, 2020 141 `````` `````` Chris Jewell committed Apr 17, 2020 142 143 144 145 146 147 `````` def f(state, events): survived = tf.reduce_sum(state, axis=-2) - tf.reduce_sum(events, axis=-1) new_state = tf.linalg.set_diag(events, survived) return new_state return tf.scan(fn=f, elems=events, initializer=tf.linalg.diag(initial_state))``````