Commit e1646659 authored by Chris Jewell's avatar Chris Jewell
Browse files

Rationalised forwarding of target_log_prob in Gibbs sampler.

parent 0291aa34
from pprint import pprint
from collections import namedtuple
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
......@@ -7,12 +6,16 @@ from tensorflow_probability.python.util import SeedStream
from covid import config
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
from covid.impl.mcmc import KernelResults
tfd = tfp.distributions
DTYPE = config.floatX
EventTimesKernelResults = namedtuple(
"KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
)
def _is_within(x, low, high):
"""Returns true if low <= x < high"""
return tf.logical_and(tf.less_equal(low, x), tf.less(x, high))
......@@ -304,7 +307,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
return [
next_state,
KernelResults(
EventTimesKernelResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.cast(x_star_results, current_events.dtype),
......@@ -315,7 +318,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
init_target_log_prob = self.target_log_prob_fn(init_state)
return KernelResults(
return EventTimesKernelResults(
log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
target_log_prob=init_target_log_prob,
extra=tf.zeros(init_state.shape[-3], dtype=DTYPE),
......
......@@ -4,6 +4,9 @@ from collections import namedtuple
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
from tensorflow_probability.python.mcmc.random_walk_metropolis import (
UncalibratedRandomWalkResults,
)
import covid.config
......@@ -12,20 +15,6 @@ DTYPE = covid.config.floatX
tfd = tfp.distributions
# The kernel result NamedTuple mechanism in TFP is too restrictive for our use here
# since the tfp MetropolisHastings class enforces that
# the types of the previous and current results are the same.
#
# See https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/mcmc/metropolis_hastings.py#L233)
#
# In order to trace the internals of the samplers used, we need an 'extra'
# field in the Results structure, so we implement and use our own Results tuple.
KernelResults = namedtuple(
"KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
)
def random_walk_mvnorm_fn(covariance, p_u=0.95, name=None):
"""Returns callable that adds Multivariate Normal noise to the input
:param covariance: the covariance matrix of the mvnorm proposal
......@@ -96,10 +85,9 @@ class UncalibratedLogRandomWalk(tfp.mcmc.UncalibratedRandomWalk):
return [
maybe_flatten(next_state_parts),
KernelResults(
UncalibratedRandomWalkResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.zeros(149, dtype=DTYPE),
),
]
......
......@@ -125,7 +125,7 @@ def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
prev_event_id=prev_event_id,
next_event_id=next_event_id,
dmax=2,
mmax=2,
mmax=1,
nmax=10,
initial_state=state_init,
)
......@@ -144,11 +144,21 @@ def trace_results_fn(results):
log_prob = results.proposed_results.target_log_prob
accepted = is_accepted(results)
q_ratio = results.proposed_results.log_acceptance_correction
proposed = results.proposed_results.extra
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
if hasattr(results.proposed_results, "extra"):
proposed = results.proposed_results.extra
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
else:
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
def forward_results(prev_results, next_results):
accepted_results = next_results.accepted_results._replace(
target_log_prob=prev_results.accepted_results.target_log_prob
)
return next_results._replace(accepted_results=accepted_results)
@tf.function # (experimental_compile=True)
# @tf.function(autograph=False, experimental_compile=True)
def sample(n_samples, init_state, par_scale):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
......@@ -158,35 +168,49 @@ def sample(n_samples, init_state, par_scale):
# Based on Gibbs idea posted by Pavel Sountsov
# https://github.com/tensorflow/probability/issues/495
results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
par_results = par_func(lambda p: logp(p, init_state[1])).bootstrap_results(
init_state[0]
)
se_results = se_func(lambda s: logp(init_state[0], s)).bootstrap_results(
init_state[1]
)
ei_results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
init_state[1]
)
results = [par_results, se_results, ei_results]
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
def body(i, state, prev_results, samples, results):
def body(i, state, results, sample_accum, results_accum):
# Parameters
def par_logp(par_state):
state[0] = par_state # close over state from outer scope
return logp(*state)
state[0], par_results = par_func(par_logp).one_step(state[0], prev_results)
state[0], results[0] = par_func(par_logp).one_step(
state[0], forward_results(results[2], results[0])
)
print("par ", is_accepted(results[0]))
# States
def state_logp(event_state):
state[1] = event_state
return logp(*state)
state[1], se_results = se_func(state_logp).one_step(state[1], par_results)
state[1], ei_results = ei_func(state_logp).one_step(state[1], se_results)
samples = [samples[k].write(i, s) for k, s in enumerate(state)]
results = [
results[k].write(i, trace_results_fn(r))
for k, r in enumerate([par_results, se_results, ei_results])
state[1], results[1] = se_func(state_logp).one_step(
state[1], forward_results(results[0], results[1])
)
print("se ", is_accepted(results[1]))
state[1], results[2] = ei_func(state_logp).one_step(
state[1], forward_results(results[1], results[2])
)
print("ei ", is_accepted(results[2]))
sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
results_accum = [
results_accum[k].write(i, trace_results_fn(r))
for k, r in enumerate(results)
]
return i + 1, state, ei_results, samples, results
return i + 1, state, results, sample_accum, results_accum
def cond(i, _1, _2, _3, _4):
return i < n_samples
......@@ -208,8 +232,10 @@ def sample(n_samples, init_state, par_scale):
NUM_LOOP_ITERATIONS = 1000
NUM_LOOP_SAMPLES = 100
# Initial States
# RNG stuff
tf.random.set_seed(2)
# Initial state. NB [M, T, X] layout for events.
current_state = [
np.array([0.6, 0.25], dtype=DTYPE),
tf.transpose(tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)),
......@@ -226,13 +252,17 @@ par_samples = posterior.create_dataset(
)
se_samples = posterior.create_dataset("samples/events", event_size, dtype=DTYPE)
par_results = posterior.create_dataset(
"acceptance/parameter", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE,
"acceptance/parameter", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 3), dtype=DTYPE,
)
se_results = posterior.create_dataset(
"acceptance/S->E", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE
"acceptance/S->E",
(NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 3 + model.N.shape[0]),
dtype=DTYPE,
)
ei_results = posterior.create_dataset(
"acceptance/E->I", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE
"acceptance/E->I",
(NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 3 + model.N.shape[0]),
dtype=DTYPE,
)
......@@ -243,8 +273,8 @@ par_scale = tf.linalg.diag(
# We loop over successive calls to sample because we have to dump results
# to disc, or else end OOM (even on a 32GB system).
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
for i in tqdm.tqdm(range(NUM_LOOP_ITERATIONS), unit_scale=NUM_LOOP_SAMPLES):
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
samples, results = sample(
NUM_LOOP_SAMPLES, init_state=current_state, par_scale=par_scale
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment