Commit cad36259 authored by Chris Jewell's avatar Chris Jewell
Browse files

Added occult inference capability

Changes:

1. Added occult Metropolis Hastings update.
2. Factored out Categorical2 distribution for use by both event time move and occults.
3. Refactored mcmc.py script for HDF5 output purposes
4. Apply compression to HDF5 output file.
parent 72773df9
"""Categorical2 corrects a bug in the tfd.Categorical.log_prob"""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.distributions.categorical import (
_broadcast_cat_event_and_params,
)
tfd = tfp.distributions
# Todo remove this class when https://github.com/tensorflow/tensorflow/issues/40606
# is fixed
class Categorical2(tfd.Categorical):
"""Done to override the faulty log_prob in tfd.Categorical due to
https://github.com/tensorflow/tensorflow/issues/40606"""
def _log_prob(self, k):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
......@@ -4,13 +4,10 @@ import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.distributions.categorical import (
_broadcast_cat_event_and_params,
)
from covid.impl.UniformInteger import UniformInteger
from covid.impl.KCategorical import KCategorical
from covid.impl.Categorical2 import Categorical2
tfd = tfp.distributions
......@@ -110,23 +107,6 @@ class Deterministic2(tfd.Deterministic):
return tf.constant(1, dtype=self.log_prob_dtype)
class Categorical2(tfd.Categorical):
"""Done to override the faulty log_prob in tfd.Categorical due to
https://github.com/tensorflow/tensorflow/issues/40606"""
def _log_prob(self, k):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
def EventTimeProposal(
events, initial_state, topology, d_max, n_max, dtype=tf.int32, name=None
):
......
......@@ -39,7 +39,7 @@ def random_walk_mvnorm_fn(covariance, p_u=0.95, name=None):
def proposal():
rv = tf.stack([rv_fix.sample(), rv_adapt.sample()])
uv = u.sample()
uv = u.sample(seed=seed)
return tf.gather(rv, uv)
new_state_parts = [proposal() + state_part for state_part in state_parts]
......
from collections import namedtuple
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream
from covid import config
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
from covid.impl.occult_proposal import AddOccultProposal, DelOccultProposal
tfd = tfp.distributions
DTYPE = config.floatX
OccultKernelResults = namedtuple(
"KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
)
def _nonzero_rows(m):
return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)
def _maybe_expand_dims(x):
"""If x is a scalar, give it at least 1 dimension"""
x = tf.convert_to_tensor(x)
if x.shape == ():
return tf.expand_dims(x, axis=0)
return x
def _add_events(events, m, t, x, x_star):
"""Adds `x_star` events to metapopulation `m`,
time `t`, transition `x` in `events`."""
x = _maybe_expand_dims(x)
indices = tf.stack([m, t, x], axis=-1)
return tf.tensor_scatter_nd_add(events, indices, x_star)
class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
"""UncalibratedEventTimesUpdate"""
def __init__(
self, target_log_prob_fn, target_event_id, nmax, seed=None, name=None,
):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the last dimension of the events
tensor that we wish to move
:param seed: a random seed
:param name: the name of the update step
"""
self._target_log_prob_fn = target_log_prob_fn
self._seed_stream = SeedStream(seed, salt="UncalibratedOccultUpdate")
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
nmax=nmax,
seed=seed,
name=name,
)
self.tx_topology = TransitionTopology(None, target_event_id, None)
@property
def target_log_prob_fn(self):
return self._parameters["target_log_prob_fn"]
@property
def target_event_id(self):
return self._parameters["target_event_id"]
@property
def seed(self):
return self._parameters["seed"]
@property
def name(self):
return self._parameters["name"]
@property
def parameters(self):
"""Return `dict` of ``__init__`` arguments and their values."""
return self._parameters
@property
def is_calibrated(self):
return False
def one_step(self, current_events, previous_kernel_results):
"""One update of event times.
:param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m,
and transition x.
:param previous_kernel_results: an object of type
UncalibratedRandomWalkResults.
:returns: a tuple containing new_state and UncalibratedRandomWalkResults
"""
with tf.name_scope("occult_rw/onestep"):
def true_fn():
with tf.name_scope("true_fn"):
proposal = AddOccultProposal(
current_events, self.parameters["nmax"]
)
update = proposal.sample()
next_state = _add_events(
events=current_events,
m=update["m"],
t=update["t"],
x=self.tx_topology.target,
x_star=tf.cast(update["x_star"], current_events.dtype),
)
reverse = DelOccultProposal(next_state, self.tx_topology)
q_fwd = tf.reduce_sum(proposal.log_prob(update))
q_rev = tf.reduce_sum(reverse.log_prob(update))
log_acceptance_correction = q_rev - q_fwd
return update, next_state, log_acceptance_correction
def false_fn():
with tf.name_scope("false_fn"):
proposal = DelOccultProposal(current_events, self.tx_topology)
update = proposal.sample()
next_state = _add_events(
events=current_events,
m=update["m"],
t=update["t"],
x=[self.tx_topology.target],
x_star=tf.cast(-update["x_star"], current_events.dtype),
)
reverse = AddOccultProposal(next_state, self.parameters["nmax"])
q_fwd = tf.reduce_sum(proposal.log_prob(update))
q_rev = tf.reduce_sum(reverse.log_prob(update))
log_acceptance_correction = q_rev - q_fwd
return update, next_state, log_acceptance_correction
u = tfd.Uniform().sample()
delta, next_state, log_acceptance_correction = tf.cond(
u < 0.5, true_fn, false_fn
)
# tf.debugging.assert_non_negative(
# next_state, message="Negative occults occurred"
# )
next_target_log_prob = self.target_log_prob_fn(next_state)
return [
next_state,
OccultKernelResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.concat([delta["m"], delta["t"], delta["x_star"]], axis=0),
),
]
def bootstrap_results(self, init_state):
with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
init_target_log_prob = self.target_log_prob_fn(init_state)
return OccultKernelResults(
log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
target_log_prob=init_target_log_prob,
extra=tf.constant([0, 0, 0], dtype=tf.int32),
)
import tensorflow as tf
import tensorflow_probability as tfp
from covid.impl.UniformInteger import UniformInteger
from covid.impl.Categorical2 import Categorical2
tfd = tfp.distributions
def AddOccultProposal(events, n_max, dtype=tf.int32, name=None):
def m():
"""Select a metapopulation"""
with tf.name_scope("m"):
return UniformInteger(low=[0], high=[events.shape[0]], dtype=dtype)
def t():
"""Select a timepoint"""
with tf.name_scope("t"):
return UniformInteger(low=[0], high=[events.shape[1]], dtype=dtype)
def x_star():
"""Draw num to add"""
return UniformInteger(low=[0], high=[n_max + 1], dtype=dtype)
return tfd.JointDistributionNamed(dict(m=m, t=t, x_star=x_star), name=name)
def DelOccultProposal(events, topology, dtype=tf.int32, name=None):
def m():
"""Select a metapopulation"""
with tf.name_scope("m"):
hot_meta = (
tf.math.count_nonzero(
events[..., topology.target], axis=1, keepdims=True
)
> 0
)
hot_meta = tf.cast(tf.transpose(hot_meta), dtype=events.dtype)
logits = tf.math.log(hot_meta)
X = Categorical2(logits=logits, dtype=dtype, name="m")
return X
def t(m):
"""Draw timepoint"""
with tf.name_scope("t"):
metapops = tf.gather(events, m)
hot_times = metapops[..., topology.target] > 0
hot_times = tf.cast(hot_times, dtype=events.dtype)
logits = tf.math.log(hot_times)
return Categorical2(logits=logits, dtype=dtype, name="t")
def x_star(m, t):
"""Draw num to delete"""
with tf.name_scope("x_star"):
indices = tf.stack([m, t, [topology.target]], axis=-1)
max_occults = tf.gather_nd(events, indices)
return UniformInteger(
low=0, high=max_occults + 1, dtype=dtype, name="x_star"
)
return tfd.JointDistributionNamed(dict(m=m, t=t, x_star=x_star), name=name)
......@@ -155,7 +155,7 @@ class CovidUKStochastic(CovidUK):
* commute_volume
* tf.linalg.matvec(self.C, state[..., 2] / self.N)
)
infec_rate = infec_rate / self.N # + 1.0e-6 # Vector of length nc
infec_rate = infec_rate / self.N # Vector of length nc
ei = tf.broadcast_to(
[param["nu"]], shape=[state.shape[0]]
......
......@@ -2,6 +2,7 @@
import optparse
import os
import pickle as pkl
from collections import OrderedDict
import h5py
import numpy as np
......@@ -15,7 +16,7 @@ from covid.model import load_data, CovidUKStochastic
from covid.util import sanitise_parameter, sanitise_settings
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
###########
# TF Bits #
......@@ -75,13 +76,19 @@ with open("stochastic_sim_covid.pkl", "rb") as f:
example_sim = pkl.load(f)
event_tensor = example_sim["events"] # shape [T, M, S, S]
event_tensor = event_tensor[:80, ...]
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
state_init = example_sim["state_init"]
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
se_events = event_tensor[:, :, 0, 1] # [T, M, X]
ei_events = event_tensor[:, :, 1, 2] # [T, M, X]
ir_events = event_tensor[:, :, 2, 3] # [T, M, X]
ir_events = np.pad(ir_events, ((4, 0), (0, 0)), mode="constant", constant_values=0.0)
ei_events = np.roll(ir_events, shift=-2, axis=0)
se_events = np.roll(ir_events, shift=-4, axis=0)
ei_events[-2:, ...] = 0.0
se_events[-4:, ...] = 0.0
##########################
# Log p and MCMC kernels #
......@@ -145,8 +152,18 @@ def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
return kernel_func
def make_occults_step():
pass
def make_occults_step(target_event_id):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedOccultUpdate(
target_log_prob_fn=logp,
target_event_id=target_event_id,
nmax=config["mcmc"]["occult_nmax"],
),
name="occult_update",
)
return kernel_func
def is_accepted(result):
......@@ -160,7 +177,7 @@ def trace_results_fn(results):
accepted = is_accepted(results)
q_ratio = results.proposed_results.log_acceptance_correction
if hasattr(results.proposed_results, "extra"):
proposed = results.proposed_results.extra
proposed = tf.cast(results.proposed_results.extra, log_prob.dtype)
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
......@@ -179,22 +196,31 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
par_func = make_parameter_kernel(par_scale, 0.95)
se_func = make_events_step(0, None, 1)
ei_func = make_events_step(1, 0, 2)
se_occult = make_occults_step(0)
ei_occult = make_occults_step(1)
# Based on Gibbs idea posted by Pavel Sountsov
# https://github.com/tensorflow/probability/issues/495
par_results = par_func(
lambda p: logp(p, init_state[1], init_state[2])
).bootstrap_results(init_state[0])
se_results = se_func(
lambda s: logp(init_state[0], s, init_state[2])
).bootstrap_results(init_state[1])
ei_results = ei_func(
lambda s: logp(init_state[0], s, init_state[2])
).bootstrap_results(init_state[1])
results = [par_results, se_results, ei_results]
results = [
par_func(lambda p: logp(p, init_state[1], init_state[2])).bootstrap_results(
init_state[0]
),
se_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
init_state[1]
),
ei_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
init_state[1]
),
se_occult(
lambda s: logp(init_state[0], init_state[1], s)
).bootstrap_results(init_state[2]),
ei_occult(
lambda s: logp(init_state[0], init_state[1], s)
).bootstrap_results(init_state[2]),
]
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(5)]
def body(i, state, results, sample_accum, results_accum):
# Parameters
......@@ -207,19 +233,31 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
)
# States
results[2] = forward_results(results[0], results[2])
results[4] = forward_results(results[0], results[4])
def infec_body(j, state, results):
def state_logp(event_state):
state[1] = event_state
return logp(*state)
def occult_logp(occult_state):
state[2] = occult_state
return logp(*state)
state[1], results[1] = se_func(state_logp).one_step(
state[1], forward_results(results[2], results[1])
state[1], forward_results(results[4], results[1])
)
state[1], results[2] = ei_func(state_logp).one_step(
state[1], forward_results(results[1], results[2])
)
state[2], results[3] = se_occult(occult_logp).one_step(
state[2], forward_results(results[2], results[3])
)
# results[3] = forward_results(results[2], results[3])
state[2], results[4] = ei_occult(occult_logp).one_step(
state[2], forward_results(results[3], results[4])
)
# results[4] = forward_results(results[3], results[4])
j += 1
return j, state, results
......@@ -282,32 +320,48 @@ par_samples = posterior.create_dataset(
[NUM_BURSTS * NUM_BURST_SAMPLES, current_state[0].shape[0]],
dtype=np.float64,
)
se_samples = posterior.create_dataset(
event_samples = posterior.create_dataset(
"samples/events",
event_size,
dtype=DTYPE,
chunks=(1000,) + tuple(event_size[1:]),
chunks=(min(NUM_BURSTS * NUM_BURST_SAMPLES, 1000),) + tuple(event_size[1:]),
compression="gzip",
compression_opts=1,
)
par_results = posterior.create_dataset(
"acceptance/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
)
se_results = posterior.create_dataset(
"acceptance/S->E",
(NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
dtype=DTYPE,
)
ei_results = posterior.create_dataset(
"acceptance/E->I",
(NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
occult_samples = posterior.create_dataset(
"samples/occults",
event_size,
dtype=DTYPE,
chunks=(min(NUM_BURSTS * NUM_BURST_SAMPLES, 1000),) + tuple(event_size[1:]),
compression="gzip",
compression_opts=1,
)
output_results = [
posterior.create_dataset(
"results/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
),
posterior.create_dataset(
"results/move/S->E",
(NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
dtype=DTYPE,
),
posterior.create_dataset(
"results/move/E->I",
(NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
dtype=DTYPE,
),
posterior.create_dataset(
"results/occult/S->E", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
),
posterior.create_dataset(
"results/occult/E->I", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
),
]
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.0000001
)
# We loop over successive calls to sample because we have to dump results
......@@ -328,21 +382,34 @@ for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
rowvar=False,
)
print(current_state[0].numpy())
print(cov)
if np.all(np.isfinite(cov)):
par_scale = 2.0 ** 2 * cov / 2.0
se_samples[s, ...] = samples[1].numpy()
par_results[s, ...] = results[0].numpy()
se_results[s, ...] = results[1].numpy()
ei_results[s, ...] = results[2].numpy()
event_samples[s, ...] = samples[1].numpy()
occult_samples[s, ...] = samples[2].numpy()
for i, ro in enumerate(output_results):
ro[s, ...] = results[i]
print("Acceptance0:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
print("Acceptance1:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
print("Acceptance2:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))
print("Acceptance par:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
print(
"Acceptance move S->E:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32))
)
print(
"Acceptance move E->I:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32))
)
print(
"Acceptance occult S->E:", tf.reduce_mean(tf.cast(results[3][:, 1], tf.float32))
)
print(
"Acceptance occult E->I:", tf.reduce_mean(tf.cast(results[4][:, 1], tf.float32))
)
print(f"Acceptance param: {par_results[:, 1].mean()}")
print(f"Acceptance S->E: {se_results[:, 1].mean()}")
print(f"Acceptance E->I: {ei_results[:, 1].mean()}")
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
posterior.close()
......@@ -34,13 +34,14 @@ settings:
- 2020-08-01
mcmc:
dmax: 32
dmax: 16
nmax: 160
m: 1
num_event_time_updates: 1
occult_nmax: 250
num_event_time_updates: 100
num_bursts: 100
num_burst_samples: 100
output:
posterior: posterior_medium_upd1.h5
posterior: posterior_medium_xla.h5
simulation: covid_ode.hd5
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment