Commit 6137beb8 authored by Chris Jewell's avatar Chris Jewell
Browse files

Implemented Posterior output

CHANGES:

1. Re-implemented trace_fn
2. Implemented new Posterior class in `gemlib`
parent edfb5b84
......@@ -10,16 +10,13 @@ import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.experimental import unnest
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc.gibbs_kernel import GibbsKernelResults
from gemlib.mcmc.gibbs_kernel import flatten_results
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
from gemlib.mcmc import Posterior
from covid.data import read_phe_cases
from covid.cli_arg_parse import cli_args
......@@ -210,32 +207,46 @@ if __name__ == "__main__":
# MCMC tracing functions
def trace_results_fn(_, results):
"""Returns log_prob, accepted, q_ratio"""
def f(result):
proposed_results = unnest.get_innermost(result, "proposed_results")
log_prob = proposed_results.target_log_prob
accepted = tf.cast(
unnest.get_innermost(result, "is_accepted"), log_prob.dtype
)
q_ratio = proposed_results.log_acceptance_correction
if hasattr(proposed_results, "extra"):
proposed = tf.cast(proposed_results.extra, log_prob.dtype)
return tf.concat(
[[log_prob], [accepted], [q_ratio], proposed], axis=0
)
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
"""Packs results into a dictionary"""
results_dict = {}
res0 = results.inner_results
results_dict["theta"] = {
"is_accepted": res0[0].inner_results.is_accepted,
"target_log_prob": res0[
0
].inner_results.accepted_results.target_log_prob,
}
results_dict["xi"] = {
"is_accepted": res0[1].is_accepted,
"target_log_prob": res0[1].accepted_results.target_log_prob,
}
def get_move_results(results):
return {
"is_accepted": results.is_accepted,
"target_log_prob": results.accepted_results.target_log_prob,
"proposed_delta": tf.stack(
[
results.accepted_results.m,
results.accepted_results.t,
results.accepted_results.delta_t,
results.accepted_results.x_star,
]
),
}
def recurse(f, results):
if isinstance(results, GibbsKernelResults):
return [recurse(f, x) for x in results.inner_results]
return f(results)
res1 = res0[2].inner_results
results_dict["move/S->E"] = get_move_results(res1[0])
results_dict["move/E->I"] = get_move_results(res1[1])
results_dict["occult/S->E"] = get_move_results(res1[2])
results_dict["occult/E->I"] = get_move_results(res1[3])
return recurse(f, results)
return results_dict
# Build MCMC algorithm here. This will be run in bursts for memory economy
@tf.function # (autograph=False, experimental_compile=True)
def sample(n_samples, init_state, previous_results=None):
@tf.function(autograph=False, experimental_compile=True)
def sample(n_samples, init_state, thin=0, previous_results=None):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
......@@ -249,10 +260,12 @@ if __name__ == "__main__":
],
name="gibbs0",
)
samples, results, final_results = tfp.mcmc.sample_chain(
n_samples,
init_state,
kernel=gibbs_schema,
num_steps_between_results=thin,
previous_kernel_results=previous_results,
return_final_kernel_results=True,
trace_fn=trace_results_fn,
......@@ -280,69 +293,19 @@ if __name__ == "__main__":
events,
]
# Output Files
posterior = h5py.File(
# Output file
samples, results, _ = sample(1, current_state)
posterior = Posterior(
os.path.join(
os.path.expandvars(config["output"]["results_dir"]),
config["output"]["posterior"],
),
"w",
rdcc_nbytes=1024 ** 2 * 400,
rdcc_nslots=100000,
libver="latest",
)
event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)
posterior.create_dataset("initial_state", data=initial_state)
# Ideally we insert the inference period into the posterior file
# as this allows us to post-attribute it to the data. Maybe better
# to simply save the data into it as well.
posterior.create_dataset("config", data=yaml.dump(config))
theta_samples = posterior.create_dataset(
"samples/theta",
[NUM_SAVED_SAMPLES, current_state[0].shape[0]],
dtype=np.float64,
)
xi_samples = posterior.create_dataset(
"samples/xi",
[NUM_SAVED_SAMPLES, current_state[1].shape[0]],
dtype=np.float64,
)
event_samples = posterior.create_dataset(
"samples/events",
event_size,
dtype=DTYPE,
chunks=(32, 32, 32, 1),
compression="szip",
compression_opts=("nn", 16),
{"theta": samples[0], "xi": samples[1]},
results,
NUM_SAVED_SAMPLES,
)
output_results = [
posterior.create_dataset(
"results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
),
posterior.create_dataset(
"results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
),
posterior.create_dataset(
"results/move/S->E",
(NUM_SAVED_SAMPLES, 3 + num_metapop),
dtype=DTYPE,
),
posterior.create_dataset(
"results/move/E->I",
(NUM_SAVED_SAMPLES, 3 + num_metapop),
dtype=DTYPE,
),
posterior.create_dataset(
"results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
),
posterior.create_dataset(
"results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
),
]
posterior.swmr_mode = True
posterior._file.create_dataset("initial_state", data=initial_state)
posterior._file.create_dataset("config", data=yaml.dump(config))
print("Initial logpi:", logp(*current_state))
......@@ -354,64 +317,73 @@ if __name__ == "__main__":
samples, results, final_results = sample(
NUM_BURST_SAMPLES,
init_state=current_state,
thin=config["mcmc"]["thin"] - 1,
previous_results=final_results,
)
current_state = [s[-1] for s in samples]
s = slice(
i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES
)
idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
theta_samples[s, ...] = tf.gather(samples[0], idx)
xi_samples[s, ...] = tf.gather(samples[1], idx)
# cov = np.cov(
# np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
# rowvar=False,
# )
print(current_state[0].numpy(), flush=True)
# print(cov, flush=True)
# if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
# theta_scale = 2.38 ** 2 * cov / 2.0
start = perf_counter()
event_samples[s, ...] = tf.gather(samples[2], idx)
posterior.write_samples(
{"theta": samples[0], "xi": samples[1]},
first_dim_offset=i * NUM_BURST_SAMPLES,
)
posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
end = perf_counter()
flat_results = flatten_results(results)
for i, ro in enumerate(output_results):
ro[s, ...] = tf.gather(flat_results[i], idx)
posterior.flush()
print("Storage time:", end - start, "seconds")
print(
"Acceptance theta:",
tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["theta"]["is_accepted"], tf.float32)
),
)
print(
"Acceptance xi:",
tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["theta"]["is_accepted"], tf.float32),
),
)
print(
"Acceptance move S->E:",
tf.reduce_mean(tf.cast(flat_results[2][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["move/S->E"]["is_accepted"], tf.float32)
),
)
print(
"Acceptance move E->I:",
tf.reduce_mean(tf.cast(flat_results[3][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["move/E->I"]["is_accepted"], tf.float32)
),
)
print(
"Acceptance occult S->E:",
tf.reduce_mean(tf.cast(flat_results[4][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)
),
)
print(
"Acceptance occult E->I:",
tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
tf.reduce_mean(
tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)
),
)
print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
print(
f"Acceptance theta: {posterior['results/theta/is_accepted'][:].mean()}"
)
print(f"Acceptance xi: {posterior['results/xi/is_accepted'][:].mean()}")
print(
f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
)
print(
f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
)
print(
f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
)
print(
f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
)
posterior.close()
del posterior
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment