Commit 3ef67e70 authored by Chris Jewell's avatar Chris Jewell
Browse files

Make NGM and prediction (almost) model agnostic

parent 6274bc57
......@@ -22,36 +22,23 @@ def calc_posterior_ngm(samples, covar_data):
"""
def r_fn(args):
beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, events_ = args
par = tf.nest.pack_sequence_as(samples, args)
t = events_.shape[-2] - 1
state = compute_state(
samples["init_state"], events_, model_spec.STOICHIOMETRY
samples["init_state"], par['seir'], model_spec.STOICHIOMETRY
)
state = tf.gather(state, t, axis=-2) # State on final inference day
par = dict(
beta1=beta1_,
beta2=beta2_,
beta3=beta3_,
sigma=sigma_,
gamma0=gamma0_,
xi=xi_,
)
del par['seir']
ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
ngm = ngm_fn(t, state)
return ngm
return tf.vectorized_map(
r_fn,
elems=(
samples["beta1"],
samples["beta2"],
samples["beta3"],
samples["sigma"],
samples["xi"],
samples["gamma0"],
samples["seir"],
),
elems=tf.nest.flatten(samples),
)
......
......@@ -9,7 +9,7 @@ from covid import model_spec
from covid.util import copy_nc_attrs
from gemlib.util import compute_state
@tf.function
def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
"""Runs the simulation forward in time from `init_state` at time `init_time`
for `num_steps`.
......@@ -21,18 +21,19 @@ def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
transitions
"""
@tf.function
posterior_state = compute_state(
posterior_samples["init_state"],
posterior_samples["seir"],
model_spec.STOICHIOMETRY,
)
posterior_samples['init_state_'] = posterior_state[..., init_step, :]
del posterior_samples['seir']
def sim_fn(args):
beta1_, beta2_, sigma_, xi_, gamma0_, gamma1_, init_ = args
par = dict(
beta1=beta1_,
beta2=beta2_,
sigma=sigma_,
xi=xi_,
gamma0=gamma0_,
gamma1=gamma1_,
)
par = tf.nest.pack_sequence_as(posterior_samples, args)
init_ = par['init_state_']
del par['init_state_']
model = model_spec.CovidUK(
covar_data,
initial_state=init_,
......@@ -42,24 +43,9 @@ def predicted_incidence(posterior_samples, covar_data, init_step, num_steps):
sim = model.sample(**par)
return sim["seir"]
posterior_state = compute_state(
posterior_samples["init_state"],
posterior_samples["seir"],
model_spec.STOICHIOMETRY,
)
init_state = posterior_state[..., init_step, :]
events = tf.map_fn(
sim_fn,
elems=(
posterior_samples["beta1"],
posterior_samples["beta2"],
posterior_samples["sigma"],
posterior_samples["xi"],
posterior_samples["gamma0"],
posterior_samples["gamma1"],
init_state,
),
elems=tf.nest.flatten(posterior_samples),
fn_output_signature=(tf.float64),
)
return init_state, events
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment