Unverified Commit 79f51e8d authored by Chris Jewell's avatar Chris Jewell Committed by GitHub
Browse files

Merge pull request #1 from csuter/stochastic

Improve efficiency of stochastic model.
parents c5590098 864fcf65
......@@ -13,38 +13,32 @@ def chain_binomial_propagate(h, time_step):
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
"""
def propagate_fn(t, state):
# State is assumed to be of shape [s, n] where s is the number of states
# and n is the number of population strata.
# TODO: having state as [s, n] means we have to do some funky transposition. It may be better
# to have state.shape = [n, s] which avoids transposition below, but may lead to slower
# rate calculations.
rate_matrix = h(t, state)
rate_matrix = tf.transpose(rate_matrix, perm=[2, 0, 1])
# Set diagonal to be the negative of the sum of other elements in each row
rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=2))
rate_matrix = tf.linalg.set_diag(rate_matrix,
-tf.reduce_sum(rate_matrix, axis=-1))
# Calculate Markov transition probability matrix
markov_transition = tf.linalg.expm(rate_matrix*time_step)
# Sample new state
new_state = tfd.Multinomial(total_count=tf.transpose(state),
new_state = tfd.Multinomial(total_count=state,
probs=markov_transition).sample()
new_state = tf.reduce_sum(new_state, axis=1)
return tf.transpose(new_state)
new_state = tf.reduce_sum(new_state, axis=-1)
return new_state
return propagate_fn
@tf.function(autograph=False) # Algorithm runs super slow if uncommented. Weird!
def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
propagate = chain_binomial_propagate(hazard_fn, time_step)
times = tf.range(start, end, time_step)
output = tf.TensorArray(tf.float64, size=times.shape[0])
output = tf.TensorArray(state.dtype, size=times.shape[0])
output = output.write(0, state)
for i in range(1, times.shape[0]):
state = propagate(i, state)
output = output.write(i, state)
sim = output.gather(tf.range(1, times.shape[0]))
return times, sim
cond = lambda i, *_: i < times.shape[0]
def body(i, state, output):
state = propagate(i, state)
output = output.write(i, state)
return i + 1, state, output
_, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
return times, output.stack()
......@@ -110,7 +110,7 @@ class CovidUK:
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
return np.stack([S, E, I, R], axis=-1)
class CovidUKODE(CovidUK):
......@@ -124,7 +124,7 @@ class CovidUKODE(CovidUK):
def h_fn(t, state):
S, E, I, R = state
S, E, I, R = tf.unstack(state, axis=-1)
# Integrator may produce time values outside the range desired, so
# we clip, implicitly assuming the outside dates have the same
# holiday status as their nearest neighbors in the desired range.
......@@ -141,7 +141,7 @@ class CovidUKODE(CovidUK):
EI = param['nu']
IR = param['gamma']
p = 1 - tf.exp([SE, EI, IR])
p = 1 - tf.exp(tf.stack([SE, EI, IR], axis=-1))
return p
return h_fn
......@@ -171,8 +171,8 @@ class CovidUKODE(CovidUK):
def covid19uk_logp(y, sim, phi):
# Sum daily increments in removed
r_incr = sim[1:, 3, :] - sim[:-1, 3, :]
r_incr = tf.reduce_sum(r_incr, axis=1)
r_incr = sim[1:, :, 3] - sim[:-1, :, 3]
r_incr = tf.reduce_sum(r_incr, axis=-1)
y_ = tfp.distributions.Poisson(rate=phi*r_incr)
return y_.log_prob(y)
......@@ -201,22 +201,30 @@ class CovidUKStochastic(CovidUK):
commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
infec_rate = param['beta1'] * (
tf.gather(self.M.matvec(state[2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[2] / self.N_sum))
tf.gather(self.M.matvec(state[:, 2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[:, 2] / self.N_sum))
infec_rate = infec_rate / self.N
ei = tf.broadcast_to([param['nu']], shape=[state.shape[1]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[1]])
ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]])
# Scatter rates into a [ns, ns, nc] tensor
rates = [infec_rate, ei, ir]
rates = tf.scatter_nd(updates=rates,
indices=[[0, 1], [1, 2], [2, 3]],
shape=[state.shape[0], state.shape[0], state.shape[1]])
return rates
n = state.shape[0]
b = tf.stack([tf.range(n),
tf.zeros(n, dtype=tf.int32),
tf.ones(n, dtype=tf.int32)], axis=-1)
indices = tf.stack([b, b + [0, 1, 1], b + [0, 2, 2]], axis=-2)
# Un-normalised rate matrix (diag is 0 here)
rate_matrix = tf.scatter_nd(indices=indices,
updates=tf.stack([infec_rate, ei, ir], axis=-1),
shape=[state.shape[0],
state.shape[1],
state.shape[1]])
return rate_matrix
return h
@tf.function(experimental_compile=True)
@tf.function(autograph=False)
def simulate(self, param, state_init):
"""Runs a simulation from the epidemic model
......
......@@ -13,26 +13,26 @@ from covid.util import sanitise_parameter, sanitise_settings, seed_areas
def sum_age_groups(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=2)
return infec_uk
def sum_la(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=1)
return infec_uk
def sum_total_removals(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
return remove.sum(axis=1)
def final_size(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
remove = remove.reshape([remove.shape[0], 152, 17])
fs = remove[-1, :, :].sum(axis=0)
return fs
......@@ -52,7 +52,7 @@ def write_hdf5(filename, param, t, sim):
def plot_total_curve(sim):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
removals = sum_total_removals(sim)
times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
plt.plot(times, infec_uk, 'r-', label='Infected')
......@@ -66,7 +66,7 @@ def plot_total_curve(sim):
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label)
......@@ -75,7 +75,7 @@ def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_la(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......@@ -88,7 +88,7 @@ def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_age_groups(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment