summary.py 7.94 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""Calculate Rt given a posterior"""
Chris Jewell's avatar
Chris Jewell committed
2
3
import argparse
import os
Chris Jewell's avatar
Chris Jewell committed
4
5
6
import yaml
import h5py
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
7
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
8
9
10
import geopandas as gp

import tensorflow as tf
11
from gemlib.util import compute_state
Chris Jewell's avatar
Chris Jewell committed
12

Chris Jewell's avatar
Chris Jewell committed
13
from covid.cli_arg_parse import cli_args
14
from covid.summary import (
Chris Jewell's avatar
Chris Jewell committed
15
16
17
18
19
20
21
22
23
    rayleigh_quotient,
    power_iteration,
)
from covid.summary import mean_and_ci

import model_spec

DTYPE = model_spec.DTYPE

Chris Jewell's avatar
Chris Jewell committed
24
GIS_TEMPLATE = "data/UK2019mod_pop.gpkg"
Chris Jewell's avatar
Chris Jewell committed
25

26

Chris Jewell's avatar
Chris Jewell committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Reproduction number calculation
def calc_R_it(theta, xi, events, init_state, covar_data):
    """Calculates effective reproduction number for batches of metapopulations
    :param theta: a tensor of batched theta parameters [B] + theta.shape
    :param xi: a tensor of batched xi parameters [B] + xi.shape
    :param events: a [B, M, T, X] batched events tensor
    :param init_state: the initial state of the epidemic at earliest inference date
    :param covar_data: the covariate data
    :return a batched vector of R_it estimates
    """
    print("Theta shape: ", theta.shape)

    def r_fn(args):
        theta_, xi_, events_ = args
        t = events_.shape[-2] - 1
        state = compute_state(init_state, events_, model_spec.STOICHIOMETRY)
        state = tf.gather(state, t - 1, axis=-2)  # State on final inference day

45
46
47
48
49
50
51
        par = dict(
            beta1=xi_[0],
            beta2=theta_[0],
            beta3=xi_[1:3],
            gamma=theta_[1],
            xi=xi_[3:],
        )
Chris Jewell's avatar
Chris Jewell committed
52
53
54
55
56
57
58
59
60

        ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
        ngm = ngm_fn(t, state)
        return ngm

    return tf.vectorized_map(r_fn, elems=(theta, xi, events))


@tf.function
61
def predicted_incidence(theta, xi, init_state, init_step, num_steps, priors):
Chris Jewell's avatar
Chris Jewell committed
62
63
64
65
66
67
68
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param theta: a tensor of batched theta parameters [B] + theta.shape
    :param xi: a tensor of batched xi parameters [B] + xi.shape
    :param events: a [B, M, S] batched state tensor
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
69
70
    :param priors: the priors for gamma
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
76
              transitions
    """

    def sim_fn(args):
        theta_, xi_, init_ = args

77
        par = dict(beta1=xi_[0], beta2=theta_[0], gamma=theta_[1], xi=xi_[1:])
Chris Jewell's avatar
Chris Jewell committed
78
79
80
81
82
83

        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
84
            priors=priors,
Chris Jewell's avatar
Chris Jewell committed
85
86
87
88
89
        )
        sim = model.sample(**par)
        return sim["seir"]

    events = tf.map_fn(
90
91
92
        sim_fn,
        elems=(theta, xi, init_state),
        fn_output_signature=(tf.float64),
Chris Jewell's avatar
Chris Jewell committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    )
    return events


# Today's prevalence
def prevalence(predicted_state, population_size, name=None):
    """Computes prevalence of E and I individuals

    :param state: the state at a particular timepoint [batch, M, S]
    :param population_size: the size of the population
    :returns: a dict of mean and 95% credibility intervals for prevalence
              in units of infections per person
    """
    prev = tf.reduce_sum(predicted_state[:, :, 1:3], axis=-1) / tf.squeeze(
        population_size
    )
    return mean_and_ci(prev, name=name)


def predicted_events(events, name=None):
    num_events = tf.reduce_sum(events, axis=-1)
    return mean_and_ci(num_events, name=name)


if __name__ == "__main__":

Chris Jewell's avatar
Chris Jewell committed
119
120
    args = cli_args()

Chris Jewell's avatar
Chris Jewell committed
121
    # Get general config
Chris Jewell's avatar
Chris Jewell committed
122
    with open(args.config, "r") as f:
Chris Jewell's avatar
Chris Jewell committed
123
124
        config = yaml.load(f, Loader=yaml.FullLoader)

Chris Jewell's avatar
Chris Jewell committed
125
126
127
128
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

Chris Jewell's avatar
Chris Jewell committed
129
    # Load covariate data
Chris Jewell's avatar
Chris Jewell committed
130
    covar_data = model_spec.read_covariates(
131
132
133
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
134
    )
Chris Jewell's avatar
Chris Jewell committed
135
136
137

    # Load posterior file
    posterior = h5py.File(
Chris Jewell's avatar
Chris Jewell committed
138
        os.path.expandvars(
139
140
141
            os.path.join(
                config["output"]["results_dir"], config["output"]["posterior"]
            )
Chris Jewell's avatar
Chris Jewell committed
142
143
144
145
        ),
        "r",
        rdcc_nbytes=1024 ** 3,
        rdcc_nslots=1e6,
Chris Jewell's avatar
Chris Jewell committed
146
147
148
    )

    # Pre-determined thinning of posterior (better done in MCMC?)
Chris Jewell's avatar
Chris Jewell committed
149
    idx = range(6000, 10000, 10)
Chris Jewell's avatar
Chris Jewell committed
150
151
152
153
    theta = posterior["samples/theta"][idx]
    xi = posterior["samples/xi"][idx]
    events = posterior["samples/events"][idx]
    init_state = posterior["initial_state"][:]
154
155
156
    state_timeseries = compute_state(
        init_state, events, model_spec.STOICHIOMETRY
    )
Chris Jewell's avatar
Chris Jewell committed
157
158
159

    # Build model
    model = model_spec.CovidUK(
160
161
162
163
164
        covar_data,
        initial_state=init_state,
        initial_step=0,
        num_steps=events.shape[1],
        priors=config["mcmc"]["prior"],
Chris Jewell's avatar
Chris Jewell committed
165
166
167
168
169
170
    )

    ngms = calc_R_it(theta, xi, events, init_state, covar_data)
    b, _ = power_iteration(ngms)
    rt = rayleigh_quotient(ngms, b)
    q = np.arange(0.05, 1.0, 0.05)
Chris Jewell's avatar
Chris Jewell committed
171
    rt_quantiles = pd.DataFrame({"Rt": np.quantile(rt, q)}, index=q).T.to_excel(
172
173
174
        os.path.join(
            config["output"]["results_dir"], config["output"]["national_rt"]
        ),
Chris Jewell's avatar
Chris Jewell committed
175
    )
Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
180
181
182
183
184

    # Prediction requires simulation from the last available timepoint for 28 + 4 + 1 days
    # Note a 4 day recording lag in the case timeseries data requires that
    # now = state_timeseries.shape[-2] + 4
    prediction = predicted_incidence(
        theta,
        xi,
        init_state=state_timeseries[..., -1, :],
        init_step=state_timeseries.shape[-2] - 1,
Chris Jewell's avatar
Chris Jewell committed
185
        num_steps=70,
186
        priors=config["mcmc"]["prior"],
Chris Jewell's avatar
Chris Jewell committed
187
188
189
190
191
192
    )
    predicted_state = compute_state(
        state_timeseries[..., -1, :], prediction, model_spec.STOICHIOMETRY
    )

    # Prevalence now
193
194
195
    prev_now = prevalence(
        predicted_state[..., 4, :], covar_data["N"], name="prev"
    )
Chris Jewell's avatar
Chris Jewell committed
196
197
198
199
200
201
202
203
204

    # Incidence of detections now
    cases_now = predicted_events(prediction[..., 4:5, 2], name="cases")

    # Incidence from now to now+7
    cases_7 = predicted_events(prediction[..., 4:11, 2], name="cases7")
    cases_14 = predicted_events(prediction[..., 4:18, 2], name="cases14")
    cases_21 = predicted_events(prediction[..., 4:25, 2], name="cases21")
    cases_28 = predicted_events(prediction[..., 4:32, 2], name="cases28")
Chris Jewell's avatar
Chris Jewell committed
205
    cases_56 = predicted_events(prediction[..., 4:60, 2], name="cases56")
Chris Jewell's avatar
Chris Jewell committed
206
207

    # Prevalence at day 7
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    prev_7 = prevalence(
        predicted_state[..., 11, :], covar_data["N"], name="prev7"
    )
    prev_14 = prevalence(
        predicted_state[..., 18, :], covar_data["N"], name="prev14"
    )
    prev_21 = prevalence(
        predicted_state[..., 25, :], covar_data["N"], name="prev21"
    )
    prev_28 = prevalence(
        predicted_state[..., 32, :], covar_data["N"], name="prev28"
    )
    prev_56 = prevalence(
        predicted_state[..., 60, :], covar_data["N"], name="prev56"
    )
Chris Jewell's avatar
Chris Jewell committed
223
224
225
226
227
228
229
230
231
232

    def geosummary(geodata, summaries):
        for summary in summaries:
            for k, v in summary.items():
                arr = v
                if isinstance(v, tf.Tensor):
                    arr = v.numpy()
                geodata[k] = arr

    ## GIS here
Chris Jewell's avatar
Chris Jewell committed
233
    ltla = gp.read_file(GIS_TEMPLATE, layer="UK2019mod_pop_xgen")
Chris Jewell's avatar
Chris Jewell committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    ltla = ltla[ltla["lad19cd"].str.startswith("E")]  # England only, for now.
    ltla = ltla.sort_values("lad19cd")
    rti = tf.reduce_sum(ngms, axis=-1)

    geosummary(
        ltla,
        (
            mean_and_ci(rti, name="Rt"),
            prev_now,
            cases_now,
            prev_7,
            prev_14,
            prev_21,
            prev_28,
Chris Jewell's avatar
Chris Jewell committed
248
            prev_56,
Chris Jewell's avatar
Chris Jewell committed
249
250
251
252
            cases_7,
            cases_14,
            cases_21,
            cases_28,
Chris Jewell's avatar
Chris Jewell committed
253
            cases_56,
Chris Jewell's avatar
Chris Jewell committed
254
255
256
257
258
259
260
261
262
263
        ),
    )

    ltla["Rt_exceed"] = np.mean(rti > 1.0, axis=0)
    ltla = ltla.loc[
        :,
        ltla.columns.str.contains(
            "(lad19cd|lad19nm$|prev|cases|Rt|popsize|geometry)", regex=True
        ),
    ]
Chris Jewell's avatar
Chris Jewell committed
264
    ltla.to_file(
265
266
267
        os.path.join(
            config["output"]["results_dir"], config["output"]["geopackage"]
        ),
Chris Jewell's avatar
Chris Jewell committed
268
269
        driver="GPKG",
    )