Commit fb384975 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'refactor-pipeline' into 'master'

Major pipeline refactor

See merge request !1
parents c8546479 251439e6
*~
test_config.yaml
.ruffus_history.sqlite
__pycache__
FROM nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04
ENV PYTHONFAULTHANDLER=1 \
PYTHONHASHSEED=random \
PYTHONUNBUFFERED=1
WORKDIR /app
VOLUME /results
COPY . /app
ENV PIP_DEFAULT_TIMEOUT=100 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_NO_CACHE_DIR=1
# System utils
RUN apt-get update
RUN apt-get install -y curl
RUN apt-get install -y git
RUN apt-get install -y python3.7
RUN apt-get install -y python3-distutils
# Upgrade to latest pip to ensure PEP517 compatibility
RUN curl -O https://bootstrap.pypa.io/get-pip.py
RUN /usr/bin/python3.7 get-pip.py
# Install covid_pipeline
RUN python3.7 -m pip install .
ENTRYPOINT ["/usr/bin/python3.7", "-m", "covid_pipeline.pipeline", "-c", "config.yaml", "-r", "/results"]
# Covid stochastic model configuration
ProcessData:
date_range:
- 2020-10-09
- 2021-01-01
mobility_matrix: data/mergedflows.csv
population_size: data/c2019modagepop.csv
commute_volume: # Can be replaced by DfT traffic flow data - contact authors <c.jewell@lancaster.ac.uk>
geopackage: data/UK2019mod_pop.gpkg
CasesData:
input: url
address: https://api.coronavirus.data.gov.uk/v2/data?areaType=ltla&metric=newCasesBySpecimenDate&format=json
pillars: None # Capability to filter Pillar 1 and 2 testing data from PHE confidential line listing
measure: None # Capability to filter date of test report from PHE confidential line listing
format: gov
AreaCodeData:
input: json
address: "https://services1.arcgis.com/ESMARspQHYMw9BZ9/arcgis/rest/services/LAD_APR_2019_UK_NC/FeatureServer/0/query?where=1%3D1&outFields=LAD19CD,LAD19NM&returnGeometry=false&returnDistinctValues=true&orderByFields=LAD19CD&outSR=4326&f=json"
format: ons
regions:
- S # Scotland
- E # England
- W # Wales
- N # Northern Ireland
Mcmc:
dmax: 84 # Max distance to move events
nmax: 25 # Max num events per metapopulation/time to move
m: 1 # Number of metapopulations to move
occult_nmax: 8 # Max number of occults to add/delete per metapop/time
num_event_time_updates: 380 # Num event and occult updates per sweep of Gibbs MCMC sampler.
num_bursts: 100 # Number of MCMC bursts of `num_burst_samples`
num_burst_samples: 50 # Number of MCMC samples per burst
thin: 20 # Thin MCMC samples every `thin` iterations
num_adaptation_iterations: 1000
ThinPosterior: # Post-process further chain thinning HDF5 -> .pkl.
start: 3000
end:
by: 4
Geopackage: # covid.tasks.summary_geopackage
base_geopackage: data/UK2019mod_pop.gpkg
base_layer: UK2019mod_pop_xgen
"""Provides functions to format data"""
import pandas as pd
def _expand_quantiles(q_dict):
"""Expand a dictionary of quantiles"""
q_str = [
"0.05",
"0.1",
"0.15",
"0.2",
"0.25",
"0.3",
"0.35",
"0.4",
"0.45",
"0.5",
"0.55",
"0.6",
"0.65",
"0.7",
"0.75",
"0.8",
"0.85",
"0.9",
"0.95",
]
quantiles = {f"Quantile {q}": None for q in q_str}
if q_dict is None:
return quantiles
for k, v in q_dict.items():
q_key = f"Quantile {float(k)}" # Coerce back to float to strip trailing 0s
if q_key not in quantiles.keys():
raise KeyError(f"quantile '{k}' not compatible with template form")
quantiles[q_key] = v
return [pd.Series(v, name=k).reset_index(drop=True) for k, v in quantiles.items()]
def _split_dates(dates):
if dates is None:
return {"day": None, "month": None, "year": None}
if hasattr(dates, "__iter__"):
dx = pd.DatetimeIndex(dates)
else:
dx = pd.DatetimeIndex([dates])
return {"day": dx.day, "month": dx.month, "year": dx.year}
def make_dstl_template(
group=None,
model=None,
scenario=None,
model_type=None,
version=None,
creation_date=None,
value_date=None,
age_band=None,
geography=None,
value_type=None,
value=None,
quantiles=None,
):
"""Formats a DSTL-type Excel results template"""
# Process date
creation_date_parts = _split_dates(creation_date)
value_date_parts = _split_dates(value_date)
quantile_series = _expand_quantiles(quantiles)
# DSTL require only MAJOR and MINOR version number
version = ".".join(version.split(".")[:2])
fields = {
"Group": group,
"Model": model,
"Scenario": scenario,
"ModelType": model_type,
"Version": version,
"Creation Day": creation_date_parts["day"],
"Creation Month": creation_date_parts["month"],
"Creation Year": creation_date_parts["year"],
"Day of Value": value_date_parts["day"],
"Month of Value": value_date_parts["month"],
"Year of Value": value_date_parts["year"],
"AgeBand": age_band,
"Geography": geography,
"ValueType": value_type,
"Value": value,
}
fields = [pd.Series(v, name=k).reset_index(drop=True) for k, v in fields.items()]
return pd.concat(fields + quantile_series, axis="columns").ffill(axis="index")
"""A Ruffus-ised pipeline for COVID-19 analysis"""
from os.path import expandvars
import yaml
import datetime
import ruffus as rf
from covid_pipeline.ruffus_pipeline import run_pipeline
def _import_global_config(config_file):
with open(config_file, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
if __name__ == "__main__":
# Ruffus wrapper around argparse used to give us ruffus
# cmd line switches as well as our own config
argparser = rf.cmdline.get_argparse(description="COVID-19 pipeline")
data_args = argparser.add_argument_group(
"Data options", "Options controlling input data"
)
data_args.add_argument(
"-c", "--config", type=str, help="global configuration file", required=True,
)
data_args.add_argument(
"-r",
"--results-directory",
type=str,
help="pipeline results directory",
required=True,
)
data_args.add_argument(
"--date-range",
type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"),
nargs=2,
help="Date range [low high)",
metavar="ISO6801",
)
cli_options = argparser.parse_args()
global_config = _import_global_config(cli_options.config)
if cli_options.date_range is not None:
global_config["ProcessData"]["date_range"][0] = cli_options.date_range[0]
global_config["ProcessData"]["date_range"][1] = cli_options.date_range[1]
run_pipeline(global_config, cli_options.results_directory, cli_options)
"""Represents the analytic pipeline as a ruffus chain"""
import os
from datetime import datetime
from uuid import uuid1
import json
import yaml
import netCDF4 as nc
import pandas as pd
import ruffus as rf
from covid19uk import (
assemble_data,
mcmc,
thin_posterior,
predict,
reproduction_number,
within_between,
__version__ as covid19uk_version,
)
from covid_pipeline.tasks import (
overall_rt,
summarize,
case_exceedance,
summary_geopackage,
summary_longformat,
crystalcast_output,
)
__all__ = ["run_pipeline"]
def _make_append_work_dir(work_dir):
return lambda filename: os.path.expandvars(os.path.join(work_dir, filename))
def _create_metadata(config):
return dict(
pipeline_id=uuid1().hex,
created_at=str(datetime.now()),
inference_library="GEM",
inference_library_version="0.1.1-alpha.1",
model_version=covid19uk_version,
pipeline_config=json.dumps(config, default=str),
)
def _create_nc_file(output_file, meta_data_dict):
nc_file = nc.Dataset(output_file, "w", format="NETCDF4")
for k, v in meta_data_dict.items():
setattr(nc_file, k, v)
nc_file.close()
def run_pipeline(global_config, results_directory, cli_options):
wd = _make_append_work_dir(results_directory)
pipeline_meta = _create_metadata(global_config)
# Pipeline starts here
@rf.mkdir(results_directory)
@rf.originate(wd("config.yaml"), global_config)
def save_config(output_file, config):
with open(output_file, "w") as f:
yaml.dump(config, f)
@rf.follows(save_config)
@rf.originate(wd("inferencedata.nc"), global_config)
def process_data(output_file, config):
_create_nc_file(output_file, pipeline_meta)
assemble_data(output_file, config["ProcessData"])
@rf.transform(
process_data, rf.formatter(), wd("posterior.hd5"), global_config,
)
def run_mcmc(input_file, output_file, config):
mcmc(input_file, output_file, config["Mcmc"])
@rf.transform(
input=run_mcmc,
filter=rf.formatter(),
output=wd("thin_samples.pkl"),
extras=[global_config],
)
def thin_samples(input_file, output_file, config):
thin_posterior(input_file, output_file, config["ThinPosterior"])
# Rt related steps
rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("reproduction_number.nc"),
)(reproduction_number)
rf.transform(
input=reproduction_number, filter=rf.formatter(), output=wd("national_rt.xlsx"),
)(overall_rt)
# In-sample prediction
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("insample7.nc"),
)
def insample7(input_files, output_file):
predict(
data=input_files[0],
posterior_samples=input_files[1],
output_file=output_file,
initial_step=-7,
num_steps=28,
out_of_sample=True,
)
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("insample14.nc"),
)
def insample14(input_files, output_file):
return predict(
data=input_files[0],
posterior_samples=input_files[1],
output_file=output_file,
initial_step=-14,
num_steps=28,
out_of_sample=True,
)
# Medium-term prediction
@rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("medium_term.nc"),
)
def medium_term(input_files, output_file):
return predict(
data=input_files[0],
posterior_samples=input_files[1],
output_file=output_file,
initial_step=-1,
num_steps=84,
out_of_sample=True,
)
# Summarisation
rf.transform(
input=reproduction_number, filter=rf.formatter(), output=wd("rt_summary.csv"),
)(summarize.rt)
rf.transform(
input=medium_term,
filter=rf.formatter(),
output=wd("infec_incidence_summary.csv"),
)(summarize.infec_incidence)
rf.transform(
input=[[process_data, medium_term]],
filter=rf.formatter(),
output=wd("prevalence_summary.csv"),
)(summarize.prevalence)
rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("within_between_summary.csv"),
)(within_between)
@rf.transform(
input=[[process_data, insample7, insample14]],
filter=rf.formatter(),
output=wd("exceedance_summary.csv"),
)
def exceedance(input_files, output_file):
exceed7 = case_exceedance((input_files[0], input_files[1]), 7)
exceed14 = case_exceedance((input_files[0], input_files[2]), 14)
df = pd.DataFrame(
{"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14},
index=exceed7.coords["location"],
)
df.to_csv(output_file)
# Plot in-sample
# @rf.transform(
# input=[insample7, insample14],
# filter=rf.formatter(".+/insample(?P<LAG>\d+).nc"),
# add_inputs=rf.add_inputs(process_data),
# output="{path[0]}/insample_plots{LAG[0]}",
# extras=["{LAG[0]}"],
# )
# def plot_insample_predictive_timeseries(input_files, output_dir, lag):
# insample_predictive_timeseries(input_files, output_dir, lag)
# Geopackage
rf.transform(
[
[
process_data,
summarize.rt,
summarize.infec_incidence,
summarize.prevalence,
within_between,
exceedance,
]
],
rf.formatter(),
wd("prediction.gpkg"),
global_config["Geopackage"],
)(summary_geopackage)
rf.transform(
input=[[process_data, thin_samples, reproduction_number]],
filter=rf.formatter(),
output=wd("crystalcast.xlsx"),
)(crystalcast_output)
rf.cmdline.run(cli_options)
# DSTL Summary
rf.transform(
[[process_data, insample7, insample14, medium_term, reproduction_number,]],
rf.formatter(),
wd("summary_longformat.xlsx"),
)(summary_longformat)
rf.cmdline.run(cli_options)
"""covid_pipeline provides a pipeline and downstream results-summarisation methods"""
from covid_pipeline.tasks.overall_rt import overall_rt
from covid_pipeline.tasks.case_exceedance import case_exceedance
from covid_pipeline.tasks.insample_predictive_timeseries import (
insample_predictive_timeseries,
)
from covid_pipeline.tasks.summary_geopackage import summary_geopackage
from covid_pipeline.tasks.summary_longformat import summary_longformat
import covid_pipeline.tasks.summarize as summarize
from covid_pipeline.tasks.crystalcast_output import crystalcast_output
__all__ = [
"overall_rt",
"case_exceedance",
"insample_predictive_timeseries",
"summary_geopackage",
"summary_longformat",
"summarize",
"crystalcast_output",
]
"""Calculates case exceedance probabilities"""
import numpy as np
import pickle as pkl
import xarray
def case_exceedance(input_files, lag):
"""Calculates case exceedance probabilities,
i.e. Pr(pred[lag:] < observed[lag:])
:param input_files: [data pickle, prediction pickle]
:param lag: the lag for which to calculate the exceedance
"""
data_file, prediction_file = input_files
data = xarray.open_dataset(data_file, group="observations")
prediction = xarray.open_dataset(prediction_file, group="predictions")[
"events"
]
modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1)
observed_cases = np.sum(data["cases"][:, -lag:], axis=-1)
exceedance = np.mean(modelled_cases < observed_cases, axis=0)
return exceedance
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(
description="Calculates case exceedance probabilities"
)
parser.add_argument("data_file", type=str)
parser.add_argument("prediction_file", type=str)
parser.add_argument(
"-l",
"--lag",
type=int,
help="The lag for which to calculate exceedance",
default=7,
)
parser.add_argument(
"-o",
"--output",
type=str,
help="The output csv",
default="exceedance.csv",
)
args = parser.parse_args()
df = case_exceedance([args.data_file, args.prediction_file], args.lag)
df.to_csv(args.output)
"""Creates CrystalCast formatted report for incidence and prevalence"""
from pathlib import Path
from datetime import date
import pickle as pkl
import numpy as np
import xarray
import pandas as pd
from gemlib.util import compute_state
import covid19uk
from covid19uk import model_spec
from covid_pipeline.formats import make_dstl_template
QUANTILES = (0.05, 0.25, 0.5, 0.75, 0.95)
def _events2xarray(samples, constant_data):
event_samples = xarray.DataArray(
samples["seir"],
coords=[
np.arange(samples["seir"].shape[0]),
constant_data.coords["location"],
constant_data.coords["time"],
np.arange(samples["seir"].shape[-1]),
],
dims=["iteration", "location", "time", "event"],
)
initial_state = xarray.DataArray(
samples["initial_state"],
coords=[
constant_data.coords["location"],
np.arange(samples["initial_state"].shape[1]),
],
dims=["location", "state"],
)
return xarray.Dataset({"seir": event_samples, "initial_state": initial_state})
def _xarray2dstl(xarr, value_type, geography):
quantiles = xarr.quantile(q=QUANTILES, dim="iteration")
quantiles = {qi: v for qi, v in zip(QUANTILES, quantiles)}
mean = xarr.mean(dim="iteration")
return make_dstl_template(
group="Lancaster",
model="StochSpatMetaPopSEIR",
model_type="Pillar Two Testing",
scenario="Nowcast",
version=covid19uk.__version__,
creation_date=date.today(),
value_date=xarr.coords["time"].data,
age_band="All",
geography=geography,
value_type=value_type,
value=mean,
quantiles=quantiles,
)