Commit 2f56eb3d authored by Chris Jewell's avatar Chris Jewell
Browse files

JIT-compile window functions

parent 6b65a5f4
......@@ -100,14 +100,21 @@ def _fast_adapt_window(
name="fast_adapt",
)
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
pkr = kernel.bootstrap_results(initial_position)
@tf.function(jit_compile=True)
def sample(current_state, previous_kernel_results):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
previous_kernel_results=previous_kernel_results,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
draws, trace, fkr = sample(initial_position, pkr)
weighted_running_variance = get_weighted_running_variance(draws[0])
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
......@@ -159,14 +166,20 @@ def _slow_adapt_window(
name="slow_adapt",
)
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
)
pkr = kernel.bootstrap_results(initial_position)
@tf.function(jit_compile=True)
def sample(current_state, previous_kernel_results):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
previous_kernel_results=pkr,
return_final_kernel_results=True,
trace_fn=trace_fn,
)
draws, trace, fkr = sample(initial_position, pkr)
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
momentum_distribution = unnest.get_outermost(
fkr.inner_results[0], "momentum_distribution"
......@@ -296,10 +309,6 @@ def run_mcmc(
output_file,
):
# first_window_size, slow_window_size, last_window_size = _get_window_sizes(
# config["num_adaptation_iterations"]
# )
first_window_size = 200
last_window_size = 50
slow_window_size = 25
......@@ -441,7 +450,6 @@ def run_mcmc(
)
pkr = kernel.bootstrap_results(current_state)
tf.profiler.experimental.start("tf_logdir")
for i in tqdm.tqdm(
range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"],
......@@ -458,7 +466,6 @@ def run_mcmc(
first_dim_offset=offset,
)
offset += config["num_burst_samples"]
tf.profiler.experimental.stop()
return posterior
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment