mcmc.py 9.47 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
import os
2
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
3

Chris Jewell's avatar
Chris Jewell committed
4
import h5py
5
import numpy as np
6
7
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
8
9
10
import tqdm
import yaml

11
12
13
tfd = tfp.distributions
tfb = tfp.bijectors

14
15
from covid import config
from covid.model import load_data, CovidUKStochastic
Chris Jewell's avatar
Chris Jewell committed
16
from covid.util import sanitise_parameter, sanitise_settings
17
from covid.impl.util import make_transition_matrix
18
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
19
from covid.impl.event_time_mh import EventTimesUpdate
20
21
22

DTYPE = config.floatX

Chris Jewell's avatar
Chris Jewell committed
23
24
25
26
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

tf.random.set_seed(10)

27
# Random moves of events.  What invalidates an epidemic, how can we test for it?
Chris Jewell's avatar
Chris Jewell committed
28
with open('ode_config.yaml', 'r') as f:
29
30
31
32
33
34
35
36
    config = yaml.load(f)

param = sanitise_parameter(config['parameter'])
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

settings = sanitise_settings(config['settings'])

data = load_data(config['data'], settings, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
37
data['pop'] = data['pop'].sum(level=0)
38

39
model = CovidUKStochastic(C=data['C'][:10, :10],
Chris Jewell's avatar
Chris Jewell committed
40
                          N=[1000] * 10,  # data['pop']['n'].to_numpy(),
41
42
43
44
45
46
                          W=data['W'],
                          date_range=settings['inference_period'],
                          holidays=settings['holiday'],
                          lockdown=settings['lockdown'],
                          time_step=1.)

47
# Load data
48
with open('stochastic_sim_medium.pkl', 'rb') as f:
49
    example_sim = pkl.load(f)
50

51
52
53
54
55
56
57
event_tensor = example_sim['events']  # shape [T, M, S, S]
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
state_init = example_sim['state_init']
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
Chris Jewell's avatar
Chris Jewell committed
58

Chris Jewell's avatar
Chris Jewell committed
59

60
def logp(par, events):
61
    p = param
Chris Jewell's avatar
Chris Jewell committed
62
    p['beta1'] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
63
64
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
65
    p['gamma'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
66
    beta1_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
67
68
69
                           rate=tf.constant(1., dtype=DTYPE)).log_prob(
        p['beta1'])
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
70
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
71
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
72
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
73
    gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
74
75
                           rate=tf.constant(400., dtype=DTYPE)).log_prob(
        p['gamma'])
Chris Jewell's avatar
Chris Jewell committed
76
77
78
79
80
81
    with tf.name_scope('main_log_p'):
        event_tensor = make_transition_matrix(events,
                                              [[0, 1], [1, 2], [2, 3]],
                                              [num_times, num_meta, 4])
        y_logp = tf.reduce_sum(model.log_prob(event_tensor, p, state_init))
    logp = beta1_logp + gamma_logp + y_logp
82
83
84
85
86
87
88
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
89
90
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
91
92
93
94
95
                target_log_prob_fn=logp,
                new_state_fn=random_walk_mvnorm_fn(scale,
                                                   p_u=bounded_convergence)
            ), name='parameter_update')

96
97
98
    return kernel_func


99
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
100
    def kernel_func(logp):
101
102
103
104
        return EventTimesUpdate(target_log_prob_fn=logp,
                                target_event_id=target_event_id,
                                prev_event_id=prev_event_id,
                                next_event_id=next_event_id,
Chris Jewell's avatar
Chris Jewell committed
105
106
107
                                dmax=2,
                                mmax=2,
                                nmax=10,
108
                                initial_state=state_init)
Chris Jewell's avatar
Chris Jewell committed
109

110
111
112
113
114
    return kernel_func


def is_accepted(result):
    if hasattr(result, 'is_accepted'):
115
        return tf.cast(result.is_accepted, DTYPE)
116
117
118
119
    else:
        return is_accepted(result.inner_results)


120
121
122
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
123
    q_ratio = results.proposed_results.log_acceptance_correction
124
    proposed = results.proposed_results.extra
Chris Jewell's avatar
Chris Jewell committed
125
    return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
126
127


Chris Jewell's avatar
Chris Jewell committed
128
@tf.function(autograph=False, experimental_compile=True)
129
130
131
def sample(n_samples, init_state, par_scale):
    init_state = init_state.copy()
    par_func = make_parameter_kernel(par_scale, 0.95)
Chris Jewell's avatar
Chris Jewell committed
132
    se_func = make_events_step(0, None, 1)
Chris Jewell's avatar
Chris Jewell committed
133
    ei_func = make_events_step(1, 0, 2)
134

Chris Jewell's avatar
Chris Jewell committed
135
136
137
138
    # Based on Gibbs idea posted by Pavel Sountsov
    # https://github.com/tensorflow/probability/issues/495
    results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
        init_state[1])
139
140

    samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
Chris Jewell's avatar
Chris Jewell committed
141
    results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
142
143

    def body(i, state, prev_results, samples, results):
Chris Jewell's avatar
Chris Jewell committed
144
145
146
147
        # Parameters
        def par_logp(par_state):
            state[0] = par_state  # close over state from outer scope
            return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
148
149
150

        state[0], par_results = par_func(par_logp).one_step(state[0],
                                                            prev_results)
Chris Jewell's avatar
Chris Jewell committed
151
152
153
154
155

        # States
        def state_logp(event_state):
            state[1] = event_state
            return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
156
157
158
159
160

        state[1], se_results = se_func(state_logp).one_step(state[1],
                                                            par_results)
        state[1], ei_results = ei_func(state_logp).one_step(state[1],
                                                            se_results)
Chris Jewell's avatar
Chris Jewell committed
161
162

        samples = [samples[k].write(i, s) for k, s in enumerate(state)]
163
164
        results = [results[k].write(i, trace_results_fn(r))
                   for k, r in enumerate([par_results, se_results, ei_results])]
Chris Jewell's avatar
Chris Jewell committed
165
        return i + 1, state, ei_results, samples, results
166
167
168
169
170

    def cond(i, _1, _2, _3, _4):
        return i < n_samples

    _1, _2, _3, samples, results = tf.while_loop(cond=cond, body=body,
Chris Jewell's avatar
Chris Jewell committed
171
172
173
174
                                                 loop_vars=[0, init_state,
                                                            results,
                                                            samples_arr,
                                                            results_arr])
175
176
177
178

    return [s.stack() for s in samples], [r.stack() for r in results]


Chris Jewell's avatar
Chris Jewell committed
179
if __name__ == '__main__':
180

Chris Jewell's avatar
Chris Jewell committed
181
182
183
184
185
    if tf.test.gpu_device_name():
        print('Using GPU')
    else:
        print("Using CPU")

186
    num_loop_iterations = 1000
Chris Jewell's avatar
Chris Jewell committed
187
    num_loop_samples = 100
Chris Jewell's avatar
Chris Jewell committed
188
189
190
191
192
193
194
195
196
197
198
199
200
    current_state = [np.array([0.15, 0.25], dtype=DTYPE),
                     tf.stack([se_events, ei_events, ir_events], axis=-1)]

    posterior = h5py.File(os.path.expandvars(config['output']['posterior']),
                          'w')
    event_size = [num_loop_iterations * num_loop_samples] + list(
        current_state[1].shape)
    par_samples = posterior.create_dataset('samples/parameter', [
        num_loop_iterations * num_loop_samples,
        current_state[0].shape[0]], dtype=np.float64)
    se_samples = posterior.create_dataset('samples/events', event_size,
                                          dtype=DTYPE)
    par_results = posterior.create_dataset('acceptance/parameter', (
Chris Jewell's avatar
Chris Jewell committed
201
    num_loop_iterations * num_loop_samples, 13), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
202
    se_results = posterior.create_dataset('acceptance/S->E', (
Chris Jewell's avatar
Chris Jewell committed
203
    num_loop_iterations * num_loop_samples, 13), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
204
    ei_results = posterior.create_dataset('acceptance/E->I', (
Chris Jewell's avatar
Chris Jewell committed
205
    num_loop_iterations * num_loop_samples, 13), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
206
207

    print("Initial logpi:", logp(*current_state))
Chris Jewell's avatar
Chris Jewell committed
208
209
    par_scale = tf.linalg.diag(
        tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1)
Chris Jewell's avatar
Chris Jewell committed
210
211
212

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
Chris Jewell's avatar
Chris Jewell committed
213
    # with tf.profiler.experimental.Profile('/tmp/tf_logdir'):
Chris Jewell's avatar
Chris Jewell committed
214
    for i in tqdm.tqdm(range(num_loop_iterations), unit_scale=num_loop_samples):
Chris Jewell's avatar
Chris Jewell committed
215
216
        samples, results = sample(num_loop_samples, init_state=current_state,
                                  par_scale=par_scale)
Chris Jewell's avatar
Chris Jewell committed
217
        current_state = [s[-1] for s in samples]
Chris Jewell's avatar
Chris Jewell committed
218
        s = slice(i * num_loop_samples, i * num_loop_samples + num_loop_samples)
Chris Jewell's avatar
Chris Jewell committed
219
        par_samples[s, ...] = samples[0].numpy()
Chris Jewell's avatar
Chris Jewell committed
220
221
222
        cov = np.cov(np.log(
            par_samples[:(i * num_loop_samples + num_loop_samples), ...]),
                     rowvar=False)
Chris Jewell's avatar
Chris Jewell committed
223
224
        print(current_state[0].numpy())
        print(cov)
Chris Jewell's avatar
Chris Jewell committed
225
226
        if (np.all(np.isfinite(cov))):
            par_scale = 2.38 ** 2 * cov / 2.
Chris Jewell's avatar
Chris Jewell committed
227
228
229
230
231
232

        se_samples[s, ...] = samples[1].numpy()
        par_results[s, ...] = results[0].numpy()
        se_results[s, ...] = results[1].numpy()
        ei_results[s, ...] = results[2].numpy()

Chris Jewell's avatar
Chris Jewell committed
233
234
235
236
237
238
        print("Acceptance0:",
              tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
        print("Acceptance1:",
              tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
        print("Acceptance2:",
              tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))
Chris Jewell's avatar
Chris Jewell committed
239
240
241
242

    print(f'Acceptance param: {par_results[:, 1].mean()}')
    print(f'Acceptance S->E: {se_results[:, 1].mean()}')
    print(f'Acceptance E->I: {ei_results[:, 1].mean()}')
243
244

    posterior.close()