Commit 864fcf65 authored by Christopher Suter's avatar Christopher Suter
Browse files

Improve efficiency of stochastic model.

Big changes:
 1. replace python for loop with tf.while_loop
 2. work with a transposed state tensor shape
   - instead of [4, nlads * nages], use [nlads * nages, 4]
   - this made it pretty easy to eliminate some transposes in
     propagate_fn (there were comments there seemingly contemplating
     this shape arrangement)
   - this feels a little more natural to me, too; in TFP we'd call the 4
     SEIR states components of the "event shape" of the system, and the
     nlads * nages part a "batch shape" (although one could reasonably
     also combine these together into one big matrix "event shape")
   - anyway, this allowed elimination of 3 transpose ops which makes for
     simpler code and avoids some memcpys
   - I also made an effort to update surrounding code to use the same data
     layout, but it seems like mcmc.py and covid_ode.py are broken right
     now anyway, due to other changes made in support of stochastic mode,
     so I couldn't confirm that my changes were sufficient.
 3. switch off XLA (which didn't yield any clear improvement, although
    it also didn't really hurt), and disable autograph (which tries to
    do things like rewrite python for loops into TF graph code but tends
    to produce less performant than manually optimized code like what
    I've done here)
parent c5590098
......@@ -13,38 +13,32 @@ def chain_binomial_propagate(h, time_step):
:returns : a function that propagate `state[t]` -> `state[t+time_step]`
"""
def propagate_fn(t, state):
# State is assumed to be of shape [s, n] where s is the number of states
# and n is the number of population strata.
# TODO: having state as [s, n] means we have to do some funky transposition. It may be better
# to have state.shape = [n, s] which avoids transposition below, but may lead to slower
# rate calculations.
rate_matrix = h(t, state)
rate_matrix = tf.transpose(rate_matrix, perm=[2, 0, 1])
# Set diagonal to be the negative of the sum of other elements in each row
rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=2))
rate_matrix = tf.linalg.set_diag(rate_matrix,
-tf.reduce_sum(rate_matrix, axis=-1))
# Calculate Markov transition probability matrix
markov_transition = tf.linalg.expm(rate_matrix*time_step)
# Sample new state
new_state = tfd.Multinomial(total_count=tf.transpose(state),
new_state = tfd.Multinomial(total_count=state,
probs=markov_transition).sample()
new_state = tf.reduce_sum(new_state, axis=1)
return tf.transpose(new_state)
new_state = tf.reduce_sum(new_state, axis=-1)
return new_state
return propagate_fn
@tf.function(autograph=False) # Algorithm runs super slow if uncommented. Weird!
def chain_binomial_simulate(hazard_fn, state, start, end, time_step):
propagate = chain_binomial_propagate(hazard_fn, time_step)
times = tf.range(start, end, time_step)
output = tf.TensorArray(tf.float64, size=times.shape[0])
output = tf.TensorArray(state.dtype, size=times.shape[0])
output = output.write(0, state)
for i in range(1, times.shape[0]):
state = propagate(i, state)
output = output.write(i, state)
sim = output.gather(tf.range(1, times.shape[0]))
return times, sim
cond = lambda i, *_: i < times.shape[0]
def body(i, state, output):
state = propagate(i, state)
output = output.write(i, state)
return i + 1, state, output
_, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output))
return times, output.stack()
......@@ -110,7 +110,7 @@ class CovidUK:
S = self.N - I
E = np.zeros(self.N.shape, dtype=np.float64)
R = np.zeros(self.N.shape, dtype=np.float64)
return np.stack([S, E, I, R])
return np.stack([S, E, I, R], axis=-1)
class CovidUKODE(CovidUK):
......@@ -124,7 +124,7 @@ class CovidUKODE(CovidUK):
def h_fn(t, state):
S, E, I, R = state
S, E, I, R = tf.unstack(state, axis=-1)
# Integrator may produce time values outside the range desired, so
# we clip, implicitly assuming the outside dates have the same
# holiday status as their nearest neighbors in the desired range.
......@@ -141,7 +141,7 @@ class CovidUKODE(CovidUK):
EI = param['nu']
IR = param['gamma']
p = 1 - tf.exp([SE, EI, IR])
p = 1 - tf.exp(tf.stack([SE, EI, IR], axis=-1))
return p
return h_fn
......@@ -171,8 +171,8 @@ class CovidUKODE(CovidUK):
def covid19uk_logp(y, sim, phi):
# Sum daily increments in removed
r_incr = sim[1:, 3, :] - sim[:-1, 3, :]
r_incr = tf.reduce_sum(r_incr, axis=1)
r_incr = sim[1:, :, 3] - sim[:-1, :, 3]
r_incr = tf.reduce_sum(r_incr, axis=-1)
y_ = tfp.distributions.Poisson(rate=phi*r_incr)
return y_.log_prob(y)
......@@ -201,22 +201,30 @@ class CovidUKStochastic(CovidUK):
commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega'])
infec_rate = param['beta1'] * (
tf.gather(self.M.matvec(state[2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[2] / self.N_sum))
tf.gather(self.M.matvec(state[:, 2]), m_switch) +
param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[:, 2] / self.N_sum))
infec_rate = infec_rate / self.N
ei = tf.broadcast_to([param['nu']], shape=[state.shape[1]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[1]])
ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]])
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]])
# Scatter rates into a [ns, ns, nc] tensor
rates = [infec_rate, ei, ir]
rates = tf.scatter_nd(updates=rates,
indices=[[0, 1], [1, 2], [2, 3]],
shape=[state.shape[0], state.shape[0], state.shape[1]])
return rates
n = state.shape[0]
b = tf.stack([tf.range(n),
tf.zeros(n, dtype=tf.int32),
tf.ones(n, dtype=tf.int32)], axis=-1)
indices = tf.stack([b, b + [0, 1, 1], b + [0, 2, 2]], axis=-2)
# Un-normalised rate matrix (diag is 0 here)
rate_matrix = tf.scatter_nd(indices=indices,
updates=tf.stack([infec_rate, ei, ir], axis=-1),
shape=[state.shape[0],
state.shape[1],
state.shape[1]])
return rate_matrix
return h
@tf.function(experimental_compile=True)
@tf.function(autograph=False)
def simulate(self, param, state_init):
"""Runs a simulation from the epidemic model
......
......@@ -13,26 +13,26 @@ from covid.util import sanitise_parameter, sanitise_settings, seed_areas
def sum_age_groups(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=2)
return infec_uk
def sum_la(sim):
infec = sim[:, 2, :]
infec = sim[:, :, 2]
infec = infec.reshape([infec.shape[0], 152, 17])
infec_uk = infec.sum(axis=1)
return infec_uk
def sum_total_removals(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
return remove.sum(axis=1)
def final_size(sim):
remove = sim[:, 3, :]
remove = sim[:, :, 3]
remove = remove.reshape([remove.shape[0], 152, 17])
fs = remove[-1, :, :].sum(axis=0)
return fs
......@@ -52,7 +52,7 @@ def write_hdf5(filename, param, t, sim):
def plot_total_curve(sim):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
removals = sum_total_removals(sim)
times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
plt.plot(times, infec_uk, 'r-', label='Infected')
......@@ -66,7 +66,7 @@ def plot_total_curve(sim):
def plot_infec_curve(ax, sim, label):
infec_uk = sum_la(sim)
infec_uk = infec_uk.sum(axis=1)
infec_uk = infec_uk.sum(axis=-1)
times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0])
ax.plot(times, infec_uk, '-', label=label)
......@@ -75,7 +75,7 @@ def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_la(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......@@ -88,7 +88,7 @@ def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
if ax is None:
ax = plt.figure().gca()
infec_uk = sum_age_groups(sim)
total_uk = infec_uk.mean(axis=1)
total_uk = infec_uk.mean(axis=-1)
t = t0 + np.arange(infec_uk.shape[0])
colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
for i in range(infec_uk.shape[1]):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment