inference.py 12.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
6
import h5py
import pickle as pkl
7
from time import perf_counter
8
9
import tqdm
import yaml
10
import numpy as np
11
12
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
13

14
15
16
17
18
19
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
20
from gemlib.mcmc import Posterior
21

22
import covid.model_spec as model_spec
23

24
25
tfd = tfp.distributions
tfb = tfp.bijectors
26
DTYPE = model_spec.DTYPE
27
28


29
def mcmc(data_file, output_file, config):
30
    """Constructs and runs the MCMC"""
31

32
33
34
35
    if tf.test.gpu_device_name():
        print("Using GPU")
    else:
        print("Using CPU")
36

37
38
    with open(data_file, "rb") as f:
        data = pkl.load(f)
39
40
41
42

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
    # Impute censored events, return cases
43
    events = model_spec.impute_censored_events(data["cases"].astype(DTYPE))
44
45
46
47
48
49
50
51
52

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
53
            [data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
54
            axis=-1,
55
        ),
56
        events=events,
57
        stoichiometry=model_spec.STOICHIOMETRY,
58
    )
59
    start_time = state.shape[1] - data["cases"].shape[1]
60
61
62
63
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]

    ########################################################
64
    # Construct the MCMC kernels #
65
66
    ########################################################
    model = model_spec.CovidUK(
67
        covariates=data,
68
69
70
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
71
    )
72

73
    def joint_log_prob(block0, block1, events):
74
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
75
            dict(
76
                beta2=block0[0],
77
78
                gamma0=block0[1],
                gamma1=block0[2],
Chris Jewell's avatar
Chris Jewell committed
79
                sigma=block0[3],
Chris Jewell's avatar
Chris Jewell committed
80
                beta3=block0[4:],
Chris Jewell's avatar
Chris Jewell committed
81
                beta1=block1[0],
Chris Jewell's avatar
Chris Jewell committed
82
                xi=block1[1:],
Chris Jewell's avatar
Chris Jewell committed
83
84
                seir=events,
            )
85
        )
86

87
88
89
90
91
92
93
94
95
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
96
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
97
        def fn(target_log_prob_fn, _):
98
            return tfp.mcmc.TransformedTransitionKernel(
99
                inner_kernel=AdaptiveRandomWalkMetropolis(
100
                    target_log_prob_fn=target_log_prob_fn,
101
102
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
103
104
                    covariance_burnin=200,
                ),
105
                bijector=tfp.bijectors.Blockwise(
Chris Jewell's avatar
Chris Jewell committed
106
107
108
109
110
111
                    bijectors=[
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                    ],
Chris Jewell's avatar
Chris Jewell committed
112
                    block_sizes=[1, 2, 1, 4],
113
                ),
114
115
                name=name,
            )
116

117
118
        return fn

119
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
120
        def fn(target_log_prob_fn, _):
121
            return AdaptiveRandomWalkMetropolis(
122
                target_log_prob_fn=target_log_prob_fn,
123
124
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
125
126
127
128
129
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
130

131
132
133
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
134
        def fn(target_log_prob_fn, _):
135
            return tfp.mcmc.MetropolisHastings(
136
                inner_kernel=UncalibratedEventTimesUpdate(
137
                    target_log_prob_fn=target_log_prob_fn,
138
139
140
141
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
142
143
144
                    dmax=config["dmax"],
                    mmax=config["m"],
                    nmax=config["nmax"],
145
146
147
148
149
                ),
                name=name,
            )

        return fn
150

151
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
152
        def fn(target_log_prob_fn, _):
153
            return tfp.mcmc.MetropolisHastings(
154
                inner_kernel=UncalibratedOccultUpdate(
155
                    target_log_prob_fn=target_log_prob_fn,
156
157
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
158
                    ),
159
                    cumulative_event_offset=initial_state,
160
                    nmax=config["occult_nmax"],
161
162
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
163
164
165
166
167
168
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
169
    def make_event_multiscan_kernel(target_log_prob_fn, _):
170
        return MultiScanKernel(
171
            config["num_event_time_updates"],
172
173
174
175
176
177
178
179
180
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
181
            ),
182
183
        )

184
185
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
186
187
188
189
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

190
        results_dict["block0"] = {
Chris Jewell's avatar
Chris Jewell committed
191
192
193
194
195
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
196
        results_dict["block1"] = {
Chris Jewell's avatar
Chris Jewell committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
214

Chris Jewell's avatar
Chris Jewell committed
215
216
217
218
219
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
220

Chris Jewell's avatar
Chris Jewell committed
221
        return results_dict
222
223

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
224
    @tf.function  # (autograph=False, experimental_compile=True)
Chris Jewell's avatar
Chris Jewell committed
225
    def sample(n_samples, init_state, thin=0, previous_results=None):
226
227
228
229
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

230
            gibbs_schema = GibbsKernel(
231
                target_log_prob_fn=joint_log_prob,
232
                kernel_list=[
233
234
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
235
                    (2, make_event_multiscan_kernel),
236
                ],
237
                name="gibbs0",
238
            )
Chris Jewell's avatar
Chris Jewell committed
239

240
241
242
243
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
244
                num_steps_between_results=thin,
245
246
247
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
248
249
            )

250
            return samples, results, final_results
251

252
253
254
    ###############################
    # Construct bursted MCMC loop #
    ###############################
255
256
    NUM_BURSTS = int(config["num_bursts"])
    NUM_BURST_SAMPLES = int(config["num_burst_samples"])
257
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
258
259
260
261
262

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
263
        np.array([0.6, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE),
264
        np.zeros(
265
            model.model["xi"](0.0, 0.1).event_shape[-1] + 1,
266
267
            dtype=DTYPE,
        ),
268
269
        events,
    ]
270
    print("Initial logpi:", joint_log_prob(*current_state))
271

Chris Jewell's avatar
Chris Jewell committed
272
273
274
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
275
        output_file,
276
277
278
279
        sample_dict={
            "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)),
            "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)),
            "gamma1": (samples[0][:, 2], (NUM_BURST_SAMPLES,)),
Chris Jewell's avatar
Chris Jewell committed
280
281
            "sigma": (samples[0][:, 3], (NUM_BURST_SAMPLES,)),
            "beta3": (samples[0][:, 4:], (NUM_BURST_SAMPLES, 2)),
282
283
284
285
286
287
288
            "beta1": (samples[1][:, 0], (NUM_BURST_SAMPLES,)),
            "xi": (
                samples[1][:, 1:],
                (NUM_BURST_SAMPLES, samples[1].shape[1] - 1),
            ),
            "events": (samples[2], (NUM_BURST_SAMPLES, 64, 64, 1)),
        },
Chris Jewell's avatar
Chris Jewell committed
289
290
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
291
    )
Chris Jewell's avatar
Chris Jewell committed
292
    posterior._file.create_dataset("initial_state", data=initial_state)
293
294
295
296
    posterior._file.create_dataset(
        "date_range",
        data=np.array(data["date_range"]).astype(h5py.string_dtype()),
    )
297
298
299
    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
300
    final_results = None
301
    for i in tqdm.tqdm(
302
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["thin"]
303
    ):
304
        samples, results, final_results = sample(
305
306
            NUM_BURST_SAMPLES,
            init_state=current_state,
307
            thin=config["thin"] - 1,
308
            previous_results=final_results,
309
310
311
312
313
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
314
        posterior.write_samples(
315
316
317
318
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
Chris Jewell's avatar
Chris Jewell committed
319
320
                "sigma": samples[0][:, 3],
                "beta3": samples[0][:, 4:],
321
322
323
324
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
                "events": samples[2],
            },
Chris Jewell's avatar
Chris Jewell committed
325
326
327
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
328
329
330
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
331
332
        print("Results type: ", type(results))
        for k, v in results.items():
333
334
335
336
            print(
                f"Acceptance {k}:",
                tf.reduce_mean(tf.cast(v["is_accepted"], tf.float32)),
            )
Chris Jewell's avatar
Chris Jewell committed
337

Chris Jewell's avatar
Chris Jewell committed
338
    print(
339
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
Chris Jewell's avatar
Chris Jewell committed
340
    )
341
    print(f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
342
343
344
345
346
347
348
349
350
351
352
353
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
354

Chris Jewell's avatar
Chris Jewell committed
355
    del posterior
356
357
358
359


if __name__ == "__main__":

360
361
362
363
364
365
366
367
368
369
370
371
372
    from argparse import ArgumentParser

    parser = ArgumentParser(description="Run MCMC inference algorithm")
    parser.add_argument(
        "-c", "--config", type=str, help="Config file", required=True
    )
    parser.add_argument(
        "-o", "--output", type=str, help="Output file", required=True
    )
    parser.add_argument(
        "data_file", type=str, help="Data pickle file", required=True
    )
    args = parser.parse_args()
373
374
375
376

    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

377
    mcmc(args.data_file, args.output, config["Mcmc"])