Commit 215a0ee0 authored by Chris Jewell's avatar Chris Jewell
Browse files

Add time axis to Rt

CHANGES:

1. The posterior NGM is now not saved, only the sum over all destinations to give $R_{it}$
2. Replace Rt as the dominant eigenvalue of the NGM with weighted average by population size.
3. ngm.nc now contains both $R_{it}$ and $R_t$ posterior predictive values.
parent 264cea5c
......@@ -10,7 +10,7 @@ from covid.util import copy_nc_attrs
from gemlib.util import compute_state
def calc_posterior_ngm(samples, initial_state, times, covar_data):
def calc_posterior_rit(samples, initial_state, times, covar_data):
"""Calculates effective reproduction number for batches of metapopulations
:param theta: a tensor of batched theta parameters [B] + theta.shape
:param xi: a tensor of batched xi parameters [B] + xi.shape
......@@ -38,7 +38,8 @@ def calc_posterior_ngm(samples, initial_state, times, covar_data):
ngm = ngm_fn(t, state_)
return ngm
return tf.vectorized_map(fn, elems=times)
ngm = tf.vectorized_map(fn, elems=times)
return tf.reduce_sum(ngm, axis=-2) # sum over destinations
return tf.vectorized_map(
r_fn,
......@@ -56,26 +57,25 @@ def next_generation_matrix(input_files, output_file):
initial_state = samples["initial_state"]
del samples["initial_state"]
times = [
samples["seir"].shape[-2] - 1,
]
times = np.arange(covar_data.coords["time"].shape[0])
# Compute ngm posterior
ngm = calc_posterior_ngm(samples, initial_state, times, covar_data)
ngm = xarray.DataArray(
ngm,
r_it = calc_posterior_rit(samples, initial_state, times, covar_data)
r_it = xarray.DataArray(
r_it,
coords=[
np.arange(ngm.shape[0]),
np.arange(r_it.shape[0]),
covar_data.coords["time"][times],
covar_data.coords["location"],
covar_data.coords["location"],
],
dims=["iteration", "time", "dest", "src"],
dims=["iteration", "time", "location"],
)
ngm = xarray.Dataset({"ngm": ngm})
weight = covar_data["N"] / covar_data["N"].sum()
r_t = (r_it * weight).sum(dim="location")
ds = xarray.Dataset({"R_it": r_it, "R_t": r_t})
# Output
ngm.to_netcdf(output_file, group="posterior_predictive")
ds.to_netcdf(output_file, group="posterior_predictive")
copy_nc_attrs(input_files[0], output_file)
......
......@@ -10,18 +10,18 @@ from covid.summary import (
)
def overall_rt(next_generation_matrix, output_file):
ngms = xarray.open_dataset(
next_generation_matrix, group="posterior_predictive"
)["ngm"]
ngms = ngms[:, 0, :, :].drop("time")
b, _ = power_iteration(ngms)
rt = rayleigh_quotient(ngms, b)
def overall_rt(inference_data, output_file):
r_t = xarray.open_dataset(inference_data, group="posterior_predictive")[
"R_t"
]
q = np.arange(0.05, 1.0, 0.05)
rt_quantiles = pd.DataFrame(
{"Rt": np.quantile(rt, q, axis=-1)}, index=q
).T.to_excel(output_file)
quantiles = r_t.isel(time=-1).quantile(q=q)
quantiles.to_dataframe().T.to_excel(output_file)
# pd.DataFrame({"Rt": np.quantile(r_t, q, axis=-1)}, index=q).T.to_excel(
# output_file
# )
if __name__ == "__main__":
......
......@@ -20,14 +20,14 @@ def rt(input_file, output_file):
:param output_file: a .csv of mean (ci) values
"""
ngm = xarray.open_dataset(input_file, group="posterior_predictive")["ngm"]
r_it = xarray.open_dataset(input_file, group="posterior_predictive")["R_it"]
rt = ngm.sum(dim="dest").isel(time=-1).drop("time")
rt = r_it.isel(time=-1).drop("time")
rt_summary = mean_and_ci(rt, name="Rt")
exceed = np.mean(rt > 1.0, axis=0)
rt_summary = pd.DataFrame(
rt_summary, index=pd.Index(ngm.coords["dest"], name="location")
rt_summary, index=pd.Index(r_it.coords["location"], name="location")
)
rt_summary["Rt_exceed"] = exceed
rt_summary.to_csv(output_file)
......
......@@ -147,14 +147,12 @@ def summary_longformat(input_files, output_file):
df = pd.concat([df, prev_df], axis="index")
# Rt
ngms = xarray.load_dataset(input_files[4], group="posterior_predictive")[
"ngm"
rt = xarray.load_dataset(input_files[4], group="posterior_predictive")[
"R_it"
]
rt = ngms.sum(dim="dest")
rt = rt.rename({"src": "location"})
rt_summary = xarray2summarydf(rt)
rt_summary = xarray2summarydf(rt.isel(time=-1))
rt_summary["value_name"] = "R"
rt_summary["time"] = cases.coords["time"].data[-1] + np.timedelta64(1, "D")
rt_summary["time"] = rt.coords["time"].data[-1] + np.timedelta64(1, "D")
df = pd.concat([df, rt_summary], axis="index")
quantiles = df.columns[df.columns.str.startswith("0.")]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment