model_spec.py 8.84 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
4
import numpy as np
5
import xarray
Chris Jewell's avatar
Chris Jewell committed
6
7
8
import tensorflow as tf
import tensorflow_probability as tfp

9
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
10
from covid.util import impute_previous_cases
11
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
12
13

tfd = tfp.distributions
14

15
VERSION = "0.5.0"
Chris Jewell's avatar
Chris Jewell committed
16
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
17

18
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
19
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
20
XI_FREQ = 14  # baseline transmission changes every 14 days
21
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
22

Chris Jewell's avatar
Chris Jewell committed
23

24
def gather_data(config):
25
26
27
28
29
30
31
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
32
33
34

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
Chris Jewell's avatar
Chris Jewell committed
35
36
37
38
39
40
41
    locations = data.AreaCodeData.process(config)
    mobility = data.read_mobility(
        config["mobility_matrix"], locations["lad19cd"]
    )
    popsize = data.read_population(
        config["population_size"], locations["lad19cd"]
    )
42
    commute_volume = data.read_traffic_flow(
43
        config["commute_volume"], date_low=date_low, date_high=date_high
44
    )
45

Chris Jewell's avatar
Chris Jewell committed
46
    # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
47
48
49
50
51
52
    dates = pd.date_range(*config["date_range"], closed="left")
    weekday = xarray.DataArray(
        dates.weekday < 5,
        name="weekday",
        dims=["time"],
        coords=[dates.to_numpy()],
53
    )
54

Chris Jewell's avatar
Chris Jewell committed
55
    cases = data.CasesData.process(config).to_xarray()
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    return (
        xarray.Dataset(
            dict(
                C=mobility.astype(DTYPE),
                W=commute_volume.astype(DTYPE),
                N=popsize.astype(DTYPE),
                weekday=weekday.astype(DTYPE),
                locations=xarray.DataArray(
                    locations["name"],
                    dims=["location"],
                    coords=[locations["lad19cd"]],
                ),
            )
        ),
        xarray.Dataset(dict(cases=cases)),
71
72
73
    )


Chris Jewell's avatar
Chris Jewell committed
74
75
76
77
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

78
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
79
80
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
81
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
82
83
84
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
85
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
86
87
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
88
89
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
90
91
92
93
94
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
95
96
97
98
99
100
101
102
103
104
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


105
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
106
    def beta1():
107
108
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
109
            scale=tf.constant(1000.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
110
111
112
113
114
115
116
117
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

Chris Jewell's avatar
Chris Jewell committed
118
119
120
121
122
123
124
    def sigma():
        return tfd.Gamma(
            concentration=tf.constant(2.0, dtype=DTYPE),
            rate=tf.constant(20.0, dtype=DTYPE),
        )

    def xi(beta1, sigma):
125
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
126
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
127
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
128
        return tfd.GaussianProcessRegressionModel(
129
130
131
132
            kernel,
            mean_fn=lambda idx: beta1,
            index_points=idx_pts[:, tf.newaxis],
        )
Chris Jewell's avatar
Chris Jewell committed
133

134
    def gamma0():
135
136
137
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
138
        )
139

140
141
142
143
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
144
145
        )

146
    def seir(beta2, xi, gamma0, gamma1):
Chris Jewell's avatar
Chris Jewell committed
147
148
        beta2 = tf.convert_to_tensor(beta2, DTYPE)
        xi = tf.convert_to_tensor(xi, DTYPE)
149
150
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
151

152
153
154
155
156
157
158
159
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

160
161
162
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
163
164
165
166
167
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
168
169
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
                dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
170
171
            )
            xi_ = tf.gather(xi, xi_idx)
172

173
174
175
176
177
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

178
            infec_rate = tf.math.exp(xi_) * (
Chris Jewell's avatar
Chris Jewell committed
179
                state[..., 2]
180
181
                + beta2
                * commute_volume
182
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
183
            )
Chris Jewell's avatar
Chris Jewell committed
184
185
186
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
187

Chris Jewell's avatar
Chris Jewell committed
188
189
190
191
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
192
193
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
194
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
195
196
197
198
199
200
201
202
203
204
205
206
207

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
208
        dict(
209
210
            beta1=beta1,
            beta2=beta2,
Chris Jewell's avatar
Chris Jewell committed
211
            sigma=sigma,
212
213
214
215
            xi=xi,
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
216
        )
217
    )
218
219


Chris Jewell's avatar
Chris Jewell committed
220
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
221
222
223
224
225
226
227
228
229
230
231
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
232
233
234
235
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
236
237
238
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
239

Chris Jewell's avatar
Chris Jewell committed
240
241
242
243
244
245
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
246
247
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
            dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
248
249
250
        )
        xi = tf.gather(param["xi"], xi_idx)

Chris Jewell's avatar
Chris Jewell committed
251
        beta = tf.math.exp(xi)
252

Chris Jewell's avatar
Chris Jewell committed
253
        ngm = (
254
255
256
257
258
259
            beta
            * (
                tf.eye(Cstar.shape[0], dtype=state.dtype)
                + param["beta2"] * commute_volume * Cstar / N[tf.newaxis, :]
            )
            / N[:, tf.newaxis]
Chris Jewell's avatar
Chris Jewell committed
260
        )
261
262
263

        ngm = (1.0 - tf.math.exp(-ngm)) * state[..., 0][..., tf.newaxis]
        ngm = ngm / (1 - tf.math.exp(-tf.math.exp(param["gamma0"])))
Chris Jewell's avatar
Chris Jewell committed
264
265
266
        return ngm

    return fn