Commit 1fe159c2 authored by Chris Jewell's avatar Chris Jewell
Browse files

Set size of output dataset automatically

CHANGES:

1. Size of posterior dataset is calculated automatically;
2. Require `num_adaptation_iterations` in config file;
3. Set progress bar to measure iterations.
parent 1676ca5d
......@@ -42,6 +42,15 @@ def get_weighted_running_variance(draws):
return weighted_running_variance
def _get_window_sizes(num_adaptation_steps):
slow_window_size = num_adaptation_steps // 21
first_window_size = 3 * slow_window_size
last_window_size = (
num_adaptation_steps - 15 * slow_window_size - first_window_size
)
return first_window_size, slow_window_size, last_window_size
@tf.function
def _fast_adapt_window(
num_draws,
......@@ -226,9 +235,9 @@ def run_mcmc(
joint_log_prob_fn, current_state, initial_conditions, config, output_file
):
num_draws = config["num_bursts"] * config["num_burst_samples"]
fast_window_size = 75
slow_window_size = 25
first_window_size, slow_window_size, last_window_size = _get_window_sizes(
config["num_adaptation_iterations"]
)
num_slow_windows = 4
hmc_kernel_kwargs = {
......@@ -237,7 +246,7 @@ def run_mcmc(
"momentum_distribution": None,
}
dual_averaging_kwargs = {
"num_adaptation_steps": fast_window_size,
"num_adaptation_steps": first_window_size,
"target_accept_prob": 0.75,
}
event_kernel_kwargs = {
......@@ -262,14 +271,15 @@ def run_mcmc(
output_file,
sample_dict=draws_to_dict(draws),
results_dict=trace,
num_samples=5000 + 75 + 25 + 50 + 100 + 200 + 75,
num_samples=config["num_adaptation_iterations"]
+ config["num_burst_samples"] * config["num_bursts"],
)
offset = 0
# Fast adaptation sampling
print(f"Fast window {fast_window_size}")
print(f"Fast window {first_window_size}")
draws, trace, step_size, running_variance = _fast_adapt_window(
num_draws=fast_window_size,
num_draws=first_window_size,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
......@@ -282,7 +292,7 @@ def run_mcmc(
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += fast_window_size
offset += first_window_size
current_state = [s[-1] for s in draws]
# Slow adaptation sampling
......@@ -317,9 +327,10 @@ def run_mcmc(
offset += window_num_draws
# Fast adaptation sampling
print(f"Fast window {fast_window_size}")
print(f"Fast window {last_window_size}")
dual_averaging_kwargs["num_adaptation_steps"] = last_window_size
draws, trace, step_size, weighted_running_variance = _fast_adapt_window(
num_draws=fast_window_size,
num_draws=last_window_size,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
......@@ -333,12 +344,15 @@ def run_mcmc(
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += fast_window_size
offset += last_window_size
# Fixed window sampling
print("Sampling...")
hmc_kernel_kwargs["step_size"] = step_size
for i in tqdm.tqdm(range(config["num_bursts"])):
for i in tqdm.tqdm(
range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"],
):
draws, trace, _ = _fixed_window(
num_draws=config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
......
......@@ -31,6 +31,7 @@ Mcmc:
num_bursts: 50 # Number of MCMC bursts of `num_burst_samples`
num_burst_samples: 100 # Number of MCMC samples per burst
thin: 1 # Thin MCMC samples every `thin` iterations
num_adaptation_iterations: 1000
ThinPosterior: # Post-process further chain thinning HDF5 -> .pkl.
start: 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment