event_time_mh.py 13.2 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
from pprint import pprint

3
4
5
6
7
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream

8
from covid import config
Chris Jewell's avatar
Chris Jewell committed
9
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
10
11
from covid.impl.mcmc import KernelResults

12
tfd = tfp.distributions
13
14
15
16
17
18
19
20
DTYPE = config.floatX


def _is_within(x, low, high):
    """Returns true if low <= x < high"""
    return tf.logical_and(tf.less_equal(low, x), tf.less(x, high))


21
def _nonzero_rows(m):
Chris Jewell's avatar
Chris Jewell committed
22
    return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)
23
24


Chris Jewell's avatar
Chris Jewell committed
25
26
27
def _max_free_events(
    events, initial_state, target_t, target_id, constraint_t, constraint_id
):
28
29
30
    """Returns the maximum number of free events to move in target_events constrained by
    constraining_events.
    :param events: a [T, M, X] tensor of transition events
31
    :param initial_state: a [M, X] tensor of the constraining initial state
32
33
    :param target_t: the target time
    :param target_id: the Xth index of the target event
34
    :param constraint_t: the Tth times of the constraint
35
    :param constraining_id: the Xth index of the constraining event, -1 implies no constraint
36
    :returns: a tensor of shape constraint_t.shape[0] + [M] of max free events, dtype=target_events.dtype
37
    """
38

39
    def true_fn():
40
        target_events_ = tf.gather(events, target_id, axis=-1)
41
        target_cumsum = tf.cumsum(target_events_, axis=0)
Chris Jewell's avatar
Chris Jewell committed
42
43
        constraining_events = tf.gather(events, constraint_id, axis=-1)  # TxM
        constraining_cumsum = tf.cumsum(constraining_events, axis=0)  # TxM
Chris Jewell's avatar
Chris Jewell committed
44
        constraining_init_state = tf.gather(initial_state, constraint_id + 1, axis=-1)
45
46
47
        n1 = tf.gather(target_cumsum, constraint_t, axis=0)
        n2 = tf.gather(constraining_cumsum, constraint_t, axis=0)
        free_events = tf.abs(n1 - n2) + constraining_init_state
Chris Jewell's avatar
Chris Jewell committed
48
49
50
        max_free_events = tf.minimum(
            free_events, tf.gather(target_events_, target_t, axis=0)
        )
51
        return max_free_events
52

53
54
55
56
    # Manual broadcasting of n_events_t is required here so that the XLA
    # compiler can guarantee that the output shapes of true_fn() and
    # false_fn() are equal.  Known shape information can thus be
    # propagated right through the algorithm, so the return value has known shape.
57
    def false_fn():
58
        n_events_t = tf.gather(events[..., target_id], target_t, axis=0)
Chris Jewell's avatar
Chris Jewell committed
59
60
61
        return tf.broadcast_to(
            [n_events_t], [constraint_t.shape[0]] + [n_events_t.shape[0]]
        )
62

63
64
    ret_val = tf.cond(constraint_id != -1, true_fn, false_fn)
    return ret_val
65
66


Chris Jewell's avatar
Chris Jewell committed
67
68
69
70
71
72
73
74
75
76
77
78
79
def _move_events(event_tensor, event_id, m, from_t, to_t, n_move):
    """Subtracts n_move from event_tensor[m, from_t, event_id]
    and adds n_move to event_tensor[m, to_t, event_id].

    :param event_tensor: shape [M, T, X]
    :param event_id: the event id to move
    :param m: the metapopulation to move
    :param from_t: the move-from time
    :param to_t: the move-to time
    :param n_move: the number of events to move
    :return: the modified event_tensor
    """
    # Todo rationalise this -- compute a delta, and add once.
Chris Jewell's avatar
Chris Jewell committed
80
81
82
    indices = tf.stack(
        [m, from_t, tf.broadcast_to(event_id, m.shape)], axis=-1  # All meta-populations
    )  # Event
83
    # Subtract x_star from the [from_t, :, event_id] row of the state tensor
Chris Jewell's avatar
Chris Jewell committed
84
    n_move = tf.cast(n_move, event_tensor.dtype)
Chris Jewell's avatar
Chris Jewell committed
85
    new_state = tf.tensor_scatter_nd_sub(event_tensor, indices, n_move)
Chris Jewell's avatar
Chris Jewell committed
86
    indices = tf.stack([m, to_t, tf.broadcast_to(event_id, m.shape)], axis=-1)
87
    # Add x_star to the [to_t, :, event_id] row of the state tensor
Chris Jewell's avatar
Chris Jewell committed
88
89
    new_state = tf.tensor_scatter_nd_add(new_state, indices, n_move)
    return new_state
90
91


92
class EventTimesUpdate(tfp.mcmc.TransitionKernel):
Chris Jewell's avatar
Chris Jewell committed
93
94
95
96
97
98
99
100
101
102
103
104
105
    def __init__(
        self,
        target_log_prob_fn,
        target_event_id,
        prev_event_id,
        next_event_id,
        initial_state,
        dmax,
        mmax,
        nmax,
        seed=None,
        name=None,
    ):
106
107
108
109
110
111
        """A random walk Metropolis Hastings for event times.
        :param target_log_prob_fn: the log density of the target distribution
        :param target_event_id: the position in the first dimension of the events tensor that we wish to move
        :param prev_event_id: the position of the previous event in the events tensor
        :param next_event_id: the position of the next event in the events tensor
        :param initial_state: the initial state tensor
112
113
114
        :param dmax: maximum distance to move in time
        :param mmax: number of metapopulations to move
        :param nmax: max number of events to move
115
116
117
        :param seed: a random seed
        :param name: the name of the update step
        """
Chris Jewell's avatar
Chris Jewell committed
118
        self._seed_stream = SeedStream(seed, salt="EventTimesUpdate")
119
        self._impl = tfp.mcmc.MetropolisHastings(
Chris Jewell's avatar
Chris Jewell committed
120
121
122
123
124
125
126
127
            inner_kernel=UncalibratedEventTimesUpdate(
                target_log_prob_fn=target_log_prob_fn,
                target_event_id=target_event_id,
                prev_event_id=prev_event_id,
                next_event_id=next_event_id,
                dmax=dmax,
                mmax=mmax,
                nmax=nmax,
Chris Jewell's avatar
Chris Jewell committed
128
129
130
                initial_state=initial_state,
            )
        )
131
        self._parameters = self._impl.inner_kernel.parameters.copy()
Chris Jewell's avatar
Chris Jewell committed
132
        self._parameters["seed"] = seed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    @property
    def target_log_prob_fn(self):
        return self._impl.inner_kernel.target_log_prob_fn

    @property
    def name(self):
        return self._impl.inner_kernel.name

    @property
    def parameters(self):
        """Return `dict` of ``__init__`` arguments and their values."""
        return self._parameters

    @property
    def is_calibrated(self):
        return True

    def one_step(self, current_state, previous_kernel_results):
        """Performs one step of an event times update.
        :param current_state: the current state tensor [TxMxX]
        :param previous_kernel_results: a named tuple of results.
        :returns: (next_state, kernel_results)
        """
Chris Jewell's avatar
Chris Jewell committed
157
158
159
        next_state, kernel_results = self._impl.one_step(
            current_state, previous_kernel_results
        )
160
161
162
163
164
165
166
        return next_state, kernel_results

    def bootstrap_results(self, init_state):
        kernel_results = self._impl.bootstrap_results(init_state)
        return kernel_results


Chris Jewell's avatar
Chris Jewell committed
167
def _reverse_move(move):
Chris Jewell's avatar
Chris Jewell committed
168
169
    move["t"] = move["t"] + move["delta_t"]
    move["delta_t"] = -move["delta_t"]
Chris Jewell's avatar
Chris Jewell committed
170
171
172
    return move


173
class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
Chris Jewell's avatar
Chris Jewell committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    """UncalibratedEventTimesUpdate"""

    def __init__(
        self,
        target_log_prob_fn,
        target_event_id,
        prev_event_id,
        next_event_id,
        initial_state,
        dmax,
        mmax,
        nmax,
        seed=None,
        name=None,
    ):
189
190
        """An uncalibrated random walk for event times.
        :param target_log_prob_fn: the log density of the target distribution
Chris Jewell's avatar
Chris Jewell committed
191
192
        :param target_event_id: the position in the first dimension of the events 
                                tensor that we wish to move
193
194
195
196
197
198
199
        :param prev_event_id: the position of the previous event in the events tensor
        :param next_event_id: the position of the next event in the events tensor
        :param initial_state: the initial state tensor
        :param seed: a random seed
        :param name: the name of the update step
        """
        self._target_log_prob_fn = target_log_prob_fn
Chris Jewell's avatar
Chris Jewell committed
200
        self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
201
202
203
204
205
206
207
        self._name = name
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            target_event_id=target_event_id,
            prev_event_id=prev_event_id,
            next_event_id=next_event_id,
            initial_state=initial_state,
208
            dmax=dmax,
209
210
            mmax=mmax,
            nmax=nmax,
211
            seed=seed,
Chris Jewell's avatar
Chris Jewell committed
212
213
214
215
216
217
            name=name,
        )
        self.tx_topology = TransitionTopology(
            prev_event_id, target_event_id, next_event_id
        )
        self.time_offsets = tf.range(self.parameters["dmax"])
218
219
220

    @property
    def target_log_prob_fn(self):
Chris Jewell's avatar
Chris Jewell committed
221
        return self._parameters["target_log_prob_fn"]
222
223
224

    @property
    def target_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
225
        return self._parameters["target_event_id"]
226
227
228

    @property
    def prev_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
229
        return self._parameters["prev_event_id"]
230
231
232

    @property
    def next_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
233
        return self._parameters["next_event_id"]
234
235
236

    @property
    def seed(self):
Chris Jewell's avatar
Chris Jewell committed
237
        return self._parameters["seed"]
238
239
240

    @property
    def name(self):
Chris Jewell's avatar
Chris Jewell committed
241
        return self._parameters["name"]
242
243
244
245
246
247
248
249
250
251

    @property
    def parameters(self):
        """Return `dict` of ``__init__`` arguments and their values."""
        return self._parameters

    @property
    def is_calibrated(self):
        return False

252
    def one_step(self, current_events, previous_kernel_results):
253
        """One update of event times.
Chris Jewell's avatar
Chris Jewell committed
254
255
256
257
258
259
        :param current_events: a [T, M, X] tensor containing number of events
                               per time t, metapopulation m,
                               and transition x.
        :param previous_kernel_results: an object of type
                                        UncalibratedRandomWalkResults.
        :returns: a tuple containing new_state and UncalibratedRandomWalkResults
260
        """
Chris Jewell's avatar
Chris Jewell committed
261
        with tf.name_scope("uncalibrated_event_times_rw/onestep"):
Chris Jewell's avatar
Chris Jewell committed
262
            current_events = tf.transpose(current_events, perm=(1, 0, 2))
263
            target_events = current_events[..., self.tx_topology.target]
Chris Jewell's avatar
Chris Jewell committed
264
            num_times = target_events.shape[1]
265

Chris Jewell's avatar
Chris Jewell committed
266
267
268
269
270
271
272
273
            proposal = FilteredEventTimeProposal(
                events=current_events,
                initial_state=self.parameters["initial_state"],
                topology=self.tx_topology,
                m_max=self.parameters["mmax"],
                d_max=self.parameters["dmax"],
                n_max=self.parameters["nmax"],
            )
Chris Jewell's avatar
Chris Jewell committed
274
            update = proposal.sample()
Chris Jewell's avatar
Chris Jewell committed
275
276
277
278
279
            q_fwd = proposal.log_prob(update)
            tf.debugging.assert_all_finite(q_fwd, "q_fwd is not finite")

            move = update["move"]
            to_t = move["t"] + move["delta_t"]
Chris Jewell's avatar
Chris Jewell committed
280
281
282
283
284

            # Moves outside the range [0, num_times] are illegal
            # Todo: address potential issue in the proposal if
            #       dmax accesses indices outside this range.
            def true_fn():
Chris Jewell's avatar
Chris Jewell committed
285
286
287
288
289
290
291
292
293

                next_state = _move_events(
                    event_tensor=current_events,
                    event_id=self.tx_topology.target,
                    m=update["m"],
                    from_t=move["t"],
                    to_t=to_t,
                    n_move=move["x_star"],
                )
Chris Jewell's avatar
Chris Jewell committed
294
295
296
297
298
299

                next_state_tr = tf.transpose(next_state, perm=(1, 0, 2))
                next_target_log_prob = self._target_log_prob_fn(next_state_tr)

                # Calculate proposal mass ratio
                rev_move = _reverse_move(move.copy())
Chris Jewell's avatar
Chris Jewell committed
300
301
                rev_update = dict(m=update["m"], move=rev_move)
                Q_rev = FilteredEventTimeProposal(  # pylint: disable-invalid-name
Chris Jewell's avatar
Chris Jewell committed
302
                    events=next_state,
Chris Jewell's avatar
Chris Jewell committed
303
                    initial_state=self.parameters["initial_state"],
Chris Jewell's avatar
Chris Jewell committed
304
                    topology=self.tx_topology,
Chris Jewell's avatar
Chris Jewell committed
305
306
307
308
                    m_max=self.parameters["mmax"],
                    d_max=self.parameters["dmax"],
                    n_max=self.parameters["nmax"],
                )
Chris Jewell's avatar
Chris Jewell committed
309
310
311
                q_rev = Q_rev.log_prob(rev_update)
                log_acceptance_correction = tf.reduce_sum(q_rev - q_fwd)

Chris Jewell's avatar
Chris Jewell committed
312
                return (next_target_log_prob, log_acceptance_correction, next_state_tr)
Chris Jewell's avatar
Chris Jewell committed
313
314

            def false_fn():
Chris Jewell's avatar
Chris Jewell committed
315
316
317
318
319
320
321
                next_target_log_prob = tf.constant(-np.inf, dtype=current_events.dtype)
                log_acceptance_correction = tf.constant(0.0, dtype=current_events.dtype)
                return (
                    next_target_log_prob,
                    log_acceptance_correction,
                    tf.transpose(current_events, perm=(1, 0, 2)),
                )
Chris Jewell's avatar
Chris Jewell committed
322

Chris Jewell's avatar
Chris Jewell committed
323
            # Trap out-of-bounds moves that go outside [0, num_times)
Chris Jewell's avatar
Chris Jewell committed
324
            (next_target_log_prob, log_acceptance_correction, next_state) = tf.cond(
Chris Jewell's avatar
Chris Jewell committed
325
326
                tf.reduce_all(_is_within(to_t, 0, num_times)),
                true_fn=true_fn,
Chris Jewell's avatar
Chris Jewell committed
327
328
                false_fn=false_fn,
            )
Chris Jewell's avatar
Chris Jewell committed
329

Chris Jewell's avatar
Chris Jewell committed
330
331
332
333
334
            x_star_results = tf.scatter_nd(
                update["m"][:, tf.newaxis],
                tf.abs(move["x_star"] * move["delta_t"]),
                [current_events.shape[0]],
            )
335

Chris Jewell's avatar
Chris Jewell committed
336
337
338
339
340
341
342
343
            return [
                next_state,
                KernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    extra=tf.cast(x_star_results, current_events.dtype),
                ),
            ]
344
345

    def bootstrap_results(self, init_state):
Chris Jewell's avatar
Chris Jewell committed
346
        with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
347
348
            init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
            init_target_log_prob = self.target_log_prob_fn(init_state)
349
            return KernelResults(
Chris Jewell's avatar
Chris Jewell committed
350
                log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
351
                target_log_prob=init_target_log_prob,
Chris Jewell's avatar
Chris Jewell committed
352
                extra=tf.zeros(init_state.shape[-2], dtype=DTYPE),
353
            )