mcmc.py 11.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
5

Chris Jewell's avatar
Chris Jewell committed
6
import h5py
7
import numpy as np
8
9
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
10
11
12
import tqdm
import yaml

13
14
from covid import config
from covid.model import load_data, CovidUKStochastic
Chris Jewell's avatar
Chris Jewell committed
15
from covid.util import sanitise_parameter, sanitise_settings
16
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
17
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
18

19

Chris Jewell's avatar
Chris Jewell committed
20
21
22
###########
# TF Bits #
###########
23

Chris Jewell's avatar
Chris Jewell committed
24
25
26
tfd = tfp.distributions
tfb = tfp.bijectors

27
28
DTYPE = config.floatX

29
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
30
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
31

32
33
34
35
36
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

37
38
39
40
41
42
43
44
# Read in settings
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
    default="ode_config.yaml",
    help="configuration file",
45
)
46
47
48
49
options, args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
50
51
    config = yaml.load(f)

52
53
print("Config:", config)

Chris Jewell's avatar
Chris Jewell committed
54
param = sanitise_parameter(config["parameter"])
55
56
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
57
58
59
60
settings = sanitise_settings(config["settings"])

data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
61

Chris Jewell's avatar
Chris Jewell committed
62
63
64
65
66
67
68
69
70
model = CovidUKStochastic(
    C=data["C"],
    N=data["pop"]["n"].to_numpy(),
    W=data["W"],
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
    lockdown=settings["lockdown"],
    time_step=1.0,
)
71
72


73
# Load data
Chris Jewell's avatar
Chris Jewell committed
74
with open("stochastic_sim_covid.pkl", "rb") as f:
75
    example_sim = pkl.load(f)
76

Chris Jewell's avatar
Chris Jewell committed
77
event_tensor = example_sim["events"]  # shape [T, M, S, S]
78
79
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
Chris Jewell's avatar
Chris Jewell committed
80
state_init = example_sim["state_init"]
81
82
83
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
Chris Jewell's avatar
Chris Jewell committed
84

Chris Jewell's avatar
Chris Jewell committed
85

Chris Jewell's avatar
Chris Jewell committed
86
87
88
##########################
# Log p and MCMC kernels #
##########################
89
90


91
def logp(par, events, occult_events):
92
    p = param
Chris Jewell's avatar
Chris Jewell committed
93
    p["beta1"] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
94
95
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
96
97
98
99
    p["gamma"] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
Chris Jewell's avatar
Chris Jewell committed
100
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
101
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
102
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
103
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
104
105
106
107
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
108
    with tf.name_scope("epidemic_log_posterior"):
109
        y_logp = model.log_prob(events + occult_events, p, state_init)
Chris Jewell's avatar
Chris Jewell committed
110
    logp = beta1_logp + gamma_logp + y_logp
111
112
113
114
115
116
117
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
118
119
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
120
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
121
122
123
124
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
            name="parameter_update",
        )
Chris Jewell's avatar
Chris Jewell committed
125

126
127
128
    return kernel_func


129
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
130
    def kernel_func(logp):
131
132
133
134
135
136
137
138
139
140
141
142
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=logp,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                initial_state=state_init,
                dmax=config["mcmc"]["dmax"],
                mmax=config["mcmc"]["m"],
                nmax=config["mcmc"]["nmax"],
            ),
            name="event_update",
Chris Jewell's avatar
Chris Jewell committed
143
        )
Chris Jewell's avatar
Chris Jewell committed
144

145
146
147
    return kernel_func


148
149
150
151
def make_occults_step():
    pass


152
def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
153
    if hasattr(result, "is_accepted"):
154
        return tf.cast(result.is_accepted, DTYPE)
155
    return is_accepted(result.inner_results)
156
157


158
159
160
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
161
    q_ratio = results.proposed_results.log_acceptance_correction
162
163
164
    if hasattr(results.proposed_results, "extra"):
        proposed = results.proposed_results.extra
        return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
165
    return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
166

167

168
169
170
171
172
def forward_results(prev_results, next_results):
    accepted_results = next_results.accepted_results._replace(
        target_log_prob=prev_results.accepted_results.target_log_prob
    )
    return next_results._replace(accepted_results=accepted_results)
173

174

175
@tf.function(autograph=False, experimental_compile=True)
176
def sample(n_samples, init_state, par_scale, num_event_updates):
Chris Jewell's avatar
Chris Jewell committed
177
178
179
180
181
182
183
184
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
        par_func = make_parameter_kernel(par_scale, 0.95)
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
185
186
187
188
189
190
191
192
193
        par_results = par_func(
            lambda p: logp(p, init_state[1], init_state[2])
        ).bootstrap_results(init_state[0])
        se_results = se_func(
            lambda s: logp(init_state[0], s, init_state[2])
        ).bootstrap_results(init_state[1])
        ei_results = ei_func(
            lambda s: logp(init_state[0], s, init_state[2])
        ).bootstrap_results(init_state[1])
194
        results = [par_results, se_results, ei_results]
195

Chris Jewell's avatar
Chris Jewell committed
196
197
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
        results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
198

199
        def body(i, state, results, sample_accum, results_accum):
Chris Jewell's avatar
Chris Jewell committed
200
201
202
203
            # Parameters
            def par_logp(par_state):
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
204

205
206
207
            state[0], results[0] = par_func(par_logp).one_step(
                state[0], forward_results(results[2], results[0])
            )
Chris Jewell's avatar
Chris Jewell committed
208

Chris Jewell's avatar
Chris Jewell committed
209
            # States
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            results[2] = forward_results(results[0], results[2])

            def infec_body(j, state, results):
                def state_logp(event_state):
                    state[1] = event_state
                    return logp(*state)

                state[1], results[1] = se_func(state_logp).one_step(
                    state[1], forward_results(results[2], results[1])
                )
                state[1], results[2] = ei_func(state_logp).one_step(
                    state[1], forward_results(results[1], results[2])
                )
                j += 1
                return j, state, results

            def infec_cond(j, state, results):
                return j < num_event_updates

            _, state, results = tf.while_loop(
                infec_cond,
                infec_body,
                loop_vars=[tf.constant(0, tf.int32), state, results],
233
            )
234

235
236
237
238
            sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
            results_accum = [
                results_accum[k].write(i, trace_results_fn(r))
                for k, r in enumerate(results)
Chris Jewell's avatar
Chris Jewell committed
239
            ]
240
            return i + 1, state, results, sample_accum, results_accum
241

Chris Jewell's avatar
Chris Jewell committed
242
243
        def cond(i, _1, _2, _3, _4):
            return i < n_samples
244

Chris Jewell's avatar
Chris Jewell committed
245
246
247
248
249
        _1, _2, _3, samples, results = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[0, init_state, results, samples_arr, results_arr],
        )
250

Chris Jewell's avatar
Chris Jewell committed
251
        return [s.stack() for s in samples], [r.stack() for r in results]
252
253


Chris Jewell's avatar
Chris Jewell committed
254
255
256
257
##################
# MCMC loop here #
##################

258
# MCMC Control
259
260
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
261
NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
Chris Jewell's avatar
Chris Jewell committed
262

263
# RNG stuff
264
tf.random.set_seed(2)
265
266

# Initial state.  NB [M, T, X] layout for events.
267
268
269
270
events = tf.transpose(
    tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)
)
current_state = [np.array([0.6, 0.25], dtype=DTYPE), events, tf.zeros_like(events)]
Chris Jewell's avatar
Chris Jewell committed
271

272
273

# Output Files
Chris Jewell's avatar
Chris Jewell committed
274
275
276
posterior = h5py.File(
    os.path.expandvars(config["output"]["posterior"]), "w", rdcc_nbytes=1024 ** 3 * 2,
)
277
event_size = [NUM_BURSTS * NUM_BURST_SAMPLES] + list(current_state[1].shape)
Chris Jewell's avatar
Chris Jewell committed
278
279
# event_chunk = (10, 1, 1, 1)
# print("Event chunk size:", event_chunk)
Chris Jewell's avatar
Chris Jewell committed
280
281
par_samples = posterior.create_dataset(
    "samples/parameter",
282
    [NUM_BURSTS * NUM_BURST_SAMPLES, current_state[0].shape[0]],
Chris Jewell's avatar
Chris Jewell committed
283
284
    dtype=np.float64,
)
Chris Jewell's avatar
Chris Jewell committed
285
286
287
288
289
290
291
292
se_samples = posterior.create_dataset(
    "samples/events",
    event_size,
    dtype=DTYPE,
    chunks=(1000,) + tuple(event_size[1:]),
    compression="gzip",
    compression_opts=1,
)
Chris Jewell's avatar
Chris Jewell committed
293
par_results = posterior.create_dataset(
294
    "acceptance/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
295
296
)
se_results = posterior.create_dataset(
297
    "acceptance/S->E",
298
    (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
299
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
300
301
)
ei_results = posterior.create_dataset(
302
    "acceptance/E->I",
303
    (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
304
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
305
306
)

307

Chris Jewell's avatar
Chris Jewell committed
308
309
310
311
312
313
314
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
    tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
)

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
315
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
316
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
317
    samples, results = sample(
318
319
320
321
        NUM_BURST_SAMPLES,
        init_state=current_state,
        par_scale=par_scale,
        num_event_updates=tf.constant(NUM_EVENT_TIME_UPDATES, tf.int32),
322
323
    )
    current_state = [s[-1] for s in samples]
324
    s = slice(i * NUM_BURST_SAMPLES, i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES)
325
326
    par_samples[s, ...] = samples[0].numpy()
    cov = np.cov(
327
        np.log(par_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
328
329
330
331
332
        rowvar=False,
    )
    print(current_state[0].numpy())
    print(cov)
    if np.all(np.isfinite(cov)):
333
        par_scale = 2.0 ** 2 * cov / 2.0
334
335
336
337
338
339
340
341
342

    se_samples[s, ...] = samples[1].numpy()
    par_results[s, ...] = results[0].numpy()
    se_results[s, ...] = results[1].numpy()
    ei_results[s, ...] = results[2].numpy()

    print("Acceptance0:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print("Acceptance1:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
    print("Acceptance2:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))
Chris Jewell's avatar
Chris Jewell committed
343
344
345
346
347
348

print(f"Acceptance param: {par_results[:, 1].mean()}")
print(f"Acceptance S->E: {se_results[:, 1].mean()}")
print(f"Acceptance E->I: {ei_results[:, 1].mean()}")

posterior.close()