inference.py 14.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12

13
14
from tensorflow_probability.python.experimental import unnest

15
16
17
18
19
20
21
22
23
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc.gibbs_kernel import GibbsKernelResults
from gemlib.mcmc.gibbs_kernel import flatten_results
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis

24
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
25
from covid.cli_arg_parse import cli_args
26

27
import model_spec
28

29
30
31
32
33
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

34

35
36
tfd = tfp.distributions
tfb = tfp.bijectors
37
DTYPE = model_spec.DTYPE
38
39
40
41
42


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
43
    args = cli_args()
44

45
    with open(args.config, "r") as f:
46
47
        config = yaml.load(f, Loader=yaml.FullLoader)

48
49
50
51
52
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

    covar_data = model_spec.read_covariates(
53
54
55
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
56
    )
57
58
59

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
60
61
62
63
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
64
65
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
66
    ).astype(DTYPE)
67
68

    # Impute censored events, return cases
69
    events = model_spec.impute_censored_events(cases)
70
71
72
73
74
75
76
77
78

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
79
80
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
81
        ),
82
        events=events,
83
        stoichiometry=model_spec.STOICHIOMETRY,
84
    )
85
86
87
88
89
90
91
92
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
93
94
95
96
97
98
99
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

100
101
102
103
104
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
105
        priors=convert_priors(config['mcmc']['prior']),
106
    )
107

108
109
110
111
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
    def logp(theta, xi, events):
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
112
113
114
115
116
117
118
            dict(
                beta1=xi[0],
                beta2=theta[0],
                gamma=theta[1],
                xi=xi[1:],
                seir=events,
            )
119
        )
120

121
122
123
124
125
126
127
128
129
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
130
    def make_theta_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
131
        def fn(target_log_prob_fn, _):
132
            return tfp.mcmc.TransformedTransitionKernel(
133
                inner_kernel=AdaptiveRandomWalkMetropolis(
134
                    target_log_prob_fn=target_log_prob_fn,
Chris Jewell's avatar
Chris Jewell committed
135
136
137
                    initial_covariance=[
                        np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
                    ],
138
139
140
141
142
                    covariance_burnin=200,
                ),
                bijector=tfp.bijectors.Exp(),
                name=name,
            )
143

144
145
146
        return fn

    def make_xi_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
147
        def fn(target_log_prob_fn, _):
148
            return AdaptiveRandomWalkMetropolis(
149
                target_log_prob_fn=target_log_prob_fn,
Chris Jewell's avatar
Chris Jewell committed
150
151
152
                initial_covariance=[
                    np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
                ],
153
154
155
156
157
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
158

159
160
161
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
162
        def fn(target_log_prob_fn, _):
163
            return tfp.mcmc.MetropolisHastings(
164
                inner_kernel=UncalibratedEventTimesUpdate(
165
                    target_log_prob_fn=target_log_prob_fn,
166
167
168
169
170
171
172
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
173
174
175
176
177
                ),
                name=name,
            )

        return fn
178

179
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
180
        def fn(target_log_prob_fn, _):
181
            return tfp.mcmc.MetropolisHastings(
182
                inner_kernel=UncalibratedOccultUpdate(
183
                    target_log_prob_fn=target_log_prob_fn,
184
185
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
186
                    ),
187
188
189
190
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
191
192
193
194
195
196
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
197
    def make_event_multiscan_kernel(target_log_prob_fn, _):
198
199
200
201
202
203
204
205
206
207
208
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
209
            ),
210
211
        )

212
213
214
215
216
    # MCMC tracing functions
    def trace_results_fn(_, results):
        """Returns log_prob, accepted, q_ratio"""

        def f(result):
217
218
219
220
221
222
223
224
            proposed_results = unnest.get_innermost(result, "proposed_results")
            log_prob = proposed_results.target_log_prob
            accepted = tf.cast(
                unnest.get_innermost(result, "is_accepted"), log_prob.dtype
            )
            q_ratio = proposed_results.log_acceptance_correction
            if hasattr(proposed_results, "extra"):
                proposed = tf.cast(proposed_results.extra, log_prob.dtype)
225
226
227
                return tf.concat(
                    [[log_prob], [accepted], [q_ratio], proposed], axis=0
                )
228
229
            return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)

230
231
232
233
        def recurse(f, results):
            if isinstance(results, GibbsKernelResults):
                return [recurse(f, x) for x in results.inner_results]
            return f(results)
234
235
236
237
238

        return recurse(f, results)

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
    @tf.function(autograph=False, experimental_compile=True)
239
    def sample(n_samples, init_state, previous_results=None):
240
241
242
243
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

244
245
246
247
248
249
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
                    (0, make_theta_kernel(init_state[0].shape, "theta")),
                    (1, make_xi_kernel(init_state[1].shape, "xi")),
                    (2, make_event_multiscan_kernel),
250
                ],
251
                name="gibbs0",
252
            )
253
254
255
256
257
258
259
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
260
261
            )

262
            return samples, results, final_results
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
    NUM_BURSTS = config["mcmc"]["num_bursts"]
    NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
    NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
    THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
    NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
279
        np.array([0.65, 0.48], dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
280
        np.zeros(model.model["xi"](0.0).event_shape[-1] + 1, dtype=DTYPE),
281
282
283
284
285
        events,
    ]

    # Output Files
    posterior = h5py.File(
Chris Jewell's avatar
Chris Jewell committed
286
287
288
289
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
290
291
292
293
        "w",
        rdcc_nbytes=1024 ** 2 * 400,
        rdcc_nslots=100000,
        libver="latest",
294
    )
295
296
297
    event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)

    posterior.create_dataset("initial_state", data=initial_state)
Chris Jewell's avatar
Chris Jewell committed
298
299
300
301

    # Ideally we insert the inference period into the posterior file
    # as this allows us to post-attribute it to the data.  Maybe better
    # to simply save the data into it as well.
302
    posterior.create_dataset("config", data=yaml.dump(config))
303
304
305
306
    theta_samples = posterior.create_dataset(
        "samples/theta",
        [NUM_SAVED_SAMPLES, current_state[0].shape[0]],
        dtype=np.float64,
307
    )
308
    xi_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
309
310
311
        "samples/xi",
        [NUM_SAVED_SAMPLES, current_state[1].shape[0]],
        dtype=np.float64,
312
    )
313
314
315
316
317
318
319
    event_samples = posterior.create_dataset(
        "samples/events",
        event_size,
        dtype=DTYPE,
        chunks=(32, 32, 32, 1),
        compression="szip",
        compression_opts=("nn", 16),
320
    )
321
322
323

    output_results = [
        posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
324
            "results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
325
326
        ),
        posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
327
            "results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
328
329
        ),
        posterior.create_dataset(
330
331
332
333
334
335
336
337
            "results/move/S->E",
            (NUM_SAVED_SAMPLES, 3 + num_metapop),
            dtype=DTYPE,
        ),
        posterior.create_dataset(
            "results/move/E->I",
            (NUM_SAVED_SAMPLES, 3 + num_metapop),
            dtype=DTYPE,
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        ),
        posterior.create_dataset(
            "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
        ),
        posterior.create_dataset(
            "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
        ),
    ]
    posterior.swmr_mode = True

    print("Initial logpi:", logp(*current_state))

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
353
    final_results = None
354
    for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
355
        samples, results, final_results = sample(
356
357
358
            NUM_BURST_SAMPLES,
            init_state=current_state,
            previous_results=final_results,
359
360
        )
        current_state = [s[-1] for s in samples]
361
362
363
        s = slice(
            i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES
        )
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
        theta_samples[s, ...] = tf.gather(samples[0], idx)
        xi_samples[s, ...] = tf.gather(samples[1], idx)
        # cov = np.cov(
        #     np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
        #     rowvar=False,
        # )
        print(current_state[0].numpy(), flush=True)
        # print(cov, flush=True)
        # if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
        #     theta_scale = 2.38 ** 2 * cov / 2.0

        start = perf_counter()
        event_samples[s, ...] = tf.gather(samples[2], idx)
        end = perf_counter()

        flat_results = flatten_results(results)
        for i, ro in enumerate(output_results):
            ro[s, ...] = tf.gather(flat_results[i], idx)

        posterior.flush()
        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
            tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32)),
        )
        print(
391
392
            "Acceptance xi:",
            tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32)),
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        )
        print(
            "Acceptance move S->E:",
            tf.reduce_mean(tf.cast(flat_results[2][:, 1], tf.float32)),
        )
        print(
            "Acceptance move E->I:",
            tf.reduce_mean(tf.cast(flat_results[3][:, 1], tf.float32)),
        )
        print(
            "Acceptance occult S->E:",
            tf.reduce_mean(tf.cast(flat_results[4][:, 1], tf.float32)),
        )
        print(
            "Acceptance occult E->I:",
            tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
        )
Chris Jewell's avatar
Chris Jewell committed
410

411
412
413
414
415
416
    print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
    print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
    print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
    print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
    print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
    print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
417

418
    posterior.close()