inference.py 13 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12

13
14
15
16
17
18
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
19
from gemlib.mcmc import Posterior
20

21
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
22
from covid.cli_arg_parse import cli_args
23

24
import model_spec
25

26
27
28
29
30
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

31

32
33
tfd = tfp.distributions
tfb = tfp.bijectors
34
DTYPE = model_spec.DTYPE
35
36
37
38
39


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
40
    args = cli_args()
41

42
    with open(args.config, "r") as f:
43
44
        config = yaml.load(f, Loader=yaml.FullLoader)

45
46
47
48
49
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

    covar_data = model_spec.read_covariates(
50
51
52
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
53
    )
54
55
56

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
57
58
59
60
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
61
62
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
63
    ).astype(DTYPE)
64
65

    # Impute censored events, return cases
66
    events = model_spec.impute_censored_events(cases)
67
68
69
70
71
72
73
74
75

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
76
77
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
78
        ),
79
        events=events,
80
        stoichiometry=model_spec.STOICHIOMETRY,
81
    )
82
83
84
85
86
87
88
89
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
90
91
92
93
94
95
96
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

97
98
99
100
101
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
102
        priors=convert_priors(config["mcmc"]["prior"]),
103
    )
104

105
106
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
107
    def logp(block0, block1, events):
108
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
109
            dict(
110
111
                beta2=block0[0],
                gamma=block0[1],
Chris Jewell's avatar
Chris Jewell committed
112
113
                beta1=block1[0],
                beta3=block1[1:3],
114
                xi=block1[3:],
Chris Jewell's avatar
Chris Jewell committed
115
116
                seir=events,
            )
117
        )
118

119
120
121
122
123
124
125
126
127
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
128
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
129
        def fn(target_log_prob_fn, _):
130
            return tfp.mcmc.TransformedTransitionKernel(
131
                inner_kernel=AdaptiveRandomWalkMetropolis(
132
                    target_log_prob_fn=target_log_prob_fn,
133
134
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
135
136
137
138
139
                    covariance_burnin=200,
                ),
                bijector=tfp.bijectors.Exp(),
                name=name,
            )
140

141
142
        return fn

143
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
144
        def fn(target_log_prob_fn, _):
145
            return AdaptiveRandomWalkMetropolis(
146
                target_log_prob_fn=target_log_prob_fn,
147
148
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
149
150
151
152
153
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
154

155
156
157
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
158
        def fn(target_log_prob_fn, _):
159
            return tfp.mcmc.MetropolisHastings(
160
                inner_kernel=UncalibratedEventTimesUpdate(
161
                    target_log_prob_fn=target_log_prob_fn,
162
163
164
165
166
167
168
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
169
170
171
172
173
                ),
                name=name,
            )

        return fn
174

175
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
176
        def fn(target_log_prob_fn, _):
177
            return tfp.mcmc.MetropolisHastings(
178
                inner_kernel=UncalibratedOccultUpdate(
179
                    target_log_prob_fn=target_log_prob_fn,
180
181
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
182
                    ),
183
184
185
186
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
187
188
189
190
191
192
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
193
    def make_event_multiscan_kernel(target_log_prob_fn, _):
194
195
196
197
198
199
200
201
202
203
204
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
205
            ),
206
207
        )

208
209
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

        results_dict["theta"] = {
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
        results_dict["xi"] = {
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
238

Chris Jewell's avatar
Chris Jewell committed
239
240
241
242
243
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
244

Chris Jewell's avatar
Chris Jewell committed
245
        return results_dict
246
247

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
Chris Jewell's avatar
Chris Jewell committed
248
249
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
250
251
252
253
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

254
255
256
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
257
258
                    (0, make_blk0_kernel(init_state[0].shape, "theta")),
                    (1, make_blk1_kernel(init_state[1].shape, "xi")),
259
                    (2, make_event_multiscan_kernel),
260
                ],
261
                name="gibbs0",
262
            )
Chris Jewell's avatar
Chris Jewell committed
263

264
265
266
267
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
268
                num_steps_between_results=thin,
269
270
271
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
272
273
            )

274
            return samples, results, final_results
275
276
277
278
279
280

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
Chris Jewell's avatar
Chris Jewell committed
281
282
283
    NUM_BURSTS = int(config["mcmc"]["num_bursts"])
    NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"])
    NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
284
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
285
286
287
288
289

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
290
        np.array([0.65, 0.48], dtype=DTYPE),
291
        np.zeros(model.model["xi"](0.0).event_shape[-1] + 3, dtype=DTYPE),
292
293
294
        events,
    ]

Chris Jewell's avatar
Chris Jewell committed
295
296
297
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
Chris Jewell's avatar
Chris Jewell committed
298
299
300
301
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
Chris Jewell's avatar
Chris Jewell committed
302
303
304
305
306
307
        sample_dict={"theta": (samples[0], (NUM_BURST_SAMPLES, 1)),
                     "xi": (samples[1], (NUM_BURST_SAMPLES, 1)),
                     "events": (samples[2], (NUM_BURST_SAMPLES, 64, 64, 1)),
                    },
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
308
    )
Chris Jewell's avatar
Chris Jewell committed
309
310
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))
311
312
313
314
315
316

    print("Initial logpi:", logp(*current_state))

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
317
    final_results = None
318
319
320
    for i in tqdm.tqdm(
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]
    ):
321
        samples, results, final_results = sample(
322
323
            NUM_BURST_SAMPLES,
            init_state=current_state,
Chris Jewell's avatar
Chris Jewell committed
324
            thin=config["mcmc"]["thin"] - 1,
325
            previous_results=final_results,
326
327
328
329
330
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
331
        posterior.write_samples(
332
            {"theta": samples[0], "xi": samples[1], "events": samples[2]},
Chris Jewell's avatar
Chris Jewell committed
333
334
335
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
336
337
338
339
340
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
Chris Jewell's avatar
Chris Jewell committed
341
342
343
            tf.reduce_mean(
                tf.cast(results["theta"]["is_accepted"], tf.float32)
            ),
344
345
        )
        print(
346
            "Acceptance xi:",
Chris Jewell's avatar
Chris Jewell committed
347
348
349
            tf.reduce_mean(
                tf.cast(results["theta"]["is_accepted"], tf.float32),
            ),
350
351
352
        )
        print(
            "Acceptance move S->E:",
Chris Jewell's avatar
Chris Jewell committed
353
354
355
            tf.reduce_mean(
                tf.cast(results["move/S->E"]["is_accepted"], tf.float32)
            ),
356
357
358
        )
        print(
            "Acceptance move E->I:",
Chris Jewell's avatar
Chris Jewell committed
359
360
361
            tf.reduce_mean(
                tf.cast(results["move/E->I"]["is_accepted"], tf.float32)
            ),
362
363
364
        )
        print(
            "Acceptance occult S->E:",
Chris Jewell's avatar
Chris Jewell committed
365
366
367
            tf.reduce_mean(
                tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)
            ),
368
369
370
        )
        print(
            "Acceptance occult E->I:",
Chris Jewell's avatar
Chris Jewell committed
371
372
373
            tf.reduce_mean(
                tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)
            ),
374
        )
Chris Jewell's avatar
Chris Jewell committed
375

Chris Jewell's avatar
Chris Jewell committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    print(
        f"Acceptance theta: {posterior['results/theta/is_accepted'][:].mean()}"
    )
    print(f"Acceptance xi: {posterior['results/xi/is_accepted'][:].mean()}")
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
392

Chris Jewell's avatar
Chris Jewell committed
393
    del posterior