occult_events_mh.py 6.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from collections import namedtuple
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream

from covid import config
from covid.impl.event_time_proposal import TransitionTopology, FilteredEventTimeProposal
from covid.impl.occult_proposal import AddOccultProposal, DelOccultProposal

tfd = tfp.distributions
DTYPE = config.floatX


OccultKernelResults = namedtuple(
    "KernelResults", ("log_acceptance_correction", "target_log_prob", "extra")
)


def _nonzero_rows(m):
    return tf.cast(tf.reduce_sum(m, axis=-1) > 0.0, m.dtype)


def _maybe_expand_dims(x):
    """If x is a scalar, give it at least 1 dimension"""
    x = tf.convert_to_tensor(x)
    if x.shape == ():
        return tf.expand_dims(x, axis=0)
    return x


def _add_events(events, m, t, x, x_star):
    """Adds `x_star` events to metapopulation `m`,
       time `t`, transition `x` in `events`."""
    x = _maybe_expand_dims(x)
    indices = tf.stack([m, t, x], axis=-1)
    return tf.tensor_scatter_nd_add(events, indices, x_star)


class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
    """UncalibratedEventTimesUpdate"""

    def __init__(
44
45
46
47
48
49
50
        self,
        target_log_prob_fn,
        target_event_id,
        nmax,
        t_range=None,
        seed=None,
        name=None,
51
52
53
54
55
    ):
        """An uncalibrated random walk for event times.
        :param target_log_prob_fn: the log density of the target distribution
        :param target_event_id: the position in the last dimension of the events
                                tensor that we wish to move
56
57
        :param t_range: a tuple containing earliest and latest times between which 
                        to update occults.
58
59
60
61
62
63
64
65
66
        :param seed: a random seed
        :param name: the name of the update step
        """
        self._seed_stream = SeedStream(seed, salt="UncalibratedOccultUpdate")
        self._name = name
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            target_event_id=target_event_id,
            nmax=nmax,
67
            t_range=t_range,
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            seed=seed,
            name=name,
        )
        self.tx_topology = TransitionTopology(None, target_event_id, None)

    @property
    def target_log_prob_fn(self):
        return self._parameters["target_log_prob_fn"]

    @property
    def target_event_id(self):
        return self._parameters["target_event_id"]

    @property
    def seed(self):
        return self._parameters["seed"]

    @property
    def name(self):
        return self._parameters["name"]

    @property
    def parameters(self):
        """Return `dict` of ``__init__`` arguments and their values."""
        return self._parameters

    @property
    def is_calibrated(self):
        return False

    def one_step(self, current_events, previous_kernel_results):
        """One update of event times.
        :param current_events: a [T, M, X] tensor containing number of events
                               per time t, metapopulation m,
                               and transition x.
        :param previous_kernel_results: an object of type
                                        UncalibratedRandomWalkResults.
        :returns: a tuple containing new_state and UncalibratedRandomWalkResults
        """
        with tf.name_scope("occult_rw/onestep"):

109
            def add_occult_fn():
110
111
                with tf.name_scope("true_fn"):
                    proposal = AddOccultProposal(
112
113
114
                        events=current_events,
                        n_max=self.parameters["nmax"],
                        t_range=self.parameters["t_range"],
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                    )
                    update = proposal.sample()
                    next_state = _add_events(
                        events=current_events,
                        m=update["m"],
                        t=update["t"],
                        x=self.tx_topology.target,
                        x_star=tf.cast(update["x_star"], current_events.dtype),
                    )
                    reverse = DelOccultProposal(next_state, self.tx_topology)
                    q_fwd = tf.reduce_sum(proposal.log_prob(update))
                    q_rev = tf.reduce_sum(reverse.log_prob(update))
                    log_acceptance_correction = q_rev - q_fwd
                return update, next_state, log_acceptance_correction

130
            def del_occult_fn():
131
132
133
134
135
136
137
138
139
140
                with tf.name_scope("false_fn"):
                    proposal = DelOccultProposal(current_events, self.tx_topology)
                    update = proposal.sample()
                    next_state = _add_events(
                        events=current_events,
                        m=update["m"],
                        t=update["t"],
                        x=[self.tx_topology.target],
                        x_star=tf.cast(-update["x_star"], current_events.dtype),
                    )
141
142
143
144
145
                    reverse = AddOccultProposal(
                        events=next_state,
                        n_max=self.parameters["nmax"],
                        t_range=self.parameters["t_range"],
                    )
146
147
148
149
150
151
152
153
                    q_fwd = tf.reduce_sum(proposal.log_prob(update))
                    q_rev = tf.reduce_sum(reverse.log_prob(update))
                    log_acceptance_correction = q_rev - q_fwd

                return update, next_state, log_acceptance_correction

            u = tfd.Uniform().sample()
            delta, next_state, log_acceptance_correction = tf.cond(
154
155
156
157
158
159
160
                (u < 0.5)
                & (
                    tf.math.count_nonzero(current_events[..., self.tx_topology.target])
                    > 0
                ),
                del_occult_fn,
                add_occult_fn,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            )
            # tf.debugging.assert_non_negative(
            #     next_state, message="Negative occults occurred"
            # )

            next_target_log_prob = self.target_log_prob_fn(next_state)
            return [
                next_state,
                OccultKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    extra=tf.concat([delta["m"], delta["t"], delta["x_star"]], axis=0),
                ),
            ]

    def bootstrap_results(self, init_state):
        with tf.name_scope("uncalibrated_event_times_rw/bootstrap_results"):
            init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
            init_target_log_prob = self.target_log_prob_fn(init_state)
            return OccultKernelResults(
                log_acceptance_correction=tf.constant(0.0, dtype=DTYPE),
                target_log_prob=init_target_log_prob,
                extra=tf.constant([0, 0, 0], dtype=tf.int32),
            )