model.py 5.86 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Functions for infection rates"""
import tensorflow as tf
Chris Jewell's avatar
Chris Jewell committed
3
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
4
from tensorflow_probability.python.internal import dtype_util
Chris Jewell's avatar
Chris Jewell committed
5
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
6

7
from covid import config
8
from covid.impl.util import make_transition_matrix
9
10
from covid.rdata import load_age_mixing
from covid.pydata import load_commute_volume, load_mobility_matrix, load_population
Chris Jewell's avatar
Chris Jewell committed
11
12
13
14
from covid.impl.discrete_markov import (
    discrete_markov_simulation,
    discrete_markov_log_prob,
)
Chris Jewell's avatar
Chris Jewell committed
15

Chris Jewell's avatar
Chris Jewell committed
16
17
tla = tf.linalg

18
DTYPE = config.floatX
Chris Jewell's avatar
Chris Jewell committed
19

Chris Jewell's avatar
Chris Jewell committed
20

Chris Jewell's avatar
Chris Jewell committed
21
def power_iteration(A, tol=1e-3):
Chris Jewell's avatar
Chris Jewell committed
22
    b_k = tf.random.normal([A.shape[1], 1], dtype=A.dtype)
Chris Jewell's avatar
Chris Jewell committed
23
    epsilon = tf.constant(1.0, dtype=A.dtype)
Chris Jewell's avatar
Chris Jewell committed
24
25
26
27
28
    i = 0
    while tf.greater(epsilon, tol):
        b_k1 = tf.matmul(A, b_k)
        b_k1_norm = tf.linalg.norm(b_k1)
        b_k_new = b_k1 / b_k1_norm
Chris Jewell's avatar
Chris Jewell committed
29
        epsilon = tf.reduce_sum(tf.pow(b_k_new - b_k, 2))
Chris Jewell's avatar
Chris Jewell committed
30
31
32
33
        b_k = b_k_new
        i += 1
    return b_k, i

Chris Jewell's avatar
Chris Jewell committed
34

Chris Jewell's avatar
Chris Jewell committed
35
def rayleigh_quotient(A, b):
36
37
38
    b = tf.squeeze(b)
    numerator = tf.einsum("...i,...i->...", b, tf.linalg.matvec(A, b))
    denominator = tf.einsum("...i,...i->...", b, b)
Chris Jewell's avatar
Chris Jewell committed
39
    return numerator / denominator
40

Chris Jewell's avatar
Chris Jewell committed
41

42
43
def dense_to_block_diagonal(A, n_blocks):
    A_dense = tf.linalg.LinearOperatorFullMatrix(A)
Chris Jewell's avatar
Chris Jewell committed
44
    eye = tf.linalg.LinearOperatorIdentity(n_blocks, dtype=A.dtype)
45
46
    A_block = tf.linalg.LinearOperatorKronecker([eye, A_dense])
    return A_block
Chris Jewell's avatar
Chris Jewell committed
47

48

Chris Jewell's avatar
Chris Jewell committed
49
def load_data(paths, settings, dtype=DTYPE):
Chris Jewell's avatar
Chris Jewell committed
50
51
    M_tt, age_groups = load_age_mixing(paths["age_mixing_matrix_term"])
    M_hh, _ = load_age_mixing(paths["age_mixing_matrix_hol"])
Chris Jewell's avatar
Chris Jewell committed
52

53
    C = load_mobility_matrix(paths["mobility_matrix"])
54
    la_names = C.index.to_numpy()
Chris Jewell's avatar
Chris Jewell committed
55

56
    w_period = [settings["inference_period"][0], settings["inference_period"][1]]
Chris Jewell's avatar
Chris Jewell committed
57
    W = load_commute_volume(paths["commute_volume"], w_period)["percent"]
Chris Jewell's avatar
Chris Jewell committed
58

59
    pop = load_population(paths["population_size"])
Chris Jewell's avatar
Chris Jewell committed
60
61
62

    M_tt = M_tt.astype(DTYPE)
    M_hh = M_hh.astype(DTYPE)
63
    C = C.to_numpy().astype(DTYPE)
Chris Jewell's avatar
Chris Jewell committed
64
    np.fill_diagonal(C, 0.0)
65
66
    W = W.to_numpy().astype(DTYPE)
    pop = pop.to_numpy().astype(DTYPE)
Chris Jewell's avatar
Chris Jewell committed
67

Chris Jewell's avatar
Chris Jewell committed
68
69
70
71
72
73
74
75
76
    return {
        "M_tt": M_tt,
        "M_hh": M_hh,
        "C": C,
        "la_names": la_names,
        "age_groups": age_groups,
        "W": W,
        "pop": pop,
    }
Chris Jewell's avatar
Chris Jewell committed
77
78


79
class CovidUKStochastic:
Chris Jewell's avatar
Chris Jewell committed
80
81
    def __init__(
        self,
82
83
84
85
86
87
        transition_rates,
        stoichiometry,
        initial_state,
        initial_step,
        time_delta,
        num_steps,
Chris Jewell's avatar
Chris Jewell committed
88
    ):
89
90
91
92
93
94
95
96
        """Implements a discrete-time Markov jump process for a state transition model.

        :param transition_rates: a function of the form `fn(t, state)` taking the current time `t` and state tensor `state`.  This function returns a tensor which broadcasts to the first dimension of `stoichiometry`.
        :param stoichiometry: the stochiometry matrix for the state transition model, with rows representing transitions and columns representing states.
        :param initial_state: an initial state tensor with inner dimension equal to the first dimension of `stoichiometry`.
        :param initial_step: an offset giving the time `t` of the first timestep in the model.
        :param time_delta: the size of the time step to be used.
        :param num_steps: the number of time steps across which the model runs.
Chris Jewell's avatar
Chris Jewell committed
97
        """
Chris Jewell's avatar
Chris Jewell committed
98

99
100
101
102
103
        self.transition_rates = transition_rates
        self.stoichiometry = stoichiometry
        self.initial_state = initial_state
        self.initial_step = initial_step
        self.time_delta = time_delta
104
        self.num_steps = num_steps
105

106
    def ngm(self, t, state, param):
107
        """Computes a next generation matrix -- pressure from i to j is G_{ij}
108
109
110
111
112
113
114
115
        :param t: the time step
        :param state: a tensor of shape [M, S] for S states and M population strata.
                      States are S, E, I, R.
        :return: a tensor of shape [M, M] giving the expected number of new cases of
                 disease individuals in each metapopulation give rise to.
        """
        t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
        commute_volume = tf.pow(tf.gather(self.W, t_idx), param["omega"])
116
117
118
        xi_idx = tf.gather(self.xi_select, t_idx)
        xi = tf.gather(param["xi"], xi_idx)
        beta = param["beta1"] * tf.math.exp(xi)
119
120
121

        ngm = beta * (
            tf.eye(self.C.shape[0], dtype=state.dtype)
Chris Jewell's avatar
Chris Jewell committed
122
123
124
125
126
127
            + param["beta2"] * commute_volume * self.C / self.N[tf.newaxis, :]
        )
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
            / (self.N[:, tf.newaxis] * param["gamma"])
128
129
130
        )
        return ngm

131
    def sample(self, seed=None):
132
133
134
135
136
137
        """Runs a simulation from the epidemic model

        :param param: a dictionary of model parameters
        :param state_init: the initial state
        :returns: a tuple of times and simulated states.
        """
Chris Jewell's avatar
Chris Jewell committed
138
        t, sim = discrete_markov_simulation(
139
            hazard_fn=self.transition_rates,
140
            state=self.initial_state,
141
142
143
            start=self.initial_step,
            end=self.initial_step + self.num_steps * self.time_delta,
            time_step=self.time_delta,
144
            seed=seed,
Chris Jewell's avatar
Chris Jewell committed
145
        )
146
        return t, sim
147

148
    def log_prob(self, y):
149
150
151
152
153
154
        """Calculates the log probability of observing epidemic events y
        :param y: a list of tensors.  The first is of shape [n_times] containing times,
                  the second is of shape [n_times, n_states, n_states] containing event matrices.
        :param param: a list of parameters
        :returns: a scalar giving the log probability of the epidemic
        """
155
156
157
        dtype = dtype = dtype_util.common_dtype(
            [y, self.initial_state], dtype_hint=DTYPE
        )
158
        y = tf.convert_to_tensor(y, dtype)
Chris Jewell's avatar
Chris Jewell committed
159
        with tf.name_scope("CovidUKStochastic.log_prob"):
160
            hazard = self.transition_rates
Chris Jewell's avatar
Chris Jewell committed
161
            return discrete_markov_log_prob(
162
                y, self.initial_state, hazard, self.time_delta, self.stoichiometry
Chris Jewell's avatar
Chris Jewell committed
163
            )