summary.py 9.4 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""Calculate Rt given a posterior"""
Chris Jewell's avatar
Chris Jewell committed
2
3
import argparse
import os
Chris Jewell's avatar
Chris Jewell committed
4
5
6
import yaml
import h5py
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
7
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
8
9
10
import geopandas as gp

import tensorflow as tf
11
from gemlib.util import compute_state
Chris Jewell's avatar
Chris Jewell committed
12

Chris Jewell's avatar
Chris Jewell committed
13
from covid.cli_arg_parse import cli_args
14
from covid.summary import (
Chris Jewell's avatar
Chris Jewell committed
15
16
17
18
19
20
21
22
23
    rayleigh_quotient,
    power_iteration,
)
from covid.summary import mean_and_ci

import model_spec

DTYPE = model_spec.DTYPE

Chris Jewell's avatar
Chris Jewell committed
24
GIS_TEMPLATE = "data/UK2019mod_pop.gpkg"
Chris Jewell's avatar
Chris Jewell committed
25

26

Chris Jewell's avatar
Chris Jewell committed
27
# Reproduction number calculation
Chris Jewell's avatar
Chris Jewell committed
28
def calc_R_it(param, events, init_state, covar_data, priors):
Chris Jewell's avatar
Chris Jewell committed
29
30
31
32
33
34
35
36
37
38
    """Calculates effective reproduction number for batches of metapopulations
    :param theta: a tensor of batched theta parameters [B] + theta.shape
    :param xi: a tensor of batched xi parameters [B] + xi.shape
    :param events: a [B, M, T, X] batched events tensor
    :param init_state: the initial state of the epidemic at earliest inference date
    :param covar_data: the covariate data
    :return a batched vector of R_it estimates
    """

    def r_fn(args):
Chris Jewell's avatar
Chris Jewell committed
39
        beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, events_ = args
Chris Jewell's avatar
Chris Jewell committed
40
41
        t = events_.shape[-2] - 1
        state = compute_state(init_state, events_, model_spec.STOICHIOMETRY)
42
        state = tf.gather(state, t, axis=-2)  # State on final inference day
Chris Jewell's avatar
Chris Jewell committed
43

Chris Jewell's avatar
Chris Jewell committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        model = model_spec.CovidUK(
            covariates=covar_data,
            initial_state=init_state,
            initial_step=0,
            num_steps=events_.shape[-2],
            priors=priors,
        )

        xi_pred = model_spec.conditional_gp(
            model.model["xi"](beta1_, sigma_),
            xi_,
            tf.constant(
                [events.shape[-2] + model_spec.XI_FREQ], dtype=model_spec.DTYPE
            )[:, tf.newaxis],
        )

60
        par = dict(
61
62
            beta1=beta1_,
            beta2=beta2_,
Chris Jewell's avatar
Chris Jewell committed
63
            beta3=beta3_,
Chris Jewell's avatar
Chris Jewell committed
64
            sigma=sigma_,
65
            gamma0=gamma0_,
66
            xi=xi_,
67
        )
Chris Jewell's avatar
Chris Jewell committed
68
        print("xi shape:", par["xi"].shape)
Chris Jewell's avatar
Chris Jewell committed
69
70
71
72
        ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
        ngm = ngm_fn(t, state)
        return ngm

73
74
75
76
77
    return tf.vectorized_map(
        r_fn,
        elems=(
            param["beta1"],
            param["beta2"],
Chris Jewell's avatar
Chris Jewell committed
78
79
            param["beta3"],
            param["sigma"],
80
81
82
83
84
            param["xi"],
            param["gamma0"],
            events,
        ),
    )
Chris Jewell's avatar
Chris Jewell committed
85
86
87


@tf.function
88
def predicted_incidence(param, init_state, init_step, num_steps, priors):
Chris Jewell's avatar
Chris Jewell committed
89
90
91
92
93
94
95
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param theta: a tensor of batched theta parameters [B] + theta.shape
    :param xi: a tensor of batched xi parameters [B] + xi.shape
    :param events: a [B, M, S] batched state tensor
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
96
97
    :param priors: the priors for gamma
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
Chris Jewell's avatar
Chris Jewell committed
98
99
100
101
              transitions
    """

    def sim_fn(args):
Chris Jewell's avatar
Chris Jewell committed
102
        beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args
Chris Jewell's avatar
Chris Jewell committed
103

104
105
106
        par = dict(
            beta1=beta1_,
            beta2=beta2_,
Chris Jewell's avatar
Chris Jewell committed
107
            beta3=beta3_,
108
109
            gamma0=gamma0_,
            gamma1=gamma1_,
110
            xi=xi_,
111
        )
Chris Jewell's avatar
Chris Jewell committed
112
113
114
115
116
117

        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
118
            priors=priors,
Chris Jewell's avatar
Chris Jewell committed
119
120
121
122
123
        )
        sim = model.sample(**par)
        return sim["seir"]

    events = tf.map_fn(
124
        sim_fn,
125
126
127
        elems=(
            param["beta1"],
            param["beta2"],
Chris Jewell's avatar
Chris Jewell committed
128
129
            param["beta3"],
            param["sigma"],
130
131
132
133
134
            param["xi"],
            param["gamma0"],
            param["gamma1"],
            init_state,
        ),
135
        fn_output_signature=(tf.float64),
Chris Jewell's avatar
Chris Jewell committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    )
    return events


# Today's prevalence
def prevalence(predicted_state, population_size, name=None):
    """Computes prevalence of E and I individuals

    :param state: the state at a particular timepoint [batch, M, S]
    :param population_size: the size of the population
    :returns: a dict of mean and 95% credibility intervals for prevalence
              in units of infections per person
    """
    prev = tf.reduce_sum(predicted_state[:, :, 1:3], axis=-1) / tf.squeeze(
        population_size
    )
    return mean_and_ci(prev, name=name)


def predicted_events(events, name=None):
    num_events = tf.reduce_sum(events, axis=-1)
    return mean_and_ci(num_events, name=name)


if __name__ == "__main__":

Chris Jewell's avatar
Chris Jewell committed
162
163
    args = cli_args()

Chris Jewell's avatar
Chris Jewell committed
164
    # Get general config
Chris Jewell's avatar
Chris Jewell committed
165
    with open(args.config, "r") as f:
Chris Jewell's avatar
Chris Jewell committed
166
167
        config = yaml.load(f, Loader=yaml.FullLoader)

Chris Jewell's avatar
Chris Jewell committed
168
169
170
171
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

Chris Jewell's avatar
Chris Jewell committed
172
    # Load covariate data
Chris Jewell's avatar
Chris Jewell committed
173
    covar_data = model_spec.read_covariates(
174
175
176
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
177
    )
Chris Jewell's avatar
Chris Jewell committed
178
179

    # Load posterior file
180
181
182
183
    posterior_path = os.path.join(
        config["output"]["results_dir"], config["output"]["posterior"]
    )
    print("Using posterior:", posterior_path)
Chris Jewell's avatar
Chris Jewell committed
184
    posterior = h5py.File(
Chris Jewell's avatar
Chris Jewell committed
185
        os.path.expandvars(
186
            posterior_path,
Chris Jewell's avatar
Chris Jewell committed
187
188
189
190
        ),
        "r",
        rdcc_nbytes=1024 ** 3,
        rdcc_nslots=1e6,
Chris Jewell's avatar
Chris Jewell committed
191
192
193
    )

    # Pre-determined thinning of posterior (better done in MCMC?)
Chris Jewell's avatar
Chris Jewell committed
194
    idx = range(6000, 10000, 10)
195
196
197
    param = dict(
        beta1=posterior["samples/beta1"][idx],
        beta2=posterior["samples/beta2"][idx],
Chris Jewell's avatar
Chris Jewell committed
198
199
200
201
202
203
        beta3=posterior["samples/beta3"][
            idx,
        ],
        sigma=posterior["samples/sigma"][
            idx,
        ],
204
205
206
207
        xi=posterior["samples/xi"][idx],
        gamma0=posterior["samples/gamma0"][idx],
        gamma1=posterior["samples/gamma1"][idx],
    )
Chris Jewell's avatar
Chris Jewell committed
208
209
    events = posterior["samples/events"][idx]
    init_state = posterior["initial_state"][:]
210
211
212
    state_timeseries = compute_state(
        init_state, events, model_spec.STOICHIOMETRY
    )
Chris Jewell's avatar
Chris Jewell committed
213
214
215

    # Build model
    model = model_spec.CovidUK(
216
217
218
219
220
        covar_data,
        initial_state=init_state,
        initial_step=0,
        num_steps=events.shape[1],
        priors=config["mcmc"]["prior"],
Chris Jewell's avatar
Chris Jewell committed
221
222
    )

Chris Jewell's avatar
Chris Jewell committed
223
224
225
    ngms = calc_R_it(
        param, events, init_state, covar_data, config["mcmc"]["prior"]
    )
Chris Jewell's avatar
Chris Jewell committed
226
227
228
    b, _ = power_iteration(ngms)
    rt = rayleigh_quotient(ngms, b)
    q = np.arange(0.05, 1.0, 0.05)
229
230
231
    rt_quantiles = pd.DataFrame(
        {"Rt": np.quantile(rt, q, axis=-1)}, index=q
    ).T.to_excel(
232
233
234
        os.path.join(
            config["output"]["results_dir"], config["output"]["national_rt"]
        ),
Chris Jewell's avatar
Chris Jewell committed
235
    )
Chris Jewell's avatar
Chris Jewell committed
236
237
238
239
240

    # Prediction requires simulation from the last available timepoint for 28 + 4 + 1 days
    # Note a 4 day recording lag in the case timeseries data requires that
    # now = state_timeseries.shape[-2] + 4
    prediction = predicted_incidence(
241
        param,
Chris Jewell's avatar
Chris Jewell committed
242
243
        init_state=state_timeseries[..., -1, :],
        init_step=state_timeseries.shape[-2] - 1,
Chris Jewell's avatar
Chris Jewell committed
244
        num_steps=70,
245
        priors=config["mcmc"]["prior"],
Chris Jewell's avatar
Chris Jewell committed
246
247
248
249
250
251
    )
    predicted_state = compute_state(
        state_timeseries[..., -1, :], prediction, model_spec.STOICHIOMETRY
    )

    # Prevalence now
252
253
254
    prev_now = prevalence(
        predicted_state[..., 4, :], covar_data["N"], name="prev"
    )
Chris Jewell's avatar
Chris Jewell committed
255
256
257
258
259
260
261
262
263

    # Incidence of detections now
    cases_now = predicted_events(prediction[..., 4:5, 2], name="cases")

    # Incidence from now to now+7
    cases_7 = predicted_events(prediction[..., 4:11, 2], name="cases7")
    cases_14 = predicted_events(prediction[..., 4:18, 2], name="cases14")
    cases_21 = predicted_events(prediction[..., 4:25, 2], name="cases21")
    cases_28 = predicted_events(prediction[..., 4:32, 2], name="cases28")
Chris Jewell's avatar
Chris Jewell committed
264
    cases_56 = predicted_events(prediction[..., 4:60, 2], name="cases56")
Chris Jewell's avatar
Chris Jewell committed
265
266

    # Prevalence at day 7
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    prev_7 = prevalence(
        predicted_state[..., 11, :], covar_data["N"], name="prev7"
    )
    prev_14 = prevalence(
        predicted_state[..., 18, :], covar_data["N"], name="prev14"
    )
    prev_21 = prevalence(
        predicted_state[..., 25, :], covar_data["N"], name="prev21"
    )
    prev_28 = prevalence(
        predicted_state[..., 32, :], covar_data["N"], name="prev28"
    )
    prev_56 = prevalence(
        predicted_state[..., 60, :], covar_data["N"], name="prev56"
    )
Chris Jewell's avatar
Chris Jewell committed
282
283
284
285
286
287
288
289
290
291

    def geosummary(geodata, summaries):
        for summary in summaries:
            for k, v in summary.items():
                arr = v
                if isinstance(v, tf.Tensor):
                    arr = v.numpy()
                geodata[k] = arr

    ## GIS here
Chris Jewell's avatar
Chris Jewell committed
292
    ltla = gp.read_file(GIS_TEMPLATE, layer="UK2019mod_pop_xgen")
Chris Jewell's avatar
Chris Jewell committed
293
294
    ltla = ltla[ltla["lad19cd"].str.startswith("E")]  # England only, for now.
    ltla = ltla.sort_values("lad19cd")
295
    rti = tf.reduce_sum(ngms, axis=-2)
Chris Jewell's avatar
Chris Jewell committed
296
297
298
299
300
301
302
303
304
305
306

    geosummary(
        ltla,
        (
            mean_and_ci(rti, name="Rt"),
            prev_now,
            cases_now,
            prev_7,
            prev_14,
            prev_21,
            prev_28,
Chris Jewell's avatar
Chris Jewell committed
307
            prev_56,
Chris Jewell's avatar
Chris Jewell committed
308
309
310
311
            cases_7,
            cases_14,
            cases_21,
            cases_28,
Chris Jewell's avatar
Chris Jewell committed
312
            cases_56,
Chris Jewell's avatar
Chris Jewell committed
313
314
315
316
317
318
319
320
321
322
        ),
    )

    ltla["Rt_exceed"] = np.mean(rti > 1.0, axis=0)
    ltla = ltla.loc[
        :,
        ltla.columns.str.contains(
            "(lad19cd|lad19nm$|prev|cases|Rt|popsize|geometry)", regex=True
        ),
    ]
Chris Jewell's avatar
Chris Jewell committed
323
    ltla.to_file(
324
325
326
        os.path.join(
            config["output"]["results_dir"], config["output"]["geopackage"]
        ),
Chris Jewell's avatar
Chris Jewell committed
327
328
        driver="GPKG",
    )