next_generation_matrix.py 3.33 KB
Newer Older
1
2
3
"""Calculates and saves a next generation matrix"""

import pickle as pkl
4
import numpy as np
5
import xarray
6
7
8
import tensorflow as tf

from covid import model_spec
9
from covid.util import copy_nc_attrs
10
11
12
from gemlib.util import compute_state


Chris Jewell's avatar
Chris Jewell committed
13
def calc_posterior_rit(samples, initial_state, times, covar_data):
14
15
16
17
18
19
20
21
    """Calculates effective reproduction number for batches of metapopulations
    :param theta: a tensor of batched theta parameters [B] + theta.shape
    :param xi: a tensor of batched xi parameters [B] + xi.shape
    :param events: a [B, M, T, X] batched events tensor
    :param init_state: the initial state of the epidemic at earliest inference date
    :param covar_data: the covariate data
    :return a batched vector of R_it estimates
    """
22
    times = tf.convert_to_tensor(times)
23
24

    def r_fn(args):
25
26

        par = tf.nest.pack_sequence_as(samples, args)
27

28
        state = compute_state(
29
            initial_state, par["seir"], model_spec.STOICHIOMETRY
30
        )
31
        del par["seir"]
32

33
34
35
36
37
38
39
40
        def fn(t):
            state_ = tf.gather(
                state, t, axis=-2
            )  # State on final inference day
            ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
            ngm = ngm_fn(t, state_)
            return ngm

Chris Jewell's avatar
Chris Jewell committed
41
42
        ngm = tf.vectorized_map(fn, elems=times)
        return tf.reduce_sum(ngm, axis=-2)  # sum over destinations
43
44
45

    return tf.vectorized_map(
        r_fn,
46
        elems=tf.nest.flatten(samples),
47
48
49
    )


50
51
52
CHUNKSIZE = 50


53
def reproduction_number(input_files, output_file):
54

55
    covar_data = xarray.open_dataset(input_files[0], group="constant_data")
56
57

    with open(input_files[1], "rb") as f:
58
        samples = pkl.load(f)
59
    num_samples = samples["seir"].shape[0]
60

61
62
63
    initial_state = samples["initial_state"]
    del samples["initial_state"]

Chris Jewell's avatar
Chris Jewell committed
64
    times = np.arange(covar_data.coords["time"].shape[0])
65

66
67
68
69
70
71
72
73
74
75
    # Compute ngm posterior in chunks to prevent over-memory
    r_its = []
    for i in range(0, num_samples, CHUNKSIZE):
        start = i
        end = np.minimum(i + CHUNKSIZE, num_samples)
        print(f"Chunk {start}:{end}", flush=True)
        subsamples = {k: v[start:end] for k, v in samples.items()}
        r_it = calc_posterior_rit(subsamples, initial_state, times, covar_data)
        r_its.append(r_it)

Chris Jewell's avatar
Chris Jewell committed
76
    r_it = xarray.DataArray(
77
        tf.concat(r_its, axis=0),
78
        coords=[
79
            np.arange(num_samples),
80
            covar_data.coords["time"][times],
81
            covar_data.coords["location"],
82
        ],
Chris Jewell's avatar
Chris Jewell committed
83
        dims=["iteration", "time", "location"],
84
    )
Chris Jewell's avatar
Chris Jewell committed
85
86
87
    weight = covar_data["N"] / covar_data["N"].sum()
    r_t = (r_it * weight).sum(dim="location")
    ds = xarray.Dataset({"R_it": r_it, "R_t": r_t})
88

89
    # Output
Chris Jewell's avatar
Chris Jewell committed
90
    ds.to_netcdf(output_file, group="posterior_predictive")
91
    copy_nc_attrs(input_files[0], output_file)
92
93
94
95


if __name__ == "__main__":

96
97
98
99
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
100
        "samples",
101
        type=str,
102
        help="A pickle file with MCMC samples",
103
104
105
106
107
    )
    parser.add_argument(
        "-d",
        "--data",
        type=str,
108
109
        help="A data glob pickle file",
        required=True,
110
111
    )
    parser.add_argument(
112
        "-o", "--output", type=str, help="The output file", required=True
113
114
115
    )
    args = parser.parse_args()

116
    reproduction_number([args.data, args.samples], args.output)