Commit 4f17748d authored by Chris Jewell's avatar Chris Jewell
Browse files

Changes necessary for Tiered model summary.

parent 7be08193
......@@ -4,7 +4,7 @@ import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from gemlib_tfp_extra.util import compute_state
from gemlib.util import compute_state
def mean_and_ci(arr, q=(0.025, 0.975), axis=0, name=None):
......
......@@ -281,8 +281,7 @@ if __name__ == "__main__":
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS
NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
# RNG stuff
tf.random.set_seed(2)
......@@ -300,7 +299,7 @@ if __name__ == "__main__":
os.path.expandvars(config["output"]["results_dir"]),
config["output"]["posterior"],
),
{"theta": samples[0], "xi": samples[1]},
{"theta": samples[0], "xi": samples[1], "events": samples[2]},
results,
NUM_SAVED_SAMPLES,
)
......@@ -313,7 +312,9 @@ if __name__ == "__main__":
# to disc, or else end OOM (even on a 32GB system).
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
final_results = None
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
for i in tqdm.tqdm(
range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]
):
samples, results, final_results = sample(
NUM_BURST_SAMPLES,
init_state=current_state,
......@@ -325,7 +326,7 @@ if __name__ == "__main__":
start = perf_counter()
posterior.write_samples(
{"theta": samples[0], "xi": samples[1]},
{"theta": samples[0], "xi": samples[1], "events": samples[2]},
first_dim_offset=i * NUM_BURST_SAMPLES,
)
posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
......
......@@ -62,8 +62,8 @@ def impute_censored_events(cases):
:returns: a MxTx3 tensor of events where the first two indices of
the right-most dimension contain the imputed event times.
"""
ei_events, lag_ei = impute_previous_cases(cases, 0.25)
se_events, lag_se = impute_previous_cases(ei_events, 0.5)
ei_events, lag_ei = impute_previous_cases(cases, 0.21)
se_events, lag_se = impute_previous_cases(ei_events, 0.28)
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
return tf.stack([se_events, ei_events, ir_events], axis=-1)
......@@ -190,6 +190,9 @@ def next_generation_matrix_fn(covar_data, param):
"""
def fn(t, state):
L = tf.convert_to_tensor(covar_data["L"], DTYPE)
L = L - tf.reduce_mean(L, axis=0)
C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
C = tf.linalg.set_diag(
C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
......@@ -204,9 +207,14 @@ def next_generation_matrix_fn(covar_data, param):
dtype=tf.int64,
)
xi = tf.gather(param["xi"], xi_idx)
beta = param["beta1"] * tf.math.exp(xi)
ngm = beta * (
L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
Lt = tf.gather(L, L_idx)
xB = tf.linalg.matvec(Lt, param["beta3"])
beta = tf.math.exp(xi + xB)
ngm = beta[tf.newaxis, :] * (
tf.eye(C.shape[0], dtype=state.dtype)
+ param["beta2"] * commute_volume * C / N[tf.newaxis, :]
)
......
......@@ -17,7 +17,7 @@ xlrd = "^1.2.0"
tqdm = "^4.50.2"
openpyxl = "^3.0.5"
h5py = "^2.10.0"
tf-nightly = "2.4.0.dev20201021"
tf-nightly = "2.5.0.dev20201027"
gemlib = {git = "http://fhm-chicas-code.lancs.ac.uk/GEM/gemlib.git"}
xarray = "^0.16.1"
seaborn = "^0.11.0"
......
......@@ -8,13 +8,13 @@ import pandas as pd
import geopandas as gp
import tensorflow as tf
from gemlib.util import compute_state
from covid.cli_arg_parse import cli_args
from covid.model import (
from covid.summary import (
rayleigh_quotient,
power_iteration,
)
from covid.impl.util import compute_state
from covid.summary import mean_and_ci
import model_spec
......@@ -23,6 +23,7 @@ DTYPE = model_spec.DTYPE
GIS_TEMPLATE = "data/UK2019mod_pop.gpkg"
# Reproduction number calculation
def calc_R_it(theta, xi, events, init_state, covar_data):
"""Calculates effective reproduction number for batches of metapopulations
......@@ -41,7 +42,13 @@ def calc_R_it(theta, xi, events, init_state, covar_data):
state = compute_state(init_state, events_, model_spec.STOICHIOMETRY)
state = tf.gather(state, t - 1, axis=-2) # State on final inference day
par = dict(beta1=theta_[0], beta2=theta_[1], gamma=theta_[2], xi=xi_)
par = dict(
beta1=xi_[0],
beta2=theta_[0],
beta3=xi_[1:3],
gamma=theta_[1],
xi=xi_[3:],
)
ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
ngm = ngm_fn(t, state)
......@@ -51,7 +58,7 @@ def calc_R_it(theta, xi, events, init_state, covar_data):
@tf.function
def predicted_incidence(theta, xi, init_state, init_step, num_steps):
def predicted_incidence(theta, xi, init_state, init_step, num_steps, priors):
"""Runs the simulation forward in time from `init_state` at time `init_time`
for `num_steps`.
:param theta: a tensor of batched theta parameters [B] + theta.shape
......@@ -59,26 +66,30 @@ def predicted_incidence(theta, xi, init_state, init_step, num_steps):
:param events: a [B, M, S] batched state tensor
:param init_step: the initial time step
:param num_steps: the number of steps to simulate
:returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
:param priors: the priors for gamma
:returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
transitions
"""
def sim_fn(args):
theta_, xi_, init_ = args
par = dict(beta1=theta_[0], beta2=theta_[1], gamma=theta_[2], xi=xi_)
par = dict(beta1=xi_[0], beta2=theta_[0], gamma=theta_[1], xi=xi_[1:])
model = model_spec.CovidUK(
covar_data,
initial_state=init_,
initial_step=init_step,
num_steps=num_steps,
priors=priors,
)
sim = model.sample(**par)
return sim["seir"]
events = tf.map_fn(
sim_fn, elems=(theta, xi, init_state), fn_output_signature=(tf.float64),
sim_fn,
elems=(theta, xi, init_state),
fn_output_signature=(tf.float64),
)
return events
......@@ -117,13 +128,17 @@ if __name__ == "__main__":
# Load covariate data
covar_data = model_spec.read_covariates(
config["data"], date_low=inference_period[0], date_high=inference_period[1]
config["data"],
date_low=inference_period[0],
date_high=inference_period[1],
)
# Load posterior file
posterior = h5py.File(
os.path.expandvars(
os.path.join(config["output"]["results_dir"], config["output"]["posterior"])
os.path.join(
config["output"]["results_dir"], config["output"]["posterior"]
)
),
"r",
rdcc_nbytes=1024 ** 3,
......@@ -136,11 +151,17 @@ if __name__ == "__main__":
xi = posterior["samples/xi"][idx]
events = posterior["samples/events"][idx]
init_state = posterior["initial_state"][:]
state_timeseries = compute_state(init_state, events, model_spec.STOICHIOMETRY)
state_timeseries = compute_state(
init_state, events, model_spec.STOICHIOMETRY
)
# Build model
model = model_spec.CovidUK(
covar_data, initial_state=init_state, initial_step=0, num_steps=events.shape[1],
covar_data,
initial_state=init_state,
initial_step=0,
num_steps=events.shape[1],
priors=config["mcmc"]["prior"],
)
ngms = calc_R_it(theta, xi, events, init_state, covar_data)
......@@ -148,7 +169,9 @@ if __name__ == "__main__":
rt = rayleigh_quotient(ngms, b)
q = np.arange(0.05, 1.0, 0.05)
rt_quantiles = pd.DataFrame({"Rt": np.quantile(rt, q)}, index=q).T.to_excel(
os.path.join(config["output"]["results_dir"], config["output"]["national_rt"]),
os.path.join(
config["output"]["results_dir"], config["output"]["national_rt"]
),
)
# Prediction requires simulation from the last available timepoint for 28 + 4 + 1 days
......@@ -160,13 +183,16 @@ if __name__ == "__main__":
init_state=state_timeseries[..., -1, :],
init_step=state_timeseries.shape[-2] - 1,
num_steps=70,
priors=config["mcmc"]["prior"],
)
predicted_state = compute_state(
state_timeseries[..., -1, :], prediction, model_spec.STOICHIOMETRY
)
# Prevalence now
prev_now = prevalence(predicted_state[..., 4, :], covar_data["N"], name="prev")
prev_now = prevalence(
predicted_state[..., 4, :], covar_data["N"], name="prev"
)
# Incidence of detections now
cases_now = predicted_events(prediction[..., 4:5, 2], name="cases")
......@@ -179,11 +205,21 @@ if __name__ == "__main__":
cases_56 = predicted_events(prediction[..., 4:60, 2], name="cases56")
# Prevalence at day 7
prev_7 = prevalence(predicted_state[..., 11, :], covar_data["N"], name="prev7")
prev_14 = prevalence(predicted_state[..., 18, :], covar_data["N"], name="prev14")
prev_21 = prevalence(predicted_state[..., 25, :], covar_data["N"], name="prev21")
prev_28 = prevalence(predicted_state[..., 32, :], covar_data["N"], name="prev28")
prev_56 = prevalence(predicted_state[..., 60, :], covar_data["N"], name="prev56")
prev_7 = prevalence(
predicted_state[..., 11, :], covar_data["N"], name="prev7"
)
prev_14 = prevalence(
predicted_state[..., 18, :], covar_data["N"], name="prev14"
)
prev_21 = prevalence(
predicted_state[..., 25, :], covar_data["N"], name="prev21"
)
prev_28 = prevalence(
predicted_state[..., 32, :], covar_data["N"], name="prev28"
)
prev_56 = prevalence(
predicted_state[..., 60, :], covar_data["N"], name="prev56"
)
def geosummary(geodata, summaries):
for summary in summaries:
......@@ -226,6 +262,8 @@ if __name__ == "__main__":
),
]
ltla.to_file(
os.path.join(config["output"]["results_dir"], config["output"]["geopackage"]),
os.path.join(
config["output"]["results_dir"], config["output"]["geopackage"]
),
driver="GPKG",
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment