event_time_mh.py 8.82 KB
Newer Older
1
from collections import namedtuple
2
3
4
5
6
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream

7
from covid import config
Chris Jewell's avatar
Chris Jewell committed
8
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
9

10
tfd = tfp.distributions
11
12
13
DTYPE = config.floatX


14
15
16
17
18
EventTimesKernelResults = namedtuple(
    "KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
)


19
20
21
22
23
def _is_within(x, low, high):
    """Returns true if low <= x < high"""
    return tf.logical_and(tf.less_equal(low, x), tf.less(x, high))


24
def _nonzero_rows(m):
Chris Jewell's avatar
Chris Jewell committed
25
    return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)
26
27


Chris Jewell's avatar
Chris Jewell committed
28
29
30
31
32
33
34
35
36
37
38
39
40
def _move_events(event_tensor, event_id, m, from_t, to_t, n_move):
    """Subtracts n_move from event_tensor[m, from_t, event_id]
    and adds n_move to event_tensor[m, to_t, event_id].

    :param event_tensor: shape [M, T, X]
    :param event_id: the event id to move
    :param m: the metapopulation to move
    :param from_t: the move-from time
    :param to_t: the move-to time
    :param n_move: the number of events to move
    :return: the modified event_tensor
    """
    # Todo rationalise this -- compute a delta, and add once.
Chris Jewell's avatar
Chris Jewell committed
41
42
43
    indices = tf.stack(
        [m, from_t, tf.broadcast_to(event_id, m.shape)], axis=-1  # All meta-populations
    )  # Event
44
    # Subtract x_star from the [from_t, :, event_id] row of the state tensor
Chris Jewell's avatar
Chris Jewell committed
45
    n_move = tf.cast(n_move, event_tensor.dtype)
Chris Jewell's avatar
Chris Jewell committed
46
    new_state = tf.tensor_scatter_nd_sub(event_tensor, indices, n_move)
Chris Jewell's avatar
Chris Jewell committed
47
    indices = tf.stack([m, to_t, tf.broadcast_to(event_id, m.shape)], axis=-1)
48
    # Add x_star to the [to_t, :, event_id] row of the state tensor
Chris Jewell's avatar
Chris Jewell committed
49
50
    new_state = tf.tensor_scatter_nd_add(new_state, indices, n_move)
    return new_state
51
52


Chris Jewell's avatar
Chris Jewell committed
53
def _reverse_move(move):
Chris Jewell's avatar
Chris Jewell committed
54
55
    move["t"] = move["t"] + move["delta_t"]
    move["delta_t"] = -move["delta_t"]
Chris Jewell's avatar
Chris Jewell committed
56
57
58
    return move


59
class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
Chris Jewell's avatar
Chris Jewell committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    """UncalibratedEventTimesUpdate"""

    def __init__(
        self,
        target_log_prob_fn,
        target_event_id,
        prev_event_id,
        next_event_id,
        initial_state,
        dmax,
        mmax,
        nmax,
        seed=None,
        name=None,
    ):
75
76
        """An uncalibrated random walk for event times.
        :param target_log_prob_fn: the log density of the target distribution
77
        :param target_event_id: the position in the first dimension of the events
Chris Jewell's avatar
Chris Jewell committed
78
                                tensor that we wish to move
79
80
81
82
83
84
        :param prev_event_id: the position of the previous event in the events tensor
        :param next_event_id: the position of the next event in the events tensor
        :param initial_state: the initial state tensor
        :param seed: a random seed
        :param name: the name of the update step
        """
Chris Jewell's avatar
Chris Jewell committed
85
        self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
86
87
88
89
90
91
92
        self._name = name
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            target_event_id=target_event_id,
            prev_event_id=prev_event_id,
            next_event_id=next_event_id,
            initial_state=initial_state,
93
            dmax=dmax,
94
95
            mmax=mmax,
            nmax=nmax,
96
            seed=seed,
Chris Jewell's avatar
Chris Jewell committed
97
98
99
100
101
102
            name=name,
        )
        self.tx_topology = TransitionTopology(
            prev_event_id, target_event_id, next_event_id
        )
        self.time_offsets = tf.range(self.parameters["dmax"])
103
104
105

    @property
    def target_log_prob_fn(self):
Chris Jewell's avatar
Chris Jewell committed
106
        return self._parameters["target_log_prob_fn"]
107
108
109

    @property
    def target_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
110
        return self._parameters["target_event_id"]
111
112
113

    @property
    def prev_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
114
        return self._parameters["prev_event_id"]
115
116
117

    @property
    def next_event_id(self):
Chris Jewell's avatar
Chris Jewell committed
118
        return self._parameters["next_event_id"]
119
120
121

    @property
    def seed(self):
Chris Jewell's avatar
Chris Jewell committed
122
        return self._parameters["seed"]
123
124
125

    @property
    def name(self):
Chris Jewell's avatar
Chris Jewell committed
126
        return self._parameters["name"]
127
128
129
130
131
132
133
134
135
136

    @property
    def parameters(self):
        """Return `dict` of ``__init__`` arguments and their values."""
        return self._parameters

    @property
    def is_calibrated(self):
        return False

137
    def one_step(self, current_events, previous_kernel_results):
138
        """One update of event times.
Chris Jewell's avatar
Chris Jewell committed
139
140
141
142
143
144
        :param current_events: a [T, M, X] tensor containing number of events
                               per time t, metapopulation m,
                               and transition x.
        :param previous_kernel_results: an object of type
                                        UncalibratedRandomWalkResults.
        :returns: a tuple containing new_state and UncalibratedRandomWalkResults
145
        """
Chris Jewell's avatar
Chris Jewell committed
146
        with tf.name_scope("uncalibrated_event_times_rw/onestep"):
147
            target_events = current_events[..., self.tx_topology.target]
Chris Jewell's avatar
Chris Jewell committed
148
            num_times = target_events.shape[1]
149

Chris Jewell's avatar
Chris Jewell committed
150
151
152
153
154
155
156
157
            proposal = FilteredEventTimeProposal(
                events=current_events,
                initial_state=self.parameters["initial_state"],
                topology=self.tx_topology,
                m_max=self.parameters["mmax"],
                d_max=self.parameters["dmax"],
                n_max=self.parameters["nmax"],
            )
Chris Jewell's avatar
Chris Jewell committed
158
            update = proposal.sample()
Chris Jewell's avatar
Chris Jewell committed
159
160
161

            move = update["move"]
            to_t = move["t"] + move["delta_t"]
Chris Jewell's avatar
Chris Jewell committed
162
163

            def true_fn():
Chris Jewell's avatar
Chris Jewell committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
                with tf.name_scope("true_fn"):
                    # Prob of fwd move
                    q_fwd = proposal.log_prob(update)
                    tf.debugging.assert_all_finite(q_fwd, "q_fwd is not finite")

                    # Propagate state
                    next_state = _move_events(
                        event_tensor=current_events,
                        event_id=self.tx_topology.target,
                        m=update["m"],
                        from_t=move["t"],
                        to_t=to_t,
                        n_move=move["x_star"],
                    )

179
                    next_target_log_prob = self.target_log_prob_fn(next_state)
Chris Jewell's avatar
Chris Jewell committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

                    # Calculate proposal mass ratio
                    rev_move = _reverse_move(move.copy())
                    rev_update = dict(m=update["m"], move=rev_move)
                    Q_rev = FilteredEventTimeProposal(  # pylint: disable-invalid-name
                        events=next_state,
                        initial_state=self.parameters["initial_state"],
                        topology=self.tx_topology,
                        m_max=self.parameters["mmax"],
                        d_max=self.parameters["dmax"],
                        n_max=self.parameters["nmax"],
                    )

                    # Prob of reverse move and q-ratio
                    q_rev = Q_rev.log_prob(rev_update)
                    log_acceptance_correction = tf.reduce_sum(q_rev - q_fwd)

                    return (
                        next_target_log_prob,
                        log_acceptance_correction,
200
                        next_state,
Chris Jewell's avatar
Chris Jewell committed
201
                    )
Chris Jewell's avatar
Chris Jewell committed
202
203

            def false_fn():
Chris Jewell's avatar
Chris Jewell committed
204
205
206
207
208
209
210
211
212
213
                with tf.name_scope("false_fn"):
                    next_target_log_prob = tf.constant(
                        -np.inf, dtype=current_events.dtype
                    )
                    log_acceptance_correction = tf.constant(
                        0.0, dtype=current_events.dtype
                    )
                    return (
                        next_target_log_prob,
                        log_acceptance_correction,
214
                        current_events,
Chris Jewell's avatar
Chris Jewell committed
215
                    )
Chris Jewell's avatar
Chris Jewell committed
216

Chris Jewell's avatar
Chris Jewell committed
217
            # Trap out-of-bounds moves that go outside [0, num_times)
Chris Jewell's avatar
Chris Jewell committed
218
            next_target_log_prob, log_acceptance_correction, next_state = tf.cond(
Chris Jewell's avatar
Chris Jewell committed
219
220
                tf.reduce_all(_is_within(to_t, 0, num_times)),
                true_fn=true_fn,
Chris Jewell's avatar
Chris Jewell committed
221
222
                false_fn=false_fn,
            )
Chris Jewell's avatar
Chris Jewell committed
223

Chris Jewell's avatar
Chris Jewell committed
224
225
226
227
228
            x_star_results = tf.scatter_nd(
                update["m"][:, tf.newaxis],
                tf.abs(move["x_star"] * move["delta_t"]),
                [current_events.shape[0]],
            )
229

Chris Jewell's avatar
Chris Jewell committed
230
231
            return [
                next_state,
232
                EventTimesKernelResults(
Chris Jewell's avatar
Chris Jewell committed
233
234
235
236
237
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    extra=tf.cast(x_star_results, current_events.dtype),
                ),
            ]
238
239

    def bootstrap_results(self, init_state):
Chris Jewell's avatar
Chris Jewell committed
240
        with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
241
242
            init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
            init_target_log_prob = self.target_log_prob_fn(init_state)
243
            return EventTimesKernelResults(
Chris Jewell's avatar
Chris Jewell committed
244
                log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
245
                target_log_prob=init_target_log_prob,
246
                extra=tf.zeros(init_state.shape[-3], dtype=DTYPE),
247
            )