Commit bf52d237 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'mod-time-rt' into 'master'

Modified Rt calculation

See merge request !33
parents 264cea5c 2a5043c0
......@@ -14,7 +14,7 @@ from covid.tasks import (
assemble_data,
mcmc,
thin_posterior,
next_generation_matrix,
reproduction_number,
overall_rt,
predict,
summarize,
......@@ -91,11 +91,11 @@ def run_pipeline(global_config, results_directory, cli_options):
rf.transform(
input=[[process_data, thin_samples]],
filter=rf.formatter(),
output=wd("ngm.nc"),
)(next_generation_matrix)
output=wd("reproduction_number.nc"),
)(reproduction_number)
rf.transform(
input=next_generation_matrix,
input=reproduction_number,
filter=rf.formatter(),
output=wd("national_rt.xlsx"),
)(overall_rt)
......@@ -146,7 +146,7 @@ def run_pipeline(global_config, results_directory, cli_options):
# Summarisation
rf.transform(
input=next_generation_matrix,
input=reproduction_number,
filter=rf.formatter(),
output=wd("rt_summary.csv"),
)(summarize.rt)
......@@ -221,7 +221,7 @@ def run_pipeline(global_config, results_directory, cli_options):
insample7,
insample14,
medium_term,
next_generation_matrix,
reproduction_number,
]
],
rf.formatter(),
......
......@@ -3,7 +3,7 @@
from covid.tasks.assemble_data import assemble_data
from covid.tasks.inference import mcmc
from covid.tasks.thin_posterior import thin_posterior
from covid.tasks.next_generation_matrix import next_generation_matrix
from covid.tasks.next_generation_matrix import reproduction_number
from covid.tasks.overall_rt import overall_rt
from covid.tasks.predict import predict
import covid.tasks.summarize as summarize
......@@ -18,7 +18,7 @@ __all__ = [
"assemble_data",
"mcmc",
"thin_posterior",
"next_generation_matrix",
"reproduction_number",
"overall_rt",
"predict",
"summarize",
......
......@@ -10,7 +10,7 @@ from covid.util import copy_nc_attrs
from gemlib.util import compute_state
def calc_posterior_ngm(samples, initial_state, times, covar_data):
def calc_posterior_rit(samples, initial_state, times, covar_data):
"""Calculates effective reproduction number for batches of metapopulations
:param theta: a tensor of batched theta parameters [B] + theta.shape
:param xi: a tensor of batched xi parameters [B] + xi.shape
......@@ -38,7 +38,8 @@ def calc_posterior_ngm(samples, initial_state, times, covar_data):
ngm = ngm_fn(t, state_)
return ngm
return tf.vectorized_map(fn, elems=times)
ngm = tf.vectorized_map(fn, elems=times)
return tf.reduce_sum(ngm, axis=-2) # sum over destinations
return tf.vectorized_map(
r_fn,
......@@ -46,36 +47,47 @@ def calc_posterior_ngm(samples, initial_state, times, covar_data):
)
def next_generation_matrix(input_files, output_file):
CHUNKSIZE = 50
def reproduction_number(input_files, output_file):
covar_data = xarray.open_dataset(input_files[0], group="constant_data")
with open(input_files[1], "rb") as f:
samples = pkl.load(f)
num_samples = samples["seir"].shape[0]
initial_state = samples["initial_state"]
del samples["initial_state"]
times = [
samples["seir"].shape[-2] - 1,
]
times = np.arange(covar_data.coords["time"].shape[0])
# Compute ngm posterior
ngm = calc_posterior_ngm(samples, initial_state, times, covar_data)
ngm = xarray.DataArray(
ngm,
# Compute ngm posterior in chunks to prevent over-memory
r_its = []
for i in range(0, num_samples, CHUNKSIZE):
start = i
end = np.minimum(i + CHUNKSIZE, num_samples)
print(f"Chunk {start}:{end}", flush=True)
subsamples = {k: v[start:end] for k, v in samples.items()}
r_it = calc_posterior_rit(subsamples, initial_state, times, covar_data)
r_its.append(r_it)
r_it = xarray.DataArray(
tf.concat(r_its, axis=0),
coords=[
np.arange(ngm.shape[0]),
np.arange(num_samples),
covar_data.coords["time"][times],
covar_data.coords["location"],
covar_data.coords["location"],
],
dims=["iteration", "time", "dest", "src"],
dims=["iteration", "time", "location"],
)
ngm = xarray.Dataset({"ngm": ngm})
weight = covar_data["N"] / covar_data["N"].sum()
r_t = (r_it * weight).sum(dim="location")
ds = xarray.Dataset({"R_it": r_it, "R_t": r_t})
# Output
ngm.to_netcdf(output_file, group="posterior_predictive")
ds.to_netcdf(output_file, group="posterior_predictive")
copy_nc_attrs(input_files[0], output_file)
......@@ -101,4 +113,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
next_generation_matrix([args.data, args.samples], args.output)
reproduction_number([args.data, args.samples], args.output)
......@@ -10,18 +10,18 @@ from covid.summary import (
)
def overall_rt(next_generation_matrix, output_file):
ngms = xarray.open_dataset(
next_generation_matrix, group="posterior_predictive"
)["ngm"]
ngms = ngms[:, 0, :, :].drop("time")
b, _ = power_iteration(ngms)
rt = rayleigh_quotient(ngms, b)
def overall_rt(inference_data, output_file):
r_t = xarray.open_dataset(inference_data, group="posterior_predictive")[
"R_t"
]
q = np.arange(0.05, 1.0, 0.05)
rt_quantiles = pd.DataFrame(
{"Rt": np.quantile(rt, q, axis=-1)}, index=q
).T.to_excel(output_file)
quantiles = r_t.isel(time=-1).quantile(q=q)
quantiles.to_dataframe().T.to_excel(output_file)
# pd.DataFrame({"Rt": np.quantile(r_t, q, axis=-1)}, index=q).T.to_excel(
# output_file
# )
if __name__ == "__main__":
......
......@@ -20,14 +20,14 @@ def rt(input_file, output_file):
:param output_file: a .csv of mean (ci) values
"""
ngm = xarray.open_dataset(input_file, group="posterior_predictive")["ngm"]
r_it = xarray.open_dataset(input_file, group="posterior_predictive")["R_it"]
rt = ngm.sum(dim="dest").isel(time=-1).drop("time")
rt = r_it.isel(time=-1).drop("time")
rt_summary = mean_and_ci(rt, name="Rt")
exceed = np.mean(rt > 1.0, axis=0)
rt_summary = pd.DataFrame(
rt_summary, index=pd.Index(ngm.coords["dest"], name="location")
rt_summary, index=pd.Index(r_it.coords["location"], name="location")
)
rt_summary["Rt_exceed"] = exceed
rt_summary.to_csv(output_file)
......
......@@ -147,14 +147,12 @@ def summary_longformat(input_files, output_file):
df = pd.concat([df, prev_df], axis="index")
# Rt
ngms = xarray.load_dataset(input_files[4], group="posterior_predictive")[
"ngm"
rt = xarray.load_dataset(input_files[4], group="posterior_predictive")[
"R_it"
]
rt = ngms.sum(dim="dest")
rt = rt.rename({"src": "location"})
rt_summary = xarray2summarydf(rt)
rt_summary = xarray2summarydf(rt.isel(time=-1))
rt_summary["value_name"] = "R"
rt_summary["time"] = cases.coords["time"].data[-1] + np.timedelta64(1, "D")
rt_summary["time"] = rt.coords["time"].data[-1] + np.timedelta64(1, "D")
df = pd.concat([df, rt_summary], axis="index")
quantiles = df.columns[df.columns.str.startswith("0.")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment