mcmc.py 10.1 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""MCMC Test Rig for COVID-19 UK model"""
2
import optparse
Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
5

Chris Jewell's avatar
Chris Jewell committed
6
import h5py
7
import numpy as np
8
9
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
10
11
12
import tqdm
import yaml

13
14
from covid import config
from covid.model import load_data, CovidUKStochastic
Chris Jewell's avatar
Chris Jewell committed
15
from covid.util import sanitise_parameter, sanitise_settings
16
from covid.impl.util import make_transition_matrix
17
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
18
from covid.impl.event_time_mh import EventTimesUpdate
19

20

Chris Jewell's avatar
Chris Jewell committed
21
22
23
###########
# TF Bits #
###########
24

Chris Jewell's avatar
Chris Jewell committed
25
26
27
tfd = tfp.distributions
tfb = tfp.bijectors

28
29
DTYPE = config.floatX

30
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
31
# os.environ["XLA_FLAGS"] = '--xla_dump_to=xla_dump --xla_dump_hlo_pass_re=".*"'
Chris Jewell's avatar
Chris Jewell committed
32

33
34
35
36
37
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

38
39
40
41
42
43
44
45
46
47
48
49
50
# Read in settings
parser = optparse.OptionParser()
parser.add_option(
    "--config",
    "-c",
    dest="config",
    default="ode_config.yaml",
    help="configuration file",
    )
options, args = parser.parse_args()
print("Loading config file:", options.config)

with open(options.config, "r") as f:
51
52
    config = yaml.load(f)

53
54
print("Config:",config)
    
Chris Jewell's avatar
Chris Jewell committed
55
param = sanitise_parameter(config["parameter"])
56
57
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
58
59
60
61
settings = sanitise_settings(config["settings"])

data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
62

Chris Jewell's avatar
Chris Jewell committed
63
64
65
66
67
68
69
70
71
model = CovidUKStochastic(
    C=data["C"],
    N=data["pop"]["n"].to_numpy(),
    W=data["W"],
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
    lockdown=settings["lockdown"],
    time_step=1.0,
)
72
73


74
# Load data
Chris Jewell's avatar
Chris Jewell committed
75
with open("stochastic_sim_covid.pkl", "rb") as f:
76
    example_sim = pkl.load(f)
77

Chris Jewell's avatar
Chris Jewell committed
78
event_tensor = example_sim["events"]  # shape [T, M, S, S]
79
80
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
Chris Jewell's avatar
Chris Jewell committed
81
state_init = example_sim["state_init"]
82
83
84
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
Chris Jewell's avatar
Chris Jewell committed
85

Chris Jewell's avatar
Chris Jewell committed
86

Chris Jewell's avatar
Chris Jewell committed
87
88
89
##########################
# Log p and MCMC kernels #
##########################
90
91


92
def logp(par, events):
93
    p = param
Chris Jewell's avatar
Chris Jewell committed
94
    p["beta1"] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
95
96
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
97
98
99
100
    p["gamma"] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
Chris Jewell's avatar
Chris Jewell committed
101
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
102
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
103
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
104
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
105
106
107
108
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
109
    with tf.name_scope("epidemic_log_posterior"):
Chris Jewell's avatar
Chris Jewell committed
110
        y_logp = model.log_prob(events, p, state_init)
Chris Jewell's avatar
Chris Jewell committed
111
    logp = beta1_logp + gamma_logp + y_logp
112
113
114
115
116
117
118
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
119
120
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
121
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
122
123
124
125
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
            name="parameter_update",
        )
Chris Jewell's avatar
Chris Jewell committed
126

127
128
129
    return kernel_func


130
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
131
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
132
133
134
135
136
        return EventTimesUpdate(
            target_log_prob_fn=logp,
            target_event_id=target_event_id,
            prev_event_id=prev_event_id,
            next_event_id=next_event_id,
137
138
139
            dmax=config["mcmc"]["dmax"],
            mmax=config["mcmc"]["m"],
            nmax=config["mcmc"]["nmax"],
Chris Jewell's avatar
Chris Jewell committed
140
141
            initial_state=state_init,
        )
Chris Jewell's avatar
Chris Jewell committed
142

143
144
145
146
    return kernel_func


def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
147
    if hasattr(result, "is_accepted"):
148
        return tf.cast(result.is_accepted, DTYPE)
149
150
151
152
    else:
        return is_accepted(result.inner_results)


153
154
155
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
156
    q_ratio = results.proposed_results.log_acceptance_correction
157
158
159
160
161
162
    if hasattr(results.proposed_results, "extra"):
        proposed = results.proposed_results.extra
        return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
    else:
        return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)

163

164
165
166
167
168
def forward_results(prev_results, next_results):
    accepted_results = next_results.accepted_results._replace(
        target_log_prob=prev_results.accepted_results.target_log_prob
    )
    return next_results._replace(accepted_results=accepted_results)
169

170

171
@tf.function(autograph=False, experimental_compile=True)
172
def sample(n_samples, init_state, par_scale):
Chris Jewell's avatar
Chris Jewell committed
173
174
175
176
177
178
179
180
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
        par_func = make_parameter_kernel(par_scale, 0.95)
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
181
182
183
184
        par_results = par_func(lambda p: logp(p, init_state[1])).bootstrap_results(
            init_state[0]
        )
        se_results = se_func(lambda s: logp(init_state[0], s)).bootstrap_results(
Chris Jewell's avatar
Chris Jewell committed
185
186
            init_state[1]
        )
187
188
189
190
        ei_results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
            init_state[1]
        )
        results = [par_results, se_results, ei_results]
191

Chris Jewell's avatar
Chris Jewell committed
192
193
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
        results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
194

195
        def body(i, state, results, sample_accum, results_accum):
Chris Jewell's avatar
Chris Jewell committed
196
197
198
199
            # Parameters
            def par_logp(par_state):
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
200

201
202
203
            state[0], results[0] = par_func(par_logp).one_step(
                state[0], forward_results(results[2], results[0])
            )
Chris Jewell's avatar
Chris Jewell committed
204

Chris Jewell's avatar
Chris Jewell committed
205
206
207
208
            # States
            def state_logp(event_state):
                state[1] = event_state
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
209

210
211
212
213
214
215
216
217
218
219
            state[1], results[1] = se_func(state_logp).one_step(
                state[1], forward_results(results[0], results[1])
            )
            state[1], results[2] = ei_func(state_logp).one_step(
                state[1], forward_results(results[1], results[2])
            )
            sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
            results_accum = [
                results_accum[k].write(i, trace_results_fn(r))
                for k, r in enumerate(results)
Chris Jewell's avatar
Chris Jewell committed
220
            ]
221
            return i + 1, state, results, sample_accum, results_accum
222

Chris Jewell's avatar
Chris Jewell committed
223
224
        def cond(i, _1, _2, _3, _4):
            return i < n_samples
225

Chris Jewell's avatar
Chris Jewell committed
226
227
228
229
230
        _1, _2, _3, samples, results = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[0, init_state, results, samples_arr, results_arr],
        )
231

Chris Jewell's avatar
Chris Jewell committed
232
        return [s.stack() for s in samples], [r.stack() for r in results]
233
234


Chris Jewell's avatar
Chris Jewell committed
235
236
237
238
##################
# MCMC loop here #
##################

239
# MCMC Control
240
241
NUM_BURSTS = config["mcmc"]["num_bursts"]
NUM_BURST_SAMPLES = config["mcmc"]["num_burst_samples"]
Chris Jewell's avatar
Chris Jewell committed
242

243
# RNG stuff
244
tf.random.set_seed(2)
245
246

# Initial state.  NB [M, T, X] layout for events.
Chris Jewell's avatar
Chris Jewell committed
247
248
current_state = [
    np.array([0.6, 0.25], dtype=DTYPE),
249
    tf.transpose(tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)),
Chris Jewell's avatar
Chris Jewell committed
250
251
]

252
253

# Output Files
Chris Jewell's avatar
Chris Jewell committed
254
posterior = h5py.File(os.path.expandvars(config["output"]["posterior"]), "w")
255
event_size = [NUM_BURSTS * NUM_BURST_SAMPLES] + list(current_state[1].shape)
Chris Jewell's avatar
Chris Jewell committed
256
257
par_samples = posterior.create_dataset(
    "samples/parameter",
258
    [NUM_BURSTS * NUM_BURST_SAMPLES, current_state[0].shape[0]],
Chris Jewell's avatar
Chris Jewell committed
259
260
261
262
    dtype=np.float64,
)
se_samples = posterior.create_dataset("samples/events", event_size, dtype=DTYPE)
par_results = posterior.create_dataset(
263
    "acceptance/parameter", (NUM_BURSTS * NUM_BURST_SAMPLES, 3), dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
264
265
)
se_results = posterior.create_dataset(
266
    "acceptance/S->E",
267
    (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
268
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
269
270
)
ei_results = posterior.create_dataset(
271
    "acceptance/E->I",
272
    (NUM_BURSTS * NUM_BURST_SAMPLES, 3 + model.N.shape[0]),
273
    dtype=DTYPE,
Chris Jewell's avatar
Chris Jewell committed
274
275
)

276

Chris Jewell's avatar
Chris Jewell committed
277
278
279
280
281
282
283
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
    tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
)

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
284
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
285
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
286
    samples, results = sample(
287
        NUM_BURST_SAMPLES, init_state=current_state, par_scale=par_scale
288
289
    )
    current_state = [s[-1] for s in samples]
290
    s = slice(i * NUM_BURST_SAMPLES, i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES)
291
292
    par_samples[s, ...] = samples[0].numpy()
    cov = np.cov(
293
        np.log(par_samples[: (i * NUM_BURST_SAMPLES + NUM_BURST_SAMPLES), ...]),
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        rowvar=False,
    )
    print(current_state[0].numpy())
    print(cov)
    if np.all(np.isfinite(cov)):
        par_scale = 2.38 ** 2 * cov / 2.0

    se_samples[s, ...] = samples[1].numpy()
    par_results[s, ...] = results[0].numpy()
    se_results[s, ...] = results[1].numpy()
    ei_results[s, ...] = results[2].numpy()

    print("Acceptance0:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print("Acceptance1:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
    print("Acceptance2:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))
Chris Jewell's avatar
Chris Jewell committed
309
310
311
312
313
314

print(f"Acceptance param: {par_results[:, 1].mean()}")
print(f"Acceptance S->E: {se_results[:, 1].mean()}")
print(f"Acceptance E->I: {ei_results[:, 1].mean()}")

posterior.close()