Commit e3744cc2 authored by Chris Jewell's avatar Chris Jewell
Browse files

Chain binomial simulation is now running, but seems very slow.

parent cc975ee0
......@@ -12,13 +12,13 @@ def chain_binomial_propagate(h, time_step):
:param time_step: the time step
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
"""
def propagate_fn(state):
def propagate_fn(t, state):
# State is assumed to be of shape [s, n] where s is the number of states
# and n is the number of population strata.
# TODO: having state as [s, n] means we have to do some funky transposition. It may be better
# to have state.shape = [n, s] which avoids transposition below, but may lead to slower
# rate calculations.
rate_matrix = h(state)
rate_matrix = h(t, state)
rate_matrix = tf.transpose(rate_matrix, perm=[2, 0, 1])
# Set diagonal to be the negative of the sum of other elements in each row
rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=2))
......@@ -41,7 +41,7 @@ def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
output = output.write(0, state)
for i in tf.range(1, times.shape[0]):
state = propagate(state)
state = propagate(i, state)
output = output.write(i, state)
sim = output.gather(tf.range(times.shape[0]))
......
......@@ -38,7 +38,7 @@ def dense_to_block_diagonal(A, n_blocks):
return A_block
class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. \epsilon term.
class CovidUK:
def __init__(self,
M_tt: np.float64,
M_hh: np.float64,
......@@ -47,7 +47,7 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
N: np.float64,
date_range: list,
holidays: list,
t_step: np.int64):
time_step: np.int64):
"""Represents a CovidUK ODE model
:param K_tt: a MxM matrix of age group mixing in term time
......@@ -91,14 +91,35 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=dtype)
self.N_sum = tf.reshape(N_sum, [-1])
self.times = np.arange(date_range[0], date_range[1], np.timedelta64(t_step, 'D'))
self.time_step = time_step
self.times = np.arange(date_range[0], date_range[1], np.timedelta64(int(time_step), 'D'))
self.m_select = np.int64((self.times >= holidays[0]) &
(self.times < holidays[1]))
self.max_t = self.m_select.shape[0] - 1
def create_initial_state(self, init_matrix=None):
if init_matrix is None:
I = np.zeros(self.N.shape, dtype=np.float64)
I[149*17+10] = 30. # Middle-aged in Surrey
else:
np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
({self.n_lads},{self.n_ages})")
I = init_matrix.flatten()
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
class CovidUKODE(CovidUK):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.solver = tode.DormandPrince()
def make_h(self, param):
def h_fn(t, state):
......@@ -125,20 +146,6 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
return h_fn
def create_initial_state(self, init_matrix=None):
if init_matrix is None:
I = np.zeros(self.N.shape, dtype=np.float64)
I[149*17+10] = 30. # Middle-aged in Surrey
else:
np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
({self.n_lads},{self.n_ages})")
I = init_matrix.flatten()
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
def simulate(self, param, state_init, solver_state=None):
h = self.make_h(param)
t = np.arange(self.times.shape[0])
......@@ -168,3 +175,57 @@ def covid19uk_logp(y, sim, phi):
r_incr = tf.reduce_sum(r_incr, axis=1)
y_ = tfp.distributions.Poisson(rate=phi*r_incr)
return y_.log_prob(y)
class CovidUKStochastic(CovidUK):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def make_h(self, param):
"""Constructs a function that takes `state` and outputs a
transition rate matrix (with 0 diagonal).
"""
def h(t, state):
"""Computes a transition rate matrix
:param state: a tensor of shape [ns, nc] for ns states and nc population strata. States
are S, E, I, R. We arrange the state like this because the state vectors are then arranged
contiguously in memory for fast calculation below.
:return a tensor of shape [ns, ns, nc] containing transition matric for each i=0,...,(c-1)
"""
t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
m_switch = tf.gather(self.m_select, t_idx)
commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
infec_rate = param['beta1'] * (
tf.gather(self.M.matvec(state[2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[2] / self.N_sum))
infec_rate = infec_rate / self.N
ei = tf.broadcast_to([param['nu']], shape=[state.shape[1]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[1]])
# Scatter rates into a [ns, ns, nc] tensor
rates = [infec_rate, ei, ir]
rates = tf.scatter_nd(updates=rates,
indices=[[0, 1], [1, 2], [2, 3]],
shape=[state.shape[0], state.shape[0], state.shape[1]])
return rates
return h
@tf.function
def simulate(self, param, state_init):
"""Runs a simulation from the epidemic model
:param param: a dictionary of model parameters
:param state_init: the initial state
:returns: a tuple of times and simulated states.
"""
param = {k: tf.constant(v, dtype=tf.float64) for k, v in param.items()}
hazard = self.make_h(param)
t, sim = chain_binomial_simulate(hazard, state_init, np.float64(0.),
np.float64(self.times.shape[0]), self.time_step)
return t, sim
import optparse
import time
import tensorflow as tf
import matplotlib.pyplot as plt
import yaml
from covid.model import CovidUKStochastic
from covid.rdata import *
from covid.pydata import load_commute_volume
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
DTYPE = np.float64
def sum_age_groups(sim):
infec = sim[:, 2, :]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=2)
return infec_uk
def sum_la(sim):
infec = sim[:, 2, :]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=1)
return infec_uk
def sum_total_removals(sim):
remove = sim[:, 3, :]
return remove.sum(axis=1)
def final_size(sim):
remove = sim[:, 3, :]
remove = remove.reshape([remove.shape[0], 152, 17])
fs = remove[-1, :, :].sum(axis=0)
return fs
def plot_total_curve(sim):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
removals = sum_total_removals(sim)
times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
plt.plot(times, infec_uk, 'r-', label='Infected')
plt.plot(times, removals, 'b-', label='Removed')
plt.title('UK total cases')
plt.xlabel('Date')
plt.ylabel('Num infected or removed')
plt.grid()
plt.legend()
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label)
def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_la(sim)
total_uk = infec_uk.mean(axis=1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
ax.plot(t, total_uk, '-', color='black', label='Mean')
return ax
def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_age_groups(sim)
total_uk = infec_uk.mean(axis=1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
ax.plot(t, total_uk, '-', color='black', label='Mean')
return ax
def draw_figs(sim, N):
# Attack rate
N = N.reshape([152, 17]).sum(axis=0)
fs = final_size(sim)
attack_rate = fs / N
print("Attack rate:", attack_rate)
print("Overall attack rate: ", np.sum(fs) / np.sum(N))
# Total UK epidemic curve
plot_total_curve(sim)
plt.xticks(rotation=45, horizontalalignment="right")
plt.savefig('total_uk_curve.pdf')
plt.show()
# TotalUK epidemic curve by age-group
fig, ax = plt.subplots(1, 2, figsize=[24, 12])
plot_by_la(sim, la_names, ax=ax[0])
plot_by_age(sim, age_groups, ax=ax[1])
ax[1].legend()
plt.xticks(rotation=45, horizontalalignment="right")
fig.autofmt_xdate()
plt.savefig('la_age_infec_curves.pdf')
plt.show()
# Plot attack rate
plt.figure(figsize=[4, 2])
plt.plot(age_groups, attack_rate, 'o-')
plt.xticks(rotation=90)
plt.title('Age-specific attack rate')
plt.savefig('age_attack_rate.pdf')
plt.show()
def doubling_time(t, sim, t1, t2):
t1 = np.where(t == np.datetime64(t1))[0]
t2 = np.where(t == np.datetime64(t2))[0]
delta = t2 - t1
r = sum_total_removals(sim)
q1 = r[t1]
q2 = r[t2]
return delta * np.log(2) / np.log(q2 / q1)
def plot_age_attack_rate(ax, sim, N, label):
Ns = N.reshape([152, 17]).sum(axis=0)
fs = final_size(sim.numpy())
attack_rate = fs / Ns
ax.plot(age_groups, attack_rate, 'o-', label=label)
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
help="configuration file")
options, args = parser.parse_args()
with open(options.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
parser = optparse.OptionParser()
parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
help="configuration file")
options, args = parser.parse_args()
with open(options.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
M_tt, age_groups = load_age_mixing(config['data']['age_mixing_matrix_term'])
M_hh, _ = load_age_mixing(config['data']['age_mixing_matrix_hol'])
C, la_names = load_mobility_matrix(config['data']['mobility_matrix'])
np.fill_diagonal(C, 0.)
W = load_commute_volume(config['data']['commute_volume'], settings['inference_period'])['percent']
N, n_names = load_population(config['data']['population_size'])
M_tt = M_tt.astype(DTYPE)
M_hh = M_hh.astype(DTYPE)
W = W.to_numpy().astype(DTYPE)
C = C.astype(DTYPE)
N = N.astype(DTYPE)
model = CovidUKStochastic(M_tt=M_tt,
M_hh=M_hh,
C=C,
N=N,
W=W,
date_range=settings['prediction_period'],
holidays=settings['holiday'],
time_step=1.)
seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size
state_init = model.create_initial_state(init_matrix=seeding)
with tf.device('CPU'):
start = time.perf_counter()
t, sim = model.simulate(param, state_init)
end = time.perf_counter()
print(f'Complete in {end - start} seconds')
# Plotting functions
dates = settings['start'] + t.numpy().astype(np.timedelta64)
dt = doubling_time(dates, sim.numpy(), '2020-03-01', '2020-03-31')
print(f"Doubling time: {dt}")
fig_attack = plt.figure()
fig_uk = plt.figure()
plot_age_attack_rate(fig_attack.gca(), sim, N, "Attack Rate")
fig_attack.suptitle("Attack Rate")
plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk.suptitle("UK Infections")
fig_attack.autofmt_xdate()
fig_uk.autofmt_xdate()
fig_attack.gca().grid(True)
fig_uk.gca().grid(True)
plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment