model.py 11.4 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Functions for infection rates"""
import tensorflow as tf
Chris Jewell's avatar
Chris Jewell committed
3
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
4
from tensorflow_probability.python.internal import dtype_util
Chris Jewell's avatar
Chris Jewell committed
5
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
6

7
from covid.impl.util import make_transition_rate_matrix
Chris Jewell's avatar
Chris Jewell committed
8
9
from covid.rdata import load_mobility_matrix, load_population, load_age_mixing
from covid.pydata import load_commute_volume
10
from covid.impl.discrete_markov import discrete_markov_simulation, discrete_markov_log_prob
Chris Jewell's avatar
Chris Jewell committed
11

Chris Jewell's avatar
Chris Jewell committed
12
13
14
tode = tfp.math.ode
tla = tf.linalg

Chris Jewell's avatar
Chris Jewell committed
15
16
DTYPE = np.float64

Chris Jewell's avatar
Chris Jewell committed
17

Chris Jewell's avatar
Chris Jewell committed
18
def power_iteration(A, tol=1e-3):
Chris Jewell's avatar
Chris Jewell committed
19
20
    b_k = tf.random.normal([A.shape[1], 1], dtype=A.dtype)
    epsilon = tf.constant(1., dtype=A.dtype)
Chris Jewell's avatar
Chris Jewell committed
21
22
23
24
25
26
27
28
29
30
    i = 0
    while tf.greater(epsilon, tol):
        b_k1 = tf.matmul(A, b_k)
        b_k1_norm = tf.linalg.norm(b_k1)
        b_k_new = b_k1 / b_k1_norm
        epsilon = tf.reduce_sum(tf.pow(b_k_new-b_k, 2))
        b_k = b_k_new
        i += 1
    return b_k, i

Chris Jewell's avatar
Chris Jewell committed
31

Chris Jewell's avatar
Chris Jewell committed
32
33
34
35
36
def rayleigh_quotient(A, b):
    b = tf.reshape(b, [b.shape[0], 1])
    numerator = tf.matmul(tf.transpose(b), tf.matmul(A, b))
    denominator = tf.matmul(tf.transpose(b), b)
    return numerator / denominator
37

Chris Jewell's avatar
Chris Jewell committed
38

39
40
def dense_to_block_diagonal(A, n_blocks):
    A_dense = tf.linalg.LinearOperatorFullMatrix(A)
Chris Jewell's avatar
Chris Jewell committed
41
    eye = tf.linalg.LinearOperatorIdentity(n_blocks, dtype=A.dtype)
42
43
    A_block = tf.linalg.LinearOperatorKronecker([eye, A_dense])
    return A_block
Chris Jewell's avatar
Chris Jewell committed
44

45

Chris Jewell's avatar
Chris Jewell committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def load_data(paths, settings, dtype=DTYPE):
    M_tt, age_groups = load_age_mixing(paths['age_mixing_matrix_term'])
    M_hh, _ = load_age_mixing(paths['age_mixing_matrix_hol'])

    C, la_names = load_mobility_matrix(paths['mobility_matrix'])
    np.fill_diagonal(C, 0.)

    w_period = [settings['inference_period'][0], settings['prediction_period'][1]]
    W = load_commute_volume(paths['commute_volume'], w_period)['percent']

    pop = load_population(paths['population_size'])

    M_tt = M_tt.astype(DTYPE)
    M_hh = M_hh.astype(DTYPE)
    C = C.astype(DTYPE)
    W = W.astype(DTYPE)
    pop['n'] = pop['n'].astype(DTYPE)

    return {'M_tt': M_tt, 'M_hh': M_hh,
            'C': C, 'la_names': la_names,
            'age_groups': age_groups,
            'W': W, 'pop': pop}


70
class CovidUK:
Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
76
77
78
    def __init__(self,
                 M_tt: np.float64,
                 M_hh: np.float64,
                 W: np.float64,
                 C: np.float64,
                 N: np.float64,
                 date_range: list,
                 holidays: list,
Chris Jewell's avatar
Chris Jewell committed
79
                 lockdown: list,
80
                 time_step: np.int64):
Chris Jewell's avatar
Chris Jewell committed
81
82
        """Represents a CovidUK ODE model

Chris Jewell's avatar
Chris Jewell committed
83
84
        :param M_tt: a MxM matrix of age group mixing in term time
        :param M_hh: a MxM matrix of age group mixing in holiday time
Chris Jewell's avatar
Chris Jewell committed
85
        :param W: Commuting volume
86
        :param C: a n_ladsxn_lads matrix of inter-LAD commuting
Chris Jewell's avatar
Chris Jewell committed
87
        :param N: a vector of population sizes in each LAD
Chris Jewell's avatar
Chris Jewell committed
88
89
90
91
        :param date_range: a time range [start, end)
        :param holidays: a list of length-2 tuples containing dates of holidays
        :param lockdown: a length-2 tuple of start and end of lockdown measures
        :param time_step: a time step to use in the discrete time simulation
Chris Jewell's avatar
Chris Jewell committed
92
        """
Chris Jewell's avatar
Chris Jewell committed
93
        dtype = dtype_util.common_dtype([M_tt, M_hh, W, C, N], dtype_hint=np.float64)
94
95
        self.n_ages = M_tt.shape[0]
        self.n_lads = C.shape[0]
Chris Jewell's avatar
Chris Jewell committed
96
97
        self.M_tt = tf.convert_to_tensor(M_tt, dtype=tf.float64)
        self.M_hh = tf.convert_to_tensor(M_hh, dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
98

Chris Jewell's avatar
Chris Jewell committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        # Create one linear operator comprising both the term and holiday
        # matrices. This is nice because
        #   - the dense "M" parts will take up memory of shape [2, M, M]
        #   - the identity matirix will only take up memory of shape [M]
        #   - matmuls/matvecs will be quite efficient because of the
        #     LinearOperatorKronecker structure and diagonal structure of the
        #     identity piece thereof.
        # It should be sufficiently efficient that we can just get rid of the
        # control flow switching between the two operators, and instead just do
        # both matmuls in one big (vectorized!) pass, and pull out what we want
        # after the fact with tf.gather.
        self.M = dense_to_block_diagonal(
            np.stack([M_tt, M_hh], axis=0), self.n_lads)

        self.Kbar = tf.reduce_mean(M_tt)

        self.C = tf.linalg.LinearOperatorFullMatrix(C + tf.transpose(C))
        shp = tf.linalg.LinearOperatorFullMatrix(tf.ones_like(M_tt, dtype=dtype))
        self.C = tf.linalg.LinearOperatorKronecker([self.C, shp])
Chris Jewell's avatar
Chris Jewell committed
118
        self.W = tf.constant(W, dtype=dtype)
Chris Jewell's avatar
Chris Jewell committed
119
        self.N = tf.constant(N, dtype=dtype)
Chris Jewell's avatar
Chris Jewell committed
120
121
        N_matrix = tf.reshape(self.N, [self.n_lads, self.n_ages])
        N_sum = tf.reduce_sum(N_matrix, axis=1)
Chris Jewell's avatar
Chris Jewell committed
122
        N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=dtype)
Chris Jewell's avatar
Chris Jewell committed
123
        self.N_sum = tf.reshape(N_sum, [-1])
Chris Jewell's avatar
Chris Jewell committed
124

125
126
        self.time_step = time_step
        self.times = np.arange(date_range[0], date_range[1], np.timedelta64(int(time_step), 'D'))
Chris Jewell's avatar
Chris Jewell committed
127

Chris Jewell's avatar
Chris Jewell committed
128
129
        self.m_select = np.int64((self.times >= holidays[0]) &
                                 (self.times < holidays[1]))
Chris Jewell's avatar
Chris Jewell committed
130
131
        self.lockdown_select = np.int64((self.times >= lockdown[0]) &
                                        (self.times < lockdown[1]))
Chris Jewell's avatar
Chris Jewell committed
132
133
        self.max_t = self.m_select.shape[0] - 1

134
135
    def create_initial_state(self, init_matrix=None):
        if init_matrix is None:
Chris Jewell's avatar
Chris Jewell committed
136
            I = np.zeros(self.N.shape, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
137
            I[149*17+10] = 30.  # Middle-aged in Surrey
138
139
140
141
        else:
            np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
                                          err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
                                          ({self.n_lads},{self.n_ages})")
Chris Jewell's avatar
Chris Jewell committed
142
            I = tf.reshape(init_matrix, [-1])
143
        S = self.N - I
Chris Jewell's avatar
Chris Jewell committed
144
145
146
        E = tf.zeros(self.N.shape, dtype=DTYPE)
        R = tf.zeros(self.N.shape, dtype=DTYPE)
        return tf.stack([S, E, I, R], axis=-1)
147
148
149
150
151
152


class CovidUKODE(CovidUK):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
153
154
        self.solver = tode.DormandPrince()

155

Chris Jewell's avatar
Chris Jewell committed
156
    def make_h(self, param):
Chris Jewell's avatar
Chris Jewell committed
157

158
        def h_fn(t, state):
159

Chris Jewell's avatar
Chris Jewell committed
160
            S, E, I, R = tf.unstack(state, axis=-1)
Chris Jewell's avatar
Chris Jewell committed
161
162
163
164
165
            # Integrator may produce time values outside the range desired, so
            # we clip, implicitly assuming the outside dates have the same
            # holiday status as their nearest neighbors in the desired range.
            t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
            m_switch = tf.gather(self.m_select, t_idx)
Chris Jewell's avatar
Chris Jewell committed
166
            commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
Chris Jewell's avatar
Chris Jewell committed
167
168
            lockdown = tf.gather(self.lockdown_select, t_idx)
            beta = tf.where(lockdown == 0, param['beta1'], param['beta1']*param['beta3'])
Chris Jewell's avatar
Chris Jewell committed
169

Chris Jewell's avatar
Chris Jewell committed
170
            infec_rate = beta * (
Chris Jewell's avatar
Chris Jewell committed
171
                tf.gather(self.M.matvec(I), m_switch) +
Chris Jewell's avatar
Chris Jewell committed
172
                param['beta2'] * self.Kbar * commute_volume * self.C.matvec(I / self.N_sum))
173
            infec_rate = S * infec_rate / self.N
Chris Jewell's avatar
Chris Jewell committed
174

175
176
177
178
            dS = -infec_rate
            dE = infec_rate - param['nu'] * E
            dI = param['nu'] * E - param['gamma'] * I
            dR = param['gamma'] * I
Chris Jewell's avatar
Chris Jewell committed
179

180
181
            df = tf.stack([dS, dE, dI, dR], axis=-1)
            return df
182

Chris Jewell's avatar
Chris Jewell committed
183
184
        return h_fn

185
    def simulate(self, param, state_init, solver_state=None):
Chris Jewell's avatar
Chris Jewell committed
186
        h = self.make_h(param)
187
        t = np.arange(self.times.shape[0])
Chris Jewell's avatar
Chris Jewell committed
188
        results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init,
189
                                    solution_times=t, previous_solver_internal_state=solver_state)
190
        return results.times, results.states, results.solver_internal_state
Chris Jewell's avatar
Chris Jewell committed
191
192

    def ngm(self, param):
Chris Jewell's avatar
Chris Jewell committed
193
194
195
196
        infec_rate = param['beta1'] * (
            self.M.to_dense()[0, ...] +
            (param['beta2'] * self.Kbar * self.C.to_dense() /
             self.N_sum[np.newaxis, :]))
Chris Jewell's avatar
Chris Jewell committed
197
198
199
200
201
202
        ngm = infec_rate / param['gamma']
        return ngm

    def eval_R0(self, param, tol=1e-8):
        ngm = self.ngm(param)
        # Dominant eigen value by power iteration
203
        dom_eigen_vec, i = power_iteration(ngm, tol=tf.cast(tol, tf.float64))
Chris Jewell's avatar
Chris Jewell committed
204
205
        R0 = rayleigh_quotient(ngm, dom_eigen_vec)
        return tf.squeeze(R0), i
206
207


208
def covid19uk_logp(y, sim, phi, r):
209
        # Sum daily increments in removed
210
211
        r_incr = sim[1:, :, 3] - sim[:-1, :, 3]
        r_incr = tf.reduce_sum(r_incr, axis=-1)
212
213
214
215
        # Poisson(\lambda) = \lim{r\rightarrow \infty} NB(r, \frac{\lambda}{r + \lambda})
        #y_ = tfp.distributions.Poisson(rate=phi*r_incr)
        lambda_ = r_incr * phi
        y_ = tfp.distributions.NegativeBinomial(r, probs=lambda_/(r+lambda_))
216
        return y_.log_prob(y)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231


class CovidUKStochastic(CovidUK):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def make_h(self, param):
        """Constructs a function that takes `state` and outputs a
        transition rate matrix (with 0 diagonal).
        """

        def h(t, state):
            """Computes a transition rate matrix

Chris Jewell's avatar
Chris Jewell committed
232
            :param state: a tensor of shape [nc, ns] for ns states and nc population strata. States
233
234
              are S, E, I, R.  We arrange the state like this because the state vectors are then arranged
              contiguously in memory for fast calculation below.
Chris Jewell's avatar
Chris Jewell committed
235
            :return a tensor of shape [nc, ns, ns] containing transition matric for each i=0,...,(c-1)
236
237
238
239
            """
            t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
            m_switch = tf.gather(self.m_select, t_idx)
            commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
240
241
            lockdown = tf.gather(self.lockdown_select, t_idx)
            beta = tf.where(lockdown == 0, param['beta1'], param['beta1'] * param['beta3'])
242

243
244
245
            infec_rate = beta * (
                tf.gather(self.M.matvec(state[..., 2]), m_switch) +
                param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[..., 2] / self.N_sum))
Chris Jewell's avatar
Chris Jewell committed
246
            infec_rate = infec_rate / self.N  # Vector of length nc
247

248
249
250
251
            ei = tf.broadcast_to([tf.convert_to_tensor(param['nu'])], shape=[state.shape[0]])  # Vector of length nc
            ir = tf.broadcast_to([tf.convert_to_tensor(param['gamma'])], shape=[state.shape[0]])  # Vector of length nc

            rate_matrix = make_transition_rate_matrix([infec_rate, ei, ir], [[0, 1], [1, 2], [2, 3]], state)
252
            return rate_matrix
253
254
        return h

Chris Jewell's avatar
Chris Jewell committed
255
    @tf.function(autograph=False, experimental_compile=True)
256
257
258
259
260
261
262
263
264
    def simulate(self, param, state_init):
        """Runs a simulation from the epidemic model

        :param param: a dictionary of model parameters
        :param state_init: the initial state
        :returns: a tuple of times and simulated states.
        """
        param = {k: tf.constant(v, dtype=tf.float64) for k, v in param.items()}
        hazard = self.make_h(param)
Chris Jewell's avatar
Chris Jewell committed
265
266
        t, sim = discrete_markov_simulation(hazard, state_init, np.float64(0.),
                                            np.float64(self.times.shape[0]), self.time_step)
267
        return t, sim
268
269
270
271
272
273
274
275
276
277

    def log_prob(self, y, param, state_init):
        """Calculates the log probability of observing epidemic events y
        :param y: a list of tensors.  The first is of shape [n_times] containing times,
                  the second is of shape [n_times, n_states, n_states] containing event matrices.
        :param param: a list of parameters
        :returns: a scalar giving the log probability of the epidemic
        """
        hazard = self.make_h(param)
        return discrete_markov_log_prob(y, state_init, hazard, self.time_step)