mcmc.py 6.49 KB
Newer Older
1
2
import optparse
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
3
import time
Chris Jewell's avatar
Chris Jewell committed
4

Chris Jewell's avatar
Chris Jewell committed
5
6
7
8
import matplotlib.pyplot as plt
import yaml
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
9

Chris Jewell's avatar
Chris Jewell committed
10
from covid.model import CovidUKODE, covid19uk_logp, load_data
11
12
from covid.util import *

Chris Jewell's avatar
Chris Jewell committed
13
14
DTYPE = np.float64

15

16
17
def random_walk_mvnorm_fn(covariance, name=None):
    """Returns callable that adds Multivariate Normal noise to the input"""
18
    covariance = covariance + tf.eye(covariance.shape[0], dtype=tf.float64) * 1.e-9
Chris Jewell's avatar
Chris Jewell committed
19
    scale_tril = tf.linalg.cholesky(covariance)
20
    rv = tfp.distributions.MultivariateNormalTriL(loc=tf.zeros(covariance.shape[0], dtype=tf.float64),
Chris Jewell's avatar
Chris Jewell committed
21
                                                  scale_tril=scale_tril)
22
23
24

    def _fn(state_parts, seed):
        with tf.name_scope(name or 'random_walk_mvnorm_fn'):
Chris Jewell's avatar
Chris Jewell committed
25
            new_state_parts = [rv.sample() + state_part for state_part in state_parts]
26
27
28
29
            return new_state_parts

    return _fn

30
31


32
33
34
if __name__ == '__main__':

    parser = optparse.OptionParser()
Chris Jewell's avatar
Chris Jewell committed
35
    parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
36
37
38
39
40
                      help="configuration file")
    options, args = parser.parse_args()
    with open(options.config, 'r') as ymlfile:
        config = yaml.load(ymlfile)

Chris Jewell's avatar
Chris Jewell committed
41
42
43
    param = sanitise_parameter(config['parameter'])
    settings = sanitise_settings(config['settings'])

Chris Jewell's avatar
Chris Jewell committed
44
    data = load_data(config['data'], settings)
45
46
47

    case_reports = pd.read_csv(config['data']['reported_cases'])
    case_reports['DateVal'] = pd.to_datetime(case_reports['DateVal'])
Chris Jewell's avatar
Chris Jewell committed
48
    case_reports = case_reports[case_reports['DateVal'] >= '2020-02-19']
49
50
51
52
    date_range = [case_reports['DateVal'].min(), case_reports['DateVal'].max()]
    y = case_reports['CumCases'].to_numpy()
    y_incr = np.round((y[1:] - y[:-1]) * 0.8)

Chris Jewell's avatar
Chris Jewell committed
53
54
55
56
57
58
59
60
61
62
63
    simulator = CovidUKODE(M_tt=data['M_tt'],
                           M_hh=data['M_hh'],
                           C=data['C'],
                           N=data['pop']['n'].to_numpy(),
                           W=data['W'].to_numpy(),
                           date_range=[date_range[0] - np.timedelta64(1, 'D'), date_range[1]],
                           holidays=settings['holiday'],
                           lockdown=settings['lockdown'],
                           time_step=int(settings['time_step']))

    seeding = seed_areas(data['pop']['n'].to_numpy(), data['pop']['Area.name.2'])  # Seed 40-44 age group, 30 seeds by popn size
64
65
    state_init = simulator.create_initial_state(init_matrix=seeding)

66
    def logp(par):
67
        p = param
68
        p['beta1'] = par[0]
Chris Jewell's avatar
Chris Jewell committed
69
70
71
72
        p['beta3'] = par[1]
        p['gamma'] = par[2]
        p['I0'] = par[3]
        p['r'] = par[4]
73
        beta_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE), rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
Chris Jewell's avatar
Chris Jewell committed
74
75
        beta3_logp = tfd.Gamma(concentration=tf.constant(200., dtype=DTYPE),
                               rate=tf.constant(200., dtype=DTYPE)).log_prob(p['beta3'])
76
77
78
        gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE), rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
        I0_logp = tfd.Gamma(concentration=tf.constant(1.5, dtype=DTYPE), rate=tf.constant(0.05, dtype=DTYPE)).log_prob(p['I0'])
        r_logp = tfd.Gamma(concentration=tf.constant(0.1, dtype=DTYPE), rate=tf.constant(0.1, dtype=DTYPE)).log_prob(p['gamma'])
Chris Jewell's avatar
Chris Jewell committed
79
        state_init = simulator.create_initial_state(init_matrix=seeding * p['I0'])
80
        t, sim, solve = simulator.simulate(p, state_init)
81
        y_logp = covid19uk_logp(y_incr, sim, 0.1, p['r'])
Chris Jewell's avatar
Chris Jewell committed
82
        logp = beta_logp + beta3_logp + gamma_logp + I0_logp + r_logp + tf.reduce_sum(y_logp)
Chris Jewell's avatar
Chris Jewell committed
83
        return logp
84

Chris Jewell's avatar
Chris Jewell committed
85
86
87
88
89
90
91
    def trace_fn(_, pkr):
      return (
          pkr.inner_results.log_accept_ratio,
          pkr.inner_results.accepted_results.target_log_prob,
          pkr.inner_results.accepted_results.step_size)


92
    unconstraining_bijector = [tfb.Exp()]
Chris Jewell's avatar
Chris Jewell committed
93
    initial_mcmc_state = np.array([0.05, 1.0, 0.25, 1.0, 50.], dtype=np.float64)  # beta1, gamma, I0
94
    print("Initial log likelihood:", logp(initial_mcmc_state))
95

Chris Jewell's avatar
Chris Jewell committed
96
97
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, scale, num_burnin_steps=0):
98
        return tfp.mcmc.sample_chain(
Chris Jewell's avatar
Chris Jewell committed
99
            num_results=n_samples,
Chris Jewell's avatar
Chris Jewell committed
100
            num_burnin_steps=num_burnin_steps,
101
102
103
            current_state=init_state,
            kernel=tfp.mcmc.TransformedTransitionKernel(
                    inner_kernel=tfp.mcmc.RandomWalkMetropolis(
104
                        target_log_prob_fn=logp,
105
106
                        new_state_fn=random_walk_mvnorm_fn(scale)
                    ),
107
                    bijector=unconstraining_bijector),
108
            trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)
Chris Jewell's avatar
Chris Jewell committed
109

Chris Jewell's avatar
Chris Jewell committed
110
111
    joint_posterior = tf.zeros([0] + list(initial_mcmc_state.shape), dtype=DTYPE)

Chris Jewell's avatar
Chris Jewell committed
112
    scale = np.diag([0.1, 0.1, 0.1, 0.1, 1.])
Chris Jewell's avatar
Chris Jewell committed
113
114
115
116
117
    overall_start = time.perf_counter()

    num_covariance_estimation_iterations = 50
    num_covariance_estimation_samples = 50
    num_final_samples = 10000
Chris Jewell's avatar
Chris Jewell committed
118
119
    start = time.perf_counter()
    for i in range(num_covariance_estimation_iterations):
Chris Jewell's avatar
Chris Jewell committed
120
        step_start = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
121
122
123
        samples, results = sample(num_covariance_estimation_samples,
                                  initial_mcmc_state,
                                  scale)
Chris Jewell's avatar
Chris Jewell committed
124
        step_end = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        print(f'{i} time {step_end - step_start}')
        print("Acceptance: ", results.numpy().mean())
        joint_posterior = tf.concat([joint_posterior, samples], axis=0)
        cov = tfp.stats.covariance(tf.math.log(joint_posterior))
        print(cov.numpy())
        scale = cov * 2.38**2 / joint_posterior.shape[1]
        initial_mcmc_state = joint_posterior[-1, :]

    step_start = time.perf_counter()
    samples, results = sample(num_final_samples,
                              init_state=joint_posterior[-1, :], scale=scale,)
    joint_posterior = tf.concat([joint_posterior, samples], axis=0)
    step_end = time.perf_counter()
    print(f'Sampling step time {step_end - step_start}')
    end = time.perf_counter()
    print(f"Simulation complete in {end-start} seconds")
    print("Acceptance: ", np.mean(results.numpy()))
    print(tfp.stats.covariance(tf.math.log(joint_posterior)))
Chris Jewell's avatar
Chris Jewell committed
143

144
145
146
147
    fig, ax = plt.subplots(1, joint_posterior.shape[1])
    for i in range(joint_posterior.shape[1]):
        ax[i].plot(joint_posterior[:, i])

148
    plt.show()
149
    print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}")
150

Chris Jewell's avatar
Chris Jewell committed
151
    with open('pi_beta_2020-03-29.pkl', 'wb') as f:
152
        pkl.dump(joint_posterior, f)