model_spec.py 9.72 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
4
import geopandas as gp
5
import numpy as np
6
import xarray
Chris Jewell's avatar
Chris Jewell committed
7
8
9
import tensorflow as tf
import tensorflow_probability as tfp

10
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
11
12
from gemlib.distributions import BrownianMotion

Chris Jewell's avatar
Chris Jewell committed
13
from covid.util import impute_previous_cases
14
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
15
16

tfd = tfp.distributions
17

Chris Jewell's avatar
Chris Jewell committed
18
VERSION = "0.7.0"
Chris Jewell's avatar
Chris Jewell committed
19
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
20

21
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
22
TIME_DELTA = 1.0
23
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
24

Chris Jewell's avatar
Chris Jewell committed
25

26
def gather_data(config):
27
28
29
30
31
32
33
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
34
35
36

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
Chris Jewell's avatar
Chris Jewell committed
37
38
39
40
41
42
43
    locations = data.AreaCodeData.process(config)
    mobility = data.read_mobility(
        config["mobility_matrix"], locations["lad19cd"]
    )
    popsize = data.read_population(
        config["population_size"], locations["lad19cd"]
    )
44
    commute_volume = data.read_traffic_flow(
45
        config["commute_volume"], date_low=date_low, date_high=date_high
46
    )
Chris Jewell's avatar
Chris Jewell committed
47
48
49
50
51
52
53
54
    geo = gp.read_file(config["geopackage"])
    geo = geo.sort_values("lad19cd")
    area = xarray.DataArray(
        geo.area,
        name="area",
        dims=["location"],
        coords=[geo["lad19cd"]],
    )
55

Chris Jewell's avatar
Chris Jewell committed
56
    # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
57
58
59
60
61
62
    dates = pd.date_range(*config["date_range"], closed="left")
    weekday = xarray.DataArray(
        dates.weekday < 5,
        name="weekday",
        dims=["time"],
        coords=[dates.to_numpy()],
63
    )
64

Chris Jewell's avatar
Chris Jewell committed
65
    cases = data.CasesData.process(config).to_xarray()
66
67
68
69
70
71
72
    return (
        xarray.Dataset(
            dict(
                C=mobility.astype(DTYPE),
                W=commute_volume.astype(DTYPE),
                N=popsize.astype(DTYPE),
                weekday=weekday.astype(DTYPE),
Chris Jewell's avatar
Chris Jewell committed
73
                area=area.astype(DTYPE),
74
75
76
77
78
79
80
81
                locations=xarray.DataArray(
                    locations["name"],
                    dims=["location"],
                    coords=[locations["lad19cd"]],
                ),
            )
        ),
        xarray.Dataset(dict(cases=cases)),
82
83
84
    )


Chris Jewell's avatar
Chris Jewell committed
85
86
87
88
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

89
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
90
91
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
92
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
93
94
95
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
96
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
97
98
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
99
100
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
101
102
103
104
105
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
106
107
108
109
110
111
112
113
114
115
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


116
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
117
118
119
120
121
122
123
    def alpha_0():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(10.0, dtype=DTYPE),
        )

    def beta_area():
124
125
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
126
            scale=tf.constant(1.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
127
128
        )

Chris Jewell's avatar
Chris Jewell committed
129
    def psi():
Chris Jewell's avatar
Chris Jewell committed
130
131
132
133
134
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

Chris Jewell's avatar
Chris Jewell committed
135
136
    def alpha_t(alpha_0):
        return BrownianMotion(
137
            tf.range(num_steps, dtype=DTYPE), x0=alpha_0, scale=0.005
138
        )
Chris Jewell's avatar
Chris Jewell committed
139

140
    def gamma0():
141
142
143
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
144
        )
145

146
147
148
149
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
150
151
        )

Chris Jewell's avatar
Chris Jewell committed
152
153
154
155
    def seir(psi, beta_area, alpha_0, alpha_t, gamma0, gamma1):
        psi = tf.convert_to_tensor(psi, DTYPE)
        beta_area = tf.convert_to_tensor(beta_area, DTYPE)
        alpha_t = tf.convert_to_tensor(alpha_t, DTYPE)
156
157
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
158

159
160
161
162
163
164
165
166
167
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))

        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

168
169
170
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
171
172
173
174
175
        # Area in 100km^2
        area = tf.convert_to_tensor(covariates["area"], DTYPE)
        log_area = tf.math.log(area / 100000000.0)  # log area in 100km^2
        log_area = log_area - tf.reduce_mean(log_area)

Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
180

181
182
183
184
185
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

Chris Jewell's avatar
Chris Jewell committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
            with tf.name_scope("Pick_alpha_t"):
                alpha_t_idx = tf.cast(t, tf.int64)
                alpha_t_ = tf.where(
                    alpha_t_idx == initial_step,
                    alpha_0,
                    tf.gather(
                        alpha_t,
                        tf.clip_by_value(
                            alpha_t_idx - initial_step - 1,
                            clip_value_min=0,
                            clip_value_max=alpha_t.shape[0] - 1,
                        ),
                    ),
                )

            eta = alpha_t_ + beta_area * log_area
            infec_rate = tf.math.exp(eta) * (
Chris Jewell's avatar
Chris Jewell committed
203
                state[..., 2]
Chris Jewell's avatar
Chris Jewell committed
204
                + psi
205
                * commute_volume
206
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
207
            )
Chris Jewell's avatar
Chris Jewell committed
208
209
210
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
211

Chris Jewell's avatar
Chris Jewell committed
212
213
214
215
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
216
217
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
218
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
219
220
221
222
223
224
225
226
227
228
229
230
231

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
232
        dict(
Chris Jewell's avatar
Chris Jewell committed
233
234
235
236
            alpha_0=alpha_0,
            beta_area=beta_area,
            psi=psi,
            alpha_t=alpha_t,
237
238
239
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
240
        )
241
    )
242
243


Chris Jewell's avatar
Chris Jewell committed
244
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
245
246
247
248
249
250
251
252
253
254
255
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
256
257
258
259
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
260
261
262
263
        C = tf.linalg.set_diag(C, -tf.reduce_sum(C, axis=-2))
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
264

Chris Jewell's avatar
Chris Jewell committed
265
266
267
268
269
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
Chris Jewell's avatar
Chris Jewell committed
270
271
272
273
274
275
276
277
278
279
280
        xi = tf.where(
            t == 0,
            param["alpha_0"],
            tf.gather(
                param["alpha_t"],
                tf.clip_by_value(
                    t,
                    clip_value_min=0,
                    clip_value_max=param["alpha_t"].shape[-1] - 1,
                ),
            ),
Chris Jewell's avatar
Chris Jewell committed
281
282
        )

Chris Jewell's avatar
Chris Jewell committed
283
        beta = tf.math.exp(xi)
284

Chris Jewell's avatar
Chris Jewell committed
285
        ngm = beta * (
286
            tf.eye(Cstar.shape[0], dtype=state.dtype)
Chris Jewell's avatar
Chris Jewell committed
287
            + param["psi"] * commute_volume * Cstar / N[tf.newaxis, :]
Chris Jewell's avatar
Chris Jewell committed
288
        )
Chris Jewell's avatar
Chris Jewell committed
289
290
291
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
292
            / (N[:, tf.newaxis] * tf.math.exp(param["gamma0"]))
Chris Jewell's avatar
Chris Jewell committed
293
        )
Chris Jewell's avatar
Chris Jewell committed
294
295
296
        return ngm

    return fn