Commit 511d23be authored by Chris Jewell's avatar Chris Jewell

Code tidy

parent 4cddaded
"""Functions for infection rates"""
from warnings import warn
import tensorflow as tf
import tensorflow_probability as tfp
tode = tfp.math.ode
import numpy as np
from covid.impl.chainbinom_simulate import chain_binomial_simulate
tode = tfp.math.ode
tla = tf.linalg
def power_iteration(A, tol=1e-3):
b_k = tf.random.normal([A.shape[1], 1], dtype=tf.float64)
epsilon = tf.constant(1., dtype=tf.float64)
......@@ -21,64 +22,13 @@ def power_iteration(A, tol=1e-3):
i += 1
return b_k, i
#@tf.function
def rayleigh_quotient(A, b):
b = tf.reshape(b, [b.shape[0], 1])
numerator = tf.matmul(tf.transpose(b), tf.matmul(A, b))
denominator = tf.matmul(tf.transpose(b), b)
return numerator / denominator
class CovidUK:
def __init__(self, K, T, W):
self.K = K
self.T = T
self.W = W
self.stoichiometry = [[-1, 1, 0, 0],
[0, -1, 1, 0],
[0, 0, -1, 1]]
def h(self, state):
state = tf.unstack(state, axis=0)
S, E, I, R = state
hazard_rates = tf.stack([
self.param['beta1'] * tf.dot(self.T, tf.dot(self.K, I))/self.K.shape[0],
self.param['nu'],
self.param['gamma']
])
return hazard_rates
#@tf.function
def sample(self, initial_state, time_lims, param):
self.param = param
return chain_binomial_simulate(self.h, initial_state, time_lims[0],
time_lims[1], 1., self.stoichiometry)
class Homogeneous:
def __init__(self):
self.stoichiometry = tf.constant([[-1, 1, 0, 0],
[0, -1, 1, 0],
[0, 0, -1, 1]], dtype=tf.float32)
def h(self, state):
state = tf.unstack(state, axis=0)
S, E, I, R = state
hazard_rates = tf.stack([
self.param['beta'] * I / tf.reduce_sum(state),
self.param['nu'] * tf.ones_like(I),
self.param['gamma'] * tf.ones_like(I)
])
return hazard_rates
@tf.function
def sample(self, initial_state, time_lims, param):
self.param = param
return chain_binomial_simulate(self.h, initial_state, time_lims[0],
time_lims[1], 1., self.stoichiometry)
def dense_to_block_diagonal(A, n_blocks):
A_dense = tf.linalg.LinearOperatorFullMatrix(A)
......@@ -100,15 +50,15 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
"""
self.n_ages = M_tt.shape[0]
self.n_lads = C.shape[0]
self.M_tt = tf.convert_to_tensor(M_tt, dtype=tf.float64)
self.M_hh = tf.convert_to_tensor(M_hh, dtype=tf.float64)
self.Kbar = tf.reduce_mean(tf.cast(M_tt, tf.float64))
self.M = tf.tuple([dense_to_block_diagonal(tf.cast(M_tt, tf.float64), self.n_lads),
dense_to_block_diagonal(tf.cast(M_hh, tf.float64), self.n_lads)])
self.Kbar = tf.reduce_mean(self.M_tt)
C = tf.cast(C, tf.float64)
self.C = tf.linalg.LinearOperatorFullMatrix(C + tf.transpose(C))
shp = tf.linalg.LinearOperatorFullMatrix(np.ones_like(M_tt, dtype=np.float64))
self.C = tf.linalg.LinearOperatorKronecker([self.C, shp])
self.C = tla.LinearOperatorFullMatrix(C + tf.transpose(C))
shp = tla.LinearOperatorFullMatrix(np.ones_like(M_tt, dtype=np.float64))
self.C = tla.LinearOperatorKronecker([self.C, shp])
self.N = tf.constant(N, dtype=tf.float64)
N_matrix = tf.reshape(self.N, [self.n_lads, self.n_ages])
......@@ -117,9 +67,10 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
self.N_sum = tf.reshape(N_sum, [-1])
self.times = np.arange(start, end, np.timedelta64(t_step, 'D'))
m_select = (np.less_equal(holidays[0], self.times) & np.less(self.times, holidays[1])).astype(np.int64)
self.m_select = tf.constant(m_select, dtype=tf.int64)
self.bg_select = tf.constant(np.less(self.times, bg_max_t), dtype=tf.int64)
self.school_hols = [tf.constant((holidays[0] - start) // np.timedelta64(1, 'D'), dtype=tf.float64),
tf.constant((holidays[1] - start) // np.timedelta64(1, 'D'), dtype=tf.float64)]
self.bg_max_t = tf.convert_to_tensor(bg_max_t, dtype=tf.float64)
self.solver = tode.DormandPrince()
def make_h(self, param):
......@@ -127,14 +78,16 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
def h_fn(t, state):
state = tf.unstack(state, axis=0)
S, E, I, R = state
t = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.m_select.shape[0]-1)
m_switch = tf.gather(self.m_select, t)
epsilon = param['epsilon'] * tf.cast(tf.gather(self.bg_select, t), tf.float64)
if m_switch == 0:
infec_rate = param['beta1'] * tf.linalg.matvec(self.M[0], I)
else:
infec_rate = param['beta1'] * tf.linalg.matvec(self.M[1], I)
infec_rate += param['beta1'] * param['beta2'] * self.Kbar * tf.linalg.matvec(self.C, I / self.N_sum)
M = tf.where(tf.less_equal(self.school_hols[0], t) & tf.less(t, self.school_hols[1]),
self.M_hh, self.M_tt)
M = dense_to_block_diagonal(M, self.n_lads)
epsilon = tf.where(t < self.bg_max_t, param['epsilon'], tf.constant(0., dtype=tf.float64))
infec_rate = param['beta1'] * tla.matvec(M, I)
infec_rate += param['beta1'] * param['beta2'] * self.Kbar * tla.matvec(self.C, I / self.N_sum)
infec_rate = S / self.N * (infec_rate + epsilon)
dS = -infec_rate
......@@ -161,7 +114,6 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
@tf.function
def simulate(self, param, state_init, solver_state=None):
h = self.make_h(param)
t = np.arange(self.times.shape[0])
......@@ -170,7 +122,7 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g.
return results.times, results.states, results.solver_internal_state
def ngm(self, param):
infec_rate = param['beta1'] * self.M[0].to_dense()
infec_rate = param['beta1'] * dense_to_block_diagonal(self.M_tt, self.n_lads).to_dense()
infec_rate += param['beta1'] * param['beta2'] * self.Kbar * self.C.to_dense() / self.N_sum[None, :]
ngm = infec_rate / param['gamma']
return ngm
......
......@@ -97,11 +97,11 @@ def plot_total_curve(sim):
plt.legend()
def plot_infec_curve(ax, sim, label, color):
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label, color=color)
ax.plot(times, infec_uk, '-', label=label)
def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
......@@ -173,11 +173,11 @@ def doubling_time(t, sim, t1, t2):
return delta * np.log(2) / np.log(q2 / q1)
def plot_age_attack_rate(ax, sim, N, label, color):
def plot_age_attack_rate(ax, sim, N, label):
Ns = N.reshape([152, 17]).sum(axis=0)
fs = final_size(sim.numpy())
attack_rate = fs / Ns
ax.plot(age_groups, attack_rate, 'o-', color=color, label=label)
ax.plot(age_groups, attack_rate, 'o-', label=label)
if __name__ == '__main__':
......@@ -200,47 +200,31 @@ if __name__ == '__main__':
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
model = CovidUKODE(K_tt, K_hh, T, N, settings['start'], settings['end'], settings['holiday'],
settings['bg_max_time'], 1)
# Effect of cocooning on elderly (70+)
def cocooning_model(cocooning_ratio=1.):
seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size
state_init = model.create_initial_state(init_matrix=seeding)
K1_tt = K_tt.copy()
K1_hh = K_hh.copy()
print('R0_term=', model.eval_R0(param))
K1_tt[14:, :] *= cocooning_ratio
K1_tt[:, 14:] *= cocooning_ratio
K1_hh[14:, :] *= cocooning_ratio
K1_hh[:, 14:] *= cocooning_ratio
start = time.perf_counter()
t, sim, _ = model.simulate(param, state_init)
end = time.perf_counter()
print(f'Complete in {end - start} seconds')
model = CovidUKODE(K1_tt, K_hh, T, N, settings['start'], settings['end'], settings['holiday'],
settings['bg_max_time'], 1)
dates = settings['start'] + t.numpy().astype(np.timedelta64)
dt = doubling_time(dates, sim.numpy(), '2020-03-01', '2020-03-31')
print(f"Doubling time: {dt}")
seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size
state_init = model.create_initial_state(init_matrix=seeding)
print('R_term=', model.eval_R0(param))
#print('R_holiday=', model_holiday.eval_R0(param))
start = time.perf_counter()
t, sim, _ = model.simulate(param, state_init)
end = time.perf_counter()
print(f'Complete in {end - start} seconds')
dates = settings['start'] + t.numpy().astype(np.timedelta64)
dt = doubling_time(dates, sim.numpy(), '2020-03-01', '2020-04-01')
print(f"Doubling time: {dt}")
return t, sim
fig_attack = plt.figure()
fig_uk = plt.figure()
cocoon_ratios = [1.]
for i, r in enumerate(cocoon_ratios):
print(f"Simulation, r={r}", flush=True)
t, sim = cocooning_model(r)
plot_age_attack_rate(fig_attack.gca(), sim, N, f"{1 - r}", f"C{i}")
# fig_attack.gca().legend(title="Contact ratio")
plot_infec_curve(fig_uk.gca(), sim.numpy(), f"{1 - r}", f"C{i}")
# fig_uk.gca().legend(title="Contact ratio")
plot_age_attack_rate(fig_attack.gca(), sim, N, "Attack Rate")
fig_attack.suptitle("Attack Rate")
plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk.suptitle("UK Infections")
fig_attack.autofmt_xdate()
fig_uk.autofmt_xdate()
......@@ -248,5 +232,3 @@ if __name__ == '__main__':
fig_uk.gca().grid(True)
plt.show()
# if 'simulation' in config['output']:
# write_hdf5(config['output']['simulation'], param, t, sim)
......@@ -48,7 +48,7 @@ def random_walk_mvnorm_fn(covariance, name=None):
def _fn(state_parts, seed):
with tf.name_scope(name or 'random_walk_mvnorm_fn'):
new_state_parts = rv.sample() + state_parts
new_state_parts = [rv.sample() + state_part for state_part in state_parts]
return new_state_parts
return _fn
......@@ -132,10 +132,8 @@ if __name__ == '__main__':
for i in range(200):
cov = tfp.stats.covariance(tf.math.log(joint_posterior)) * 2.38**2 / joint_posterior.shape[1]
print(cov.numpy())
posterior_new, results = sample(50, joint_posterior[-1, :], cov)
posterior_new, results = sample(50, joint_posterior[-1, :].numpy(), cov)
joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0)
#posterior_new, results = sample(2000, init_state=joint_posterior[-1, :], scale=cov)
#joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0)
end = time.perf_counter()
print(f"Simulation complete in {end-start} seconds")
print("Acceptance: ", np.mean(results.numpy()))
......
......@@ -2,8 +2,7 @@
data:
age_mixing_matrix_term: data/polymod_normal_df.rds
#age_mixing_matrix_hol: data/polymod_no_school_df.rds
age_mixing_matrix_hol: data/polymod_weekend_df.rds
age_mixing_matrix_hol: data/polymod_no_school_df.rds
mobility_matrix: data/movement.rds
population_size: data/pop.rds
reported_cases: data/DailyConfirmedCases.csv
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment