model_spec.py 7.86 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
import geopandas as gp
4
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
5
6
7
import tensorflow as tf
import tensorflow_probability as tfp

8
from gemlib.distributions import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
9
from covid.util import impute_previous_cases
10
import covid.data as data
Chris Jewell's avatar
Chris Jewell committed
11
12

tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
13
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
14

15
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
Chris Jewell's avatar
Chris Jewell committed
16
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
17
18
19
XI_FREQ = 14  # baseline transmission changes every 14 days
NU = tf.constant(0.5, dtype=DTYPE)  # E->I rate assumed known.

Chris Jewell's avatar
Chris Jewell committed
20

21
def read_covariates(paths, date_low, date_high):
22
23
24
25
26
27
28
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
29
30
31
32
33
    mobility = data.read_mobility(paths["mobility_matrix"])
    popsize = data.read_population(paths["population_size"])
    commute_volume = data.read_traffic_flow(
        paths["commute_volume"], date_low=date_low, date_high=date_high
    )
34

35
36
37
38
39
40
41
42
    geo = gp.read_file(paths["geopackage"])
    geo = geo.loc[geo["lad19cd"].str.startswith("E")]
    tier_restriction = data.read_tier_restriction_data(
        paths["tier_restriction_csv"],
        geo[["lad19cd", "lad19nm"]],
        date_low,
        date_high,
    )
43
44
45
46
    return dict(
        C=mobility.to_numpy().astype(DTYPE),
        W=commute_volume.to_numpy().astype(DTYPE),
        N=popsize.to_numpy().astype(DTYPE),
47
        L=tier_restriction.astype(DTYPE),
48
49
50
    )


Chris Jewell's avatar
Chris Jewell committed
51
52
53
54
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

55
    There are application-specific magic numbers hard-coded below,
Chris Jewell's avatar
Chris Jewell committed
56
57
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
58
    by experimentation and examination of the resulting epidemic
Chris Jewell's avatar
Chris Jewell committed
59
60
61
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
62
    :returns: a MxTx3 tensor of events where the first two indices of
Chris Jewell's avatar
Chris Jewell committed
63
64
              the right-most dimension contain the imputed event times.
    """
65
66
    ei_events, lag_ei = impute_previous_cases(cases, 0.21)
    se_events, lag_se = impute_previous_cases(ei_events, 0.28)
Chris Jewell's avatar
Chris Jewell committed
67
68
69
70
71
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


72
def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
Chris Jewell's avatar
Chris Jewell committed
73
    def beta1():
74
75
        return tfd.Normal(
            loc=tf.constant(0.0, dtype=DTYPE),
76
            scale=tf.constant(1000.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
77
78
79
80
81
82
83
84
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

85
86
87
88
    def beta3():
        return tfd.Sample(
            tfd.Normal(
                loc=tf.constant(0.0, dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
89
                scale=tf.constant(100.0, dtype=DTYPE),
90
91
92
93
            ),
            sample_shape=2,
        )

94
95
    def xi(beta1):
        sigma = tf.constant(0.1, dtype=DTYPE)
96
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
97
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
98
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
99
100
101
102
103
        return tfd.GaussianProcess(
            kernel,
            mean_fn=lambda idx: beta1,
            index_points=idx_pts[:, tf.newaxis],
        )
Chris Jewell's avatar
Chris Jewell committed
104
105
106

    def gamma():
        return tfd.Gamma(
107
108
109
110
            concentration=tf.constant(
                priors["gamma"]["concentration"], dtype=DTYPE
            ),
            rate=tf.constant(priors["gamma"]["rate"], dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
111
112
        )

113
    def seir(beta2, beta3, xi, gamma):
Chris Jewell's avatar
Chris Jewell committed
114
115

        beta2 = tf.convert_to_tensor(beta2, DTYPE)
116
        beta3 = tf.convert_to_tensor(beta3, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
117
118
119
        xi = tf.convert_to_tensor(xi, DTYPE)
        gamma = tf.convert_to_tensor(gamma, DTYPE)

120
121
122
        L = tf.convert_to_tensor(covariates["L"], DTYPE)
        L = L - tf.reduce_mean(L, axis=0)

Chris Jewell's avatar
Chris Jewell committed
123
124
125
126
127
        def transition_rate_fn(t, state):
            C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
            C = tf.linalg.set_diag(
                C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
            )
128
129
            W = tf.constant(np.squeeze(covariates["W"]), dtype=DTYPE)
            N = tf.constant(np.squeeze(covariates["N"]), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
130
131
132
133

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
134
135
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
                dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
136
137
138
            )
            xi_ = tf.gather(xi, xi_idx)

139
140
141
142
143
            L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
            Lt = tf.gather(L, L_idx)
            xB = tf.linalg.matvec(Lt, beta3)

            infec_rate = tf.math.exp(xi_ + xB) * (
Chris Jewell's avatar
Chris Jewell committed
144
                state[..., 2]
145
146
147
                + beta2
                * commute_volume
                * tf.linalg.matvec(C, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
148
            )
149
150
151
            infec_rate = (
                infec_rate / tf.squeeze(N) + 0.000000001
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
152

153
154
155
156
157
158
            ei = tf.broadcast_to(
                [NU], shape=[state.shape[0]]
            )  # Vector of length nc
            ir = tf.broadcast_to(
                [gamma], shape=[state.shape[0]]
            )  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
159
160
161
162
163
164
165
166
167
168
169
170
171

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
172
173
        dict(
            beta1=beta1, beta2=beta2, beta3=beta3, xi=xi, gamma=gamma, seir=seir
174
        )
175
    )
176
177


Chris Jewell's avatar
Chris Jewell committed
178
def next_generation_matrix_fn(covar_data, param):
Chris Jewell's avatar
Chris Jewell committed
179
180
181
182
183
184
185
186
187
188
189
    """The next generation matrix calculates the force of infection from
    individuals in metapopulation i to all other metapopulations j during
    a typical infectious period (1/gamma). i.e.

      \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

    :param covar_data: a dictionary of covariate data
    :param param: a dictionary of parameters
    :returns: a function taking arguments `t` and `state` giving the time and
              epidemic state (SEIR) for which the NGM is to be calculated.  This
              function in turn returns an MxM next generation matrix.
Chris Jewell's avatar
Chris Jewell committed
190
191
192
    """

    def fn(t, state):
193
194
195
        L = tf.convert_to_tensor(covar_data["L"], DTYPE)
        L = L - tf.reduce_mean(L, axis=0)

Chris Jewell's avatar
Chris Jewell committed
196
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
197
198
199
        C = tf.linalg.set_diag(
            C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
        )
Chris Jewell's avatar
Chris Jewell committed
200
201
202
203
204
205
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
206
207
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
            dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
208
209
210
        )
        xi = tf.gather(param["xi"], xi_idx)

211
212
213
214
215
216
217
        L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
        Lt = tf.gather(L, L_idx)
        xB = tf.linalg.matvec(Lt, param["beta3"])

        beta = tf.math.exp(xi + xB)

        ngm = beta[tf.newaxis, :] * (
Chris Jewell's avatar
Chris Jewell committed
218
219
220
            tf.eye(C.shape[0], dtype=state.dtype)
            + param["beta2"] * commute_volume * C / N[tf.newaxis, :]
        )
221
222
223
224
225
        ngm = (
            ngm
            * state[..., 0][..., tf.newaxis]
            / (N[:, tf.newaxis] * param["gamma"])
        )
Chris Jewell's avatar
Chris Jewell committed
226
227
228
        return ngm

    return fn