inference.py 12.9 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
3
# pylint: disable=E402

Chris Jewell's avatar
Chris Jewell committed
4
import os
5
from time import perf_counter
6
7
import tqdm
import yaml
8
import numpy as np
9
10
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
11

12
from covid.data import AreaCodeData
13
14
15
16
17
18
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
Chris Jewell's avatar
Chris Jewell committed
19
from gemlib.mcmc import Posterior
20

21
from covid.data import read_phe_cases
Chris Jewell's avatar
Chris Jewell committed
22
from covid.cli_arg_parse import cli_args
23

24
import covid.model_spec as model_spec
25

26
27
tfd = tfp.distributions
tfb = tfp.bijectors
28
DTYPE = model_spec.DTYPE
29
30


31
32
def run_mcmc(config):
    """Constructs and runs the MCMC"""
33

34
35
36
37
    if tf.test.gpu_device_name():
        print("Using GPU")
    else:
        print("Using CPU")
38

39
    inference_period = [
40
        np.datetime64(x) for x in config["Global"]["inference_period"]
41
42
    ]

43
    covar_data = model_spec.read_covariates(config)
44
45
46

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
47
48
49
50
    cases = read_phe_cases(
        config["data"]["reported_cases"],
        date_low=inference_period[0],
        date_high=inference_period[1],
Chris Jewell's avatar
Chris Jewell committed
51
52
        date_type=config["data"]["case_date_type"],
        pillar=config["data"]["pillar"],
53
    ).astype(DTYPE)
54
55

    # Impute censored events, return cases
56
    events = model_spec.impute_censored_events(cases)
57
58
59
60
61
62
63
64
65

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    state = compute_state(
        initial_state=tf.concat(
66
67
            [covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
            axis=-1,
68
        ),
69
        events=events,
70
        stoichiometry=model_spec.STOICHIOMETRY,
71
    )
72
73
74
75
76
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]

    ########################################################
77
    # Construct the MCMC kernels #
78
79
80
81
82
83
    ########################################################
    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
84
    )
85

86
    def joint_log_prob(block0, block1, events):
87
        return model.log_prob(
Chris Jewell's avatar
Chris Jewell committed
88
            dict(
89
                beta2=block0[0],
90
91
                gamma0=block0[1],
                gamma1=block0[2],
Chris Jewell's avatar
Chris Jewell committed
92
                sigma=block0[3],
Chris Jewell's avatar
Chris Jewell committed
93
                beta3=block0[4:],
Chris Jewell's avatar
Chris Jewell committed
94
                beta1=block1[0],
Chris Jewell's avatar
Chris Jewell committed
95
                xi=block1[1:],
Chris Jewell's avatar
Chris Jewell committed
96
97
                seir=events,
            )
98
        )
99

100
101
102
103
104
105
106
107
108
    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
109
    def make_blk0_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
110
        def fn(target_log_prob_fn, _):
111
            return tfp.mcmc.TransformedTransitionKernel(
112
                inner_kernel=AdaptiveRandomWalkMetropolis(
113
                    target_log_prob_fn=target_log_prob_fn,
114
115
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
116
117
                    covariance_burnin=200,
                ),
118
                bijector=tfp.bijectors.Blockwise(
Chris Jewell's avatar
Chris Jewell committed
119
120
121
122
123
124
                    bijectors=[
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                    ],
Chris Jewell's avatar
Chris Jewell committed
125
                    block_sizes=[1, 2, 1, 4],
126
                ),
127
128
                name=name,
            )
129

130
131
        return fn

132
    def make_blk1_kernel(shape, name):
Chris Jewell's avatar
Chris Jewell committed
133
        def fn(target_log_prob_fn, _):
134
            return AdaptiveRandomWalkMetropolis(
135
                target_log_prob_fn=target_log_prob_fn,
136
137
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                * 1e-1,
138
139
140
141
142
                covariance_burnin=200,
                name=name,
            )

        return fn
Chris Jewell's avatar
Chris Jewell committed
143

144
145
146
    def make_partially_observed_step(
        target_event_id, prev_event_id=None, next_event_id=None, name=None
    ):
Chris Jewell's avatar
Chris Jewell committed
147
        def fn(target_log_prob_fn, _):
148
            return tfp.mcmc.MetropolisHastings(
149
                inner_kernel=UncalibratedEventTimesUpdate(
150
                    target_log_prob_fn=target_log_prob_fn,
151
152
153
154
155
156
157
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
158
159
160
161
162
                ),
                name=name,
            )

        return fn
163

164
    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
Chris Jewell's avatar
Chris Jewell committed
165
        def fn(target_log_prob_fn, _):
166
            return tfp.mcmc.MetropolisHastings(
167
                inner_kernel=UncalibratedOccultUpdate(
168
                    target_log_prob_fn=target_log_prob_fn,
169
170
                    topology=TransitionTopology(
                        prev_event_id, target_event_id, next_event_id
171
                    ),
172
173
174
175
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
176
177
178
179
180
181
                ),
                name=name,
            )

        return fn

Chris Jewell's avatar
Chris Jewell committed
182
    def make_event_multiscan_kernel(target_log_prob_fn, _):
183
184
185
186
187
188
189
190
191
192
193
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
194
            ),
195
196
        )

197
198
    # MCMC tracing functions
    def trace_results_fn(_, results):
Chris Jewell's avatar
Chris Jewell committed
199
200
201
202
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

203
        results_dict["block0"] = {
Chris Jewell's avatar
Chris Jewell committed
204
205
206
207
208
            "is_accepted": res0[0].inner_results.is_accepted,
            "target_log_prob": res0[
                0
            ].inner_results.accepted_results.target_log_prob,
        }
209
        results_dict["block1"] = {
Chris Jewell's avatar
Chris Jewell committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted": results.is_accepted,
                "target_log_prob": results.accepted_results.target_log_prob,
                "proposed_delta": tf.stack(
                    [
                        results.accepted_results.m,
                        results.accepted_results.t,
                        results.accepted_results.delta_t,
                        results.accepted_results.x_star,
                    ]
                ),
            }
227

Chris Jewell's avatar
Chris Jewell committed
228
229
230
231
232
        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])
233

Chris Jewell's avatar
Chris Jewell committed
234
        return results_dict
235
236

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
Chris Jewell's avatar
Chris Jewell committed
237
238
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
239
240
241
242
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

243
            gibbs_schema = GibbsKernel(
244
                target_log_prob_fn=joint_log_prob,
245
                kernel_list=[
246
247
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
248
                    (2, make_event_multiscan_kernel),
249
                ],
250
                name="gibbs0",
251
            )
Chris Jewell's avatar
Chris Jewell committed
252

253
254
255
256
            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
Chris Jewell's avatar
Chris Jewell committed
257
                num_steps_between_results=thin,
258
259
260
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
261
262
            )

263
            return samples, results, final_results
264

265
266
267
    ###############################
    # Construct bursted MCMC loop #
    ###############################
Chris Jewell's avatar
Chris Jewell committed
268
269
270
    NUM_BURSTS = int(config["mcmc"]["num_bursts"])
    NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"])
    NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
271
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS
272
273
274
275
276

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
Chris Jewell's avatar
Chris Jewell committed
277
        np.array([0.6, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE),
278
        np.zeros(
279
            model.model["xi"](0.0, 0.1).event_shape[-1] + 1,
280
281
            dtype=DTYPE,
        ),
282
283
        events,
    ]
284
    print("Initial logpi:", joint_log_prob(*current_state))
285

Chris Jewell's avatar
Chris Jewell committed
286
287
288
    # Output file
    samples, results, _ = sample(1, current_state)
    posterior = Posterior(
Chris Jewell's avatar
Chris Jewell committed
289
290
291
292
        os.path.join(
            os.path.expandvars(config["output"]["results_dir"]),
            config["output"]["posterior"],
        ),
293
294
295
296
        sample_dict={
            "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES,)),
            "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES,)),
            "gamma1": (samples[0][:, 2], (NUM_BURST_SAMPLES,)),
Chris Jewell's avatar
Chris Jewell committed
297
298
            "sigma": (samples[0][:, 3], (NUM_BURST_SAMPLES,)),
            "beta3": (samples[0][:, 4:], (NUM_BURST_SAMPLES, 2)),
299
300
301
302
303
304
305
            "beta1": (samples[1][:, 0], (NUM_BURST_SAMPLES,)),
            "xi": (
                samples[1][:, 1:],
                (NUM_BURST_SAMPLES, samples[1].shape[1] - 1),
            ),
            "events": (samples[2], (NUM_BURST_SAMPLES, 64, 64, 1)),
        },
Chris Jewell's avatar
Chris Jewell committed
306
307
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
308
    )
Chris Jewell's avatar
Chris Jewell committed
309
310
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))
311
312
313
314

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
315
    final_results = None
316
317
318
    for i in tqdm.tqdm(
        range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]
    ):
319
        samples, results, final_results = sample(
320
321
            NUM_BURST_SAMPLES,
            init_state=current_state,
Chris Jewell's avatar
Chris Jewell committed
322
            thin=config["mcmc"]["thin"] - 1,
323
            previous_results=final_results,
324
325
326
327
328
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
Chris Jewell's avatar
Chris Jewell committed
329
        posterior.write_samples(
330
331
332
333
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
Chris Jewell's avatar
Chris Jewell committed
334
335
                "sigma": samples[0][:, 3],
                "beta3": samples[0][:, 4:],
336
337
338
339
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
                "events": samples[2],
            },
Chris Jewell's avatar
Chris Jewell committed
340
341
342
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results, first_dim_offset=i * NUM_BURST_SAMPLES)
343
344
345
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
346
347
348
349
350
        for k, v in results:
            print(
                f"Acceptance {k}:",
                tf.reduce_mean(tf.cast(v["is_accepted"], tf.float32)),
            )
Chris Jewell's avatar
Chris Jewell committed
351

Chris Jewell's avatar
Chris Jewell committed
352
    print(
353
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
Chris Jewell's avatar
Chris Jewell committed
354
    )
355
    print(f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
Chris Jewell's avatar
Chris Jewell committed
356
357
358
359
360
361
362
363
364
365
366
367
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )
Chris Jewell's avatar
Chris Jewell committed
368

Chris Jewell's avatar
Chris Jewell committed
369
    del posterior
370
371
372
373
374
375
376
377
378
379
380


if __name__ == "__main__":

    # Read in settings
    args = cli_args()

    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    run_mcmc(config)