Commit 1676ca5d authored by Chris Jewell's avatar Chris Jewell
Browse files

Adaptive Hamiltonian Monte Carlo within Gibbs implementation

CHANGES:

1. Kernel builder functions moved to `mcmc_kernel_factory.py`
2. Windowed adaptive MCMC a la STAN implemented in `inference.py`
3. Prior on beta1 tightened to improve stability.
4. Depends on `gemlib`@develop branch for tf-nightly and tfp-nightly
parent 5a862f51
......@@ -95,7 +95,7 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def beta1():
return tfd.Normal(
loc=tf.constant(0.0, dtype=DTYPE),
scale=tf.constant(1000.0, dtype=DTYPE),
scale=tf.constant(1.0, dtype=DTYPE),
)
def beta2():
......@@ -106,8 +106,8 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def sigma():
return tfd.Gamma(
concentration=tf.constant(2.0, dtype=DTYPE),
rate=tf.constant(20.0, dtype=DTYPE),
concentration=tf.constant(20.0, dtype=DTYPE),
rate=tf.constant(200.0, dtype=DTYPE),
)
def xi(beta1, sigma):
......
This diff is collapsed.
"""MCMC kernel builder functions"""
import tensorflow_probability as tfp
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate
from gemlib.mcmc import TransitionTopology
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import GibbsKernel
# Kernels
# Build Metropolis within Gibbs sampler with windowed HMC
def make_hmc_base_kernel(
step_size,
num_leapfrog_steps,
momentum_distribution,
):
def fn(target_log_prob_fn, _):
return tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
momentum_distribution=momentum_distribution,
)
return fn
def make_hmc_fast_adapt_kernel(
hmc_kernel_kwargs,
dual_averaging_kwargs,
):
def fn(target_log_prob_fn, state):
return tfp.mcmc.DualAveragingStepSizeAdaptation(
make_hmc_base_kernel(
**hmc_kernel_kwargs,
)(target_log_prob_fn, state),
**dual_averaging_kwargs,
)
return fn
def make_hmc_slow_adapt_kernel(
initial_running_variance,
hmc_kernel_kwargs,
dual_averaging_kwargs,
):
def fn(target_log_prob_fn, state):
return tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
make_hmc_fast_adapt_kernel(
hmc_kernel_kwargs, dual_averaging_kwargs
)(target_log_prob_fn, state),
initial_running_variance=initial_running_variance,
)
return fn
def make_partially_observed_step(
initial_state,
target_event_id,
prev_event_id,
next_event_id,
config,
name=None,
):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
initial_state=initial_state,
dmax=config["dmax"],
mmax=config["m"],
nmax=config["nmax"],
),
name=name,
)
return fn
def make_occults_step(
initial_state,
t_range,
prev_event_id,
target_event_id,
next_event_id,
config,
name,
):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedOccultUpdate(
target_log_prob_fn=target_log_prob_fn,
topology=TransitionTopology(
prev_event_id, target_event_id, next_event_id
),
cumulative_event_offset=initial_state,
nmax=config["occult_nmax"],
t_range=t_range,
name=name,
),
name=name,
)
return fn
def make_event_multiscan_gibbs_step(
initial_state,
t_range,
config,
):
def make_kernel_fn(target_log_prob_fn, _):
return MultiScanKernel(
config["num_event_time_updates"],
GibbsKernel(
target_log_prob_fn=target_log_prob_fn,
kernel_list=[
(
0,
make_partially_observed_step(
initial_state, 0, None, 1, config, "se_events"
),
),
(
0,
make_partially_observed_step(
initial_state, 1, 0, 2, config, "ei_events"
),
),
(
0,
make_occults_step(
initial_state,
t_range,
None,
0,
1,
config,
"se_occults",
),
),
(
0,
make_occults_step(
initial_state,
t_range,
0,
1,
2,
config,
"ei_occults",
),
),
],
name="gibbs1",
),
)
return make_kernel_fn
......@@ -2,8 +2,8 @@
ProcessData:
date_range:
- 2020-10-09
- 2021-01-01
- 2021-02-02
mobility_matrix: data/mergedflows.csv
population_size: data/c2019modagepop.csv
commute_volume: # Can be replaced by DfT traffic flow data - contact authors <c.jewell@lancaster.ac.uk>
......@@ -20,25 +20,22 @@ ProcessData:
address: "https://services1.arcgis.com/ESMARspQHYMw9BZ9/arcgis/rest/services/LAD_APR_2019_UK_NC/FeatureServer/0/query?where=1%3D1&outFields=LAD19CD,LAD19NM&returnGeometry=false&returnDistinctValues=true&orderByFields=LAD19CD&outSR=4326&f=json"
format: ons
regions:
- S # Scotland
- E # England
- W # Wales
- NI # Northern Ireland
- N # Northern Ireland
Mcmc:
dmax: 84 # Max distance to move events
nmax: 50 # Max num events per metapopulation/time to move
m: 1 # Number of metapopulations to move
nmax: 25 # Max num events per metapopulation/time to move
m: 2 # Number of metapopulations to move
occult_nmax: 15 # Max number of occults to add/delete per metapop/time
num_event_time_updates: 35 # Num event and occult updates per sweep of Gibbs MCMC sampler.
num_bursts: 200 # Number of MCMC bursts of `num_burst_samples`
num_burst_samples: 50 # Number of MCMC samples per burst
thin: 20 # Thin MCMC samples every `thin` iterations
num_event_time_updates: 5 # Num event and occult updates per sweep of Gibbs MCMC sampler.
num_bursts: 50 # Number of MCMC bursts of `num_burst_samples`
num_burst_samples: 100 # Number of MCMC samples per burst
thin: 1 # Thin MCMC samples every `thin` iterations
ThinPosterior: # Post-process further chain thinning HDF5 -> .pkl.
start: 6000
end: 10000
by: 10
start: 0
end: 5000
by: 1
Geopackage: # covid.tasks.summary_geopackage
base_geopackage: data/UK2019mod_pop.gpkg
......
......@@ -19,13 +19,15 @@ matplotlib = "^3.3.2"
xlrd = "^1.2.0"
tqdm = "^4.50.2"
openpyxl = "^3.0.5"
h5py = "^2.10.0"
gemlib = {git = "http://fhm-chicas-code.lancs.ac.uk/GEM/gemlib.git"}
h5py = "^3.1.0"
gemlib = {git = "http://fhm-chicas-code.lancs.ac.uk/GEM/gemlib.git", branch="develop"}
xarray = "^0.16.1"
seaborn = "^0.11.0"
ruffus = "^2.8.4"
tensorflow = "^2.4.0"
jedi = "^0.17.2"
psycopg2 = "^2.8.6"
pymongo = "^3.11.3"
xarray-mongodb = "^0.2.1"
[tool.poetry.dev-dependencies]
ipython = "^7.18.1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment