predict.py 4.3 KB
Newer Older
1
2
"""Run predictions for COVID-19 model"""

3
4
import numpy as np
import xarray
5
import pickle as pkl
6
import pandas as pd
7
8
9
import tensorflow as tf

from covid import model_spec
10
from covid.util import copy_nc_attrs
11
12
from gemlib.util import compute_state

13
14
15
16

def predicted_incidence(
    posterior_samples, init_state, covar_data, init_step, num_steps
):
17
18
19
20
21
22
23
24
25
26
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param param: a dictionary of model parameters
    :covar_data: a dictionary of model covariate data
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
              transitions
    """

27
    posterior_state = compute_state(
28
        init_state,
29
30
31
        posterior_samples["seir"],
        model_spec.STOICHIOMETRY,
    )
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    posterior_samples["new_init_state"] = posterior_state[..., init_step, :]
    del posterior_samples["seir"]

    @tf.function
    def do_sim():
        def sim_fn(args):
            par = tf.nest.pack_sequence_as(posterior_samples, args)
            init_ = par["new_init_state"]
            del par["new_init_state"]

            model = model_spec.CovidUK(
                covar_data,
                initial_state=init_,
                initial_step=init_step,
                num_steps=num_steps,
            )
            sim = model.sample(**par)
            return sim["seir"]

        return tf.map_fn(
            sim_fn,
            elems=tf.nest.flatten(posterior_samples),
            fn_output_signature=(tf.float64),
55
56
        )

57
    return posterior_samples["new_init_state"], do_sim()
58
59
60
61
62
63
64
65
66


def read_pkl(filename):
    with open(filename, "rb") as f:
        return pkl.load(f)


def predict(data, posterior_samples, output_file, initial_step, num_steps):

67
68
    covar_data = xarray.open_dataset(data, group="constant_data")
    cases = xarray.open_dataset(data, group="observations")
69

70
    samples = read_pkl(posterior_samples)
71
72
    initial_state = samples["initial_state"]
    del samples["initial_state"]
73
74

    if initial_step < 0:
75
76
        initial_step = samples["seir"].shape[-2] + initial_step

77
78
    origin_date = np.array(cases.coords["time"][0])
    dates = np.arange(
79
        origin_date,
80
81
82
        origin_date + np.timedelta64(initial_step + num_steps, "D"),
        np.timedelta64(1, "D"),
    )
83

84
85
86
87
88
89
    covar_data["weekday"] = xarray.DataArray(
        (pd.to_datetime(dates).weekday < 5).astype(model_spec.DTYPE),
        coords=[dates],
        dims=["prediction_time"],
    )

90
    estimated_init_state, predicted_events = predicted_incidence(
91
        samples, initial_state, covar_data, initial_step, num_steps
92
93
    )

94
    prediction = xarray.DataArray(
95
        predicted_events.numpy(),
96
        coords=[
97
            np.arange(predicted_events.shape[0]),
98
            covar_data.coords["location"],
99
            dates[initial_step:],
100
            np.arange(predicted_events.shape[3]),
101
102
103
        ],
        dims=("iteration", "location", "time", "event"),
    )
104
    estimated_init_state = xarray.DataArray(
105
        estimated_init_state.numpy(),
106
107
        coords=[
            np.arange(estimated_init_state.shape[0]),
108
            covar_data.coords["location"],
109
110
111
112
113
114
115
            np.arange(estimated_init_state.shape[-1]),
        ],
        dims=("iteration", "location", "state"),
    )
    ds = xarray.Dataset(
        {"events": prediction, "initial_state": estimated_init_state}
    )
116
117
118
    ds.to_netcdf(output_file, group="predictions")
    ds.close()
    copy_nc_attrs(data, output_file)
119
120
121
122
123
124
125
126


if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
127
        "-i", "--initial-step", type=int, default=0, help="Initial step"
128
129
    )
    parser.add_argument(
130
        "-n", "--num-steps", type=int, default=1, help="Number of steps"
131
    )
132
    parser.add_argument("data_pkl", type=str, help="Covariate data pickle")
133
134
135
    parser.add_argument(
        "posterior_samples_pkl",
        type=str,
136
        help="Posterior samples pickle",
137
138
139
140
    )
    parser.add_argument(
        "output_file",
        type=str,
141
        help="Output pkl file",
142
143
144
145
146
147
148
149
150
151
    )
    args = parser.parse_args()

    predict(
        args.data_pkl,
        args.posterior_samples_pkl,
        args.output_file,
        args.initial_step,
        args.num_steps,
    )