summary_longformat.py 4.77 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Produces a long-format summary of fitted model results"""

import pickle as pkl
from datetime import date
import numpy as np
import pandas as pd
import xarray

from gemlib.util import compute_state
from covid.model_spec import STOICHIOMETRY
from covid import model_spec
from covid.formats import make_dstl_template


def xarray2summarydf(arr):
    mean = arr.mean(dim="iteration").to_dataset(name="value")
17
18
19
    q = np.arange(start=0.05, stop=1.0, step=0.05)
    quantiles = arr.quantile(q=q, dim="iteration").to_dataset(dim="quantile")
    ds = mean.merge(quantiles).rename_vars({qi: f"{qi:.2f}" for qi in q})
Chris Jewell's avatar
Chris Jewell committed
20
21
22
    return ds.to_dataframe().reset_index()


23
24
25
26
def prevalence(prediction, popsize):
    prev = compute_state(
        prediction["initial_state"], prediction["events"], STOICHIOMETRY
    )
Chris Jewell's avatar
Chris Jewell committed
27
28
29
30
    prev = xarray.DataArray(
        prev.numpy(),
        coords=[
            np.arange(prev.shape[0]),
31
32
            prediction.coords["location"],
            prediction.coords["time"],
Chris Jewell's avatar
Chris Jewell committed
33
34
35
36
37
38
39
40
41
42
43
44
            np.arange(prev.shape[-1]),
        ],
        dims=["iteration", "location", "time", "state"],
    )
    prev_per_1e5 = (
        prev[..., 1:3].sum(dim="state").reset_coords(drop=True)
        / popsize[np.newaxis, :, np.newaxis]
        * 100000
    )
    return xarray2summarydf(prev_per_1e5)


45
def weekly_pred_cases_per_100k(prediction, popsize):
Chris Jewell's avatar
Chris Jewell committed
46
    """Returns weekly number of cases per 100k of population"""
47
48

    prediction = prediction[..., 2]  # Case removals
49
50
    prediction = prediction.reset_coords(drop=True)

Chris Jewell's avatar
Chris Jewell committed
51
52
    # TODO: Find better way to sum up into weeks other than
    # a list comprehension.
53
54
55
56
57
    dates = pd.DatetimeIndex(prediction.coords["time"].data)
    first_sunday_index = np.where(dates.weekday == 6)[0][0]
    weeks = range(first_sunday_index, prediction.coords["time"].shape[0], 7)[
        :-1
    ]
58
59
60
61
62
63
64
65
66
    week_incidence = [
        prediction[..., week : (week + 7)].sum(dim="time") for week in weeks
    ]
    week_incidence = xarray.concat(
        week_incidence, dim=prediction.coords["time"][weeks]
    )
    week_incidence = week_incidence.transpose(
        *prediction.dims, transpose_coords=True
    )
Chris Jewell's avatar
Chris Jewell committed
67
    # Divide by population sizes
68
69
70
71
72
73
    week_incidence = (
        week_incidence / popsize[np.newaxis, :, np.newaxis] * 100000
    )
    return xarray2summarydf(week_incidence)


Chris Jewell's avatar
Chris Jewell committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def summary_longformat(input_files, output_file):
    """Draws together pipeline results into a long format
       csv file.

    :param input_files: a list of filenames [data_pkl,
                                             insample14_pkl,
                                             medium_term_pred_pkl,
                                             ngm_pkl]
    :param output_file: the output CSV with columns `[date,
                        location,value_name,value,q0.025,q0.975]`
    """

    with open(input_files[0], "rb") as f:
        data = pkl.load(f)
    da = data["cases"].rename({"date": "time"})
    df = da.to_dataframe(name="value").reset_index()
    df["value_name"] = "newCasesBySpecimenDate"
    df["0.05"] = np.nan
    df["0.5"] = np.nan
    df["0.95"] = np.nan

    # Insample predictive incidence
96
97
98
99
    insample = xarray.open_dataset(input_files[1])
    insample_df = xarray2summarydf(
        insample["events"][..., 2].reset_coords(drop=True)
    )
Chris Jewell's avatar
Chris Jewell committed
100
101
102
103
    insample_df["value_name"] = "insample14_Cases"
    df = pd.concat([df, insample_df], axis="index")

    # Medium term incidence
104
105
106
107
    medium_term = xarray.open_dataset(input_files[2])
    medium_df = xarray2summarydf(
        medium_term["events"][..., 2].reset_coords(drop=True)
    )
Chris Jewell's avatar
Chris Jewell committed
108
109
110
    medium_df["value_name"] = "Cases"
    df = pd.concat([df, medium_df], axis="index")

111
    # Weekly incidence per 100k
112
113
114
    weekly_incidence = weekly_pred_cases_per_100k(
        medium_term["events"], data["N"]
    )
115
116
117
    weekly_incidence["value_name"] = "weekly_cases_per_100k"
    df = pd.concat([df, weekly_incidence], axis="index")

Chris Jewell's avatar
Chris Jewell committed
118
119
    # Medium term prevalence
    prev_df = prevalence(medium_term, data["N"])
120
    prev_df["value_name"] = "prevalence"
Chris Jewell's avatar
Chris Jewell committed
121
122
123
    df = pd.concat([df, prev_df], axis="index")

    # Rt
124
    ngms = xarray.load_dataset(input_files[3])["ngm"]
Chris Jewell's avatar
Chris Jewell committed
125
126
127
128
129
130
131
    rt = ngms.sum(dim="dest")
    rt = rt.rename({"src": "location"})
    rt_summary = xarray2summarydf(rt)
    rt_summary["value_name"] = "R"
    rt_summary["time"] = data["date_range"][1]
    df = pd.concat([df, rt_summary], axis="index")

132
133
    quantiles = df.columns[df.columns.str.startswith("0.")]

Chris Jewell's avatar
Chris Jewell committed
134
135
136
    return make_dstl_template(
        group="Lancaster",
        model="SpatialStochasticSEIR",
137
        scenario="Nowcast",
Chris Jewell's avatar
Chris Jewell committed
138
139
140
141
142
143
        creation_date=date.today(),
        version=model_spec.VERSION,
        age_band="All",
        geography=df["location"],
        value_date=df["time"],
        value_type=df["value_name"],
Chris Jewell's avatar
Chris Jewell committed
144
        value=df["value"],
145
        quantiles={q: df[q] for q in quantiles},
146
    ).to_excel(output_file, index=False)