Commit 7adf0699 authored by Chris Jewell's avatar Chris Jewell
Browse files

Implemented kernel-ised Gibbs sampler

Changes:

1. Implemented GibbsStep and GibbsKernel classes
2. Modified mcmc.sample function to use Gibbs sampler
3. Amended bugs in event_time_mh.py and occult_proposal.py (edge cases where tf.gathers
were overshooting the bounds of the data stuctures, not apparent on a GPU but raised on
CPU).
parent 7741981b
......@@ -82,7 +82,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed
:param name: the name of the update step
"""
self._target_log_prob_fn = target_log_prob_fn
self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
self._name = name
self._parameters = dict(
......@@ -177,7 +176,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
n_move=move["x_star"],
)
next_target_log_prob = self._target_log_prob_fn(next_state)
next_target_log_prob = self.target_log_prob_fn(next_state)
# Calculate proposal mass ratio
rev_move = _reverse_move(move.copy())
......
......@@ -58,7 +58,6 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed
:param name: the name of the update step
"""
self._target_log_prob_fn = target_log_prob_fn
self._seed_stream = SeedStream(seed, salt="UncalibratedOccultUpdate")
self._name = name
self._parameters = dict(
......@@ -107,7 +106,7 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
"""
with tf.name_scope("occult_rw/onestep"):
def true_fn():
def add_occult_fn():
with tf.name_scope("true_fn"):
proposal = AddOccultProposal(
events=current_events,
......@@ -128,7 +127,7 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
log_acceptance_correction = q_rev - q_fwd
return update, next_state, log_acceptance_correction
def false_fn():
def del_occult_fn():
with tf.name_scope("false_fn"):
proposal = DelOccultProposal(current_events, self.tx_topology)
update = proposal.sample()
......@@ -152,7 +151,13 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
u = tfd.Uniform().sample()
delta, next_state, log_acceptance_correction = tf.cond(
u < 0.5, true_fn, false_fn
(u < 0.5)
& (
tf.math.count_nonzero(current_events[..., self.tx_topology.target])
> 0
),
del_occult_fn,
add_occult_fn,
)
# tf.debugging.assert_non_negative(
# next_state, message="Negative occults occurred"
......
......@@ -23,7 +23,7 @@ def AddOccultProposal(events, n_max, t_range=None, dtype=tf.int32, name=None):
def x_star():
"""Draw num to add"""
return UniformInteger(low=[0], high=[n_max + 1], dtype=dtype)
return UniformInteger(low=[1], high=[n_max + 1], dtype=dtype)
return tfd.JointDistributionNamed(dict(m=m, t=t, x_star=x_star), name=name)
......
......@@ -19,6 +19,7 @@ from covid.util import sanitise_parameter, sanitise_settings, impute_previous_ca
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
from covid.impl.gibbs import GibbsKernel, GibbsStep
###########
# TF Bits #
......@@ -85,7 +86,7 @@ def logp(theta, xi, events, occult_events):
p["beta2"] = tf.convert_to_tensor(theta[1], dtype=DTYPE)
p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
print("XI: ", p["xi"])
beta1_logp = tfd.Gamma(
concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
).log_prob(p["beta1"])
......@@ -114,32 +115,33 @@ def logp(theta, xi, events, occult_events):
# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_theta_kernel(scale, bounded_convergence):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
return GibbsStep(
0,
tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedLogRandomWalk(
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
),
name="theta_update",
)
return kernel_func
)
),
name="update_theta",
)
def make_xi_kernel(scale, bounded_convergence):
def kernel_func(logp):
return tfp.mcmc.RandomWalkMetropolis(
return GibbsStep(
1,
tfp.mcmc.RandomWalkMetropolis(
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
name="xi_update",
)
return kernel_func
),
name="xi_update",
)
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
return GibbsStep(
2,
tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=logp,
target_event_id=target_event_id,
......@@ -149,26 +151,25 @@ def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
dmax=config["mcmc"]["dmax"],
mmax=config["mcmc"]["m"],
nmax=config["mcmc"]["nmax"],
),
name="event_update",
)
return kernel_func
)
),
name="event_update",
)
def make_occults_step(target_event_id):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
return GibbsStep(
3,
tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedOccultUpdate(
target_log_prob_fn=logp,
target_event_id=target_event_id,
nmax=config["mcmc"]["occult_nmax"],
t_range=[se_events.shape[1] - 21, se_events.shape[1]],
),
name="occult_update",
)
return kernel_func
t_range=(se_events.shape[1] - 22, se_events.shape[1] - 1),
)
),
name="occult_update",
)
def is_accepted(result):
......@@ -177,152 +178,44 @@ def is_accepted(result):
return is_accepted(result.inner_results)
def trace_results_fn(results):
log_prob = results.proposed_results.target_log_prob
accepted = is_accepted(results)
q_ratio = results.proposed_results.log_acceptance_correction
if hasattr(results.proposed_results, "extra"):
proposed = tf.cast(results.proposed_results.extra, log_prob.dtype)
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
def get_tlp(results):
return results.accepted_results.target_log_prob
def put_tlp(results, target_log_prob):
accepted_results = results.accepted_results._replace(
target_log_prob=target_log_prob
)
return results._replace(accepted_results=accepted_results)
def trace_results_fn(_, results):
"""Returns log_prob, accepted, q_ratio"""
def f(result):
log_prob = result.proposed_results.target_log_prob
accepted = is_accepted(result)
q_ratio = result.proposed_results.log_acceptance_correction
if hasattr(result.proposed_results, "extra"):
proposed = tf.cast(result.proposed_results.extra, log_prob.dtype)
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
def invoke_one_step(kernel, state, previous_results, target_log_prob):
current_results = put_tlp(previous_results, target_log_prob)
new_state, new_results = kernel.one_step(state, current_results)
return new_state, new_results, get_tlp(new_results)
return [f(result) for result in results]
@tf.function(autograph=False, experimental_compile=True)
def sample(n_samples, init_state, theta_scale, xi_scale, num_event_updates):
def sample(n_samples, init_state, scale_theta, scale_xi, num_event_updates):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
theta_func = make_theta_kernel(theta_scale, 0.0)
xi_func = make_xi_kernel(xi_scale, 0.0)
se_func = make_events_step(0, None, 1)
ei_func = make_events_step(1, 0, 2)
se_occult = make_occults_step(0)
ei_occult = make_occults_step(1)
# Based on Gibbs idea posted by Pavel Sountsov
# https://github.com/tensorflow/probability/issues/495
results = [
theta_func(
lambda p: logp(p, init_state[1], init_state[2], init_state[3])
).bootstrap_results(init_state[0]),
xi_func(
lambda p: logp(init_state[0], p, init_state[2], init_state[3])
).bootstrap_results(init_state[1]),
se_func(
lambda s: logp(init_state[0], init_state[1], s, init_state[3])
).bootstrap_results(init_state[2]),
ei_func(
lambda s: logp(init_state[0], init_state[1], s, init_state[3])
).bootstrap_results(init_state[2]),
se_occult(
lambda s: logp(init_state[0], init_state[1], init_state[2], s)
).bootstrap_results(init_state[3]),
ei_occult(
lambda s: logp(init_state[0], init_state[1], init_state[2], s)
).bootstrap_results(init_state[3]),
]
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [
tf.TensorArray(DTYPE, size=n_samples) for r in range(len(results))
]
def body(i, state, results, target_log_prob, sample_accum, results_accum):
# Parameters
def theta_logp(par_state):
state[0] = par_state # close over state from outer scope
return logp(*state)
state[0], results[0], target_log_prob = invoke_one_step(
theta_func(theta_logp), state[0], results[0], target_log_prob,
)
def xi_logp(xi_state):
state[1] = xi_state
return logp(*state)
state[1], results[1], target_log_prob = invoke_one_step(
xi_func(xi_logp), state[1], results[1], target_log_prob,
)
def infec_body(j, state, results, target_log_prob):
def state_logp(event_state):
state[2] = event_state
return logp(*state)
def occult_logp(occult_state):
state[3] = occult_state
return logp(*state)
state[2], results[2], target_log_prob = invoke_one_step(
se_func(state_logp), state[2], results[2], target_log_prob
)
state[2], results[3], target_log_prob = invoke_one_step(
ei_func(state_logp), state[2], results[3], target_log_prob
)
state[3], results[4], target_log_prob = invoke_one_step(
se_occult(occult_logp), state[3], results[4], target_log_prob
)
state[3], results[5], target_log_prob = invoke_one_step(
ei_occult(occult_logp), state[3], results[5], target_log_prob
)
j += 1
return j, state, results, target_log_prob
def infec_cond(j, state, results, target_log_prob):
return j < num_event_updates
_, state, results, target_log_prob = tf.while_loop(
infec_cond,
infec_body,
loop_vars=[tf.constant(0, tf.int32), state, results, target_log_prob],
)
init_state = init_state.copy()
sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
results_accum = [
results_accum[k].write(i, trace_results_fn(r))
for k, r in enumerate(results)
]
return i + 1, state, results, target_log_prob, sample_accum, results_accum
def cond(i, *_):
return i < n_samples
_1, _2, _3, target_log_prob, samples, results = tf.while_loop(
cond=cond,
body=body,
loop_vars=[
0,
init_state,
results,
logp(*init_state),
samples_arr,
results_arr,
kernel = GibbsKernel(
[
make_theta_kernel(theta_scale, 0.0),
make_xi_kernel(xi_scale, 0.0),
make_events_step(0, None, 1),
make_events_step(1, 0, 2),
make_occults_step(0),
make_occults_step(1),
],
name="gibbs_kernel",
)
samples, results = tfp.mcmc.sample_chain(
n_samples, init_state, kernel=kernel, trace_fn=trace_results_fn
)
return [s.stack() for s in samples], [r.stack() for r in results]
return samples, results
##################
......@@ -415,8 +308,8 @@ for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
samples, results = sample(
NUM_BURST_SAMPLES,
init_state=current_state,
theta_scale=theta_scale,
xi_scale=xi_scale,
scale_theta=theta_scale,
scale_xi=xi_scale,
num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
)
current_state = [s[-1] for s in samples]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment