mcmc.py 13.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
5
from collections import OrderedDict
6
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
7

Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12
13
14
import tqdm
import yaml

15
16
from covid import config
from covid.model import load_data, CovidUKStochastic
17
18
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
19
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
20
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
21
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
22

Chris Jewell's avatar
Chris Jewell committed
23
24
25
###########
# TF Bits #
###########
26

Chris Jewell's avatar
Chris Jewell committed
27
28
29
tfd = tfp.distributions
tfb = tfp.bijectors

30
31
DTYPE = config.floatX

32
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
33
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
34

35
36
37
38
39
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

40
# Read in settings
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# parser = optparse.OptionParser()
# parser.add_option(
#     "--config",
#     "-c",
#     dest="config",
#     default="ode_config.yaml",
#     help="configuration file",
# )
# options, cmd_args = parser.parse_args()
# print("Loading config file:", options.config)

# with open(options.config, "r") as f:
with open("ode_config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
55

56
settings = sanitise_settings(config["settings"])
57

Chris Jewell's avatar
Chris Jewell committed
58
param = sanitise_parameter(config["parameter"])
59
60
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

61
62
63
64
65
66
67
covar_data = load_data(config["data"], settings, DTYPE)

cases = phe_case_data(config["data"]["reported_cases"], settings["inference_period"])
ei_events, lag_ei = impute_previous_cases(cases, 0.25)
se_events, lag_se = impute_previous_cases(ei_events, 0.25)
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
Chris Jewell's avatar
Chris Jewell committed
68

69

Chris Jewell's avatar
Chris Jewell committed
70
model = CovidUKStochastic(
71
72
73
    C=covar_data["C"],
    N=covar_data["pop"],
    W=covar_data["W"],
Chris Jewell's avatar
Chris Jewell committed
74
75
76
77
78
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
    lockdown=settings["lockdown"],
    time_step=1.0,
)
79

Chris Jewell's avatar
Chris Jewell committed
80
81
82
##########################
# Log p and MCMC kernels #
##########################
83
84


85
def logp(par, events, occult_events):
86
    p = param
Chris Jewell's avatar
Chris Jewell committed
87
    p["beta1"] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
88
89
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
90
91
92
93
    p["gamma"] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
Chris Jewell's avatar
Chris Jewell committed
94
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
95
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
96
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
97
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
98
99
100
101
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
102
    with tf.name_scope("epidemic_log_posterior"):
103
        y_logp = model.log_prob(events + occult_events, p, state_init)
Chris Jewell's avatar
Chris Jewell committed
104
    logp = beta1_logp + gamma_logp + y_logp
105
106
107
108
109
110
111
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
112
113
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
114
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
115
116
117
118
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
            name="parameter_update",
        )
Chris Jewell's avatar
Chris Jewell committed
119

120
121
122
    return kernel_func


123
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
124
    def kernel_func(logp):
125
126
127
128
129
130
131
132
133
134
135
136
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                initial_state=state_init,
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
            ),
            name="event_update",
Chris Jewell's avatar
Chris Jewell committed
137
        )
Chris Jewell's avatar
Chris Jewell committed
138

139
140
141
    return kernel_func


142
143
144
145
146
147
148
def make_occults_step(target_event_id):
    def kernel_func(logp):
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                nmax=config["mcmc"]["occult_nmax"],
149
                t_range=[se_events.shape[1] - 21, se_events.shape[1]],
150
151
152
153
154
            ),
            name="occult_update",
        )

    return kernel_func
155
156


157
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
158
    if hasattr(result, "is_accepted"):
159
        return tf.cast(result.is_accepted, DTYPE)
160
    return is_accepted(result.inner_results)
161
162


163
164
165
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
166
    q_ratio = results.proposed_results.log_acceptance_correction
167
    if hasattr(results.proposed_results, "extra"):
168
        proposed = tf.cast(results.proposed_results.extra, log_prob.dtype)
169
        return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
170
    return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
171

172

173
174
175
176
177
def forward_results(prev_results, next_results):
    accepted_results = next_results.accepted_results._replace(
        target_log_prob=prev_results.accepted_results.target_log_prob
    )
    return next_results._replace(accepted_results=accepted_results)
178

179

180
@tf.function(autograph=False, experimental_compile=True)
181
def sample(n_samples, init_state, par_scale, num_event_updates):
Chris Jewell's avatar
Chris Jewell committed
182
183
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
184
        par_func = make_parameter_kernel(par_scale, 0.0)
Chris Jewell's avatar
Chris Jewell committed
185
186
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)
187
188
        se_occult = make_occults_step(0)
        ei_occult = make_occults_step(1)
Chris Jewell's avatar
Chris Jewell committed
189
190
191

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        results = [
            par_func(lambda p: logp(p, init_state[1], init_state[2])).bootstrap_results(
                init_state[0]
            ),
            se_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
                init_state[1]
            ),
            ei_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
                init_state[1]
            ),
            se_occult(
                lambda s: logp(init_state[0], init_state[1], s)
            ).bootstrap_results(init_state[2]),
            ei_occult(
                lambda s: logp(init_state[0], init_state[1], s)
            ).bootstrap_results(init_state[2]),
        ]
209

Chris Jewell's avatar
Chris Jewell committed
210
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
211
        results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(5)]
212

213
        def body(i, state, results, sample_accum, results_accum):
Chris Jewell's avatar
Chris Jewell committed
214
215
216
217
            # Parameters
            def par_logp(par_state):
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
218

219
220
221
            par_kernel = par_func(par_logp)
            state[0], results[0] = par_kernel.one_step(
                state[0], par_kernel.bootstrap_results(state[0])
222
            )
Chris Jewell's avatar
Chris Jewell committed
223

Chris Jewell's avatar
Chris Jewell committed
224
            # States
225
            results[4] = forward_results(results[0], results[4])
226
227
228
229
230
231

            def infec_body(j, state, results):
                def state_logp(event_state):
                    state[1] = event_state
                    return logp(*state)

232
233
234
235
                def occult_logp(occult_state):
                    state[2] = occult_state
                    return logp(*state)

236
                state[1], results[1] = se_func(state_logp).one_step(
237
                    state[1], forward_results(results[4], results[1])
238
239
240
241
                )
                state[1], results[2] = ei_func(state_logp).one_step(
                    state[1], forward_results(results[1], results[2])
                )
242
243
244
                state[2], results[3] = se_occult(occult_logp).one_step(
                    state[2], forward_results(results[2], results[3])
                )
245
                # results[3] = forward_results(results[2], results[3])
246
247
248
249
                state[2], results[4] = ei_occult(occult_logp).one_step(
                    state[2], forward_results(results[3], results[4])
                )
                # results[4] = forward_results(results[3], results[4])
250
251
252
253
254
255
256
257
258
259
                j += 1
                return j, state, results

            def infec_cond(j, state, results):
                return j < num_event_updates

            _, state, results = tf.while_loop(
                infec_cond,
                infec_body,
                loop_vars=[tf.constant(0, tf.int32), state, results],
260
            )
261

262
263
264
265
            sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
            results_accum = [
                results_accum[k].write(i, trace_results_fn(r))
                for k, r in enumerate(results)
Chris Jewell's avatar
Chris Jewell committed
266
            ]
267
            return i + 1, state, results, sample_accum, results_accum
268

Chris Jewell's avatar
Chris Jewell committed
269
270
        def cond(i, _1, _2, _3, _4):
            return i < n_samples
271

Chris Jewell's avatar
Chris Jewell committed
272
273
274
275
276
        _1, _2, _3, samples, results = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[0, init_state, results, samples_arr, results_arr],
        )
277

Chris Jewell's avatar
Chris Jewell committed
278
        return [s.stack() for s in samples], [r.stack() for r in results]
279
280


Chris Jewell's avatar
Chris Jewell committed
281
282
283
284
##################
# MCMC loop here #
##################

285
# MCMC Control
286
287
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
288
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
Chris Jewell's avatar
Chris Jewell committed
289

290
# RNG stuff
291
tf.random.set_seed(2)
292
293

# Initial state.  NB [M, T, X] layout for events.
294
295
296
297
298
299
300
301
events = tf.stack([se_events, ei_events, ir_events], axis=-1)
state_init = tf.concat([model.N[:, tf.newaxis], events[:, 0, :]], axis=-1)
events = events[:, 1:, :]
current_state = [
    np.array([0.85, 0.25], dtype=DTYPE),
    events,
    tf.zeros_like(events),
]
302
303

# Output Files
Chris Jewell's avatar
Chris Jewell committed
304
posterior = h5py.File(
305
306
307
308
    os.path.expandvars(config["output"]["posterior"]),
    "w",
    rdcc_nbytes=1024 ** 2 * 400,
    rdcc_nslots=100000,
Chris Jewell's avatar
Chris Jewell committed
309
)
310
event_size = [NUM_BURSTS * NUM_BURST_SAMPLES] + list(current_state[1].shape)
Chris Jewell's avatar
Chris Jewell committed
311
312
# event_chunk = (10, 1, 1, 1)
# print("Event chunk size:", event_chunk)
Chris Jewell's avatar
Chris Jewell committed
313
314
par_samples = posterior.create_dataset(
    "samples/parameter",
315
    [NUM_BURSTS * NUM_BURST_SAMPLES, current_state[0].shape[0]],
Chris Jewell's avatar
Chris Jewell committed
316
317
    dtype=np.float64,
)
318
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
319
320
321
    "samples/events",
    event_size,
    dtype=DTYPE,
322
323
    chunks=(1024, 64, 64, current_state[1].shape[-1]),
    compression="lzf",
Chris Jewell's avatar
Chris Jewell committed
324
)
325
326
327
occult_samples = posterior.create_dataset(
    "samples/occults",
    event_size,
328
    dtype=DTYPE,
329
330
    chunks=(1024, 64, 64, current_state[1].shape[-1]),
    compression="lzf",
Chris Jewell's avatar
Chris Jewell committed
331
332
)

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
output_results = [
    posterior.create_dataset(
        "results/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/move/S->E",
        (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
        dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/move/E->I",
        (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
        dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/occult/S->E", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
    ),
    posterior.create_dataset(
        "results/occult/E->I", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
    ),
]
354

Chris Jewell's avatar
Chris Jewell committed
355
356
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
357
    tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
Chris Jewell's avatar
Chris Jewell committed
358
359
360
361
)

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
362
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
363
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
364
    samples, results = sample(
365
366
367
368
        NUM_BURST_SAMPLES,
        init_state=current_state,
        par_scale=par_scale,
        num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
369
370
    )
    current_state = [s[-1] for s in samples]
371
    s = slice(i * NUM_BURST_SAMPLES, i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES)
372
373
    par_samples[s, ...] = samples[0].numpy()
    cov = np.cov(
374
        np.log(par_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
375
376
377
        rowvar=False,
    )
    print(current_state[0].numpy())
378

379
    print(cov)
380
381
    if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
        par_scale = 2.38 ** 2 * cov / 2.0
382

383
    start = perf_counter()
384
385
    event_samples[s, ...] = samples[1].numpy()
    occult_samples[s, ...] = samples[2].numpy()
386
387
    end = perf_counter()

388
389
    for i, ro in enumerate(output_results):
        ro[s, ...] = results[i]
390

391
    print("Storage time:", end - start, "seconds")
392
393
394
395
396
397
398
399
400
401
402
403
404
    print("Acceptance par:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print(
        "Acceptance move S->E:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32))
    )
    print(
        "Acceptance move E->I:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32))
    )
    print(
        "Acceptance occult S->E:", tf.reduce_mean(tf.cast(results[3][:, 1], tf.float32))
    )
    print(
        "Acceptance occult E->I:", tf.reduce_mean(tf.cast(results[4][:, 1], tf.float32))
    )
Chris Jewell's avatar
Chris Jewell committed
405

406
407
408
409
410
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
411
412

posterior.close()