inference.py 13.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
4
# pylint: disable=E402

import argparse
Chris Jewell's avatar
Chris Jewell committed
5
import os
6
7
8

# Uncomment to block GPU use

Chris Jewell's avatar
Chris Jewell committed
9
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
10

11
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
12

13
14
import tqdm
import yaml
Chris Jewell's avatar
Chris Jewell committed
15
import h5py
16
import numpy as np
17
18
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
19

20
21
from tensorflow_probability.python.experimental import unnest

Chris Jewell's avatar
Chris Jewell committed
22
from covid.impl.util import compute_state
23
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
24
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
25
from covid.impl.occult_events_mh import UncalibratedOccultUpdate, TransitionTopology
26
27
from covid.impl.gibbs import flatten_results
from covid.impl.gibbs_kernel import GibbsKernel, GibbsKernelResults
28
from covid.impl.multi_scan_kernel import MultiScanKernel
29
30
31
from covid.impl.adaptive_random_walk_metropolis import (
    AdaptiveRandomWalkMetropolisHastings,
)
32
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
33
from covid.cli_arg_parse import cli_args
34

35
import model_spec
36

37
38
39
40
41
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

42

43
44
tfd = tfp.distributions
tfb = tfp.bijectors
45
DTYPE = model_spec.DTYPE
46
47
48
49
50


if __name__ == "__main__":

    # Read in settings
Chris Jewell's avatar
Chris Jewell committed
51
    args = cli_args()
52

53
    with open(args.config, "r") as f:
54
55
        config = yaml.load(f, Loader=yaml.FullLoader)

56
57
58
59
60
61
62
    inference_period = [
        np.datetime64(x) for x in config["settings"]["inference_period"]
    ]

    covar_data = model_spec.read_covariates(
        config["data"], date_low=inference_period[0], date_high=inference_period[1],
    )
63
64
65

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
66
67
68
69
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
70
71
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
72
    ).astype(DTYPE)
73
74

    # Impute censored events, return cases
75
    events = model_spec.impute_censored_events(cases)
76
77
78
79
80
81
82
83
84

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
85
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], axis=-1
86
        ),
87
        events=events,
88
        stoichiometry=model_spec.STOICHIOMETRY,
89
    )
90
91
92
93
94
95
96
97
98
99
100
101
102
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
103
    )
104

105
106
107
108
    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
    def logp(theta, xi, events):
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
109
            dict(beta1=theta[0], beta2=theta[1], gamma=theta[2], xi=xi, seir=events,)
110
        )
111

112
113
114
115
116
117
118
119
120
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
121
122
123
124
125
126
127
128
129
130
131
132
    def make_theta_kernel(shape, name):
        def fn(target_log_prob_fn, state):
            return tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=AdaptiveRandomWalkMetropolisHastings(
                    target_log_prob_fn=target_log_prob_fn,
                    initial_state=tf.zeros(shape, dtype=model_spec.DTYPE),
                    initial_covariance=[np.eye(shape[0]) * 1e-1],
                    covariance_burnin=200,
                ),
                bijector=tfp.bijectors.Exp(),
                name=name,
            )
133

134
135
136
137
138
139
140
141
142
143
144
145
146
        return fn

    def make_xi_kernel(shape, name):
        def fn(target_log_prob_fn, state):
            return AdaptiveRandomWalkMetropolisHastings(
                target_log_prob_fn=target_log_prob_fn,
                initial_state=tf.ones(shape, dtype=model_spec.DTYPE),
                initial_covariance=[np.eye(shape[0]) * 1e-1],
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
147

148
149
150
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
151
152
        def fn(target_log_prob_fn, state):
            return tfp.mcmc.MetropolisHastings(
153
                inner_kernel=UncalibratedEventTimesUpdate(
154
                    target_log_prob_fn=target_log_prob_fn,
155
156
157
158
159
160
161
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
162
163
164
165
166
                ),
                name=name,
            )

        return fn
167

168
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
169
170
        def fn(target_log_prob_fn, state):
            return tfp.mcmc.MetropolisHastings(
171
                inner_kernel=UncalibratedOccultUpdate(
172
                    target_log_prob_fn=target_log_prob_fn,
173
174
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
175
                    ),
176
177
178
179
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
                ),
                name=name,
            )

        return fn

    def make_event_multiscan_kernel(target_log_prob_fn, state):
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
198
            ),
199
200
        )

201
202
203
204
205
    # MCMC tracing functions
    def trace_results_fn(_, results):
        """Returns log_prob, accepted, q_ratio"""

        def f(result):
206
207
208
209
210
211
212
213
            proposed_results = unnest.get_innermost(result, "proposed_results")
            log_prob = proposed_results.target_log_prob
            accepted = tf.cast(
                unnest.get_innermost(result, "is_accepted"), log_prob.dtype
            )
            q_ratio = proposed_results.log_acceptance_correction
            if hasattr(proposed_results, "extra"):
                proposed = tf.cast(proposed_results.extra, log_prob.dtype)
214
215
216
                return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
            return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)

217
218
219
220
        def recurse(f, results):
            if isinstance(results, GibbsKernelResults):
                return [recurse(f, x) for x in results.inner_results]
            return f(results)
221
222
223
224
225

        return recurse(f, results)

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
    @tf.function(autograph=False, experimental_compile=True)
226
    def sample(n_samples, init_state, previous_results=None):
227
228
229
230
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

231
232
233
234
235
236
            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
                    (0, make_theta_kernel(init_state[0].shape, "theta")),
                    (1, make_xi_kernel(init_state[1].shape, "xi")),
                    (2, make_event_multiscan_kernel),
237
                ],
238
                name="gibbs0",
239
            )
240
241
242
243
244
245
246
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
247
248
            )

249
            return samples, results, final_results
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
    NUM_BURSTS = config["mcmc"]["num_bursts"]
    NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
    NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
    THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
    NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
266
        np.array([0.45, 0.65, 0.48], dtype=DTYPE),
Chris Jewell's avatar
Chris Jewell committed
267
        np.zeros(model.model["xi"]().event_shape[-1], dtype=DTYPE),
268
269
270
271
272
        events,
    ]

    # Output Files
    posterior = h5py.File(
Chris Jewell's avatar
Chris Jewell committed
273
274
275
276
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
277
278
279
280
        "w",
        rdcc_nbytes=1024 ** 2 * 400,
        rdcc_nslots=100000,
        libver="latest",
281
    )
282
283
284
    event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)

    posterior.create_dataset("initial_state", data=initial_state)
Chris Jewell's avatar
Chris Jewell committed
285
286
287
288

    # Ideally we insert the inference period into the posterior file
    # as this allows us to post-attribute it to the data.  Maybe better
    # to simply save the data into it as well.
289
    posterior.create_dataset("config", data=yaml.dump(config))
290
291
292
293
    theta_samples = posterior.create_dataset(
        "samples/theta",
        [NUM_SAVED_SAMPLES, current_state[0].shape[0]],
        dtype=np.float64,
294
    )
295
296
    xi_samples = posterior.create_dataset(
        "samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
297
    )
298
299
300
301
302
303
304
    event_samples = posterior.create_dataset(
        "samples/events",
        event_size,
        dtype=DTYPE,
        chunks=(32, 32, 32, 1),
        compression="szip",
        compression_opts=("nn", 16),
305
    )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

    output_results = [
        posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
        posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
        posterior.create_dataset(
            "results/move/S->E", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
        ),
        posterior.create_dataset(
            "results/move/E->I", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
        ),
        posterior.create_dataset(
            "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
        ),
        posterior.create_dataset(
            "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
        ),
    ]
    posterior.swmr_mode = True

    print("Initial logpi:", logp(*current_state))

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
330
    final_results = None
331
    for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
332
333
        samples, results, final_results = sample(
            NUM_BURST_SAMPLES, init_state=current_state, previous_results=final_results,
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        )
        current_state = [s[-1] for s in samples]
        s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
        idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
        theta_samples[s, ...] = tf.gather(samples[0], idx)
        xi_samples[s, ...] = tf.gather(samples[1], idx)
        # cov = np.cov(
        #     np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
        #     rowvar=False,
        # )
        print(current_state[0].numpy(), flush=True)
        # print(cov, flush=True)
        # if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
        #     theta_scale = 2.38 ** 2 * cov / 2.0

        start = perf_counter()
        event_samples[s, ...] = tf.gather(samples[2], idx)
        end = perf_counter()

        flat_results = flatten_results(results)
        for i, ro in enumerate(output_results):
            ro[s, ...] = tf.gather(flat_results[i], idx)

        posterior.flush()
        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
            tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32)),
        )
        print(
            "Acceptance xi:", tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32))
        )
        print(
            "Acceptance move S->E:",
            tf.reduce_mean(tf.cast(flat_results[2][:, 1], tf.float32)),
        )
        print(
            "Acceptance move E->I:",
            tf.reduce_mean(tf.cast(flat_results[3][:, 1], tf.float32)),
        )
        print(
            "Acceptance occult S->E:",
            tf.reduce_mean(tf.cast(flat_results[4][:, 1], tf.float32)),
        )
        print(
            "Acceptance occult E->I:",
            tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
        )
Chris Jewell's avatar
Chris Jewell committed
382

383
384
385
386
387
388
    print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
    print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
    print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
    print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
    print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
    print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
389

390
    posterior.close()