Commit 1676ca5d authored by Chris Jewell's avatar Chris Jewell
Browse files

Adaptive Hamiltonian Monte Carlo within Gibbs implementation

CHANGES:

1. Kernel builder functions moved to `mcmc_kernel_factory.py`
2. Windowed adaptive MCMC a la STAN implemented in `inference.py`
3. Prior on beta1 tightened to improve stability.
4. Depends on `gemlib`@develop branch for tf-nightly and tfp-nightly
parent 5a862f51
......@@ -95,7 +95,7 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def beta1():
return tfd.Normal(
loc=tf.constant(0.0, dtype=DTYPE),
scale=tf.constant(1000.0, dtype=DTYPE),
scale=tf.constant(1.0, dtype=DTYPE),
)
def beta2():
......@@ -106,8 +106,8 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def sigma():
return tfd.Gamma(
concentration=tf.constant(2.0, dtype=DTYPE),
rate=tf.constant(20.0, dtype=DTYPE),
concentration=tf.constant(20.0, dtype=DTYPE),
rate=tf.constant(200.0, dtype=DTYPE),
)
def xi(beta1, sigma):
......
......@@ -10,13 +10,17 @@ import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import unnest
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.experimental.stats import sample_stats
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
from gemlib.mcmc import Posterior
from gemlib.mcmc import GibbsKernel
from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_event_multiscan_gibbs_step
import covid.model_spec as model_spec
......@@ -25,168 +29,161 @@ tfb = tfp.bijectors
DTYPE = model_spec.DTYPE
def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
"""Constructs and runs the MCMC"""
def get_weighted_running_variance(draws):
if tf.test.gpu_device_name():
print("Using GPU")
else:
print("Using CPU")
with open(data_file, "rb") as f:
data = pkl.load(f)
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
# Impute censored events, return cases
print("Data shape:", data["cases"].shape)
events = model_spec.impute_censored_events(data["cases"].astype(DTYPE))
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
#
# Imputed censored events that pre-date the first I-R events
# in the cases dataset are discarded. They are only used to
# to set up a sensible initial state.
state = compute_state(
initial_state=tf.concat(
[data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
axis=-1,
),
events=events,
stoichiometry=model_spec.STOICHIOMETRY,
prev_mean, prev_var = tf.nn.moments(draws[-draws.shape[0] // 2 :], axes=[0])
num_samples = tf.cast(
draws.shape[0] / 2,
dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32),
)
start_time = state.shape[1] - data["cases"].shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
########################################################
# Construct the MCMC kernels #
########################################################
model = model_spec.CovidUK(
covariates=data,
initial_state=initial_state,
initial_step=0,
num_steps=events.shape[1],
weighted_running_variance = sample_stats.RunningVariance.from_stats(
num_samples=num_samples, mean=prev_mean, variance=prev_var
)
return weighted_running_variance
@tf.function
def _fast_adapt_window(
num_draws,
joint_log_prob_fn,
initial_position,
hmc_kernel_kwargs,
dual_averaging_kwargs,
event_kernel_kwargs,
trace_fn=None,
seed=None,
):
kernel_list = [
(
0,
make_hmc_fast_adapt_kernel(
hmc_kernel_kwargs=hmc_kernel_kwargs,
dual_averaging_kwargs=dual_averaging_kwargs,
),
),
(1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
]
def joint_log_prob(block0, block1, events):
return model.log_prob(
dict(
beta2=block0[0],
gamma0=block0[1],
gamma1=block0[2],
sigma=block0[3],
beta1=block1[0],
xi=block1[1:],
seir=events,
kernel = GibbsKernel(
target_log_prob_fn=joint_log_prob_fn,
kernel_list=kernel_list,
name="fast_adapt",
)
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
# Build Metropolis within Gibbs sampler
def make_blk0_kernel(shape, name):
def fn(target_log_prob_fn, _):
return tfp.mcmc.TransformedTransitionKernel(
inner_kernel=AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
* 1e-1,
covariance_burnin=200,
weighted_running_variance = get_weighted_running_variance(draws[0])
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
return draws, trace, step_size, weighted_running_variance
@tf.function
def _slow_adapt_window(
num_draws,
joint_log_prob_fn,
initial_position,
initial_running_variance,
hmc_kernel_kwargs,
dual_averaging_kwargs,
event_kernel_kwargs,
trace_fn=None,
seed=None,
):
kernel_list = [
(
0,
make_hmc_slow_adapt_kernel(
initial_running_variance,
hmc_kernel_kwargs,
dual_averaging_kwargs,
),
bijector=tfp.bijectors.Blockwise(
bijectors=[
tfp.bijectors.Exp(),
tfp.bijectors.Identity(),
tfp.bijectors.Exp(),
# tfp.bijectors.Identity(),
],
block_sizes=[1, 2, 1], # , 5],
),
name=name,
(1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
]
kernel = GibbsKernel(
target_log_prob_fn=joint_log_prob_fn,
kernel_list=kernel_list,
name="slow_adapt",
)
return fn
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
)
def make_blk1_kernel(shape, name):
def fn(target_log_prob_fn, _):
return AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
* 1e-1,
covariance_burnin=200,
name=name,
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
momentum_distribution = unnest.get_outermost(
fkr.inner_results[0], "momentum_distribution"
)
return fn
weighted_running_variance = get_weighted_running_variance(draws[0])
def make_partially_observed_step(
target_event_id, prev_event_id=None, next_event_id=None, name=None
):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
initial_state=initial_state,
dmax=config["dmax"],
mmax=config["m"],
nmax=config["nmax"],
),
name=name,
return (
draws,
trace,
step_size,
weighted_running_variance,
momentum_distribution,
)
return fn
def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedOccultUpdate(
target_log_prob_fn=target_log_prob_fn,
topology=TransitionTopology(
prev_event_id, target_event_id, next_event_id
),
cumulative_event_offset=initial_state,
nmax=config["occult_nmax"],
t_range=(events.shape[1] - 21, events.shape[1]),
name=name,
),
name=name,
)
@tf.function # (experimental_compile=True)
def _fixed_window(
num_draws,
joint_log_prob_fn,
initial_position,
hmc_kernel_kwargs,
event_kernel_kwargs,
trace_fn=None,
seed=None,
):
"""Fixed step size HMC.
:returns: (draws, trace, final_kernel_results)
"""
kernel_list = [
(0, make_hmc_base_kernel(**hmc_kernel_kwargs)),
(1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
]
return fn
kernel = GibbsKernel(
target_log_prob_fn=joint_log_prob_fn,
kernel_list=kernel_list,
name="fixed",
)
def make_event_multiscan_kernel(target_log_prob_fn, _):
return MultiScanKernel(
config["num_event_time_updates"],
GibbsKernel(
target_log_prob_fn=target_log_prob_fn,
kernel_list=[
(0, make_partially_observed_step(0, None, 1, "se_events")),
(0, make_partially_observed_step(1, 0, 2, "ei_events")),
(0, make_occults_step(None, 0, 1, "se_occults")),
(0, make_occults_step(0, 1, 2, "ei_occults")),
],
name="gibbs1",
),
return tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
# MCMC tracing functions
def trace_results_fn(_, results):
def trace_results_fn(_, results):
"""Packs results into a dictionary"""
results_dict = {}
res0 = results.inner_results
root_results = results.inner_results
results_dict["block0"] = {
"is_accepted": res0[0].inner_results.is_accepted,
"target_log_prob": res0[
0
].inner_results.accepted_results.target_log_prob,
}
results_dict["block1"] = {
"is_accepted": res0[1].is_accepted,
"target_log_prob": res0[1].accepted_results.target_log_prob,
results_dict["hmc"] = {
"is_accepted": unnest.get_innermost(root_results[0], "is_accepted"),
"target_log_prob": unnest.get_innermost(
root_results[0], "target_log_prob"
),
}
def get_move_results(results):
......@@ -203,7 +200,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
),
}
res1 = res0[2].inner_results
res1 = root_results[1].inner_results
results_dict["move/S->E"] = get_move_results(res1[0])
results_dict["move/E->I"] = get_move_results(res1[1])
results_dict["occult/S->E"] = get_move_results(res1[2])
......@@ -211,134 +208,267 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
return results_dict
# Build MCMC algorithm here. This will be run in bursts for memory economy
@tf.function(autograph=use_autograph, experimental_compile=use_xla)
def sample(n_samples, init_state, thin=0, previous_results=None):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
def draws_to_dict(draws):
return {
"beta2": draws[0][:, 0],
"gamma0": draws[0][:, 1],
"gamma1": draws[0][:, 2],
"sigma": draws[0][:, 3],
"beta3": tf.zeros([1, 5], dtype=DTYPE),
"beta1": draws[0][:, 4],
"xi": draws[0][:, 5:],
"events": draws[1],
}
def run_mcmc(
joint_log_prob_fn, current_state, initial_conditions, config, output_file
):
num_draws = config["num_bursts"] * config["num_burst_samples"]
fast_window_size = 75
slow_window_size = 25
num_slow_windows = 4
gibbs_schema = GibbsKernel(
target_log_prob_fn=joint_log_prob,
kernel_list=[
(0, make_blk0_kernel(init_state[0].shape, "block0")),
(1, make_blk1_kernel(init_state[1].shape, "block1")),
(2, make_event_multiscan_kernel),
hmc_kernel_kwargs = {
"step_size": 0.00001,
"num_leapfrog_steps": 4,
"momentum_distribution": None,
}
dual_averaging_kwargs = {
"num_adaptation_steps": fast_window_size,
"target_accept_prob": 0.75,
}
event_kernel_kwargs = {
"initial_state": initial_conditions,
"t_range": [
current_state[1].shape[-2] - 21,
current_state[1].shape[-2],
],
name="gibbs0",
)
"config": config,
}
samples, results, final_results = tfp.mcmc.sample_chain(
n_samples,
init_state,
kernel=gibbs_schema,
num_steps_between_results=thin,
previous_kernel_results=previous_results,
return_final_kernel_results=True,
# Set up posterior
draws, trace, _ = _fixed_window(
num_draws=1,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
posterior = Posterior(
output_file,
sample_dict=draws_to_dict(draws),
results_dict=trace,
num_samples=5000 + 75 + 25 + 50 + 100 + 200 + 75,
)
offset = 0
# Fast adaptation sampling
print(f"Fast window {fast_window_size}")
draws, trace, step_size, running_variance = _fast_adapt_window(
num_draws=fast_window_size,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
dual_averaging_kwargs=dual_averaging_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
posterior.write_samples(
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += fast_window_size
current_state = [s[-1] for s in draws]
# Slow adaptation sampling
hmc_kernel_kwargs["step_size"] = step_size
for slow_window_idx in range(num_slow_windows):
window_num_draws = slow_window_size * (2 ** slow_window_idx)
print(f"Slow window {window_num_draws}")
(
draws,
trace,
step_size,
running_variance,
momentum_distribution,
) = _slow_adapt_window(
num_draws=window_num_draws,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
initial_running_variance=running_variance,
hmc_kernel_kwargs=hmc_kernel_kwargs,
dual_averaging_kwargs=dual_averaging_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
hmc_kernel_kwargs["step_size"] = step_size
hmc_kernel_kwargs["momentum_distribution"] = momentum_distribution
current_state = [s[-1] for s in draws]
posterior.write_samples(
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += window_num_draws
# Fast adaptation sampling
print(f"Fast window {fast_window_size}")
draws, trace, step_size, weighted_running_variance = _fast_adapt_window(
num_draws=fast_window_size,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
dual_averaging_kwargs=dual_averaging_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
current_state = [s[-1] for s in draws]
posterior.write_samples(
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += fast_window_size
# Fixed window sampling
print("Sampling...")
hmc_kernel_kwargs["step_size"] = step_size
for i in tqdm.tqdm(range(config["num_bursts"])):
draws, trace, _ = _fixed_window(
num_draws=config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
current_state = [state_part[-1] for state_part in draws]
posterior.write_samples(
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(
trace,
first_dim_offset=offset,
)
offset += config["num_burst_samples"]
return posterior
def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
"""Constructs and runs the MCMC"""
if tf.test.gpu_device_name():
print("Using GPU")
else:
print("Using CPU")
with open(data_file, "rb") as f:
data = pkl.load(f)
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
# Impute censored events, return cases
events = model_spec.impute_censored_events(data["cases"].astype(DTYPE))
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
#
# Imputed censored events that pre-date the first I-R events
# in the cases dataset are discarded. They are only used to
# to set up a sensible initial state.
state = compute_state(
initial_state=tf.concat(
[data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
axis=-1,
),
events=events,
stoichiometry=model_spec.STOICHIOMETRY,
)
start_time = state.shape[1] - data["cases"].shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
return samples, results, final_results
########################################################
# Construct the MCMC kernels #
########################################################
model = model_spec.CovidUK(
covariates=data,
initial_state=initial_state,
initial_step=0,
num_steps=events.shape[1],
)
def joint_log_prob(unconstrained_params, events):
bij = tfb.Invert( # Forward transform unconstrains params
tfb.Blockwise(
[
tfb.Softplus(low=dtype_util.eps(DTYPE)),
tfb.Identity(),
tfb.Softplus(low=dtype_util.eps(DTYPE)),
tfb.Identity(),
],
block_sizes=[1, 2, 1, unconstrained_params.shape[-1] - 4],
)
)
params = bij.inverse(unconstrained_params)
return (
model.log_prob(
dict(
beta2=params[0],
gamma0=params[1],
gamma1=params[2],
sigma=params[3],
beta1=params[4],
xi=params[5:],
seir=events,
)
)
+ bij.inverse_log_det_jacobian(unconstrained_params, event_ndims=1)
)
# MCMC tracing functions
###############################
# Construct bursted MCMC loop #
###############################
NUM_BURSTS = int(config["num_bursts"])
NUM_BURST_SAMPLES = int(config["num_burst_samples"])
NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
# RNG stuff
tf.random.set_seed(2)
current_state = [
np.array(
[0.6, 0.0, 0.0, 0.1], dtype=DTYPE
), # , 0.0, 0.0, 0.0, 0.0, 0.0], dtype=DT