model_spec.py 8.52 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import pandas as pd
4
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
5
6
7
import tensorflow as tf
import tensorflow_probability as tfp

8
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
9
from covid.util import impute_previous_cases
10
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
11
12

tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
13
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
14

15
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
16
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
17
XI_FREQ = 14  # baseline transmission changes every 14 days
18
NU = tf.constant(0.28, dtype=DTYPE)  # E->I rate assumed known.
Chris Jewell's avatar
Chris Jewell committed
19

Chris Jewell's avatar
Chris Jewell committed
20

21
def gather_data(config):
22
23
24
25
26
27
28
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
29
30
31

    date_low = np.datetime64(config["date_range"][0])
    date_high = np.datetime64(config["date_range"][1])
Chris Jewell's avatar
Chris Jewell committed
32
33
34
35
36
37
38
    locations = data.AreaCodeData.process(config)
    mobility = data.read_mobility(
        config["mobility_matrix"], locations["lad19cd"]
    )
    popsize = data.read_population(
        config["population_size"], locations["lad19cd"]
    )
39
    commute_volume = data.read_traffic_flow(
40
        config["commute_volume"], date_low=date_low, date_high=date_high
41
    )
42

Chris Jewell's avatar
Chris Jewell committed
43
    # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
44
    date_range = [date_low, date_high]
45
46
47
    weekday = (
        pd.date_range(date_low, date_high - np.timedelta64(1, "D")).weekday < 5
    )
48

Chris Jewell's avatar
Chris Jewell committed
49
    cases = data.CasesData.process(config).to_xarray()
50
51
52
53
    return dict(
        C=mobility.to_numpy().astype(DTYPE),
        W=commute_volume.to_numpy().astype(DTYPE),
        N=popsize.to_numpy().astype(DTYPE),
54
        weekday=weekday.astype(DTYPE),
55
56
57
        date_range=date_range,
        locations=locations,
        cases=cases,
58
59
60
    )


Chris Jewell's avatar
Chris Jewell committed
61
62
63
64
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

65
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
66
67
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
68
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
69
70
71
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
72
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
73
74
              the right-most dimension contain the imputed event times.
    """
Chris Jewell's avatar
Chris Jewell committed
75
76
    ei_events, lag_ei = impute_previous_cases(cases, 0.25)
    se_events, lag_se = impute_previous_cases(ei_events, 0.5)
Chris Jewell's avatar
Chris Jewell committed
77
78
79
80
81
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


Chris Jewell's avatar
Chris Jewell committed
82
83
84
85
86
87
88
89
90
91
def conditional_gp(gp, observations, new_index_points):

    param = gp.parameters
    param["observation_index_points"] = param["index_points"]
    param["observations"] = observations
    param["index_points"] = new_index_points

    return tfd.GaussianProcessRegressionModel(**param)


92
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
93
    def beta1():
94
95
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
96
            scale=tf.constant(1000.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
97
98
99
100
101
102
103
104
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

Chris Jewell's avatar
Chris Jewell committed
105
106
107
108
109
110
111
    def sigma():
        return tfd.Gamma(
            concentration=tf.constant(2.0, dtype=DTYPE),
            rate=tf.constant(20.0, dtype=DTYPE),
        )

    def xi(beta1, sigma):
112
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
113
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
114
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
115
        return tfd.GaussianProcessRegressionModel(
116
117
118
119
            kernel,
            mean_fn=lambda idx: beta1,
            index_points=idx_pts[:, tf.newaxis],
        )
Chris Jewell's avatar
Chris Jewell committed
120

121
    def gamma0():
122
123
124
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
125
        )
126

127
128
129
130
    def gamma1():
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
            scale=tf.constant(100.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
131
132
        )

133
    def seir(beta2, xi, gamma0, gamma1):
Chris Jewell's avatar
Chris Jewell committed
134
135
        beta2 = tf.convert_to_tensor(beta2, DTYPE)
        xi = tf.convert_to_tensor(xi, DTYPE)
136
137
        gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
        gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
138

139
140
141
142
143
144
145
146
147
        C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))

        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))

        W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
        N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)

148
149
150
        weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
        weekday = weekday - tf.reduce_mean(weekday, axis=-1)

Chris Jewell's avatar
Chris Jewell committed
151
152
153
154
155
        def transition_rate_fn(t, state):

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
156
157
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
                dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
158
159
            )
            xi_ = tf.gather(xi, xi_idx)
160

161
162
163
164
165
            weekday_idx = tf.clip_by_value(
                tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
            )
            weekday_t = tf.gather(weekday, weekday_idx)

166
            infec_rate = tf.math.exp(xi_) * (
Chris Jewell's avatar
Chris Jewell committed
167
                state[..., 2]
168
169
                + beta2
                * commute_volume
170
                * tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
171
            )
Chris Jewell's avatar
Chris Jewell committed
172
173
174
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
175

Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
180
181
                [tf.math.exp(gamma0 + gamma1 * weekday_t)],
                shape=[state.shape[0]],
Chris Jewell's avatar
Chris Jewell committed
182
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
183
184
185
186
187
188
189
190
191
192
193
194
195

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
196
        dict(
197
198
            beta1=beta1,
            beta2=beta2,
Chris Jewell's avatar
Chris Jewell committed
199
            sigma=sigma,
200
201
202
203
            xi=xi,
            gamma0=gamma0,
            gamma1=gamma1,
            seir=seir,
204
        )
205
    )
206
207


Chris Jewell's avatar
Chris Jewell committed
208
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
209
210
211
212
213
214
215
216
217
218
219
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
220
221
222
223
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
224
225
226
227
        C = tf.linalg.set_diag(C, -tf.reduce_sum(C, axis=-2))
        C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
        Cstar = C + tf.transpose(C)
        Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
Chris Jewell's avatar
Chris Jewell committed
228

Chris Jewell's avatar
Chris Jewell committed
229
230
231
232
233
234
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
Chris Jewell's avatar
Chris Jewell committed
235
236
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
            dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
237
238
239
        )
        xi = tf.gather(param["xi"], xi_idx)

Chris Jewell's avatar
Chris Jewell committed
240
        beta = tf.math.exp(xi)
241

Chris Jewell's avatar
Chris Jewell committed
242
        ngm = beta * (
243
244
            tf.eye(Cstar.shape[0], dtype=state.dtype)
            + param["beta2"] * commute_volume * Cstar / N[tf.newaxis, :]
Chris Jewell's avatar
Chris Jewell committed
245
        )
Chris Jewell's avatar
Chris Jewell committed
246
247
248
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
249
            / (N[:, tf.newaxis] * tf.math.exp(param["gamma0"]))
Chris Jewell's avatar
Chris Jewell committed
250
        )
Chris Jewell's avatar
Chris Jewell committed
251
252
253
        return ngm

    return fn