Commit 6c6616e9 authored by Chris Jewell's avatar Chris Jewell
Browse files

Integrate step size in final adaptation window.

parent 119e0dbd
...@@ -18,6 +18,7 @@ from tensorflow_probability.python.experimental.stats import sample_stats ...@@ -18,6 +18,7 @@ from tensorflow_probability.python.experimental.stats import sample_stats
from gemlib.util import compute_state from gemlib.util import compute_state
from gemlib.mcmc import Posterior from gemlib.mcmc import Posterior
from gemlib.mcmc import GibbsKernel from gemlib.mcmc import GibbsKernel
from gemlib.distributions import BrownianMotion
from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel
...@@ -227,11 +228,16 @@ def trace_results_fn(_, results): ...@@ -227,11 +228,16 @@ def trace_results_fn(_, results):
results_dict = {} results_dict = {}
root_results = results.inner_results root_results = results.inner_results
step_size = tf.convert_to_tensor(
unnest.get_outermost(root_results[0], "step_size")
)
results_dict["hmc"] = { results_dict["hmc"] = {
"is_accepted": unnest.get_innermost(root_results[0], "is_accepted"), "is_accepted": unnest.get_innermost(root_results[0], "is_accepted"),
"target_log_prob": unnest.get_innermost( "target_log_prob": unnest.get_innermost(
root_results[0], "target_log_prob" root_results[0], "target_log_prob"
), ),
"step_size": step_size,
} }
def get_move_results(results): def get_move_results(results):
...@@ -278,24 +284,36 @@ def run_mcmc( ...@@ -278,24 +284,36 @@ def run_mcmc(
output_file, output_file,
): ):
first_window_size, slow_window_size, last_window_size = _get_window_sizes( # first_window_size, slow_window_size, last_window_size = _get_window_sizes(
config["num_adaptation_iterations"] # config["num_adaptation_iterations"]
# )
first_window_size = 200
last_window_size = 50
slow_window_size = 25
num_slow_windows = 6
warmup_size = int(
first_window_size
+ slow_window_size
* ((1 - 2 ** num_slow_windows) / (1 - 2)) # sum geometric series
+ last_window_size
) )
num_slow_windows = 4
hmc_kernel_kwargs = { hmc_kernel_kwargs = {
"step_size": 0.00001, "step_size": 0.1,
"num_leapfrog_steps": 4, "num_leapfrog_steps": 16,
"momentum_distribution": None, "momentum_distribution": None,
"store_parameters_in_results": True,
} }
dual_averaging_kwargs = { dual_averaging_kwargs = {
"num_adaptation_steps": first_window_size,
"target_accept_prob": 0.75, "target_accept_prob": 0.75,
# "decay_rate": 0.80,
} }
event_kernel_kwargs = { event_kernel_kwargs = {
"initial_state": initial_conditions, "initial_state": initial_conditions,
"t_range": [ "t_range": [
current_state[1].shape[-2] - 28, current_state[1].shape[-2] - 21,
current_state[1].shape[-2], current_state[1].shape[-2],
], ],
"config": config, "config": config,
...@@ -315,7 +333,7 @@ def run_mcmc( ...@@ -315,7 +333,7 @@ def run_mcmc(
output_file, output_file,
sample_dict=draws_to_dict(draws), sample_dict=draws_to_dict(draws),
results_dict=trace, results_dict=trace,
num_samples=config["num_adaptation_iterations"] num_samples=warmup_size
+ config["num_burst_samples"] * config["num_bursts"], + config["num_burst_samples"] * config["num_bursts"],
) )
offset = 0 offset = 0
...@@ -323,6 +341,7 @@ def run_mcmc( ...@@ -323,6 +341,7 @@ def run_mcmc(
# Fast adaptation sampling # Fast adaptation sampling
print(f"Fast window {first_window_size}", file=sys.stderr, flush=True) print(f"Fast window {first_window_size}", file=sys.stderr, flush=True)
dual_averaging_kwargs["num_adaptation_steps"] = first_window_size
draws, trace, step_size, running_variance = _fast_adapt_window( draws, trace, step_size, running_variance = _fast_adapt_window(
num_draws=first_window_size, num_draws=first_window_size,
joint_log_prob_fn=joint_log_prob_fn, joint_log_prob_fn=joint_log_prob_fn,
...@@ -345,6 +364,7 @@ def run_mcmc( ...@@ -345,6 +364,7 @@ def run_mcmc(
hmc_kernel_kwargs["step_size"] = step_size hmc_kernel_kwargs["step_size"] = step_size
for slow_window_idx in range(num_slow_windows): for slow_window_idx in range(num_slow_windows):
window_num_draws = slow_window_size * (2 ** slow_window_idx) window_num_draws = slow_window_size * (2 ** slow_window_idx)
dual_averaging_kwargs["num_adaptation_steps"] = window_num_draws
print(f"Slow window {window_num_draws}", file=sys.stderr, flush=True) print(f"Slow window {window_num_draws}", file=sys.stderr, flush=True)
( (
draws, draws,
...@@ -396,7 +416,10 @@ def run_mcmc( ...@@ -396,7 +416,10 @@ def run_mcmc(
# Fixed window sampling # Fixed window sampling
print("Sampling...", file=sys.stderr, flush=True) print("Sampling...", file=sys.stderr, flush=True)
hmc_kernel_kwargs["step_size"] = step_size hmc_kernel_kwargs["step_size"] = tf.reduce_mean(
trace["hmc"]["step_size"][-last_window_size // 2 :]
)
print("Fixed kernel kwargs:", hmc_kernel_kwargs, flush=True)
for i in tqdm.tqdm( for i in tqdm.tqdm(
range(config["num_bursts"]), range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"], unit_scale=config["num_burst_samples"] * config["thin"],
......
...@@ -15,6 +15,7 @@ def make_hmc_base_kernel( ...@@ -15,6 +15,7 @@ def make_hmc_base_kernel(
step_size, step_size,
num_leapfrog_steps, num_leapfrog_steps,
momentum_distribution, momentum_distribution,
store_parameters_in_results,
): ):
def fn(target_log_prob_fn, _): def fn(target_log_prob_fn, _):
return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
...@@ -22,6 +23,7 @@ def make_hmc_base_kernel( ...@@ -22,6 +23,7 @@ def make_hmc_base_kernel(
step_size=step_size, step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps, num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution, momentum_distribution=momentum_distribution,
store_parameters_in_results=store_parameters_in_results,
) )
return fn return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment