predict.py 4.12 KB
Newer Older
1
2
"""Run predictions for COVID-19 model"""

3
4
import numpy as np
import xarray
5
6
7
8
import pickle as pkl
import tensorflow as tf

from covid import model_spec
9
from covid.util import copy_nc_attrs
10
11
from gemlib.util import compute_state

12
13
14
15

def predicted_incidence(
    posterior_samples, init_state, covar_data, init_step, num_steps
):
16
17
18
19
20
21
22
23
24
25
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param param: a dictionary of model parameters
    :covar_data: a dictionary of model covariate data
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
              transitions
    """

26
    posterior_state = compute_state(
27
        init_state,
28
29
30
        posterior_samples["seir"],
        model_spec.STOICHIOMETRY,
    )
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    posterior_samples["new_init_state"] = posterior_state[..., init_step, :]
    del posterior_samples["seir"]

    @tf.function
    def do_sim():
        def sim_fn(args):
            par = tf.nest.pack_sequence_as(posterior_samples, args)
            init_ = par["new_init_state"]
            del par["new_init_state"]

            model = model_spec.CovidUK(
                covar_data,
                initial_state=init_,
                initial_step=init_step,
                num_steps=num_steps,
            )
            sim = model.sample(**par)
            return sim["seir"]

        return tf.map_fn(
            sim_fn,
            elems=tf.nest.flatten(posterior_samples),
            fn_output_signature=(tf.float64),
54
55
        )

56
    return posterior_samples["new_init_state"], do_sim()
57
58
59
60
61
62
63
64
65


def read_pkl(filename):
    with open(filename, "rb") as f:
        return pkl.load(f)


def predict(data, posterior_samples, output_file, initial_step, num_steps):

66
67
    covar_data = xarray.open_dataset(data, group="constant_data")
    cases = xarray.open_dataset(data, group="observations")
68

69
    samples = read_pkl(posterior_samples)
70
71
    initial_state = samples["initial_state"]
    del samples["initial_state"]
72
73

    if initial_step < 0:
74
75
        initial_step = samples["seir"].shape[-2] + initial_step

76
77
78
79
80
81
    origin_date = np.array(cases.coords["time"][0])
    dates = np.arange(
        origin_date + np.timedelta64(initial_step, "D"),
        origin_date + np.timedelta64(initial_step + num_steps, "D"),
        np.timedelta64(1, "D"),
    )
82

83
    estimated_init_state, predicted_events = predicted_incidence(
84
        samples, initial_state, covar_data, initial_step, num_steps
85
86
    )

87
    prediction = xarray.DataArray(
88
        predicted_events.numpy(),
89
        coords=[
90
            np.arange(predicted_events.shape[0]),
91
            covar_data.coords["location"],
92
93
            dates,
            np.arange(predicted_events.shape[3]),
94
95
96
        ],
        dims=("iteration", "location", "time", "event"),
    )
97
    estimated_init_state = xarray.DataArray(
98
        estimated_init_state.numpy(),
99
100
        coords=[
            np.arange(estimated_init_state.shape[0]),
101
            covar_data.coords["location"],
102
103
104
105
106
107
108
            np.arange(estimated_init_state.shape[-1]),
        ],
        dims=("iteration", "location", "state"),
    )
    ds = xarray.Dataset(
        {"events": prediction, "initial_state": estimated_init_state}
    )
109
110
111
    ds.to_netcdf(output_file, group="predictions")
    ds.close()
    copy_nc_attrs(data, output_file)
112
113
114
115
116
117
118
119


if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
120
        "-i", "--initial-step", type=int, default=0, help="Initial step"
121
122
    )
    parser.add_argument(
123
        "-n", "--num-steps", type=int, default=1, help="Number of steps"
124
    )
125
    parser.add_argument("data_pkl", type=str, help="Covariate data pickle")
126
127
128
    parser.add_argument(
        "posterior_samples_pkl",
        type=str,
129
        help="Posterior samples pickle",
130
131
132
133
    )
    parser.add_argument(
        "output_file",
        type=str,
134
        help="Output pkl file",
135
136
137
138
139
140
141
142
143
144
    )
    args = parser.parse_args()

    predict(
        args.data_pkl,
        args.posterior_samples_pkl,
        args.output_file,
        args.initial_step,
        args.num_steps,
    )