Commit cb9c71af authored by Chris Jewell's avatar Chris Jewell
Browse files

Removed redundant EventTimesUpdate

This class was a thin wrapper around tfp.mcmc.MetropolisHastings, and was therefore just
extra code cruft.
parent b618a841
......@@ -50,83 +50,6 @@ def _move_events(event_tensor, event_id, m, from_t, to_t, n_move):
return new_state
class EventTimesUpdate(tfp.mcmc.TransitionKernel):
def __init__(
self,
target_log_prob_fn,
target_event_id,
prev_event_id,
next_event_id,
initial_state,
dmax,
mmax,
nmax,
seed=None,
name=None,
):
"""A random walk Metropolis Hastings for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the first dimension of the events tensor that we wish to move
:param prev_event_id: the position of the previous event in the events tensor
:param next_event_id: the position of the next event in the events tensor
:param initial_state: the initial state tensor
:param dmax: maximum distance to move in time
:param mmax: number of metapopulations to move
:param nmax: max number of events to move
:param seed: a random seed
:param name: the name of the update step
"""
self._seed_stream = SeedStream(seed, salt="EventTimesUpdate")
self._impl = tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
dmax=dmax,
mmax=mmax,
nmax=nmax,
initial_state=initial_state,
)
)
self._parameters = self._impl.inner_kernel.parameters.copy()
self._parameters["seed"] = seed
@property
def target_log_prob_fn(self):
return self._impl.inner_kernel.target_log_prob_fn
@property
def name(self):
return self._impl.inner_kernel.name
@property
def parameters(self):
"""Return `dict` of ``__init__`` arguments and their values."""
return self._parameters
@property
def is_calibrated(self):
return True
def one_step(self, current_state, previous_kernel_results):
"""Performs one step of an event times update.
:param current_state: the current state tensor [TxMxX]
:param previous_kernel_results: a named tuple of results.
:returns: (next_state, kernel_results)
"""
with tf.name_scope("EventTimesUpdate/one_step"):
next_state, kernel_results = self._impl.one_step(
current_state, previous_kernel_results
)
return next_state, kernel_results
def bootstrap_results(self, init_state):
with tf.name_scope("EventTimesUpdate/bootstrap_results"):
kernel_results = self._impl.bootstrap_results(init_state)
return kernel_results
def _reverse_move(move):
move["t"] = move["t"] + move["delta_t"]
move["delta_t"] = -move["delta_t"]
......@@ -151,7 +74,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the first dimension of the events
:param target_event_id: the position in the first dimension of the events
tensor that we wish to move
:param prev_event_id: the position of the previous event in the events tensor
:param next_event_id: the position of the next event in the events tensor
......
......@@ -13,9 +13,8 @@ import yaml
from covid import config
from covid.model import load_data, CovidUKStochastic
from covid.util import sanitise_parameter, sanitise_settings
from covid.impl.util import make_transition_matrix
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import EventTimesUpdate
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
###########
......@@ -129,25 +128,31 @@ def make_parameter_kernel(scale, bounded_convergence):
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
def kernel_func(logp):
return EventTimesUpdate(
target_log_prob_fn=logp,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
dmax=config["mcmc"]["dmax"],
mmax=config["mcmc"]["m"],
nmax=config["mcmc"]["nmax"],
initial_state=state_init,
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=logp,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
initial_state=state_init,
dmax=config["mcmc"]["dmax"],
mmax=config["mcmc"]["m"],
nmax=config["mcmc"]["nmax"],
),
name="event_update",
)
return kernel_func
def make_occults_step():
pass
def is_accepted(result):
if hasattr(result, "is_accepted"):
return tf.cast(result.is_accepted, DTYPE)
else:
return is_accepted(result.inner_results)
return is_accepted(result.inner_results)
def trace_results_fn(results):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment