discrete_markov.py 4.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""Functions for chain binomial simulation."""
2
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
3
4
5
6
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

Chris Jewell's avatar
Chris Jewell committed
7
from covid.impl.util import make_transition_rate_matrix
Chris Jewell's avatar
Chris Jewell committed
8

Chris Jewell's avatar
Chris Jewell committed
9
10

def chain_binomial_propagate(h, time_step, seed=None):
11
    """Propagates the state of a population according to discrete time dynamics.
Chris Jewell's avatar
Chris Jewell committed
12

13
    :param h: a hazard rate function returning the non-row-normalised Markov transition rate matrix
14
15
              This function should return a tensor of dimension [ns, ns, nc] where ns is the number of
              states, and nc is the number of strata within the population.
16
17
18
    :param time_step: the time step
    :returns : a function that propagate `state[t]` -> `state[t+time_step]`
    """
19
20
    def propagate_fn(t, state):
        rate_matrix = h(t, state)
21
        # Set diagonal to be the negative of the sum of other elements in each row
22
23
        rate_matrix = tf.linalg.set_diag(rate_matrix,
                                         -tf.reduce_sum(rate_matrix, axis=-1))
24
25
        # Calculate Markov transition probability matrix
        markov_transition = tf.linalg.expm(rate_matrix*time_step)
26
        num_states = markov_transition.shape[-1]
Christopher Suter's avatar
Christopher Suter committed
27
        prev_probs = tf.zeros_like(markov_transition[..., :, 0])
28
29
30
31
32
33
34
35
        counts = tf.zeros(markov_transition.shape[:-1].as_list() + [0],
                          dtype=markov_transition.dtype)
        total_count = state
        # This for loop is ok because there are (currently) only 4 states (SEIR)
        # and we're only actually creating work for 3 of them. Even for as many
        # as a ~10 states it should probably be fine, just increasing the size
        # of the graph a bit.
        for i in range(num_states - 1):
Christopher Suter's avatar
Christopher Suter committed
36
          probs = markov_transition[..., :, i]
37
38
          binom = tfd.Binomial(
              total_count=total_count,
Chris Jewell's avatar
Chris Jewell committed
39
              probs=tf.clip_by_value(probs / (1. - prev_probs), 0., 1.))
Chris Jewell's avatar
Chris Jewell committed
40
          sample = binom.sample(seed=seed)
41
42
          counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1)
          total_count -= sample
Christopher Suter's avatar
Christopher Suter committed
43
44
          prev_probs += probs

45
46
        counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1)
        new_state = tf.reduce_sum(counts, axis=-2)
Chris Jewell's avatar
Chris Jewell committed
47
        return counts, new_state
Chris Jewell's avatar
Chris Jewell committed
48
49
50
    return propagate_fn


Chris Jewell's avatar
Chris Jewell committed
51
def discrete_markov_simulation(hazard_fn, state, start, end, time_step, seed=None):
52
53
    """Simulates from a discrete time Markov state transition model using multinomial sampling
    across rows of the """
Chris Jewell's avatar
Chris Jewell committed
54
    propagate = chain_binomial_propagate(hazard_fn, time_step, seed=seed)
Chris Jewell's avatar
Chris Jewell committed
55
    times = tf.range(start, end, time_step)
Chris Jewell's avatar
Chris Jewell committed
56
    state = tf.convert_to_tensor(state)
Chris Jewell's avatar
Chris Jewell committed
57

58
    output = tf.TensorArray(state.dtype, size=times.shape[0])
Chris Jewell's avatar
Chris Jewell committed
59

60
61
    cond = lambda i, *_: i < times.shape[0]
    def body(i, state, output):
Chris Jewell's avatar
Chris Jewell committed
62
63
      update, state = propagate(i, state)
      output = output.write(i, update)
64
65
66
      return i + 1, state, output
    _, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
    return times, output.stack()
Chris Jewell's avatar
Chris Jewell committed
67
68
69
70
71
72
73
74
75
76
77
78
79


def discrete_markov_log_prob(events, init_state, hazard_fn, time_step):
    """Calculates an unnormalised log_prob function for a discrete time epidemic model.
    :param events: a [n_t, n_c, n_s, n_s] batch of transition events for all times t, metapopulations c,
                   and states s
    :param init_state: a vector of shape [n_c, n_s] the initial state of the epidemic for s states
                       and c metapopulations
    :param hazard_fn: a function that takes a state and returns a matrix of transition rates
    """
    states = tf.concat([[init_state], tf.reduce_sum(events, axis=-2)], axis=-3)[:-1]
    t = tf.range(states.shape[-3])

80
81
82
83
84
85
86
87
88
89
    def log_prob_t(a, elems):
        t, event, state = elems
        rate_matrix = hazard_fn(t, state)
        rate_matrix = tf.linalg.set_diag(rate_matrix,
                                         -tf.reduce_sum(rate_matrix, axis=-1))
        markov_transition = tf.linalg.expm(rate_matrix*time_step)
        logp = tfd.Multinomial(state, probs=markov_transition).log_prob(event)
        return a + tf.reduce_sum(logp)

    return tf.foldl(log_prob_t, (t, events, states), initializer=tf.constant(0., events.dtype))