Commit ddada779 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'ode_model'

parents b52f05d2 5a434328
"""Functions for chain binomial simulation."""
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def update_state(update, state, stoichiometry):
update = tf.expand_dims(update, 1) # Rx1xN
update *= tf.expand_dims(stoichiometry, -1) # RxSx1
update = tf.reduce_sum(update, axis=0) # SxN
return state + update
def chain_binomial_propagate(h, time_step, stoichiometry):
def propagate_fn(state):
state_idx, rates = h(state)
probs = 1 - tf.exp(-rates*time_step) # RxN
state_mult = tf.scatter_nd(state_idx[:, None], state,
shape=[state_idx.shape[0], state.shape[1], state.shape[2]])
update = tfd.Binomial(state_mult, probs=probs).sample() # RxN
update = tf.expand_dims(update, 1) # Rx1xN
upd_shape = tf.concat([stoichiometry.shape, tf.fill([tf.rank(state)-1], 1)], axis=0)
update *= tf.reshape(stoichiometry, upd_shape) # RxSx1
update = tf.reduce_sum(update, axis=0)
state = state + update
return state
return propagate_fn
def chain_binomial_propagate(h, time_step):
"""Propagates the state of a population according to discrete time dynamics.
:param h: a hazard rate function returning the non-row-normalised Markov transition rate matrix
This function should return a tensor of dimension [ns, ns, nc] where ns is the number of
states, and nc is the number of strata within the population.
:param time_step: the time step
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
"""
def propagate_fn(t, state):
rate_matrix = h(t, state)
# Set diagonal to be the negative of the sum of other elements in each row
rate_matrix = tf.linalg.set_diag(rate_matrix,
-tf.reduce_sum(rate_matrix, axis=-1))
# Calculate Markov transition probability matrix
markov_transition = tf.linalg.expm(rate_matrix*time_step)
num_states = markov_transition.shape[-1]
prev_probs = tf.zeros_like(markov_transition[..., :, 0])
counts = tf.zeros(markov_transition.shape[:-1].as_list() + [0],
dtype=markov_transition.dtype)
total_count = state
# This for loop is ok because there are (currently) only 4 states (SEIR)
# and we're only actually creating work for 3 of them. Even for as many
# as a ~10 states it should probably be fine, just increasing the size
# of the graph a bit.
for i in range(num_states - 1):
probs = markov_transition[..., :, i]
binom = tfd.Binomial(
total_count=total_count,
probs=tf.clip_by_value(probs / (1. - prev_probs), 0., 1.))
sample = binom.sample()
counts = tf.concat([counts, sample[..., tf.newaxis]], axis=-1)
total_count -= sample
prev_probs += probs
def chain_binomial_simulate(hazard_fn, state, start, end, time_step, stoichiometry):
counts = tf.concat([counts, total_count[..., tf.newaxis]], axis=-1)
new_state = tf.reduce_sum(counts, axis=-2)
return new_state
return propagate_fn
propagate = chain_binomial_propagate(hazard_fn, time_step, stoichiometry)
def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
"""Simulates from a discrete time Markov state transition model using multinomial sampling
across rows of the """
propagate = chain_binomial_propagate(hazard_fn, time_step)
times = tf.range(start, end, time_step)
output = tf.TensorArray(tf.float64, size=times.shape[0])
output = tf.TensorArray(state.dtype, size=times.shape[0])
output = output.write(0, state)
for i in tf.range(1, times.shape[0]):
state = propagate(state)
output = output.write(i, state)
with tf.device("/CPU:0"):
sim = output.gather(tf.range(times.shape[0]))
return times, sim
cond = lambda i, *_: i < times.shape[0]
def body(i, state, output):
state = propagate(i, state)
output = output.write(i, state)
return i + 1, state, output
_, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
return times, output.stack()
......@@ -38,7 +38,7 @@ def dense_to_block_diagonal(A, n_blocks):
return A_block
class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. \epsilon term.
class CovidUK:
def __init__(self,
M_tt: np.float64,
M_hh: np.float64,
......@@ -47,7 +47,8 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
N: np.float64,
date_range: list,
holidays: list,
t_step: np.int64):
lockdown: list,
time_step: np.int64):
"""Represents a CovidUK ODE model
:param K_tt: a MxM matrix of age group mixing in term time
......@@ -91,59 +92,71 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=dtype)
self.N_sum = tf.reshape(N_sum, [-1])
self.times = np.arange(date_range[0], date_range[1], np.timedelta64(t_step, 'D'))
self.time_step = time_step
self.times = np.arange(date_range[0], date_range[1], np.timedelta64(int(time_step), 'D'))
self.m_select = np.int64((self.times >= holidays[0]) &
(self.times < holidays[1]))
self.lockdown_select = np.int64((self.times >= lockdown[0]) &
(self.times < lockdown[1]))
self.max_t = self.m_select.shape[0] - 1
def create_initial_state(self, init_matrix=None):
if init_matrix is None:
I = np.zeros(self.N.shape, dtype=np.float64)
I[149*17+10] = 30. # Middle-aged in Surrey
else:
np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
({self.n_lads},{self.n_ages})")
I = init_matrix.flatten()
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R], axis=-1)
class CovidUKODE(CovidUK):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.solver = tode.DormandPrince()
def make_h(self, param):
def h_fn(t, state):
state = tf.unstack(state, axis=0)
S, E, I, R = state
state_ = tf.transpose(state)
S, E, I, R = tf.unstack(state_, axis=0)
# Integrator may produce time values outside the range desired, so
# we clip, implicitly assuming the outside dates have the same
# holiday status as their nearest neighbors in the desired range.
t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
m_switch = tf.gather(self.m_select, t_idx)
commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
lockdown = tf.gather(self.lockdown_select, t_idx)
beta = tf.where(lockdown == 0, param['beta1'], param['beta1']*param['beta3'])
infec_rate = param['beta1'] * (
infec_rate = beta * (
tf.gather(self.M.matvec(I), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(I / self.N_sum))
infec_rate = S / self.N * infec_rate
infec_rate = S * infec_rate / self.N
dS = -infec_rate
dE = infec_rate - param['nu'] * E
dI = param['nu'] * E - param['gamma'] * I
dR = param['gamma'] * I
df = tf.stack([dS, dE, dI, dR])
df = tf.stack([dS, dE, dI, dR], axis=-1)
return df
return h_fn
def create_initial_state(self, init_matrix=None):
if init_matrix is None:
I = np.zeros(self.N.shape, dtype=np.float64)
I[149*17+10] = 30. # Middle-aged in Surrey
else:
np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
({self.n_lads},{self.n_ages})")
I = init_matrix.flatten()
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
def simulate(self, param, state_init, solver_state=None):
h = self.make_h(param)
t = np.arange(self.times.shape[0])
results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init,
results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init * param['I0'],
solution_times=t, previous_solver_internal_state=solver_state)
return results.times, results.states, results.solver_internal_state
......@@ -163,9 +176,72 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
return tf.squeeze(R0), i
def covid19uk_logp(y, sim, phi):
def covid19uk_logp(y, sim, phi, r):
# Sum daily increments in removed
r_incr = sim[1:, 3, :] - sim[:-1, 3, :]
r_incr = tf.reduce_sum(r_incr, axis=1)
y_ = tfp.distributions.Poisson(rate=phi*r_incr)
r_incr = sim[1:, :, 3] - sim[:-1, :, 3]
r_incr = tf.reduce_sum(r_incr, axis=-1)
# Poisson(\lambda) = \lim{r\rightarrow \infty} NB(r, \frac{\lambda}{r + \lambda})
#y_ = tfp.distributions.Poisson(rate=phi*r_incr)
lambda_ = r_incr * phi
y_ = tfp.distributions.NegativeBinomial(r, probs=lambda_/(r+lambda_))
return y_.log_prob(y)
class CovidUKStochastic(CovidUK):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def make_h(self, param):
"""Constructs a function that takes `state` and outputs a
transition rate matrix (with 0 diagonal).
"""
def h(t, state):
"""Computes a transition rate matrix
:param state: a tensor of shape [ns, nc] for ns states and nc population strata. States
are S, E, I, R. We arrange the state like this because the state vectors are then arranged
contiguously in memory for fast calculation below.
:return a tensor of shape [ns, ns, nc] containing transition matric for each i=0,...,(c-1)
"""
t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
m_switch = tf.gather(self.m_select, t_idx)
commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
infec_rate = param['beta1'] * (
tf.gather(self.M.matvec(state[:, 2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[:, 2] / self.N_sum))
infec_rate = infec_rate / self.N
ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]])
# Scatter rates into a [ns, ns, nc] tensor
n = state.shape[0]
b = tf.stack([tf.range(n),
tf.zeros(n, dtype=tf.int32),
tf.ones(n, dtype=tf.int32)], axis=-1)
indices = tf.stack([b, b + [0, 1, 1], b + [0, 2, 2]], axis=-2)
# Un-normalised rate matrix (diag is 0 here)
rate_matrix = tf.scatter_nd(indices=indices,
updates=tf.stack([infec_rate, ei, ir], axis=-1),
shape=[state.shape[0],
state.shape[1],
state.shape[1]])
return rate_matrix
return h
@tf.function(autograph=False, experimental_compile=True)
def simulate(self, param, state_init):
"""Runs a simulation from the epidemic model
:param param: a dictionary of model parameters
:param state_init: the initial state
:returns: a tuple of times and simulated states.
"""
param = {k: tf.constant(v, dtype=tf.float64) for k, v in param.items()}
hazard = self.make_h(param)
t, sim = chain_binomial_simulate(hazard, state_init, np.float64(0.),
np.float64(self.times.shape[0]), self.time_step)
return t, sim
......@@ -9,8 +9,7 @@ tfs = tfp.stats
def plot_prediction(prediction_period, sims, case_reports):
# Sum over country
sims = tf.reduce_sum(sims, axis=3)
sims = tf.reduce_sum(sims, axis=-2) # Sum over all meta-populations
quantiles = [2.5, 50, 97.5]
......@@ -29,8 +28,9 @@ def plot_prediction(prediction_period, sims, case_reports):
rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed")
ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections')
data_range = [case_reports['DateVal'].min(), case_reports['DateVal'].max()]
data_dates = np.linspace(data_range[0], data_range[1], np.timedelta64(1, 'D'))
data_range = [case_reports['DateVal'].to_numpy().min(), case_reports['DateVal'].to_numpy().max()]
one_day = np.timedelta64(1, 'D')
data_dates = np.arange(data_range[0], data_range[1]+one_day, one_day)
marks = plt.plot(data_dates, case_reports['CumCases'].to_numpy(), '+', label='Observed cases')
plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]],
["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"])
......@@ -44,7 +44,7 @@ def plot_prediction(prediction_period, sims, case_reports):
def plot_case_incidence(dates, sims):
# Number of new cases per day
new_cases = sims[:, :, 3, :].sum(axis=2)
new_cases = sims[:, :, :, 3].sum(axis=2)
new_cases = new_cases[:, 1:] - new_cases[:, :-1]
new_cases = tfs.percentile(new_cases, q=[2.5, 50, 97.5], axis=0)/10000.
......
......@@ -10,8 +10,7 @@ tfs = tfp.stats
def sanitise_parameter(par_dict):
"""Sanitises a dictionary of parameters"""
par = ['omega', 'beta1', 'beta2', 'nu', 'gamma']
d = {key: np.float64(par_dict[key]) for key in par}
d = {key: np.float64(val) for key, val in par_dict.items()}
return d
......@@ -19,7 +18,8 @@ def sanitise_settings(par_dict):
d = {'inference_period': np.array(par_dict['inference_period'], dtype=np.datetime64),
'prediction_period': np.array(par_dict['prediction_period'], dtype=np.datetime64),
'time_step': float(par_dict['time_step']),
'holiday': np.array([np.datetime64(date) for date in par_dict['holiday']])}
'holiday': np.array([np.datetime64(date) for date in par_dict['holiday']]),
'lockdown': np.array([np.datetime64(date) for date in par_dict['lockdown']])}
return d
......@@ -124,7 +124,7 @@ def extract_locs(in_file: str, out_file: str, loc: list):
la_names = f['la_names'][:].astype(str)
la_loc = np.isin(la_names, loc)
extract = f['prediction'][:, :, :, la_loc]
extract = f['prediction'][:, :, la_loc, :]
save_sims(f['date'][:], extract, f['la_names'][la_loc],
f['age_names'][la_loc], out_file)
......
......@@ -13,26 +13,26 @@ from covid.util import sanitise_parameter, sanitise_settings, seed_areas
def sum_age_groups(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=2)
return infec_uk
def sum_la(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=1)
return infec_uk
def sum_total_removals(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
return remove.sum(axis=1)
def final_size(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
remove = remove.reshape([remove.shape[0], 152, 17])
fs = remove[-1, :, :].sum(axis=0)
return fs
......@@ -52,7 +52,7 @@ def write_hdf5(filename, param, t, sim):
def plot_total_curve(sim):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
removals = sum_total_removals(sim)
times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
plt.plot(times, infec_uk, 'r-', label='Infected')
......@@ -66,7 +66,7 @@ def plot_total_curve(sim):
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label)
......@@ -75,7 +75,7 @@ def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_la(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......@@ -88,7 +88,7 @@ def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_age_groups(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......
import optparse
import time
import tensorflow as tf
import matplotlib.pyplot as plt
import yaml
from covid.model import CovidUKStochastic
from covid.rdata import *
from covid.pydata import load_commute_volume
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
DTYPE = np.float64
def sum_age_groups(sim):
infec = sim[:, 2, :]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=2)
return infec_uk
def sum_la(sim):
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=1)
return infec_uk
def sum_total_removals(sim):
remove = sim[:, 3, :]
return remove.sum(axis=1)
def final_size(sim):
remove = sim[:, :, 3]
remove = remove.reshape([remove.shape[0], 152, 17])
fs = remove[-1, :, :].sum(axis=0)
return fs
def plot_total_curve(sim):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
removals = sum_total_removals(sim)
times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
plt.plot(times, infec_uk, 'r-', label='Infected')
plt.plot(times, removals, 'b-', label='Removed')
plt.title('UK total cases')
plt.xlabel('Date')
plt.ylabel('Num infected or removed')
plt.grid()
plt.legend()
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label)
def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_la(sim)
total_uk = infec_uk.mean(axis=1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
ax.plot(t, total_uk, '-', color='black', label='Mean')
return ax
def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_age_groups(sim)
total_uk = infec_uk.mean(axis=1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
ax.plot(t, total_uk, '-', color='black', label='Mean')
return ax
def draw_figs(sim, N):
# Attack rate
N = N.reshape([152, 17]).sum(axis=0)
fs = final_size(sim)
attack_rate = fs / N
print("Attack rate:", attack_rate)
print("Overall attack rate: ", np.sum(fs) / np.sum(N))
# Total UK epidemic curve
plot_total_curve(sim)
plt.xticks(rotation=45, horizontalalignment="right")
plt.savefig('total_uk_curve.pdf')
plt.show()
# TotalUK epidemic curve by age-group
fig, ax = plt.subplots(1, 2, figsize=[24, 12])
plot_by_la(sim, la_names, ax=ax[0])
plot_by_age(sim, age_groups, ax=ax[1])
ax[1].legend()
plt.xticks(rotation=45, horizontalalignment="right")
fig.autofmt_xdate()
plt.savefig('la_age_infec_curves.pdf')
plt.show()
# Plot attack rate
plt.figure(figsize=[4, 2])
plt.plot(age_groups, attack_rate, 'o-')
plt.xticks(rotation=90)
plt.title('Age-specific attack rate')
plt.savefig('age_attack_rate.pdf')
plt.show()
def doubling_time(t, sim, t1, t2):
t1 = np.where(t == np.datetime64(t1))[0]
t2 = np.where(t == np.datetime64(t2))[0]
delta = t2 - t1
r = sum_total_removals(sim)
q1 = r[t1]
q2 = r[t2]
return delta * np.log(2) / np.log(q2 / q1)
def plot_age_attack_rate(ax, sim, N, label):
Ns = N.reshape([152, 17]).sum(axis=0)
fs = final_size(sim.numpy())
attack_rate = fs / Ns
ax.plot(age_groups, attack_rate, 'o-', label=label)
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
help="configuration file")
options, args = parser.parse_args()
with open(options.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
parser = optparse.OptionParser()
parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
help="configuration file")
options, args = parser.parse_args()
with open(options.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
M_tt, age_groups = load_age_mixing(config['data']['age_mixing_matrix_term'])
M_hh, _ = load_age_mixing(config['data']['age_mixing_matrix_hol'])
C, la_names = load_mobility_matrix(config['data']['mobility_matrix'])
np.fill_diagonal(C, 0.)
W = load_commute_volume(config['data']['commute_volume'], settings['inference_period'])['percent']
N, n_names = load_population(config['data']['population_size'])
M_tt = M_tt.astype(DTYPE)
M_hh = M_hh.astype(DTYPE)
W = W.to_numpy().astype(DTYPE)
C = C.astype(DTYPE)
N = N.astype(DTYPE)