Commit 9fe819d7 authored by Chris Jewell's avatar Chris Jewell
Browse files

Pipeline bugfixes

CHANGES:

1. Replaced `read_phe_cases` with `CasesData`
2. Column-name fix in `case_exceedance`
3. Fix column heading in `summarize.infec_incidence`
4. Add Rt_exceed to `summarize.rt`
5. Add layer option to `summary_geopackage`
6. Update `ruffus_pipeline.py` to reflect above changes
7. Update `template_config.yaml` to reflect above changes
parent 676b1d7c
"""Covid data adaptors and support code""" """Covid data adaptors and support code"""
from covid.data.data import ( from covid.data.data import (
read_phe_cases,
read_mobility, read_mobility,
read_population, read_population,
read_traffic_flow, read_traffic_flow,
) )
from covid.data.tiers import TierData from covid.data.tiers import TierData
from covid.data.area_code import AreaCodeData from covid.data.area_code import AreaCodeData
from covid.data.case_data import CasesData
__all__ = [ __all__ = [
"TierData", "TierData",
"AreaCodeData", "AreaCodeData",
"read_phe_cases", "CasesData",
"read_mobility", "read_mobility",
"read_population", "read_population",
"read_traffic_flow", "read_traffic_flow",
......
...@@ -40,13 +40,14 @@ def gather_data(config): ...@@ -40,13 +40,14 @@ def gather_data(config):
date_range = [date_low, date_high] date_range = [date_low, date_high]
weekday = pd.date_range(date_low, date_high).weekday < 5 weekday = pd.date_range(date_low, date_high).weekday < 5
cases = data.read_phe_cases( cases = data.CasesData.process(config).to_xarray()
config["reported_cases"], # cases = data.read_phe_cases(
date_low, # config['reported_cases'],
date_high, # date_low,
pillar=config["pillar"], # date_high,
date_type=config["case_date_type"], # pillar=config['pillar'],
) # date_type=config['case_date_type'],
# )
return dict( return dict(
C=mobility.to_numpy().astype(DTYPE), C=mobility.to_numpy().astype(DTYPE),
W=commute_volume.to_numpy().astype(DTYPE), W=commute_volume.to_numpy().astype(DTYPE),
......
...@@ -12,20 +12,47 @@ def case_exceedance(input_files, lag): ...@@ -12,20 +12,47 @@ def case_exceedance(input_files, lag):
:param input_files: [data pickle, prediction pickle] :param input_files: [data pickle, prediction pickle]
:param lag: the lag for which to calculate the exceedance :param lag: the lag for which to calculate the exceedance
""" """
data_file, prediction_file = input_files
with open(input_files[0], "rb") as f: with open(data_file, "rb") as f:
data = pkl.load(f) data = pkl.load(f)
with open(input_files[1], "rb") as f: with open(prediction_file, "rb") as f:
prediction = pkl.load(f) prediction = pkl.load(f)
modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1) modelled_cases = np.sum(prediction[..., :lag, -1], axis=-1)
observed_cases = np.sum(data["cases"].to_numpy()[:, -lag:], axis=-1) observed_cases = np.sum(data["cases"][:, -lag:], axis=-1)
if observed_cases.dims[0] == "lad19cd":
observed_cases = observed_cases.rename({"lad19cd": "location"})
exceedance = np.mean(modelled_cases < observed_cases, axis=0) exceedance = np.mean(modelled_cases < observed_cases, axis=0)
df = pd.Series( return exceedance
exceedance,
index=pd.Index(data["locations"]["lad19cd"], name="location"),
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(
description="Calculates case exceedance probabilities"
)
parser.add_argument("data_file", type=str)
parser.add_argument("prediction_file", type=str)
parser.add_argument(
"-l",
"--lag",
type=int,
help="The lag for which to calculate exceedance",
default=7,
)
parser.add_argument(
"-o",
"--output",
type=str,
help="The output csv",
default="exceedance.csv",
) )
args = parser.parse_args()
return df df = case_exceedance([args.data_file, args.prediction_file], args.lag)
df.to_csv(args.output)
"""MCMC Test Rig for COVID-19 UK model""" """MCMC Test Rig for COVID-19 UK model"""
# pylint: disable=E402 # pylint: disable=E402
import os
import h5py import h5py
import pickle as pkl import pickle as pkl
from time import perf_counter from time import perf_counter
...@@ -40,6 +39,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -40,6 +39,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
# We load in cases and impute missing infections first, since this sets the # We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing. # time epoch which we are analysing.
# Impute censored events, return cases # Impute censored events, return cases
print("Data shape:", data['cases'].shape)
events = model_spec.impute_censored_events(data["cases"].astype(DTYPE)) events = model_spec.impute_censored_events(data["cases"].astype(DTYPE))
# Initial conditions are calculated by calculating the state # Initial conditions are calculated by calculating the state
...@@ -85,14 +85,6 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -85,14 +85,6 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
) )
# Build Metropolis within Gibbs sampler # Build Metropolis within Gibbs sampler
#
# Kernels are:
# Q(\theta, \theta^\prime)
# Q(\xi, \xi^\prime)
# Q(Z^{se}, Z^{se\prime}) (partially-censored)
# Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
# Q(Z^{se}, Z^{se\prime}) (occult)
# Q(Z^{ei}, Z^{ei\prime}) (occult)
def make_blk0_kernel(shape, name): def make_blk0_kernel(shape, name):
def fn(target_log_prob_fn, _): def fn(target_log_prob_fn, _):
return tfp.mcmc.TransformedTransitionKernel( return tfp.mcmc.TransformedTransitionKernel(
......
...@@ -9,7 +9,6 @@ from covid import model_spec ...@@ -9,7 +9,6 @@ from covid import model_spec
from gemlib.util import compute_state from gemlib.util import compute_state
@tf.function
def predicted_incidence(posterior_samples, covar_data, init_step, num_steps): def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
"""Runs the simulation forward in time from `init_state` at time `init_time` """Runs the simulation forward in time from `init_state` at time `init_time`
for `num_steps`. for `num_steps`.
...@@ -21,6 +20,7 @@ def predicted_incidence(posterior_samples, covar_data, init_step, num_steps): ...@@ -21,6 +20,7 @@ def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
transitions transitions
""" """
@tf.function
def sim_fn(args): def sim_fn(args):
beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args
......
...@@ -22,10 +22,12 @@ def rt(input_file, output_file): ...@@ -22,10 +22,12 @@ def rt(input_file, output_file):
rt = np.sum(ngm, axis=-2) rt = np.sum(ngm, axis=-2)
rt_summary = mean_and_ci(rt, name="Rt") rt_summary = mean_and_ci(rt, name="Rt")
exceed = np.mean(rt > 1.0, axis=0)
rt_summary = pd.DataFrame( rt_summary = pd.DataFrame(
rt_summary, index=pd.Index(ngm.coords["dest"], name="location") rt_summary, index=pd.Index(ngm.coords["dest"], name="location")
) )
rt_summary['Rt_exceed'] = exceed
rt_summary.to_csv(output_file) rt_summary.to_csv(output_file)
...@@ -55,7 +57,7 @@ def infec_incidence(input_file, output_file): ...@@ -55,7 +57,7 @@ def infec_incidence(input_file, output_file):
) )
for t in timepoints[1:]: for t in timepoints[1:]:
tmp = pd.DataFrame( tmp = pd.DataFrame(
pred_events(prediction[..., offset:t, 2], name=f"cases{t}"), pred_events(prediction[..., offset:t, 2], name=f"cases{t-offset}"),
index=idx, index=idx,
) )
abs_incidence = pd.concat([abs_incidence, tmp], axis="columns") abs_incidence = pd.concat([abs_incidence, tmp], axis="columns")
......
...@@ -30,7 +30,7 @@ def summary_geopackage(input_files, output_file, config): ...@@ -30,7 +30,7 @@ def summary_geopackage(input_files, output_file, config):
data = pkl.load(f) data = pkl.load(f)
# Load and filter geopackage # Load and filter geopackage
geo = gp.read_file(config["base_geopackage"]) geo = gp.read_file(config["base_geopackage"], layer=config["base_layer"])
geo = geo[geo["lad19cd"].isin(data["locations"]["lad19cd"])] geo = geo[geo["lad19cd"].isin(data["locations"]["lad19cd"])]
geo = geo.sort_values(by="lad19cd") geo = geo.sort_values(by="lad19cd")
......
...@@ -4,7 +4,9 @@ import h5py ...@@ -4,7 +4,9 @@ import h5py
import pickle as pkl import pickle as pkl
def thin_posterior(input_file, output_file, thin_idx): def thin_posterior(input_file, output_file, config):
thin_idx = range(config['start'], config['end'], config['by'])
f = h5py.File(input_file, "r", rdcc_nbytes=1024 ** 3, rdcc_nslots=1e6) f = h5py.File(input_file, "r", rdcc_nbytes=1024 ** 3, rdcc_nslots=1e6)
output_dict = dict( output_dict = dict(
......
"""Covid analysis utility functions""" """Covid analysis utility functions"""
import yaml
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import h5py
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import dtype_util
tfd = tfp.distributions tfd = tfp.distributions
import h5py
tfs = tfp.stats tfs = tfp.stats
def load_config(config_filename):
with open(config_filename, "r") as f:
return yaml.load(f, Loader=yaml.FullLoader)
def sanitise_parameter(par_dict): def sanitise_parameter(par_dict):
"""Sanitises a dictionary of parameters""" """Sanitises a dictionary of parameters"""
d = {key: np.float64(val) for key, val in par_dict.items()} d = {key: np.float64(val) for key, val in par_dict.items()}
...@@ -20,13 +25,19 @@ def sanitise_parameter(par_dict): ...@@ -20,13 +25,19 @@ def sanitise_parameter(par_dict):
def sanitise_settings(par_dict): def sanitise_settings(par_dict):
d = { d = {
"inference_period": np.array(par_dict["inference_period"], dtype=np.datetime64), "inference_period": np.array(
par_dict["inference_period"], dtype=np.datetime64
),
"prediction_period": np.array( "prediction_period": np.array(
par_dict["prediction_period"], dtype=np.datetime64 par_dict["prediction_period"], dtype=np.datetime64
), ),
"time_step": float(par_dict["time_step"]), "time_step": float(par_dict["time_step"]),
"holiday": np.array([np.datetime64(date) for date in par_dict["holiday"]]), "holiday": np.array(
"lockdown": np.array([np.datetime64(date) for date in par_dict["lockdown"]]), [np.datetime64(date) for date in par_dict["holiday"]]
),
"lockdown": np.array(
[np.datetime64(date) for date in par_dict["lockdown"]]
),
} }
return d return d
...@@ -199,7 +210,11 @@ def extract_locs(in_file: str, out_file: str, loc: list): ...@@ -199,7 +210,11 @@ def extract_locs(in_file: str, out_file: str, loc: list):
extract = f["prediction"][:, :, la_loc, :] extract = f["prediction"][:, :, la_loc, :]
save_sims( save_sims(
f["date"][:], extract, f["la_names"][la_loc], f["age_names"][la_loc], out_file f["date"][:],
extract,
f["la_names"][la_loc],
f["age_names"][la_loc],
out_file,
) )
f.close() f.close()
return extract return extract
...@@ -255,7 +270,9 @@ def initialise_previous_events_one_time(events, rate): ...@@ -255,7 +270,9 @@ def initialise_previous_events_one_time(events, rate):
events.index.get_level_values(2).unique(), events.index.get_level_values(2).unique(),
] ]
) )
past_events = pd.Series(past_events.numpy().flatten(), index=new_index, name="n") past_events = pd.Series(
past_events.numpy().flatten(), index=new_index, name="n"
)
print(".", flush=True, end="") print(".", flush=True, end="")
return past_events return past_events
...@@ -304,8 +321,16 @@ def jump_summary(posterior_file): ...@@ -304,8 +321,16 @@ def jump_summary(posterior_file):
f.close() f.close()
return { return {
"S->E": {"sjd": np.mean(sjd_se), "accept": accept_se, "p_null": p_null_se}, "S->E": {
"E->I": {"sjd": np.mean(sjd_ei), "accept": accept_ei, "p_null": p_null_ei}, "sjd": np.mean(sjd_se),
"accept": accept_se,
"p_null": p_null_se,
},
"E->I": {
"sjd": np.mean(sjd_ei),
"accept": accept_ei,
"p_null": p_null_ei,
},
} }
...@@ -327,7 +352,9 @@ def plot_event_posterior(posterior, simulation, metapopulation=0): ...@@ -327,7 +352,9 @@ def plot_event_posterior(posterior, simulation, metapopulation=0):
) )
ax[0][1].plot( ax[0][1].plot(
np.cumsum(posterior["samples/events"][idx, metapopulation, :, 0].T, axis=0), np.cumsum(
posterior["samples/events"][idx, metapopulation, :, 0].T, axis=0
),
color="lightblue", color="lightblue",
alpha=0.1, alpha=0.1,
) )
...@@ -349,7 +376,9 @@ def plot_event_posterior(posterior, simulation, metapopulation=0): ...@@ -349,7 +376,9 @@ def plot_event_posterior(posterior, simulation, metapopulation=0):
) )
ax[1][1].plot( ax[1][1].plot(
np.cumsum(posterior["samples/events"][idx, metapopulation, :, 1].T, axis=0), np.cumsum(
posterior["samples/events"][idx, metapopulation, :, 1].T, axis=0
),
color="lightblue", color="lightblue",
alpha=0.1, alpha=0.1,
) )
...@@ -387,11 +416,13 @@ def distribute_geom(events, rate, delta_t=1.0): ...@@ -387,11 +416,13 @@ def distribute_geom(events, rate, delta_t=1.0):
return i, events_ - failures, accum_ return i, events_ - failures, accum_
def cond(_1, events_, _2): def cond(_1, events_, _2):
return tf.reduce_sum(events_) > tf.constant(0, dtype=events.dtype) res = tf.reduce_sum(events_) > tf.constant(0, dtype=events.dtype)
return res
_1, _2, accum = tf.while_loop(cond, body, loop_vars=[1, events, accum]) _1, _2, accum = tf.while_loop(cond, body, loop_vars=[1, events, accum])
accum = accum.stack()
return tf.transpose(accum.stack(), perm=(1, 0, 2)) return tf.transpose(accum, perm=(1, 0, 2))
def reduce_diagonals(m): def reduce_diagonals(m):
...@@ -425,15 +456,20 @@ def impute_previous_cases(events, rate, delta_t=1.0): ...@@ -425,15 +456,20 @@ def impute_previous_cases(events, rate, delta_t=1.0):
num_zero_days = total_events.shape[-1] - tf.math.count_nonzero( num_zero_days = total_events.shape[-1] - tf.math.count_nonzero(
tf.cumsum(total_events, axis=-1) tf.cumsum(total_events, axis=-1)
) )
return prev_cases[..., num_zero_days:], prev_case_distn.shape[-2] - num_zero_days return (
prev_cases[..., num_zero_days:],
prev_case_distn.shape[-2] - num_zero_days,
)
def mean_sojourn(in_events, out_events, init_state): def mean_sojourn(in_events, out_events, init_state):
"""Calculated the mean sojourn time for individuals in a state """Calculated the mean sojourn time for individuals in a state
within `in_events` and `out_events` given initial state `init_state`""" within `in_events` and `out_events` given initial state `init_state`"""
# state.shape = [..., M, T] # state.shape = [..., M, T]
state = tf.cumsum(in_events - out_events, axis=-1, exclusive=True) + init_state state = (
tf.cumsum(in_events - out_events, axis=-1, exclusive=True) + init_state
)
state = tf.reduce_sum(state, axis=(-2, -1)) state = tf.reduce_sum(state, axis=(-2, -1))
events = tf.reduce_sum(out_events, axis=(-2, -1)) events = tf.reduce_sum(out_events, axis=(-2, -1))
...@@ -460,7 +496,10 @@ def regularize_occults(events, occults, init_state, stoichiometry): ...@@ -460,7 +496,10 @@ def regularize_occults(events, occults, init_state, stoichiometry):
first_neg_state_idx = tf.gather( first_neg_state_idx = tf.gather(
neg_state_idx, neg_state_idx,
tf.concat( tf.concat(
[[[0]], tf.where(neg_state_idx[:-1, 0] - neg_state_idx[1:, 0]) + 1], [
[[0]],
tf.where(neg_state_idx[:-1, 0] - neg_state_idx[1:, 0]) + 1,
],
axis=0, axis=0,
), ),
) )
...@@ -474,7 +513,9 @@ def regularize_occults(events, occults, init_state, stoichiometry): ...@@ -474,7 +513,9 @@ def regularize_occults(events, occults, init_state, stoichiometry):
new_occults = tf.clip_by_value( new_occults = tf.clip_by_value(
occults_ - delta_occults, clip_value_min=0.0, clip_value_max=1.0e6 occults_ - delta_occults, clip_value_min=0.0, clip_value_max=1.0e6
) )
new_state = compute_state(init_state, events + new_occults, stoichiometry) new_state = compute_state(
init_state, events + new_occults, stoichiometry
)
return new_state, new_occults return new_state, new_occults
def cond(state_, _): def cond(state_, _):
......
...@@ -81,7 +81,7 @@ if __name__ == "__main__": ...@@ -81,7 +81,7 @@ if __name__ == "__main__":
output=work_dir("thin_samples.pkl"), output=work_dir("thin_samples.pkl"),
) )
def thin_samples(input_file, output_file): def thin_samples(input_file, output_file):
thin_posterior(input_file, output_file, range(100)) thin_posterior(input_file, output_file, config["ThinPosterior"])
# Rt related steps # Rt related steps
rf.transform( rf.transform(
...@@ -174,7 +174,8 @@ if __name__ == "__main__": ...@@ -174,7 +174,8 @@ if __name__ == "__main__":
exceed7 = case_exceedance((input_files[0], input_files[1]), 7) exceed7 = case_exceedance((input_files[0], input_files[1]), 7)
exceed14 = case_exceedance((input_files[0], input_files[2]), 14) exceed14 = case_exceedance((input_files[0], input_files[2]), 14)
df = pd.DataFrame( df = pd.DataFrame(
{"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14} {"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14},
index=exceed7.coords["location"],
) )
df.to_csv(output_file) df.to_csv(output_file)
......
...@@ -2,15 +2,24 @@ ...@@ -2,15 +2,24 @@
ProcessData: ProcessData:
date_range: date_range:
- 2020-10-11 - 2020-10-09
- 2021-01-04 - 2021-01-01
mobility_matrix: data/mergedflows.csv mobility_matrix: data/mergedflows.csv
population_size: data/c2019modagepop.csv population_size: data/c2019modagepop.csv
commute_volume: data/201231_OFF_SEN_COVID19_road_traffic_national_table.xlsx commute_volume: data/201231_OFF_SEN_COVID19_road_traffic_national_table.xlsx
reported_cases: data/Anonymised Combined Line List 20201231.csv reported_cases: data/Anonymised Combined Line List 20210104.csv
case_date_type: specimen case_date_type: specimen
pillar: both pillar: both
CasesData:
input: csv
address: data/Anonymised Combined Line List 20210104.csv
pillars:
- Pillar 1
- Pillar 2
measure: specimen
format: phe
AreaCodeData: AreaCodeData:
input: json input: json
address: "https://services1.arcgis.com/ESMARspQHYMw9BZ9/arcgis/rest/services/LAD_APR_2019_UK_NC/FeatureServer/0/query?where=1%3D1&outFields=LAD19CD,LAD19NM&returnGeometry=false&returnDistinctValues=true&orderByFields=LAD19CD&outSR=4326&f=json" address: "https://services1.arcgis.com/ESMARspQHYMw9BZ9/arcgis/rest/services/LAD_APR_2019_UK_NC/FeatureServer/0/query?where=1%3D1&outFields=LAD19CD,LAD19NM&returnGeometry=false&returnDistinctValues=true&orderByFields=LAD19CD&outSR=4326&f=json"
...@@ -28,12 +37,15 @@ Mcmc: ...@@ -28,12 +37,15 @@ Mcmc:
m: 1 m: 1
occult_nmax: 15 occult_nmax: 15
num_event_time_updates: 35 num_event_time_updates: 35
num_bursts: 1 num_bursts: 200
num_burst_samples: 100 num_burst_samples: 50
thin: 1 thin: 20
ThinPosterior:
start: 6000
end: 10000
by: 10
Geopackage: Geopackage:
base_geopackage: data/UK2019mod_pop.gpkg base_geopackage: data/UK2019mod_pop.gpkg
base_layer: UK2019mod_pop_xgen
NextGenerationMatrix: \ No newline at end of file
output_file: next_generation_matrix.pkl
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment