Commit 90fb9795 authored by Chris Jewell's avatar Chris Jewell
Browse files

Updated ODE model with overall number of starting individuals parameter, and...

Updated ODE model with overall number of starting individuals parameter, and daily confirmed cases up to 27th March (PHE, whole of England).
parent 43bf974a
......@@ -9,6 +9,8 @@ def chain_binomial_propagate(h, time_step):
"""Propagates the state of a population according to discrete time dynamics.
:param h: a hazard rate function returning the non-row-normalised Markov transition rate matrix
This function should return a tensor of dimension [ns, ns, nc] where ns is the number of
states, and nc is the number of strata within the population.
:param time_step: the time step
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
"""
......@@ -45,6 +47,8 @@ def chain_binomial_propagate(h, time_step):
def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
"""Simulates from a discrete time Markov state transition model using multinomial sampling
across rows of the """
propagate = chain_binomial_propagate(hazard_fn, time_step)
times = tf.range(start, end, time_step)
......
......@@ -124,7 +124,8 @@ class CovidUKODE(CovidUK):
def h_fn(t, state):
S, E, I, R = tf.unstack(state, axis=-1)
state_ = tf.transpose(state)
S, E, I, R = tf.unstack(state_, axis=0)
# Integrator may produce time values outside the range desired, so
# we clip, implicitly assuming the outside dates have the same
# holiday status as their nearest neighbors in the desired range.
......@@ -135,21 +136,22 @@ class CovidUKODE(CovidUK):
infec_rate = param['beta1'] * (
tf.gather(self.M.matvec(I), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(I / self.N_sum))
infec_rate = infec_rate / self.N
infec_rate = S * infec_rate / self.N
SE = infec_rate
EI = param['nu']
IR = param['gamma']
dS = -infec_rate
dE = infec_rate - param['nu'] * E
dI = param['nu'] * E - param['gamma'] * I
dR = param['gamma'] * I
p = 1 - tf.exp(tf.stack([SE, EI, IR], axis=-1))
return p
df = tf.stack([dS, dE, dI, dR], axis=-1)
return df
return h_fn
def simulate(self, param, state_init, solver_state=None):
h = self.make_h(param)
t = np.arange(self.times.shape[0])
results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init,
results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init * param['I0'],
solution_times=t, previous_solver_internal_state=solver_state)
return results.times, results.states, results.solver_internal_state
......@@ -169,15 +171,17 @@ class CovidUKODE(CovidUK):
return tf.squeeze(R0), i
def covid19uk_logp(y, sim, phi):
def covid19uk_logp(y, sim, phi, r):
# Sum daily increments in removed
r_incr = sim[1:, :, 3] - sim[:-1, :, 3]
r_incr = tf.reduce_sum(r_incr, axis=-1)
y_ = tfp.distributions.Poisson(rate=phi*r_incr)
# Poisson(\lambda) = \lim{r\rightarrow \infty} NB(r, \frac{\lambda}{r + \lambda})
#y_ = tfp.distributions.Poisson(rate=phi*r_incr)
lambda_ = r_incr * phi
y_ = tfp.distributions.NegativeBinomial(r, probs=lambda_/(r+lambda_))
return y_.log_prob(y)
class CovidUKStochastic(CovidUK):
def __init__(self, *args, **kwargs):
......
......@@ -9,8 +9,7 @@ tfs = tfp.stats
def plot_prediction(prediction_period, sims, case_reports):
# Sum over country
sims = tf.reduce_sum(sims, axis=3)
sims = tf.reduce_sum(sims, axis=-2) # Sum over all meta-populations
quantiles = [2.5, 50, 97.5]
......@@ -29,8 +28,9 @@ def plot_prediction(prediction_period, sims, case_reports):
rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed")
ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections')
data_range = [case_reports['DateVal'].min(), case_reports['DateVal'].max()]
data_dates = np.linspace(data_range[0], data_range[1], np.timedelta64(1, 'D'))
data_range = [case_reports['DateVal'].to_numpy().min(), case_reports['DateVal'].to_numpy().max()]
one_day = np.timedelta64(1, 'D')
data_dates = np.arange(data_range[0], data_range[1]+one_day, one_day)
marks = plt.plot(data_dates, case_reports['CumCases'].to_numpy(), '+', label='Observed cases')
plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]],
["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"])
......@@ -44,7 +44,7 @@ def plot_prediction(prediction_period, sims, case_reports):
def plot_case_incidence(dates, sims):
# Number of new cases per day
new_cases = sims[:, :, 3, :].sum(axis=2)
new_cases = sims[:, :, :, 3].sum(axis=2)
new_cases = new_cases[:, 1:] - new_cases[:, :-1]
new_cases = tfs.percentile(new_cases, q=[2.5, 50, 97.5], axis=0)/10000.
......
......@@ -10,8 +10,7 @@ tfs = tfp.stats
def sanitise_parameter(par_dict):
"""Sanitises a dictionary of parameters"""
par = ['omega', 'beta1', 'beta2', 'nu', 'gamma']
d = {key: np.float64(par_dict[key]) for key in par}
d = {key: np.float64(val) for key, val in par_dict.items()}
return d
......
......@@ -193,10 +193,10 @@ if __name__ == '__main__':
print(f'Run 1 Complete in {end - start} seconds')
start = time.perf_counter()
for i in range(10):
for i in range(1):
t, sim = model.simulate(param, state_init)
end = time.perf_counter()
print(f'Run 2 Complete in {(end - start)/10.} seconds')
print(f'Run 2 Complete in {(end - start)/1.} seconds')
# Plotting functions
fig_attack = plt.figure()
......
DateVal,CMODateCount,CumCases,DailyDeaths,CumDeaths
2020-01-31,2.00,2.00,,
2020-02-01,0.00,2.00,,
2020-02-02,0.00,2.00,,
2020-02-03,0.00,2.00,,
2020-02-04,0.00,2.00,,
2020-02-05,0.00,2.00,,
2020-02-06,1.00,3.00,,
2020-02-07,0.00,3.00,,
2020-02-08,0.00,3.00,,
2020-02-09,1.00,4.00,,
2020-02-10,4.00,8.00,,
2020-02-11,0.00,8.00,,
2020-02-12,0.00,8.00,,
2020-02-13,1.00,9.00,,
2020-02-14,0.00,9.00,,
2020-02-15,0.00,9.00,,
2020-02-16,0.00,9.00,,
2020-02-17,0.00,9.00,,
2020-02-18,0.00,9.00,,
2020-02-19,0.00,9.00,,
2020-02-20,0.00,9.00,,
2020-02-21,0.00,9.00,,
2020-02-22,0.00,9.00,,
2020-02-23,0.00,9.00,,
2020-02-24,4.00,13.00,,
2020-02-25,0.00,13.00,,
2020-02-26,0.00,13.00,,
2020-02-27,0.00,13.00,,
2020-02-28,6.00,19.00,,
2020-02-29,4.00,23.00,,
2020-03-01,12.00,35.00,,
2020-03-02,5.00,40.00,,
2020-03-03,11.00,51.00,,
2020-03-04,34.00,85.00,,
2020-03-05,29.00,114.00,,
2020-03-06,46.00,160.00,,
2020-03-07,46.00,206.00,,
2020-03-08,65.00,271.00,,
2020-03-09,50.00,321.00,,
2020-03-10,52.00,373.00,,6.00
2020-03-11,83.00,456.00,,
2020-03-12,139.00,590.00,,8.00
2020-03-13,207.00,797.00,,
2020-03-14,264.00,1061.00,,21.00
2020-03-15,330.00,1391.00,14.00,35.00
2020-03-16,152.00,1543.00,20.00,55.00
2020-03-17,407.00,1950.00,16.00,71.00
2020-03-18,676.00,2626.00,32.00,103.00
2020-03-19,643.00,3269.00,41.00,144.00
2020-03-20,714.00,3983.00,33.00,177.00
2020-03-21,1035.00,5018.00,56.00,233.00
2020-03-22,665.00,5683.00,48.00,281.00
2020-03-23,967.00,6650.00,54.00,335.00
2020-03-24,1427.00,8077.00,87.00,422.00
2020-03-25,1452.00,9529.00,41.00,463.00
2020-03-26,2129.00,11658.00,115.00,578.00
2020-03-27,2885.00,14543.00,181.00,759.00
2020-03-28,2546.00,17089.00,260.00,"1,019.00"
2020-03-29,2433.00,19522.00,209.00,"1,228.00"
date,percent
11/01/2020,1.05
12/01/2020,1.07
13/01/2020,1.05
14/01/2020,1.05
15/01/2020,1.05
16/01/2020,1.06
17/01/2020,1.05
18/01/2020,1.06
19/01/2020,1.05
20/01/2020,1.06
21/01/2020,1.06
22/01/2020,1.07
23/01/2020,1.07
24/01/2020,1.07
25/01/2020,1.06
26/01/2020,1.06
27/01/2020,1.06
28/01/2020,1.06
29/01/2020,1.05
30/01/2020,1.03
31/01/2020,1.04
01/02/2020,1.05
02/02/2020,1.04
03/02/2020,1.02
04/02/2020,1.01
05/02/2020,1.02
06/02/2020,1.04
07/02/2020,1.03
08/02/2020,1.02
09/02/2020,1.01
10/02/2020,1.02
11/02/2020,1.02
12/02/2020,1.02
13/02/2020,1.01
14/02/2020,1.03
15/02/2020,1.01
16/02/2020,1.02
17/02/2020,1.02
18/02/2020,1.01
19/02/2020,1.02
20/02/2020,1.00
21/02/2020,0.98
22/02/2020,0.99
23/02/2020,1.00
24/02/2020,1.01
25/02/2020,1.01
26/02/2020,0.99
27/02/2020,0.98
28/02/2020,0.97
29/02/2020,0.96
01/03/2020,0.96
02/03/2020,0.94
03/03/2020,0.93
04/03/2020,0.93
05/03/2020,0.94
06/03/2020,0.95
07/03/2020,0.95
08/03/2020,0.94
09/03/2020,0.94
10/03/2020,0.91
11/03/2020,0.89
12/03/2020,0.85
13/03/2020,0.80
14/03/2020,0.77
15/03/2020,0.73
16/03/2020,0.64
17/03/2020,0.52
18/03/2020,0.42
19/03/2020,0.32
20/03/2020,0.24
21/03/2020,0.18
22/03/2020,0.14
23/03/2020,0.06
24/03/2020,0.02
25/03/2020,-0.03
26/03/2020,-0.07
27/03/2020,-0.09
......@@ -76,7 +76,7 @@ if __name__ == '__main__':
N=N,
date_range=[date_range[0]-np.timedelta64(1,'D'), date_range[1]],
holidays=settings['holiday'],
t_step=int(settings['time_step']))
time_step=int(settings['time_step']))
seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size
state_init = simulator.create_initial_state(init_matrix=seeding)
......@@ -85,11 +85,15 @@ if __name__ == '__main__':
p = param
p['beta1'] = par[0]
p['gamma'] = par[1]
beta_logp = tfd.Gamma(concentration=tf.constant(1., tf.float64), rate=tf.constant(1., tf.float64)).log_prob(p['beta1'])
gamma_logp = tfd.Gamma(concentration=tf.constant(100., tf.float64), rate=tf.constant(400., tf.float64)).log_prob(p['gamma'])
p['I0'] = par[2]
p['r'] = par[3]
beta_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE), rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE), rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
I0_logp = tfd.Gamma(concentration=tf.constant(1.5, dtype=DTYPE), rate=tf.constant(0.05, dtype=DTYPE)).log_prob(p['I0'])
r_logp = tfd.Gamma(concentration=tf.constant(0.1, dtype=DTYPE), rate=tf.constant(0.1, dtype=DTYPE)).log_prob(p['gamma'])
t, sim, solve = simulator.simulate(p, state_init)
y_logp = covid19uk_logp(y_incr, sim, 0.1)
logp = beta_logp + gamma_logp + tf.reduce_sum(y_logp)
y_logp = covid19uk_logp(y_incr, sim, 0.1, p['r'])
logp = beta_logp + gamma_logp + I0_logp + r_logp + tf.reduce_sum(y_logp)
return logp
def trace_fn(_, pkr):
......@@ -100,7 +104,7 @@ if __name__ == '__main__':
unconstraining_bijector = [tfb.Exp()]
initial_mcmc_state = np.array([0.05, 0.25], dtype=np.float64)
initial_mcmc_state = np.array([0.05, 0.25, 1.0, 50.], dtype=np.float64) # beta1, gamma, I0
print("Initial log likelihood:", logp(initial_mcmc_state))
@tf.function(autograph=False, experimental_compile=True)
......@@ -119,7 +123,7 @@ if __name__ == '__main__':
joint_posterior = tf.zeros([0] + list(initial_mcmc_state.shape), dtype=DTYPE)
scale = np.diag([0.1, 0.1])
scale = np.diag([0.1, 0.1, 0.1, 1.])
overall_start = time.perf_counter()
num_covariance_estimation_iterations = 50
......@@ -151,9 +155,10 @@ if __name__ == '__main__':
print("Acceptance: ", np.mean(results.numpy()))
print(tfp.stats.covariance(tf.math.log(joint_posterior)))
fig, ax = plt.subplots(1, 3)
ax[0].plot(joint_posterior[:, 0])
ax[1].plot(joint_posterior[:, 1])
fig, ax = plt.subplots(1, joint_posterior.shape[1])
for i in range(joint_posterior.shape[1]):
ax[i].plot(joint_posterior[:, i])
plt.show()
print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}")
......
......@@ -5,20 +5,22 @@ data:
age_mixing_matrix_hol: data/polymod_no_school_df.rds
mobility_matrix: data/movement.rds
population_size: data/pop.rds
commute_volume: data/commute_vol_2020-03-20.csv
reported_cases: data/DailyConfirmedCases_2020-03-20.csv
commute_volume: data/commute_vol_2020-03-27.csv
reported_cases: data/DailyConfirmedCases_2020-03-29.csv
parameter:
beta1: 0.1 # R0 2.4
beta2: 0.33 # Contact with commuters 1/3rd of the time
omega: 1.0
nu: 0.25
gamma: 0.25
beta1: 0.1 # R0 2.4
beta2: 0.33 # Contact with commuters 1/3rd of the time
omega: 1.0 # Non-linearity parameter for commuting volume
nu: 0.25 # E -> I transition rate
gamma: 0.25 # I -> R transition rate
I0: 0.0 # number of individuals at start of epidemic
r: 30.0 # Negative binomial overdispersion parameter
settings:
inference_period:
- 2020-02-19
- 2020-04-01
- 2020-03-30
holiday:
- 2020-03-23
- 2020-10-01
......
......@@ -71,18 +71,20 @@ if __name__ == '__main__':
date_range=[settings['prediction_period'][0],
settings['prediction_period'][1]],
holidays=settings['holiday'],
t_step=1)
time_step=1)
seeding = seed_areas(N, n_names) # Seed 40-44 age group, 30 seeds by popn size
state_init = simulator.create_initial_state(init_matrix=seeding)
@tf.function
def prediction(beta, gamma):
def prediction(beta, gamma, I0, r_):
sims = tf.TensorArray(tf.float64, size=beta.shape[0])
R0 = tf.TensorArray(tf.float64, size=beta.shape[0])
for i in tf.range(beta.shape[0]):
p = param
p['beta1'] = beta[i]
p['gamma'] = gamma[i]
p['I0'] = I0[i]
p['r'] = r_[i]
t, sim, solver_results = simulator.simulate(p, state_init)
r = simulator.eval_R0(p)
R0 = R0.write(i, r[0])
......@@ -91,8 +93,8 @@ if __name__ == '__main__':
draws = pi_beta.numpy()[np.arange(5000, pi_beta.shape[0], 30), :]
with tf.device('/CPU:0'):
sims, R0 = prediction(draws[:, 0], draws[:, 1])
sims = tf.stack(sims) # shape=[n_sims, n_times, n_states, n_metapops]
sims, R0 = prediction(draws[:, 0], draws[:, 1], draws[:, 2], draws[:, 3])
sims = tf.stack(sims) # shape=[n_sims, n_times, n_metapops, n_states]
save_sims(simulator.times, sims, la_names, age_groups, 'pred_2020-03-23.h5')
dub_time = [doubling_time(simulator.times, sim, '2020-03-01', '2020-04-01') for sim in sims.numpy()]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment