inference.py 12 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
from time import perf_counter
Chris Jewell's avatar
Chris Jewell committed
5

Chris Jewell's avatar
Chris Jewell committed
6
import h5py
7
import numpy as np
8
9
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
10
11
12
import tqdm
import yaml

13
from covid import config
Chris Jewell's avatar
Chris Jewell committed
14
from covid.model import load_data, DiscreteTimeStateTransitionModel
15
16
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
Chris Jewell's avatar
Chris Jewell committed
17
from covid.impl.util import compute_state
18
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
19
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
20
from covid.impl.occult_events_mh import UncalibratedOccultUpdate, TransitionTopology
21
22
from covid.impl.gibbs import DeterministicScanKernel, GibbsStep, flatten_results
from covid.impl.multi_scan_kernel import MultiScanKernel
23

Chris Jewell's avatar
Chris Jewell committed
24
25
from model_spec import CovidUK

Chris Jewell's avatar
Chris Jewell committed
26
27
28
###########
# TF Bits #
###########
29

Chris Jewell's avatar
Chris Jewell committed
30
31
32
tfd = tfp.distributions
tfb = tfp.bijectors

33
DTYPE = config.floatX
34
STOICHIOMETRY = tf.constant([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]], dtype=DTYPE)
35

36
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
37
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
38

39
40
41
42
43
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

44
# Read in settings
45
46
47
48
49
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
Chris Jewell's avatar
Chris Jewell committed
50
    default="example_config.yaml",
51
52
53
54
55
56
    help="configuration file",
)
options, cmd_args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
57
    config = yaml.load(f, Loader=yaml.FullLoader)
58

59
settings = sanitise_settings(config["settings"])
60

Chris Jewell's avatar
Chris Jewell committed
61
param = sanitise_parameter(config["parameter"])
62
63
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
64
65
66
67
# Load in covariate data
covar_data = load_data(config["data"], settings, DTYPE)


68
69
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
Chris Jewell's avatar
Chris Jewell committed
70
71
72
73
74
cases = phe_case_data(
    config["data"]["reported_cases"],
    date_range=settings["inference_period"],
    date_type="report",
)
75
76
ei_events, lag_ei = impute_previous_cases(cases, 0.44)
se_events, lag_se = impute_previous_cases(ei_events, 2.0)
77
78
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
Chris Jewell's avatar
Chris Jewell committed
79
80
events = tf.stack([se_events, ei_events, ir_events], axis=-1)

Chris Jewell's avatar
Chris Jewell committed
81
82
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
Chris Jewell's avatar
Chris Jewell committed
83
84
state = compute_state(
    initial_state=tf.concat(
Chris Jewell's avatar
Chris Jewell committed
85
86
        [covar_data["pop"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], axis=-1
    ),
87
    events=events,
88
    stoichiometry=STOICHIOMETRY,
Chris Jewell's avatar
Chris Jewell committed
89
)
Chris Jewell's avatar
Chris Jewell committed
90
91
92
start_time = state.shape[1] - cases.shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
93
94
95
xi_freq = 14
num_xi = events.shape[1] // xi_freq
num_metapop = covar_data["pop"].shape[0]
Chris Jewell's avatar
Chris Jewell committed
96

97

Chris Jewell's avatar
Chris Jewell committed
98
99
100
101
102
103
104
model = CovidUK(
    covariates=covar_data,
    xi_freq=14,
    initial_state=initial_state,
    initial_step=0,
    num_steps=events.shape[1],
)
105

106

Chris Jewell's avatar
Chris Jewell committed
107
108
109
##########################
# Log p and MCMC kernels #
##########################
110
def logp(theta, xi, events):
Chris Jewell's avatar
Chris Jewell committed
111
112
113
114
115
116
117
118
119
    return model.log_prob(
        dict(
            beta1=theta[0],
            beta2=theta[1],
            gamma=theta[2],
            xi=xi,
            nu=param["nu"],
            seir=events,
        )
120
    )
121
122
123
124


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
125
def make_theta_kernel(scale, bounded_convergence, name):
126
127
128
    return GibbsStep(
        0,
        tfp.mcmc.MetropolisHastings(
Chris Jewell's avatar
Chris Jewell committed
129
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
130
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
131
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
132
133
            )
        ),
134
        name=name,
135
    )
136
137


138
def make_xi_kernel(scale, bounded_convergence, name):
139
140
141
    return GibbsStep(
        1,
        tfp.mcmc.RandomWalkMetropolis(
142
143
            target_log_prob_fn=logp,
            new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
144
        ),
145
        name=name,
146
    )
147
148


149
150
151
def make_events_step(
    target_event_id, prev_event_id=None, next_event_id=None, name=None
):
152
153
154
    return GibbsStep(
        2,
        tfp.mcmc.MetropolisHastings(
155
156
157
158
159
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
Chris Jewell's avatar
Chris Jewell committed
160
                initial_state=initial_state,
161
162
163
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
164
165
            )
        ),
166
        name=name,
167
    )
168
169


170
def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
171
    return GibbsStep(
172
        2,
173
        tfp.mcmc.MetropolisHastings(
174
175
            inner_kernel=UncalibratedOccultUpdate(
                target_log_prob_fn=logp,
176
177
178
                topology=TransitionTopology(
                    prev_event_id, target_event_id, next_event_id
                ),
179
                cumulative_event_offset=initial_state,
180
                nmax=config["mcmc"]["occult_nmax"],
181
                t_range=(events.shape[1] - 21, events.shape[1]),
182
                name=name,
183
184
            )
        ),
185
        name=name,
186
    )
187
188


189
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
190
    if hasattr(result, "is_accepted"):
191
        return tf.cast(result.is_accepted, DTYPE)
192
    return is_accepted(result.inner_results)
193
194


195
196
def trace_results_fn(_, results):
    """Returns log_prob, accepted, q_ratio"""
Chris Jewell's avatar
Chris Jewell committed
197

198
199
200
201
202
203
204
205
    def f(result):
        log_prob = result.proposed_results.target_log_prob
        accepted = is_accepted(result)
        q_ratio = result.proposed_results.log_acceptance_correction
        if hasattr(result.proposed_results, "extra"):
            proposed = tf.cast(result.proposed_results.extra, log_prob.dtype)
            return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
        return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
Chris Jewell's avatar
Chris Jewell committed
206

207
208
209
210
211
212
    def recurse(f, list_or_atom):
        if isinstance(list_or_atom, list):
            return [recurse(f, x) for x in list_or_atom]
        return f(list_or_atom)

    return recurse(f, results)
213

214

Chris Jewell's avatar
Chris Jewell committed
215
@tf.function(autograph=False, experimental_compile=True)
216
def sample(n_samples, init_state, sigma_theta, sigma_xi):
Chris Jewell's avatar
Chris Jewell committed
217
    with tf.name_scope("main_mcmc_sample_loop"):
Chris Jewell's avatar
Chris Jewell committed
218

219
        init_state = init_state.copy()
220

221
        kernel = DeterministicScanKernel(
222
            [
223
224
                make_theta_kernel(sigma_theta, 1.0, "theta_kernel"),
                make_xi_kernel(sigma_xi, 1.0, "xi_kernel"),
225
226
227
228
                MultiScanKernel(
                    config["mcmc"]["num_event_time_updates"],
                    DeterministicScanKernel(
                        [
229
230
                            make_events_step(0, None, 1, "se_events"),
                            make_events_step(1, 0, 2, "ei_events"),
231
232
                            make_occults_step(None, 0, 1, "se_occults"),
                            make_occults_step(0, 1, 2, "ei_occults"),
233
234
235
                        ]
                    ),
                ),
Chris Jewell's avatar
Chris Jewell committed
236
            ],
237
238
239
240
241
            name="gibbs_kernel",
        )

        samples, results = tfp.mcmc.sample_chain(
            n_samples, init_state, kernel=kernel, trace_fn=trace_results_fn
Chris Jewell's avatar
Chris Jewell committed
242
        )
243

244
        return samples, results
245
246


Chris Jewell's avatar
Chris Jewell committed
247
248
249
250
##################
# MCMC loop here #
##################

251
# MCMC Control
252
253
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
254
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
255
256
THIN_BURST_SAMPLES = NUM_BURST_SAMPLES // config["mcmc"]["thin"]
NUM_SAVED_SAMPLES = THIN_BURST_SAMPLES * NUM_BURSTS
Chris Jewell's avatar
Chris Jewell committed
257

258
# RNG stuff
259
tf.random.set_seed(2)
260

261
current_state = [
262
    np.array([0.85, 0.3, 0.25], dtype=DTYPE),
263
    np.zeros(num_xi, dtype=DTYPE),
264
265
    events,
]
266
267

# Output Files
Chris Jewell's avatar
Chris Jewell committed
268
posterior = h5py.File(
269
270
271
272
    os.path.expandvars(config["output"]["posterior"]),
    "w",
    rdcc_nbytes=1024 ** 2 * 400,
    rdcc_nslots=100000,
273
    libver="latest",
Chris Jewell's avatar
Chris Jewell committed
274
)
275
event_size = [NUM_SAVED_SAMPLES] + list(current_state[2].shape)
276

Chris Jewell's avatar
Chris Jewell committed
277
278
279
280
posterior.create_dataset("initial_state", data=initial_state)
posterior.create_dataset(
    "inference_period", data=settings["inference_period"].astype("S")
).attrs["description"] = "inference period [start, end)"
281
282
283
284
285
theta_samples = posterior.create_dataset(
    "samples/theta", [NUM_SAVED_SAMPLES, current_state[0].shape[0]], dtype=np.float64,
)
xi_samples = posterior.create_dataset(
    "samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
Chris Jewell's avatar
Chris Jewell committed
286
)
287
event_samples = posterior.create_dataset(
Chris Jewell's avatar
Chris Jewell committed
288
289
290
    "samples/events",
    event_size,
    dtype=DTYPE,
291
292
    chunks=(32, 64, 64, 1),
    compression="szip",
Chris Jewell's avatar
Chris Jewell committed
293
    compression_opts=("nn", 16),
Chris Jewell's avatar
Chris Jewell committed
294
)
Chris Jewell's avatar
Chris Jewell committed
295

296
output_results = [
297
298
    posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
    posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
299
    posterior.create_dataset(
300
        "results/move/S->E", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
301
302
    ),
    posterior.create_dataset(
303
        "results/move/E->I", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
304
305
    ),
    posterior.create_dataset(
306
        "results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
307
308
    ),
    posterior.create_dataset(
309
        "results/occult/E->I", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
310
311
    ),
]
312
posterior.swmr_mode = True
313

Chris Jewell's avatar
Chris Jewell committed
314
print("Initial logpi:", logp(*current_state))
315

316
theta_scale = tf.constant(
317
318
319
320
321
322
    [
        [1.12e-3, 1.67e-4, 1.61e-4],
        [1.67e-4, 7.41e-4, 4.68e-5],
        [1.61e-4, 4.68e-5, 1.28e-4],
    ],
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
323
)
324
325
theta_scale = theta_scale * 0.2 / theta_scale.shape[0]

326
xi_scale = tf.eye(current_state[1].shape[0], dtype=DTYPE)
327
xi_scale = xi_scale * 0.0001 / xi_scale.shape[0]
Chris Jewell's avatar
Chris Jewell committed
328
329
330

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
331
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
332
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
333
    samples, results = sample(
334
335
        NUM_BURST_SAMPLES,
        init_state=current_state,
336
337
        sigma_theta=theta_scale,
        sigma_xi=xi_scale,
338
339
    )
    current_state = [s[-1] for s in samples]
340
341
    s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
    idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
342
343
    theta_samples[s, ...] = tf.gather(samples[0], idx)
    xi_samples[s, ...] = tf.gather(samples[1], idx)
344
345
346
347
    # cov = np.cov(
    #     np.log(theta_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
    #     rowvar=False,
    # )
348
    print(current_state[0].numpy(), flush=True)
349
350
351
    # print(cov, flush=True)
    # if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
    #     theta_scale = 2.38 ** 2 * cov / 2.0
352

353
    start = perf_counter()
354
    event_samples[s, ...] = tf.gather(samples[2], idx)
355
356
    end = perf_counter()

357
    flat_results = flatten_results(results)
358
    for i, ro in enumerate(output_results):
359
        ro[s, ...] = tf.gather(flat_results[i], idx)
360

361
    posterior.flush()
362
    print("Storage time:", end - start, "seconds")
363
    print(
364
365
366
367
368
369
        "Acceptance theta:", tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32))
    )
    print("Acceptance xi:", tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32)))
    print(
        "Acceptance move S->E:",
        tf.reduce_mean(tf.cast(flat_results[2][:, 1], tf.float32)),
370
371
    )
    print(
372
373
        "Acceptance move E->I:",
        tf.reduce_mean(tf.cast(flat_results[3][:, 1], tf.float32)),
374
375
    )
    print(
376
377
        "Acceptance occult S->E:",
        tf.reduce_mean(tf.cast(flat_results[4][:, 1], tf.float32)),
378
379
    )
    print(
380
381
        "Acceptance occult E->I:",
        tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
382
    )
Chris Jewell's avatar
Chris Jewell committed
383

384
385
386
387
388
389
print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
Chris Jewell's avatar
Chris Jewell committed
390
391

posterior.close()