predict.py 3.95 KB
Newer Older
1
2
"""Run predictions for COVID-19 model"""

3
4
import numpy as np
import xarray
5
6
7
8
import pickle as pkl
import tensorflow as tf

from covid import model_spec
9
from covid.util import copy_nc_attrs
10
11
from gemlib.util import compute_state

12
@tf.function
13
14
15
16
17
18
19
20
21
22
23
def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
    """Runs the simulation forward in time from `init_state` at time `init_time`
       for `num_steps`.
    :param param: a dictionary of model parameters
    :covar_data: a dictionary of model covariate data
    :param init_step: the initial time step
    :param num_steps: the number of steps to simulate
    :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
              transitions
    """

24
25
26
27
28
29
30
31
    posterior_state = compute_state(
        posterior_samples["init_state"],
        posterior_samples["seir"],
        model_spec.STOICHIOMETRY,
    )
    posterior_samples['init_state_'] = posterior_state[..., init_step, :]
    del posterior_samples['seir']
    
32
    def sim_fn(args):
33
34
35
36
        par = tf.nest.pack_sequence_as(posterior_samples, args)
        init_ = par['init_state_']
        del par['init_state_']
        
37
38
39
40
41
42
43
44
45
46
47
        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
        )
        sim = model.sample(**par)
        return sim["seir"]

    events = tf.map_fn(
        sim_fn,
48
        elems=tf.nest.flatten(posterior_samples),
49
50
        fn_output_signature=(tf.float64),
    )
51
    return init_state, events
52
53
54
55
56
57
58
59
60


def read_pkl(filename):
    with open(filename, "rb") as f:
        return pkl.load(f)


def predict(data, posterior_samples, output_file, initial_step, num_steps):

61
62
    covar_data = xarray.open_dataset(data, group="constant_data")
    cases = xarray.open_dataset(data, group="observations")
63
64
65
    samples = read_pkl(posterior_samples)

    if initial_step < 0:
66
67
        initial_step = samples["seir"].shape[-2] + initial_step

68
69
70
71
72
73
    origin_date = np.array(cases.coords["time"][0])
    dates = np.arange(
        origin_date + np.timedelta64(initial_step, "D"),
        origin_date + np.timedelta64(initial_step + num_steps, "D"),
        np.timedelta64(1, "D"),
    )
74

75
    estimated_init_state, predicted_events = predicted_incidence(
76
77
78
        samples, covar_data, initial_step, num_steps
    )

79
    prediction = xarray.DataArray(
80
        predicted_events,
81
        coords=[
82
            np.arange(predicted_events.shape[0]),
83
            covar_data.coords["location"],
84
85
            dates,
            np.arange(predicted_events.shape[3]),
86
87
88
        ],
        dims=("iteration", "location", "time", "event"),
    )
89
90
91
92
    estimated_init_state = xarray.DataArray(
        estimated_init_state,
        coords=[
            np.arange(estimated_init_state.shape[0]),
93
            covar_data.coords["location"],
94
95
96
97
98
99
100
            np.arange(estimated_init_state.shape[-1]),
        ],
        dims=("iteration", "location", "state"),
    )
    ds = xarray.Dataset(
        {"events": prediction, "initial_state": estimated_init_state}
    )
101
102
103
    ds.to_netcdf(output_file, group="predictions")
    ds.close()
    copy_nc_attrs(data, output_file)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
        "-i", "--initial-step", type=int, default=0, description="Initial step"
    )
    parser.add_argument(
        "-n", "--num-steps", type=int, default=1, description="Number of steps"
    )
    parser.add_argument(
        "data_pkl", type=str, description="Covariate data pickle"
    )
    parser.add_argument(
        "posterior_samples_pkl",
        type=str,
        description="Posterior samples pickle",
    )
    parser.add_argument(
        "output_file",
        type=str,
        description="Output pkl file",
    )
    args = parser.parse_args()

    predict(
        args.data_pkl,
        args.posterior_samples_pkl,
        args.output_file,
        args.initial_step,
        args.num_steps,
    )