model_spec.py 10.5 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
4
import geopandas as gp
5
import numpy as np
6
import xarray
Chris Jewell's avatar
Chris Jewell committed
7
8
9
import tensorflow as tf
import tensorflow_probability as tfp

10
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
11
12
from gemlib.distributions import BrownianMotion

Chris Jewell's avatar
Chris Jewell committed
13
from covid.util import impute_previous_cases
14
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
15
16

tfd = tfp.distributions
17

Chris Jewell's avatar
Chris Jewell committed
18
VERSION = "0.7.1"
Chris Jewell's avatar
Chris Jewell committed
19
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
20

21
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
22
TIME_DELTA = 1.0
23
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
24

Chris Jewell's avatar
Chris Jewell committed
25

26
def gather_data(config):
27
28
29
30
31
32
33
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
34
35
36

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
Chris Jewell's avatar
Chris Jewell committed
37
38
39
40
41
42
43
    locations = data.AreaCodeData.process(config)
    mobility = data.read_mobility(
        config["mobility_matrix"], locations["lad19cd"]
    )
    popsize = data.read_population(
        config["population_size"], locations["lad19cd"]
    )
44
    commute_volume = data.read_traffic_flow(
45
        config["commute_volume"], date_low=date_low, date_high=date_high
46
    )
Chris Jewell's avatar
Chris Jewell committed
47
48
    geo = gp.read_file(config["geopackage"])
    geo = geo.sort_values("lad19cd")
Chris Jewell's avatar
Chris Jewell committed
49
    geo = geo[geo['lad19cd'].isin(locations['lad19cd'])]
Chris Jewell's avatar
Chris Jewell committed
50
51
52
53
54
55
    area = xarray.DataArray(
        geo.area,
        name="area",
        dims=["location"],
        coords=[geo["lad19cd"]],
    )
56

Chris Jewell's avatar
Chris Jewell committed
57
    # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
58
59
60
61
62
63
    dates = pd.date_range(*config["date_range"], closed="left")
    weekday = xarray.DataArray(
        dates.weekday < 5,
        name="weekday",
        dims=["time"],
        coords=[dates.to_numpy()],
64
    )
65

Chris Jewell's avatar
Chris Jewell committed
66
    cases = data.CasesData.process(config).to_xarray()
67
68
69
70
71
72
73
    return (
        xarray.Dataset(
            dict(
                C=mobility.astype(DTYPE),
                W=commute_volume.astype(DTYPE),
                N=popsize.astype(DTYPE),
                weekday=weekday.astype(DTYPE),
Chris Jewell's avatar
Chris Jewell committed
74
                area=area.astype(DTYPE),
75
76
77
78
79
80
81
82
                locations=xarray.DataArray(
                    locations["name"],
                    dims=["location"],
                    coords=[locations["lad19cd"]],
                ),
            )
        ),
        xarray.Dataset(dict(cases=cases)),
83
84
85
    )


Chris Jewell's avatar
Chris Jewell committed
86
87
88
89
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

90
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
91
92
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
93
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
94
95
96
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
97
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
98
99
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
100
101
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
102
103
104
105
106
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
107
108
109
110
111
112
113
114
115
116
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


117
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
118
119
120
121
122
123
124
    def alpha_0():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(10.0, dtype=DTYPE),
        )

    def beta_area():
125
126
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
127
            scale=tf.constant(1.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
128
129
        )

Chris Jewell's avatar
Chris Jewell committed
130
    def psi():
Chris Jewell's avatar
Chris Jewell committed
131
132
133
134
135
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

136
137
138
139
    def alpha_t():
        # return BrownianMotion(
        #     tf.range(num_steps, dtype=DTYPE), x0=alpha_0, scale=0.005
        # )
140
141
142
143
144
145
146
        return tfd.MultivariateNormalDiag(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale_diag=tf.fill(
                [num_steps - 1], tf.constant(0.005, dtype=DTYPE)
            ),
        )

147
    def gamma0():
148
149
150
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
151
        )
152

153
154
155
156
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
157
158
        )

Chris Jewell's avatar
Chris Jewell committed
159
160
161
162
    def seir(psi, beta_area, alpha_0, alpha_t, gamma0, gamma1):
        psi = tf.convert_to_tensor(psi, DTYPE)
        beta_area = tf.convert_to_tensor(beta_area, DTYPE)
        alpha_t = tf.convert_to_tensor(alpha_t, DTYPE)
163
164
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
165

166
167
168
169
170
171
172
173
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

174
175
176
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
177
178
179
180
181
        # Area in 100km^2
        area = tf.convert_to_tensor(covariates["area"], DTYPE)
        log_area = tf.math.log(area / 100000000.0)  # log area in 100km^2
        log_area = log_area - tf.reduce_mean(log_area)

Chris Jewell's avatar
Chris Jewell committed
182
183
184
185
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
186

187
188
189
190
191
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

Chris Jewell's avatar
Chris Jewell committed
192
            with tf.name_scope("Pick_alpha_t"):
193
                b_t = alpha_0 + tf.cumsum(alpha_t)
Chris Jewell's avatar
Chris Jewell committed
194
195
                alpha_t_idx = tf.cast(t, tf.int64)
                alpha_t_ = tf.where(
Chris Jewell's avatar
Chris Jewell committed
196
                    alpha_t_idx == 0,
Chris Jewell's avatar
Chris Jewell committed
197
198
                    alpha_0,
                    tf.gather(
199
                        b_t,
Chris Jewell's avatar
Chris Jewell committed
200
                        tf.clip_by_value(
Chris Jewell's avatar
Chris Jewell committed
201
                            alpha_t_idx - 1,
Chris Jewell's avatar
Chris Jewell committed
202
203
204
205
206
207
208
209
                            clip_value_min=0,
                            clip_value_max=alpha_t.shape[0] - 1,
                        ),
                    ),
                )

            eta = alpha_t_ + beta_area * log_area
            infec_rate = tf.math.exp(eta) * (
Chris Jewell's avatar
Chris Jewell committed
210
                state[..., 2]
Chris Jewell's avatar
Chris Jewell committed
211
                + psi
212
                * commute_volume
213
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
214
            )
Chris Jewell's avatar
Chris Jewell committed
215
216
217
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
218

Chris Jewell's avatar
Chris Jewell committed
219
220
221
222
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
223
224
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
225
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
226
227
228
229
230
231
232
233
234
235
236
237
238

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
239
        dict(
Chris Jewell's avatar
Chris Jewell committed
240
241
242
243
            alpha_0=alpha_0,
            beta_area=beta_area,
            psi=psi,
            alpha_t=alpha_t,
244
245
246
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
247
        )
248
    )
249
250


Chris Jewell's avatar
Chris Jewell committed
251
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
252
253
254
255
256
257
258
259
260
261
262
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
263
264
265
266
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
267
268
269
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
270

Chris Jewell's avatar
Chris Jewell committed
271
272
273
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

274
        # Area in 100km^2
275
        area = tf.convert_to_tensor(covar_data["area"], DTYPE)
276
277
        log_area = tf.math.log(area / 100000000.0)  # log area in 100km^2
        log_area = log_area - tf.reduce_mean(log_area)
278

Chris Jewell's avatar
Chris Jewell committed
279
280
        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
281
282
        b_t = param["alpha_0"] + tf.cumsum(param["alpha_t"])
        alpha_t_ = tf.where(
Chris Jewell's avatar
Chris Jewell committed
283
284
285
            t == 0,
            param["alpha_0"],
            tf.gather(
286
                b_t,
Chris Jewell's avatar
Chris Jewell committed
287
288
289
290
291
292
                tf.clip_by_value(
                    t,
                    clip_value_min=0,
                    clip_value_max=param["alpha_t"].shape[-1] - 1,
                ),
            ),
Chris Jewell's avatar
Chris Jewell committed
293
        )
294
295
296
297
298
299
300
301
302
303
304

        eta = alpha_t_ + param["beta_area"] * log_area[:, tf.newaxis]
        infec_rate = (
            tf.math.exp(eta)
            * (
                tf.eye(Cstar.shape[0], dtype=state.dtype)
                + param["psi"] * commute_volume * Cstar / N[tf.newaxis, :]
            )
            / N[:, tf.newaxis]
        )
        infec_prob = 1.0 - tf.math.exp(-infec_rate)
305
306

        expected_new_infec = infec_prob * state[..., 0][..., tf.newaxis]
307
308
309
        expected_infec_period = 1.0 / (
            1.0 - tf.math.exp(-tf.math.exp(param["gamma0"]))
        )
310
        ngm = expected_new_infec * expected_infec_period
Chris Jewell's avatar
Chris Jewell committed
311
312
313
        return ngm

    return fn