inference.py 14.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
8
import numpy as np
9
10
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
11

12
13
14
15
16
17
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
18
from gemlib.mcmc import Posterior
19

20
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
21
from covid.cli_arg_parse import cli_args
22

23
import model_spec
24

25
26
27
28
29
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

30

31
32
tfd = tfp.distributions
tfb = tfp.bijectors
33
DTYPE = model_spec.DTYPE
34
35
36
37
38


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
39
    args = cli_args()
40

41
    with open(args.config, "r") as f:
42
43
        config = yaml.load(f, Loader=yaml.FullLoader)

44
    inference_period = [
45
        np.datetime64(x) for x in config["Global"]["inference_period"]
46
47
    ]

48
    covar_data = model_spec.read_covariates(config)
49
50
51

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
52
53
54
55
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
56
57
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
58
    ).astype(DTYPE)
59
60

    # Impute censored events, return cases
61
    events = model_spec.impute_censored_events(cases)
62
63
64
65
66
67
68
69
70

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
71
72
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
73
        ),
74
        events=events,
75
        stoichiometry=model_spec.STOICHIOMETRY,
76
    )
77
78
79
80
81
82
83
84
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
85
86
87
88
89
90
91
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

92
93
94
95
96
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
97
        priors=convert_priors(config["mcmc"]["prior"]),
98
    )
99

100
101
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
102
    def logp(block0, block1, events):
103
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
104
            dict(
105
                beta2=block0[0],
106
107
                gamma0=block0[1],
                gamma1=block0[2],
Chris Jewell's avatar
Chris Jewell committed
108
                sigma=block0[3],
Chris Jewell's avatar
Chris Jewell committed
109
                beta3=block0[4:],
Chris Jewell's avatar
Chris Jewell committed
110
                beta1=block1[0],
Chris Jewell's avatar
Chris Jewell committed
111
                xi=block1[1:],
Chris Jewell's avatar
Chris Jewell committed
112
113
                seir=events,
            )
114
        )
115

116
117
118
119
120
121
122
123
124
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
125
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
126
        def fn(target_log_prob_fn, _):
127
            return tfp.mcmc.TransformedTransitionKernel(
128
                inner_kernel=AdaptiveRandomWalkMetropolis(
129
                    target_log_prob_fn=target_log_prob_fn,
130
131
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
132
133
                    covariance_burnin=200,
                ),
134
                bijector=tfp.bijectors.Blockwise(
Chris Jewell's avatar
Chris Jewell committed
135
136
137
138
139
140
                    bijectors=[
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                    ],
Chris Jewell's avatar
Chris Jewell committed
141
                    block_sizes=[1, 2, 1, 4],
142
                ),
143
144
                name=name,
            )
145

146
147
        return fn

148
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
149
        def fn(target_log_prob_fn, _):
150
            return AdaptiveRandomWalkMetropolis(
151
                target_log_prob_fn=target_log_prob_fn,
152
153
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
154
155
156
157
158
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
159

160
161
162
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
163
        def fn(target_log_prob_fn, _):
164
            return tfp.mcmc.MetropolisHastings(
165
                inner_kernel=UncalibratedEventTimesUpdate(
166
                    target_log_prob_fn=target_log_prob_fn,
167
168
169
170
171
172
173
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
174
175
176
177
178
                ),
                name=name,
            )

        return fn
179

180
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
181
        def fn(target_log_prob_fn, _):
182
            return tfp.mcmc.MetropolisHastings(
183
                inner_kernel=UncalibratedOccultUpdate(
184
                    target_log_prob_fn=target_log_prob_fn,
185
186
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
187
                    ),
188
189
190
191
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
192
193
194
195
196
197
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
198
    def make_event_multiscan_kernel(target_log_prob_fn, _):
199
200
201
202
203
204
205
206
207
208
209
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
210
            ),
211
212
        )

213
214
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
215
216
217
218
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

219
        results_dict["block0"] = {
Chris Jewell's avatar
Chris Jewell committed
220
221
222
223
224
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
225
        results_dict["block1"] = {
Chris Jewell's avatar
Chris Jewell committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
243

Chris Jewell's avatar
Chris Jewell committed
244
245
246
247
248
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
249

Chris Jewell's avatar
Chris Jewell committed
250
        return results_dict
251
252

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
Chris Jewell's avatar
Chris Jewell committed
253
254
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
255
256
257
258
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

259
260
261
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
262
263
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
264
                    (2, make_event_multiscan_kernel),
265
                ],
266
                name="gibbs0",
267
            )
Chris Jewell's avatar
Chris Jewell committed
268

269
270
271
272
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
273
                num_steps_between_results=thin,
274
275
276
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
277
278
            )

279
            return samples, results, final_results
280
281
282
283
284
285

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
Chris Jewell's avatar
Chris Jewell committed
286
287
288
    NUM_BURSTS = int(config["mcmc"]["num_bursts"])
    NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"])
    NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
289
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
290
291
292
293
294

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
295
        np.array([0.6, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE),
296
        np.zeros(
Chris Jewell's avatar
Chris Jewell committed
297
            model.model["xi"](0.0, 0.1).event_shape[-1]
298
299
300
301
            # + model.model["beta3"]().event_shape[-1]
            + 1,
            dtype=DTYPE,
        ),
302
303
        events,
    ]
304
    print("Initial logpi:", logp(*current_state))
305

Chris Jewell's avatar
Chris Jewell committed
306
307
308
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
Chris Jewell's avatar
Chris Jewell committed
309
310
311
312
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
313
314
315
316
        sample_dict={
            "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)),
            "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)),
            "gamma1": (samples[0][:, 2], (NUM_BURST_SAMPLES,)),
Chris Jewell's avatar
Chris Jewell committed
317
318
            "sigma": (samples[0][:, 3], (NUM_BURST_SAMPLES,)),
            "beta3": (samples[0][:, 4:], (NUM_BURST_SAMPLES, 2)),
319
320
321
322
323
324
325
            "beta1": (samples[1][:, 0], (NUM_BURST_SAMPLES,)),
            "xi": (
                samples[1][:, 1:],
                (NUM_BURST_SAMPLES, samples[1].shape[1] - 1),
            ),
            "events": (samples[2], (NUM_BURST_SAMPLES, 64, 64, 1)),
        },
Chris Jewell's avatar
Chris Jewell committed
326
327
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
328
    )
Chris Jewell's avatar
Chris Jewell committed
329
330
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))
331
332
333
334

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
335
    final_results = None
336
337
338
    for i in tqdm.tqdm(
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]
    ):
339
        samples, results, final_results = sample(
340
341
            NUM_BURST_SAMPLES,
            init_state=current_state,
Chris Jewell's avatar
Chris Jewell committed
342
            thin=config["mcmc"]["thin"] - 1,
343
            previous_results=final_results,
344
345
346
347
348
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
349
        posterior.write_samples(
350
351
352
353
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
Chris Jewell's avatar
Chris Jewell committed
354
355
                "sigma": samples[0][:, 3],
                "beta3": samples[0][:, 4:],
356
357
358
359
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
                "events": samples[2],
            },
Chris Jewell's avatar
Chris Jewell committed
360
361
362
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
363
364
365
366
367
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
Chris Jewell's avatar
Chris Jewell committed
368
            tf.reduce_mean(
369
                tf.cast(results["block0"]["is_accepted"], tf.float32)
Chris Jewell's avatar
Chris Jewell committed
370
            ),
371
372
        )
        print(
373
            "Acceptance xi:",
Chris Jewell's avatar
Chris Jewell committed
374
            tf.reduce_mean(
375
                tf.cast(results["block1"]["is_accepted"], tf.float32),
Chris Jewell's avatar
Chris Jewell committed
376
            ),
377
378
379
        )
        print(
            "Acceptance move S->E:",
Chris Jewell's avatar
Chris Jewell committed
380
381
382
            tf.reduce_mean(
                tf.cast(results["move/S->E"]["is_accepted"], tf.float32)
            ),
383
384
385
        )
        print(
            "Acceptance move E->I:",
Chris Jewell's avatar
Chris Jewell committed
386
387
388
            tf.reduce_mean(
                tf.cast(results["move/E->I"]["is_accepted"], tf.float32)
            ),
389
390
391
        )
        print(
            "Acceptance occult S->E:",
Chris Jewell's avatar
Chris Jewell committed
392
393
394
            tf.reduce_mean(
                tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)
            ),
395
396
397
        )
        print(
            "Acceptance occult E->I:",
Chris Jewell's avatar
Chris Jewell committed
398
399
400
            tf.reduce_mean(
                tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)
            ),
401
        )
Chris Jewell's avatar
Chris Jewell committed
402

Chris Jewell's avatar
Chris Jewell committed
403
    print(
404
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
Chris Jewell's avatar
Chris Jewell committed
405
    )
406
    print(f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
407
408
409
410
411
412
413
414
415
416
417
418
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
419

Chris Jewell's avatar
Chris Jewell committed
420
    del posterior