mcmc.py 5.73 KB
Newer Older
1
2
"""Inference on stochastic models"""

3
import optparse
Chris Jewell's avatar
Chris Jewell committed
4
import time
5
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
6

7
8
9
10
11
12
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import numpy as np
Chris Jewell's avatar
Chris Jewell committed
13
14
import matplotlib.pyplot as plt
import yaml
15

16
17
from covid.model import CovidUKStochastic, load_data
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
18

Chris Jewell's avatar
Chris Jewell committed
19
20
DTYPE = np.float64

21
22
def random_walk_mvnorm_fn(covariance, name=None):
    """Returns callable that adds Multivariate Normal noise to the input"""
23
    covariance = covariance + tf.eye(covariance.shape[0], dtype=tf.float64) * 1.e-9
Chris Jewell's avatar
Chris Jewell committed
24
    scale_tril = tf.linalg.cholesky(covariance)
25
    rv = tfp.distributions.MultivariateNormalTriL(loc=tf.zeros(covariance.shape[0], dtype=tf.float64),
Chris Jewell's avatar
Chris Jewell committed
26
                                                  scale_tril=scale_tril)
27
28
29

    def _fn(state_parts, seed):
        with tf.name_scope(name or 'random_walk_mvnorm_fn'):
Chris Jewell's avatar
Chris Jewell committed
30
            new_state_parts = [rv.sample() + state_part for state_part in state_parts]
31
32
33
34
            return new_state_parts

    return _fn

35

36
37
38
if __name__ == '__main__':

    parser = optparse.OptionParser()
Chris Jewell's avatar
Chris Jewell committed
39
    parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
40
41
                      help="configuration file")
    options, args = parser.parse_args()
42

43
44
45
    with open(options.config, 'r') as ymlfile:
        config = yaml.load(ymlfile)

Chris Jewell's avatar
Chris Jewell committed
46
47
48
    param = sanitise_parameter(config['parameter'])
    settings = sanitise_settings(config['settings'])

49

50
    data = load_data(config['data'], settings, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
51

52
53
54
55
56
57
58
59
60
61
62
63
64
    model = CovidUKStochastic(M_tt=data['M_tt'],
                              M_hh=data['M_hh'],
                              C=data['C'],
                              N=data['pop']['n'].to_numpy(),
                              W=data['W'],
                              date_range=settings['inference_period'],
                              holidays=settings['holiday'],
                              lockdown=settings['lockdown'],
                              time_step=1.)

    with open('stochastic_sim.pkl', 'rb') as f:
        sim = pkl.load(f)

65
66
    events = tf.convert_to_tensor(sim['events'], dtype=DTYPE)
    state_init = tf.convert_to_tensor(sim['state_init'], dtype=DTYPE)
67

Chris Jewell's avatar
Chris Jewell committed
68
69
    param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

70
    def logp(par):
Chris Jewell's avatar
Chris Jewell committed
71
        print("Tracing logp")
72
        p = param
73
        p['beta1'] = par[0]
Chris Jewell's avatar
Chris Jewell committed
74
75
        p['beta3'] = par[1]
        p['gamma'] = par[2]
76
77
        beta_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
                              rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
78
79
        beta3_logp = tfd.Gamma(concentration=tf.constant(20., dtype=DTYPE),
                               rate=tf.constant(20., dtype=DTYPE)).log_prob(p['beta3'])
80
81
82
83
        gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
                               rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
        y_logp = model.log_prob(events, p, state_init)
        logp = beta_logp + beta3_logp + gamma_logp + y_logp
Chris Jewell's avatar
Chris Jewell committed
84
        return logp
85

86
    unconstraining_bijector = [tfb.Exp()]
Chris Jewell's avatar
Chris Jewell committed
87
    initial_mcmc_state = tf.constant([0.05, 0.5, 0.25], dtype=tf.float64)  # beta1, gamma, I0
88
    print("Initial log likelihood:", logp(initial_mcmc_state))
89

90
    @tf.function(experimental_compile=True)
Chris Jewell's avatar
Chris Jewell committed
91
    def sample(n_samples, init_state, scale, num_burnin_steps=0):
92
        return tfp.mcmc.sample_chain(
Chris Jewell's avatar
Chris Jewell committed
93
            num_results=n_samples,
Chris Jewell's avatar
Chris Jewell committed
94
            num_burnin_steps=num_burnin_steps,
95
96
97
            current_state=init_state,
            kernel=tfp.mcmc.TransformedTransitionKernel(
                    inner_kernel=tfp.mcmc.RandomWalkMetropolis(
98
                        target_log_prob_fn=logp,
99
100
                        new_state_fn=random_walk_mvnorm_fn(scale)
                    ),
101
                    bijector=unconstraining_bijector),
102
            trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)
Chris Jewell's avatar
Chris Jewell committed
103

Chris Jewell's avatar
Chris Jewell committed
104
105
    joint_posterior = tf.zeros([0] + list(initial_mcmc_state.shape), dtype=DTYPE)

106
    scale = np.diag([0.1, 0.1, 0.1])
Chris Jewell's avatar
Chris Jewell committed
107
108
    overall_start = time.perf_counter()

Chris Jewell's avatar
Chris Jewell committed
109
    num_covariance_estimation_iterations = 20
Chris Jewell's avatar
Chris Jewell committed
110
    num_covariance_estimation_samples = 50
Chris Jewell's avatar
Chris Jewell committed
111
    num_final_samples = 10000
Chris Jewell's avatar
Chris Jewell committed
112
113
    start = time.perf_counter()
    for i in range(num_covariance_estimation_iterations):
Chris Jewell's avatar
Chris Jewell committed
114
        step_start = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
115
116
117
        samples, results = sample(num_covariance_estimation_samples,
                                  initial_mcmc_state,
                                  scale)
Chris Jewell's avatar
Chris Jewell committed
118
        step_end = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
119
120
121
122
123
124
125
126
127
        print(f'{i} time {step_end - step_start}')
        print("Acceptance: ", results.numpy().mean())
        joint_posterior = tf.concat([joint_posterior, samples], axis=0)
        cov = tfp.stats.covariance(tf.math.log(joint_posterior))
        print(cov.numpy())
        scale = cov * 2.38**2 / joint_posterior.shape[1]
        initial_mcmc_state = joint_posterior[-1, :]

    step_start = time.perf_counter()
128
    #tf.profiler.experimental.start('mcmc_logdir')
Chris Jewell's avatar
Chris Jewell committed
129
130
    samples, results = sample(num_final_samples,
                              init_state=joint_posterior[-1, :], scale=scale,)
131
    #tf.profiler.experimental.stop()
Chris Jewell's avatar
Chris Jewell committed
132
133
134
135
136
137
138
    joint_posterior = tf.concat([joint_posterior, samples], axis=0)
    step_end = time.perf_counter()
    print(f'Sampling step time {step_end - step_start}')
    end = time.perf_counter()
    print(f"Simulation complete in {end-start} seconds")
    print("Acceptance: ", np.mean(results.numpy()))
    print(tfp.stats.covariance(tf.math.log(joint_posterior)))
Chris Jewell's avatar
Chris Jewell committed
139

140
141
142
143
    fig, ax = plt.subplots(1, joint_posterior.shape[1])
    for i in range(joint_posterior.shape[1]):
        ax[i].plot(joint_posterior[:, i])

144
    plt.show()
145
    print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}")
146

147
148
    with open('stochastic_posterior.pkl', 'wb') as f:
        pkl.dump(joint_posterior, f)