model_spec.py 6.62 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Implements the COVID SEIR model as a TFP Joint Distribution"""

3
4
import pandas as pd
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
5
6
7
8
import tensorflow as tf
import tensorflow_probability as tfp

from covid.model import DiscreteTimeStateTransitionModel
Chris Jewell's avatar
Chris Jewell committed
9
from covid.util import impute_previous_cases
Chris Jewell's avatar
Chris Jewell committed
10
11

tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
12
DTYPE = np.float64
Chris Jewell's avatar
Chris Jewell committed
13
14
15

STOICHIOMETRY = tf.constant([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
TIME_DELTA = 1.0
Chris Jewell's avatar
Chris Jewell committed
16
17
18
XI_FREQ = 14  # baseline transmission changes every 14 days
NU = tf.constant(0.5, dtype=DTYPE)  # E->I rate assumed known.

Chris Jewell's avatar
Chris Jewell committed
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def read_covariates(paths):
    """Loads covariate data

    :param paths: a dictionary of paths to data with keys {'mobility_matrix',
                  'population_size', 'commute_volume'}
    :returns: a dictionary of covariate information to be consumed by the model
              {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
    """
    mobility = pd.read_csv(paths["mobility_matrix"], index_col=0)
    popsize = pd.read_csv(paths["population_size"], index_col=0)
    commute_volume = pd.read_csv(paths["commute_volume"], index_col=0)

    return dict(
        C=mobility.to_numpy().astype(DTYPE),
        W=commute_volume.to_numpy().astype(DTYPE),
        N=popsize.to_numpy().astype(DTYPE),
    )


def read_cases(path):
    """Loads case data from CSV file"""
    cases_tidy = pd.read_csv(path)
    cases_wide = cases_tidy.pivot(index="lad19cd", columns="date", values="cases")
    return cases_wide


Chris Jewell's avatar
Chris Jewell committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def impute_censored_events(cases):
    """Imputes censored S->E and E->I events using geometric
       sampling algorithm in `impute_previous_cases`

    There are application-specific magic numbers hard-coded below, 
    which reflect experimentation to get the right lag between EI and
    IR events, and SE and EI events respectively.  These were chosen
    by experimentation and examination of the resulting epidemic 
    trajectories.

    :param cases: a MxT matrix of case numbers (I->R)
    :returns: a MxTx3 tensor of events where the first two indices of 
              the right-most dimension contain the imputed event times.
    """
    ei_events, lag_ei = impute_previous_cases(cases, 0.44)
    se_events, lag_se = impute_previous_cases(ei_events, 2.0)
    ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
    ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
    return tf.stack([se_events, ei_events, ir_events], axis=-1)


67
def CovidUK(covariates, initial_state, initial_step, num_steps):
Chris Jewell's avatar
Chris Jewell committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    def beta1():
        return tfd.Gamma(
            concentration=tf.constant(1.0, dtype=DTYPE),
            rate=tf.constant(1.0, dtype=DTYPE),
        )

    def beta2():
        return tfd.Gamma(
            concentration=tf.constant(3.0, dtype=DTYPE),
            rate=tf.constant(10.0, dtype=DTYPE),
        )

    def xi():
        sigma = tf.constant(0.01, dtype=DTYPE)
82
        phi = tf.constant(24.0, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
83
        kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
84
        idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
85
86
87
88
89
90
91
92
        return tfd.GaussianProcess(kernel, index_points=idx_pts[:, tf.newaxis])

    def gamma():
        return tfd.Gamma(
            concentration=tf.constant(100.0, dtype=DTYPE),
            rate=tf.constant(400.0, dtype=DTYPE),
        )

Chris Jewell's avatar
Chris Jewell committed
93
    def seir(beta1, beta2, xi, gamma):
Chris Jewell's avatar
Chris Jewell committed
94
95
96
97
98
99
100
101
102
103
104

        beta1 = tf.convert_to_tensor(beta1, DTYPE)
        beta2 = tf.convert_to_tensor(beta2, DTYPE)
        xi = tf.convert_to_tensor(xi, DTYPE)
        gamma = tf.convert_to_tensor(gamma, DTYPE)

        def transition_rate_fn(t, state):
            C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
            C = tf.linalg.set_diag(
                C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
            )
105
106
            W = tf.constant(np.squeeze(covariates["W"]), dtype=DTYPE)
            N = tf.constant(np.squeeze(covariates["N"]), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
107
108
109
110

            w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
            commute_volume = tf.gather(W, w_idx)
            xi_idx = tf.cast(
111
                tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1), dtype=tf.int64,
Chris Jewell's avatar
Chris Jewell committed
112
113
114
115
116
117
            )
            xi_ = tf.gather(xi, xi_idx)
            beta = beta1 * tf.math.exp(xi_)

            infec_rate = beta * (
                state[..., 2]
118
119
120
                + beta2
                * commute_volume
                * tf.linalg.matvec(C, state[..., 2] / tf.squeeze(N))
Chris Jewell's avatar
Chris Jewell committed
121
            )
122
            infec_rate = infec_rate / tf.squeeze(N) + 0.000000001  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
123

Chris Jewell's avatar
Chris Jewell committed
124
            ei = tf.broadcast_to([NU], shape=[state.shape[0]])  # Vector of length nc
Chris Jewell's avatar
Chris Jewell committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            ir = tf.broadcast_to([gamma], shape=[state.shape[0]])  # Vector of length nc

            return [infec_rate, ei, ir]

        return DiscreteTimeStateTransitionModel(
            transition_rates=transition_rate_fn,
            stoichiometry=STOICHIOMETRY,
            initial_state=initial_state,
            initial_step=initial_step,
            time_delta=TIME_DELTA,
            num_steps=num_steps,
        )

    return tfd.JointDistributionNamed(
Chris Jewell's avatar
Chris Jewell committed
139
        dict(beta1=beta1, beta2=beta2, xi=xi, gamma=gamma, seir=seir)
Chris Jewell's avatar
Chris Jewell committed
140
    )
Chris Jewell's avatar
Chris Jewell committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


def next_generation_matrix_fn(covar_data, param):
    """The next generation matrix calculates the force of infection from 
       individuals in metapopulation i to all other metapopulations j during
       a typical infectious period (1/gamma). i.e.
         
         \[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]

       :param covar_data: a dictionary of covariate data
       :param param: a dictionary of parameters
       :returns: a function taking arguments `t` and `state` giving the time and
                 epidemic state (SEIR) for which the NGM is to be calculated.  This
                 function in turn returns an MxM next generation matrix.
    """

    def fn(t, state):
        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE))
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["N"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
            tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1), dtype=tf.int64,
        )
        xi = tf.gather(param["xi"], xi_idx)
        beta = param["beta1"] * tf.math.exp(xi)

        ngm = beta * (
            tf.eye(C.shape[0], dtype=state.dtype)
            + param["beta2"] * commute_volume * C / N[tf.newaxis, :]
        )
        ngm = ngm * state[..., 0][..., tf.newaxis] / (N[:, tf.newaxis] * param["gamma"])
        return ngm

    return fn