Commit d248d1e6 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'perf-xla-compile-time' into 'master'

Merge performance enhancements enabling feasible XLA compilation

See merge request !44
parents 44c560e1 2f56eb3d
......@@ -56,7 +56,7 @@ def _get_window_sizes(num_adaptation_steps):
return first_window_size, slow_window_size, last_window_size
@tf.function # (autograph=False, jit_compile=False)
@tf.function # (autograph=False, jit_compile=False)
def _fast_adapt_window(
num_draws,
joint_log_prob_fn,
......@@ -100,21 +100,28 @@ def _fast_adapt_window(
name="fast_adapt",
)
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
pkr = kernel.bootstrap_results(initial_position)
@tf.function(jit_compile=True)
def sample(current_state, previous_kernel_results):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
previous_kernel_results=previous_kernel_results,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
draws, trace, fkr = sample(initial_position, pkr)
weighted_running_variance = get_weighted_running_variance(draws[0])
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
return draws, trace, step_size, weighted_running_variance
@tf.function # (autograph=False, jit_compile=False)
@tf.function # (autograph=False, jit_compile=False)
def _slow_adapt_window(
num_draws,
joint_log_prob_fn,
......@@ -159,14 +166,20 @@ def _slow_adapt_window(
name="slow_adapt",
)
draws, trace, fkr = tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
)
pkr = kernel.bootstrap_results(initial_position)
@tf.function(jit_compile=True)
def sample(current_state, previous_kernel_results):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
previous_kernel_results=pkr,
return_final_kernel_results=True,
trace_fn=trace_fn,
)
draws, trace, fkr = sample(initial_position, pkr)
step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
momentum_distribution = unnest.get_outermost(
fkr.inner_results[0], "momentum_distribution"
......@@ -183,15 +196,14 @@ def _slow_adapt_window(
)
@tf.function # (autograph=False, jit_compile=False)
def _fixed_window(
def make_fixed_window_sampler(
num_draws,
joint_log_prob_fn,
initial_position,
hmc_kernel_kwargs,
event_kernel_kwargs,
trace_fn=None,
seed=None,
jit_compile=False,
):
"""Fixed step size and mass matrix HMC.
......@@ -215,14 +227,19 @@ def _fixed_window(
name="fixed",
)
return tfp.mcmc.sample_chain(
num_draws,
current_state=initial_position,
kernel=kernel,
return_final_kernel_results=True,
trace_fn=trace_fn,
seed=seed,
)
@tf.function(jit_compile=jit_compile)
def sample_fn(current_state, previous_kernel_results=None):
return tfp.mcmc.sample_chain(
num_draws,
current_state=current_state,
kernel=kernel,
return_final_kernel_results=True,
previous_kernel_results=previous_kernel_results,
trace_fn=trace_fn,
seed=seed,
)
return sample_fn, kernel
def trace_results_fn(_, results):
......@@ -292,10 +309,6 @@ def run_mcmc(
output_file,
):
# first_window_size, slow_window_size, last_window_size = _get_window_sizes(
# config["num_adaptation_iterations"]
# )
first_window_size = 200
last_window_size = 50
slow_window_size = 25
......@@ -329,14 +342,13 @@ def run_mcmc(
# Set up posterior
print("Initialising output...", end="", flush=True, file=sys.stderr)
draws, trace, _ = _fixed_window(
draws, trace, _ = make_fixed_window_sampler(
num_draws=1,
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
)[0](current_state)
posterior = Posterior(
output_file,
sample_dict=draws_to_dict(draws),
......@@ -427,19 +439,22 @@ def run_mcmc(
hmc_kernel_kwargs["step_size"] = tf.reduce_mean(
trace["hmc"]["step_size"][-last_window_size // 2 :]
)
print("Fixed kernel kwargs:", hmc_kernel_kwargs, flush=True)
fixed_sample, kernel = make_fixed_window_sampler(
config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
jit_compile=True,
)
pkr = kernel.bootstrap_results(current_state)
for i in tqdm.tqdm(
range(config["num_bursts"]),
unit_scale=config["num_burst_samples"] * config["thin"],
):
draws, trace, _ = _fixed_window(
num_draws=config["num_burst_samples"],
joint_log_prob_fn=joint_log_prob_fn,
initial_position=current_state,
hmc_kernel_kwargs=hmc_kernel_kwargs,
event_kernel_kwargs=event_kernel_kwargs,
trace_fn=trace_results_fn,
)
draws, trace, pkr = fixed_sample(current_state, pkr)
current_state = [state_part[-1] for state_part in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
......@@ -559,9 +574,6 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
),
events,
]
print("Num time steps:", events.shape[1], flush=True)
print("alpha_t shape", model.event_shape["alpha_t"], flush=True)
print("Initial chain state:", current_chain_state[0], flush=True)
print("Initial logpi:", joint_log_prob(*current_chain_state), flush=True)
# Output file
......
"""Returns tuple of semantic version"""
import re
import pkg_resources
......@@ -17,7 +18,9 @@ def _version():
ver = pkg_resources.get_distribution("covid19uk").version
except pkg_resources.DistributionNotFound:
ver = _get_version_from_pyproject()
return ver.split(".")
regex = re.compile("^(\d)\.(\d)\.(.*)")
version_crumbs = regex.match(ver).groups()
return version_crumbs
MAJOR, MINOR, PATCH = _version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment