discrete_markov.py 5.95 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
3
"""Functions for chain binomial simulation."""
import tensorflow as tf
import tensorflow_probability as tfp
4

Chris Jewell's avatar
Chris Jewell committed
5
6
from covid.impl.util import compute_state, make_transition_matrix

Chris Jewell's avatar
Chris Jewell committed
7
8
tfd = tfp.distributions

Chris Jewell's avatar
Chris Jewell committed
9

10
11
12
13
14
15
def approx_expm(rates):
    """Approximates a full Markov transition matrix
    :param rates: un-normalised rate matrix (i.e. diagonal zero)
    :returns: approximation to Markov transition matrix
    """
    total_rates = tf.reduce_sum(rates, axis=-1, keepdims=True)
16
    prob = 1.0 - tf.math.exp(-tf.reduce_sum(rates, axis=-1, keepdims=True))
17
    mt1 = tf.math.multiply_no_nan(rates / total_rates, prob)
18
    return tf.linalg.set_diag(mt1, 1.0 - tf.reduce_sum(mt1, axis=-1))
19
20


Chris Jewell's avatar
Chris Jewell committed
21
def chain_binomial_propagate(h, time_step, seed=None):
22
    """Propagates the state of a population according to discrete time dynamics.
Chris Jewell's avatar
Chris Jewell committed
23

Chris Jewell's avatar
Chris Jewell committed
24
25
26
27
    :param h: a hazard rate function returning the non-row-normalised Markov transition
              rate matrix.  This function should return a tensor of dimension
              [ns, ns, nc] where ns is the number of states, and nc is the number of
              strata within the population.
28
29
30
    :param time_step: the time step
    :returns : a function that propagate `state[t]` -> `state[t+time_step]`
    """
31

32
    def propagate_fn(t, state):
33
34
35
36
        rates = h(t, state)
        rate_matrix = make_transition_matrix(
            rates, [[0, 1], [1, 2], [2, 3]], state.shape
        )
37
        # Set diagonal to be the negative of the sum of other elements in each row
38
        markov_transition = approx_expm(rate_matrix * time_step)
39
        num_states = markov_transition.shape[-1]
Christopher Suter's avatar
Christopher Suter committed
40
        prev_probs = tf.zeros_like(markov_transition[..., :, 0])
41
42
43
        counts = tf.zeros(
            markov_transition.shape[:-1].as_list() + [0], dtype=markov_transition.dtype
        )
44
45
46
47
48
49
        total_count = state
        # This for loop is ok because there are (currently) only 4 states (SEIR)
        # and we're only actually creating work for 3 of them. Even for as many
        # as a ~10 states it should probably be fine, just increasing the size
        # of the graph a bit.
        for i in range(num_states - 1):
50
51
52
53
54
55
56
57
58
            probs = markov_transition[..., :, i]
            binom = tfd.Binomial(
                total_count=total_count,
                probs=tf.clip_by_value(probs / (1.0 - prev_probs), 0.0, 1.0),
            )
            sample = binom.sample(seed=seed)
            counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1)
            total_count -= sample
            prev_probs += probs
Christopher Suter's avatar
Christopher Suter committed
59

60
61
        counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1)
        new_state = tf.reduce_sum(counts, axis=-2)
Chris Jewell's avatar
Chris Jewell committed
62
        return counts, new_state
63

Chris Jewell's avatar
Chris Jewell committed
64
65
66
    return propagate_fn


Chris Jewell's avatar
Chris Jewell committed
67
def discrete_markov_simulation(hazard_fn, state, start, end, time_step, seed=None):
68
69
    """Simulates from a discrete time Markov state transition model using multinomial sampling
    across rows of the """
Chris Jewell's avatar
Chris Jewell committed
70
    propagate = chain_binomial_propagate(hazard_fn, time_step, seed=seed)
71
    times = tf.range(start, end, time_step, dtype=state.dtype)
Chris Jewell's avatar
Chris Jewell committed
72
    state = tf.convert_to_tensor(state)
Chris Jewell's avatar
Chris Jewell committed
73

74
    output = tf.TensorArray(state.dtype, size=times.shape[0])
Chris Jewell's avatar
Chris Jewell committed
75

76
    cond = lambda i, *_: i < times.shape[0]
77

78
    def body(i, state, output):
Chris Jewell's avatar
Chris Jewell committed
79
        update, state = propagate(times[i], state)
80
81
82
        output = output.write(i, update)
        return i + 1, state, output

83
84
    _, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
    return times, output.stack()
Chris Jewell's avatar
Chris Jewell committed
85
86


Chris Jewell's avatar
Chris Jewell committed
87
def discrete_markov_log_prob(events, init_state, hazard_fn, time_step, stoichiometry):
Chris Jewell's avatar
Chris Jewell committed
88
    """Calculates an unnormalised log_prob function for a discrete time epidemic model.
Chris Jewell's avatar
Chris Jewell committed
89
90
91
92
93
94
95
96
97
98
    :param events: a `[M, T, X]` batch of transition events for metapopulation M,
                   times `T`, and transitions `X`.
    :param init_state: a vector of shape `[M, S]` the initial state of the epidemic for
                       `M` metapopulations and `S` states
    :param hazard_fn: a function that takes a state and returns a matrix of transition
                      rates.
    :param time_step: the size of the time step.
    :param stoichiometry: a `[X, S]` matrix describing the state update for each
                          transition.
    :return: a scalar log probability for the epidemic.
Chris Jewell's avatar
Chris Jewell committed
99
    """
Chris Jewell's avatar
Chris Jewell committed
100
101
    num_meta = events.shape[-3]
    num_times = events.shape[-2]
102
    num_events = events.shape[-1]
Chris Jewell's avatar
Chris Jewell committed
103
104
    num_states = stoichiometry.shape[-1]
    state_timeseries = compute_state(init_state, events, stoichiometry)  # MxTxS
105

106
    tms_timeseries = tf.transpose(state_timeseries, perm=(1, 0, 2))
Chris Jewell's avatar
Chris Jewell committed
107

108
109
    def fn(elems):
        return hazard_fn(*elems)
Chris Jewell's avatar
Chris Jewell committed
110

111
    rates = tf.vectorized_map(fn=fn, elems=[tf.range(num_times), tms_timeseries])
112
113
    rate_matrix = make_transition_matrix(
        rates, [[0, 1], [1, 2], [2, 3]], tms_timeseries.shape
Chris Jewell's avatar
Chris Jewell committed
114
    )
115
    probs = approx_expm(rate_matrix * time_step)
Chris Jewell's avatar
Chris Jewell committed
116
117

    # [T, M, S, S] to [M, T, S, S]
118
    probs = tf.transpose(probs, perm=(1, 0, 2, 3))
Chris Jewell's avatar
Chris Jewell committed
119
120
121
122
123
124
    event_matrix = make_transition_matrix(
        events, [[0, 1], [1, 2], [2, 3]], [num_meta, num_times, num_states]
    )
    event_matrix = tf.linalg.set_diag(
        event_matrix, state_timeseries - tf.reduce_sum(event_matrix, axis=-1)
    )
125
126
127
128
129
    logp = tfd.Multinomial(
        tf.cast(state_timeseries, dtype=tf.float32),
        probs=tf.cast(probs, dtype=tf.float32),
        name="log_prob",
    ).log_prob(tf.cast(event_matrix, dtype=tf.float32))
Chris Jewell's avatar
Chris Jewell committed
130

131
    return tf.cast(tf.reduce_sum(logp), dtype=events.dtype)
132
133
134
135
136
137
138
139
140


def events_to_full_transitions(events, initial_state):
    """Creates a state tensor given matrices of transition events
    and the initial state
    :param events: a tensor of shape [t, c, s, s] for t timepoints, c metapopulations
                   and s states.
    :param initial_state: the initial state matrix of shape [c, s]
    """
141

142
143
144
145
146
147
    def f(state, events):
        survived = tf.reduce_sum(state, axis=-2) - tf.reduce_sum(events, axis=-1)
        new_state = tf.linalg.set_diag(events, survived)
        return new_state

    return tf.scan(fn=f, elems=events, initializer=tf.linalg.diag(initial_state))