summarize.py 3.04 KB
Newer Older
1
2
3
4
"""Summary functions"""

import numpy as np
import pickle as pkl
5
import xarray
6
7
8
9
10
11
import pandas as pd

from gemlib.util import compute_state
from covid.summary import mean_and_ci
from covid.model_spec import STOICHIOMETRY

12
SUMMARY_DAYS = np.array([1, 7, 14, 21, 28, 35, 42, 49, 56], np.int32)
13

Chris Jewell's avatar
Chris Jewell committed
14

15
16
17
18
19
20
21
22
def rt(input_file, output_file):
    """Reads an array of next generation matrices and
       outputs mean (ci) local Rt values.

    :param input_file: a pickled xarray of NGMs
    :param output_file: a .csv of mean (ci) values
    """

Chris Jewell's avatar
Chris Jewell committed
23
    r_it = xarray.open_dataset(input_file, group="posterior_predictive")["R_it"]
24

Chris Jewell's avatar
Chris Jewell committed
25
    rt = r_it.isel(time=-1).drop("time")
26
    rt_summary = mean_and_ci(rt, name="Rt")
Chris Jewell's avatar
Chris Jewell committed
27
    exceed = np.mean(rt > 1.0, axis=0)
Chris Jewell's avatar
Chris Jewell committed
28

29
    rt_summary = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
30
        rt_summary, index=pd.Index(r_it.coords["location"], name="location")
31
    )
Chris Jewell's avatar
Chris Jewell committed
32
    rt_summary["Rt_exceed"] = exceed
33
34
35
36
37
38
39
40
41
42
43
    rt_summary.to_csv(output_file)


def infec_incidence(input_file, output_file):
    """Summarises cumulative infection incidence
      as a nowcast, 7, 14, 28, and 56 days.

    :param input_file: a pkl of the medium term prediction
    :param output_file: csv with prediction summaries
    """

44
    prediction = xarray.open_dataset(input_file, group="predictions")["events"]
45
46

    offset = 4
Chris Jewell's avatar
Chris Jewell committed
47
    timepoints = SUMMARY_DAYS + offset
48
49
50
51
52
53
54
55
56

    # Absolute incidence
    def pred_events(events, name=None):
        num_events = np.sum(events, axis=-1)
        return mean_and_ci(num_events, name=name)

    idx = prediction.coords["location"]

    abs_incidence = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
57
58
        pred_events(prediction[..., offset : (offset + 1), 2], name="cases"),
        index=idx,
59
60
61
    )
    for t in timepoints[1:]:
        tmp = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
62
            pred_events(prediction[..., offset:t, 2], name=f"cases{t-offset}"),
63
64
65
66
67
68
69
70
71
72
73
            index=idx,
        )
        abs_incidence = pd.concat([abs_incidence, tmp], axis="columns")

    abs_incidence.to_csv(output_file)


def prevalence(input_files, output_file):
    """Reconstruct predicted prevalence from
       original data and projection.

74
    :param input_files: a list of [data pickle, prediction netCDF]
75
76
77
    :param output_file: a csv containing prevalence summary
    """
    offset = 4  # Account for recording lag
Chris Jewell's avatar
Chris Jewell committed
78
    timepoints = SUMMARY_DAYS + offset
79

80
81
    data = xarray.open_dataset(input_files[0], group="constant_data")
    prediction = xarray.open_dataset(input_files[1], group="predictions")
82
83

    predicted_state = compute_state(
84
        prediction["initial_state"], prediction["events"], STOICHIOMETRY
85
86
87
    )

    def calc_prev(state, name=None):
88
        prev = np.sum(state[..., 1:3], axis=-1) / np.array(data["N"])
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        return mean_and_ci(prev, name=name)

    idx = prediction.coords["location"]
    prev = pd.DataFrame(
        calc_prev(predicted_state[..., timepoints[0], :], name="prev"),
        index=idx,
    )
    for t in timepoints[1:]:
        tmp = pd.DataFrame(
            calc_prev(predicted_state[..., t, :], name=f"prev{t-offset}"),
            index=idx,
        )
        prev = pd.concat([prev, tmp], axis="columns")

    prev.to_csv(output_file)