inference.py 12.2 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

4
import pickle as pkl
5
from time import perf_counter
6
7
8

import h5py
import xarray
9
10
import tqdm
import yaml
11
import numpy as np
12
13
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
14

15
16
17
18
19
20
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
21
from gemlib.mcmc import Posterior
22

23
import covid.model_spec as model_spec
24

25
26
tfd = tfp.distributions
tfb = tfp.bijectors
27
DTYPE = model_spec.DTYPE
28
29


30
def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
31
    """Constructs and runs the MCMC"""
32

33
34
35
36
    if tf.test.gpu_device_name():
        print("Using GPU")
    else:
        print("Using CPU")
37

38
39
    data = xarray.open_dataset(data_file, group="constant_data")
    cases = xarray.open_dataset(data_file, group="observations")["cases"]
40
41
42
43

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
    # Impute censored events, return cases
44
45
    print("Data shape:", cases.shape)
    events = model_spec.impute_censored_events(cases.astype(DTYPE))
46
47
48
49
50
51
52
53
54

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
55
56
57
58
            [
                tf.constant(data["N"], DTYPE)[:, tf.newaxis],
                tf.zeros_like(events[:, 0, :]),
            ],
59
            axis=-1,
60
        ),
61
        events=events,
62
        stoichiometry=model_spec.STOICHIOMETRY,
63
    )
64
    start_time = state.shape[1] - cases.shape[1]
65
66
67
68
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]

    ########################################################
69
    # Construct the MCMC kernels #
70
71
    ########################################################
    model = model_spec.CovidUK(
72
        covariates=data,
73
74
75
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
76
    )
77

78
    def joint_log_prob(block0, block1, events):
79
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
80
            dict(
81
                beta2=block0[0],
82
83
                gamma0=block0[1],
                gamma1=block0[2],
Chris Jewell's avatar
Chris Jewell committed
84
                sigma=block0[3],
Chris Jewell's avatar
Chris Jewell committed
85
                beta1=block1[0],
Chris Jewell's avatar
Chris Jewell committed
86
                xi=block1[1:],
Chris Jewell's avatar
Chris Jewell committed
87
88
                seir=events,
            )
89
        )
90

91
    # Build Metropolis within Gibbs sampler
92
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
93
        def fn(target_log_prob_fn, _):
94
            return tfp.mcmc.TransformedTransitionKernel(
95
                inner_kernel=AdaptiveRandomWalkMetropolis(
96
                    target_log_prob_fn=target_log_prob_fn,
97
98
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
99
100
                    covariance_burnin=200,
                ),
101
                bijector=tfp.bijectors.Blockwise(
Chris Jewell's avatar
Chris Jewell committed
102
103
104
105
                    bijectors=[
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                        tfp.bijectors.Exp(),
Chris Jewell's avatar
Chris Jewell committed
106
                        # tfp.bijectors.Identity(),
Chris Jewell's avatar
Chris Jewell committed
107
                    ],
Chris Jewell's avatar
Chris Jewell committed
108
                    block_sizes=[1, 2, 1],  # , 5],
109
                ),
110
111
                name=name,
            )
112

113
114
        return fn

115
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
116
        def fn(target_log_prob_fn, _):
117
            return AdaptiveRandomWalkMetropolis(
118
                target_log_prob_fn=target_log_prob_fn,
119
120
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
121
122
123
124
125
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
126

127
128
129
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
130
        def fn(target_log_prob_fn, _):
131
            return tfp.mcmc.MetropolisHastings(
132
                inner_kernel=UncalibratedEventTimesUpdate(
133
                    target_log_prob_fn=target_log_prob_fn,
134
135
136
137
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
138
139
140
                    dmax=config["dmax"],
                    mmax=config["m"],
                    nmax=config["nmax"],
141
142
143
144
145
                ),
                name=name,
            )

        return fn
146

147
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
148
        def fn(target_log_prob_fn, _):
149
            return tfp.mcmc.MetropolisHastings(
150
                inner_kernel=UncalibratedOccultUpdate(
151
                    target_log_prob_fn=target_log_prob_fn,
152
153
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
154
                    ),
155
                    cumulative_event_offset=initial_state,
156
                    nmax=config["occult_nmax"],
157
158
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
159
160
161
162
163
164
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
165
    def make_event_multiscan_kernel(target_log_prob_fn, _):
166
        return MultiScanKernel(
167
            config["num_event_time_updates"],
168
169
170
171
172
173
174
175
176
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
177
            ),
178
179
        )

180
181
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
182
183
184
185
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

186
        results_dict["block0"] = {
Chris Jewell's avatar
Chris Jewell committed
187
188
189
190
191
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
192
        results_dict["block1"] = {
Chris Jewell's avatar
Chris Jewell committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
210

Chris Jewell's avatar
Chris Jewell committed
211
212
213
214
215
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
216

Chris Jewell's avatar
Chris Jewell committed
217
        return results_dict
218
219

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
220
    @tf.function(autograph=use_autograph, experimental_compile=use_xla)
Chris Jewell's avatar
Chris Jewell committed
221
    def sample(n_samples, init_state, thin=0, previous_results=None):
222
223
224
225
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

226
            gibbs_schema = GibbsKernel(
227
                target_log_prob_fn=joint_log_prob,
228
                kernel_list=[
229
230
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
231
                    (2, make_event_multiscan_kernel),
232
                ],
233
                name="gibbs0",
234
            )
Chris Jewell's avatar
Chris Jewell committed
235

236
237
238
239
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
240
                num_steps_between_results=thin,
241
242
243
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
244
245
            )

246
            return samples, results, final_results
247

248
249
250
    ###############################
    # Construct bursted MCMC loop #
    ###############################
251
252
    NUM_BURSTS = int(config["num_bursts"])
    NUM_BURST_SAMPLES = int(config["num_burst_samples"])
253
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
254
255
256
257
258

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
259
        tf.constant(
Chris Jewell's avatar
Chris Jewell committed
260
261
            [0.6, 0.0, 0.0, 0.1], dtype=DTYPE
        ),  # , 0.0, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE),
262
        tf.zeros(
263
            model.model["xi"](0.0, 0.1).event_shape[-1] + 1,
264
265
            dtype=DTYPE,
        ),
266
267
        events,
    ]
268
    print("Initial logpi:", joint_log_prob(*current_state))
269

Chris Jewell's avatar
Chris Jewell committed
270
271
272
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
273
        output_file,
274
        sample_dict={
275
276
277
278
279
280
            "beta2": samples[0][:, 0],
            "gamma0": samples[0][:, 1],
            "gamma1": samples[0][:, 2],
            "sigma": samples[0][:, 3],
            "beta1": samples[1][:, 0],
            "xi": samples[1][:, 1:],
281
            "seir": samples[2],
282
        },
Chris Jewell's avatar
Chris Jewell committed
283
284
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
285
    )
Chris Jewell's avatar
Chris Jewell committed
286
    posterior._file.create_dataset("initial_state", data=initial_state)
287

288
289
290
    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
291
    final_results = None
292
    for i in tqdm.tqdm(
293
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["thin"]
294
    ):
295
        samples, results, final_results = sample(
296
297
            NUM_BURST_SAMPLES,
            init_state=current_state,
298
            thin=config["thin"] - 1,
299
            previous_results=final_results,
300
301
302
303
304
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
305
        posterior.write_samples(
306
307
308
309
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
Chris Jewell's avatar
Chris Jewell committed
310
                "sigma": samples[0][:, 3],
311
312
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
313
                "seir": samples[2],
314
            },
Chris Jewell's avatar
Chris Jewell committed
315
316
317
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
318
319
320
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
321
        for k, v in results.items():
322
323
324
325
            print(
                f"Acceptance {k}:",
                tf.reduce_mean(tf.cast(v["is_accepted"], tf.float32)),
            )
Chris Jewell's avatar
Chris Jewell committed
326

Chris Jewell's avatar
Chris Jewell committed
327
    print(
328
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
Chris Jewell's avatar
Chris Jewell committed
329
    )
330
    print(f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
331
332
333
334
335
336
337
338
339
340
341
342
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
343

Chris Jewell's avatar
Chris Jewell committed
344
    del posterior
345
346
347
348


if __name__ == "__main__":

349
350
351
352
353
354
355
356
357
    from argparse import ArgumentParser

    parser = ArgumentParser(description="Run MCMC inference algorithm")
    parser.add_argument(
        "-c", "--config", type=str, help="Config file", required=True
    )
    parser.add_argument(
        "-o", "--output", type=str, help="Output file", required=True
    )
358
    parser.add_argument("data_file", type=str, help="Data NetCDF file")
359
    args = parser.parse_args()
360
361
362
363

    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

364
    mcmc(args.data_file, args.output, config["Mcmc"])