Commit b618a841 authored by Chris Jewell's avatar Chris Jewell
Browse files

Random scan algorithm for event times

Changes:

1. Fixed bug in inner tf.while_loop in sample()
2. Made number of event time updates per sweep a config file
parameter.
parent 47c402dd
......@@ -43,15 +43,15 @@ parser.add_option(
dest="config",
default="ode_config.yaml",
help="configuration file",
)
)
options, args = parser.parse_args()
print("Loading config file:", options.config)
with open(options.config, "r") as f:
config = yaml.load(f)
print("Config:",config)
print("Config:", config)
param = sanitise_parameter(config["parameter"])
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}
......@@ -169,7 +169,7 @@ def forward_results(prev_results, next_results):
@tf.function(autograph=False, experimental_compile=True)
def sample(n_samples, init_state, par_scale):
def sample(n_samples, init_state, par_scale, num_event_updates):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
par_func = make_parameter_kernel(par_scale, 0.95)
......@@ -203,16 +203,31 @@ def sample(n_samples, init_state, par_scale):
)
# States
def state_logp(event_state):
state[1] = event_state
return logp(*state)
state[1], results[1] = se_func(state_logp).one_step(
state[1], forward_results(results[0], results[1])
)
state[1], results[2] = ei_func(state_logp).one_step(
state[1], forward_results(results[1], results[2])
results[2] = forward_results(results[0], results[2])
def infec_body(j, state, results):
def state_logp(event_state):
state[1] = event_state
return logp(*state)
state[1], results[1] = se_func(state_logp).one_step(
state[1], forward_results(results[2], results[1])
)
state[1], results[2] = ei_func(state_logp).one_step(
state[1], forward_results(results[1], results[2])
)
j += 1
return j, state, results
def infec_cond(j, state, results):
return j < num_event_updates
_, state, results = tf.while_loop(
infec_cond,
infec_body,
loop_vars=[tf.constant(0, tf.int32), state, results],
)
sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
results_accum = [
results_accum[k].write(i, trace_results_fn(r))
......@@ -239,6 +254,7 @@ def sample(n_samples, init_state, par_scale):
# MCMC Control
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
# RNG stuff
tf.random.set_seed(2)
......@@ -284,7 +300,10 @@ par_scale = tf.linalg.diag(
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
samples, results = sample(
NUM_BURST_SAMPLES, init_state=current_state, par_scale=par_scale
NUM_BURST_SAMPLES,
init_state=current_state,
par_scale=par_scale,
num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
)
current_state = [s[-1] for s in samples]
s = slice(i * NUM_BURST_SAMPLES, i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES)
......@@ -296,7 +315,7 @@ for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
print(current_state[0].numpy())
print(cov)
if np.all(np.isfinite(cov)):
par_scale = 2.38 ** 2 * cov / 2.0
par_scale = 2.0 ** 2 * cov / 2.0
se_samples[s, ...] = samples[1].numpy()
par_results[s, ...] = results[0].numpy()
......
......@@ -34,12 +34,13 @@ settings:
- 2020-08-01
mcmc:
dmax: 25
nmax: 15
m: 3
dmax: 32
nmax: 160
m: 1
num_event_time_updates: 1
num_bursts: 100
num_burst_samples: 100
output:
posterior: posterior_medium2.h5
posterior: posterior_medium_upd1.h5
simulation: covid_ode.hd5
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment