Commit 80bbe105 authored by Chris Jewell's avatar Chris Jewell
Browse files

Addition of insample posterior check plots

parent 9fe819d7
......@@ -10,6 +10,7 @@ import covid.tasks.summarize as summarize
from covid.tasks.within_between import within_between
from covid.tasks.case_exceedance import case_exceedance
from covid.tasks.summary_geopackage import summary_geopackage
from covid.tasks.insample_predictive_timeseries import insample_predictive_timeseries
__all__ = [
"assemble_data",
......@@ -22,4 +23,5 @@ __all__ = [
"within_between",
"case_exceedance",
"summary_geopackage",
"insample_predictive_timeseries",
]
"""Outputs a predictive timeseries for each LAD"""
"""Create insample plots for a given lag"""
import os
from pathlib import Path
import numpy as np
import yaml
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from covid.cli_arg_parse import cli_args
from model_spec import gather_data
from covid.data import read_phe_cases
from covid.data import AreaCodeData
def get_dates(config):
return [np.datetime64(x) for x in config["Global"]["inference_period"]]
def load_cases(config):
return read_phe_cases(
config["data"]["reported_cases"],
*get_dates(config),
pillar=config["data"]["pillar"],
date_type=config["data"]["case_date_type"],
)
def load_prediction(config):
prediction_file = os.path.expandvars(
os.path.join(
config["output"]["results_dir"],
config["output"]["insample_prediction"],
)
)
with open(prediction_file, "rb") as f:
prediction = pkl.load(f)
return prediction
def plot_timeseries(prediction, data, dates, title):
"""Plots a predictive timeseries with data
......@@ -64,43 +31,66 @@ def plot_timeseries(prediction, data, dates, title):
plt.title(title)
fig.autofmt_xdate()
return fig
return fig
def main(config):
date_low, date_high = get_dates(config)
cases = load_cases(config)
prediction = load_prediction(config)[..., -1] # KxMxTxR
lads = AreaCodeData.process(config)
pred_mean = np.mean(prediction, axis=0)
pred_quants = np.quantile(
prediction, q=[0.025, 0.25, 0.5, 0.75, 0.975], axis=0
def insample_predictive_timeseries(input_files, output_dir, lag):
"""Creates insample plots
:param input_files: a list of [prediction_file, data_file] (see Details)
:param output_dir: the output dir to write files to
:param lag: the number of days at the end of the case timeseries for which to
plot the in-sample prediction.
:returns: `None` as output written to disc.
Details
-------
`data_file` is a pickled Python `dict` of data. It should have a member `cases`
which is a `xarray` with dimensions [`location`, `date`] giving the number of
detected cases in each `location` on each `date`.
`prediction_file` is assumed to be a pickled `xarray` of shape
`[K,M,T,R]` where `K` is the number of posterior samples, `M` is the number
of locations, `T` is the number of timepoints, `R` is the number of transitions
in the model. The prediction is assumed to start at `cases.coords['date'][-1] - lag`.
It is assumed that `T >= lag`.
A timeseries graph (png) summarizing for each `location` the prediction against the
observations is written to `output_dir`
"""
prediction_file, data_file = input_files
lag = int(lag)
with open(prediction_file, "rb") as f:
prediction = pkl.load(f)[..., :lag, -1] # removals
with open(data_file, "rb") as f:
data = pkl.load(f)
cases = data['cases']
lads = data['locations']
# TODO remove legacy code!
if 'lad19cd' in cases.dims:
cases = cases.rename({'lad19cd': 'location'})
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
pred_mean = prediction.mean(dim='iteration')
pred_quants = prediction.quantile(
q=[0.025, 0.25, 0.5, 0.75, 0.975], dim='iteration',
)
pred_quants[2] = pred_mean
dates = np.arange(date_high - 14, date_high)
results_dir = Path(
os.path.join(config["output"]["results_dir"], "pred_ts_14day")
)
results_dir.mkdir(parents=False, exist_ok=True)
for i in range(cases.shape[0]):
title = lads["name"].iloc[i]
plot_timeseries(
pred_quants[:, i, :14], cases.iloc[i, -14:], dates, title,
for location in cases.coords['location']:
print("Location:", location.data)
fig = plot_timeseries(
pred_quants.loc[:, location, :],
cases.loc[location][-lag:],
cases.coords['date'][-lag:],
lads.loc[lads['lad19cd'] == location, 'name'].iloc[0],
)
plt.savefig(results_dir.joinpath(f"{lads['lad19cd'].iloc[i]}.png"))
if __name__ == "__main__":
args = cli_args()
with open(args.config, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# Override config file results dir if necessary
if args.results is not None:
config["output"]["results_dir"] = args.results
main(config)
plt.savefig(output_dir.joinpath(f"{location.data}.png"))
plt.close()
......@@ -16,7 +16,7 @@ from covid.tasks import (
within_between,
case_exceedance,
summary_geopackage,
# lancs_spreadsheet,
insample_predictive_timeseries,
)
......@@ -179,12 +179,17 @@ if __name__ == "__main__":
)
df.to_csv(output_file)
# @rf.transform(
# input=[[process_data, insample7, insample14, medium_term]],
# filter=rf.formatter(),
# output=work_dir("total_predictive_timeseries.pdf")
# )(total_predictive_timeseries)
# Plot in-sample
@rf.transform(
input=[insample7, insample14],
filter=rf.formatter(".+/insample(?P<LAG>\d+).pkl"),
add_inputs=rf.add_inputs(process_data),
output="{path[0]}/insample_plots{LAG[0]}",
extras=["{LAG[0]}"],
)
def plot_insample_predictive_timeseries(input_files, output_dir, lag):
insample_predictive_timeseries(input_files, output_dir, lag)
# Geopackage
rf.transform(
[
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment