Commit c9415c4d authored by Chris Jewell's avatar Chris Jewell
Browse files

First refactor such that Covid19UK contains only model-dependent code.

parent c5d7033e
"""Covid19UK model and associated inference/prediction algorithms"""
from covid19uk.data.assemble import assemble_data
from covid19uk.inference.inference import mcmc
from covid19uk.posterior.thin import thin_posterior
from covid19uk.posterior.reproduction_number import reproduction_number
from covid19uk.posterior.predict import predict
from covid19uk.posterior.within_between import within_between
__all__ = [
"assemble_data",
"mcmc",
"thin_posterior",
"reproduction_number",
"predict",
"within_between",
]
"""General argument parsing for all scripts"""
import argparse
def cli_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, help="configuration file")
parser.add_argument(
"-r",
"--results",
type=str,
default=None,
help="override config file results dir",
)
args = parser.parse_args(args)
return args
"""Tensorflow configuration options"""
import tensorflow as tf
import numpy as np
floatX = np.float64
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
to instantiate the COVID19 model""" to instantiate the COVID19 model"""
import os import os
from covid.model_spec import gather_data from covid19uk.model_spec import gather_data
def assemble_data(filename, config): def assemble_data(filename, config):
......
"""Provides functions to format data"""
import pandas as pd
def _expand_quantiles(q_dict):
"""Expand a dictionary of quantiles"""
q_str = [
"0.05",
"0.1",
"0.15",
"0.2",
"0.25",
"0.3",
"0.35",
"0.4",
"0.45",
"0.5",
"0.55",
"0.6",
"0.65",
"0.7",
"0.75",
"0.8",
"0.85",
"0.9",
"0.95",
]
quantiles = {f"Quantile {q}": None for q in q_str}
if q_dict is None:
return quantiles
for k, v in q_dict.items():
q_key = (
f"Quantile {float(k)}" # Coerce back to float to strip trailing 0s
)
if q_key not in quantiles.keys():
raise KeyError(f"quantile '{k}' not compatible with template form")
quantiles[q_key] = v
return [
pd.Series(v, name=k).reset_index(drop=True)
for k, v in quantiles.items()
]
def _split_dates(dates):
if dates is None:
return {"day": None, "month": None, "year": None}
if hasattr(dates, "__iter__"):
dx = pd.DatetimeIndex(dates)
else:
dx = pd.DatetimeIndex([dates])
return {"day": dx.day, "month": dx.month, "year": dx.year}
def make_dstl_template(
group=None,
model=None,
scenario=None,
model_type=None,
version=None,
creation_date=None,
value_date=None,
age_band=None,
geography=None,
value_type=None,
value=None,
quantiles=None,
):
"""Formats a DSTL-type Excel results template"""
# Process date
creation_date_parts = _split_dates(creation_date)
value_date_parts = _split_dates(value_date)
quantile_series = _expand_quantiles(quantiles)
fields = {
"Group": group,
"Model": model,
"Scenario": scenario,
"ModelType": model_type,
"Version": version,
"Creation Day": creation_date_parts["day"],
"Creation Month": creation_date_parts["month"],
"Creation Year": creation_date_parts["year"],
"Day of Value": value_date_parts["day"],
"Month of Value": value_date_parts["month"],
"Year of Value": value_date_parts["year"],
"AgeBand": age_band,
"Geography": geography,
"ValueType": value_type,
"Value": value,
}
fields = [
pd.Series(v, name=k).reset_index(drop=True) for k, v in fields.items()
]
return pd.concat(fields + quantile_series, axis="columns").ffill(
axis="index"
)
...@@ -19,12 +19,12 @@ from gemlib.util import compute_state ...@@ -19,12 +19,12 @@ from gemlib.util import compute_state
from gemlib.mcmc import Posterior from gemlib.mcmc import Posterior
from gemlib.mcmc import GibbsKernel from gemlib.mcmc import GibbsKernel
from gemlib.distributions import BrownianMotion from gemlib.distributions import BrownianMotion
from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel from covid19uk.tasks.mcmc_kernel_factory import make_hmc_base_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel from covid19uk.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel from covid19uk.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_event_multiscan_gibbs_step from covid19uk.tasks.mcmc_kernel_factory import make_event_multiscan_gibbs_step
import covid.model_spec as model_spec import covid19uk.model_spec as model_spec
tfd = tfp.distributions tfd = tfp.distributions
tfb = tfp.bijectors tfb = tfp.bijectors
...@@ -354,8 +354,7 @@ def run_mcmc( ...@@ -354,8 +354,7 @@ def run_mcmc(
current_state = [s[-1] for s in draws] current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0]) draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples( posterior.write_samples(
draws_to_dict(draws), draws_to_dict(draws), first_dim_offset=offset,
first_dim_offset=offset,
) )
posterior.write_results(trace, first_dim_offset=offset) posterior.write_results(trace, first_dim_offset=offset)
offset += first_window_size offset += first_window_size
...@@ -387,8 +386,7 @@ def run_mcmc( ...@@ -387,8 +386,7 @@ def run_mcmc(
current_state = [s[-1] for s in draws] current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0]) draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples( posterior.write_samples(
draws_to_dict(draws), draws_to_dict(draws), first_dim_offset=offset,
first_dim_offset=offset,
) )
posterior.write_results(trace, first_dim_offset=offset) posterior.write_results(trace, first_dim_offset=offset)
offset += window_num_draws offset += window_num_draws
...@@ -408,8 +406,7 @@ def run_mcmc( ...@@ -408,8 +406,7 @@ def run_mcmc(
current_state = [s[-1] for s in draws] current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0]) draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples( posterior.write_samples(
draws_to_dict(draws), draws_to_dict(draws), first_dim_offset=offset,
first_dim_offset=offset,
) )
posterior.write_results(trace, first_dim_offset=offset) posterior.write_results(trace, first_dim_offset=offset)
offset += last_window_size offset += last_window_size
...@@ -435,12 +432,10 @@ def run_mcmc( ...@@ -435,12 +432,10 @@ def run_mcmc(
current_state = [state_part[-1] for state_part in draws] current_state = [state_part[-1] for state_part in draws]
draws[0] = param_bijector.inverse(draws[0]) draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples( posterior.write_samples(
draws_to_dict(draws), draws_to_dict(draws), first_dim_offset=offset,
first_dim_offset=offset,
) )
posterior.write_results( posterior.write_results(
trace, trace, first_dim_offset=offset,
first_dim_offset=offset,
) )
offset += config["num_burst_samples"] offset += config["num_burst_samples"]
...@@ -534,11 +529,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -534,11 +529,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
tf.concat( tf.concat(
[ [
np.array([0.1, 0.0, 0.0, 0.0], dtype=DTYPE), np.array([0.1, 0.0, 0.0, 0.0], dtype=DTYPE),
np.full( np.full(events.shape[1], -1.75, dtype=DTYPE,),
events.shape[1],
-1.75,
dtype=DTYPE,
),
], ],
axis=0, axis=0,
), ),
...@@ -560,8 +551,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -560,8 +551,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
) )
posterior._file.create_dataset("initial_state", data=initial_state) posterior._file.create_dataset("initial_state", data=initial_state)
posterior._file.create_dataset( posterior._file.create_dataset(
"time", "time", data=np.array(dates).astype(str).astype(h5py.string_dtype()),
data=np.array(dates).astype(str).astype(h5py.string_dtype()),
) )
print(f"Acceptance theta: {posterior['results/hmc/is_accepted'][:].mean()}") print(f"Acceptance theta: {posterior['results/hmc/is_accepted'][:].mean()}")
......
...@@ -10,12 +10,11 @@ import tensorflow_probability as tfp ...@@ -10,12 +10,11 @@ import tensorflow_probability as tfp
from gemlib.distributions import DiscreteTimeStateTransitionModel from gemlib.distributions import DiscreteTimeStateTransitionModel
from gemlib.distributions import BrownianMotion from gemlib.distributions import BrownianMotion
from covid.util import impute_previous_cases from covid19uk.util import impute_previous_cases
import covid.data as data import covid19uk.data.loaders as data
tfd = tfp.distributions tfd = tfp.distributions
VERSION = "0.7.1"
DTYPE = np.float64 DTYPE = np.float64
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]]) STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
...@@ -46,12 +45,9 @@ def gather_data(config): ...@@ -46,12 +45,9 @@ def gather_data(config):
) )
geo = gp.read_file(config["geopackage"]) geo = gp.read_file(config["geopackage"])
geo = geo.sort_values("lad19cd") geo = geo.sort_values("lad19cd")
geo = geo[geo['lad19cd'].isin(locations['lad19cd'])] geo = geo[geo["lad19cd"].isin(locations["lad19cd"])]
area = xarray.DataArray( area = xarray.DataArray(
geo.area, geo.area, name="area", dims=["location"], coords=[geo["lad19cd"]],
name="area",
dims=["location"],
coords=[geo["lad19cd"]],
) )
# tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]] # tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
......
"""A Ruffus-ised pipeline for COVID-19 analysis"""
import os
from os.path import expandvars
import warnings
import yaml
import datetime
import s3fs
import ruffus as rf
from covid.ruffus_pipeline import run_pipeline
def _import_global_config(config_file):
with open(config_file, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
if __name__ == "__main__":
# Ruffus wrapper around argparse used to give us ruffus
# cmd line switches as well as our own config
argparser = rf.cmdline.get_argparse(description="COVID-19 pipeline")
data_args = argparser.add_argument_group(
"Data options", "Options controlling input data"
)
data_args.add_argument(
"-c",
"--config",
type=str,
help="global configuration file",
required=True,
)
data_args.add_argument(
"-r",
"--results-directory",
type=str,
help="pipeline results directory",
required=True,
)
data_args.add_argument(
"--date-range",
type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"),
nargs=2,
help="Date range [low high)",
metavar="ISO6801",
)
data_args.add_argument(
"--reported-cases", type=str, help="Path to case file"
)
data_args.add_argument(
"--commute-volume", type=str, help="Path to commute volume file"
)
data_args.add_argument(
"--case-date-type",
type=str,
help="Case date type (specimen | report)",
choices=["specimen", "report"],
)
data_args.add_argument(
"--pillar", type=str, help="Pillar", choices=["both", "1", "2"]
)
data_args.add_argument("--aws", action="store_true", help="Push to AWS")
cli_options = argparser.parse_args()
global_config = _import_global_config(cli_options.config)
if cli_options.date_range is not None:
global_config["ProcessData"]["date_range"][0] = cli_options.date_range[
0
]
global_config["ProcessData"]["date_range"][1] = cli_options.date_range[
1
]
if cli_options.reported_cases is not None:
global_config["ProcessData"]["CasesData"]["address"] = expandvars(
cli_options.reported_cases
)
if cli_options.commute_volume is not None:
global_config["ProcessData"]["commute_volume"] = expandvars(
cli_options.commute_volume
)
if cli_options.case_date_type is not None:
global_config["ProcessData"][
"case_date_type"
] = cli_options.case_date_type
if cli_options.pillar is not None:
opts = {
"both": ["Pillar 1", "Pillar 2"],
"1": ["Pillar 1"],
"2": ["Pillar 2"],
}
global_config["ProcessData"]["CasesData"]["pillars"] = opts[
cli_options.pillar
]
run_pipeline(global_config, cli_options.results_directory, cli_options)
"Plot functions for Covid-19 data brick"
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfs = tfp.stats
def plot_prediction(prediction_period, sims, case_reports):
sims = tf.reduce_sum(sims, axis=-2) # Sum over all meta-populations
quantiles = [2.5, 50, 97.5]
dates = np.arange(prediction_period[0],
prediction_period[1],
np.timedelta64(1, 'D'))
total_infected = tfs.percentile(tf.reduce_sum(sims[:, :, 1:3], axis=2), q=quantiles, axis=0)
removed = tfs.percentile(sims[:, :, 3], q=quantiles, axis=0)
removed_observed = tfs.percentile(removed * 0.1, q=quantiles, axis=0)
fig = plt.figure()
filler = plt.fill_between(dates, total_infected[0, :], total_infected[2, :], color='lightgray', alpha=0.8, label="95% credible interval")
plt.fill_between(dates, removed[0, :], removed[2, :], color='lightgray', alpha=0.8)
plt.fill_between(dates, removed_observed[0, :], removed_observed[2, :], color='lightgray', alpha=0.8)
ti_line = plt.plot(dates, total_infected[1, :], '-', color='red', alpha=0.4, label="Infected")
rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed")
ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections')
data_range = [case_reports['DateVal'].to_numpy().min(), case_reports['DateVal'].to_numpy().max()]
one_day = np.timedelta64(1, 'D')
data_dates = np.arange(data_range[0], data_range[1]+one_day, one_day)
marks = plt.plot(data_dates, case_reports['CumCases'].to_numpy(), '+', label='Observed cases')
plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]],
["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"])
plt.grid(color='lightgray', linestyle='dotted')
plt.xlabel("Date")
plt.ylabel("Individuals")
fig.autofmt_xdate()
plt.show()
def plot_case_incidence(date_range, sims):
# Number of new cases per day
dates = np.arange(date_range[0], date_range[1])
new_cases = sims[:, :, :, 3].sum(axis=2)
new_cases = new_cases[:, 1:] - new_cases[:, :-1]
new_cases = tfs.percentile(new_cases, q=[2.5, 50, 97.5], axis=0)/10000.
fig = plt.figure()
plt.fill_between(dates[:-1], new_cases[0, :], new_cases[2, :], color='lightgray', label="95% credible interval")
plt.plot(dates[:-1], new_cases[1, :], '-', alpha=0.2, label='New cases')
plt.grid(color='lightgray', linestyle='dotted')
plt.xlabel("Date")
plt.ylabel("Incidence per 10,000")
fig.autofmt_xdate()
plt.show()
\ No newline at end of file
...@@ -6,8 +6,8 @@ import pickle as pkl ...@@ -6,8 +6,8 @@ import pickle as pkl
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from covid import model_spec from covid19uk import model_spec
from covid.util import copy_nc_attrs from covid19uk.util import copy_nc_attrs
from gemlib.util import compute_state from gemlib.util import compute_state
...@@ -30,9 +30,7 @@ def predicted_incidence( ...@@ -30,9 +30,7 @@ def predicted_incidence(
""" """
posterior_state = compute_state( posterior_state = compute_state(
init_state, init_state, posterior_samples["seir"], model_spec.STOICHIOMETRY,
posterior_samples["seir"],
model_spec.STOICHIOMETRY,
) )
posterior_samples["new_init_state"] = posterior_state[..., init_step, :] posterior_samples["new_init_state"] = posterior_state[..., init_step, :]
del posterior_samples["seir"] del posterior_samples["seir"]
...@@ -166,14 +164,10 @@ if __name__ == "__main__": ...@@ -166,14 +164,10 @@ if __name__ == "__main__":
) )
parser.add_argument("data_pkl", type=str, help="Covariate data pickle") parser.add_argument("data_pkl", type=str, help="Covariate data pickle")
parser.add_argument( parser.add_argument(
"posterior_samples_pkl", "posterior_samples_pkl", type=str, help="Posterior samples pickle",
type=str,
help="Posterior samples pickle",
) )
parser.add_argument( parser.add_argument(
"output_file", "output_file", type=str, help="Output pkl file",
type=str,
help="Output pkl file",
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -5,8 +5,8 @@ import numpy as np ...@@ -5,8 +5,8 @@ import numpy as np
import xarray import xarray
import tensorflow as tf import tensorflow as tf
from covid import model_spec from covid19uk import model_spec
from covid.util import copy_nc_attrs from covid19uk.util import copy_nc_attrs
from gemlib.util import compute_state from gemlib.util import compute_state
...@@ -41,10 +41,7 @@ def calc_posterior_rit(samples, initial_state, times, covar_data): ...@@ -41,10 +41,7 @@ def calc_posterior_rit(samples, initial_state, times, covar_data):
ngm = tf.vectorized_map(fn, elems=times) ngm = tf.vectorized_map(fn, elems=times)
return tf.reduce_sum(ngm, axis=-2) # sum over destinations return tf.reduce_sum(ngm, axis=-2) # sum over destinations
return tf.vectorized_map( return tf.vectorized_map(r_fn, elems=tf.nest.flatten(samples),)
r_fn,
elems=tf.nest.flatten(samples),
)
CHUNKSIZE = 50 CHUNKSIZE = 50
...@@ -97,16 +94,10 @@ if __name__ == "__main__": ...@@ -97,16 +94,10 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"samples", "samples", type=str, help="A pickle file with MCMC samples",
type=str,
help="A pickle file with MCMC samples",
) )
parser.add_argument( parser.add_argument(
"-d", "-d", "--data", type=str, help="A data glob pickle file", required=True,
"--data",
type=str,
help="A data glob pickle file",
required=True,
) )
parser.add_argument( parser.add_argument(
"-o", "--output", type=str, help="The output file", required=True "-o", "--output", type=str, help="The output file", required=True
......
...@@ -7,7 +7,7 @@ import xarray ...@@ -7,7 +7,7 @@ import xarray
import tensorflow as tf import tensorflow as tf
from gemlib.util import compute_state