Commit b2c6806a authored by Chris Jewell's avatar Chris Jewell
Browse files

Transferred to external gemlib library

parent 2e4311a9
"""Categorical2 corrects a bug in the tfd.Categorical.log_prob"""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.distributions.categorical import (
tfd = tfp.distributions
# Todo remove this class when
# is fixed
class Categorical2(tfd.Categorical):
"""Done to override the faulty log_prob in tfd.Categorical due to"""
def _log_prob(self, k):
with tf.name_scope("Cat2log_prob"):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.cast(tf.gather(logits_normalised, k, batch_dims=1), tf.float64)
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import reparameterization
tfd = tfp.distributions
class KCategorical(tfd.Distribution):
def __init__(self,
"""K-Categorical distribution
Given a set of items indexed $1,...,n$ with weights $w_1,\dots,w_n$,
sample $k$ indices without replacement.
:param k: the number of indices to sample
:param probs: the (normalized) probability vector
:param validate_args: Whether to validate args
:param allow_nan_stats: allow nan stats
:param name: name of the distribution
parameters = dict(locals())
self.probs = probs
self.logits = tf.math.log(probs)
dtype = tf.int32
with tf.name_scope(name) as name:
super(KCategorical, self).__init__(
def _sample_n(self, n, seed=None):
g = tfd.Gumbel(tf.constant(0., dtype=self.probs.dtype),
tf.constant(1., dtype=self.probs.dtype)).sample(
self.logits.shape, seed=seed)
# Hack for missing float64 version
z = tf.cast(g + self.logits, tf.float32)
_, x = tf.nn.top_k(z, self.parameters['k'])
return x
def _log_prob(self, x):
n = self.logits.shape
k = x.shape
wz = tf.gather(self.probs, x, axis=-1)
W = tf.cumsum(wz, reverse=True)
return tf.reduce_sum(wz - tf.math.log(W))
if __name__ == '__main__':
probs = tf.constant([1, 0, 0, 1, 1, 0, 1], dtype=tf.float32)
probs = probs / tf.reduce_sum(probs)
X = KCategorical(3, probs)
x = X.sample()
lp = X.log_prob(x)
"""The UniformInteger distribution class"""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
tfd = tfp.distributions
class UniformInteger(tfd.Distribution):
def __init__(
"""Initialise a UniformInteger random variable on `[low, high)`.
low: Integer tensor, lower boundary of the output interval. Must have
`low <= high`.
high: Integer tensor, _inclusive_ upper boundary of the output
interval. Must have `low <= high`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
dtype: the dtype of the output variates
name: Python `str` name prefixed to Ops created by this class.
InvalidArgument if `low > high` and `validate_args=False`.
parameters = dict(locals())
with tf.name_scope(name) as name:
self._low = tf.cast(low, name="low", dtype=dtype)
self._high = tf.cast(high, name="high", dtype=dtype)
super(UniformInteger, self).__init__(
self.float_dtype = float_dtype
if validate_args is True:
tf.assert_greater(self._high, self._low, "Condition low < high failed")
def _param_shapes(sample_shape):
return dict(
("low", "high"),
([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2),
def _params_event_ndims(cls):
return dict(low=0, high=0)
def low(self):
"""Lower boundary of the output interval."""
return self._low
def high(self):
"""Upper boundary of the output interval."""
return self._high
def range(self, name="range"):
"""`high - low`."""
with self._name_and_control_scope(name):
return self._range()
def _range(self, low=None, high=None):
low = self.low if low is None else low
high = self.high if high is None else high
return high - low
def _batch_shape_tensor(self, low=None, high=None):
return tf.broadcast_dynamic_shape(
tf.shape(self.low if low is None else low),
tf.shape(self.high if high is None else high),
def _batch_shape(self):
return tf.broadcast_static_shape(self.low.shape, self.high.shape)
def _event_shape_tensor(self):
return tf.constant([], dtype=tf.int32)
def _event_shape(self):
return tf.TensorShape([])
def _sample_n(self, n, seed=None):
with tf.name_scope("sample_n"):
low = tf.convert_to_tensor(self.low)
high = tf.convert_to_tensor(self.high)
shape = tf.concat([[n], self._batch_shape_tensor(low=low, high=high)], 0)
samples = samplers.uniform(shape=shape, dtype=tf.float32, seed=seed)
return low + tf.cast(
tf.cast(self._range(low=low, high=high), tf.float32) * samples,
def _prob(self, x):
with tf.name_scope("prob"):
low = tf.cast(self.low, self.float_dtype)
high = tf.cast(self.high, self.float_dtype)
x = tf.cast(x, dtype=self.float_dtype)
return tf.where(
(x < low) | (x >= high),
tf.ones_like(x) / self._range(low=low, high=high),
def _log_prob(self, x):
with tf.name_scope("log_prob"):
res = tf.math.log(self._prob(x))
return res
"""Debugging tools"""
import tensorflow as tf
import tensorflow_probability as tfp
class DoNotUpdate(tfp.mcmc.TransitionKernel):
def __init__(self, inner_kernel, name=None):
"""Prevents the update of a kernel for debug purposes"""
self._parameters = dict(inner_kernel=inner_kernel, name=name)
def inner_kernel(self):
return self._parameters["inner_kernel"]
def name(self):
return self._parameters["name"]
def is_calibrated(self):
return True
def parameters(self):
return self._parameters
def one_step(self, current_state, previous_results, seed=None):
"""Don't invoke inner_kernel.one_step, but return
current state and results"""
return current_state, previous_results
def bootstrap_results(self, current_state):
return self.inner_kernel.bootstrap_results(current_state)
"""Functions for chain binomial simulation."""
import tensorflow as tf
import tensorflow_probability as tfp
from covid.impl.util import compute_state, make_transition_matrix, transition_coords
tfd = tfp.distributions
def approx_expm(rates):
"""Approximates a full Markov transition matrix
:param rates: un-normalised rate matrix (i.e. diagonal zero)
:returns: approximation to Markov transition matrix
total_rates = tf.reduce_sum(rates, axis=-1, keepdims=True)
prob = 1.0 - tf.math.exp(-tf.reduce_sum(rates, axis=-1, keepdims=True))
mt1 = tf.math.multiply_no_nan(rates / total_rates, prob)
return tf.linalg.set_diag(mt1, 1.0 - tf.reduce_sum(mt1, axis=-1))
def chain_binomial_propagate(h, time_step, stoichiometry, seed=None):
"""Propagates the state of a population according to discrete time dynamics.
:param h: a hazard rate function returning the non-row-normalised Markov transition
rate matrix. This function should return a tensor of dimension
[ns, ns, nc] where ns is the number of states, and nc is the number of
strata within the population.
:param time_step: the time step
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
def propagate_fn(t, state):
rates = h(t, state)
rate_matrix = make_transition_matrix(
rates, transition_coords(stoichiometry), state.shape
# Set diagonal to be the negative of the sum of other elements in each row
markov_transition = approx_expm(rate_matrix * time_step)
num_states = markov_transition.shape[-1]
prev_probs = tf.zeros_like(markov_transition[..., :, 0])
counts = tf.zeros(
markov_transition.shape[:-1].as_list() + [0], dtype=markov_transition.dtype
total_count = state
# This for loop is ok because there are (currently) only 4 states (SEIR)
# and we're only actually creating work for 3 of them. Even for as many
# as a ~10 states it should probably be fine, just increasing the size
# of the graph a bit.
for i in range(num_states - 1):
probs = markov_transition[..., :, i]
binom = tfd.Binomial(
probs=tf.clip_by_value(probs / (1.0 - prev_probs), 0.0, 1.0),
sample = binom.sample(seed=seed)
counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1)
total_count -= sample
prev_probs += probs
counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1)
new_state = tf.reduce_sum(counts, axis=-2)
return counts, new_state
return propagate_fn
def discrete_markov_simulation(
hazard_fn, state, start, end, time_step, stoichiometry, seed=None
"""Simulates from a discrete time Markov state transition model using multinomial sampling
across rows of the """
propagate = chain_binomial_propagate(hazard_fn, time_step, stoichiometry, seed=seed)
times = tf.range(start, end, time_step, dtype=state.dtype)
state = tf.convert_to_tensor(state)
output = tf.TensorArray(state.dtype, size=times.shape[0])
cond = lambda i, *_: i < times.shape[0]
def body(i, state, output):
update, state = propagate(times[i], state)
output = output.write(i, update)
return i + 1, state, output
_, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
return times, output.stack()
def discrete_markov_log_prob(
events, init_state, init_step, time_delta, hazard_fn, stoichiometry
"""Calculates an unnormalised log_prob function for a discrete time epidemic model.
:param events: a `[M, T, X]` batch of transition events for metapopulation M,
times `T`, and transitions `X`.
:param init_state: a vector of shape `[M, S]` the initial state of the epidemic for
`M` metapopulations and `S` states
:param init_step: the initial time step, as an offset to `range(events.shape[-2])`
:param time_delta: the size of the time step.
:param hazard_fn: a function that takes a state and returns a matrix of transition
:param stoichiometry: a `[X, S]` matrix describing the state update for each
:return: a scalar log probability for the epidemic.
num_meta = events.shape[-3]
num_times = events.shape[-2]
num_events = events.shape[-1]
num_states = stoichiometry.shape[-1]
state_timeseries = compute_state(init_state, events, stoichiometry) # MxTxS
tms_timeseries = tf.transpose(state_timeseries, perm=(1, 0, 2))
def fn(elems):
return hazard_fn(*elems)
tx_coords = transition_coords(stoichiometry)
rates = tf.vectorized_map(fn=fn, elems=[tf.range(num_times), tms_timeseries])
rate_matrix = make_transition_matrix(rates, tx_coords, tms_timeseries.shape)
probs = approx_expm(rate_matrix * time_delta)
# [T, M, S, S] to [M, T, S, S]
probs = tf.transpose(probs, perm=(1, 0, 2, 3))
event_matrix = make_transition_matrix(
events, tx_coords, [num_meta, num_times, num_states]
event_matrix = tf.linalg.set_diag(
event_matrix, state_timeseries - tf.reduce_sum(event_matrix, axis=-1)
logp = tfd.Multinomial(
tf.cast(state_timeseries, dtype=tf.float32),
probs=tf.cast(probs, dtype=tf.float32),
).log_prob(tf.cast(event_matrix, dtype=tf.float32))
return tf.cast(tf.reduce_sum(logp), dtype=events.dtype)
def events_to_full_transitions(events, initial_state):
"""Creates a state tensor given matrices of transition events
and the initial state
:param events: a tensor of shape [t, c, s, s] for t timepoints, c metapopulations
and s states.
:param initial_state: the initial state matrix of shape [c, s]
def f(state, events):
survived = tf.reduce_sum(state, axis=-2) - tf.reduce_sum(events, axis=-1)
new_state = tf.linalg.set_diag(events, survived)
return new_state
return tf.scan(fn=f, elems=events, initializer=tf.linalg.diag(initial_state))
from collections import namedtuple
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream
from covid import config
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
tfd = tfp.distributions
DTYPE = config.floatX
EventTimesKernelResults = namedtuple(
"KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
def _is_within(x, low, high):
"""Returns true if low <= x < high"""
return tf.logical_and(tf.less_equal(low, x), tf.less(x, high))
def _nonzero_rows(m):
return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)
def _move_events(event_tensor, event_id, m, from_t, to_t, n_move):
"""Subtracts n_move from event_tensor[m, from_t, event_id]
and adds n_move to event_tensor[m, to_t, event_id].
:param event_tensor: shape [M, T, X]
:param event_id: the event id to move
:param m: the metapopulation to move
:param from_t: the move-from time
:param to_t: the move-to time
:param n_move: the number of events to move
:return: the modified event_tensor
# Todo rationalise this -- compute a delta, and add once.
indices = tf.stack(
[m, from_t, tf.broadcast_to(event_id, m.shape)], axis=-1 # All meta-populations
) # Event
# Subtract x_star from the [from_t, :, event_id] row of the state tensor
n_move = tf.cast(n_move, event_tensor.dtype)
new_state = tf.tensor_scatter_nd_sub(event_tensor, indices, n_move)
indices = tf.stack([m, to_t, tf.broadcast_to(event_id, m.shape)], axis=-1)
# Add x_star to the [to_t, :, event_id] row of the state tensor
new_state = tf.tensor_scatter_nd_add(new_state, indices, n_move)
return new_state
def _reverse_move(move):
move["t"] = move["t"] + move["delta_t"]
move["delta_t"] = -move["delta_t"]
return move
class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def __init__(
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the first dimension of the events
tensor that we wish to move
:param prev_event_id: the position of the previous event in the events tensor
:param next_event_id: the position of the next event in the events tensor
:param initial_state: the initial state tensor
:param seed: a random seed
:param name: the name of the update step
self._name = name
self._parameters = dict(
self.tx_topology = TransitionTopology(
prev_event_id, target_event_id, next_event_id
self.time_offsets = tf.range(self.parameters["dmax"])
def target_log_prob_fn(self):
return self._parameters["target_log_prob_fn"]
def target_event_id(self):
return self._parameters["target_event_id"]
def prev_event_id(self):
return self._parameters["prev_event_id"]
def next_event_id(self):
return self._parameters["next_event_id"]
def seed(self):
return self._parameters["seed"]
def name(self):
return self._parameters["name"]
def parameters(self):
"""Return `dict` of ``__init__`` arguments and their values."""
return self._parameters
def is_calibrated(self):
return False
def one_step(self, current_events, previous_kernel_results, seed=None):
"""One update of event times.
:param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m,
and transition x.
:param previous_kernel_results: an object of type
:returns: a tuple containing new_state and UncalibratedRandomWalkResults
with tf.name_scope("uncalibrated_event_times_rw/onestep"):
target_events = current_events[...,]
num_times = target_events.shape[1]
proposal = FilteredEventTimeProposal(
update = proposal.sample(seed=seed)
move = update["move"]
to_t = move["t"] + move["delta_t"]
def true_fn():
with tf.name_scope("true_fn"):
# Prob of fwd move
q_fwd = proposal.log_prob(update)
tf.debugging.assert_all_finite(q_fwd, "q_fwd is not finite")
# Propagate state
next_state = _move_events(
next_target_log_prob = self.target_log_prob_fn(next_state)
# Calculate proposal mass ratio
rev_move = _reve