Commit 13cb5068 authored by Chris Jewell's avatar Chris Jewell
Browse files

Inline documentation in inference.py

parent 1fe159c2
......@@ -30,6 +30,7 @@ DTYPE = model_spec.DTYPE
def get_weighted_running_variance(draws):
"""Initialises online variance accumulator"""
prev_mean, prev_var = tf.nn.moments(draws[-draws.shape[0] // 2 :], axes=[0])
num_samples = tf.cast(
......@@ -62,7 +63,22 @@ def _fast_adapt_window(
trace_fn=None,
seed=None,
):
"""
In the fast adaptation window, we use the
`DualAveragingStepSizeAdaptation` kernel
to wrap an HMC kernel.
:param num_draws: Number of MCMC draws in window
:param joint_log_prob_fn: joint log posterior function
:param initial_position: initial state of the Markov chain
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel keywords args
:param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` keyword args
:param event_kernel_kwargs: EventTimesMH and Occult kernel args
:param trace_fn: function to trace kernel results
:param seed: optional random seed.
:returns: draws, kernel results, the adapted HMC step size, and variance
accumulator
"""
kernel_list = [
(
0,
......@@ -106,6 +122,21 @@ def _slow_adapt_window(
trace_fn=None,
seed=None,
):
"""In the slow adaptation phase, we adapt the HMC
step size and mass matrix together.
:param num_draws: number of MCMC iterations
:param joint_log_prob_fn: the joint posterior density function
:param initial_position: initial Markov chain state
:param initial_running_variance: initial variance accumulator
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel kwargs
:param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` kwargs
:param event_kernel_kwargs: EventTimesMH and Occults kwargs
:param trace_fn: result trace function
:param seed: optional random seed
:returns: draws, kernel results, adapted step size, the variance accumulator,
and "learned" momentum distribution for the HMC.
"""
kernel_list = [
(
0,
......@@ -158,8 +189,15 @@ def _fixed_window(
trace_fn=None,
seed=None,
):
"""Fixed step size HMC.
"""Fixed step size and mass matrix HMC.
:param num_draws: number of MCMC iterations
:param joint_log_prob_fn: joint log posterior density function
:param initial_position: initial Markov chain state
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kwargs
:param event_kernel_kwargs: Event and Occults kwargs
:param trace_fn: results trace function
:param seed: optional random seed
:returns: (draws, trace, final_kernel_results)
"""
kernel_list = [
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment