Commit 6c6616e9 authored by Chris Jewell's avatar Chris Jewell
Browse files

Integrate step size in final adaptation window.

parent 119e0dbd
......@@ -18,6 +18,7 @@ from tensorflow_probability.python.experimental.stats import sample_stats
from gemlib.util import compute_state
from gemlib.mcmc import Posterior
from gemlib.mcmc import GibbsKernel
from gemlib.distributions import BrownianMotion
from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel
......@@ -227,11 +228,16 @@ def trace_results_fn(_, results):
results_dict = {}
root_results = results.inner_results
step_size = tf.convert_to_tensor(
unnest.get_outermost(root_results[0], "step_size")
)
results_dict["hmc"] = {
"is_accepted": unnest.get_innermost(root_results[0], "is_accepted"),
"target_log_prob": unnest.get_innermost(
root_results[0], "target_log_prob"
),
"step_size": step_size,
}
def get_move_results(results):
......@@ -278,24 +284,36 @@ def run_mcmc(
output_file,
):
first_window_size, slow_window_size, last_window_size = _get_window_sizes(
config["num_adaptation_iterations"]
# first_window_size, slow_window_size, last_window_size = _get_window_sizes(
# config["num_adaptation_iterations"]
# )
first_window_size = 200
last_window_size = 50
slow_window_size = 25
num_slow_windows = 6
warmup_size = int(
first_window_size
+ slow_window_size
* ((1 - 2 ** num_slow_windows) / (1 - 2)) # sum geometric series
+ last_window_size
)
num_slow_windows = 4
hmc_kernel_kwargs = {
"step_size": 0.00001,
"num_leapfrog_steps": 4,
"step_size": 0.1,
"num_leapfrog_steps": 16,
"momentum_distribution": None,
"store_parameters_in_results": True,
}
dual_averaging_kwargs = {
"num_adaptation_steps": first_window_size,
"target_accept_prob": 0.75,
# "decay_rate": 0.80,
}
event_kernel_kwargs = {
"initial_state": initial_conditions,
"t_range": [
current_state[1].shape[-2] - 28,
current_state[1].shape[-2] - 21,
current_state[1].shape[-2],
],
"config": config,
......@@ -315,7 +333,7 @@ def run_mcmc(
output_file,
sample_dict=draws_to_dict(draws),
results_dict=trace,
num_samples=config["num_adaptation_iterations"]
num_samples=warmup_size
+ config["num_burst_samples"] * config["num_bursts"],
)
offset = 0
......@@ -323,6 +341,7 @@ def run_mcmc(
# Fast adaptation sampling
print(f"Fast window {first_window_size}", file=sys.stderr, flush=True)
dual_averaging_kwargs["num_adaptation_steps"] = first_window_size
draws, trace, step_size, running_variance = _fast_adapt_window(
num_draws=first_window_size,
joint_log_prob_fn=joint_log_prob_fn,
......@@ -345,6 +364,7 @@ def run_mcmc(
hmc_kernel_kwargs["step_size"] = step_size
for slow_window_idx in range(num_slow_windows):
window_num_draws = slow_window_size * (2 ** slow_window_idx)
dual_averaging_kwargs["num_adaptation_steps"] = window_num_draws
print(f"Slow window {window_num_draws}", file=sys.stderr, flush=True)
(
draws,
......@@ -396,7 +416,10 @@ def run_mcmc(
# Fixed window sampling
print("Sampling...", file=sys.stderr, flush=True)
hmc_kernel_kwargs["step_size"] = step_size
hmc_kernel_kwargs["step_size"] = tf.reduce_mean(
trace["hmc"]["step_size"][-last_window_size // 2 :]
)
print("Fixed kernel kwargs:", hmc_kernel_kwargs, flush=True)
for i in tqdm.tqdm(
range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"],
......
......@@ -15,6 +15,7 @@ def make_hmc_base_kernel(
step_size,
num_leapfrog_steps,
momentum_distribution,
store_parameters_in_results,
):
def fn(target_log_prob_fn, _):
return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
......@@ -22,6 +23,7 @@ def make_hmc_base_kernel(
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution,
store_parameters_in_results=store_parameters_in_results,
)
return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment