inference.py 18.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import sys
5
6
7

import h5py
import xarray
8
9
import tqdm
import yaml
10
import numpy as np
11
12
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
13

14
15
16
17
from tensorflow_probability.python.internal import unnest
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.experimental.stats import sample_stats

18
from gemlib.util import compute_state
Chris Jewell's avatar
Chris Jewell committed
19
from gemlib.mcmc import Posterior
20
from gemlib.mcmc import GibbsKernel
21
from gemlib.distributions import BrownianMotion
22
23
24
25
from covid.tasks.mcmc_kernel_factory import make_hmc_base_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_fast_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_hmc_slow_adapt_kernel
from covid.tasks.mcmc_kernel_factory import make_event_multiscan_gibbs_step
26

27
import covid.model_spec as model_spec
28

29
30
tfd = tfp.distributions
tfb = tfp.bijectors
31
DTYPE = model_spec.DTYPE
32
33


34
def get_weighted_running_variance(draws):
35
    """Initialises online variance accumulator"""
36
37
38
39
40
41
42
43
44
45
46
47

    prev_mean, prev_var = tf.nn.moments(draws[-draws.shape[0] // 2 :], axes=[0])
    num_samples = tf.cast(
        draws.shape[0] / 2,
        dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32),
    )
    weighted_running_variance = sample_stats.RunningVariance.from_stats(
        num_samples=num_samples, mean=prev_mean, variance=prev_var
    )
    return weighted_running_variance


48
49
50
51
52
53
54
55
56
def _get_window_sizes(num_adaptation_steps):
    slow_window_size = num_adaptation_steps // 21
    first_window_size = 3 * slow_window_size
    last_window_size = (
        num_adaptation_steps - 15 * slow_window_size - first_window_size
    )
    return first_window_size, slow_window_size, last_window_size


Chris Jewell's avatar
Chris Jewell committed
57
@tf.function(jit_compile=False)
58
59
60
61
62
63
64
65
66
67
def _fast_adapt_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    hmc_kernel_kwargs,
    dual_averaging_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    """
    In the fast adaptation window, we use the
    `DualAveragingStepSizeAdaptation` kernel
    to wrap an HMC kernel.

    :param num_draws: Number of MCMC draws in window
    :param joint_log_prob_fn: joint log posterior function
    :param initial_position: initial state of the Markov chain
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel keywords args
    :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` keyword args
    :param event_kernel_kwargs: EventTimesMH and Occult kernel args
    :param trace_fn: function to trace kernel results
    :param seed: optional random seed.
    :returns: draws, kernel results, the adapted HMC step size, and variance
              accumulator
    """
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    kernel_list = [
        (
            0,
            make_hmc_fast_adapt_kernel(
                hmc_kernel_kwargs=hmc_kernel_kwargs,
                dual_averaging_kwargs=dual_averaging_kwargs,
            ),
        ),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="fast_adapt",
    )

    draws, trace, fkr = tfp.mcmc.sample_chain(
        num_draws,
        initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
        seed=seed,
    )

    weighted_running_variance = get_weighted_running_variance(draws[0])
    step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
    return draws, trace, step_size, weighted_running_variance


Chris Jewell's avatar
Chris Jewell committed
115
@tf.function(jit_compile=False)
116
117
118
119
120
121
122
123
124
125
126
def _slow_adapt_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    initial_running_variance,
    hmc_kernel_kwargs,
    dual_averaging_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    """In the slow adaptation phase, we adapt the HMC
    step size and mass matrix together.

    :param num_draws: number of MCMC iterations
    :param joint_log_prob_fn: the joint posterior density function
    :param initial_position: initial Markov chain state
    :param initial_running_variance: initial variance accumulator
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kernel kwargs
    :param dual_averaging_kwargs: `DualAveragingStepSizeAdaptation` kwargs
    :param event_kernel_kwargs: EventTimesMH and Occults kwargs
    :param trace_fn: result trace function
    :param seed: optional random seed
    :returns: draws, kernel results, adapted step size, the variance accumulator,
              and "learned" momentum distribution for the HMC.
    """
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    kernel_list = [
        (
            0,
            make_hmc_slow_adapt_kernel(
                initial_running_variance,
                hmc_kernel_kwargs,
                dual_averaging_kwargs,
            ),
        ),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="slow_adapt",
    )

    draws, trace, fkr = tfp.mcmc.sample_chain(
        num_draws,
        current_state=initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
    )

    step_size = unnest.get_outermost(fkr.inner_results[0], "step_size")
    momentum_distribution = unnest.get_outermost(
        fkr.inner_results[0], "momentum_distribution"
    )

    weighted_running_variance = get_weighted_running_variance(draws[0])

    return (
        draws,
        trace,
        step_size,
        weighted_running_variance,
        momentum_distribution,
    )


Chris Jewell's avatar
Chris Jewell committed
184
@tf.function(jit_compile=False)
185
186
187
188
189
190
191
192
193
def _fixed_window(
    num_draws,
    joint_log_prob_fn,
    initial_position,
    hmc_kernel_kwargs,
    event_kernel_kwargs,
    trace_fn=None,
    seed=None,
):
194
195
196
197
198
199
200
201
202
    """Fixed step size and mass matrix HMC.

    :param num_draws: number of MCMC iterations
    :param joint_log_prob_fn: joint log posterior density function
    :param initial_position: initial Markov chain state
    :param hmc_kernel_kwargs: `HamiltonianMonteCarlo` kwargs
    :param event_kernel_kwargs: Event and Occults kwargs
    :param trace_fn: results trace function
    :param seed: optional random seed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    :returns: (draws, trace, final_kernel_results)
    """
    kernel_list = [
        (0, make_hmc_base_kernel(**hmc_kernel_kwargs)),
        (1, make_event_multiscan_gibbs_step(**event_kernel_kwargs)),
    ]

    kernel = GibbsKernel(
        target_log_prob_fn=joint_log_prob_fn,
        kernel_list=kernel_list,
        name="fixed",
    )

    return tfp.mcmc.sample_chain(
        num_draws,
        current_state=initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        trace_fn=trace_fn,
        seed=seed,
    )


def trace_results_fn(_, results):
    """Packs results into a dictionary"""
    results_dict = {}
    root_results = results.inner_results

231
232
233
234
    step_size = tf.convert_to_tensor(
        unnest.get_outermost(root_results[0], "step_size")
    )

235
236
237
238
239
    results_dict["hmc"] = {
        "is_accepted": unnest.get_innermost(root_results[0], "is_accepted"),
        "target_log_prob": unnest.get_innermost(
            root_results[0], "target_log_prob"
        ),
240
        "step_size": step_size,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    }

    def get_move_results(results):
        return {
            "is_accepted": results.is_accepted,
            "target_log_prob": results.accepted_results.target_log_prob,
            "proposed_delta": tf.stack(
                [
                    results.accepted_results.m,
                    results.accepted_results.t,
                    results.accepted_results.delta_t,
                    results.accepted_results.x_star,
                ]
            ),
        }

    res1 = root_results[1].inner_results
    results_dict["move/S->E"] = get_move_results(res1[0])
    results_dict["move/E->I"] = get_move_results(res1[1])
    results_dict["occult/S->E"] = get_move_results(res1[2])
    results_dict["occult/E->I"] = get_move_results(res1[3])

    return results_dict


def draws_to_dict(draws):
    return {
Chris Jewell's avatar
Chris Jewell committed
268
269
270
271
272
273
274
        "psi": draws[0][:, 0],
        "beta_area": draws[0][:, 1],
        "gamma0": draws[0][:, 2],
        "gamma1": draws[0][:, 3],
        "alpha_0": draws[0][:, 4],
        "alpha_t": draws[0][:, 5:],
        "seir": draws[1],
275
276
277
278
    }


def run_mcmc(
279
280
281
282
283
284
    joint_log_prob_fn,
    current_state,
    param_bijector,
    initial_conditions,
    config,
    output_file,
285
286
):

287
288
289
290
291
292
293
294
295
296
297
298
299
300
    # first_window_size, slow_window_size, last_window_size = _get_window_sizes(
    #     config["num_adaptation_iterations"]
    # )

    first_window_size = 200
    last_window_size = 50
    slow_window_size = 25
    num_slow_windows = 6

    warmup_size = int(
        first_window_size
        + slow_window_size
        * ((1 - 2 ** num_slow_windows) / (1 - 2))  # sum geometric series
        + last_window_size
301
    )
302
303

    hmc_kernel_kwargs = {
304
305
        "step_size": 0.1,
        "num_leapfrog_steps": 16,
306
        "momentum_distribution": None,
307
        "store_parameters_in_results": True,
308
309
310
    }
    dual_averaging_kwargs = {
        "target_accept_prob": 0.75,
311
        # "decay_rate": 0.80,
312
313
314
315
    }
    event_kernel_kwargs = {
        "initial_state": initial_conditions,
        "t_range": [
316
            current_state[1].shape[-2] - 21,
317
318
319
320
321
322
            current_state[1].shape[-2],
        ],
        "config": config,
    }

    # Set up posterior
Chris Jewell's avatar
Chris Jewell committed
323
    print("Initialising output...", end="", flush=True, file=sys.stderr)
324
325
326
327
328
329
330
331
332
333
334
335
    draws, trace, _ = _fixed_window(
        num_draws=1,
        joint_log_prob_fn=joint_log_prob_fn,
        initial_position=current_state,
        hmc_kernel_kwargs=hmc_kernel_kwargs,
        event_kernel_kwargs=event_kernel_kwargs,
        trace_fn=trace_results_fn,
    )
    posterior = Posterior(
        output_file,
        sample_dict=draws_to_dict(draws),
        results_dict=trace,
336
        num_samples=warmup_size
337
        + config["num_burst_samples"] * config["num_bursts"],
338
339
    )
    offset = 0
Chris Jewell's avatar
Chris Jewell committed
340
    print("Done", flush=True, file=sys.stderr)
341
342

    # Fast adaptation sampling
Chris Jewell's avatar
Chris Jewell committed
343
    print(f"Fast window {first_window_size}", file=sys.stderr, flush=True)
344
    dual_averaging_kwargs["num_adaptation_steps"] = first_window_size
345
    draws, trace, step_size, running_variance = _fast_adapt_window(
346
        num_draws=first_window_size,
347
348
349
350
351
352
353
        joint_log_prob_fn=joint_log_prob_fn,
        initial_position=current_state,
        hmc_kernel_kwargs=hmc_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        event_kernel_kwargs=event_kernel_kwargs,
        trace_fn=trace_results_fn,
    )
354
355
    current_state = [s[-1] for s in draws]
    draws[0] = param_bijector.inverse(draws[0])
356
    posterior.write_samples(
Chris Jewell's avatar
Chris Jewell committed
357
358
        draws_to_dict(draws),
        first_dim_offset=offset,
359
360
    )
    posterior.write_results(trace, first_dim_offset=offset)
361
    offset += first_window_size
362
363
364
365
366

    # Slow adaptation sampling
    hmc_kernel_kwargs["step_size"] = step_size
    for slow_window_idx in range(num_slow_windows):
        window_num_draws = slow_window_size * (2 ** slow_window_idx)
367
        dual_averaging_kwargs["num_adaptation_steps"] = window_num_draws
Chris Jewell's avatar
Chris Jewell committed
368
        print(f"Slow window {window_num_draws}", file=sys.stderr, flush=True)
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        (
            draws,
            trace,
            step_size,
            running_variance,
            momentum_distribution,
        ) = _slow_adapt_window(
            num_draws=window_num_draws,
            joint_log_prob_fn=joint_log_prob_fn,
            initial_position=current_state,
            initial_running_variance=running_variance,
            hmc_kernel_kwargs=hmc_kernel_kwargs,
            dual_averaging_kwargs=dual_averaging_kwargs,
            event_kernel_kwargs=event_kernel_kwargs,
            trace_fn=trace_results_fn,
        )
        hmc_kernel_kwargs["step_size"] = step_size
        hmc_kernel_kwargs["momentum_distribution"] = momentum_distribution
        current_state = [s[-1] for s in draws]
388
        draws[0] = param_bijector.inverse(draws[0])
389
        posterior.write_samples(
Chris Jewell's avatar
Chris Jewell committed
390
391
            draws_to_dict(draws),
            first_dim_offset=offset,
392
393
394
395
396
        )
        posterior.write_results(trace, first_dim_offset=offset)
        offset += window_num_draws

    # Fast adaptation sampling
Chris Jewell's avatar
Chris Jewell committed
397
    print(f"Fast window {last_window_size}", file=sys.stderr, flush=True)
398
    dual_averaging_kwargs["num_adaptation_steps"] = last_window_size
399
    draws, trace, step_size, _ = _fast_adapt_window(
400
        num_draws=last_window_size,
401
402
403
404
405
406
407
408
        joint_log_prob_fn=joint_log_prob_fn,
        initial_position=current_state,
        hmc_kernel_kwargs=hmc_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        event_kernel_kwargs=event_kernel_kwargs,
        trace_fn=trace_results_fn,
    )
    current_state = [s[-1] for s in draws]
409
    draws[0] = param_bijector.inverse(draws[0])
410
    posterior.write_samples(
Chris Jewell's avatar
Chris Jewell committed
411
412
        draws_to_dict(draws),
        first_dim_offset=offset,
413
414
    )
    posterior.write_results(trace, first_dim_offset=offset)
415
    offset += last_window_size
416
417

    # Fixed window sampling
Chris Jewell's avatar
Chris Jewell committed
418
    print("Sampling...", file=sys.stderr, flush=True)
419
420
421
422
    hmc_kernel_kwargs["step_size"] = tf.reduce_mean(
        trace["hmc"]["step_size"][-last_window_size // 2 :]
    )
    print("Fixed kernel kwargs:", hmc_kernel_kwargs, flush=True)
423
424
425
426
    for i in tqdm.tqdm(
        range(config["num_bursts"]),
        unit_scale=config["num_burst_samples"] * config["thin"],
    ):
427
428
429
430
431
432
433
434
435
        draws, trace, _ = _fixed_window(
            num_draws=config["num_burst_samples"],
            joint_log_prob_fn=joint_log_prob_fn,
            initial_position=current_state,
            hmc_kernel_kwargs=hmc_kernel_kwargs,
            event_kernel_kwargs=event_kernel_kwargs,
            trace_fn=trace_results_fn,
        )
        current_state = [state_part[-1] for state_part in draws]
436
        draws[0] = param_bijector.inverse(draws[0])
437
        posterior.write_samples(
Chris Jewell's avatar
Chris Jewell committed
438
439
            draws_to_dict(draws),
            first_dim_offset=offset,
440
441
        )
        posterior.write_results(
Chris Jewell's avatar
Chris Jewell committed
442
443
            trace,
            first_dim_offset=offset,
444
445
446
447
448
449
        )
        offset += config["num_burst_samples"]

    return posterior


450
def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
451
    """Constructs and runs the MCMC"""
452

453
454
455
456
    if tf.test.gpu_device_name():
        print("Using GPU")
    else:
        print("Using CPU")
457

458
    data = xarray.open_dataset(data_file, group="constant_data")
Chris Jewell's avatar
Chris Jewell committed
459
460
461
462
    cases = xarray.open_dataset(data_file, group="observations")[
        "cases"
    ].astype(DTYPE)
    dates = cases.coords["time"]
463
464

    # Impute censored events, return cases
Chris Jewell's avatar
Chris Jewell committed
465
466
467
468
469
    # Take the last week of data, and repeat a further 3 times
    # to get a better occult initialisation.
    extra_cases = tf.tile(cases[:, -7:], [1, 3])
    cases = tf.concat([cases, extra_cases], axis=-1)
    events = model_spec.impute_censored_events(cases).numpy()
470
471
472
473
474
475
476
477
478

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
479
480
481
482
            [
                tf.constant(data["N"], DTYPE)[:, tf.newaxis],
                tf.zeros_like(events[:, 0, :]),
            ],
483
            axis=-1,
484
        ),
485
        events=events,
486
        stoichiometry=model_spec.STOICHIOMETRY,
487
    )
488
    start_time = state.shape[1] - cases.shape[1]
489
    initial_state = state[:, start_time, :]
Chris Jewell's avatar
Chris Jewell committed
490
    events = events[:, start_time:-21, :]  # Clip off the "extra" events
491
492

    ########################################################
493
    # Construct the MCMC kernels #
494
495
    ########################################################
    model = model_spec.CovidUK(
496
        covariates=data,
497
498
499
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
500
    )
501

502
503
504
505
506
507
508
    param_bij = tfb.Invert(  # Forward transform unconstrains params
        tfb.Blockwise(
            [
                tfb.Softplus(low=dtype_util.eps(DTYPE)),
                tfb.Identity(),
                tfb.Identity(),
            ],
Chris Jewell's avatar
Chris Jewell committed
509
            block_sizes=[1, 3, events.shape[1]],
510
        )
511
512
513
514
    )

    def joint_log_prob(unconstrained_params, events):
        params = param_bij.inverse(unconstrained_params)
515
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
516
            dict(
Chris Jewell's avatar
Chris Jewell committed
517
518
519
520
521
522
                psi=params[0],
                beta_area=params[1],
                gamma0=params[2],
                gamma1=params[3],
                alpha_0=params[4],
                alpha_t=params[5:],
Chris Jewell's avatar
Chris Jewell committed
523
                seir=events,
524
            )
525
526
        ) + param_bij.inverse_log_det_jacobian(
            unconstrained_params, event_ndims=1
527
528
        )

529
    # MCMC tracing functions
530
531
532
    ###############################
    # Construct bursted MCMC loop #
    ###############################
533
534
535
    current_chain_state = [
        tf.concat(
            [
Chris Jewell's avatar
Chris Jewell committed
536
537
538
539
                np.array([0.1, 0.0, 0.0, 0.0], dtype=DTYPE),
                np.full(
                    events.shape[1],
                    -1.75,
540
541
542
543
                    dtype=DTYPE,
                ),
            ],
            axis=0,
544
        ),
545
546
        events,
    ]
Chris Jewell's avatar
Chris Jewell committed
547
548
549
550
    print("Num time steps:", events.shape[1], flush=True)
    print("alpha_t shape", model.event_shape["alpha_t"], flush=True)
    print("Initial chain state:", current_chain_state[0], flush=True)
    print("Initial logpi:", joint_log_prob(*current_chain_state), flush=True)
551

Chris Jewell's avatar
Chris Jewell committed
552
    # Output file
553
554
555
    posterior = run_mcmc(
        joint_log_prob_fn=joint_log_prob,
        current_state=current_chain_state,
556
        param_bijector=param_bij,
557
558
559
        initial_conditions=initial_state,
        config=config,
        output_file=output_file,
560
    )
Chris Jewell's avatar
Chris Jewell committed
561
    posterior._file.create_dataset("initial_state", data=initial_state)
562
    posterior._file.create_dataset(
563
        "time",
Chris Jewell's avatar
Chris Jewell committed
564
        data=np.array(dates).astype(str).astype(h5py.string_dtype()),
565
    )
566

567
    print(f"Acceptance theta: {posterior['results/hmc/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
568
569
570
571
572
573
574
575
576
577
578
579
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
580

Chris Jewell's avatar
Chris Jewell committed
581
    del posterior
582
583
584
585


if __name__ == "__main__":

586
587
588
589
590
591
592
593
594
595
    from argparse import ArgumentParser

    parser = ArgumentParser(description="Run MCMC inference algorithm")
    parser.add_argument(
        "-c", "--config", type=str, help="Config file", required=True
    )
    parser.add_argument(
        "-o", "--output", type=str, help="Output file", required=True
    )
    parser.add_argument(
Chris Jewell's avatar
Chris Jewell committed
596
597
598
        "data_file",
        type=str,
        help="Data pickle file",
599
600
    )
    args = parser.parse_args()
601
602
603
604

    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

605
    mcmc(args.data_file, args.output, config["Mcmc"])