Commit c5d7033e authored by Chris Jewell's avatar Chris Jewell
Browse files

Renamed module to

parent 4a43ffde
"""Calculates case exceedance probabilities"""
import numpy as np
import pickle as pkl
import xarray
def case_exceedance(input_files, lag):
"""Calculates case exceedance probabilities,
i.e. Pr(pred[lag:] < observed[lag:])
:param input_files: [data pickle, prediction pickle]
:param lag: the lag for which to calculate the exceedance
"""
data_file, prediction_file = input_files
data = xarray.open_dataset(data_file, group="observations")
prediction = xarray.open_dataset(prediction_file, group="predictions")[
"events"
]
modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1)
observed_cases = np.sum(data["cases"][:, -lag:], axis=-1)
exceedance = np.mean(modelled_cases < observed_cases, axis=0)
return exceedance
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(
description="Calculates case exceedance probabilities"
)
parser.add_argument("data_file", type=str)
parser.add_argument("prediction_file", type=str)
parser.add_argument(
"-l",
"--lag",
type=int,
help="The lag for which to calculate exceedance",
default=7,
)
parser.add_argument(
"-o",
"--output",
type=str,
help="The output csv",
default="exceedance.csv",
)
args = parser.parse_args()
df = case_exceedance([args.data_file, args.prediction_file], args.lag)
df.to_csv(args.output)
"""Create insample plots for a given lag"""
import pickle as pkl
import xarray
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
def plot_timeseries(mean, quantiles, data, dates, title):
"""Plots a predictive timeseries with data
:param prediction: a [5, T]-shaped array with first dimension
representing quantiles, and T the number of
time points.
:param data: an array of shape [T] providing the data
:param dates: an array of shape [T] of type np.datetime64
:param title: the plot title
:returns: a matplotlib axis
"""
fig = plt.figure()
# In-sample prediction
plt.fill_between(
dates, y1=quantiles[0], y2=quantiles[-1], color="lightblue", alpha=0.5
)
plt.fill_between(
dates, y1=quantiles[1], y2=quantiles[-2], color="lightblue", alpha=1.0
)
plt.plot(dates, mean, color="blue")
plt.plot(dates, data, "+", color="red")
plt.title(title)
fig.autofmt_xdate()
return fig
def insample_predictive_timeseries(input_files, output_dir, lag):
"""Creates insample plots
:param input_files: a list of [prediction_file, data_file] (see Details)
:param output_dir: the output dir to write files to
:param lag: the number of days at the end of the case timeseries for which to
plot the in-sample prediction.
:returns: `None` as output written to disc.
Details
-------
`data_file` is a pickled Python `dict` of data. It should have a member `cases`
which is a `xarray` with dimensions [`location`, `date`] giving the number of
detected cases in each `location` on each `date`.
`prediction_file` is assumed to be a pickled `xarray` of shape
`[K,M,T,R]` where `K` is the number of posterior samples, `M` is the number
of locations, `T` is the number of timepoints, `R` is the number of transitions
in the model. The prediction is assumed to start at `cases.coords['date'][-1] - lag`.
It is assumed that `T >= lag`.
A timeseries graph (png) summarizing for each `location` the prediction against the
observations is written to `output_dir`
"""
prediction_file, data_file = input_files
lag = int(lag)
prediction = xarray.open_dataset(prediction_file, group="predictions")[
"events"
]
prediction = prediction[..., :lag, -1] # Just removals
cases = xarray.open_dataset(data_file, group="observations")["cases"]
lads = xarray.open_dataset(data_file, group="constant_data")["locations"]
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
pred_mean = prediction.mean(dim="iteration")
pred_quants = prediction.quantile(
q=[0.025, 0.25, 0.5, 0.75, 0.975],
dim="iteration",
)
for location in cases.coords["location"]:
print("Location:", location.data)
fig = plot_timeseries(
pred_mean.loc[location, :],
pred_quants.loc[:, location, :],
cases.loc[location][-lag:],
cases.coords["time"][-lag:],
lads.loc[location].data,
)
plt.savefig(output_dir.joinpath(f"{location.data}.png"))
plt.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_pkl", type=str, required=True, help="Pipeline data pickle"
)
parser.add_argument(
"--insample",
type=str,
required=True,
help="Insample prediction netCDF4",
)
parser.add_argument(
"--output-dir", type=str, required=True, help="Output directory"
)
parser.add_argument("--lag", type=int, required=True, help="Lag")
args = parser.parse_args()
insample_predictive_timeseries(
[args.insample, args.data_pkl], args.output_dir, args.lag
)
"""Calculates overall Rt given a posterior next generation matix"""
import numpy as np
import xarray
import pandas as pd
from covid.summary import (
rayleigh_quotient,
power_iteration,
)
def overall_rt(inference_data, output_file):
r_t = xarray.open_dataset(inference_data, group="posterior_predictive")[
"R_t"
]
q = np.arange(0.05, 1.0, 0.05)
quantiles = r_t.isel(time=-1).quantile(q=q)
quantiles.to_dataframe().T.to_excel(output_file)
# pd.DataFrame({"Rt": np.quantile(r_t, q, axis=-1)}, index=q).T.to_excel(
# output_file
# )
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument(
"input_file",
type=str,
help="The input .pkl file containing the next generation matrix",
)
parser.add_argument(
"output_file", type=str, help="The name of the output .xlsx file"
)
args = parser.parse_args()
overall_rt(args.input_file, args.output_file)
"""Summary functions"""
import numpy as np
import pickle as pkl
import xarray
import pandas as pd
from gemlib.util import compute_state
from covid.summary import mean_and_ci
from covid.model_spec import STOICHIOMETRY
SUMMARY_DAYS = np.array([1, 7, 14, 21, 28, 35, 42, 49, 56], np.int32)
def rt(input_file, output_file):
"""Reads an array of next generation matrices and
outputs mean (ci) local Rt values.
:param input_file: a pickled xarray of NGMs
:param output_file: a .csv of mean (ci) values
"""
r_it = xarray.open_dataset(input_file, group="posterior_predictive")["R_it"]
rt = r_it.isel(time=-1).drop("time")
rt_summary = mean_and_ci(rt, name="Rt")
exceed = np.mean(rt > 1.0, axis=0)
rt_summary = pd.DataFrame(
rt_summary, index=pd.Index(r_it.coords["location"], name="location")
)
rt_summary["Rt_exceed"] = exceed
rt_summary.to_csv(output_file)
def infec_incidence(input_file, output_file):
"""Summarises cumulative infection incidence
as a nowcast, 7, 14, 28, and 56 days.
:param input_file: a pkl of the medium term prediction
:param output_file: csv with prediction summaries
"""
prediction = xarray.open_dataset(input_file, group="predictions")["events"]
offset = 4
timepoints = SUMMARY_DAYS + offset
# Absolute incidence
def pred_events(events, name=None):
num_events = np.sum(events, axis=-1)
return mean_and_ci(num_events, name=name)
idx = prediction.coords["location"]
abs_incidence = pd.DataFrame(
pred_events(prediction[..., offset : (offset + 1), 2], name="cases"),
index=idx,
)
for t in timepoints[1:]:
tmp = pd.DataFrame(
pred_events(prediction[..., offset:t, 2], name=f"cases{t-offset}"),
index=idx,
)
abs_incidence = pd.concat([abs_incidence, tmp], axis="columns")
abs_incidence.to_csv(output_file)
def prevalence(input_files, output_file):
"""Reconstruct predicted prevalence from
original data and projection.
:param input_files: a list of [data pickle, prediction netCDF]
:param output_file: a csv containing prevalence summary
"""
offset = 4 # Account for recording lag
timepoints = SUMMARY_DAYS + offset
data = xarray.open_dataset(input_files[0], group="constant_data")
prediction = xarray.open_dataset(input_files[1], group="predictions")
predicted_state = compute_state(
prediction["initial_state"], prediction["events"], STOICHIOMETRY
)
def calc_prev(state, name=None):
prev = np.sum(state[..., 1:3], axis=-1) / np.array(data["N"])
return mean_and_ci(prev, name=name)
idx = prediction.coords["location"]
prev = pd.DataFrame(
calc_prev(predicted_state[..., timepoints[0], :], name="prev"),
index=idx,
)
for t in timepoints[1:]:
tmp = pd.DataFrame(
calc_prev(predicted_state[..., t, :], name=f"prev{t-offset}"),
index=idx,
)
prev = pd.concat([prev, tmp], axis="columns")
prev.to_csv(output_file)
"""Summarises posterior distribution into a geopackage"""
import pickle as pkl
import numpy as np
import xarray
import pandas as pd
import geopandas as gp
def _tier_enum(design_matrix):
"""Turns a factor variable design matrix into
an enumerated column"""
df = design_matrix[-1].to_dataframe()[["value"]]
df = df[df["value"] == 1.0].reset_index()
return df["alert_level"]
def summary_geopackage(input_files, output_file, config):
"""Creates a summary geopackage file
:param input_files: a list of data file names [data pkl,
next_generation_matrix,
insample7,
insample14,
medium_term]
:param output_file: the output geopackage file
:param config: SummaryGeopackage configuration information
"""
# Read in the first input file
data = xarray.open_dataset(input_files.pop(0), group="constant_data")
# Load and filter geopackage
geo = gp.read_file(config["base_geopackage"], layer=config["base_layer"])
geo = geo[geo["lad19cd"].isin(np.array(data.coords["location"]))]
geo = geo.sort_values(by="lad19cd")
# Dump data into the geopackage
while len(input_files) > 0:
fn = input_files.pop()
print(f"Collating {fn}")
try:
columns = pd.read_csv(fn, index_col="location")
except ValueError as e:
raise ValueError(f"Error reading file '{fn}': {e}")
geo = geo.merge(
columns, how="left", left_on="lad19cd", right_index=True
)
geo.to_file(output_file, driver="GPKG")
"""Produces a long-format summary of fitted model results"""
import pickle as pkl
from datetime import date
import numpy as np
import pandas as pd
import xarray
from gemlib.util import compute_state
from covid.model_spec import STOICHIOMETRY
from covid import model_spec
from covid.formats import make_dstl_template
def xarray2summarydf(arr):
mean = arr.mean(dim="iteration").to_dataset(name="value")
q = np.arange(start=0.05, stop=1.0, step=0.05)
quantiles = arr.quantile(q=q, dim="iteration").to_dataset(dim="quantile")
ds = mean.merge(quantiles).rename_vars({qi: f"{qi:.2f}" for qi in q})
return ds.to_dataframe().reset_index()
def prevalence(prediction, popsize):
prev = compute_state(
prediction["initial_state"], prediction["events"], STOICHIOMETRY
)
prev = xarray.DataArray(
prev.numpy(),
coords=[
np.arange(prev.shape[0]),
prediction.coords["location"],
prediction.coords["time"],
np.arange(prev.shape[-1]),
],
dims=["iteration", "location", "time", "state"],
)
prev_per_1e5 = (
prev[..., 1:3].sum(dim="state").reset_coords(drop=True)
/ np.array(popsize)[np.newaxis, :, np.newaxis]
* 100000
)
return xarray2summarydf(prev_per_1e5)
def weekly_pred_cases_per_100k(prediction, popsize):
"""Returns weekly number of cases per 100k of population"""
prediction = prediction[..., 2] # Case removals
prediction = prediction.reset_coords(drop=True)
# TODO: Find better way to sum up into weeks other than
# a list comprehension.
dates = pd.DatetimeIndex(prediction.coords["time"].data)
first_sunday_index = np.where(dates.weekday == 6)[0][0]
weeks = range(first_sunday_index, prediction.coords["time"].shape[0], 7)[
:-1
]
week_incidence = [
prediction[..., week : (week + 7)].sum(dim="time") for week in weeks
]
week_incidence = xarray.concat(
week_incidence, dim=prediction.coords["time"][weeks]
)
week_incidence = week_incidence.transpose(
*prediction.dims, transpose_coords=True
)
# Divide by population sizes
week_incidence = (
week_incidence / np.array(popsize)[np.newaxis, :, np.newaxis] * 100000
)
return xarray2summarydf(week_incidence)
def summary_longformat(input_files, output_file):
"""Draws together pipeline results into a long format
csv file.
:param input_files: a list of filenames [data_pkl,
insample7_nc
insample14_nc,
medium_term_pred_nc,
ngm_nc]
:param output_file: the output CSV with columns `[date,
location,value_name,value,q0.025,q0.975]`
"""
data = xarray.open_dataset(input_files[0], group="constant_data")
cases = xarray.open_dataset(input_files[0], group="observations")["cases"]
df = cases.to_dataframe(name="value").reset_index()
df["value_name"] = "newCasesBySpecimenDate"
df["0.05"] = np.nan
df["0.5"] = np.nan
df["0.95"] = np.nan
# Insample predictive incidence
insample = xarray.open_dataset(input_files[1], group="predictions")
insample_df = xarray2summarydf(
insample["events"][..., 2].reset_coords(drop=True)
)
insample_df["value_name"] = "insample7_Cases"
df = pd.concat([df, insample_df], axis="index")
insample = xarray.open_dataset(input_files[2], group="predictions")
insample_df = xarray2summarydf(
insample["events"][..., 2].reset_coords(drop=True)
)
insample_df["value_name"] = "insample14_Cases"
df = pd.concat([df, insample_df], axis="index")
# Medium term absolute incidence
medium_term = xarray.open_dataset(input_files[3], group="predictions")
medium_df = xarray2summarydf(
medium_term["events"][..., 2].reset_coords(drop=True)
)
medium_df["value_name"] = "absolute_incidence"
df = pd.concat([df, medium_df], axis="index")
# Cumulative cases
medium_df = xarray2summarydf(
medium_term["events"][..., 2].cumsum(dim="time").reset_coords(drop=True)
)
medium_df["value_name"] = "cumulative_absolute_incidence"
df = pd.concat([df, medium_df], axis="index")
# Medium term incidence per 100k
medium_df = xarray2summarydf(
(
medium_term["events"][..., 2].reset_coords(drop=True)
/ np.array(data["N"])[np.newaxis, :, np.newaxis]
)
* 100000
)
medium_df["value_name"] = "incidence_per_100k"
df = pd.concat([df, medium_df], axis="index")
# Weekly incidence per 100k
weekly_incidence = weekly_pred_cases_per_100k(
medium_term["events"], data["N"]
)
weekly_incidence["value_name"] = "weekly_cases_per_100k"
df = pd.concat([df, weekly_incidence], axis="index")
# Medium term prevalence
prev_df = prevalence(medium_term, data["N"])
prev_df["value_name"] = "prevalence"
df = pd.concat([df, prev_df], axis="index")
# Rt
rt = xarray.load_dataset(input_files[4], group="posterior_predictive")[
"R_it"
]
rt_summary = xarray2summarydf(rt.isel(time=-1))
rt_summary["value_name"] = "R"
rt_summary["time"] = rt.coords["time"].data[-1] + np.timedelta64(1, "D")
df = pd.concat([df, rt_summary], axis="index")
quantiles = df.columns[df.columns.str.startswith("0.")]
return make_dstl_template(
group="Lancaster",
model="SpatialStochasticSEIR",
scenario="Nowcast",
creation_date=date.today(),
version=model_spec.VERSION,
age_band="All",
geography=df["location"],
value_date=df["time"],
value_type=df["value_name"],
value=df["value"],
quantiles={q: df[q] for q in quantiles},
).to_excel(output_file, index=False)
if __name__ == "__main__":
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--output", "-o", type=str, required=True, help="Output file"
)
parser.add_argument(
"resultsdir",
type=str,
help="Results directory",
)
args = parser.parse_args()
input_files = [
os.path.join(args.resultsdir, d)
for d in [
"pipeline_data.pkl",
"insample7.nc",
"insample14.nc",
"medium_term.nc",
"ngm.nc",
]
]
summary_longformat(input_files, args.output)