Commit 5c2058da authored by Chris Jewell's avatar Chris Jewell
Browse files

Ruffus-style pipeline implemented

parent 74aaa3dd
...@@ -66,16 +66,6 @@ class AreaCodeData: ...@@ -66,16 +66,6 @@ class AreaCodeData:
if settings["format"] == "ons": if settings["format"] == "ons":
print("Retrieving Area Code data from the ONS") print("Retrieving Area Code data from the ONS")
data = response.json() data = response.json()
if config["GenerateOutput"]["storeInputs"]:
fn = format_output_filename(
config["GenerateOutput"]["scrapedDataDir"]
+ "AreaCodeData_ONS.json",
config,
)
with open(fn, "w") as f:
json.dump(data, f)
df = AreaCodeData.getJSON(json.dumps(data)) df = AreaCodeData.getJSON(json.dumps(data))
return df return df
...@@ -162,28 +152,22 @@ class AreaCodeData: ...@@ -162,28 +152,22 @@ class AreaCodeData:
""" """
Adapt the area codes to the desired dataframe format Adapt the area codes to the desired dataframe format
""" """
output_settings = config["GenerateOutput"]
settings = config["AreaCodeData"] settings = config["AreaCodeData"]
output = settings["output"]
regions = settings["regions"] regions = settings["regions"]
if settings["input"] == "processed": if settings["input"] == "processed":
return df return df
if settings["format"].lower() == "ons": if settings["format"].lower() == "ons":
df = AreaCodeData.adapt_ons(df, regions, output, config) df = AreaCodeData.adapt_ons(df, regions)
# if we have a predefined list of LADs, filter them down # if we have a predefined list of LADs, filter them down
if "lad19cds" in config: if "lad19cds" in config:
df = df[[x in config["lad19cds"] for x in df.lad19cd.values]] df = df[[x in config["lad19cds"] for x in df.lad19cd.values]]
if output_settings["storeProcessedInputs"] and output != "None":
output = format_output_filename(output, config)
df.to_csv(output, index=False)
return df return df
def adapt_ons(df, regions, output, config): def adapt_ons(df, regions):
colnames = ["lad19cd", "name"] colnames = ["lad19cd", "name"]
df.columns = colnames df.columns = colnames
filters = df["lad19cd"].str.contains(str.join("|", regions)) filters = df["lad19cd"].str.contains(str.join("|", regions))
......
...@@ -14,12 +14,6 @@ def test_url(): ...@@ -14,12 +14,6 @@ def test_url():
"output": "processed_data/processed_lad19cd.csv", "output": "processed_data/processed_lad19cd.csv",
"regions": ["E"], "regions": ["E"],
}, },
"GenerateOutput": {
"storeInputs": True,
"scrapedDataDir": "scraped_data",
"storeProcessedInputs": True,
},
"Global": {"prependID": False, "prependDate": False},
} }
df = AreaCodeData.process(config) df = AreaCodeData.process(config)
......
...@@ -55,16 +55,8 @@ def merge_lad_values(df): ...@@ -55,16 +55,8 @@ def merge_lad_values(df):
def get_date_low_high(config): def get_date_low_high(config):
if "dates" in config: date_range = [np.datetime64(x) for x in config["date_range"]]
low = config["dates"]["low"] return tuple(date_range)
high = config["dates"]["high"]
else:
inference_period = [
np.datetime64(x) for x in config["Global"]["inference_period"]
]
low = inference_period[0]
high = inference_period[1]
return (low, high)
def check_date_format(df): def check_date_format(df):
......
"""Implements the COVID SEIR model as a TFP Joint Distribution""" """Implements the COVID SEIR model as a TFP Joint Distribution"""
import pandas as pd import pandas as pd
import geopandas as gp
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
...@@ -19,7 +18,7 @@ XI_FREQ = 14 # baseline transmission changes every 14 days ...@@ -19,7 +18,7 @@ XI_FREQ = 14 # baseline transmission changes every 14 days
NU = tf.constant(0.28, dtype=DTYPE) # E->I rate assumed known. NU = tf.constant(0.28, dtype=DTYPE) # E->I rate assumed known.
def read_covariates(config): def gather_data(config):
"""Loads covariate data """Loads covariate data
:param paths: a dictionary of paths to data with keys {'mobility_matrix', :param paths: a dictionary of paths to data with keys {'mobility_matrix',
...@@ -27,31 +26,36 @@ def read_covariates(config): ...@@ -27,31 +26,36 @@ def read_covariates(config):
:returns: a dictionary of covariate information to be consumed by the model :returns: a dictionary of covariate information to be consumed by the model
{'C': commute_matrix, 'W': traffic_flow, 'N': population_size} {'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
""" """
paths = config["data"]
date_low = np.datetime64(config["Global"]["inference_period"][0]) date_low = np.datetime64(config["date_range"][0])
date_high = np.datetime64(config["Global"]["inference_period"][1]) date_high = np.datetime64(config["date_range"][1])
mobility = data.read_mobility(paths["mobility_matrix"]) mobility = data.read_mobility(config["mobility_matrix"])
popsize = data.read_population(paths["population_size"]) popsize = data.read_population(config["population_size"])
commute_volume = data.read_traffic_flow( commute_volume = data.read_traffic_flow(
paths["commute_volume"], date_low=date_low, date_high=date_high config["commute_volume"], date_low=date_low, date_high=date_high
) )
geo = gp.read_file(paths["geopackage"]) locations = data.AreaCodeData.process(config)
geo = geo.loc[geo["lad19cd"].str.startswith("E")]
# tier_restriction = data.read_challen_tier_restriction(
# paths["tier_restriction_csv"],
# date_low,
# date_high,
# )
tier_restriction = data.TierData.process(config)[:, :, 2:] tier_restriction = data.TierData.process(config)[:, :, 2:]
date_range = [date_low, date_high]
weekday = pd.date_range(date_low, date_high).weekday < 5 weekday = pd.date_range(date_low, date_high).weekday < 5
cases = data.read_phe_cases(
config["reported_cases"],
date_low,
date_high,
pillar=config["pillar"],
date_type=config["case_date_type"],
)
return dict( return dict(
C=mobility.to_numpy().astype(DTYPE), C=mobility.to_numpy().astype(DTYPE),
W=commute_volume.to_numpy().astype(DTYPE), W=commute_volume.to_numpy().astype(DTYPE),
N=popsize.to_numpy().astype(DTYPE), N=popsize.to_numpy().astype(DTYPE),
L=tier_restriction.astype(DTYPE), L=tier_restriction.astype(DTYPE),
weekday=weekday.astype(DTYPE), weekday=weekday.astype(DTYPE),
date_range=date_range,
locations=locations,
cases=cases,
) )
...@@ -143,6 +147,15 @@ def CovidUK(covariates, initial_state, initial_step, num_steps): ...@@ -143,6 +147,15 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
gamma0 = tf.convert_to_tensor(gamma0, DTYPE) gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
gamma1 = tf.convert_to_tensor(gamma1, DTYPE) gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
Cstar = C + tf.transpose(C)
Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
W = tf.convert_to_tensor(tf.squeeze(covariates["W"]), dtype=DTYPE)
N = tf.convert_to_tensor(tf.squeeze(covariates["N"]), dtype=DTYPE)
L = tf.convert_to_tensor(covariates["L"], DTYPE) L = tf.convert_to_tensor(covariates["L"], DTYPE)
L = L - tf.reduce_mean(L, axis=(0, 1)) L = L - tf.reduce_mean(L, axis=(0, 1))
...@@ -150,14 +163,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps): ...@@ -150,14 +163,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
weekday = weekday - tf.reduce_mean(weekday, axis=-1) weekday = weekday - tf.reduce_mean(weekday, axis=-1)
def transition_rate_fn(t, state): def transition_rate_fn(t, state):
C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=DTYPE))
Cstar = C + tf.transpose(C)
Cstar = tf.linalg.set_diag(Cstar, -tf.reduce_sum(C, axis=-2))
W = tf.constant(np.squeeze(covariates["W"]), dtype=DTYPE)
N = tf.constant(np.squeeze(covariates["N"]), dtype=DTYPE)
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1) w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx) commute_volume = tf.gather(W, w_idx)
...@@ -166,7 +171,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps): ...@@ -166,7 +171,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
dtype=tf.int64, dtype=tf.int64,
) )
xi_ = tf.gather(xi, xi_idx) xi_ = tf.gather(xi, xi_idx)
L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1) L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
Lt = tf.gather(L, L_idx) Lt = tf.gather(L, L_idx)
xB = tf.linalg.matvec(Lt, beta3) xB = tf.linalg.matvec(Lt, beta3)
......
"""Import tasks"""
from covid.tasks.assemble_data import assemble_data
from covid.tasks.inference import mcmc
from covid.tasks.thin_posterior import thin_posterior
from covid.tasks.next_generation_matrix import next_generation_matrix
from covid.tasks.overall_rt import overall_rt
from covid.tasks.predict import predict
import covid.tasks.summarize as summarize
from covid.tasks.within_between import within_between
from covid.tasks.case_exceedance import case_exceedance
from covid.tasks.summary_geopackage import summary_geopackage
__all__ = [
"assemble_data",
"mcmc",
"thin_posterior",
"next_generation_matrix",
"overall_rt",
"predict",
"summarize",
"within_between",
"case_exceedance",
"summary_geopackage",
]
...@@ -2,6 +2,29 @@ ...@@ -2,6 +2,29 @@
to instantiate the COVID19 model""" to instantiate the COVID19 model"""
import pickle as pkl
from covid.model_spec import gather_data
def assemble_data(output_file, config): def assemble_data(output_file, config):
covar_data = {} all_data = gather_data(config)
with open(output_file, "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
from argparse import ArgumentParser
import yaml
parser = ArgumentParser(description="Bundle data into a pickled dictionary")
parser.add_argument("config_file", help="Global config file")
parser.add_argument("output_file", help="Data bundle pkl file")
args = parser.parse_args()
with open(args.config_file, "r") as f:
global_config = yaml.load(f, Loader=yaml.FullLoader)
assemble_data(args.output_file, global_config["ProcessData"])
"""Calculates case exceedance probabilities"""
import numpy as np
import pickle as pkl
import pandas as pd
def case_exceedance(input_files, lag):
"""Calculates case exceedance probabilities,
i.e. Pr(pred[lag:] < observed[lag:])
:param input_files: [data pickle, prediction pickle]
:param lag: the lag for which to calculate the exceedance
"""
with open(input_files[0], "rb") as f:
data = pkl.load(f)
with open(input_files[1], "rb") as f:
prediction = pkl.load(f)
modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1)
observed_cases = np.sum(data["cases"].to_numpy()[:, -lag:], axis=-1)
exceedance = np.mean(modelled_cases < observed_cases, axis=0)
df = pd.Series(
exceedance,
index=pd.Index(data["locations"]["lad19cd"], name="location"),
)
return df
...@@ -92,7 +92,7 @@ if __name__ == "__main__": ...@@ -92,7 +92,7 @@ if __name__ == "__main__":
] ]
# Load covariate data # Load covariate data
covar_data = model_spec.read_covariates(config) covar_data = model_spec.gather_data(config)
output_folder_path = config["output"]["results_dir"] output_folder_path = config["output"]["results_dir"]
geopackage_path = os.path.expandvars( geopackage_path = os.path.expandvars(
...@@ -111,9 +111,7 @@ if __name__ == "__main__": ...@@ -111,9 +111,7 @@ if __name__ == "__main__":
) )
print("Using posterior:", posterior_path) print("Using posterior:", posterior_path)
posterior = h5py.File( posterior = h5py.File(
os.path.expandvars( os.path.expandvars(posterior_path,),
posterior_path,
),
"r", "r",
rdcc_nbytes=1024 ** 3, rdcc_nbytes=1024 ** 3,
rdcc_nslots=1e6, rdcc_nslots=1e6,
...@@ -125,9 +123,7 @@ if __name__ == "__main__": ...@@ -125,9 +123,7 @@ if __name__ == "__main__":
beta1=posterior["samples/beta1"][idx], beta1=posterior["samples/beta1"][idx],
beta2=posterior["samples/beta2"][idx], beta2=posterior["samples/beta2"][idx],
beta3=posterior["samples/beta3"][idx], beta3=posterior["samples/beta3"][idx],
sigma=posterior["samples/sigma"][ sigma=posterior["samples/sigma"][idx,],
idx,
],
xi=posterior["samples/xi"][idx], xi=posterior["samples/xi"][idx],
gamma0=posterior["samples/gamma0"][idx], gamma0=posterior["samples/gamma0"][idx],
gamma1=posterior["samples/gamma1"][idx], gamma1=posterior["samples/gamma1"][idx],
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# pylint: disable=E402 # pylint: disable=E402
import os import os
import h5py
import pickle as pkl
from time import perf_counter from time import perf_counter
import tqdm import tqdm
import yaml import yaml
...@@ -9,7 +11,6 @@ import numpy as np ...@@ -9,7 +11,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
from covid.data import AreaCodeData
from gemlib.util import compute_state from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
...@@ -18,9 +19,6 @@ from gemlib.mcmc import MultiScanKernel ...@@ -18,9 +19,6 @@ from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis from gemlib.mcmc import AdaptiveRandomWalkMetropolis
from gemlib.mcmc import Posterior from gemlib.mcmc import Posterior
from covid.data import read_phe_cases
from covid.cli_arg_parse import cli_args
import covid.model_spec as model_spec import covid.model_spec as model_spec
tfd = tfp.distributions tfd = tfp.distributions
...@@ -28,7 +26,7 @@ tfb = tfp.bijectors ...@@ -28,7 +26,7 @@ tfb = tfp.bijectors
DTYPE = model_spec.DTYPE DTYPE = model_spec.DTYPE
def run_mcmc(config): def mcmc(data_file, output_file, config):
"""Constructs and runs the MCMC""" """Constructs and runs the MCMC"""
if tf.test.gpu_device_name(): if tf.test.gpu_device_name():
...@@ -36,24 +34,13 @@ def run_mcmc(config): ...@@ -36,24 +34,13 @@ def run_mcmc(config):
else: else:
print("Using CPU") print("Using CPU")
inference_period = [ with open(data_file, "rb") as f:
np.datetime64(x) for x in config["Global"]["inference_period"] data = pkl.load(f)
]
covar_data = model_spec.read_covariates(config)
# We load in cases and impute missing infections first, since this sets the # We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing. # time epoch which we are analysing.
cases = read_phe_cases(
config["data"]["reported_cases"],
date_low=inference_period[0],
date_high=inference_period[1],
date_type=config["data"]["case_date_type"],
pillar=config["data"]["pillar"],
).astype(DTYPE)
# Impute censored events, return cases # Impute censored events, return cases
events = model_spec.impute_censored_events(cases) events = model_spec.impute_censored_events(data["cases"].astype(DTYPE))
# Initial conditions are calculated by calculating the state # Initial conditions are calculated by calculating the state
# at the beginning of the inference period # at the beginning of the inference period
...@@ -63,13 +50,13 @@ def run_mcmc(config): ...@@ -63,13 +50,13 @@ def run_mcmc(config):
# to set up a sensible initial state. # to set up a sensible initial state.
state = compute_state( state = compute_state(
initial_state=tf.concat( initial_state=tf.concat(
[covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], [data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
axis=-1, axis=-1,
), ),
events=events, events=events,
stoichiometry=model_spec.STOICHIOMETRY, stoichiometry=model_spec.STOICHIOMETRY,
) )
start_time = state.shape[1] - cases.shape[1] start_time = state.shape[1] - data["cases"].shape[1]
initial_state = state[:, start_time, :] initial_state = state[:, start_time, :]
events = events[:, start_time:, :] events = events[:, start_time:, :]
...@@ -77,7 +64,7 @@ def run_mcmc(config): ...@@ -77,7 +64,7 @@ def run_mcmc(config):
# Construct the MCMC kernels # # Construct the MCMC kernels #
######################################################## ########################################################
model = model_spec.CovidUK( model = model_spec.CovidUK(
covariates=covar_data, covariates=data,
initial_state=initial_state, initial_state=initial_state,
initial_step=0, initial_step=0,
num_steps=events.shape[1], num_steps=events.shape[1],
...@@ -152,9 +139,9 @@ def run_mcmc(config): ...@@ -152,9 +139,9 @@ def run_mcmc(config):
prev_event_id=prev_event_id, prev_event_id=prev_event_id,
next_event_id=next_event_id, next_event_id=next_event_id,
initial_state=initial_state, initial_state=initial_state,
dmax=config["mcmc"]["dmax"], dmax=config["dmax"],
mmax=config["mcmc"]["m"], mmax=config["m"],
nmax=config["mcmc"]["nmax"], nmax=config["nmax"],
), ),
name=name, name=name,
) )
...@@ -170,7 +157,7 @@ def run_mcmc(config): ...@@ -170,7 +157,7 @@ def run_mcmc(config):
prev_event_id, target_event_id, next_event_id prev_event_id, target_event_id, next_event_id
), ),
cumulative_event_offset=initial_state, cumulative_event_offset=initial_state,
nmax=config["mcmc"]["occult_nmax"], nmax=config["occult_nmax"],
t_range=(events.shape[1] - 21, events.shape[1]), t_range=(events.shape[1] - 21, events.shape[1]),
name=name, name=name,
), ),
...@@ -181,7 +168,7 @@ def run_mcmc(config): ...@@ -181,7 +168,7 @@ def run_mcmc(config):
def make_event_multiscan_kernel(target_log_prob_fn, _): def make_event_multiscan_kernel(target_log_prob_fn, _):
return MultiScanKernel( return MultiScanKernel(
config["mcmc"]["num_event_time_updates"], config["num_event_time_updates"],
GibbsKernel( GibbsKernel(
target_log_prob_fn=target_log_prob_fn, target_log_prob_fn=target_log_prob_fn,
kernel_list=[ kernel_list=[
...@@ -234,7 +221,7 @@ def run_mcmc(config): ...@@ -234,7 +221,7 @@ def run_mcmc(config):
return results_dict return results_dict
# Build MCMC algorithm here. This will be run in bursts for memory economy # Build MCMC algorithm here. This will be run in bursts for memory economy
@tf.function(autograph=False, experimental_compile=True) @tf.function # (autograph=False, experimental_compile=True)
def sample(n_samples, init_state, thin=0, previous_results=None): def sample(n_samples, init_state, thin=0, previous_results=None):
with tf.name_scope("main_mcmc_sample_loop"): with tf.name_scope("main_mcmc_sample_loop"):
...@@ -265,9 +252,8 @@ def run_mcmc(config): ...@@ -265,9 +252,8 @@ def run_mcmc(config):
############################### ###############################
# Construct bursted MCMC loop # # Construct bursted MCMC loop #
############################### ###############################
NUM_BURSTS = int(config["mcmc"]["num_bursts"]) NUM_BURSTS = int(config["num_bursts"])
NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"]) NUM_BURST_SAMPLES = int(config["num_burst_samples"])
NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
# RNG stuff # RNG stuff
...@@ -286,10 +272,7 @@ def run_mcmc(config): ...@@ -286,10 +272,7 @@ def run_mcmc(config):
# Output file # Output file
samples, results, _ = sample(1, current_state) samples, results, _ = sample(1, current_state)
posterior = Posterior( posterior = Posterior(
os.path.join( output_file,
os.path.expandvars(config["output"]["results_dir"]),
config["output"]["posterior"],
),
sample_dict={ sample_dict={
"beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)), "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)),
"gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)), "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)),
...@@ -307,19 +290,21 @@ def run_mcmc(config): ...@@ -307,19 +290,21 @@ def run_mcmc(config):
num_samples=NUM_SAVED_SAMPLES, num_samples=NUM_SAVED_SAMPLES,
) )
posterior._file.create_dataset("initial_state"