mcmc.py 15.6 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
5
from collections import OrderedDict
6
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
7

Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12
13
14
import tqdm
import yaml

15
16
from covid import config
from covid.model import load_data, CovidUKStochastic
17
18
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
19
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
20
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
21
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
22

Chris Jewell's avatar
Chris Jewell committed
23
24
25
###########
# TF Bits #
###########
26

Chris Jewell's avatar
Chris Jewell committed
27
28
29
tfd = tfp.distributions
tfb = tfp.bijectors

30
31
DTYPE = config.floatX

32
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
33
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
34

35
36
37
38
39
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

40
# Read in settings
41
42
43
44
45
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
Chris Jewell's avatar
Chris Jewell committed
46
    default="example_config.yaml",
47
48
49
50
51
52
    help="configuration file",
)
options, cmd_args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
53
    config = yaml.load(f, Loader=yaml.FullLoader)
54

55
settings = sanitise_settings(config["settings"])
56

Chris Jewell's avatar
Chris Jewell committed
57
param = sanitise_parameter(config["parameter"])
58
59
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

60
61
62
covar_data = load_data(config["data"], settings, DTYPE)

cases = phe_case_data(config["data"]["reported_cases"], settings["inference_period"])
63
64
ei_events, lag_ei = impute_previous_cases(cases, 0.44)
se_events, lag_se = impute_previous_cases(ei_events, 2.0)
65
66
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
Chris Jewell's avatar
Chris Jewell committed
67

68

Chris Jewell's avatar
Chris Jewell committed
69
model = CovidUKStochastic(
70
71
72
    C=covar_data["C"],
    N=covar_data["pop"],
    W=covar_data["W"],
Chris Jewell's avatar
Chris Jewell committed
73
74
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
75
    xi_freq=14,
Chris Jewell's avatar
Chris Jewell committed
76
77
    time_step=1.0,
)
78

Chris Jewell's avatar
Chris Jewell committed
79
80
81
##########################
# Log p and MCMC kernels #
##########################
82
def logp(theta, xi, events, occult_events):
83
    p = param
84
85
86
87
88
    p["beta1"] = tf.convert_to_tensor(theta[0], dtype=DTYPE)
    p["beta2"] = tf.convert_to_tensor(theta[1], dtype=DTYPE)
    p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
    p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
    print("XI: ", p["xi"])
Chris Jewell's avatar
Chris Jewell committed
89
90
91
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
92
93
94
95
96
97
98
99
100

    sigma = tf.constant(0.1, dtype=DTYPE)
    phi = tf.constant(12.0, dtype=DTYPE)
    kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
    xi_logp = tfd.GaussianProcess(
        kernel, index_points=tf.cast(model.xi_times[:, tf.newaxis], DTYPE)
    ).log_prob(p["xi"])

    spatial_beta_logp = tfd.Gamma(
Chris Jewell's avatar
Chris Jewell committed
101
102
        concentration=tf.constant(3.0, dtype=DTYPE), rate=tf.constant(10.0, dtype=DTYPE)
    ).log_prob(p["beta2"])
103

Chris Jewell's avatar
Chris Jewell committed
104
105
106
107
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
108
    with tf.name_scope("epidemic_log_posterior"):
109
        y_logp = model.log_prob(events + occult_events, p, state_init)
110
    logp = beta1_logp + spatial_beta_logp + gamma_logp + xi_logp + y_logp
111
112
113
114
115
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
116
def make_theta_kernel(scale, bounded_convergence):
117
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
118
119
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
120
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
121
122
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
123
124
125
126
127
128
129
130
131
132
133
134
            name="theta_update",
        )

    return kernel_func


def make_xi_kernel(scale, bounded_convergence):
    def kernel_func(logp):
        return tfp.mcmc.RandomWalkMetropolis(
            target_log_prob_fn=logp,
            new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            name="xi_update",
Chris Jewell's avatar
Chris Jewell committed
135
        )
Chris Jewell's avatar
Chris Jewell committed
136

137
138
139
    return kernel_func


140
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
141
    def kernel_func(logp):
142
143
144
145
146
147
148
149
150
151
152
153
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                initial_state=state_init,
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
            ),
            name="event_update",
Chris Jewell's avatar
Chris Jewell committed
154
        )
Chris Jewell's avatar
Chris Jewell committed
155

156
157
158
    return kernel_func


159
160
161
162
163
164
165
def make_occults_step(target_event_id):
    def kernel_func(logp):
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                nmax=config["mcmc"]["occult_nmax"],
166
                t_range=[se_events.shape[1] - 21, se_events.shape[1]],
167
168
169
170
171
            ),
            name="occult_update",
        )

    return kernel_func
172
173


174
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
175
    if hasattr(result, "is_accepted"):
176
        return tf.cast(result.is_accepted, DTYPE)
177
    return is_accepted(result.inner_results)
178
179


180
181
182
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
183
    q_ratio = results.proposed_results.log_acceptance_correction
184
    if hasattr(results.proposed_results, "extra"):
185
        proposed = tf.cast(results.proposed_results.extra, log_prob.dtype)
186
        return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
187
    return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
188

189

Chris Jewell's avatar
Chris Jewell committed
190
191
192
193
194
195
196
def get_tlp(results):
    return results.accepted_results.target_log_prob


def put_tlp(results, target_log_prob):
    accepted_results = results.accepted_results._replace(
        target_log_prob=target_log_prob
197
    )
Chris Jewell's avatar
Chris Jewell committed
198
199
200
201
202
203
204
    return results._replace(accepted_results=accepted_results)


def invoke_one_step(kernel, state, previous_results, target_log_prob):
    current_results = put_tlp(previous_results, target_log_prob)
    new_state, new_results = kernel.one_step(state, current_results)
    return new_state, new_results, get_tlp(new_results)
205

206

207
@tf.function(autograph=False, experimental_compile=True)
208
def sample(n_samples, init_state, theta_scale, xi_scale, num_event_updates):
Chris Jewell's avatar
Chris Jewell committed
209
210
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
211
212
        theta_func = make_theta_kernel(theta_scale, 0.0)
        xi_func = make_xi_kernel(xi_scale, 0.0)
Chris Jewell's avatar
Chris Jewell committed
213
214
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)
215
216
        se_occult = make_occults_step(0)
        ei_occult = make_occults_step(1)
Chris Jewell's avatar
Chris Jewell committed
217
218
219

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
220
        results = [
221
222
223
224
225
226
227
228
            theta_func(
                lambda p: logp(p, init_state[1], init_state[2], init_state[3])
            ).bootstrap_results(init_state[0]),
            xi_func(
                lambda p: logp(init_state[0], p, init_state[2], init_state[3])
            ).bootstrap_results(init_state[1]),
            se_func(
                lambda s: logp(init_state[0], init_state[1], s, init_state[3])
229
            ).bootstrap_results(init_state[2]),
230
231
            ei_func(
                lambda s: logp(init_state[0], init_state[1], s, init_state[3])
232
            ).bootstrap_results(init_state[2]),
233
234
235
236
237
238
            se_occult(
                lambda s: logp(init_state[0], init_state[1], init_state[2], s)
            ).bootstrap_results(init_state[3]),
            ei_occult(
                lambda s: logp(init_state[0], init_state[1], init_state[2], s)
            ).bootstrap_results(init_state[3]),
239
        ]
240

Chris Jewell's avatar
Chris Jewell committed
241
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
242
243
244
        results_arr = [
            tf.TensorArray(DTYPE, size=n_samples) for r in range(len(results))
        ]
245

Chris Jewell's avatar
Chris Jewell committed
246
        def body(i, state, results, target_log_prob, sample_accum, results_accum):
Chris Jewell's avatar
Chris Jewell committed
247
            # Parameters
Chris Jewell's avatar
Chris Jewell committed
248

249
            def theta_logp(par_state):
Chris Jewell's avatar
Chris Jewell committed
250
251
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
252

Chris Jewell's avatar
Chris Jewell committed
253
            state[0], results[0], target_log_prob = invoke_one_step(
254
255
256
257
258
259
260
261
262
                theta_func(theta_logp), state[0], results[0], target_log_prob,
            )

            def xi_logp(xi_state):
                state[1] = xi_state
                return logp(*state)

            state[1], results[1], target_log_prob = invoke_one_step(
                xi_func(xi_logp), state[1], results[1], target_log_prob,
263
            )
Chris Jewell's avatar
Chris Jewell committed
264

Chris Jewell's avatar
Chris Jewell committed
265
            def infec_body(j, state, results, target_log_prob):
266
                def state_logp(event_state):
267
                    state[2] = event_state
268
269
                    return logp(*state)

270
                def occult_logp(occult_state):
271
                    state[3] = occult_state
272
273
                    return logp(*state)

274
275
                state[2], results[2], target_log_prob = invoke_one_step(
                    se_func(state_logp), state[2], results[2], target_log_prob
276
                )
Chris Jewell's avatar
Chris Jewell committed
277

278
279
                state[2], results[3], target_log_prob = invoke_one_step(
                    ei_func(state_logp), state[2], results[3], target_log_prob
280
                )
Chris Jewell's avatar
Chris Jewell committed
281

282
283
                state[3], results[4], target_log_prob = invoke_one_step(
                    se_occult(occult_logp), state[3], results[4], target_log_prob
284
                )
Chris Jewell's avatar
Chris Jewell committed
285

286
287
                state[3], results[5], target_log_prob = invoke_one_step(
                    ei_occult(occult_logp), state[3], results[5], target_log_prob
288
                )
Chris Jewell's avatar
Chris Jewell committed
289

290
                j += 1
Chris Jewell's avatar
Chris Jewell committed
291
                return j, state, results, target_log_prob
292

Chris Jewell's avatar
Chris Jewell committed
293
            def infec_cond(j, state, results, target_log_prob):
294
295
                return j < num_event_updates

Chris Jewell's avatar
Chris Jewell committed
296
            _, state, results, target_log_prob = tf.while_loop(
297
298
                infec_cond,
                infec_body,
Chris Jewell's avatar
Chris Jewell committed
299
                loop_vars=[tf.constant(0, tf.int32), state, results, target_log_prob],
300
            )
301

302
303
304
305
            sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
            results_accum = [
                results_accum[k].write(i, trace_results_fn(r))
                for k, r in enumerate(results)
Chris Jewell's avatar
Chris Jewell committed
306
            ]
Chris Jewell's avatar
Chris Jewell committed
307
            return i + 1, state, results, target_log_prob, sample_accum, results_accum
308

Chris Jewell's avatar
Chris Jewell committed
309
        def cond(i, *_):
Chris Jewell's avatar
Chris Jewell committed
310
            return i < n_samples
311

Chris Jewell's avatar
Chris Jewell committed
312
        _1, _2, _3, target_log_prob, samples, results = tf.while_loop(
Chris Jewell's avatar
Chris Jewell committed
313
314
            cond=cond,
            body=body,
Chris Jewell's avatar
Chris Jewell committed
315
316
317
318
319
320
321
322
            loop_vars=[
                0,
                init_state,
                results,
                logp(*init_state),
                samples_arr,
                results_arr,
            ],
Chris Jewell's avatar
Chris Jewell committed
323
        )
324

Chris Jewell's avatar
Chris Jewell committed
325
        return [s.stack() for s in samples], [r.stack() for r in results]
326
327


Chris Jewell's avatar
Chris Jewell committed
328
329
330
331
##################
# MCMC loop here #
##################

332
# MCMC Control
333
334
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
335
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
336
337
THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS
Chris Jewell's avatar
Chris Jewell committed
338

339
# RNG stuff
340
tf.random.set_seed(2)
341
342

# Initial state.  NB [M, T, X] layout for events.
343
344
345
346
events = tf.stack([se_events, ei_events, ir_events], axis=-1)
state_init = tf.concat([model.N[:, tf.newaxis], events[:, 0, :]], axis=-1)
events = events[:, 1:, :]
current_state = [
347
    np.array([0.85, 0.3, 0.25], dtype=DTYPE),
348
    np.zeros(model.num_xi, dtype=DTYPE),
349
350
351
    events,
    tf.zeros_like(events),
]
352
353

# Output Files
Chris Jewell's avatar
Chris Jewell committed
354
posterior = h5py.File(
355
356
357
358
    os.path.expandvars(config["output"]["posterior"]),
    "w",
    rdcc_nbytes=1024 ** 2 * 400,
    rdcc_nslots=100000,
Chris Jewell's avatar
Chris Jewell committed
359
)
360
event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)
361

362
363
364
365
366
theta_samples = posterior.create_dataset(
    "samples/theta", [NUM_SAVED_SAMPLES, current_state[0].shape[0]], dtype=np.float64,
)
xi_samples = posterior.create_dataset(
    "samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
Chris Jewell's avatar
Chris Jewell committed
367
)
368
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
369
370
371
    "samples/events",
    event_size,
    dtype=DTYPE,
372
373
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
374
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
375
)
376
377
378
occult_samples = posterior.create_dataset(
    "samples/occults",
    event_size,
379
    dtype=DTYPE,
380
381
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
382
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
383
384
)

385
output_results = [
386
387
    posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
    posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
388
    posterior.create_dataset(
389
        "results/move/S->E", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
390
391
    ),
    posterior.create_dataset(
392
        "results/move/E->I", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
393
394
    ),
    posterior.create_dataset(
395
        "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
396
397
    ),
    posterior.create_dataset(
398
        "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
399
400
    ),
]
401

Chris Jewell's avatar
Chris Jewell committed
402
print("Initial logpi:", logp(*current_state))
403

404
theta_scale = tf.constant(
Chris Jewell's avatar
Chris Jewell committed
405
406
    [[0.1, 0.0, 0.0], [0.0, 0.8, 0.0], [0.0, 0.0, 0.1]], dtype=current_state[0].dtype
)
407
408
409
xi_scale = tf.linalg.diag(
    tf.constant([0.1] * model.num_xi.numpy(), dtype=current_state[1].dtype)
)
Chris Jewell's avatar
Chris Jewell committed
410
411
412

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
413
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
414
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
415
    samples, results = sample(
416
417
        NUM_BURST_SAMPLES,
        init_state=current_state,
418
419
        theta_scale=theta_scale,
        xi_scale=xi_scale,
420
        num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
421
422
    )
    current_state = [s[-1] for s in samples]
423
424
    s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
    idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
425
426
    theta_samples[s, ...] = tf.gather(samples[0], idx)
    xi_samples[s, ...] = tf.gather(samples[1], idx)
427
    cov = np.cov(
428
        np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
429
430
        rowvar=False,
    )
431
432
    print(current_state[0].numpy(), flush=True)
    print(cov, flush=True)
433
    if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
434
        theta_scale = 2.38 ** 2 * cov / 2.0
435

436
    start = perf_counter()
437
438
    event_samples[s, ...] = tf.gather(samples[2], idx)
    occult_samples[s, ...] = tf.gather(samples[3], idx)
439
440
    end = perf_counter()

441
    for i, ro in enumerate(output_results):
442
        ro[s, ...] = tf.gather(results[i], idx)
443

444
    print("Storage time:", end - start, "seconds")
445
446
    print("Acceptance theta:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print("Acceptance xi:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
447
    print(
448
        "Acceptance move S->E:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32))
449
450
    )
    print(
451
        "Acceptance move E->I:", tf.reduce_mean(tf.cast(results[3][:, 1], tf.float32))
452
453
    )
    print(
454
        "Acceptance occult S->E:", tf.reduce_mean(tf.cast(results[4][:, 1], tf.float32))
455
456
    )
    print(
457
        "Acceptance occult E->I:", tf.reduce_mean(tf.cast(results[5][:, 1], tf.float32))
458
    )
Chris Jewell's avatar
Chris Jewell committed
459

460
461
462
463
464
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
465
466

posterior.close()