Commit 043fa4f5 by Chris Jewell

### Working MCMC for known event times.

parent 2d423d9e
 ... ... @@ -77,11 +77,13 @@ def discrete_markov_log_prob(events, init_state, hazard_fn, time_step): states = tf.concat([[init_state], tf.reduce_sum(events, axis=-2)], axis=-3)[:-1] t = tf.range(states.shape[-3]) rate_matrix = hazard_fn(t, states) rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=-1)) markov_transition = tf.linalg.expm(rate_matrix*time_step) log_mt = tf.math.log(markov_transition) idx = events > 0. # Todo: probably should check for x>0 when p==0. logp = tf.reduce_sum(events[idx] * log_mt[idx]) return logp def log_prob_t(a, elems): t, event, state = elems rate_matrix = hazard_fn(t, state) rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=-1)) markov_transition = tf.linalg.expm(rate_matrix*time_step) logp = tfd.Multinomial(state, probs=markov_transition).log_prob(event) return a + tf.reduce_sum(logp) return tf.foldl(log_prob_t, (t, events, states), initializer=tf.constant(0., events.dtype))
 ... ... @@ -4,9 +4,10 @@ import tensorflow_probability as tfp from tensorflow_probability.python.internal import dtype_util import numpy as np from covid.impl.util import make_transition_rate_matrix from covid.rdata import load_mobility_matrix, load_population, load_age_mixing from covid.pydata import load_commute_volume from covid.impl.discrete_markov import discrete_markov_simulation from covid.impl.discrete_markov import discrete_markov_simulation, discrete_markov_log_prob tode = tfp.math.ode tla = tf.linalg ... ... @@ -236,27 +237,18 @@ class CovidUKStochastic(CovidUK): t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t) m_switch = tf.gather(self.m_select, t_idx) commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega']) lockdown = tf.gather(self.lockdown_select, t_idx) beta = tf.where(lockdown == 0, param['beta1'], param['beta1'] * param['beta3']) infec_rate = param['beta1'] * ( tf.gather(self.M.matvec(state[:, 2]), m_switch) + param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[:, 2] / self.N_sum)) infec_rate = beta * ( tf.gather(self.M.matvec(state[..., 2]), m_switch) + param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[..., 2] / self.N_sum)) infec_rate = infec_rate / self.N # Vector of length nc ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]]) # Vector of length nc ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]]) # Vector of length nc # Scatter rates into a [nc, ns, ns] tensor n = state.shape[0] b = tf.stack([tf.range(n), tf.zeros(n, dtype=tf.int32), tf.ones(n, dtype=tf.int32)], axis=-1) indices = tf.stack([b, b + [0, 1, 1], b + [0, 2, 2]], axis=-2) # Un-normalised rate matrix (diag is 0 here) rate_matrix = tf.scatter_nd(indices=indices, updates=tf.stack([infec_rate, ei, ir], axis=-1), shape=[state.shape[0], state.shape[1], state.shape[1]]) # Tensor of dim [nc, ns, ns] ei = tf.broadcast_to([tf.convert_to_tensor(param['nu'])], shape=[state.shape[0]]) # Vector of length nc ir = tf.broadcast_to([tf.convert_to_tensor(param['gamma'])], shape=[state.shape[0]]) # Vector of length nc rate_matrix = make_transition_rate_matrix([infec_rate, ei, ir], [[0, 1], [1, 2], [2, 3]], state) return rate_matrix return h ... ... @@ -273,3 +265,13 @@ class CovidUKStochastic(CovidUK): t, sim = discrete_markov_simulation(hazard, state_init, np.float64(0.), np.float64(self.times.shape[0]), self.time_step) return t, sim def log_prob(self, y, param, state_init): """Calculates the log probability of observing epidemic events y :param y: a list of tensors. The first is of shape [n_times] containing times, the second is of shape [n_times, n_states, n_states] containing event matrices. :param param: a list of parameters :returns: a scalar giving the log probability of the epidemic """ hazard = self.make_h(param) return discrete_markov_log_prob(y, state_init, hazard, self.time_step)
 import optparse import time import pickle as pkl import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions tfb = tfp.bijectors import numpy as np import matplotlib.pyplot as plt import yaml ... ... @@ -11,6 +16,21 @@ from covid.util import sanitise_parameter, sanitise_settings, seed_areas DTYPE = np.float64 def random_walk_mvnorm_fn(covariance, name=None): """Returns callable that adds Multivariate Normal noise to the input""" covariance = covariance + tf.eye(covariance.shape[0], dtype=tf.float64) * 1.e-9 scale_tril = tf.linalg.cholesky(covariance) rv = tfp.distributions.MultivariateNormalTriL(loc=tf.zeros(covariance.shape[0], dtype=tf.float64), scale_tril=scale_tril) def _fn(state_parts, seed): with tf.name_scope(name or 'random_walk_mvnorm_fn'): new_state_parts = [rv.sample() + state_part for state_part in state_parts] return new_state_parts return _fn def sum_age_groups(sim): infec = sim[:, 2, :] infec = infec.reshape([infec.shape[0], 152, 17]) ... ... @@ -199,3 +219,5 @@ if __name__ == '__main__': fig_uk.gca().grid(True) plt.show() with open('stochastic_sim.pkl', 'wb') as f: pkl.dump({'events': upd, 'state_init': state_init}, f)