mcmc.py 8.79 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
import os
2
import tqdm
3
import pickle as pkl
4
5
import yaml
import h5py
Chris Jewell's avatar
Chris Jewell committed
6

7
import numpy as np
8
9
10
11
12
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

13
14
from covid import config
from covid.model import load_data, CovidUKStochastic
15
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
16
from covid.impl.util import make_transition_matrix
17
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
18
from covid.impl.event_time_mh import EventTimesUpdate
19
20
21
22
23
24
25
26
27
28
29
30
31

DTYPE = config.floatX

# Random moves of events.  What invalidates an epidemic, how can we test for it?
with open('ode_config.yaml','r') as f:
    config = yaml.load(f)

param = sanitise_parameter(config['parameter'])
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}

settings = sanitise_settings(config['settings'])

data = load_data(config['data'], settings, DTYPE)
Chris Jewell's avatar
Chris Jewell committed
32
data['pop'] = data['pop'].sum(level=0)
33

34
model = CovidUKStochastic(C=data['C'][:10, :10],
35
                          N=[1000]*10, #data['pop']['n'].to_numpy(),
36
37
38
39
40
41
                          W=data['W'],
                          date_range=settings['inference_period'],
                          holidays=settings['holiday'],
                          lockdown=settings['lockdown'],
                          time_step=1.)

42
# Load data
43
with open('stochastic_sim_medium.pkl', 'rb') as f:
44
    example_sim = pkl.load(f)
45

46
47
48
49
50
51
52
event_tensor = example_sim['events']  # shape [T, M, S, S]
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
state_init = example_sim['state_init']
se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
Chris Jewell's avatar
Chris Jewell committed
53

54
def logp(par, events):
55
    p = param
Chris Jewell's avatar
Chris Jewell committed
56
    p['beta1'] = tf.convert_to_tensor(par[0], dtype=DTYPE)
57
58
59
    #p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
    #p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
    p['gamma'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
60
61
    beta1_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
                          rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
62
63
64
65
    #beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
    #                       rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
    #beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
    #                       rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
Chris Jewell's avatar
Chris Jewell committed
66
67
    gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
                           rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
Chris Jewell's avatar
Chris Jewell committed
68
69
70
71
72
73
    with tf.name_scope('main_log_p'):
        event_tensor = make_transition_matrix(events,
                                              [[0, 1], [1, 2], [2, 3]],
                                              [num_times, num_meta, 4])
        y_logp = tf.reduce_sum(model.log_prob(event_tensor, p, state_init))
    logp = beta1_logp + gamma_logp + y_logp
74
75
76
77
78
79
80
81
82
83
84
85
    return logp


def trace_fn(state, prev_results):
    return (prev_results.is_accepted,
            prev_results.accepted_results.target_log_prob)


# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
    def kernel_func(logp):
Chris Jewell's avatar
Chris Jewell committed
86
87
        return tfp.mcmc.MetropolisHastings(
            inner_kernel=UncalibratedLogRandomWalk(
88
89
                    target_log_prob_fn=logp,
                    new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence)
Chris Jewell's avatar
Chris Jewell committed
90
                ), name='parameter_update')
91
92
93
    return kernel_func


94
def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
95
    def kernel_func(logp):
96
97
98
99
        return EventTimesUpdate(target_log_prob_fn=logp,
                                target_event_id=target_event_id,
                                prev_event_id=prev_event_id,
                                next_event_id=next_event_id,
100
                                dmax=1,
101
102
                                mmax=1,
                                nmax=20,
103
                                initial_state=state_init)
104
105
106
107
108
    return kernel_func


def is_accepted(result):
    if hasattr(result, 'is_accepted'):
109
        return tf.cast(result.is_accepted, DTYPE)
110
111
112
113
    else:
        return is_accepted(result.inner_results)


114
115
116
117
118
119
120
def trace_results_fn(results):
    log_prob = results.proposed_results.target_log_prob
    accepted = is_accepted(results)
    proposed = results.proposed_results.extra
    return tf.concat([[log_prob], [accepted], proposed], axis=0)


121
#@tf.function #(autograph=False, experimental_compile=True)
122
123
124
def sample(n_samples, init_state, par_scale):
    init_state = init_state.copy()
    par_func = make_parameter_kernel(par_scale, 0.95)
Chris Jewell's avatar
Chris Jewell committed
125
126
    se_func = make_events_step(0, None, 1)
    ei_func = make_events_step(target_event_id=1, prev_event_id=0, next_event_id=2)
127
128

    # Based on Gibbs idea posted by Pavel Sountsov https://github.com/tensorflow/probability/issues/495
Chris Jewell's avatar
Chris Jewell committed
129
    results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(init_state[1])
130
131

    samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
Chris Jewell's avatar
Chris Jewell committed
132
    results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
133
134

    def body(i, state, prev_results, samples, results):
Chris Jewell's avatar
Chris Jewell committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        # Parameters
        def par_logp(par_state):
            state[0] = par_state  # close over state from outer scope
            return logp(*state)
        state[0], par_results = par_func(par_logp).one_step(state[0], prev_results)

        # States
        def state_logp(event_state):
            state[1] = event_state
            return logp(*state)
        state[1], se_results = se_func(state_logp).one_step(state[1], par_results)
        state[1], ei_results = ei_func(state_logp).one_step(state[1], se_results)

        samples = [samples[k].write(i, s) for k, s in enumerate(state)]
149
150
        results = [results[k].write(i, trace_results_fn(r))
                   for k, r in enumerate([par_results, se_results, ei_results])]
Chris Jewell's avatar
Chris Jewell committed
151
        return i+1, state, ei_results, samples, results
152
153
154
155
156
157
158
159
160
161
162
163

    def cond(i, _1, _2, _3, _4):
        return i < n_samples

    _1, _2, _3, samples, results = tf.while_loop(cond=cond, body=body,
                                                 loop_vars=[0, init_state, results, samples_arr, results_arr])

    return [s.stack() for s in samples], [r.stack() for r in results]


if __name__=='__main__':

164
    num_loop_iterations = 1000
Chris Jewell's avatar
Chris Jewell committed
165
    num_loop_samples = 100
166
    current_state = [np.array([0.15, 0.25], dtype=DTYPE), tf.stack([se_events, ei_events, ir_events], axis=-1)]
Chris Jewell's avatar
Chris Jewell committed
167
168
169
170
171
172

    posterior = h5py.File(os.path.expandvars(config['output']['posterior']), 'w')
    event_size = [num_loop_iterations * num_loop_samples] + list(current_state[1].shape)
    par_samples = posterior.create_dataset('samples/parameter', [num_loop_iterations*num_loop_samples,
                                                                 current_state[0].shape[0]], dtype=np.float64)
    se_samples = posterior.create_dataset('samples/events', event_size, dtype=DTYPE)
173
174
175
    par_results = posterior.create_dataset('acceptance/parameter', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
    se_results = posterior.create_dataset('acceptance/S->E', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
    ei_results = posterior.create_dataset('acceptance/E->I', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
Chris Jewell's avatar
Chris Jewell committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    print("Initial logpi:", logp(*current_state))
    par_scale = tf.linalg.diag(tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1)

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    #with tf.profiler.experimental.Profile('/tmp/tf_logdir'):
    for i in tqdm.tqdm(range(num_loop_iterations), unit_scale=num_loop_samples):
        samples, results = sample(num_loop_samples, init_state=current_state, par_scale=par_scale)
        current_state = [s[-1] for s in samples]
        s = slice(i*num_loop_samples, i*num_loop_samples+num_loop_samples)
        par_samples[s, ...] = samples[0].numpy()
        cov = np.cov(np.log(par_samples[:(i*num_loop_samples+num_loop_samples), ...]), rowvar=False)
        print(current_state[0].numpy())
        print(cov)
        if(np.all(np.isfinite(cov))):
            par_scale = 2.38**2 * cov / 2.

        se_samples[s, ...] = samples[1].numpy()
        par_results[s, ...] = results[0].numpy()
        se_results[s, ...] = results[1].numpy()
        ei_results[s, ...] = results[2].numpy()

        print("Acceptance0:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
        print("Acceptance1:", tf.reduce_mean(tf.cast(results[1][:, 1], tf.float32)))
        print("Acceptance2:", tf.reduce_mean(tf.cast(results[2][:, 1], tf.float32)))

    print(f'Acceptance param: {par_results[:, 1].mean()}')
    print(f'Acceptance S->E: {se_results[:, 1].mean()}')
    print(f'Acceptance E->I: {ei_results[:, 1].mean()}')
206
207

    posterior.close()
Chris Jewell's avatar
Chris Jewell committed
208