model_spec.py 10.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
4
import geopandas as gp
5
import numpy as np
6
import xarray
Chris Jewell's avatar
Chris Jewell committed
7
8
9
import tensorflow as tf
import tensorflow_probability as tfp

10
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
11
12
from gemlib.distributions import BrownianMotion

Chris Jewell's avatar
Chris Jewell committed
13
from covid.util import impute_previous_cases
14
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
15
16

tfd = tfp.distributions
17

Chris Jewell's avatar
Chris Jewell committed
18
VERSION = "0.7.0"
Chris Jewell's avatar
Chris Jewell committed
19
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
20

21
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
22
TIME_DELTA = 1.0
23
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
24

Chris Jewell's avatar
Chris Jewell committed
25

26
def gather_data(config):
27
28
29
30
31
32
33
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
34
35
36

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
Chris Jewell's avatar
Chris Jewell committed
37
38
39
40
41
42
43
    locations = data.AreaCodeData.process(config)
    mobility = data.read_mobility(
        config["mobility_matrix"], locations["lad19cd"]
    )
    popsize = data.read_population(
        config["population_size"], locations["lad19cd"]
    )
44
    commute_volume = data.read_traffic_flow(
45
        config["commute_volume"], date_low=date_low, date_high=date_high
46
    )
Chris Jewell's avatar
Chris Jewell committed
47
48
49
50
51
52
53
54
    geo = gp.read_file(config["geopackage"])
    geo = geo.sort_values("lad19cd")
    area = xarray.DataArray(
        geo.area,
        name="area",
        dims=["location"],
        coords=[geo["lad19cd"]],
    )
55

Chris Jewell's avatar
Chris Jewell committed
56
    # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
57
58
59
60
61
62
    dates = pd.date_range(*config["date_range"], closed="left")
    weekday = xarray.DataArray(
        dates.weekday < 5,
        name="weekday",
        dims=["time"],
        coords=[dates.to_numpy()],
63
    )
64

Chris Jewell's avatar
Chris Jewell committed
65
    cases = data.CasesData.process(config).to_xarray()
66
67
68
69
70
71
72
    return (
        xarray.Dataset(
            dict(
                C=mobility.astype(DTYPE),
                W=commute_volume.astype(DTYPE),
                N=popsize.astype(DTYPE),
                weekday=weekday.astype(DTYPE),
Chris Jewell's avatar
Chris Jewell committed
73
                area=area.astype(DTYPE),
74
75
76
77
78
79
80
81
                locations=xarray.DataArray(
                    locations["name"],
                    dims=["location"],
                    coords=[locations["lad19cd"]],
                ),
            )
        ),
        xarray.Dataset(dict(cases=cases)),
82
83
84
    )


Chris Jewell's avatar
Chris Jewell committed
85
86
87
88
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

89
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
90
91
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
92
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
93
94
95
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
96
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
97
98
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
99
100
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
101
102
103
104
105
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
106
107
108
109
110
111
112
113
114
115
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


116
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
117
118
119
120
121
122
123
    def alpha_0():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(10.0, dtype=DTYPE),
        )

    def beta_area():
124
125
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
126
            scale=tf.constant(1.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
127
128
        )

Chris Jewell's avatar
Chris Jewell committed
129
    def psi():
Chris Jewell's avatar
Chris Jewell committed
130
131
132
133
134
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

135
136
137
138
139
140
141
142
143
144
    def alpha_t():
        # return BrownianMotion(
        #     tf.range(num_steps, dtype=DTYPE), x0=alpha_0, scale=0.005
        # )
        return tfd.MultivariateNormalDiag(loc=tf.constant(0.0, dtype=DTYPE),
                                          scale_diag=tf.fill(
                                              [num_steps-1], tf.constant(0.005,  dtype=DTYPE)
                                          ),
                                          )
        
145
    def gamma0():
146
147
148
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
149
        )
150

151
152
153
154
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
155
156
        )

Chris Jewell's avatar
Chris Jewell committed
157
158
159
160
    def seir(psi, beta_area, alpha_0, alpha_t, gamma0, gamma1):
        psi = tf.convert_to_tensor(psi, DTYPE)
        beta_area = tf.convert_to_tensor(beta_area, DTYPE)
        alpha_t = tf.convert_to_tensor(alpha_t, DTYPE)
161
162
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
163

164
165
166
167
168
169
170
171
172
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))

        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

173
174
175
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
180
        # Area in 100km^2
        area = tf.convert_to_tensor(covariates["area"], DTYPE)
        log_area = tf.math.log(area / 100000000.0)  # log area in 100km^2
        log_area = log_area - tf.reduce_mean(log_area)

Chris Jewell's avatar
Chris Jewell committed
181
182
183
184
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
185

186
187
188
189
190
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

Chris Jewell's avatar
Chris Jewell committed
191
            with tf.name_scope("Pick_alpha_t"):
192
                b_t = alpha_0 + tf.cumsum(alpha_t)
Chris Jewell's avatar
Chris Jewell committed
193
194
195
196
197
                alpha_t_idx = tf.cast(t, tf.int64)
                alpha_t_ = tf.where(
                    alpha_t_idx == initial_step,
                    alpha_0,
                    tf.gather(
198
                        b_t,
Chris Jewell's avatar
Chris Jewell committed
199
200
201
202
203
204
205
206
207
208
                        tf.clip_by_value(
                            alpha_t_idx - initial_step - 1,
                            clip_value_min=0,
                            clip_value_max=alpha_t.shape[0] - 1,
                        ),
                    ),
                )

            eta = alpha_t_ + beta_area * log_area
            infec_rate = tf.math.exp(eta) * (
Chris Jewell's avatar
Chris Jewell committed
209
                state[..., 2]
Chris Jewell's avatar
Chris Jewell committed
210
                + psi
211
                * commute_volume
212
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
213
            )
Chris Jewell's avatar
Chris Jewell committed
214
215
216
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
217

Chris Jewell's avatar
Chris Jewell committed
218
219
220
221
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
222
223
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
224
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
225
226
227
228
229
230
231
232
233
234
235
236
237

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
238
        dict(
Chris Jewell's avatar
Chris Jewell committed
239
240
241
242
            alpha_0=alpha_0,
            beta_area=beta_area,
            psi=psi,
            alpha_t=alpha_t,
243
244
245
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
246
        )
247
    )
248
249


Chris Jewell's avatar
Chris Jewell committed
250
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
251
252
253
254
255
256
257
258
259
260
261
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
262
263
264
265
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
266
267
268
269
        C = tf.linalg.set_diag(C, -tf.reduce_sum(C, axis=-2))
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
270

Chris Jewell's avatar
Chris Jewell committed
271
272
273
274
275
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
Chris Jewell's avatar
Chris Jewell committed
276
277
278
279
280
281
282
283
284
285
286
        xi = tf.where(
            t == 0,
            param["alpha_0"],
            tf.gather(
                param["alpha_t"],
                tf.clip_by_value(
                    t,
                    clip_value_min=0,
                    clip_value_max=param["alpha_t"].shape[-1] - 1,
                ),
            ),
Chris Jewell's avatar
Chris Jewell committed
287
288
        )

Chris Jewell's avatar
Chris Jewell committed
289
        beta = tf.math.exp(xi)
290

Chris Jewell's avatar
Chris Jewell committed
291
        ngm = beta * (
292
            tf.eye(Cstar.shape[0], dtype=state.dtype)
Chris Jewell's avatar
Chris Jewell committed
293
            + param["psi"] * commute_volume * Cstar / N[tf.newaxis, :]
Chris Jewell's avatar
Chris Jewell committed
294
        )
Chris Jewell's avatar
Chris Jewell committed
295
296
297
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
298
            / (N[:, tf.newaxis] * tf.math.exp(param["gamma0"]))
Chris Jewell's avatar
Chris Jewell committed
299
        )
Chris Jewell's avatar
Chris Jewell committed
300
301
302
        return ngm

    return fn