Commit 26329184 by Chris Jewell

### Fixed a silly bug in random_walk_mvnorm_fn

parent 899551a9
 ... @@ -9,8 +9,8 @@ import numpy as np ... @@ -9,8 +9,8 @@ import numpy as np from covid.impl.chainbinom_simulate import chain_binomial_simulate from covid.impl.chainbinom_simulate import chain_binomial_simulate def power_iteration(A, tol=1e-3): def power_iteration(A, tol=1e-3): b_k = tf.random.normal([A.shape[1], 1]) b_k = tf.random.normal([A.shape[1], 1], dtype=tf.float64) epsilon = 1. epsilon = tf.constant(1., dtype=tf.float64) i = 0 i = 0 while tf.greater(epsilon, tol): while tf.greater(epsilon, tol): b_k1 = tf.matmul(A, b_k) b_k1 = tf.matmul(A, b_k) ... @@ -178,7 +178,7 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. ... @@ -178,7 +178,7 @@ class CovidUKODE: # TODO: add background case importation rate to the UK, e.g. def eval_R0(self, param, tol=1e-8): def eval_R0(self, param, tol=1e-8): ngm = self.ngm(param) ngm = self.ngm(param) # Dominant eigen value by power iteration # Dominant eigen value by power iteration dom_eigen_vec, i = power_iteration(ngm, tol=tol) dom_eigen_vec, i = power_iteration(ngm, tol=tf.cast(tol, tf.float64)) R0 = rayleigh_quotient(ngm, dom_eigen_vec) R0 = rayleigh_quotient(ngm, dom_eigen_vec) return tf.squeeze(R0), i return tf.squeeze(R0), i ... ...
 ... @@ -92,7 +92,7 @@ if __name__ == '__main__': ... @@ -92,7 +92,7 @@ if __name__ == '__main__': p = param p = param p['epsilon'] = par[0] p['epsilon'] = par[0] p['beta1'] = par[1] p['beta1'] = par[1] #p['gamma'] = par[2] p['gamma'] = par[2] epsilon_logp = tfd.Gamma(concentration=tf.constant(1., tf.float64), rate=tf.constant(1., tf.float64)).log_prob(p['epsilon']) epsilon_logp = tfd.Gamma(concentration=tf.constant(1., tf.float64), rate=tf.constant(1., tf.float64)).log_prob(p['epsilon']) beta_logp = tfd.Gamma(concentration=tf.constant(1., tf.float64), rate=tf.constant(1., tf.float64)).log_prob(p['beta1']) beta_logp = tfd.Gamma(concentration=tf.constant(1., tf.float64), rate=tf.constant(1., tf.float64)).log_prob(p['beta1']) gamma_logp = tfd.Gamma(concentration=tf.constant(100., tf.float64), rate=tf.constant(400., tf.float64)).log_prob(p['gamma']) gamma_logp = tfd.Gamma(concentration=tf.constant(100., tf.float64), rate=tf.constant(400., tf.float64)).log_prob(p['gamma']) ... @@ -109,10 +109,10 @@ if __name__ == '__main__': ... @@ -109,10 +109,10 @@ if __name__ == '__main__': unconstraining_bijector = [tfb.Exp()] unconstraining_bijector = [tfb.Exp()] initial_mcmc_state = np.array([0.001, 0.036], dtype=np.float64) initial_mcmc_state = np.array([0.001, 0.036, 0.25], dtype=np.float64) print("Initial log likelihood:", logp(initial_mcmc_state)) print("Initial log likelihood:", logp(initial_mcmc_state)) #@tf.function @tf.function(experimental_compile=True) def sample(n_samples, init_state, scale): def sample(n_samples, init_state, scale): return tfp.mcmc.sample_chain( return tfp.mcmc.sample_chain( num_results=n_samples, num_results=n_samples, ... @@ -124,19 +124,19 @@ if __name__ == '__main__': ... @@ -124,19 +124,19 @@ if __name__ == '__main__': new_state_fn=random_walk_mvnorm_fn(scale) new_state_fn=random_walk_mvnorm_fn(scale) ), ), bijector=unconstraining_bijector), bijector=unconstraining_bijector), trace_fn=lambda _, pkr: pkr) trace_fn=lambda _, pkr: pkr.inner_results.is_accepted) with tf.device("/CPU:0"): with tf.device("/CPU:0"): cov = np.diag([0.00001, 0.00001]) cov = np.diag([0.00001, 0.00001, 0.00001]) start = time.perf_counter() start = time.perf_counter() joint_posterior, results = sample(50, init_state=initial_mcmc_state, scale=cov) joint_posterior, results = sample(50, init_state=initial_mcmc_state, scale=cov) for i in range(20): for i in range(200): cov = tfp.stats.covariance(tf.math.log(joint_posterior)) * 2.38**2 / joint_posterior.shape[1] cov = tfp.stats.covariance(tf.math.log(joint_posterior)) * 2.38**2 / joint_posterior.shape[1] print(cov.numpy()) print(cov.numpy()) posterior_new, results = sample(50, joint_posterior[-1, :], cov) posterior_new, results = sample(50, joint_posterior[-1, :], cov) joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0) joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0) posterior_new, results = sample(1000, init_state=joint_posterior[-1, :], scale=cov) #posterior_new, results = sample(2000, init_state=joint_posterior[-1, :], scale=cov) joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0) #joint_posterior = tf.concat([joint_posterior, posterior_new], axis=0) end = time.perf_counter() end = time.perf_counter() print(f"Simulation complete in {end-start} seconds") print(f"Simulation complete in {end-start} seconds") print("Acceptance: ", np.mean(results.numpy())) print("Acceptance: ", np.mean(results.numpy())) ... @@ -145,7 +145,7 @@ if __name__ == '__main__': ... @@ -145,7 +145,7 @@ if __name__ == '__main__': fig, ax = plt.subplots(1, 3) fig, ax = plt.subplots(1, 3) ax[0].plot(joint_posterior[:, 0]) ax[0].plot(joint_posterior[:, 0]) ax[1].plot(joint_posterior[:, 1]) ax[1].plot(joint_posterior[:, 1]) #ax[2].plot(joint_posterior[:, 2]) ax[2].plot(joint_posterior[:, 2]) plt.show() plt.show() print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}") print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}") ... ...
 ... @@ -2,7 +2,8 @@ ... @@ -2,7 +2,8 @@ data: data: age_mixing_matrix_term: data/polymod_normal_df.rds age_mixing_matrix_term: data/polymod_normal_df.rds age_mixing_matrix_hol: data/polymod_no_school_df.rds #age_mixing_matrix_hol: data/polymod_no_school_df.rds age_mixing_matrix_hol: data/polymod_weekend_df.rds mobility_matrix: data/movement.rds mobility_matrix: data/movement.rds population_size: data/pop.rds population_size: data/pop.rds reported_cases: data/DailyConfirmedCases.csv reported_cases: data/DailyConfirmedCases.csv ... @@ -18,8 +19,8 @@ settings: ... @@ -18,8 +19,8 @@ settings: start: 2020-02-04 start: 2020-02-04 end: 2020-04-01 end: 2020-04-01 holiday: holiday: - 2020-04-06 - 2020-03-23 - 2020-04-17 - 2020-10-01 bg_max_time: 2020-03-01 bg_max_time: 2020-03-01 time_step: 1. time_step: 1. ... ...
 ... @@ -67,26 +67,23 @@ if __name__ == '__main__': ... @@ -67,26 +67,23 @@ if __name__ == '__main__': state_init = simulator.create_initial_state(init_matrix=seeding) state_init = simulator.create_initial_state(init_matrix=seeding) @tf.function @tf.function def prediction(epsilon, beta): def prediction(epsilon, beta, gamma): sims = tf.TensorArray(tf.float32, size=beta.shape[0]) sims = tf.TensorArray(tf.float64, size=beta.shape[0]) R0 = tf.TensorArray(tf.float32, size=beta.shape[0]) R0 = tf.TensorArray(tf.float64, size=beta.shape[0]) #d_time = tf.TensorArray(tf.float32, size=beta.shape[0]) for i in tf.range(beta.shape[0]): for i in tf.range(beta.shape[0]): p = param p = param p['epsilon'] = epsilon[i] p['epsilon'] = epsilon[i] p['beta1'] = beta[i] p['beta1'] = beta[i] p['gamma'] = gamma[i] t, sim, solver_results = simulator.simulate(p, state_init) t, sim, solver_results = simulator.simulate(p, state_init) r = simulator.eval_R0(p) r = simulator.eval_R0(p) R0 = R0.write(i, r[0]) R0 = R0.write(i, r[0]) #d_time = d_time.write(i, doubling_time(t, sim, '2002-03-01', '2002-04-01')) #sim_aggr = tf.reduce_sum(sim, axis=2) sims = sims.write(i, sim) sims = sims.write(i, sim) return sims.gather(range(beta.shape[0])), R0.gather(range(beta.shape[0])) return sims.gather(range(beta.shape[0])), R0.gather(range(beta.shape[0])) draws = [pi_beta[0].numpy()[np.arange(500, pi_beta[0].shape[0], 10)], draws = pi_beta.numpy()[np.arange(5000, pi_beta.shape[0], 10), :] pi_beta[1].numpy()[np.arange(500, pi_beta[1].shape[0], 10)]] with tf.device('/CPU:0'): with tf.device('/CPU:0'): sims, R0 = prediction(draws[0], draws[1]) sims, R0 = prediction(draws[:, 0], draws[:, 1], draws[:, 2]) sims = tf.stack(sims) # shape=[n_sims, n_times, n_states, n_metapops] sims = tf.stack(sims) # shape=[n_sims, n_times, n_states, n_metapops] save_sims(sims, la_names, age_groups, 'pred_2020-03-15.h5') save_sims(sims, la_names, age_groups, 'pred_2020-03-15.h5') ... @@ -104,18 +101,18 @@ if __name__ == '__main__': ... @@ -104,18 +101,18 @@ if __name__ == '__main__': removed_observed = tfs.percentile(removed * 0.1, q=[2.5, 50, 97.5], axis=0) removed_observed = tfs.percentile(removed * 0.1, q=[2.5, 50, 97.5], axis=0) fig = plt.figure() fig = plt.figure() filler = plt.fill_between(dates, total_infected[0, :], total_infected[2, :], color='lightgray', label="95% credible interval") filler = plt.fill_between(dates, total_infected[0, :], total_infected[2, :], color='lightgray', alpha=0.8, label="95% credible interval") plt.fill_between(dates, removed[0, :], removed[2, :], color='lightgray') plt.fill_between(dates, removed[0, :], removed[2, :], color='lightgray', alpha=0.8) plt.fill_between(dates, removed_observed[0, :], removed_observed[2, :], color='lightgray') plt.fill_between(dates, removed_observed[0, :], removed_observed[2, :], color='lightgray', alpha=0.8) ti_line = plt.plot(dates, total_infected[1, :], '-', color='red', alpha=0.4, label="Infected") ti_line = plt.plot(dates, total_infected[1, :], '-', color='red', alpha=0.4, label="Infected") rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed") rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed") ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections') ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections') marks = plt.plot(data_dates, y, '+', label='Observed cases') marks = plt.plot(data_dates, y, '+', label='Observed cases') plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]], plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]], ["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"]) ["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"]) plt.grid() plt.grid(color='lightgray', linestyle='dotted') plt.xlabel("Date") plt.xlabel("Date") plt.ylabel("\$10^7\$ individuals") plt.ylabel("Individuals") fig.autofmt_xdate() fig.autofmt_xdate() plt.show() plt.show() ... @@ -124,7 +121,7 @@ if __name__ == '__main__': ... @@ -124,7 +121,7 @@ if __name__ == '__main__': fig = plt.figure() fig = plt.figure() plt.fill_between(dates[:-1], new_cases[0, :], new_cases[2, :], color='lightgray', label="95% credible interval") plt.fill_between(dates[:-1], new_cases[0, :], new_cases[2, :], color='lightgray', label="95% credible interval") plt.plot(dates[:-1], new_cases[1, :], '-', alpha=0.2, label='New cases') plt.plot(dates[:-1], new_cases[1, :], '-', alpha=0.2, label='New cases') plt.grid() plt.grid(color='lightgray', linestyle='dotted') plt.xlabel("Date") plt.xlabel("Date") plt.ylabel("Incidence per 10,000") plt.ylabel("Incidence per 10,000") fig.autofmt_xdate() fig.autofmt_xdate() ... @@ -141,3 +138,7 @@ if __name__ == '__main__': ... @@ -141,3 +138,7 @@ if __name__ == '__main__': # Doubling time # Doubling time dub_ci = tfs.percentile(dub_time, q=[2.5, 50, 97.5]) dub_ci = tfs.percentile(dub_time, q=[2.5, 50, 97.5]) print("Doubling time:", dub_ci) print("Doubling time:", dub_ci) # Infectious period ip = tfs.percentile(1./pi_beta[3000:, 2], q=[2.5, 50, 97.5]) print("Infectious period:", ip)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!