Commit e72cfe76 authored by Chris Jewell's avatar Chris Jewell
Browse files

Changes:

1. Formatting -- black style!
2. After profiling, implemented a better workaround for the
tfd.Categorical bug.  Replaced with an inherited Categorical2 class,
with overloaded log_prob function.
3. Fixed a  bug in EventTimeProposal where tfd.FiniteDiscrete was
being used on a non-strictly-increasing outcomes vector.
4. Increased number of meta-populations to 149 to mirror UK UTLA regions.
parent 705ce40b
......@@ -6,8 +6,7 @@ import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream
from covid import config
from covid.impl.event_time_proposal import TransitionTopology, \
FilteredEventTimeProposal
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
from covid.impl.mcmc import KernelResults
tfd = tfp.distributions
......@@ -20,12 +19,12 @@ def _is_within(x, low, high):
def _nonzero_rows(m):
return tf.cast(tf.reduce_sum(m, axis=-1) > 0., m.dtype)
return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)
def _max_free_events(events, initial_state,
target_t, target_id,
constraint_t, constraint_id):
def _max_free_events(
events, initial_state, target_t, target_id, constraint_t, constraint_id
):
"""Returns the maximum number of free events to move in target_events constrained by
constraining_events.
:param events: a [T, M, X] tensor of transition events
......@@ -42,14 +41,13 @@ def _max_free_events(events, initial_state,
target_cumsum = tf.cumsum(target_events_, axis=0)
constraining_events = tf.gather(events, constraint_id, axis=-1) # TxM
constraining_cumsum = tf.cumsum(constraining_events, axis=0) # TxM
constraining_init_state = tf.gather(initial_state, constraint_id + 1,
axis=-1)
constraining_init_state = tf.gather(initial_state, constraint_id + 1, axis=-1)
n1 = tf.gather(target_cumsum, constraint_t, axis=0)
n2 = tf.gather(constraining_cumsum, constraint_t, axis=0)
free_events = tf.abs(n1 - n2) + constraining_init_state
max_free_events = tf.minimum(free_events,
tf.gather(target_events_, target_t,
axis=0))
max_free_events = tf.minimum(
free_events, tf.gather(target_events_, target_t, axis=0)
)
return max_free_events
# Manual broadcasting of n_events_t is required here so that the XLA
......@@ -58,8 +56,9 @@ def _max_free_events(events, initial_state,
# propagated right through the algorithm, so the return value has known shape.
def false_fn():
n_events_t = tf.gather(events[..., target_id], target_t, axis=0)
return tf.broadcast_to([n_events_t],
[constraint_t.shape[0]] + [n_events_t.shape[0]])
return tf.broadcast_to(
[n_events_t], [constraint_t.shape[0]] + [n_events_t.shape[0]]
)
ret_val = tf.cond(constraint_id != -1, true_fn, false_fn)
return ret_val
......@@ -78,32 +77,32 @@ def _move_events(event_tensor, event_id, m, from_t, to_t, n_move):
:return: the modified event_tensor
"""
# Todo rationalise this -- compute a delta, and add once.
indices = tf.stack([m, # All meta-populations
from_t,
tf.broadcast_to(event_id, m.shape)], axis=-1) # Event
indices = tf.stack(
[m, from_t, tf.broadcast_to(event_id, m.shape)], axis=-1 # All meta-populations
) # Event
# Subtract x_star from the [from_t, :, event_id] row of the state tensor
n_move = tf.cast(n_move, event_tensor.dtype)
new_state = tf.tensor_scatter_nd_sub(event_tensor, indices, n_move)
indices = tf.stack([m,
to_t,
tf.broadcast_to(event_id, m.shape)], axis=-1)
indices = tf.stack([m, to_t, tf.broadcast_to(event_id, m.shape)], axis=-1)
# Add x_star to the [to_t, :, event_id] row of the state tensor
new_state = tf.tensor_scatter_nd_add(new_state, indices, n_move)
return new_state
class EventTimesUpdate(tfp.mcmc.TransitionKernel):
def __init__(self,
target_log_prob_fn,
target_event_id,
prev_event_id,
next_event_id,
initial_state,
dmax,
mmax,
nmax,
seed=None,
name=None):
def __init__(
self,
target_log_prob_fn,
target_event_id,
prev_event_id,
next_event_id,
initial_state,
dmax,
mmax,
nmax,
seed=None,
name=None,
):
"""A random walk Metropolis Hastings for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the first dimension of the events tensor that we wish to move
......@@ -116,7 +115,7 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed
:param name: the name of the update step
"""
self._seed_stream = SeedStream(seed, salt='EventTimesUpdate')
self._seed_stream = SeedStream(seed, salt="EventTimesUpdate")
self._impl = tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
......@@ -126,9 +125,11 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax=dmax,
mmax=mmax,
nmax=nmax,
initial_state=initial_state))
initial_state=initial_state,
)
)
self._parameters = self._impl.inner_kernel.parameters.copy()
self._parameters['seed'] = seed
self._parameters["seed"] = seed
@property
def target_log_prob_fn(self):
......@@ -153,8 +154,9 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
:param previous_kernel_results: a named tuple of results.
:returns: (next_state, kernel_results)
"""
next_state, kernel_results = self._impl.one_step(current_state,
previous_kernel_results)
next_state, kernel_results = self._impl.one_step(
current_state, previous_kernel_results
)
return next_state, kernel_results
def bootstrap_results(self, init_state):
......@@ -163,30 +165,31 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
def _reverse_move(move):
move['t'] = move['t'] + move['delta_t']
# Todo remove this hack once tfd.Categorical is working correctly.
move['t_'] = tf.one_hot(indices=move['t'],
depth=move['t_'].shape[1],
dtype=move['t_'].dtype)
move['delta_t'] = -move['delta_t']
move["t"] = move["t"] + move["delta_t"]
move["delta_t"] = -move["delta_t"]
return move
class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def __init__(self,
target_log_prob_fn,
target_event_id,
prev_event_id,
next_event_id,
initial_state,
dmax,
mmax,
nmax,
seed=None,
name=None):
"""UncalibratedEventTimesUpdate"""
def __init__(
self,
target_log_prob_fn,
target_event_id,
prev_event_id,
next_event_id,
initial_state,
dmax,
mmax,
nmax,
seed=None,
name=None,
):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the first dimension of the events tensor that we wish to move
:param target_event_id: the position in the first dimension of the events
tensor that we wish to move
:param prev_event_id: the position of the previous event in the events tensor
:param next_event_id: the position of the next event in the events tensor
:param initial_state: the initial state tensor
......@@ -194,8 +197,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param name: the name of the update step
"""
self._target_log_prob_fn = target_log_prob_fn
self._seed_stream = SeedStream(seed,
salt='UncalibratedEventTimesUpdate')
self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
......@@ -207,34 +209,36 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
mmax=mmax,
nmax=nmax,
seed=seed,
name=name)
self.tx_topology = TransitionTopology(prev_event_id, target_event_id,
next_event_id)
self.time_offsets = tf.range(self.parameters['dmax'])
name=name,
)
self.tx_topology = TransitionTopology(
prev_event_id, target_event_id, next_event_id
)
self.time_offsets = tf.range(self.parameters["dmax"])
@property
def target_log_prob_fn(self):
return self._parameters['target_log_prob_fn']
return self._parameters["target_log_prob_fn"]
@property
def target_event_id(self):
return self._parameters['target_event_id']
return self._parameters["target_event_id"]
@property
def prev_event_id(self):
return self._parameters['prev_event_id']
return self._parameters["prev_event_id"]
@property
def next_event_id(self):
return self._parameters['next_event_id']
return self._parameters["next_event_id"]
@property
def seed(self):
return self._parameters['seed']
return self._parameters["seed"]
@property
def name(self):
return self._parameters['name']
return self._parameters["name"]
@property
def parameters(self):
......@@ -254,93 +258,96 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
UncalibratedRandomWalkResults.
:returns: a tuple containing new_state and UncalibratedRandomWalkResults
"""
with tf.name_scope('uncalibrated_event_times_rw/onestep'):
with tf.name_scope("uncalibrated_event_times_rw/onestep"):
current_events = tf.transpose(current_events, perm=(1, 0, 2))
target_events = current_events[..., self.tx_topology.target]
num_times = target_events.shape[1]
proposal = FilteredEventTimeProposal(events=current_events,
initial_state=self.parameters[
'initial_state'],
topology=self.tx_topology,
m_max=self.parameters['mmax'],
d_max=self.parameters['dmax'],
n_max=self.parameters['nmax'],
direction='fwd')
proposal = FilteredEventTimeProposal(
events=current_events,
initial_state=self.parameters["initial_state"],
topology=self.tx_topology,
m_max=self.parameters["mmax"],
d_max=self.parameters["dmax"],
n_max=self.parameters["nmax"],
)
update = proposal.sample()
move = update['move']
to_t = move['t'] + move['delta_t']
q_fwd = proposal.log_prob(update)
tf.debugging.assert_all_finite(q_fwd, "q_fwd is not finite")
move = update["move"]
to_t = move["t"] + move["delta_t"]
# Moves outside the range [0, num_times] are illegal
# Todo: address potential issue in the proposal if
# dmax accesses indices outside this range.
def true_fn():
next_state = _move_events(event_tensor=current_events,
event_id=self.tx_topology.target,
m=update['m'],
from_t=move['t'],
to_t=to_t,
n_move=move['x_star'])
next_state = _move_events(
event_tensor=current_events,
event_id=self.tx_topology.target,
m=update["m"],
from_t=move["t"],
to_t=to_t,
n_move=move["x_star"],
)
next_state_tr = tf.transpose(next_state, perm=(1, 0, 2))
next_target_log_prob = self._target_log_prob_fn(next_state_tr)
# Calculate proposal mass ratio
q_fwd = proposal.log_prob(update)
rev_move = _reverse_move(move.copy())
rev_update = dict(m=update['m'],
move=rev_move)
Q_rev = FilteredEventTimeProposal(
rev_update = dict(m=update["m"], move=rev_move)
Q_rev = FilteredEventTimeProposal( # pylint: disable-invalid-name
events=next_state,
initial_state=self.parameters[
'initial_state'],
initial_state=self.parameters["initial_state"],
topology=self.tx_topology,
m_max=self.parameters['mmax'],
d_max=self.parameters['dmax'],
n_max=self.parameters[
'nmax'],
direction='rev')
m_max=self.parameters["mmax"],
d_max=self.parameters["dmax"],
n_max=self.parameters["nmax"],
)
q_rev = Q_rev.log_prob(rev_update)
log_acceptance_correction = tf.reduce_sum(q_rev - q_fwd)
return (next_target_log_prob,
log_acceptance_correction,
next_state_tr)
return (next_target_log_prob, log_acceptance_correction, next_state_tr)
def false_fn():
next_target_log_prob = tf.constant(-np.inf,
dtype=current_events.dtype)
log_acceptance_correction = tf.constant(0.0,
dtype=current_events.dtype)
return (next_target_log_prob,
log_acceptance_correction,
tf.transpose(current_events, perm=(1, 0, 2)))
next_target_log_prob = tf.constant(-np.inf, dtype=current_events.dtype)
log_acceptance_correction = tf.constant(0.0, dtype=current_events.dtype)
return (
next_target_log_prob,
log_acceptance_correction,
tf.transpose(current_events, perm=(1, 0, 2)),
)
# Trap out-of-bounds moves that go outside [0, num_times)
(next_target_log_prob,
log_acceptance_correction,
next_state) = tf.cond(
(next_target_log_prob, log_acceptance_correction, next_state) = tf.cond(
tf.reduce_all(_is_within(to_t, 0, num_times)),
true_fn=true_fn,
false_fn=false_fn)
false_fn=false_fn,
)
x_star_results = tf.scatter_nd(update['m'][:, tf.newaxis],
tf.abs(move['x_star']*move['delta_t']),
[current_events.shape[0]])
x_star_results = tf.scatter_nd(
update["m"][:, tf.newaxis],
tf.abs(move["x_star"] * move["delta_t"]),
[current_events.shape[0]],
)
return [next_state,
KernelResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.cast(x_star_results, current_events.dtype)
)]
return [
next_state,
KernelResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.cast(x_star_results, current_events.dtype),
),
]
def bootstrap_results(self, init_state):
with tf.name_scope('uncalibrated_event_times_rw/bootstrap_results'):
with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
init_target_log_prob = self.target_log_prob_fn(init_state)
return KernelResults(
log_acceptance_correction=tf.constant(0., dtype=DTYPE),
log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
target_log_prob=init_target_log_prob,
extra=tf.zeros(init_state.shape[-2], dtype=DTYPE)
extra=tf.zeros(init_state.shape[-2], dtype=DTYPE),
)
"""Mechanism for proposing event times to move"""
import collections
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from pprint import pprint
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.distributions.categorical import (
_broadcast_cat_event_and_params,
)
from covid.impl.UniformInteger import UniformInteger
from covid.impl.KCategorical import KCategorical
tfd = tfp.distributions
TransitionTopology = collections.namedtuple('TransitionTopology',
('prev',
'target',
'next'))
TransitionTopology = collections.namedtuple(
"TransitionTopology", ("prev", "target", "next")
)
def _events_or_inf(events, transition_id):
if transition_id is None:
return tf.fill(events.shape[:-1],
tf.constant(np.inf, dtype=events.dtype))
return tf.fill(events.shape[:-1], tf.constant(np.inf, dtype=events.dtype))
return tf.gather(events, transition_id, axis=-1)
def _abscumdiff(events, initial_state, topology, t, delta_t, bound_times,
int_dtype=tf.int32):
def _abscumdiff(
events, initial_state, topology, t, delta_t, bound_times, int_dtype=tf.int32
):
"""Returns the number of free events to move in target_events
bounded by max([N_{target_id}(t)-N_{bound_id}(t)]_{bound_t}).
:param events: a [(M), T, X] tensor of transition events
......@@ -39,16 +41,15 @@ def _abscumdiff(events, initial_state, topology, t, delta_t, bound_times,
# This line prevents negative indices. However, we must have
# a contract that the output of the algorithm is invalid!
bound_times = tf.clip_by_value(bound_times,
clip_value_min=0,
clip_value_max=events.shape[-2])
bound_times = tf.clip_by_value(
bound_times, clip_value_min=0, clip_value_max=events.shape[-2] - 1
)
# Maybe replace with pad to avoid unstack/stack
prev_events = _events_or_inf(events, topology.prev)
target_events = tf.gather(events, topology.target, axis=-1)
next_events = _events_or_inf(events, topology.next)
event_tensor = tf.stack([prev_events, target_events, next_events],
axis=-1)
event_tensor = tf.stack([prev_events, target_events, next_events], axis=-1)
# Compute the absolute cumulative difference between event times
diff = event_tensor[..., 1:] - event_tensor[..., :-1] # [m, T, 2]
......@@ -56,38 +57,44 @@ def _abscumdiff(events, initial_state, topology, t, delta_t, bound_times,
# Create indices into cumdiff [m, d_max, 2]. Last dimension selects
# the bound for either the previous or next event.
indices = tf.stack([
tf.repeat(tf.range(events.shape[0], dtype=int_dtype),
[bound_times.shape[1]]),
tf.reshape(bound_times, [-1]),
tf.repeat(tf.where(delta_t < 0, 0, 1), [bound_times.shape[1]])
], axis=-1)
indices = tf.stack(
[
tf.repeat(
tf.range(events.shape[0], dtype=int_dtype), [bound_times.shape[1]]
),
tf.reshape(bound_times, [-1]),
tf.repeat(tf.where(delta_t < 0, 0, 1), [bound_times.shape[1]]),
],
axis=-1,
)
indices = tf.reshape(indices, [events.shape[-3], bound_times.shape[1], 3])
free_events = tf.gather_nd(cumdiff, indices)
# Add on initial state
indices = tf.stack([tf.range(events.shape[0]),
tf.where(delta_t[:, 0] < 0,
topology.target,
topology.target+1)],
axis=-1)
bound_init_state = tf.gather_nd(initial_state,
indices)
indices = tf.stack(
[
tf.range(events.shape[0]),
tf.where(delta_t[:, 0] < 0, topology.target, topology.target + 1),
],
axis=-1,
)
bound_init_state = tf.gather_nd(initial_state, indices)
free_events += bound_init_state[..., tf.newaxis]
return free_events
class Deterministic2(tfd.Deterministic):
def __init__(self,
loc,
atol=None,
rtol=None,
validate_args=False,
allow_nan_stats=True,
log_prob_dtype=tf.float32,
name='Deterministic'):
def __init__(
self,
loc,
atol=None,
rtol=None,
validate_args=False,
allow_nan_stats=True,
log_prob_dtype=tf.float32,
name="Deterministic",
):
parameters = dict(locals())
super(Deterministic2, self).__init__(
loc,
......@@ -95,7 +102,7 @@ class Deterministic2(tfd.Deterministic):
rtol=rtol,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name
name=name,
)
self.log_prob_dtype = log_prob_dtype
......@@ -103,41 +110,61 @@ class Deterministic2(tfd.Deterministic):
return tf.constant(1, dtype=self.log_prob_dtype)
def EventTimeProposal(events, initial_state, topology, d_max, n_max, direction,
dtype=tf.int32, name=None):
class Categorical2(tfd.Categorical):
"""Done to override the faulty log_prob in tfd.Categorical due to
https://github.com/tensorflow/tensorflow/issues/40606"""
def _log_prob(self, k):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
def EventTimeProposal(
events, initial_state, topology, d_max, n_max, dtype=tf.int32, name=None
):
"""Draws an event time move proposal.
:param events: a [M, T, K] tensor of event times (M number of metapopulations,
T number of timepoints, K number of transitions)
:param initial_state: a [M, S] tensor of initial metapopulation x state counts
:param topology: a 3-element tuple of (previous_transition, target_transition,
next_transition), eg "(s->e, e->i, i->r)"
(assuming we are interested presently in e->i, `None` for boundaries)
(assuming we are interested presently in e->i,
`None` for boundaries)
:param d_max: the maximum distance over which to move (in time)
:param n_max: the maximum number of events to move
"""
target_events = tf.gather(events, topology.target, axis=-1)
time_interval = tf.range(d_max, dtype=dtype)
def t_():
x = tf.cast(target_events > 0, dtype=tf.float64)
logits = tf.math.log(x)
return tfd.Multinomial(total_count=1, logits=logits, name='t_')
# def t_():
# x = tf.cast(target_events > 0, dtype=tf.float32)
# logits = tf.math.log(x)
# # return tfd.Multinomial(total_count=1, logits=logits, name="t_")
# # print("logits dtype:", logits.dtype)
# return tfd.OneHotCategorical(logits=logits, name="t_")
def delta_t():
outcomes = tf.concat([-tf.range(1, d_max + 1), tf.range(1, d_max + 1)],
axis=0)
outcomes = tf.concat([tf.range(-d_max, 0), tf.range(1, d_max + 1)], axis=0)
logits = tf.ones([events.shape[-3]] + outcomes.shape, dtype=tf.float64)
return tfd.FiniteDiscrete(outcomes=outcomes, logits=logits,
name='delta_t')
return tfd.FiniteDiscrete(outcomes=outcomes, logits=logits, name="delta_t")
def t(t_):
def t():