Commit 660672c7 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'fix-prediction-times' into 'master'

Fix for prediction time origins

See merge request !39
parents 7029d4d0 34e1623a
......@@ -192,12 +192,12 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
b_t = alpha_0 + tf.cumsum(alpha_t)
alpha_t_idx = tf.cast(t, tf.int64)
alpha_t_ = tf.where(
alpha_t_idx == initial_step,
alpha_t_idx == 0,
alpha_0,
tf.gather(
b_t,
tf.clip_by_value(
alpha_t_idx - initial_step - 1,
alpha_t_idx - 1,
clip_value_min=0,
clip_value_max=alpha_t.shape[0] - 1,
),
......
......@@ -116,6 +116,7 @@ def run_pipeline(global_config, results_directory, cli_options):
output_file=output_file,
initial_step=-7,
num_steps=28,
out_of_sample=True,
)
@rf.transform(
......@@ -130,6 +131,7 @@ def run_pipeline(global_config, results_directory, cli_options):
output_file=output_file,
initial_step=-14,
num_steps=28,
out_of_sample=True,
)
# Medium-term prediction
......@@ -144,7 +146,8 @@ def run_pipeline(global_config, results_directory, cli_options):
posterior_samples=input_files[1],
output_file=output_file,
initial_step=-1,
num_steps=61,
num_steps=84,
out_of_sample=True,
)
# Summarisation
......
......@@ -12,7 +12,12 @@ from gemlib.util import compute_state
def predicted_incidence(
posterior_samples, init_state, covar_data, init_step, num_steps
posterior_samples,
init_state,
covar_data,
init_step,
num_steps,
out_of_sample=False,
):
"""Runs the simulation forward in time from `init_state` at time `init_time`
for `num_steps`.
......@@ -32,6 +37,18 @@ def predicted_incidence(
posterior_samples["new_init_state"] = posterior_state[..., init_step, :]
del posterior_samples["seir"]
# For out-of-sample prediction, we have to re-simulate the
# alpha_t trajectory given the starting point.
if out_of_sample is True:
alpha_t = posterior_samples["alpha_0"][:, tf.newaxis] + tf.cumsum(
posterior_samples["alpha_t"], axis=-1
)
if init_step > 0:
posterior_samples["alpha_0"] = alpha_t[:, init_step - 1]
# Remove alpha_t from the posterior to make TFP re-simulate it.
del posterior_samples["alpha_t"]
@tf.function
def do_sim():
def sim_fn(args):
......@@ -62,7 +79,14 @@ def read_pkl(filename):
return pkl.load(f)
def predict(data, posterior_samples, output_file, initial_step, num_steps):
def predict(
data,
posterior_samples,
output_file,
initial_step,
num_steps,
out_of_sample=False,
):
covar_data = xarray.open_dataset(data, group="constant_data")
cases = xarray.open_dataset(data, group="observations")
......@@ -88,7 +112,12 @@ def predict(data, posterior_samples, output_file, initial_step, num_steps):
)
estimated_init_state, predicted_events = predicted_incidence(
samples, initial_state, covar_data, initial_step, num_steps
samples,
initial_state,
covar_data,
initial_step,
num_steps,
out_of_sample,
)
prediction = xarray.DataArray(
......@@ -129,6 +158,12 @@ if __name__ == "__main__":
parser.add_argument(
"-n", "--num-steps", type=int, default=1, help="Number of steps"
)
parser.add_argument(
"-o",
"--out-of-sample",
action="store_true",
help="Out of sample prediction (sample alpha_t)",
)
parser.add_argument("data_pkl", type=str, help="Covariate data pickle")
parser.add_argument(
"posterior_samples_pkl",
......@@ -148,4 +183,5 @@ if __name__ == "__main__":
args.output_file,
args.initial_step,
args.num_steps,
args.out_of_sample,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment