Commit 34e1623a authored by Chris Jewell's avatar Chris Jewell
Browse files

Fix for prediction time origins

This commit fixes a bug in the predictive time-series where
the stochastic process for the force of infection was misaligned
with the stochastic process for the events.
parent 7029d4d0
...@@ -192,12 +192,12 @@ def CovidUK(covariates, initial_state, initial_step, num_steps): ...@@ -192,12 +192,12 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
b_t = alpha_0 + tf.cumsum(alpha_t) b_t = alpha_0 + tf.cumsum(alpha_t)
alpha_t_idx = tf.cast(t, tf.int64) alpha_t_idx = tf.cast(t, tf.int64)
alpha_t_ = tf.where( alpha_t_ = tf.where(
alpha_t_idx == initial_step, alpha_t_idx == 0,
alpha_0, alpha_0,
tf.gather( tf.gather(
b_t, b_t,
tf.clip_by_value( tf.clip_by_value(
alpha_t_idx - initial_step - 1, alpha_t_idx - 1,
clip_value_min=0, clip_value_min=0,
clip_value_max=alpha_t.shape[0] - 1, clip_value_max=alpha_t.shape[0] - 1,
), ),
......
...@@ -116,6 +116,7 @@ def run_pipeline(global_config, results_directory, cli_options): ...@@ -116,6 +116,7 @@ def run_pipeline(global_config, results_directory, cli_options):
output_file=output_file, output_file=output_file,
initial_step=-7, initial_step=-7,
num_steps=28, num_steps=28,
out_of_sample=True,
) )
@rf.transform( @rf.transform(
...@@ -130,6 +131,7 @@ def run_pipeline(global_config, results_directory, cli_options): ...@@ -130,6 +131,7 @@ def run_pipeline(global_config, results_directory, cli_options):
output_file=output_file, output_file=output_file,
initial_step=-14, initial_step=-14,
num_steps=28, num_steps=28,
out_of_sample=True,
) )
# Medium-term prediction # Medium-term prediction
...@@ -144,7 +146,8 @@ def run_pipeline(global_config, results_directory, cli_options): ...@@ -144,7 +146,8 @@ def run_pipeline(global_config, results_directory, cli_options):
posterior_samples=input_files[1], posterior_samples=input_files[1],
output_file=output_file, output_file=output_file,
initial_step=-1, initial_step=-1,
num_steps=61, num_steps=84,
out_of_sample=True,
) )
# Summarisation # Summarisation
......
...@@ -12,7 +12,12 @@ from gemlib.util import compute_state ...@@ -12,7 +12,12 @@ from gemlib.util import compute_state
def predicted_incidence( def predicted_incidence(
posterior_samples, init_state, covar_data, init_step, num_steps posterior_samples,
init_state,
covar_data,
init_step,
num_steps,
out_of_sample=False,
): ):
"""Runs the simulation forward in time from `init_state` at time `init_time` """Runs the simulation forward in time from `init_state` at time `init_time`
for `num_steps`. for `num_steps`.
...@@ -32,6 +37,18 @@ def predicted_incidence( ...@@ -32,6 +37,18 @@ def predicted_incidence(
posterior_samples["new_init_state"] = posterior_state[..., init_step, :] posterior_samples["new_init_state"] = posterior_state[..., init_step, :]
del posterior_samples["seir"] del posterior_samples["seir"]
# For out-of-sample prediction, we have to re-simulate the
# alpha_t trajectory given the starting point.
if out_of_sample is True:
alpha_t = posterior_samples["alpha_0"][:, tf.newaxis] + tf.cumsum(
posterior_samples["alpha_t"], axis=-1
)
if init_step > 0:
posterior_samples["alpha_0"] = alpha_t[:, init_step - 1]
# Remove alpha_t from the posterior to make TFP re-simulate it.
del posterior_samples["alpha_t"]
@tf.function @tf.function
def do_sim(): def do_sim():
def sim_fn(args): def sim_fn(args):
...@@ -62,7 +79,14 @@ def read_pkl(filename): ...@@ -62,7 +79,14 @@ def read_pkl(filename):
return pkl.load(f) return pkl.load(f)
def predict(data, posterior_samples, output_file, initial_step, num_steps): def predict(
data,
posterior_samples,
output_file,
initial_step,
num_steps,
out_of_sample=False,
):
covar_data = xarray.open_dataset(data, group="constant_data") covar_data = xarray.open_dataset(data, group="constant_data")
cases = xarray.open_dataset(data, group="observations") cases = xarray.open_dataset(data, group="observations")
...@@ -88,7 +112,12 @@ def predict(data, posterior_samples, output_file, initial_step, num_steps): ...@@ -88,7 +112,12 @@ def predict(data, posterior_samples, output_file, initial_step, num_steps):
) )
estimated_init_state, predicted_events = predicted_incidence( estimated_init_state, predicted_events = predicted_incidence(
samples, initial_state, covar_data, initial_step, num_steps samples,
initial_state,
covar_data,
initial_step,
num_steps,
out_of_sample,
) )
prediction = xarray.DataArray( prediction = xarray.DataArray(
...@@ -129,6 +158,12 @@ if __name__ == "__main__": ...@@ -129,6 +158,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-n", "--num-steps", type=int, default=1, help="Number of steps" "-n", "--num-steps", type=int, default=1, help="Number of steps"
) )
parser.add_argument(
"-o",
"--out-of-sample",
action="store_true",
help="Out of sample prediction (sample alpha_t)",
)
parser.add_argument("data_pkl", type=str, help="Covariate data pickle") parser.add_argument("data_pkl", type=str, help="Covariate data pickle")
parser.add_argument( parser.add_argument(
"posterior_samples_pkl", "posterior_samples_pkl",
...@@ -148,4 +183,5 @@ if __name__ == "__main__": ...@@ -148,4 +183,5 @@ if __name__ == "__main__":
args.output_file, args.output_file,
args.initial_step, args.initial_step,
args.num_steps, args.num_steps,
args.out_of_sample,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment