predict.py 4.24 KB
Newer Older
1
2
"""Run predictions for COVID-19 model"""

3
4
import numpy as np
import xarray
5
6
7
8
import pickle as pkl
import tensorflow as tf

from covid import model_spec
9
from covid.util import copy_nc_attrs
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from gemlib.util import compute_state


def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param param: a dictionary of model parameters
    :covar_data: a dictionary of model covariate data
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
              transitions
    """

Chris Jewell's avatar
Chris Jewell committed
24
    @tf.function
25
    def sim_fn(args):
Chris Jewell's avatar
Chris Jewell committed
26
        beta1_, beta2_, sigma_, xi_, gamma0_, gamma1_, init_ = args
27
28
29
30

        par = dict(
            beta1=beta1_,
            beta2=beta2_,
31
            sigma=sigma_,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            xi=xi_,
            gamma0=gamma0_,
            gamma1=gamma1_,
        )
        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
        )
        sim = model.sample(**par)
        return sim["seir"]

    posterior_state = compute_state(
46
        posterior_samples["init_state"],
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        posterior_samples["seir"],
        model_spec.STOICHIOMETRY,
    )
    init_state = posterior_state[..., init_step, :]

    events = tf.map_fn(
        sim_fn,
        elems=(
            posterior_samples["beta1"],
            posterior_samples["beta2"],
            posterior_samples["sigma"],
            posterior_samples["xi"],
            posterior_samples["gamma0"],
            posterior_samples["gamma1"],
            init_state,
        ),
        fn_output_signature=(tf.float64),
    )
65
    return init_state, events
66
67
68
69
70
71
72
73
74


def read_pkl(filename):
    with open(filename, "rb") as f:
        return pkl.load(f)


def predict(data, posterior_samples, output_file, initial_step, num_steps):

75
76
    covar_data = xarray.open_dataset(data, group="constant_data")
    cases = xarray.open_dataset(data, group="observations")
77
78
79
    samples = read_pkl(posterior_samples)

    if initial_step < 0:
80
81
        initial_step = samples["seir"].shape[-2] + initial_step

82
83
84
85
86
87
    origin_date = np.array(cases.coords["time"][0])
    dates = np.arange(
        origin_date + np.timedelta64(initial_step, "D"),
        origin_date + np.timedelta64(initial_step + num_steps, "D"),
        np.timedelta64(1, "D"),
    )
88

89
    estimated_init_state, predicted_events = predicted_incidence(
90
91
92
        samples, covar_data, initial_step, num_steps
    )

93
    prediction = xarray.DataArray(
94
        predicted_events,
95
        coords=[
96
            np.arange(predicted_events.shape[0]),
97
            covar_data.coords["location"],
98
99
            dates,
            np.arange(predicted_events.shape[3]),
100
101
102
        ],
        dims=("iteration", "location", "time", "event"),
    )
103
104
105
106
    estimated_init_state = xarray.DataArray(
        estimated_init_state,
        coords=[
            np.arange(estimated_init_state.shape[0]),
107
            covar_data.coords["location"],
108
109
110
111
112
113
114
            np.arange(estimated_init_state.shape[-1]),
        ],
        dims=("iteration", "location", "state"),
    )
    ds = xarray.Dataset(
        {"events": prediction, "initial_state": estimated_init_state}
    )
115
116
117
    ds.to_netcdf(output_file, group="predictions")
    ds.close()
    copy_nc_attrs(data, output_file)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152


if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
        "-i", "--initial-step", type=int, default=0, description="Initial step"
    )
    parser.add_argument(
        "-n", "--num-steps", type=int, default=1, description="Number of steps"
    )
    parser.add_argument(
        "data_pkl", type=str, description="Covariate data pickle"
    )
    parser.add_argument(
        "posterior_samples_pkl",
        type=str,
        description="Posterior samples pickle",
    )
    parser.add_argument(
        "output_file",
        type=str,
        description="Output pkl file",
    )
    args = parser.parse_args()

    predict(
        args.data_pkl,
        args.posterior_samples_pkl,
        args.output_file,
        args.initial_step,
        args.num_steps,
    )