summary_longformat.py 3.5 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Produces a long-format summary of fitted model results"""

import pickle as pkl
from datetime import date
import numpy as np
import pandas as pd
import xarray

from gemlib.util import compute_state
from covid.model_spec import STOICHIOMETRY
from covid import model_spec
from covid.formats import make_dstl_template


def xarray2summarydf(arr):
    mean = arr.mean(dim="iteration").to_dataset(name="value")
    quantiles = arr.quantile(q=[0.05, 0.5, 0.95], dim="iteration").to_dataset(
        dim="quantile"
    )
    ds = mean.merge(quantiles).rename_vars(
        {0.05: "0.05", 0.5: "0.5", 0.95: "0.95"}
    )
    return ds.to_dataframe().reset_index()


def prevalence(events, popsize):
    prev = compute_state(events.attrs["initial_state"], events, STOICHIOMETRY)
    prev = xarray.DataArray(
        prev.numpy(),
        coords=[
            np.arange(prev.shape[0]),
            events.coords["location"],
            events.coords["time"],
            np.arange(prev.shape[-1]),
        ],
        dims=["iteration", "location", "time", "state"],
    )
    prev_per_1e5 = (
        prev[..., 1:3].sum(dim="state").reset_coords(drop=True)
        / popsize[np.newaxis, :, np.newaxis]
        * 100000
    )
    return xarray2summarydf(prev_per_1e5)


def summary_longformat(input_files, output_file):
    """Draws together pipeline results into a long format
       csv file.

    :param input_files: a list of filenames [data_pkl,
                                             insample14_pkl,
                                             medium_term_pred_pkl,
                                             ngm_pkl]
    :param output_file: the output CSV with columns `[date,
                        location,value_name,value,q0.025,q0.975]`
    """

    with open(input_files[0], "rb") as f:
        data = pkl.load(f)
    da = data["cases"].rename({"date": "time"})
    df = da.to_dataframe(name="value").reset_index()
    df["value_name"] = "newCasesBySpecimenDate"
    df["0.05"] = np.nan
    df["0.5"] = np.nan
    df["0.95"] = np.nan

    # Insample predictive incidence
    with open(input_files[1], "rb") as f:
        insample = pkl.load(f)
    insample_df = xarray2summarydf(insample[..., 2].reset_coords(drop=True))
    insample_df["value_name"] = "insample14_Cases"
    df = pd.concat([df, insample_df], axis="index")

    # Medium term incidence
    with open(input_files[2], "rb") as f:
        medium_term = pkl.load(f)
    medium_df = xarray2summarydf(medium_term[..., 2].reset_coords(drop=True))
    medium_df["value_name"] = "Cases"
    df = pd.concat([df, medium_df], axis="index")

    # Medium term prevalence
    prev_df = prevalence(medium_term, data["N"])
83
    prev_df["value_name"] = "prevalence"
Chris Jewell's avatar
Chris Jewell committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    df = pd.concat([df, prev_df], axis="index")

    # Rt
    with open(input_files[3], "rb") as f:
        ngms = pkl.load(f)
    rt = ngms.sum(dim="dest")
    rt = rt.rename({"src": "location"})
    rt_summary = xarray2summarydf(rt)
    rt_summary["value_name"] = "R"
    rt_summary["time"] = data["date_range"][1]
    df = pd.concat([df, rt_summary], axis="index")

    return make_dstl_template(
        group="Lancaster",
        model="SpatialStochasticSEIR",
99
        scenario="Nowcast",
Chris Jewell's avatar
Chris Jewell committed
100
101
102
103
104
105
106
107
108
109
110
        creation_date=date.today(),
        version=model_spec.VERSION,
        age_band="All",
        geography=df["location"],
        value_date=df["time"],
        value_type=df["value_name"],
        quantiles={
            "0.05": df["0.05"],
            "0.5": df["0.5"],
            "0.95": df["0.95"],
        },
111
    ).to_excel(output_file, index=False)