mcmc.py 14.2 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
5

Chris Jewell's avatar
Chris Jewell committed
6
import h5py
7
import numpy as np
8
9
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
10
11
12
import tqdm
import yaml

13
14
from covid import config
from covid.model import load_data, CovidUKStochastic
15
16
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
Chris Jewell's avatar
Chris Jewell committed
17
from covid.impl.util import compute_state
18
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
19
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
20
from covid.impl.occult_events_mh import UncalibratedOccultUpdate, TransitionTopology
21
22
from covid.impl.gibbs import DeterministicScanKernel, GibbsStep, flatten_results
from covid.impl.multi_scan_kernel import MultiScanKernel
23

Chris Jewell's avatar
Chris Jewell committed
24
25
26
###########
# TF Bits #
###########
27

Chris Jewell's avatar
Chris Jewell committed
28
29
30
tfd = tfp.distributions
tfb = tfp.bijectors

31
DTYPE = config.floatX
32
STOICHIOMETRY = tf.constant([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]], dtype=DTYPE)
33

34
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
35
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
36

37
38
39
40
41
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

42
# Read in settings
43
44
45
46
47
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
Chris Jewell's avatar
Chris Jewell committed
48
    default="example_config.yaml",
49
50
51
52
53
54
    help="configuration file",
)
options, cmd_args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
55
    config = yaml.load(f, Loader=yaml.FullLoader)
56

57
settings = sanitise_settings(config["settings"])
58

Chris Jewell's avatar
Chris Jewell committed
59
param = sanitise_parameter(config["parameter"])
60
61
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
62
63
64
65
# Load in covariate data
covar_data = load_data(config["data"], settings, DTYPE)


66
67
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
68
cases = phe_case_data(config["data"]["reported_cases"], date_range=settings["inference_period"], date_type='report')
69
70
ei_events, lag_ei = impute_previous_cases(cases, 0.44)
se_events, lag_se = impute_previous_cases(ei_events, 2.0)
71
72
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
Chris Jewell's avatar
Chris Jewell committed
73
74
events = tf.stack([se_events, ei_events, ir_events], axis=-1)

Chris Jewell's avatar
Chris Jewell committed
75
76
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
Chris Jewell's avatar
Chris Jewell committed
77
78
state = compute_state(
    initial_state=tf.concat(
Chris Jewell's avatar
Chris Jewell committed
79
80
        [covar_data["pop"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], axis=-1
    ),
81
    events=events,
82
    stoichiometry=STOICHIOMETRY,
Chris Jewell's avatar
Chris Jewell committed
83
)
Chris Jewell's avatar
Chris Jewell committed
84
85
86
start_time = state.shape[1] - cases.shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
87
88
89
xi_freq = 14
num_xi = events.shape[1] // xi_freq
num_metapop = covar_data["pop"].shape[0]
Chris Jewell's avatar
Chris Jewell committed
90

91

92
# Create the epidemic model given parameters
93
def build_epidemic(param):
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def transition_rates(t, state):

        C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
        C = tf.linalg.set_diag(C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE))
        W = tf.constant(covar_data["W"], dtype=DTYPE)
        N = tf.constant(covar_data["pop"], dtype=DTYPE)

        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        xi_idx = tf.cast(
            tf.clip_by_value(t // 14, 0, param["xi"].shape[0] - 1), dtype=tf.int64,
        )
        xi = tf.gather(param["xi"], xi_idx)
        beta = param["beta1"] * tf.math.exp(xi)

        infec_rate = beta * (
            state[..., 2]
            + param["beta2"] * commute_volume * tf.linalg.matvec(C, state[..., 2] / N)
        )
        infec_rate = infec_rate / N + 0.000000001  # Vector of length nc

        ei = tf.broadcast_to(
            [param["nu"]], shape=[state.shape[0]]
        )  # Vector of length nc
        ir = tf.broadcast_to(
            [param["gamma"]], shape=[state.shape[0]]
        )  # Vector of length nc

        return [infec_rate, ei, ir]

124
    return CovidUKStochastic(
125
126
        transition_rates=transition_rates,
        stoichiometry=STOICHIOMETRY,
127
        initial_state=initial_state,
128
129
        initial_step=0,
        time_delta=1.0,
130
131
132
        num_steps=events.shape[1],
    )

133

Chris Jewell's avatar
Chris Jewell committed
134
135
136
##########################
# Log p and MCMC kernels #
##########################
137
def logp(theta, xi, events):
138
    p = param
139
140
141
142
    p["beta1"] = tf.convert_to_tensor(theta[0], dtype=DTYPE)
    p["beta2"] = tf.convert_to_tensor(theta[1], dtype=DTYPE)
    p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
    p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
143

144
    beta1 = tfd.Gamma(
Chris Jewell's avatar
Chris Jewell committed
145
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
146
    )
147

148
    sigma = tf.constant(0.01, dtype=DTYPE)
149
150
    phi = tf.constant(12.0, dtype=DTYPE)
    kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
151
152
    idx_pts = tf.cast(tf.range(events.shape[1] // xi_freq) * xi_freq, dtype=DTYPE)
    xi = tfd.GaussianProcess(kernel, index_points=idx_pts[:, tf.newaxis])
153

154
    spatial_beta = tfd.Gamma(
Chris Jewell's avatar
Chris Jewell committed
155
        concentration=tf.constant(3.0, dtype=DTYPE), rate=tf.constant(10.0, dtype=DTYPE)
156
    )
157

158
    gamma = tfd.Gamma(
Chris Jewell's avatar
Chris Jewell committed
159
160
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
161
162
    )

Chris Jewell's avatar
Chris Jewell committed
163
    with tf.name_scope("epidemic_log_posterior"):
164
165
166
167
168
169
170
171
172
        seir = build_epidemic(p)

    return (
        beta1.log_prob(p["beta1"])
        + xi.log_prob(p["xi"])
        + spatial_beta.log_prob(p["beta2"])
        + gamma.log_prob(p["gamma"])
        + seir.log_prob(events)
    )
173
174
175
176


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
177
def make_theta_kernel(scale, bounded_convergence, name):
178
179
180
    return GibbsStep(
        0,
        tfp.mcmc.MetropolisHastings(
Chris Jewell's avatar
Chris Jewell committed
181
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
182
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
183
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
184
185
            )
        ),
186
        name=name,
187
    )
188
189


190
def make_xi_kernel(scale, bounded_convergence, name):
191
192
193
    return GibbsStep(
        1,
        tfp.mcmc.RandomWalkMetropolis(
194
195
            target_log_prob_fn=logp,
            new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
196
        ),
197
        name=name,
198
    )
199
200


201
202
203
def make_events_step(
    target_event_id, prev_event_id=None, next_event_id=None, name=None
):
204
205
206
    return GibbsStep(
        2,
        tfp.mcmc.MetropolisHastings(
207
208
209
210
211
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
Chris Jewell's avatar
Chris Jewell committed
212
                initial_state=initial_state,
213
214
215
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
216
217
            )
        ),
218
        name=name,
219
    )
220
221


222
def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
223
    return GibbsStep(
224
        2,
225
        tfp.mcmc.MetropolisHastings(
226
227
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
228
229
230
                topology=TransitionTopology(
                    prev_event_id, target_event_id, next_event_id
                ),
231
                cumulative_event_offset=initial_state,
232
                nmax=config["mcmc"]["occult_nmax"],
233
                t_range=(events.shape[1] - 21, events.shape[1]),
234
                name=name,
235
236
            )
        ),
237
        name=name,
238
    )
239
240


241
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
242
    if hasattr(result, "is_accepted"):
243
        return tf.cast(result.is_accepted, DTYPE)
244
    return is_accepted(result.inner_results)
245
246


247
248
def trace_results_fn(_, results):
    """Returns log_prob, accepted, q_ratio"""
Chris Jewell's avatar
Chris Jewell committed
249

250
251
252
253
254
255
256
257
    def f(result):
        log_prob = result.proposed_results.target_log_prob
        accepted = is_accepted(result)
        q_ratio = result.proposed_results.log_acceptance_correction
        if hasattr(result.proposed_results, "extra"):
            proposed = tf.cast(result.proposed_results.extra, log_prob.dtype)
            return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
        return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
Chris Jewell's avatar
Chris Jewell committed
258

259
260
261
262
263
264
    def recurse(f, list_or_atom):
        if isinstance(list_or_atom, list):
            return [recurse(f, x) for x in list_or_atom]
        return f(list_or_atom)

    return recurse(f, results)
265

266

Chris Jewell's avatar
Chris Jewell committed
267
@tf.function(autograph=False, experimental_compile=True)
268
def sample(n_samples, init_state, sigma_theta, sigma_xi):
Chris Jewell's avatar
Chris Jewell committed
269
    with tf.name_scope("main_mcmc_sample_loop"):
Chris Jewell's avatar
Chris Jewell committed
270

271
        init_state = init_state.copy()
272

273
        kernel = DeterministicScanKernel(
274
            [
275
276
                make_theta_kernel(sigma_theta, 1.0, "theta_kernel"),
                make_xi_kernel(sigma_xi, 1.0, "xi_kernel"),
277
278
279
280
                MultiScanKernel(
                    config["mcmc"]["num_event_time_updates"],
                    DeterministicScanKernel(
                        [
281
282
                            make_events_step(0, None, 1, "se_events"),
                            make_events_step(1, 0, 2, "ei_events"),
283
284
                            make_occults_step(None, 0, 1, "se_occults"),
                            make_occults_step(0, 1, 2, "ei_occults"),
285
286
287
                        ]
                    ),
                ),
Chris Jewell's avatar
Chris Jewell committed
288
            ],
289
290
291
292
293
            name="gibbs_kernel",
        )

        samples, results = tfp.mcmc.sample_chain(
            n_samples, init_state, kernel=kernel, trace_fn=trace_results_fn
Chris Jewell's avatar
Chris Jewell committed
294
        )
295

296
        return samples, results
297
298


Chris Jewell's avatar
Chris Jewell committed
299
300
301
302
##################
# MCMC loop here #
##################

303
# MCMC Control
304
305
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
306
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
307
308
THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS
Chris Jewell's avatar
Chris Jewell committed
309

310
# RNG stuff
311
tf.random.set_seed(2)
312

313
current_state = [
314
    np.array([0.85, 0.3, 0.25], dtype=DTYPE),
315
    np.zeros(num_xi, dtype=DTYPE),
316
317
    events,
]
318
319

# Output Files
Chris Jewell's avatar
Chris Jewell committed
320
posterior = h5py.File(
321
322
323
324
    os.path.expandvars(config["output"]["posterior"]),
    "w",
    rdcc_nbytes=1024 ** 2 * 400,
    rdcc_nslots=100000,
325
    libver="latest",
Chris Jewell's avatar
Chris Jewell committed
326
)
327
event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)
328

Chris Jewell's avatar
Chris Jewell committed
329
330
331
332
posterior.create_dataset("initial_state", data=initial_state)
posterior.create_dataset(
    "inference_period", data=settings["inference_period"].astype("S")
).attrs["description"] = "inference period [start, end)"
333
334
335
336
337
theta_samples = posterior.create_dataset(
    "samples/theta", [NUM_SAVED_SAMPLES, current_state[0].shape[0]], dtype=np.float64,
)
xi_samples = posterior.create_dataset(
    "samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
Chris Jewell's avatar
Chris Jewell committed
338
)
339
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
340
341
342
    "samples/events",
    event_size,
    dtype=DTYPE,
343
344
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
345
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
346
)
Chris Jewell's avatar
Chris Jewell committed
347

348
output_results = [
349
350
    posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
    posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
351
    posterior.create_dataset(
352
        "results/move/S->E", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
353
354
    ),
    posterior.create_dataset(
355
        "results/move/E->I", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
356
357
    ),
    posterior.create_dataset(
358
        "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
359
360
    ),
    posterior.create_dataset(
361
        "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
362
363
    ),
]
364
posterior.swmr_mode = True
365

Chris Jewell's avatar
Chris Jewell committed
366
print("Initial logpi:", logp(*current_state))
367

368
theta_scale = tf.constant(
369
370
371
372
373
374
    [
        [1.12e-3, 1.67e-4, 1.61e-4],
        [1.67e-4, 7.41e-4, 4.68e-5],
        [1.61e-4, 4.68e-5, 1.28e-4],
    ],
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
375
)
376
377
theta_scale = theta_scale * 0.2 / theta_scale.shape[0]

378
xi_scale = tf.eye(current_state[1].shape[0], dtype=DTYPE)
379
xi_scale = xi_scale * 0.0001 / xi_scale.shape[0]
Chris Jewell's avatar
Chris Jewell committed
380
381
382

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
383
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
384
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
385
    samples, results = sample(
386
387
        NUM_BURST_SAMPLES,
        init_state=current_state,
388
389
        sigma_theta=theta_scale,
        sigma_xi=xi_scale,
390
391
    )
    current_state = [s[-1] for s in samples]
392
393
    s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
    idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
394
395
    theta_samples[s, ...] = tf.gather(samples[0], idx)
    xi_samples[s, ...] = tf.gather(samples[1], idx)
396
397
398
399
    # cov = np.cov(
    #     np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
    #     rowvar=False,
    # )
400
    print(current_state[0].numpy(), flush=True)
401
402
403
    # print(cov, flush=True)
    # if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
    #     theta_scale = 2.38 ** 2 * cov / 2.0
404

405
    start = perf_counter()
406
    event_samples[s, ...] = tf.gather(samples[2], idx)
407
408
    end = perf_counter()

409
    flat_results = flatten_results(results)
410
    for i, ro in enumerate(output_results):
411
        ro[s, ...] = tf.gather(flat_results[i], idx)
412

413
    posterior.flush()
414
    print("Storage time:", end - start, "seconds")
415
    print(
416
417
418
419
420
421
        "Acceptance theta:", tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32))
    )
    print("Acceptance xi:", tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32)))
    print(
        "Acceptance move S->E:",
        tf.reduce_mean(tf.cast(flat_results[2][:, 1], tf.float32)),
422
423
    )
    print(
424
425
        "Acceptance move E->I:",
        tf.reduce_mean(tf.cast(flat_results[3][:, 1], tf.float32)),
426
427
    )
    print(
428
429
        "Acceptance occult S->E:",
        tf.reduce_mean(tf.cast(flat_results[4][:, 1], tf.float32)),
430
431
    )
    print(
432
433
        "Acceptance occult E->I:",
        tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
434
    )
Chris Jewell's avatar
Chris Jewell committed
435

436
437
438
439
440
441
print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
442
443

posterior.close()