Commit e52283b6 authored by Chris Jewell's avatar Chris Jewell
Browse files

Refactored model specification

Changes:

1. Created a TFP JointDistribution to represent full probability model;
2. Renamed CovidUKStochastic --> DiscreteTimeStateTransitionModel;
3. DiscreteTimeStateTransitionModel now inherits from tfp.Distribution.
parent 9120a8ec
......@@ -17,13 +17,14 @@ class Categorical2(tfd.Categorical):
https://github.com/tensorflow/tensorflow/issues/40606"""
def _log_prob(self, k):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
with tf.name_scope("Cat2log_prob"):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype
)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
......@@ -101,7 +101,6 @@ def discrete_markov_log_prob(events, init_state, hazard_fn, time_step, stoichiom
num_times = events.shape[-2]
num_events = events.shape[-1]
num_states = stoichiometry.shape[-1]
state_timeseries = compute_state(init_state, events, stoichiometry) # MxTxS
tms_timeseries = tf.transpose(state_timeseries, perm=(1, 0, 2))
......
......@@ -3,6 +3,7 @@
import numpy as np
import tensorflow as tf
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
from tensorflow_probability.python.internal import prefer_static as ps
def which(predicate):
......@@ -59,8 +60,52 @@ def compute_state(initial_state, events, stoichiometry):
:return: a tensor of shape [M, T, S] describing the state of the
system for each batch M at time T.
"""
stoichiometry = tf.convert_to_tensor(stoichiometry, dtype=events.dtype)
increments = tf.tensordot(events, stoichiometry, axes=[[-1], [-2]]) # mtx,xs->mts
cum_increments = tf.cumsum(increments, axis=-2, exclusive=True)
state = cum_increments + tf.expand_dims(initial_state, axis=-2)
return state
def transition_coords(stoichiometry):
src = np.where(stoichiometry == -1)[1]
dest = np.where(stoichiometry == 1)[1]
return np.stack([src, dest], axis=-1)
def batch_gather(tensor, indices):
"""Written by Chris Suter (c) 2020
Modified by Chris Jewell, 2020
"""
tensor_shape = ps.shape(tensor) # B + E
tensor_rank = ps.rank(tensor)
indices_shape = ps.shape(indices) # [N, E]
num_outputs = indices_shape[0]
non_batch_rank = indices_shape[1] # r(E)
batch_rank = tensor_rank - non_batch_rank
# batch_shape = tf.cast(tensor_shape[:batch_rank], dtype=tf.int64)
# batch_size = tf.reduce_prod(batch_shape)
# Create indices into batch_shape, of shape [batch_size, batch_rank]
# batch_indices = tf.transpose(
# tf.unravel_index(tf.range(batch_size), dims=batch_shape)
# )
batch_shape = tensor_shape[:batch_rank]
batch_size = np.prod(batch_shape)
batch_indices = np.transpose(
np.unravel_index(np.arange(batch_size), dims=batch_shape)
)
# Tile the batch indices num_outputs times
batch_indices_tiled = tf.reshape(
tf.tile(batch_indices, multiples=[1, num_outputs]),
[batch_size * num_outputs, -1],
)
batched_output_indices = tf.tile(indices, multiples=[batch_size, 1])
full_indices = tf.concat([batch_indices_tiled, batched_output_indices], axis=-1)
output_shape = ps.concat([batch_shape, [num_outputs]], axis=0)
return tf.reshape(tf.gather_nd(tensor, full_indices), output_shape)
......@@ -2,10 +2,12 @@
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import prefer_static as ps
import numpy as np
from covid import config
from covid.impl.util import make_transition_matrix
from covid.impl.util import make_transition_matrix, batch_gather, transition_coords
from covid.rdata import load_age_mixing
from covid.pydata import load_commute_volume, load_mobility_matrix, load_population
from covid.impl.discrete_markov import (
......@@ -14,6 +16,7 @@ from covid.impl.discrete_markov import (
)
tla = tf.linalg
tfd = tfp.distributions
DTYPE = config.floatX
......@@ -76,7 +79,7 @@ def load_data(paths, settings, dtype=DTYPE):
}
class CovidUKStochastic:
class DiscreteTimeStateTransitionModel(tfd.Distribution):
def __init__(
self,
transition_rates,
......@@ -85,6 +88,9 @@ class CovidUKStochastic:
initial_step,
time_delta,
num_steps,
validate_args=False,
allow_nan_stats=True,
name="DiscreteTimeStateTransitionModel",
):
"""Implements a discrete-time Markov jump process for a state transition model.
......@@ -96,68 +102,92 @@ class CovidUKStochastic:
:param num_steps: the number of time steps across which the model runs.
"""
self.transition_rates = transition_rates
self.stoichiometry = stoichiometry
self.initial_state = initial_state
self.initial_step = initial_step
self.time_delta = time_delta
self.num_steps = num_steps
def ngm(self, t, state, param):
"""Computes a next generation matrix -- pressure from i to j is G_{ij}
:param t: the time step
:param state: a tensor of shape [M, S] for S states and M population strata.
States are S, E, I, R.
:return: a tensor of shape [M, M] giving the expected number of new cases of
disease individuals in each metapopulation give rise to.
"""
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.W.shape[0] - 1)
commute_volume = tf.gather(self.W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // self.xi_freq, 0, self.params["xi"].shape[0] - 1),
dtype=tf.int64,
)
xi = tf.gather(self.params["xi"], xi_idx)
beta = param["beta1"] * tf.math.exp(xi)
parameters = dict(locals())
with tf.name_scope(name) as name:
self._transition_rates = transition_rates
self._stoichiometry = np.array(stoichiometry, dtype=DTYPE)
self._initial_state = initial_state
self._initial_step = initial_step
self._time_delta = time_delta
self._num_steps = num_steps
super().__init__(
dtype=initial_state.dtype,
reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name,
)
ngm = beta * (
tf.eye(self.C.shape[0], dtype=state.dtype)
+ param["beta2"] * commute_volume * self.C / self.N[tf.newaxis, :]
)
ngm = (
ngm
* state[..., 0][..., tf.newaxis]
/ (self.N[:, tf.newaxis] * param["gamma"])
self.dtype = initial_state.dtype
@property
def transition_rates(self):
return self._transition_rates
@property
def stoichiometry(self):
return self._stoichiometry
@property
def initial_state(self):
return self._initial_state
@property
def initial_step(self):
return self._initial_step
@property
def time_delta(self):
return self._time_delta
@property
def num_steps(self):
return self._num_steps
def _batch_shape(self):
return tf.TensorShape([])
def _event_shape(self):
shape = tf.TensorShape(
[
self.initial_state.shape[0],
tf.get_static_value(self._num_steps),
self._stoichiometry.shape[0],
]
)
return ngm
return shape
def sample(self, seed=None):
def _sample_n(self, n, seed=None):
"""Runs a simulation from the epidemic model
:param param: a dictionary of model parameters
:param state_init: the initial state
:returns: a tuple of times and simulated states.
"""
t, sim = discrete_markov_simulation(
hazard_fn=self.transition_rates,
state=self.initial_state,
start=self.initial_step,
end=self.initial_step + self.num_steps * self.time_delta,
time_step=self.time_delta,
seed=seed,
)
return t, sim
with tf.name_scope("DiscreteTimeStateTransitionModel.log_prob"):
t, sim = discrete_markov_simulation(
hazard_fn=self.transition_rates,
state=self.initial_state,
start=self.initial_step,
end=self.initial_step + self.num_steps * self.time_delta,
time_step=self.time_delta,
seed=seed,
)
indices = transition_coords(self.stoichiometry)
sim = batch_gather(sim, indices)
sim = tf.transpose(sim, perm=(1, 0, 2))
return tf.expand_dims(sim, 0)
def log_prob(self, y):
def _log_prob(self, y, **kwargs):
"""Calculates the log probability of observing epidemic events y
:param y: a list of tensors. The first is of shape [n_times] containing times,
the second is of shape [n_times, n_states, n_states] containing event matrices.
:param param: a list of parameters
:returns: a scalar giving the log probability of the epidemic
"""
dtype = dtype = dtype_util.common_dtype(
[y, self.initial_state], dtype_hint=DTYPE
)
dtype = dtype_util.common_dtype([y, self.initial_state], dtype_hint=DTYPE)
y = tf.convert_to_tensor(y, dtype)
with tf.name_scope("CovidUKStochastic.log_prob"):
hazard = self.transition_rates
......
......@@ -129,7 +129,6 @@ def phe_case_data(linelisting_file, date_range=None, date_type="specimen", pilla
index = pd.MultiIndex.from_product(
[full_dates, all_regions], names=["date", "region_code"]
)
print(index)
case_counts = ts.reindex(index)
case_counts.loc[case_counts.isna()] = 0.0
......
......@@ -11,7 +11,7 @@ import tqdm
import yaml
from covid import config
from covid.model import load_data, CovidUKStochastic
from covid.model import load_data, DiscreteTimeStateTransitionModel
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
from covid.impl.util import compute_state
......@@ -21,6 +21,8 @@ from covid.impl.occult_events_mh import UncalibratedOccultUpdate, TransitionTopo
from covid.impl.gibbs import DeterministicScanKernel, GibbsStep, flatten_results
from covid.impl.multi_scan_kernel import MultiScanKernel
from model_spec import CovidUK
###########
# TF Bits #
###########
......@@ -65,7 +67,11 @@ covar_data = load_data(config["data"], settings, DTYPE)
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
cases = phe_case_data(config["data"]["reported_cases"], date_range=settings["inference_period"], date_type='report')
cases = phe_case_data(
config["data"]["reported_cases"],
date_range=settings["inference_period"],
date_type="report",
)
ei_events, lag_ei = impute_previous_cases(cases, 0.44)
se_events, lag_se = impute_previous_cases(ei_events, 2.0)
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
......@@ -89,86 +95,28 @@ num_xi = events.shape[1] // xi_freq
num_metapop = covar_data["pop"].shape[0]
# Create the epidemic model given parameters
def build_epidemic(param):
def transition_rates(t, state):
C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
C = tf.linalg.set_diag(C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE))
W = tf.constant(covar_data["W"], dtype=DTYPE)
N = tf.constant(covar_data["pop"], dtype=DTYPE)
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // 14, 0, param["xi"].shape[0] - 1), dtype=tf.int64,
)
xi = tf.gather(param["xi"], xi_idx)
beta = param["beta1"] * tf.math.exp(xi)
infec_rate = beta * (
state[..., 2]
+ param["beta2"] * commute_volume * tf.linalg.matvec(C, state[..., 2] / N)
)
infec_rate = infec_rate / N + 0.000000001 # Vector of length nc
ei = tf.broadcast_to(
[param["nu"]], shape=[state.shape[0]]
) # Vector of length nc
ir = tf.broadcast_to(
[param["gamma"]], shape=[state.shape[0]]
) # Vector of length nc
return [infec_rate, ei, ir]
return CovidUKStochastic(
transition_rates=transition_rates,
stoichiometry=STOICHIOMETRY,
initial_state=initial_state,
initial_step=0,
time_delta=1.0,
num_steps=events.shape[1],
)
model = CovidUK(
covariates=covar_data,
xi_freq=14,
initial_state=initial_state,
initial_step=0,
num_steps=events.shape[1],
)
##########################
# Log p and MCMC kernels #
##########################
def logp(theta, xi, events):
p = param
p["beta1"] = tf.convert_to_tensor(theta[0], dtype=DTYPE)
p["beta2"] = tf.convert_to_tensor(theta[1], dtype=DTYPE)
p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
beta1 = tfd.Gamma(
concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
)
sigma = tf.constant(0.01, dtype=DTYPE)
phi = tf.constant(12.0, dtype=DTYPE)
kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
idx_pts = tf.cast(tf.range(events.shape[1] // xi_freq) * xi_freq, dtype=DTYPE)
xi = tfd.GaussianProcess(kernel, index_points=idx_pts[:, tf.newaxis])
spatial_beta = tfd.Gamma(
concentration=tf.constant(3.0, dtype=DTYPE), rate=tf.constant(10.0, dtype=DTYPE)
)
gamma = tfd.Gamma(
concentration=tf.constant(100.0, dtype=DTYPE),
rate=tf.constant(400.0, dtype=DTYPE),
)
with tf.name_scope("epidemic_log_posterior"):
seir = build_epidemic(p)
return (
beta1.log_prob(p["beta1"])
+ xi.log_prob(p["xi"])
+ spatial_beta.log_prob(p["beta2"])
+ gamma.log_prob(p["gamma"])
+ seir.log_prob(events)
return model.log_prob(
dict(
beta1=theta[0],
beta2=theta[1],
gamma=theta[2],
xi=xi,
nu=param["nu"],
seir=events,
)
)
......
"""Implements the COVID SEIR model as a TFP Joint Distribution"""
import tensorflow as tf
import tensorflow_probability as tfp
from covid.config import floatX
from covid.model import DiscreteTimeStateTransitionModel
tfd = tfp.distributions
DTYPE = floatX
STOICHIOMETRY = tf.constant([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
TIME_DELTA = 1.0
def CovidUK(covariates, xi_freq, initial_state, initial_step, num_steps):
def beta1():
return tfd.Gamma(
concentration=tf.constant(1.0, dtype=DTYPE),
rate=tf.constant(1.0, dtype=DTYPE),
)
def beta2():
return tfd.Gamma(
concentration=tf.constant(3.0, dtype=DTYPE),
rate=tf.constant(10.0, dtype=DTYPE),
)
def xi():
sigma = tf.constant(0.01, dtype=DTYPE)
phi = tf.constant(12.0, dtype=DTYPE)
kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
idx_pts = tf.cast(tf.range(num_steps // xi_freq) * xi_freq, dtype=DTYPE)
return tfd.GaussianProcess(kernel, index_points=idx_pts[:, tf.newaxis])
def nu():
return tfd.Gamma(
concentration=tf.constant(1.0, dtype=DTYPE),
rate=tf.constant(1.0, dtype=DTYPE),
)
def gamma():
return tfd.Gamma(
concentration=tf.constant(100.0, dtype=DTYPE),
rate=tf.constant(400.0, dtype=DTYPE),
)
def seir(beta1, beta2, xi, nu, gamma):
beta1 = tf.convert_to_tensor(beta1, DTYPE)
beta2 = tf.convert_to_tensor(beta2, DTYPE)
xi = tf.convert_to_tensor(xi, DTYPE)
nu = tf.convert_to_tensor(nu, DTYPE)
gamma = tf.convert_to_tensor(gamma, DTYPE)
def transition_rate_fn(t, state):
C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
C = tf.linalg.set_diag(
C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
)
W = tf.constant(covariates["W"], dtype=DTYPE)
N = tf.constant(covariates["pop"], dtype=DTYPE)
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // 14, 0, xi.shape[0] - 1), dtype=tf.int64,
)
xi_ = tf.gather(xi, xi_idx)
beta = beta1 * tf.math.exp(xi_)
infec_rate = beta * (
state[..., 2]
+ beta2 * commute_volume * tf.linalg.matvec(C, state[..., 2] / N)
)
infec_rate = infec_rate / N + 0.000000001 # Vector of length nc
ei = tf.broadcast_to([nu], shape=[state.shape[0]]) # Vector of length nc
ir = tf.broadcast_to([gamma], shape=[state.shape[0]]) # Vector of length nc
return [infec_rate, ei, ir]
return DiscreteTimeStateTransitionModel(
transition_rates=transition_rate_fn,
stoichiometry=STOICHIOMETRY,
initial_state=initial_state,
initial_step=initial_step,
time_delta=TIME_DELTA,
num_steps=num_steps,
)
return tfd.JointDistributionNamed(
dict(beta1=beta1, beta2=beta2, xi=xi, nu=nu, gamma=gamma, seir=seir)
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment