mcmc.py 13.8 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
5
from collections import OrderedDict
Chris Jewell's avatar
Chris Jewell committed
6

Chris Jewell's avatar
Chris Jewell committed
7
import h5py
8
import numpy as np
9
10
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
11
12
13
import tqdm
import yaml

14
15
from covid import config
from covid.model import load_data, CovidUKStochastic
Chris Jewell's avatar
Chris Jewell committed
16
from covid.util import sanitise_parameter, sanitise_settings
17
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
18
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
19
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
20

Chris Jewell's avatar
Chris Jewell committed
21
22
23
###########
# TF Bits #
###########
24

Chris Jewell's avatar
Chris Jewell committed
25
26
27
tfd = tfp.distributions
tfb = tfp.bijectors

28
29
DTYPE = config.floatX

30
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
31
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
32

33
34
35
36
37
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

38
39
40
41
42
43
44
45
# Read in settings
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
    default="ode_config.yaml",
    help="configuration file",
46
)
47
48
49
50
options, args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
51
52
    config = yaml.load(f)

53
54
print("Config:", config)

Chris Jewell's avatar
Chris Jewell committed
55
param = sanitise_parameter(config["parameter"])
56
57
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
58
59
60
61
settings = sanitise_settings(config["settings"])

data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
62

Chris Jewell's avatar
Chris Jewell committed
63
64
65
66
67
68
69
70
71
model = CovidUKStochastic(
    C=data["C"],
    N=data["pop"]["n"].to_numpy(),
    W=data["W"],
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
    lockdown=settings["lockdown"],
    time_step=1.0,
)
72
73


74
# Load data
75
with open("stochastic_sim_covid1.pkl", "rb") as f:
76
    example_sim = pkl.load(f)
77

Chris Jewell's avatar
Chris Jewell committed
78
event_tensor = example_sim["events"]  # shape [T, M, S, S]
79
event_tensor = event_tensor[:60, ...]
80
81
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
Chris Jewell's avatar
Chris Jewell committed
82
state_init = example_sim["state_init"]
83
84
85
se_events = event_tensor[:, :, 0, 1]  # [T, M, X]
ei_events = event_tensor[:, :, 1, 2]  # [T, M, X]
ir_events = event_tensor[:, :, 2, 3]  # [T, M, X]
Chris Jewell's avatar
Chris Jewell committed
86

87
88
89
90
91
ir_events = np.pad(ir_events, ((4, 0), (0, 0)), mode="constant", constant_values=0.0)
ei_events = np.roll(ir_events, shift=-2, axis=0)
se_events = np.roll(ir_events, shift=-4, axis=0)
ei_events[-2:, ...] = 0.0
se_events[-4:, ...] = 0.0
Chris Jewell's avatar
Chris Jewell committed
92

Chris Jewell's avatar
Chris Jewell committed
93
94
95
##########################
# Log p and MCMC kernels #
##########################
96
97


98
def logp(par, events, occult_events):
99
    p = param
Chris Jewell's avatar
Chris Jewell committed
100
    p["beta1"] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
101
102
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
103
104
105
106
    p["gamma"] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
Chris Jewell's avatar
Chris Jewell committed
107
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
108
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
109
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
110
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
111
112
113
114
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
115
    with tf.name_scope("epidemic_log_posterior"):
116
        y_logp = model.log_prob(events + occult_events, p, state_init)
Chris Jewell's avatar
Chris Jewell committed
117
    logp = beta1_logp + gamma_logp + y_logp
118
119
120
121
122
123
124
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
125
126
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
127
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
128
129
130
131
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
            name="parameter_update",
        )
Chris Jewell's avatar
Chris Jewell committed
132

133
134
135
    return kernel_func


136
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
137
    def kernel_func(logp):
138
139
140
141
142
143
144
145
146
147
148
149
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                initial_state=state_init,
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
            ),
            name="event_update",
Chris Jewell's avatar
Chris Jewell committed
150
        )
Chris Jewell's avatar
Chris Jewell committed
151

152
153
154
    return kernel_func


155
156
157
158
159
160
161
def make_occults_step(target_event_id):
    def kernel_func(logp):
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                nmax=config["mcmc"]["occult_nmax"],
162
                t_range=[se_events.shape[0] - 21, se_events.shape[0]],
163
164
165
166
167
            ),
            name="occult_update",
        )

    return kernel_func
168
169


170
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
171
    if hasattr(result, "is_accepted"):
172
        return tf.cast(result.is_accepted, DTYPE)
173
    return is_accepted(result.inner_results)
174
175


176
177
178
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
179
    q_ratio = results.proposed_results.log_acceptance_correction
180
    if hasattr(results.proposed_results, "extra"):
181
        proposed = tf.cast(results.proposed_results.extra, log_prob.dtype)
182
        return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
183
    return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
184

185

186
187
188
189
190
def forward_results(prev_results, next_results):
    accepted_results = next_results.accepted_results._replace(
        target_log_prob=prev_results.accepted_results.target_log_prob
    )
    return next_results._replace(accepted_results=accepted_results)
191

192

193
@tf.function(autograph=False, experimental_compile=True)
194
def sample(n_samples, init_state, par_scale, num_event_updates):
Chris Jewell's avatar
Chris Jewell committed
195
196
197
198
199
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
        par_func = make_parameter_kernel(par_scale, 0.95)
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)
200
201
        se_occult = make_occults_step(0)
        ei_occult = make_occults_step(1)
Chris Jewell's avatar
Chris Jewell committed
202
203
204

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        results = [
            par_func(lambda p: logp(p, init_state[1], init_state[2])).bootstrap_results(
                init_state[0]
            ),
            se_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
                init_state[1]
            ),
            ei_func(lambda s: logp(init_state[0], s, init_state[2])).bootstrap_results(
                init_state[1]
            ),
            se_occult(
                lambda s: logp(init_state[0], init_state[1], s)
            ).bootstrap_results(init_state[2]),
            ei_occult(
                lambda s: logp(init_state[0], init_state[1], s)
            ).bootstrap_results(init_state[2]),
        ]
222

Chris Jewell's avatar
Chris Jewell committed
223
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
224
        results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(5)]
225

226
        def body(i, state, results, sample_accum, results_accum):
Chris Jewell's avatar
Chris Jewell committed
227
228
229
230
            # Parameters
            def par_logp(par_state):
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
231

232
233
234
            state[0], results[0] = par_func(par_logp).one_step(
                state[0], forward_results(results[2], results[0])
            )
Chris Jewell's avatar
Chris Jewell committed
235

Chris Jewell's avatar
Chris Jewell committed
236
            # States
237
            results[4] = forward_results(results[0], results[4])
238
239
240
241
242
243

            def infec_body(j, state, results):
                def state_logp(event_state):
                    state[1] = event_state
                    return logp(*state)

244
245
246
247
                def occult_logp(occult_state):
                    state[2] = occult_state
                    return logp(*state)

248
                state[1], results[1] = se_func(state_logp).one_step(
249
                    state[1], forward_results(results[4], results[1])
250
251
252
253
                )
                state[1], results[2] = ei_func(state_logp).one_step(
                    state[1], forward_results(results[1], results[2])
                )
254
255
256
                state[2], results[3] = se_occult(occult_logp).one_step(
                    state[2], forward_results(results[2], results[3])
                )
257
                #                results[3] = forward_results(results[2], results[3])
258
259
260
261
                state[2], results[4] = ei_occult(occult_logp).one_step(
                    state[2], forward_results(results[3], results[4])
                )
                # results[4] = forward_results(results[3], results[4])
262
263
264
265
266
267
268
269
270
271
                j += 1
                return j, state, results

            def infec_cond(j, state, results):
                return j < num_event_updates

            _, state, results = tf.while_loop(
                infec_cond,
                infec_body,
                loop_vars=[tf.constant(0, tf.int32), state, results],
272
            )
273

274
275
276
277
            sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
            results_accum = [
                results_accum[k].write(i, trace_results_fn(r))
                for k, r in enumerate(results)
Chris Jewell's avatar
Chris Jewell committed
278
            ]
279
            return i + 1, state, results, sample_accum, results_accum
280

Chris Jewell's avatar
Chris Jewell committed
281
282
        def cond(i, _1, _2, _3, _4):
            return i < n_samples
283

Chris Jewell's avatar
Chris Jewell committed
284
285
286
287
288
        _1, _2, _3, samples, results = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[0, init_state, results, samples_arr, results_arr],
        )
289

Chris Jewell's avatar
Chris Jewell committed
290
        return [s.stack() for s in samples], [r.stack() for r in results]
291
292


Chris Jewell's avatar
Chris Jewell committed
293
294
295
296
##################
# MCMC loop here #
##################

297
# MCMC Control
298
299
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
300
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
Chris Jewell's avatar
Chris Jewell committed
301

302
# RNG stuff
303
tf.random.set_seed(2)
304
305

# Initial state.  NB [M, T, X] layout for events.
306
307
308
309
events = tf.transpose(
    tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)
)
current_state = [np.array([0.6, 0.25], dtype=DTYPE), events, tf.zeros_like(events)]
Chris Jewell's avatar
Chris Jewell committed
310

311
312

# Output Files
Chris Jewell's avatar
Chris Jewell committed
313
314
315
posterior = h5py.File(
    os.path.expandvars(config["output"]["posterior"]), "w", rdcc_nbytes=1024 ** 3 * 2,
)
316
event_size = [NUM_BURSTS * NUM_BURST_SAMPLES] + list(current_state[1].shape)
Chris Jewell's avatar
Chris Jewell committed
317
318
# event_chunk = (10, 1, 1, 1)
# print("Event chunk size:", event_chunk)
Chris Jewell's avatar
Chris Jewell committed
319
320
par_samples = posterior.create_dataset(
    "samples/parameter",
321
    [NUM_BURSTS * NUM_BURST_SAMPLES, current_state[0].shape[0]],
Chris Jewell's avatar
Chris Jewell committed
322
323
    dtype=np.float64,
)
324
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
325
326
327
    "samples/events",
    event_size,
    dtype=DTYPE,
328
    chunks=(10,) + tuple(current_state[1].shape),
Chris Jewell's avatar
Chris Jewell committed
329
330
331
    compression="gzip",
    compression_opts=1,
)
332
333
334
occult_samples = posterior.create_dataset(
    "samples/occults",
    event_size,
335
    dtype=DTYPE,
336
    chunks=(10,) + tuple(current_state[1].shape),
337
338
    compression="gzip",
    compression_opts=1,
Chris Jewell's avatar
Chris Jewell committed
339
340
)

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
output_results = [
    posterior.create_dataset(
        "results/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/move/S->E",
        (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
        dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/move/E->I",
        (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
        dtype=DTYPE,
    ),
    posterior.create_dataset(
        "results/occult/S->E", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
    ),
    posterior.create_dataset(
        "results/occult/E->I", (NUM_BURSTS * NUM_BURST_SAMPLES, 6), dtype=DTYPE
    ),
]
362

Chris Jewell's avatar
Chris Jewell committed
363
364
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
365
    tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 1.0
Chris Jewell's avatar
Chris Jewell committed
366
367
368
369
)

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
370
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
371
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
372
    samples, results = sample(
373
374
375
376
        NUM_BURST_SAMPLES,
        init_state=current_state,
        par_scale=par_scale,
        num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
377
378
    )
    current_state = [s[-1] for s in samples]
379
    s = slice(i * NUM_BURST_SAMPLES, i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES)
380
381
    par_samples[s, ...] = samples[0].numpy()
    cov = np.cov(
382
        np.log(par_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
383
384
385
        rowvar=False,
    )
    print(current_state[0].numpy())
386

387
388
    print(cov)
    if np.all(np.isfinite(cov)):
389
        par_scale = 2.0 ** 2 * cov / 2.0
390

391
392
393
394
    event_samples[s, ...] = samples[1].numpy()
    occult_samples[s, ...] = samples[2].numpy()
    for i, ro in enumerate(output_results):
        ro[s, ...] = results[i]
395

396
397
398
399
400
401
402
403
404
405
406
407
408
    print("Acceptance par:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print(
        "Acceptance move S->E:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32))
    )
    print(
        "Acceptance move E->I:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32))
    )
    print(
        "Acceptance occult S->E:", tf.reduce_mean(tf.cast(results[3][:, 1], tf.float32))
    )
    print(
        "Acceptance occult E->I:", tf.reduce_mean(tf.cast(results[4][:, 1], tf.float32))
    )
Chris Jewell's avatar
Chris Jewell committed
409

410
411
412
413
414
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
415
416

posterior.close()