mcmc.py 11.4 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
5
from collections import OrderedDict
6
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
7

Chris Jewell's avatar
Chris Jewell committed
8
import h5py
9
import numpy as np
10
11
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
12
13
14
import tqdm
import yaml

15
16
from covid import config
from covid.model import load_data, CovidUKStochastic
17
18
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
19
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
20
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
21
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
22
from covid.impl.gibbs import GibbsKernel, GibbsStep
23

Chris Jewell's avatar
Chris Jewell committed
24
25
26
###########
# TF Bits #
###########
27

Chris Jewell's avatar
Chris Jewell committed
28
29
30
tfd = tfp.distributions
tfb = tfp.bijectors

31
32
DTYPE = config.floatX

33
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
34
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
35

36
37
38
39
40
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

41
# Read in settings
42
43
44
45
46
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
Chris Jewell's avatar
Chris Jewell committed
47
    default="example_config.yaml",
48
49
50
51
52
53
    help="configuration file",
)
options, cmd_args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
54
    config = yaml.load(f, Loader=yaml.FullLoader)
55

56
settings = sanitise_settings(config["settings"])
57

Chris Jewell's avatar
Chris Jewell committed
58
param = sanitise_parameter(config["parameter"])
59
60
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

61
62
63
covar_data = load_data(config["data"], settings, DTYPE)

cases = phe_case_data(config["data"]["reported_cases"], settings["inference_period"])
64
65
ei_events, lag_ei = impute_previous_cases(cases, 0.44)
se_events, lag_se = impute_previous_cases(ei_events, 2.0)
66
67
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
Chris Jewell's avatar
Chris Jewell committed
68

69

Chris Jewell's avatar
Chris Jewell committed
70
model = CovidUKStochastic(
71
72
73
    C=covar_data["C"],
    N=covar_data["pop"],
    W=covar_data["W"],
Chris Jewell's avatar
Chris Jewell committed
74
75
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
76
    xi_freq=14,
Chris Jewell's avatar
Chris Jewell committed
77
78
    time_step=1.0,
)
79

Chris Jewell's avatar
Chris Jewell committed
80
81
82
##########################
# Log p and MCMC kernels #
##########################
83
def logp(theta, xi, events, occult_events):
84
    p = param
85
86
87
88
    p["beta1"] = tf.convert_to_tensor(theta[0], dtype=DTYPE)
    p["beta2"] = tf.convert_to_tensor(theta[1], dtype=DTYPE)
    p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
    p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
89

Chris Jewell's avatar
Chris Jewell committed
90
91
92
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
93
94
95
96
97
98
99
100
101

    sigma = tf.constant(0.1, dtype=DTYPE)
    phi = tf.constant(12.0, dtype=DTYPE)
    kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
    xi_logp = tfd.GaussianProcess(
        kernel, index_points=tf.cast(model.xi_times[:, tf.newaxis], DTYPE)
    ).log_prob(p["xi"])

    spatial_beta_logp = tfd.Gamma(
Chris Jewell's avatar
Chris Jewell committed
102
103
        concentration=tf.constant(3.0, dtype=DTYPE), rate=tf.constant(10.0, dtype=DTYPE)
    ).log_prob(p["beta2"])
104

Chris Jewell's avatar
Chris Jewell committed
105
106
107
108
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
109
    with tf.name_scope("epidemic_log_posterior"):
110
        y_logp = model.log_prob(events + occult_events, p, state_init)
111
    logp = beta1_logp + spatial_beta_logp + gamma_logp + xi_logp + y_logp
112
113
114
115
116
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
117
def make_theta_kernel(scale, bounded_convergence):
118
119
120
    return GibbsStep(
        0,
        tfp.mcmc.MetropolisHastings(
Chris Jewell's avatar
Chris Jewell committed
121
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
122
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
123
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
124
125
126
127
            )
        ),
        name="update_theta",
    )
128
129
130


def make_xi_kernel(scale, bounded_convergence):
131
132
133
    return GibbsStep(
        1,
        tfp.mcmc.RandomWalkMetropolis(
134
135
            target_log_prob_fn=logp,
            new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
136
137
138
        ),
        name="xi_update",
    )
139
140


141
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
142
143
144
    return GibbsStep(
        2,
        tfp.mcmc.MetropolisHastings(
145
146
147
148
149
150
151
152
153
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                initial_state=state_init,
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
154
155
156
157
            )
        ),
        name="event_update",
    )
158
159


160
def make_occults_step(target_event_id):
161
162
163
    return GibbsStep(
        3,
        tfp.mcmc.MetropolisHastings(
164
165
166
167
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                nmax=config["mcmc"]["occult_nmax"],
168
169
170
171
172
                t_range=(se_events.shape[1] - 22, se_events.shape[1] - 1),
            )
        ),
        name="occult_update",
    )
173
174


175
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
176
    if hasattr(result, "is_accepted"):
177
        return tf.cast(result.is_accepted, DTYPE)
178
    return is_accepted(result.inner_results)
179
180


181
182
def trace_results_fn(_, results):
    """Returns log_prob, accepted, q_ratio"""
Chris Jewell's avatar
Chris Jewell committed
183

184
185
186
187
188
189
190
191
    def f(result):
        log_prob = result.proposed_results.target_log_prob
        accepted = is_accepted(result)
        q_ratio = result.proposed_results.log_acceptance_correction
        if hasattr(result.proposed_results, "extra"):
            proposed = tf.cast(result.proposed_results.extra, log_prob.dtype)
            return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
        return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
Chris Jewell's avatar
Chris Jewell committed
192

193
    return [f(result) for result in results]
194

195

196
@tf.function(autograph=False, experimental_compile=True)
197
def sample(n_samples, init_state, scale_theta, scale_xi, num_event_updates):
Chris Jewell's avatar
Chris Jewell committed
198
    with tf.name_scope("main_mcmc_sample_loop"):
Chris Jewell's avatar
Chris Jewell committed
199

200
        init_state = init_state.copy()
201

202
203
204
205
206
207
208
209
        kernel = GibbsKernel(
            [
                make_theta_kernel(theta_scale, 0.0),
                make_xi_kernel(xi_scale, 0.0),
                make_events_step(0, None, 1),
                make_events_step(1, 0, 2),
                make_occults_step(0),
                make_occults_step(1),
Chris Jewell's avatar
Chris Jewell committed
210
            ],
211
212
213
214
215
            name="gibbs_kernel",
        )

        samples, results = tfp.mcmc.sample_chain(
            n_samples, init_state, kernel=kernel, trace_fn=trace_results_fn
Chris Jewell's avatar
Chris Jewell committed
216
        )
217

218
        return samples, results
219
220


Chris Jewell's avatar
Chris Jewell committed
221
222
223
224
##################
# MCMC loop here #
##################

225
# MCMC Control
226
227
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
228
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
229
230
THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS
Chris Jewell's avatar
Chris Jewell committed
231

232
# RNG stuff
233
tf.random.set_seed(2)
234
235

# Initial state.  NB [M, T, X] layout for events.
236
237
238
239
events = tf.stack([se_events, ei_events, ir_events], axis=-1)
state_init = tf.concat([model.N[:, tf.newaxis], events[:, 0, :]], axis=-1)
events = events[:, 1:, :]
current_state = [
240
    np.array([0.85, 0.3, 0.25], dtype=DTYPE),
241
    np.zeros(model.num_xi, dtype=DTYPE),
242
243
244
    events,
    tf.zeros_like(events),
]
245
246

# Output Files
Chris Jewell's avatar
Chris Jewell committed
247
posterior = h5py.File(
248
249
250
251
    os.path.expandvars(config["output"]["posterior"]),
    "w",
    rdcc_nbytes=1024 ** 2 * 400,
    rdcc_nslots=100000,
Chris Jewell's avatar
Chris Jewell committed
252
)
253
event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)
254

255
256
257
258
259
theta_samples = posterior.create_dataset(
    "samples/theta", [NUM_SAVED_SAMPLES, current_state[0].shape[0]], dtype=np.float64,
)
xi_samples = posterior.create_dataset(
    "samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
Chris Jewell's avatar
Chris Jewell committed
260
)
261
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
262
263
264
    "samples/events",
    event_size,
    dtype=DTYPE,
265
266
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
267
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
268
)
269
270
271
occult_samples = posterior.create_dataset(
    "samples/occults",
    event_size,
272
    dtype=DTYPE,
273
274
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
275
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
276
277
)

278
output_results = [
279
280
    posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
    posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
281
    posterior.create_dataset(
282
        "results/move/S->E", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
283
284
    ),
    posterior.create_dataset(
285
        "results/move/E->I", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
286
287
    ),
    posterior.create_dataset(
288
        "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
289
290
    ),
    posterior.create_dataset(
291
        "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
292
293
    ),
]
294

Chris Jewell's avatar
Chris Jewell committed
295
print("Initial logpi:", logp(*current_state))
296

297
theta_scale = tf.constant(
Chris Jewell's avatar
Chris Jewell committed
298
299
    [[0.1, 0.0, 0.0], [0.0, 0.8, 0.0], [0.0, 0.0, 0.1]], dtype=current_state[0].dtype
)
300
301
302
xi_scale = tf.linalg.diag(
    tf.constant([0.1] * model.num_xi.numpy(), dtype=current_state[1].dtype)
)
Chris Jewell's avatar
Chris Jewell committed
303
304
305

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
306
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
307
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
308
    samples, results = sample(
309
310
        NUM_BURST_SAMPLES,
        init_state=current_state,
311
312
        scale_theta=theta_scale,
        scale_xi=xi_scale,
313
        num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
314
315
    )
    current_state = [s[-1] for s in samples]
316
317
    s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
    idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
318
319
    theta_samples[s, ...] = tf.gather(samples[0], idx)
    xi_samples[s, ...] = tf.gather(samples[1], idx)
320
    cov = np.cov(
321
        np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
322
323
        rowvar=False,
    )
324
325
    print(current_state[0].numpy(), flush=True)
    print(cov, flush=True)
326
    if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
327
        theta_scale = 2.38 ** 2 * cov / 2.0
328

329
    start = perf_counter()
330
331
    event_samples[s, ...] = tf.gather(samples[2], idx)
    occult_samples[s, ...] = tf.gather(samples[3], idx)
332
333
    end = perf_counter()

334
    for i, ro in enumerate(output_results):
335
        ro[s, ...] = tf.gather(results[i], idx)
336

337
    print("Storage time:", end - start, "seconds")
338
339
    print("Acceptance theta:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print("Acceptance xi:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
340
    print(
341
        "Acceptance move S->E:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32))
342
343
    )
    print(
344
        "Acceptance move E->I:", tf.reduce_mean(tf.cast(results[3][:, 1], tf.float32))
345
346
    )
    print(
347
        "Acceptance occult S->E:", tf.reduce_mean(tf.cast(results[4][:, 1], tf.float32))
348
349
    )
    print(
350
        "Acceptance occult E->I:", tf.reduce_mean(tf.cast(results[5][:, 1], tf.float32))
351
    )
Chris Jewell's avatar
Chris Jewell committed
352

353
354
355
356
357
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
358
359

posterior.close()