inference.py 13.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
8
import numpy as np
9
10
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
11

12
13
14
15
16
17
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
18
from gemlib.mcmc import Posterior
19

20
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
21
from covid.cli_arg_parse import cli_args
22

23
import model_spec
24

25
26
27
28
29
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

30

31
32
tfd = tfp.distributions
tfb = tfp.bijectors
33
DTYPE = model_spec.DTYPE
34
35
36
37
38


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
39
    args = cli_args()
40

41
    with open(args.config, "r") as f:
42
43
        config = yaml.load(f, Loader=yaml.FullLoader)

44
45
46
47
48
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

    covar_data = model_spec.read_covariates(
49
50
51
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
52
    )
53
54
55

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
56
57
58
59
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
60
61
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
62
    ).astype(DTYPE)
63
64

    # Impute censored events, return cases
65
    events = model_spec.impute_censored_events(cases)
66
67
68
69
70
71
72
73
74

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
75
76
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
77
        ),
78
        events=events,
79
        stoichiometry=model_spec.STOICHIOMETRY,
80
    )
81
82
83
84
85
86
87
88
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
89
90
91
92
93
94
95
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

96
97
98
99
100
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
101
        priors=convert_priors(config["mcmc"]["prior"]),
102
    )
103

104
105
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
106
    def logp(block0, block1, events):
107
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
108
            dict(
109
                beta2=block0[0],
110
111
                gamma0=block0[1],
                gamma1=block0[2],
Chris Jewell's avatar
Chris Jewell committed
112
                beta1=block1[0],
113
114
                beta3=[0.0, 0.0, 0.0],  # block1[1:4],
                xi=block1[1:],  # block1[4:],
Chris Jewell's avatar
Chris Jewell committed
115
116
                seir=events,
            )
117
        )
118

119
120
121
122
123
124
125
126
127
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
128
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
129
        def fn(target_log_prob_fn, _):
130
            return tfp.mcmc.TransformedTransitionKernel(
131
                inner_kernel=AdaptiveRandomWalkMetropolis(
132
                    target_log_prob_fn=target_log_prob_fn,
133
134
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
135
136
                    covariance_burnin=200,
                ),
137
138
139
140
                bijector=tfp.bijectors.Blockwise(
                    bijectors=[tfp.bijectors.Exp(), tfp.bijectors.Identity()],
                    block_sizes=[1, 2],
                ),
141
142
                name=name,
            )
143

144
145
        return fn

146
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
147
        def fn(target_log_prob_fn, _):
148
            return AdaptiveRandomWalkMetropolis(
149
                target_log_prob_fn=target_log_prob_fn,
150
151
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
152
153
154
155
156
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
157

158
159
160
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
161
        def fn(target_log_prob_fn, _):
162
            return tfp.mcmc.MetropolisHastings(
163
                inner_kernel=UncalibratedEventTimesUpdate(
164
                    target_log_prob_fn=target_log_prob_fn,
165
166
167
168
169
170
171
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
172
173
174
175
176
                ),
                name=name,
            )

        return fn
177

178
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
179
        def fn(target_log_prob_fn, _):
180
            return tfp.mcmc.MetropolisHastings(
181
                inner_kernel=UncalibratedOccultUpdate(
182
                    target_log_prob_fn=target_log_prob_fn,
183
184
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
185
                    ),
186
187
188
189
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
190
191
192
193
194
195
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
196
    def make_event_multiscan_kernel(target_log_prob_fn, _):
197
198
199
200
201
202
203
204
205
206
207
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
208
            ),
209
210
        )

211
212
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
213
214
215
216
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

217
        results_dict["block0"] = {
Chris Jewell's avatar
Chris Jewell committed
218
219
220
221
222
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
223
        results_dict["block1"] = {
Chris Jewell's avatar
Chris Jewell committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
241

Chris Jewell's avatar
Chris Jewell committed
242
243
244
245
246
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
247

Chris Jewell's avatar
Chris Jewell committed
248
        return results_dict
249
250

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
Chris Jewell's avatar
Chris Jewell committed
251
252
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
253
254
255
256
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

257
258
259
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
260
261
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
262
                    (2, make_event_multiscan_kernel),
263
                ],
264
                name="gibbs0",
265
            )
Chris Jewell's avatar
Chris Jewell committed
266

267
268
269
270
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
271
                num_steps_between_results=thin,
272
273
274
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
275
276
            )

277
            return samples, results, final_results
278
279
280
281
282
283

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
Chris Jewell's avatar
Chris Jewell committed
284
285
286
    NUM_BURSTS = int(config["mcmc"]["num_bursts"])
    NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"])
    NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
287
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
288
289
290
291
292

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
293
        np.array([0.2, 0.0, 0.0], dtype=DTYPE),
294
295
296
297
298
299
        np.zeros(
            model.model["xi"](0.0).event_shape[-1]
            # + model.model["beta3"]().event_shape[-1]
            + 1,
            dtype=DTYPE,
        ),
300
301
        events,
    ]
302
    print("Initial logpi:", logp(*current_state))
303

Chris Jewell's avatar
Chris Jewell committed
304
305
306
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
Chris Jewell's avatar
Chris Jewell committed
307
308
309
310
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
311
312
313
314
315
316
317
318
319
320
321
        sample_dict={
            "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)),
            "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)),
            "gamma1": (samples[0][:, 2], (NUM_BURST_SAMPLES,)),
            "beta1": (samples[1][:, 0], (NUM_BURST_SAMPLES,)),
            "xi": (
                samples[1][:, 1:],
                (NUM_BURST_SAMPLES, samples[1].shape[1] - 1),
            ),
            "events": (samples[2], (NUM_BURST_SAMPLES, 64, 64, 1)),
        },
Chris Jewell's avatar
Chris Jewell committed
322
323
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
324
    )
Chris Jewell's avatar
Chris Jewell committed
325
326
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))
327
328
329
330

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
331
    final_results = None
332
333
334
    for i in tqdm.tqdm(
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]
    ):
335
        samples, results, final_results = sample(
336
337
            NUM_BURST_SAMPLES,
            init_state=current_state,
Chris Jewell's avatar
Chris Jewell committed
338
            thin=config["mcmc"]["thin"] - 1,
339
            previous_results=final_results,
340
341
342
343
344
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
345
        posterior.write_samples(
346
347
348
349
350
351
352
353
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
                "events": samples[2],
            },
Chris Jewell's avatar
Chris Jewell committed
354
355
356
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
357
358
359
360
361
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
Chris Jewell's avatar
Chris Jewell committed
362
            tf.reduce_mean(
363
                tf.cast(results["block0"]["is_accepted"], tf.float32)
Chris Jewell's avatar
Chris Jewell committed
364
            ),
365
366
        )
        print(
367
            "Acceptance xi:",
Chris Jewell's avatar
Chris Jewell committed
368
            tf.reduce_mean(
369
                tf.cast(results["block1"]["is_accepted"], tf.float32),
Chris Jewell's avatar
Chris Jewell committed
370
            ),
371
372
373
        )
        print(
            "Acceptance move S->E:",
Chris Jewell's avatar
Chris Jewell committed
374
375
376
            tf.reduce_mean(
                tf.cast(results["move/S->E"]["is_accepted"], tf.float32)
            ),
377
378
379
        )
        print(
            "Acceptance move E->I:",
Chris Jewell's avatar
Chris Jewell committed
380
381
382
            tf.reduce_mean(
                tf.cast(results["move/E->I"]["is_accepted"], tf.float32)
            ),
383
384
385
        )
        print(
            "Acceptance occult S->E:",
Chris Jewell's avatar
Chris Jewell committed
386
387
388
            tf.reduce_mean(
                tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)
            ),
389
390
391
        )
        print(
            "Acceptance occult E->I:",
Chris Jewell's avatar
Chris Jewell committed
392
393
394
            tf.reduce_mean(
                tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)
            ),
395
        )
Chris Jewell's avatar
Chris Jewell committed
396

Chris Jewell's avatar
Chris Jewell committed
397
    print(
398
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
Chris Jewell's avatar
Chris Jewell committed
399
    )
400
    print(f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
401
402
403
404
405
406
407
408
409
410
411
412
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
413

Chris Jewell's avatar
Chris Jewell committed
414
    del posterior