Commit e3744cc2 by Chris Jewell

### Chain binomial simulation is now running, but seems very slow.

parent cc975ee0
 ... ... @@ -12,13 +12,13 @@ def chain_binomial_propagate(h, time_step): :param time_step: the time step :returns : a function that propagate state[t] -> state[t+time_step] """ def propagate_fn(state): def propagate_fn(t, state): # State is assumed to be of shape [s, n] where s is the number of states # and n is the number of population strata. # TODO: having state as [s, n] means we have to do some funky transposition. It may be better # to have state.shape = [n, s] which avoids transposition below, but may lead to slower # rate calculations. rate_matrix = h(state) rate_matrix = h(t, state) rate_matrix = tf.transpose(rate_matrix, perm=[2, 0, 1]) # Set diagonal to be the negative of the sum of other elements in each row rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=2)) ... ... @@ -41,7 +41,7 @@ def chain_binomial_simulate(hazard_fn, state, start, end, time_step): output = output.write(0, state) for i in tf.range(1, times.shape[0]): state = propagate(state) state = propagate(i, state) output = output.write(i, state) sim = output.gather(tf.range(times.shape[0])) ... ...
 ... ... @@ -38,7 +38,7 @@ def dense_to_block_diagonal(A, n_blocks): return A_block class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. \epsilon term. class CovidUK: def __init__(self, M_tt: np.float64, M_hh: np.float64, ... ... @@ -47,7 +47,7 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. N: np.float64, date_range: list, holidays: list, t_step: np.int64): time_step: np.int64): """Represents a CovidUK ODE model :param K_tt: a MxM matrix of age group mixing in term time ... ... @@ -91,14 +91,35 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=dtype) self.N_sum = tf.reshape(N_sum, [-1]) self.times = np.arange(date_range[0], date_range[1], np.timedelta64(t_step, 'D')) self.time_step = time_step self.times = np.arange(date_range[0], date_range[1], np.timedelta64(int(time_step), 'D')) self.m_select = np.int64((self.times >= holidays[0]) & (self.times < holidays[1])) self.max_t = self.m_select.shape[0] - 1 def create_initial_state(self, init_matrix=None): if init_matrix is None: I = np.zeros(self.N.shape, dtype=np.float64) I[149*17+10] = 30. # Middle-aged in Surrey else: np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages], err_msg=f"init_matrix does not have shape [,] \ ({self.n_lads},{self.n_ages})") I = init_matrix.flatten() S = self.N - I E = np.zeros(self.N.shape, dtype=np.float64) R = np.zeros(self.N.shape, dtype=np.float64) return np.stack([S, E, I, R]) class CovidUKODE(CovidUK): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.solver = tode.DormandPrince() def make_h(self, param): def h_fn(t, state): ... ... @@ -125,20 +146,6 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. return h_fn def create_initial_state(self, init_matrix=None): if init_matrix is None: I = np.zeros(self.N.shape, dtype=np.float64) I[149*17+10] = 30. # Middle-aged in Surrey else: np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages], err_msg=f"init_matrix does not have shape [,] \ ({self.n_lads},{self.n_ages})") I = init_matrix.flatten() S = self.N - I E = np.zeros(self.N.shape, dtype=np.float64) R = np.zeros(self.N.shape, dtype=np.float64) return np.stack([S, E, I, R]) def simulate(self, param, state_init, solver_state=None): h = self.make_h(param) t = np.arange(self.times.shape[0]) ... ... @@ -168,3 +175,57 @@ def covid19uk_logp(y, sim, phi): r_incr = tf.reduce_sum(r_incr, axis=1) y_ = tfp.distributions.Poisson(rate=phi*r_incr) return y_.log_prob(y) class CovidUKStochastic(CovidUK): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def make_h(self, param): """Constructs a function that takes state and outputs a transition rate matrix (with 0 diagonal). """ def h(t, state): """Computes a transition rate matrix :param state: a tensor of shape [ns, nc] for ns states and nc population strata. States are S, E, I, R. We arrange the state like this because the state vectors are then arranged contiguously in memory for fast calculation below. :return a tensor of shape [ns, ns, nc] containing transition matric for each i=0,...,(c-1) """ t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t) m_switch = tf.gather(self.m_select, t_idx) commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega']) infec_rate = param['beta1'] * ( tf.gather(self.M.matvec(state[2]), m_switch) + param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[2] / self.N_sum)) infec_rate = infec_rate / self.N ei = tf.broadcast_to([param['nu']], shape=[state.shape[1]]) ir = tf.broadcast_to([param['gamma']], shape=[state.shape[1]]) # Scatter rates into a [ns, ns, nc] tensor rates = [infec_rate, ei, ir] rates = tf.scatter_nd(updates=rates, indices=[[0, 1], [1, 2], [2, 3]], shape=[state.shape[0], state.shape[0], state.shape[1]]) return rates return h @tf.function def simulate(self, param, state_init): """Runs a simulation from the epidemic model :param param: a dictionary of model parameters :param state_init: the initial state :returns: a tuple of times and simulated states. """ param = {k: tf.constant(v, dtype=tf.float64) for k, v in param.items()} hazard = self.make_h(param) t, sim = chain_binomial_simulate(hazard, state_init, np.float64(0.), np.float64(self.times.shape[0]), self.time_step) return t, sim
 import optparse import time import tensorflow as tf import matplotlib.pyplot as plt import yaml from covid.model import CovidUKStochastic from covid.rdata import * from covid.pydata import load_commute_volume from covid.util import sanitise_parameter, sanitise_settings, seed_areas DTYPE = np.float64 def sum_age_groups(sim): infec = sim[:, 2, :] infec = infec.reshape([infec.shape[0], 152, 17]) infec_uk = infec.sum(axis=2) return infec_uk def sum_la(sim): infec = sim[:, 2, :] infec = infec.reshape([infec.shape[0], 152, 17]) infec_uk = infec.sum(axis=1) return infec_uk def sum_total_removals(sim): remove = sim[:, 3, :] return remove.sum(axis=1) def final_size(sim): remove = sim[:, 3, :] remove = remove.reshape([remove.shape[0], 152, 17]) fs = remove[-1, :, :].sum(axis=0) return fs def plot_total_curve(sim): infec_uk = sum_la(sim) infec_uk = infec_uk.sum(axis=1) removals = sum_total_removals(sim) times = np.datetime64('2020-02-20') + np.arange(removals.shape[0]) plt.plot(times, infec_uk, 'r-', label='Infected') plt.plot(times, removals, 'b-', label='Removed') plt.title('UK total cases') plt.xlabel('Date') plt.ylabel('Num infected or removed') plt.grid() plt.legend() def plot_infec_curve(ax, sim, label): infec_uk = sum_la(sim) infec_uk = infec_uk.sum(axis=1) times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0]) ax.plot(times, infec_uk, '-', label=label) def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None): if ax is None: ax = plt.figure().gca() infec_uk = sum_la(sim) total_uk = infec_uk.mean(axis=1) t = t0 + np.arange(infec_uk.shape[0]) colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1])) for i in range(infec_uk.shape[1]): ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i]) ax.plot(t, total_uk, '-', color='black', label='Mean') return ax def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None): if ax is None: ax = plt.figure().gca() infec_uk = sum_age_groups(sim) total_uk = infec_uk.mean(axis=1) t = t0 + np.arange(infec_uk.shape[0]) colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1])) for i in range(infec_uk.shape[1]): ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i]) ax.plot(t, total_uk, '-', color='black', label='Mean') return ax def draw_figs(sim, N): # Attack rate N = N.reshape([152, 17]).sum(axis=0) fs = final_size(sim) attack_rate = fs / N print("Attack rate:", attack_rate) print("Overall attack rate: ", np.sum(fs) / np.sum(N)) # Total UK epidemic curve plot_total_curve(sim) plt.xticks(rotation=45, horizontalalignment="right") plt.savefig('total_uk_curve.pdf') plt.show() # TotalUK epidemic curve by age-group fig, ax = plt.subplots(1, 2, figsize=[24, 12]) plot_by_la(sim, la_names, ax=ax[0]) plot_by_age(sim, age_groups, ax=ax[1]) ax[1].legend() plt.xticks(rotation=45, horizontalalignment="right") fig.autofmt_xdate() plt.savefig('la_age_infec_curves.pdf') plt.show() # Plot attack rate plt.figure(figsize=[4, 2]) plt.plot(age_groups, attack_rate, 'o-') plt.xticks(rotation=90) plt.title('Age-specific attack rate') plt.savefig('age_attack_rate.pdf') plt.show() def doubling_time(t, sim, t1, t2): t1 = np.where(t == np.datetime64(t1))[0] t2 = np.where(t == np.datetime64(t2))[0] delta = t2 - t1 r = sum_total_removals(sim) q1 = r[t1] q2 = r[t2] return delta * np.log(2) / np.log(q2 / q1) def plot_age_attack_rate(ax, sim, N, label): Ns = N.reshape([152, 17]).sum(axis=0) fs = final_size(sim.numpy()) attack_rate = fs / Ns ax.plot(age_groups, attack_rate, 'o-', label=label) if __name__ == '__main__': parser = optparse.OptionParser() parser.add_option("--config", "-c", dest="config", default="ode_config.yaml", help="configuration file") options, args = parser.parse_args() with open(options.config, 'r') as ymlfile: config = yaml.load(ymlfile) param = sanitise_parameter(config['parameter']) settings = sanitise_settings(config['settings']) parser = optparse.OptionParser() parser.add_option("--config", "-c", dest="config", default="ode_config.yaml", help="configuration file") options, args = parser.parse_args() with open(options.config, 'r') as ymlfile: config = yaml.load(ymlfile) param = sanitise_parameter(config['parameter']) settings = sanitise_settings(config['settings']) M_tt, age_groups = load_age_mixing(config['data']['age_mixing_matrix_term']) M_hh, _ = load_age_mixing(config['data']['age_mixing_matrix_hol']) C, la_names = load_mobility_matrix(config['data']['mobility_matrix']) np.fill_diagonal(C, 0.) W = load_commute_volume(config['data']['commute_volume'], settings['inference_period'])['percent'] N, n_names = load_population(config['data']['population_size']) M_tt = M_tt.astype(DTYPE) M_hh = M_hh.astype(DTYPE) W = W.to_numpy().astype(DTYPE) C = C.astype(DTYPE) N = N.astype(DTYPE) model = CovidUKStochastic(M_tt=M_tt, M_hh=M_hh, C=C, N=N, W=W, date_range=settings['prediction_period'], holidays=settings['holiday'], time_step=1.) seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size state_init = model.create_initial_state(init_matrix=seeding) with tf.device('CPU'): start = time.perf_counter() t, sim = model.simulate(param, state_init) end = time.perf_counter() print(f'Complete in {end - start} seconds') # Plotting functions dates = settings['start'] + t.numpy().astype(np.timedelta64) dt = doubling_time(dates, sim.numpy(), '2020-03-01', '2020-03-31') print(f"Doubling time: {dt}") fig_attack = plt.figure() fig_uk = plt.figure() plot_age_attack_rate(fig_attack.gca(), sim, N, "Attack Rate") fig_attack.suptitle("Attack Rate") plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections") fig_uk.suptitle("UK Infections") fig_attack.autofmt_xdate() fig_uk.autofmt_xdate() fig_attack.gca().grid(True) fig_uk.gca().grid(True) plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!