mcmc.py 8.9 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""MCMC Test Rig for COVID-19 UK model"""

Chris Jewell's avatar
Chris Jewell committed
3
import os
4
import pickle as pkl
Chris Jewell's avatar
Chris Jewell committed
5

Chris Jewell's avatar
Chris Jewell committed
6
import h5py
7
import numpy as np
8
9
import tensorflow as tf
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
10
11
12
import tqdm
import yaml

13
14
from covid import config
from covid.model import load_data, CovidUKStochastic
Chris Jewell's avatar
Chris Jewell committed
15
from covid.util import sanitise_parameter, sanitise_settings
16
from covid.impl.util import make_transition_matrix
17
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
18
from covid.impl.event_time_mh import EventTimesUpdate
19

20
21
22
23
24

#############
## TF Bits ##
#############

Chris Jewell's avatar
Chris Jewell committed
25
26
27
tfd = tfp.distributions
tfb = tfp.bijectors

28
29
DTYPE = config.floatX

30
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Chris Jewell's avatar
Chris Jewell committed
31

32
33
34
35
36
if tf.test.gpu_device_name():
    print("Using GPU")
else:
    print("Using CPU")

Chris Jewell's avatar
Chris Jewell committed
37

38
# Random moves of events.  What invalidates an epidemic, how can we test for it?
Chris Jewell's avatar
Chris Jewell committed
39
with open("ode_config.yaml", "r") as f:
40
41
    config = yaml.load(f)

Chris Jewell's avatar
Chris Jewell committed
42
param = sanitise_parameter(config["parameter"])
43
44
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

Chris Jewell's avatar
Chris Jewell committed
45
46
47
48
settings = sanitise_settings(config["settings"])

data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
49

Chris Jewell's avatar
Chris Jewell committed
50
51
52
53
54
55
56
57
58
model = CovidUKStochastic(
    C=data["C"],
    N=data["pop"]["n"].to_numpy(),
    W=data["W"],
    date_range=settings["inference_period"],
    holidays=settings["holiday"],
    lockdown=settings["lockdown"],
    time_step=1.0,
)
59
60


61
# Load data
Chris Jewell's avatar
Chris Jewell committed
62
with open("stochastic_sim_covid.pkl", "rb") as f:
63
    example_sim = pkl.load(f)
64

Chris Jewell's avatar
Chris Jewell committed
65
event_tensor = example_sim["events"]  # shape [T, M, S, S]
66
67
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
Chris Jewell's avatar
Chris Jewell committed
68
state_init = example_sim["state_init"]
69
70
71
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
Chris Jewell's avatar
Chris Jewell committed
72

Chris Jewell's avatar
Chris Jewell committed
73

74
75
76
77
78
############################
## Log p and MCMC kernels ##
############################


79
def logp(par, events):
80
    p = param
Chris Jewell's avatar
Chris Jewell committed
81
    p["beta1"] = tf.convert_to_tensor(par[0], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
82
83
    # p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    # p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
84
85
86
87
    p["gamma"] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    beta1_logp = tfd.Gamma(
        concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
    ).log_prob(p["beta1"])
Chris Jewell's avatar
Chris Jewell committed
88
    # beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
89
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
Chris Jewell's avatar
Chris Jewell committed
90
    # beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
91
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
92
93
94
95
    gamma_logp = tfd.Gamma(
        concentration=tf.constant(100.0, dtype=DTYPE),
        rate=tf.constant(400.0, dtype=DTYPE),
    ).log_prob(p["gamma"])
Chris Jewell's avatar
Chris Jewell committed
96
    with tf.name_scope("epidemic_log_posterior"):
Chris Jewell's avatar
Chris Jewell committed
97
        event_tensor = make_transition_matrix(
98
            events, [[0, 1], [1, 2], [2, 3]], [num_meta, num_times, 4]
Chris Jewell's avatar
Chris Jewell committed
99
        )
Chris Jewell's avatar
Chris Jewell committed
100
101
        y_logp = tf.reduce_sum(model.log_prob(event_tensor, p, state_init))
    logp = beta1_logp + gamma_logp + y_logp
102
103
104
105
106
107
108
    return logp


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
109
110
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
Chris Jewell's avatar
Chris Jewell committed
111
                target_log_prob_fn=logp,
Chris Jewell's avatar
Chris Jewell committed
112
113
114
115
                new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence),
            ),
            name="parameter_update",
        )
Chris Jewell's avatar
Chris Jewell committed
116

117
118
119
    return kernel_func


120
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
121
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
122
123
124
125
126
127
128
129
130
131
        return EventTimesUpdate(
            target_log_prob_fn=logp,
            target_event_id=target_event_id,
            prev_event_id=prev_event_id,
            next_event_id=next_event_id,
            dmax=2,
            mmax=2,
            nmax=10,
            initial_state=state_init,
        )
Chris Jewell's avatar
Chris Jewell committed
132

133
134
135
136
    return kernel_func


def is_accepted(result):
Chris Jewell's avatar
Chris Jewell committed
137
    if hasattr(result, "is_accepted"):
138
        return tf.cast(result.is_accepted, DTYPE)
139
140
141
142
    else:
        return is_accepted(result.inner_results)


143
144
145
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
Chris Jewell's avatar
Chris Jewell committed
146
    q_ratio = results.proposed_results.log_acceptance_correction
147
    proposed = results.proposed_results.extra
Chris Jewell's avatar
Chris Jewell committed
148
    return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
149
150


151
@tf.function  # (experimental_compile=True)
152
def sample(n_samples, init_state, par_scale):
Chris Jewell's avatar
Chris Jewell committed
153
154
155
156
157
158
159
160
161
162
163
    with tf.name_scope("main_mcmc_sample_loop"):
        init_state = init_state.copy()
        par_func = make_parameter_kernel(par_scale, 0.95)
        se_func = make_events_step(0, None, 1)
        ei_func = make_events_step(1, 0, 2)

        # Based on Gibbs idea posted by Pavel Sountsov
        # https://github.com/tensorflow/probability/issues/495
        results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
            init_state[1]
        )
164

Chris Jewell's avatar
Chris Jewell committed
165
166
        samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
        results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
167

Chris Jewell's avatar
Chris Jewell committed
168
169
170
171
172
        def body(i, state, prev_results, samples, results):
            # Parameters
            def par_logp(par_state):
                state[0] = par_state  # close over state from outer scope
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
173

Chris Jewell's avatar
Chris Jewell committed
174
            state[0], par_results = par_func(par_logp).one_step(state[0], prev_results)
Chris Jewell's avatar
Chris Jewell committed
175

Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
            # States
            def state_logp(event_state):
                state[1] = event_state
                return logp(*state)
Chris Jewell's avatar
Chris Jewell committed
180

Chris Jewell's avatar
Chris Jewell committed
181
182
            state[1], se_results = se_func(state_logp).one_step(state[1], par_results)
            state[1], ei_results = ei_func(state_logp).one_step(state[1], se_results)
Chris Jewell's avatar
Chris Jewell committed
183

Chris Jewell's avatar
Chris Jewell committed
184
185
186
187
188
189
            samples = [samples[k].write(i, s) for k, s in enumerate(state)]
            results = [
                results[k].write(i, trace_results_fn(r))
                for k, r in enumerate([par_results, se_results, ei_results])
            ]
            return i + 1, state, ei_results, samples, results
190

Chris Jewell's avatar
Chris Jewell committed
191
192
        def cond(i, _1, _2, _3, _4):
            return i < n_samples
193

Chris Jewell's avatar
Chris Jewell committed
194
195
196
197
198
        _1, _2, _3, samples, results = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[0, init_state, results, samples_arr, results_arr],
        )
199

Chris Jewell's avatar
Chris Jewell committed
200
        return [s.stack() for s in samples], [r.stack() for r in results]
201
202


Chris Jewell's avatar
Chris Jewell committed
203
204
205
206
##################
# MCMC loop here #
##################

207
208
209
# MCMC Control
NUM_LOOP_ITERATIONS = 1000
NUM_LOOP_SAMPLES = 100
Chris Jewell's avatar
Chris Jewell committed
210

211
212
# Initial States
tf.random.set_seed(2)
Chris Jewell's avatar
Chris Jewell committed
213
214
current_state = [
    np.array([0.6, 0.25], dtype=DTYPE),
215
    tf.transpose(tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)),
Chris Jewell's avatar
Chris Jewell committed
216
217
]

218
219

# Output Files
Chris Jewell's avatar
Chris Jewell committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
posterior = h5py.File(os.path.expandvars(config["output"]["posterior"]), "w")
event_size = [NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES] + list(current_state[1].shape)
par_samples = posterior.create_dataset(
    "samples/parameter",
    [NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, current_state[0].shape[0]],
    dtype=np.float64,
)
se_samples = posterior.create_dataset("samples/events", event_size, dtype=DTYPE)
par_results = posterior.create_dataset(
    "acceptance/parameter", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE,
)
se_results = posterior.create_dataset(
    "acceptance/S->E", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE
)
ei_results = posterior.create_dataset(
    "acceptance/E->I", (NUM_LOOP_ITERATIONS * NUM_LOOP_SAMPLES, 152), dtype=DTYPE
)

238

Chris Jewell's avatar
Chris Jewell committed
239
240
241
242
243
244
245
246
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
    tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
)

# We loop over successive calls to sample because we have to dump results
#   to disc, or else end OOM (even on a 32GB system).
for i in tqdm.tqdm(range(NUM_LOOP_ITERATIONS), unit_scale=NUM_LOOP_SAMPLES):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
    samples, results = sample(
        NUM_LOOP_SAMPLES, init_state=current_state, par_scale=par_scale
    )
    current_state = [s[-1] for s in samples]
    s = slice(i * NUM_LOOP_SAMPLES, i * NUM_LOOP_SAMPLES + NUM_LOOP_SAMPLES)
    par_samples[s, ...] = samples[0].numpy()
    cov = np.cov(
        np.log(par_samples[: (i * NUM_LOOP_SAMPLES + NUM_LOOP_SAMPLES), ...]),
        rowvar=False,
    )
    print(current_state[0].numpy())
    print(cov)
    if np.all(np.isfinite(cov)):
        par_scale = 2.38 ** 2 * cov / 2.0

    se_samples[s, ...] = samples[1].numpy()
    par_results[s, ...] = results[0].numpy()
    se_results[s, ...] = results[1].numpy()
    ei_results[s, ...] = results[2].numpy()

    print("Acceptance0:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
    print("Acceptance1:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
    print("Acceptance2:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))
Chris Jewell's avatar
Chris Jewell committed
271
272
273
274
275
276

print(f"Acceptance param: {par_results[:, 1].mean()}")
print(f"Acceptance S->E: {se_results[:, 1].mean()}")
print(f"Acceptance E->I: {ei_results[:, 1].mean()}")

posterior.close()