From 215a0ee0b9c1702a2f00b8b2d440d7faebe43bbd Mon Sep 17 00:00:00 2001 From: Chris Jewell Date: Tue, 16 Mar 2021 23:05:47 +0000 Subject: [PATCH 1/3] Add time axis to Rt CHANGES: 1. The posterior NGM is now not saved, only the sum over all destinations to give $R_{it}$ 2. Replace Rt as the dominant eigenvalue of the NGM with weighted average by population size. 3. ngm.nc now contains both $R_{it}$ and $R_t$ posterior predictive values. --- covid/tasks/next_generation_matrix.py | 26 +++++++++++++------------- covid/tasks/overall_rt.py | 22 +++++++++++----------- covid/tasks/summarize.py | 6 +++--- covid/tasks/summary_longformat.py | 10 ++++------ 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/covid/tasks/next_generation_matrix.py b/covid/tasks/next_generation_matrix.py index dabefca..96c6456 100644 --- a/covid/tasks/next_generation_matrix.py +++ b/covid/tasks/next_generation_matrix.py @@ -10,7 +10,7 @@ from covid.util import copy_nc_attrs from gemlib.util import compute_state -def calc_posterior_ngm(samples, initial_state, times, covar_data): +def calc_posterior_rit(samples, initial_state, times, covar_data): """Calculates effective reproduction number for batches of metapopulations :param theta: a tensor of batched theta parameters [B] + theta.shape :param xi: a tensor of batched xi parameters [B] + xi.shape @@ -38,7 +38,8 @@ def calc_posterior_ngm(samples, initial_state, times, covar_data): ngm = ngm_fn(t, state_) return ngm - return tf.vectorized_map(fn, elems=times) + ngm = tf.vectorized_map(fn, elems=times) + return tf.reduce_sum(ngm, axis=-2) # sum over destinations return tf.vectorized_map( r_fn, @@ -56,26 +57,25 @@ def next_generation_matrix(input_files, output_file): initial_state = samples["initial_state"] del samples["initial_state"] - times = [ - samples["seir"].shape[-2] - 1, - ] + times = np.arange(covar_data.coords["time"].shape[0]) # Compute ngm posterior - ngm = calc_posterior_ngm(samples, initial_state, times, covar_data) - ngm = xarray.DataArray( - ngm, + r_it = calc_posterior_rit(samples, initial_state, times, covar_data) + r_it = xarray.DataArray( + r_it, coords=[ - np.arange(ngm.shape[0]), + np.arange(r_it.shape[0]), covar_data.coords["time"][times], covar_data.coords["location"], - covar_data.coords["location"], ], - dims=["iteration", "time", "dest", "src"], + dims=["iteration", "time", "location"], ) - ngm = xarray.Dataset({"ngm": ngm}) + weight = covar_data["N"] / covar_data["N"].sum() + r_t = (r_it * weight).sum(dim="location") + ds = xarray.Dataset({"R_it": r_it, "R_t": r_t}) # Output - ngm.to_netcdf(output_file, group="posterior_predictive") + ds.to_netcdf(output_file, group="posterior_predictive") copy_nc_attrs(input_files[0], output_file) diff --git a/covid/tasks/overall_rt.py b/covid/tasks/overall_rt.py index 6a6c5e6..941d662 100644 --- a/covid/tasks/overall_rt.py +++ b/covid/tasks/overall_rt.py @@ -10,18 +10,18 @@ from covid.summary import ( ) -def overall_rt(next_generation_matrix, output_file): - - ngms = xarray.open_dataset( - next_generation_matrix, group="posterior_predictive" - )["ngm"] - ngms = ngms[:, 0, :, :].drop("time") - b, _ = power_iteration(ngms) - rt = rayleigh_quotient(ngms, b) +def overall_rt(inference_data, output_file): + + r_t = xarray.open_dataset(inference_data, group="posterior_predictive")[ + "R_t" + ] + q = np.arange(0.05, 1.0, 0.05) - rt_quantiles = pd.DataFrame( - {"Rt": np.quantile(rt, q, axis=-1)}, index=q - ).T.to_excel(output_file) + quantiles = r_t.isel(time=-1).quantile(q=q) + quantiles.to_dataframe().T.to_excel(output_file) + # pd.DataFrame({"Rt": np.quantile(r_t, q, axis=-1)}, index=q).T.to_excel( + # output_file + # ) if __name__ == "__main__": diff --git a/covid/tasks/summarize.py b/covid/tasks/summarize.py index a85cde7..be2fc7d 100644 --- a/covid/tasks/summarize.py +++ b/covid/tasks/summarize.py @@ -20,14 +20,14 @@ def rt(input_file, output_file): :param output_file: a .csv of mean (ci) values """ - ngm = xarray.open_dataset(input_file, group="posterior_predictive")["ngm"] + r_it = xarray.open_dataset(input_file, group="posterior_predictive")["R_it"] - rt = ngm.sum(dim="dest").isel(time=-1).drop("time") + rt = r_it.isel(time=-1).drop("time") rt_summary = mean_and_ci(rt, name="Rt") exceed = np.mean(rt > 1.0, axis=0) rt_summary = pd.DataFrame( - rt_summary, index=pd.Index(ngm.coords["dest"], name="location") + rt_summary, index=pd.Index(r_it.coords["location"], name="location") ) rt_summary["Rt_exceed"] = exceed rt_summary.to_csv(output_file) diff --git a/covid/tasks/summary_longformat.py b/covid/tasks/summary_longformat.py index 30fe58c..af169c4 100644 --- a/covid/tasks/summary_longformat.py +++ b/covid/tasks/summary_longformat.py @@ -147,14 +147,12 @@ def summary_longformat(input_files, output_file): df = pd.concat([df, prev_df], axis="index") # Rt - ngms = xarray.load_dataset(input_files[4], group="posterior_predictive")[ - "ngm" + rt = xarray.load_dataset(input_files[4], group="posterior_predictive")[ + "R_it" ] - rt = ngms.sum(dim="dest") - rt = rt.rename({"src": "location"}) - rt_summary = xarray2summarydf(rt) + rt_summary = xarray2summarydf(rt.isel(time=-1)) rt_summary["value_name"] = "R" - rt_summary["time"] = cases.coords["time"].data[-1] + np.timedelta64(1, "D") + rt_summary["time"] = rt.coords["time"].data[-1] + np.timedelta64(1, "D") df = pd.concat([df, rt_summary], axis="index") quantiles = df.columns[df.columns.str.startswith("0.")] -- GitLab From 4afc60f0e357f799a4438668329a2b641fd14325 Mon Sep 17 00:00:00 2001 From: Chris Jewell Date: Wed, 17 Mar 2021 09:20:25 +0000 Subject: [PATCH 2/3] `next_generation_matrix` --> `reproduction_number` --- covid/ruffus_pipeline.py | 12 ++++++------ covid/tasks/__init__.py | 4 ++-- covid/tasks/next_generation_matrix.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/covid/ruffus_pipeline.py b/covid/ruffus_pipeline.py index fe576f1..ca96b7c 100644 --- a/covid/ruffus_pipeline.py +++ b/covid/ruffus_pipeline.py @@ -14,7 +14,7 @@ from covid.tasks import ( assemble_data, mcmc, thin_posterior, - next_generation_matrix, + reproduction_number, overall_rt, predict, summarize, @@ -91,11 +91,11 @@ def run_pipeline(global_config, results_directory, cli_options): rf.transform( input=[[process_data, thin_samples]], filter=rf.formatter(), - output=wd("ngm.nc"), - )(next_generation_matrix) + output=wd("reproduction_number.nc"), + )(reproduction_number) rf.transform( - input=next_generation_matrix, + input=reproduction_number, filter=rf.formatter(), output=wd("national_rt.xlsx"), )(overall_rt) @@ -146,7 +146,7 @@ def run_pipeline(global_config, results_directory, cli_options): # Summarisation rf.transform( - input=next_generation_matrix, + input=reproduction_number, filter=rf.formatter(), output=wd("rt_summary.csv"), )(summarize.rt) @@ -221,7 +221,7 @@ def run_pipeline(global_config, results_directory, cli_options): insample7, insample14, medium_term, - next_generation_matrix, + reproduction_number, ] ], rf.formatter(), diff --git a/covid/tasks/__init__.py b/covid/tasks/__init__.py index 2612990..9d8bf6d 100644 --- a/covid/tasks/__init__.py +++ b/covid/tasks/__init__.py @@ -3,7 +3,7 @@ from covid.tasks.assemble_data import assemble_data from covid.tasks.inference import mcmc from covid.tasks.thin_posterior import thin_posterior -from covid.tasks.next_generation_matrix import next_generation_matrix +from covid.tasks.next_generation_matrix import reproduction_number from covid.tasks.overall_rt import overall_rt from covid.tasks.predict import predict import covid.tasks.summarize as summarize @@ -18,7 +18,7 @@ __all__ = [ "assemble_data", "mcmc", "thin_posterior", - "next_generation_matrix", + "reproduction_number", "overall_rt", "predict", "summarize", diff --git a/covid/tasks/next_generation_matrix.py b/covid/tasks/next_generation_matrix.py index 96c6456..3169459 100644 --- a/covid/tasks/next_generation_matrix.py +++ b/covid/tasks/next_generation_matrix.py @@ -47,7 +47,7 @@ def calc_posterior_rit(samples, initial_state, times, covar_data): ) -def next_generation_matrix(input_files, output_file): +def reproduction_number(input_files, output_file): covar_data = xarray.open_dataset(input_files[0], group="constant_data") @@ -101,4 +101,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - next_generation_matrix([args.data, args.samples], args.output) + reproduction_number([args.data, args.samples], args.output) -- GitLab From 2a5043c09740daadb9c6a380f66f830ea030020e Mon Sep 17 00:00:00 2001 From: Chris Jewell Date: Wed, 17 Mar 2021 10:27:20 +0000 Subject: [PATCH 3/3] Chunked calculation of Rt metrics to avoid OOM on the GPU. --- covid/tasks/next_generation_matrix.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/covid/tasks/next_generation_matrix.py b/covid/tasks/next_generation_matrix.py index 3169459..db3fe24 100644 --- a/covid/tasks/next_generation_matrix.py +++ b/covid/tasks/next_generation_matrix.py @@ -47,24 +47,36 @@ def calc_posterior_rit(samples, initial_state, times, covar_data): ) +CHUNKSIZE = 50 + + def reproduction_number(input_files, output_file): covar_data = xarray.open_dataset(input_files[0], group="constant_data") with open(input_files[1], "rb") as f: samples = pkl.load(f) + num_samples = samples["seir"].shape[0] initial_state = samples["initial_state"] del samples["initial_state"] times = np.arange(covar_data.coords["time"].shape[0]) - # Compute ngm posterior - r_it = calc_posterior_rit(samples, initial_state, times, covar_data) + # Compute ngm posterior in chunks to prevent over-memory + r_its = [] + for i in range(0, num_samples, CHUNKSIZE): + start = i + end = np.minimum(i + CHUNKSIZE, num_samples) + print(f"Chunk {start}:{end}", flush=True) + subsamples = {k: v[start:end] for k, v in samples.items()} + r_it = calc_posterior_rit(subsamples, initial_state, times, covar_data) + r_its.append(r_it) + r_it = xarray.DataArray( - r_it, + tf.concat(r_its, axis=0), coords=[ - np.arange(r_it.shape[0]), + np.arange(num_samples), covar_data.coords["time"][times], covar_data.coords["location"], ], -- GitLab