Commit 63519eb3 authored by Chris Jewell's avatar Chris Jewell
Browse files

Switched to NetCDF4 storage of predictive output

Changes:

1. Where xarrays were pickled to disc, we now use xarray.to_netcdf
2. Where xarrays were read from Python pickles, we now use xarray.load_dataset
parent cd339446
......@@ -70,7 +70,7 @@ def run_pipeline(global_config, results_directory, cli_options):
rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("ngm.pkl"),
output=wd("ngm.nc"),
)(next_generation_matrix)
rf.transform(
......@@ -83,7 +83,7 @@ def run_pipeline(global_config, results_directory, cli_options):
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("insample7.pkl"),
output=wd("insample7.nc"),
)
def insample7(input_files, output_file):
predict(
......@@ -97,7 +97,7 @@ def run_pipeline(global_config, results_directory, cli_options):
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("insample14.pkl"),
output=wd("insample14.nc"),
)
def insample14(input_files, output_file):
return predict(
......@@ -112,7 +112,7 @@ def run_pipeline(global_config, results_directory, cli_options):
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("medium_term.pkl"),
output=wd("medium_term.nc"),
)
def medium_term(input_files, output_file):
return predict(
......@@ -137,7 +137,7 @@ def run_pipeline(global_config, results_directory, cli_options):
)(summarize.infec_incidence)
rf.transform(
input=[[process_data, thin_samples, medium_term]],
input=[[process_data, medium_term]],
filter=rf.formatter(),
output=wd("prevalence_summary.csv"),
)(summarize.prevalence)
......@@ -165,7 +165,7 @@ def run_pipeline(global_config, results_directory, cli_options):
# Plot in-sample
@rf.transform(
input=[insample7, insample14],
filter=rf.formatter(".+/insample(?P<LAG>\d+).pkl"),
filter=rf.formatter(".+/insample(?P<LAG>\d+).nc"),
add_inputs=rf.add_inputs(process_data),
output="{path[0]}/insample_plots{LAG[0]}",
extras=["{LAG[0]}"],
......
......@@ -2,7 +2,7 @@
import numpy as np
import pickle as pkl
import pandas as pd
import xarray
def case_exceedance(input_files, lag):
......@@ -17,8 +17,7 @@ def case_exceedance(input_files, lag):
with open(data_file, "rb") as f:
data = pkl.load(f)
with open(prediction_file, "rb") as f:
prediction = pkl.load(f)
prediction = xarray.open_dataset(prediction_file)["events"]
modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1)
observed_cases = np.sum(data["cases"][:, -lag:], axis=-1)
......
"""Create insample plots for a given lag"""
import pickle as pkl
import numpy as np
import xarray
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
def plot_timeseries(prediction, data, dates, title):
def plot_timeseries(mean, quantiles, data, dates, title):
"""Plots a predictive timeseries with data
:param prediction: a [5, T]-shaped array with first dimension
......@@ -21,12 +24,12 @@ def plot_timeseries(prediction, data, dates, title):
# In-sample prediction
plt.fill_between(
dates, y1=prediction[0], y2=prediction[-1], color="lightblue", alpha=0.5
dates, y1=quantiles[0], y2=quantiles[-1], color="lightblue", alpha=0.5
)
plt.fill_between(
dates, y1=prediction[1], y2=prediction[-2], color="lightblue", alpha=1.0
dates, y1=quantiles[1], y2=quantiles[-2], color="lightblue", alpha=1.0
)
plt.plot(dates, prediction[2], color="blue")
plt.plot(dates, mean, color="blue")
plt.plot(dates, data, "+", color="red")
plt.title(title)
......@@ -35,7 +38,6 @@ def plot_timeseries(prediction, data, dates, title):
return fig
def insample_predictive_timeseries(input_files, output_dir, lag):
"""Creates insample plots
......@@ -48,49 +50,51 @@ def insample_predictive_timeseries(input_files, output_dir, lag):
Details
-------
`data_file` is a pickled Python `dict` of data. It should have a member `cases`
which is a `xarray` with dimensions [`location`, `date`] giving the number of
which is a `xarray` with dimensions [`location`, `date`] giving the number of
detected cases in each `location` on each `date`.
`prediction_file` is assumed to be a pickled `xarray` of shape
`prediction_file` is assumed to be a pickled `xarray` of shape
`[K,M,T,R]` where `K` is the number of posterior samples, `M` is the number
of locations, `T` is the number of timepoints, `R` is the number of transitions
in the model. The prediction is assumed to start at `cases.coords['date'][-1] - lag`.
It is assumed that `T >= lag`.
A timeseries graph (png) summarizing for each `location` the prediction against the
observations is written to `output_dir`
"""
prediction_file, data_file = input_files
lag = int(lag)
with open(prediction_file, "rb") as f:
prediction = pkl.load(f)[..., :lag, -1] # removals
prediction = xarray.open_dataset(prediction_file)["events"]
prediction = prediction[..., :lag, -1] # Just removals
with open(data_file, "rb") as f:
data = pkl.load(f)
cases = data['cases']
lads = data['locations']
cases = data["cases"]
lads = data["locations"]
# TODO remove legacy code!
if 'lad19cd' in cases.dims:
cases = cases.rename({'lad19cd': 'location'})
if "lad19cd" in cases.dims:
cases = cases.rename({"lad19cd": "location"})
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
pred_mean = prediction.mean(dim='iteration')
pred_mean = prediction.mean(dim="iteration")
pred_quants = prediction.quantile(
q=[0.025, 0.25, 0.5, 0.75, 0.975], dim='iteration',
q=[0.025, 0.25, 0.5, 0.75, 0.975],
dim="iteration",
)
for location in cases.coords['location']:
for location in cases.coords["location"]:
print("Location:", location.data)
fig = plot_timeseries(
pred_mean.loc[location, :],
pred_quants.loc[:, location, :],
cases.loc[location][-lag:],
cases.coords['date'][-lag:],
lads.loc[lads['lad19cd'] == location, 'name'].iloc[0],
cases.coords["date"][-lag:],
lads.loc[lads["lad19cd"] == location, "name"].iloc[0],
)
plt.savefig(output_dir.joinpath(f"{location.data}.png"))
plt.close()
......@@ -61,9 +61,10 @@ def next_generation_matrix(input_files, output_file):
],
dims=["iteration", "dest", "src"],
)
ngm = xarray.Dataset({"ngm": ngm})
# Output
with open(output_file, "wb") as f:
pkl.dump(ngm, f)
ngm.to_netcdf(output_file)
if __name__ == "__main__":
......
"""Calculates overall Rt given a posterior next generation matix"""
import numpy as np
import pickle as pkl
import xarray
import pandas as pd
from covid.summary import (
......@@ -12,9 +12,7 @@ from covid.summary import (
def overall_rt(next_generation_matrix, output_file):
with open(next_generation_matrix, "rb") as f:
ngms = pkl.load(f)
ngms = xarray.open_dataset(next_generation_matrix)["ngm"]
b, _ = power_iteration(ngms)
rt = rayleigh_quotient(ngms, b)
q = np.arange(0.05, 1.0, 0.05)
......
......@@ -85,10 +85,19 @@ def predict(data, posterior_samples, output_file, initial_step, num_steps):
],
dims=("iteration", "location", "time", "event"),
)
prediction.attrs["initial_state"] = estimated_init_state
with open(output_file, "wb") as f:
pkl.dump(prediction, f)
estimated_init_state = xarray.DataArray(
estimated_init_state,
coords=[
np.arange(estimated_init_state.shape[0]),
covar_data["locations"]["lad19cd"],
np.arange(estimated_init_state.shape[-1]),
],
dims=("iteration", "location", "state"),
)
ds = xarray.Dataset(
{"events": prediction, "initial_state": estimated_init_state}
)
ds.to_netcdf(output_file)
if __name__ == "__main__":
......
......@@ -2,6 +2,7 @@
import numpy as np
import pickle as pkl
import xarray
import pandas as pd
from gemlib.util import compute_state
......@@ -19,8 +20,7 @@ def rt(input_file, output_file):
:param output_file: a .csv of mean (ci) values
"""
with open(input_file, "rb") as f:
ngm = pkl.load(f)
ngm = xarray.open_dataset(input_file)["ngm"]
rt = np.sum(ngm, axis=-2)
rt_summary = mean_and_ci(rt, name="Rt")
......@@ -41,8 +41,7 @@ def infec_incidence(input_file, output_file):
:param output_file: csv with prediction summaries
"""
with open(input_file, "rb") as f:
prediction = pkl.load(f)
prediction = xarray.open_dataset(input_file)["events"]
offset = 4
timepoints = SUMMARY_DAYS + offset
......@@ -72,7 +71,7 @@ def prevalence(input_files, output_file):
"""Reconstruct predicted prevalence from
original data and projection.
:param input_files: a list of [data pickle, samples pickle, prediction pickle]
:param input_files: a list of [data pickle, prediction netCDF]
:param output_file: a csv containing prevalence summary
"""
offset = 4 # Account for recording lag
......@@ -81,17 +80,10 @@ def prevalence(input_files, output_file):
with open(input_files[0], "rb") as f:
data = pkl.load(f)
with open(input_files[1], "rb") as f:
samples = pkl.load(f)
prediction = xarray.open_dataset(input_files[1])
with open(input_files[2], "rb") as f:
prediction = pkl.load(f)
insample_state = compute_state(
samples["init_state"], samples["seir"], STOICHIOMETRY
)
predicted_state = compute_state(
insample_state[..., -1, :], prediction, STOICHIOMETRY
prediction["initial_state"], prediction["events"], STOICHIOMETRY
)
def calc_prev(state, name=None):
......
......@@ -20,14 +20,16 @@ def xarray2summarydf(arr):
return ds.to_dataframe().reset_index()
def prevalence(events, popsize):
prev = compute_state(events.attrs["initial_state"], events, STOICHIOMETRY)
def prevalence(prediction, popsize):
prev = compute_state(
prediction["initial_state"], prediction["events"], STOICHIOMETRY
)
prev = xarray.DataArray(
prev.numpy(),
coords=[
np.arange(prev.shape[0]),
events.coords["location"],
events.coords["time"],
prediction.coords["location"],
prediction.coords["time"],
np.arange(prev.shape[-1]),
],
dims=["iteration", "location", "time", "state"],
......@@ -42,8 +44,8 @@ def prevalence(events, popsize):
def weekly_pred_cases_per_100k(prediction, popsize):
"""Returns weekly number of cases per 100k of population"""
prediction = prediction[..., 2] # Case removals
prediction = prediction[..., 2] # Case removals
prediction = prediction.reset_coords(drop=True)
# TODO: Find better way to sum up into weeks other than
......@@ -87,21 +89,25 @@ def summary_longformat(input_files, output_file):
df["0.95"] = np.nan
# Insample predictive incidence
with open(input_files[1], "rb") as f:
insample = pkl.load(f)
insample_df = xarray2summarydf(insample[..., 2].reset_coords(drop=True))
insample = xarray.open_dataset(input_files[1])
insample_df = xarray2summarydf(
insample["events"][..., 2].reset_coords(drop=True)
)
insample_df["value_name"] = "insample14_Cases"
df = pd.concat([df, insample_df], axis="index")
# Medium term incidence
with open(input_files[2], "rb") as f:
medium_term = pkl.load(f)
medium_df = xarray2summarydf(medium_term[..., 2].reset_coords(drop=True))
medium_term = xarray.open_dataset(input_files[2])
medium_df = xarray2summarydf(
medium_term["events"][..., 2].reset_coords(drop=True)
)
medium_df["value_name"] = "Cases"
df = pd.concat([df, medium_df], axis="index")
# Weekly incidence per 100k
weekly_incidence = weekly_pred_cases_per_100k(medium_term, data["N"])
weekly_incidence = weekly_pred_cases_per_100k(
medium_term["events"], data["N"]
)
weekly_incidence["value_name"] = "weekly_cases_per_100k"
df = pd.concat([df, weekly_incidence], axis="index")
......@@ -111,8 +117,7 @@ def summary_longformat(input_files, output_file):
df = pd.concat([df, prev_df], axis="index")
# Rt
with open(input_files[3], "rb") as f:
ngms = pkl.load(f)
ngms = xarray.load_dataset(input_files[3])["ngm"]
rt = ngms.sum(dim="dest")
rt = rt.rename({"src": "location"})
rt_summary = xarray2summarydf(rt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment