inference.py 12.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12

13
14
15
16
17
18
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
19
from gemlib.mcmc import Posterior
20

21
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
22
from covid.cli_arg_parse import cli_args
23

24
import model_spec
25

26
27
28
29
30
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

31

32
33
tfd = tfp.distributions
tfb = tfp.bijectors
34
DTYPE = model_spec.DTYPE
35
36
37
38
39


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
40
    args = cli_args()
41

42
    with open(args.config, "r") as f:
43
44
        config = yaml.load(f, Loader=yaml.FullLoader)

45
46
47
48
49
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

    covar_data = model_spec.read_covariates(
50
51
52
        config["data"],
        date_low=inference_period[0],
        date_high=inference_period[1],
53
    )
54
55
56

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
57
58
59
60
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
61
62
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
63
    ).astype(DTYPE)
64
65

    # Impute censored events, return cases
66
    events = model_spec.impute_censored_events(cases)
67
68
69
70
71
72
73
74
75

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
76
77
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
78
        ),
79
        events=events,
80
        stoichiometry=model_spec.STOICHIOMETRY,
81
    )
82
83
84
85
86
87
88
89
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
90
91
92
93
94
95
96
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

97
98
99
100
101
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
102
        priors=convert_priors(config["mcmc"]["prior"]),
103
    )
104

105
106
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
107
    def logp(block0, block1, events):
108
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
109
            dict(
110
111
                beta2=block0[0],
                gamma=block0[1],
Chris Jewell's avatar
Chris Jewell committed
112
113
                beta1=block1[0],
                beta3=block1[1:3],
114
                xi=block1[3:],
Chris Jewell's avatar
Chris Jewell committed
115
116
                seir=events,
            )
117
        )
118

119
120
121
122
123
124
125
126
127
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
128
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
129
        def fn(target_log_prob_fn, _):
130
            return tfp.mcmc.TransformedTransitionKernel(
131
                inner_kernel=AdaptiveRandomWalkMetropolis(
132
                    target_log_prob_fn=target_log_prob_fn,
133
134
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
135
136
137
138
139
                    covariance_burnin=200,
                ),
                bijector=tfp.bijectors.Exp(),
                name=name,
            )
140

141
142
        return fn

143
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
144
        def fn(target_log_prob_fn, _):
145
            return AdaptiveRandomWalkMetropolis(
146
                target_log_prob_fn=target_log_prob_fn,
147
148
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
149
150
151
152
153
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
154

155
156
157
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
158
        def fn(target_log_prob_fn, _):
159
            return tfp.mcmc.MetropolisHastings(
160
                inner_kernel=UncalibratedEventTimesUpdate(
161
                    target_log_prob_fn=target_log_prob_fn,
162
163
164
165
166
167
168
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
169
170
171
172
173
                ),
                name=name,
            )

        return fn
174

175
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
176
        def fn(target_log_prob_fn, _):
177
            return tfp.mcmc.MetropolisHastings(
178
                inner_kernel=UncalibratedOccultUpdate(
179
                    target_log_prob_fn=target_log_prob_fn,
180
181
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
182
                    ),
183
184
185
186
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
187
188
189
190
191
192
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
193
    def make_event_multiscan_kernel(target_log_prob_fn, _):
194
195
196
197
198
199
200
201
202
203
204
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
205
            ),
206
207
        )

208
209
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

        results_dict["theta"] = {
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
        results_dict["xi"] = {
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
238

Chris Jewell's avatar
Chris Jewell committed
239
240
241
242
243
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
244

Chris Jewell's avatar
Chris Jewell committed
245
        return results_dict
246
247

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
Chris Jewell's avatar
Chris Jewell committed
248
249
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
250
251
252
253
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

254
255
256
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
257
258
                    (0, make_blk0_kernel(init_state[0].shape, "theta")),
                    (1, make_blk1_kernel(init_state[1].shape, "xi")),
259
                    (2, make_event_multiscan_kernel),
260
                ],
261
                name="gibbs0",
262
            )
Chris Jewell's avatar
Chris Jewell committed
263

264
265
266
267
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
268
                num_steps_between_results=thin,
269
270
271
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
272
273
            )

274
            return samples, results, final_results
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
    NUM_BURSTS = config["mcmc"]["num_bursts"]
    NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
    NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
    THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
    NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
291
        np.array([0.65, 0.48], dtype=DTYPE),
292
        np.zeros(model.model["xi"](0.0).event_shape[-1] + 3, dtype=DTYPE),
293
294
295
        events,
    ]

Chris Jewell's avatar
Chris Jewell committed
296
297
298
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
Chris Jewell's avatar
Chris Jewell committed
299
300
301
302
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
Chris Jewell's avatar
Chris Jewell committed
303
304
305
        {"theta": samples[0], "xi": samples[1]},
        results,
        NUM_SAVED_SAMPLES,
306
    )
Chris Jewell's avatar
Chris Jewell committed
307
308
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))
309
310
311
312
313
314

    print("Initial logpi:", logp(*current_state))

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
315
    final_results = None
316
    for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
317
        samples, results, final_results = sample(
318
319
            NUM_BURST_SAMPLES,
            init_state=current_state,
Chris Jewell's avatar
Chris Jewell committed
320
            thin=config["mcmc"]["thin"] - 1,
321
            previous_results=final_results,
322
323
324
325
326
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
327
328
329
330
331
        posterior.write_samples(
            {"theta": samples[0], "xi": samples[1]},
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
332
333
334
335
336
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
Chris Jewell's avatar
Chris Jewell committed
337
338
339
            tf.reduce_mean(
                tf.cast(results["theta"]["is_accepted"], tf.float32)
            ),
340
341
        )
        print(
342
            "Acceptance xi:",
Chris Jewell's avatar
Chris Jewell committed
343
344
345
            tf.reduce_mean(
                tf.cast(results["theta"]["is_accepted"], tf.float32),
            ),
346
347
348
        )
        print(
            "Acceptance move S->E:",
Chris Jewell's avatar
Chris Jewell committed
349
350
351
            tf.reduce_mean(
                tf.cast(results["move/S->E"]["is_accepted"], tf.float32)
            ),
352
353
354
        )
        print(
            "Acceptance move E->I:",
Chris Jewell's avatar
Chris Jewell committed
355
356
357
            tf.reduce_mean(
                tf.cast(results["move/E->I"]["is_accepted"], tf.float32)
            ),
358
359
360
        )
        print(
            "Acceptance occult S->E:",
Chris Jewell's avatar
Chris Jewell committed
361
362
363
            tf.reduce_mean(
                tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)
            ),
364
365
366
        )
        print(
            "Acceptance occult E->I:",
Chris Jewell's avatar
Chris Jewell committed
367
368
369
            tf.reduce_mean(
                tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)
            ),
370
        )
Chris Jewell's avatar
Chris Jewell committed
371

Chris Jewell's avatar
Chris Jewell committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    print(
        f"Acceptance theta: {posterior['results/theta/is_accepted'][:].mean()}"
    )
    print(f"Acceptance xi: {posterior['results/xi/is_accepted'][:].mean()}")
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
388

Chris Jewell's avatar
Chris Jewell committed
389
    del posterior