model_spec.py 8.82 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
4
import geopandas as gp
5
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
6
7
8
import tensorflow as tf
import tensorflow_probability as tfp

9
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
10
from covid.util import impute_previous_cases
11
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
12
13

tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
14
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
15

16
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
17
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
18
XI_FREQ = 14  # baseline transmission changes every 14 days
19
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
20

Chris Jewell's avatar
Chris Jewell committed
21

22
def read_covariates(paths, date_low, date_high):
23
24
25
26
27
28
29
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
30
31
32
33
34
    mobility = data.read_mobility(paths["mobility_matrix"])
    popsize = data.read_population(paths["population_size"])
    commute_volume = data.read_traffic_flow(
        paths["commute_volume"], date_low=date_low, date_high=date_high
    )
35

36
37
38
39
40
41
42
43
    geo = gp.read_file(paths["geopackage"])
    geo = geo.loc[geo["lad19cd"].str.startswith("E")]
    tier_restriction = data.read_tier_restriction_data(
        paths["tier_restriction_csv"],
        geo[["lad19cd", "lad19nm"]],
        date_low,
        date_high,
    )
44
45
    weekday = pd.date_range(date_low, date_high).weekday < 5

46
47
48
49
    return dict(
        C=mobility.to_numpy().astype(DTYPE),
        W=commute_volume.to_numpy().astype(DTYPE),
        N=popsize.to_numpy().astype(DTYPE),
50
        L=tier_restriction.astype(DTYPE),
51
        weekday=weekday.astype(DTYPE),
52
53
54
    )


Chris Jewell's avatar
Chris Jewell committed
55
56
57
58
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

59
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
60
61
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
62
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
63
64
65
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
66
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
67
68
              the right-most dimension contain the imputed event times.
    """
69
70
    ei_events, lag_ei = impute_previous_cases(cases, 0.21)
    se_events, lag_se = impute_previous_cases(ei_events, 0.28)
Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


76
def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
Chris Jewell's avatar
Chris Jewell committed
77
    def beta1():
78
79
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
80
            scale=tf.constant(1000.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
81
82
83
84
85
86
87
88
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

89
90
91
92
    def beta3():
        return tfd.Sample(
            tfd.Normal(
                loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
93
                scale=tf.constant(100.0, dtype=DTYPE),
94
            ),
95
            sample_shape=covariates["L"].shape[-1],
96
97
        )

98
    def xi(beta1):
Chris Jewell's avatar
Chris Jewell committed
99
        sigma = tf.constant(0.4, dtype=DTYPE)
100
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
101
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
102
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
103
104
105
106
107
        return tfd.GaussianProcess(
            kernel,
            mean_fn=lambda idx: beta1,
            index_points=idx_pts[:, tf.newaxis],
        )
Chris Jewell's avatar
Chris Jewell committed
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def gamma0():
        # return tfd.Gamma(
        #     concentration=tf.constant(
        #         priors["gamma"]["concentration"], dtype=DTYPE
        #     ),
        #     rate=tf.constant(priors["gamma"]["rate"], dtype=DTYPE),
        # )
        return  tfd.Normal(loc=tf.constant(0.0, dtype=DTYPE),
                           scale=tf.constant(100.0, dtype=DTYPE),
        )
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
123
124
        )

125
    def seir(beta2, beta3, xi, gamma0, gamma1):
Chris Jewell's avatar
Chris Jewell committed
126
        beta2 = tf.convert_to_tensor(beta2, DTYPE)
127
        beta3 = tf.convert_to_tensor(beta3, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
128
        xi = tf.convert_to_tensor(xi, DTYPE)
129
130
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
131

132
133
134
        L = tf.convert_to_tensor(covariates["L"], DTYPE)
        L = L - tf.reduce_mean(L, axis=0)

135
136
137
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
138
139
140
141
142
        def transition_rate_fn(t, state):
            C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
            C = tf.linalg.set_diag(
                C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
            )
143
144
            W = tf.constant(np.squeeze(covariates["W"]), dtype=DTYPE)
            N = tf.constant(np.squeeze(covariates["N"]), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
145
146
147
148

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
149
150
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
                dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
151
152
153
            )
            xi_ = tf.gather(xi, xi_idx)

154
155
156
157
            L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
            Lt = tf.gather(L, L_idx)
            xB = tf.linalg.matvec(Lt, beta3)

158
159
160
161
162
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

163
            infec_rate = tf.math.exp(xi_ + xB) * (
Chris Jewell's avatar
Chris Jewell committed
164
                state[..., 2]
165
166
167
                + beta2
                * commute_volume
                * tf.linalg.matvec(C, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
168
            )
Chris Jewell's avatar
Chris Jewell committed
169
170
171
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
172

Chris Jewell's avatar
Chris Jewell committed
173
174
175
176
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
177
178
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
179
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
180
181
182
183
184
185
186
187
188
189
190
191
192

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
193
        dict(
194
195
196
197
198
199
200
            beta1=beta1,
            beta2=beta2,
            beta3=beta3,
            xi=xi,
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
201
        )
202
    )
203
204


Chris Jewell's avatar
Chris Jewell committed
205
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
206
207
208
209
210
211
212
213
214
215
216
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
217
218
219
    """

    def fn(t, state):
220
221
222
        L = tf.convert_to_tensor(covar_data["L"], DTYPE)
        L = L - tf.reduce_mean(L, axis=0)

Chris Jewell's avatar
Chris Jewell committed
223
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
224
225
226
        C = tf.linalg.set_diag(
            C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
        )
Chris Jewell's avatar
Chris Jewell committed
227
228
229
230
231
232
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
233
234
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
            dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
235
236
237
        )
        xi = tf.gather(param["xi"], xi_idx)

238
239
240
241
242
243
        L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
        Lt = tf.gather(L, L_idx)
        xB = tf.linalg.matvec(Lt, param["beta3"])

        beta = tf.math.exp(xi + xB)

244
        ngm = beta[:, tf.newaxis] * (
Chris Jewell's avatar
Chris Jewell committed
245
246
247
            tf.eye(C.shape[0], dtype=state.dtype)
            + param["beta2"] * commute_volume * C / N[tf.newaxis, :]
        )
Chris Jewell's avatar
Chris Jewell committed
248
249
250
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
251
            / (N[:, tf.newaxis] * tf.math.exp(param["gamma0"]))
Chris Jewell's avatar
Chris Jewell committed
252
        )
Chris Jewell's avatar
Chris Jewell committed
253
254
255
        return ngm

    return fn