model_spec.py 9.51 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
4
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
5
6
7
import tensorflow as tf
import tensorflow_probability as tfp

8
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
9
from covid.util import impute_previous_cases
10
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
11
12

tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
13
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
14

15
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
16
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
17
XI_FREQ = 14  # baseline transmission changes every 14 days
18
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
19

Chris Jewell's avatar
Chris Jewell committed
20

21
def gather_data(config):
22
23
24
25
26
27
28
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
29
30
31
32
33

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
    mobility = data.read_mobility(config["mobility_matrix"])
    popsize = data.read_population(config["population_size"])
34
    commute_volume = data.read_traffic_flow(
35
        config["commute_volume"], date_low=date_low, date_high=date_high
36
    )
37

38
    locations = data.AreaCodeData.process(config)
39
    tier_restriction = data.TierData.process(config)[:, :, 2:]
40
    date_range = [date_low, date_high]
41
42
    weekday = pd.date_range(date_low, date_high).weekday < 5

Chris Jewell's avatar
Chris Jewell committed
43
44
45
46
47
48
49
50
    cases = data.CasesData.process(config).to_xarray()
    # cases = data.read_phe_cases(
    #     config['reported_cases'],
    #     date_low,
    #     date_high,
    #     pillar=config['pillar'],
    #     date_type=config['case_date_type'],
    # )
51
52
53
54
    return dict(
        C=mobility.to_numpy().astype(DTYPE),
        W=commute_volume.to_numpy().astype(DTYPE),
        N=popsize.to_numpy().astype(DTYPE),
55
        L=tier_restriction.astype(DTYPE),
56
        weekday=weekday.astype(DTYPE),
57
58
59
        date_range=date_range,
        locations=locations,
        cases=cases,
60
61
62
    )


Chris Jewell's avatar
Chris Jewell committed
63
64
65
66
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

67
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
68
69
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
70
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
71
72
73
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
74
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
75
76
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
77
78
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
79
80
81
82
83
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
84
85
86
87
88
89
90
91
92
93
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


94
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
95
    def beta1():
96
97
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
98
            scale=tf.constant(1000.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
99
100
101
102
103
104
105
106
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

107
    def beta3():
Chris Jewell's avatar
Chris Jewell committed
108
        return tfd.Independent(
109
            tfd.Normal(
Chris Jewell's avatar
Chris Jewell committed
110
111
                loc=tf.constant([0.0] * 4, dtype=DTYPE),
                scale=tf.constant([1.0] * 4, dtype=DTYPE),
112
            ),
Chris Jewell's avatar
Chris Jewell committed
113
            reinterpreted_batch_ndims=1,
114
115
        )

Chris Jewell's avatar
Chris Jewell committed
116
117
118
119
120
121
122
    def sigma():
        return tfd.Gamma(
            concentration=tf.constant(2.0, dtype=DTYPE),
            rate=tf.constant(20.0, dtype=DTYPE),
        )

    def xi(beta1, sigma):
123
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
124
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
125
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
126
        return tfd.GaussianProcessRegressionModel(
127
128
129
130
            kernel,
            mean_fn=lambda idx: beta1,
            index_points=idx_pts[:, tf.newaxis],
        )
Chris Jewell's avatar
Chris Jewell committed
131

132
    def gamma0():
133
134
135
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
136
        )
137

138
139
140
141
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
142
143
        )

144
    def seir(beta2, beta3, xi, gamma0, gamma1):
Chris Jewell's avatar
Chris Jewell committed
145
        beta2 = tf.convert_to_tensor(beta2, DTYPE)
146
        beta3 = tf.convert_to_tensor(beta3, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
147
        xi = tf.convert_to_tensor(xi, DTYPE)
148
149
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
150

151
152
153
154
155
156
157
158
159
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))

        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

160
        L = tf.convert_to_tensor(covariates["L"], DTYPE)
Chris Jewell's avatar
Chris Jewell committed
161
        L = L - tf.reduce_mean(L, axis=(0, 1))
162

163
164
165
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
166
167
168
169
170
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
171
172
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
                dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
173
174
            )
            xi_ = tf.gather(xi, xi_idx)
175
176
177
178
            L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
            Lt = tf.gather(L, L_idx)
            xB = tf.linalg.matvec(Lt, beta3)

179
180
181
182
183
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

184
            infec_rate = tf.math.exp(xi_ + xB) * (
Chris Jewell's avatar
Chris Jewell committed
185
                state[..., 2]
186
187
                + beta2
                * commute_volume
188
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
189
            )
Chris Jewell's avatar
Chris Jewell committed
190
191
192
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
193

Chris Jewell's avatar
Chris Jewell committed
194
195
196
197
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
198
199
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
200
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
201
202
203
204
205
206
207
208
209
210
211
212
213

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
214
        dict(
215
216
217
            beta1=beta1,
            beta2=beta2,
            beta3=beta3,
Chris Jewell's avatar
Chris Jewell committed
218
            sigma=sigma,
219
220
221
222
            xi=xi,
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
223
        )
224
    )
225
226


Chris Jewell's avatar
Chris Jewell committed
227
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
228
229
230
231
232
233
234
235
236
237
238
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
239
240
241
    """

    def fn(t, state):
242
        L = tf.convert_to_tensor(covar_data["L"], DTYPE)
Chris Jewell's avatar
Chris Jewell committed
243
        L = L - tf.reduce_mean(L, axis=(0, 1))
244

Chris Jewell's avatar
Chris Jewell committed
245
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
246
247
248
249
        C = tf.linalg.set_diag(C, -tf.reduce_sum(C, axis=-2))
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
250

Chris Jewell's avatar
Chris Jewell committed
251
252
253
254
255
256
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
257
258
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
            dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
259
260
261
        )
        xi = tf.gather(param["xi"], xi_idx)

262
        L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
Chris Jewell's avatar
Chris Jewell committed
263
        Lt = L[-1]  # Last timepoint
264
265
266
267
        xB = tf.linalg.matvec(Lt, param["beta3"])

        beta = tf.math.exp(xi + xB)

268
        ngm = beta[:, tf.newaxis] * (
269
270
            tf.eye(Cstar.shape[0], dtype=state.dtype)
            + param["beta2"] * commute_volume * Cstar / N[tf.newaxis, :]
Chris Jewell's avatar
Chris Jewell committed
271
        )
Chris Jewell's avatar
Chris Jewell committed
272
273
274
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
275
            / (N[:, tf.newaxis] * tf.math.exp(param["gamma0"]))
Chris Jewell's avatar
Chris Jewell committed
276
        )
Chris Jewell's avatar
Chris Jewell committed
277
278
279
        return ngm

    return fn