Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Chris Jewell
covid19uk
Commits
13cb5068
Commit
13cb5068
authored
Mar 08, 2021
by
Chris Jewell
Browse files
Inline documentation in inference.py
parent
1fe159c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
covid/tasks/inference.py
View file @
13cb5068
...
...
@@ -30,6 +30,7 @@ DTYPE = model_spec.DTYPE
def
get_weighted_running_variance
(
draws
):
"""Initialises online variance accumulator"""
prev_mean
,
prev_var
=
tf
.
nn
.
moments
(
draws
[
-
draws
.
shape
[
0
]
//
2
:],
axes
=
[
0
])
num_samples
=
tf
.
cast
(
...
...
@@ -62,7 +63,22 @@ def _fast_adapt_window(
trace_fn
=
None
,
seed
=
None
,
):
"""
In the fast adaptation window, we use the
`DualAveragingStepSizeAdaptation` kernel
to wrap an HMC kernel.
:param num_draws: Number of MCMC draws in window
:param joint_log_prob_fn: joint log posterior function
:param initial_position: initial state of the Markov chain
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel keywords args
:param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` keyword args
:param event_kernel_kwargs: EventTimesMH and Occult kernel args
:param trace_fn: function to trace kernel results
:param seed: optional random seed.
:returns: draws, kernel results, the adapted HMC step size, and variance
accumulator
"""
kernel_list
=
[
(
0
,
...
...
@@ -106,6 +122,21 @@ def _slow_adapt_window(
trace_fn
=
None
,
seed
=
None
,
):
"""In the slow adaptation phase, we adapt the HMC
step size and mass matrix together.
:param num_draws: number of MCMC iterations
:param joint_log_prob_fn: the joint posterior density function
:param initial_position: initial Markov chain state
:param initial_running_variance: initial variance accumulator
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel kwargs
:param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` kwargs
:param event_kernel_kwargs: EventTimesMH and Occults kwargs
:param trace_fn: result trace function
:param seed: optional random seed
:returns: draws, kernel results, adapted step size, the variance accumulator,
and "learned" momentum distribution for the HMC.
"""
kernel_list
=
[
(
0
,
...
...
@@ -158,8 +189,15 @@ def _fixed_window(
trace_fn
=
None
,
seed
=
None
,
):
"""Fixed step size HMC.
"""Fixed step size and mass matrix HMC.
:param num_draws: number of MCMC iterations
:param joint_log_prob_fn: joint log posterior density function
:param initial_position: initial Markov chain state
:param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kwargs
:param event_kernel_kwargs: Event and Occults kwargs
:param trace_fn: results trace function
:param seed: optional random seed
:returns: (draws, trace, final_kernel_results)
"""
kernel_list
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment