Commit 79f51e8d by Chris Jewell Committed by GitHub

### Merge pull request #1 from csuter/stochastic

`Improve efficiency of stochastic model.`
parents c5590098 864fcf65
 ... ... @@ -13,38 +13,32 @@ def chain_binomial_propagate(h, time_step): :returns : a function that propagate `state[t]` -> `state[t+time_step]` """ def propagate_fn(t, state): # State is assumed to be of shape [s, n] where s is the number of states # and n is the number of population strata. # TODO: having state as [s, n] means we have to do some funky transposition. It may be better # to have state.shape = [n, s] which avoids transposition below, but may lead to slower # rate calculations. rate_matrix = h(t, state) rate_matrix = tf.transpose(rate_matrix, perm=[2, 0, 1]) # Set diagonal to be the negative of the sum of other elements in each row rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=2)) rate_matrix = tf.linalg.set_diag(rate_matrix, -tf.reduce_sum(rate_matrix, axis=-1)) # Calculate Markov transition probability matrix markov_transition = tf.linalg.expm(rate_matrix*time_step) # Sample new state new_state = tfd.Multinomial(total_count=tf.transpose(state), new_state = tfd.Multinomial(total_count=state, probs=markov_transition).sample() new_state = tf.reduce_sum(new_state, axis=1) return tf.transpose(new_state) new_state = tf.reduce_sum(new_state, axis=-1) return new_state return propagate_fn @tf.function(autograph=False) # Algorithm runs super slow if uncommented. Weird! def chain_binomial_simulate(hazard_fn, state, start, end, time_step): propagate = chain_binomial_propagate(hazard_fn, time_step) times = tf.range(start, end, time_step) output = tf.TensorArray(tf.float64, size=times.shape[0]) output = tf.TensorArray(state.dtype, size=times.shape[0]) output = output.write(0, state) for i in range(1, times.shape[0]): state = propagate(i, state) output = output.write(i, state) sim = output.gather(tf.range(1, times.shape[0])) return times, sim cond = lambda i, *_: i < times.shape[0] def body(i, state, output): state = propagate(i, state) output = output.write(i, state) return i + 1, state, output _, state, output = tf.while_loop(cond, body, loop_vars=(0, state, output)) return times, output.stack()
 ... ... @@ -110,7 +110,7 @@ class CovidUK: S = self.N - I E = np.zeros(self.N.shape, dtype=np.float64) R = np.zeros(self.N.shape, dtype=np.float64) return np.stack([S, E, I, R]) return np.stack([S, E, I, R], axis=-1) class CovidUKODE(CovidUK): ... ... @@ -124,7 +124,7 @@ class CovidUKODE(CovidUK): def h_fn(t, state): S, E, I, R = state S, E, I, R = tf.unstack(state, axis=-1) # Integrator may produce time values outside the range desired, so # we clip, implicitly assuming the outside dates have the same # holiday status as their nearest neighbors in the desired range. ... ... @@ -141,7 +141,7 @@ class CovidUKODE(CovidUK): EI = param['nu'] IR = param['gamma'] p = 1 - tf.exp([SE, EI, IR]) p = 1 - tf.exp(tf.stack([SE, EI, IR], axis=-1)) return p return h_fn ... ... @@ -171,8 +171,8 @@ class CovidUKODE(CovidUK): def covid19uk_logp(y, sim, phi): # Sum daily increments in removed r_incr = sim[1:, 3, :] - sim[:-1, 3, :] r_incr = tf.reduce_sum(r_incr, axis=1) r_incr = sim[1:, :, 3] - sim[:-1, :, 3] r_incr = tf.reduce_sum(r_incr, axis=-1) y_ = tfp.distributions.Poisson(rate=phi*r_incr) return y_.log_prob(y) ... ... @@ -201,22 +201,30 @@ class CovidUKStochastic(CovidUK): commute_volume = tf.pow(tf.gather(self.W, t_idx), param['omega']) infec_rate = param['beta1'] * ( tf.gather(self.M.matvec(state[2]), m_switch) + param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[2] / self.N_sum)) tf.gather(self.M.matvec(state[:, 2]), m_switch) + param['beta2'] * self.Kbar * commute_volume * self.C.matvec(state[:, 2] / self.N_sum)) infec_rate = infec_rate / self.N ei = tf.broadcast_to([param['nu']], shape=[state.shape[1]]) ir = tf.broadcast_to([param['gamma']], shape=[state.shape[1]]) ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]]) ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]]) # Scatter rates into a [ns, ns, nc] tensor rates = [infec_rate, ei, ir] rates = tf.scatter_nd(updates=rates, indices=[[0, 1], [1, 2], [2, 3]], shape=[state.shape[0], state.shape[0], state.shape[1]]) return rates n = state.shape[0] b = tf.stack([tf.range(n), tf.zeros(n, dtype=tf.int32), tf.ones(n, dtype=tf.int32)], axis=-1) indices = tf.stack([b, b + [0, 1, 1], b + [0, 2, 2]], axis=-2) # Un-normalised rate matrix (diag is 0 here) rate_matrix = tf.scatter_nd(indices=indices, updates=tf.stack([infec_rate, ei, ir], axis=-1), shape=[state.shape[0], state.shape[1], state.shape[1]]) return rate_matrix return h @tf.function(experimental_compile=True) @tf.function(autograph=False) def simulate(self, param, state_init): """Runs a simulation from the epidemic model ... ...
 ... ... @@ -13,26 +13,26 @@ from covid.util import sanitise_parameter, sanitise_settings, seed_areas def sum_age_groups(sim): infec = sim[:, 2, :] infec = sim[:, :, 2] infec = infec.reshape([infec.shape[0], 152, 17]) infec_uk = infec.sum(axis=2) return infec_uk def sum_la(sim): infec = sim[:, 2, :] infec = sim[:, :, 2] infec = infec.reshape([infec.shape[0], 152, 17]) infec_uk = infec.sum(axis=1) return infec_uk def sum_total_removals(sim): remove = sim[:, 3, :] remove = sim[:, :, 3] return remove.sum(axis=1) def final_size(sim): remove = sim[:, 3, :] remove = sim[:, :, 3] remove = remove.reshape([remove.shape[0], 152, 17]) fs = remove[-1, :, :].sum(axis=0) return fs ... ... @@ -52,7 +52,7 @@ def write_hdf5(filename, param, t, sim): def plot_total_curve(sim): infec_uk = sum_la(sim) infec_uk = infec_uk.sum(axis=1) infec_uk = infec_uk.sum(axis=-1) removals = sum_total_removals(sim) times = np.datetime64('2020-02-20') + np.arange(removals.shape[0]) plt.plot(times, infec_uk, 'r-', label='Infected') ... ... @@ -66,7 +66,7 @@ def plot_total_curve(sim): def plot_infec_curve(ax, sim, label): infec_uk = sum_la(sim) infec_uk = infec_uk.sum(axis=1) infec_uk = infec_uk.sum(axis=-1) times = np.datetime64('2020-02-20') + np.arange(infec_uk.shape[0]) ax.plot(times, infec_uk, '-', label=label) ... ... @@ -75,7 +75,7 @@ def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None): if ax is None: ax = plt.figure().gca() infec_uk = sum_la(sim) total_uk = infec_uk.mean(axis=1) total_uk = infec_uk.mean(axis=-1) t = t0 + np.arange(infec_uk.shape[0]) colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1])) for i in range(infec_uk.shape[1]): ... ... @@ -88,7 +88,7 @@ def plot_by_la(sim, labels, t0=np.datetime64('2020-02-20'), ax=None): if ax is None: ax = plt.figure().gca() infec_uk = sum_age_groups(sim) total_uk = infec_uk.mean(axis=1) total_uk = infec_uk.mean(axis=-1) t = t0 + np.arange(infec_uk.shape[0]) colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1])) for i in range(infec_uk.shape[1]): ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!