Commit 2ee6c880 authored by Chris Jewell's avatar Chris Jewell
Browse files

Re-factoring of fixed window to speed up XLA compilation

parent 07dfb0ee
......@@ -183,14 +183,14 @@ def _slow_adapt_window(
)
def _fixed_window(
def make_fixed_window_sampler(
num_draws,
joint_log_prob_fn,
initial_position,
hmc_kernel_kwargs,
event_kernel_kwargs,
trace_fn=None,
seed=None,
jit_compile=False,
):
"""Fixed step size and mass matrix HMC.
......@@ -214,16 +214,19 @@ def _fixed_window(
name="fixed",
)
results = tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
@tf.function(jit_compile=jit_compile)
def sample_fn(current_state, previous_kernel_results=None):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
return_final_kernel_results=True,
previous_kernel_results=previous_kernel_results,
trace_fn=trace_fn,
seed=seed,
)
return results
return sample_fn, kernel
def trace_results_fn(_, results):
......@@ -297,9 +300,9 @@ def run_mcmc(
# config["num_adaptation_iterations"]
# )
first_window_size = 200
first_window_size = 20 # 200
last_window_size = 50
slow_window_size = 25
slow_window_size = 2 # 25
num_slow_windows = 4
warmup_size = int(
......@@ -330,14 +333,13 @@ def run_mcmc(
# Set up posterior
print("Initialising output...", end="", flush=True, file=sys.stderr)
draws, trace, _ = _fixed_window(
draws, trace, _ = make_fixed_window_sampler(
num_draws=1,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
)[0](current_state)
posterior = Posterior(
output_file,
sample_dict=draws_to_dict(draws),
......@@ -429,27 +431,22 @@ def run_mcmc(
trace["hmc"]["step_size"][-last_window_size // 2 :]
)
@tf.function(
fixed_sample, kernel = make_fixed_window_sampler(
config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
jit_compile=True,
input_signature=[tf.TensorSpec.from_tensor(s) for s in current_state],
)
def fixed_window_closure(params, events):
return _fixed_window(
num_draws=config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
initial_position=[params, events],
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
pkr = kernel.bootstrap_results(current_state)
tf.profiler.experimental.start("tf_logdir")
for i in tqdm.tqdm(
range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"],
):
print("Current_state:", current_state)
draws, trace, _ = fixed_window_closure(*current_state)
draws, trace, pkr = fixed_sample(current_state, pkr)
current_state = [state_part[-1] for state_part in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment