model.py 6.67 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Functions for infection rates"""
import tensorflow as tf
Chris Jewell's avatar
Chris Jewell committed
3
import tensorflow_probability as tfp
Chris Jewell's avatar
Chris Jewell committed
4
from tensorflow_probability.python.internal import dtype_util
Chris Jewell's avatar
Chris Jewell committed
5
6
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import prefer_static as ps
Chris Jewell's avatar
Chris Jewell committed
7
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
8

9
from covid import config
Chris Jewell's avatar
Chris Jewell committed
10
from covid.impl.util import make_transition_matrix, batch_gather, transition_coords
11
12
from covid.rdata import load_age_mixing
from covid.pydata import load_commute_volume, load_mobility_matrix, load_population
Chris Jewell's avatar
Chris Jewell committed
13
14
15
16
from covid.impl.discrete_markov import (
    discrete_markov_simulation,
    discrete_markov_log_prob,
)
Chris Jewell's avatar
Chris Jewell committed
17

Chris Jewell's avatar
Chris Jewell committed
18
tla = tf.linalg
Chris Jewell's avatar
Chris Jewell committed
19
tfd = tfp.distributions
Chris Jewell's avatar
Chris Jewell committed
20

21
DTYPE = config.floatX
Chris Jewell's avatar
Chris Jewell committed
22

Chris Jewell's avatar
Chris Jewell committed
23

Chris Jewell's avatar
Chris Jewell committed
24
def power_iteration(A, tol=1e-3):
Chris Jewell's avatar
Chris Jewell committed
25
    b_k = tf.random.normal([A.shape[1], 1], dtype=A.dtype)
Chris Jewell's avatar
Chris Jewell committed
26
    epsilon = tf.constant(1.0, dtype=A.dtype)
Chris Jewell's avatar
Chris Jewell committed
27
28
29
30
31
    i = 0
    while tf.greater(epsilon, tol):
        b_k1 = tf.matmul(A, b_k)
        b_k1_norm = tf.linalg.norm(b_k1)
        b_k_new = b_k1 / b_k1_norm
Chris Jewell's avatar
Chris Jewell committed
32
        epsilon = tf.reduce_sum(tf.pow(b_k_new - b_k, 2))
Chris Jewell's avatar
Chris Jewell committed
33
34
35
36
        b_k = b_k_new
        i += 1
    return b_k, i

Chris Jewell's avatar
Chris Jewell committed
37

Chris Jewell's avatar
Chris Jewell committed
38
def rayleigh_quotient(A, b):
39
40
41
    b = tf.squeeze(b)
    numerator = tf.einsum("...i,...i->...", b, tf.linalg.matvec(A, b))
    denominator = tf.einsum("...i,...i->...", b, b)
Chris Jewell's avatar
Chris Jewell committed
42
    return numerator / denominator
43

Chris Jewell's avatar
Chris Jewell committed
44

45
46
def dense_to_block_diagonal(A, n_blocks):
    A_dense = tf.linalg.LinearOperatorFullMatrix(A)
Chris Jewell's avatar
Chris Jewell committed
47
    eye = tf.linalg.LinearOperatorIdentity(n_blocks, dtype=A.dtype)
48
49
    A_block = tf.linalg.LinearOperatorKronecker([eye, A_dense])
    return A_block
Chris Jewell's avatar
Chris Jewell committed
50

51

Chris Jewell's avatar
Chris Jewell committed
52
def load_data(paths, settings, dtype=DTYPE):
Chris Jewell's avatar
Chris Jewell committed
53
54
    M_tt, age_groups = load_age_mixing(paths["age_mixing_matrix_term"])
    M_hh, _ = load_age_mixing(paths["age_mixing_matrix_hol"])
Chris Jewell's avatar
Chris Jewell committed
55

56
    C = load_mobility_matrix(paths["mobility_matrix"])
57
    la_names = C.index.to_numpy()
Chris Jewell's avatar
Chris Jewell committed
58

59
    w_period = [settings["inference_period"][0], settings["inference_period"][1]]
Chris Jewell's avatar
Chris Jewell committed
60
    W = load_commute_volume(paths["commute_volume"], w_period)["percent"]
Chris Jewell's avatar
Chris Jewell committed
61

62
    pop = load_population(paths["population_size"])
Chris Jewell's avatar
Chris Jewell committed
63
64
65

    M_tt = M_tt.astype(DTYPE)
    M_hh = M_hh.astype(DTYPE)
66
    C = C.to_numpy().astype(DTYPE)
Chris Jewell's avatar
Chris Jewell committed
67
    np.fill_diagonal(C, 0.0)
68
69
    W = W.to_numpy().astype(DTYPE)
    pop = pop.to_numpy().astype(DTYPE)
Chris Jewell's avatar
Chris Jewell committed
70

Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
76
77
78
79
    return {
        "M_tt": M_tt,
        "M_hh": M_hh,
        "C": C,
        "la_names": la_names,
        "age_groups": age_groups,
        "W": W,
        "pop": pop,
    }
Chris Jewell's avatar
Chris Jewell committed
80
81


Chris Jewell's avatar
Chris Jewell committed
82
class DiscreteTimeStateTransitionModel(tfd.Distribution):
Chris Jewell's avatar
Chris Jewell committed
83
84
    def __init__(
        self,
85
86
87
88
89
90
        transition_rates,
        stoichiometry,
        initial_state,
        initial_step,
        time_delta,
        num_steps,
Chris Jewell's avatar
Chris Jewell committed
91
92
93
        validate_args=False,
        allow_nan_stats=True,
        name="DiscreteTimeStateTransitionModel",
Chris Jewell's avatar
Chris Jewell committed
94
    ):
95
96
97
98
99
100
101
102
        """Implements a discrete-time Markov jump process for a state transition model.

        :param transition_rates: a function of the form `fn(t, state)` taking the current time `t` and state tensor `state`.  This function returns a tensor which broadcasts to the first dimension of `stoichiometry`.
        :param stoichiometry: the stochiometry matrix for the state transition model, with rows representing transitions and columns representing states.
        :param initial_state: an initial state tensor with inner dimension equal to the first dimension of `stoichiometry`.
        :param initial_step: an offset giving the time `t` of the first timestep in the model.
        :param time_delta: the size of the time step to be used.
        :param num_steps: the number of time steps across which the model runs.
Chris Jewell's avatar
Chris Jewell committed
103
        """
Chris Jewell's avatar
Chris Jewell committed
104

Chris Jewell's avatar
Chris Jewell committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._transition_rates = transition_rates
            self._stoichiometry = np.array(stoichiometry, dtype=DTYPE)
            self._initial_state = initial_state
            self._initial_step = initial_step
            self._time_delta = time_delta
            self._num_steps = num_steps

            super().__init__(
                dtype=initial_state.dtype,
                reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name,
            )
122

Chris Jewell's avatar
Chris Jewell committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        self.dtype = initial_state.dtype

    @property
    def transition_rates(self):
        return self._transition_rates

    @property
    def stoichiometry(self):
        return self._stoichiometry

    @property
    def initial_state(self):
        return self._initial_state

    @property
    def initial_step(self):
        return self._initial_step

    @property
    def time_delta(self):
        return self._time_delta

    @property
    def num_steps(self):
        return self._num_steps

    def _batch_shape(self):
        return tf.TensorShape([])

    def _event_shape(self):
        shape = tf.TensorShape(
            [
                self.initial_state.shape[0],
                tf.get_static_value(self._num_steps),
                self._stoichiometry.shape[0],
            ]
159
        )
Chris Jewell's avatar
Chris Jewell committed
160
        return shape
161

Chris Jewell's avatar
Chris Jewell committed
162
    def _sample_n(self, n, seed=None):
163
164
165
166
167
168
        """Runs a simulation from the epidemic model

        :param param: a dictionary of model parameters
        :param state_init: the initial state
        :returns: a tuple of times and simulated states.
        """
Chris Jewell's avatar
Chris Jewell committed
169
170
171
172
173
174
175
176
177
178
179
180
181
        with tf.name_scope("DiscreteTimeStateTransitionModel.log_prob"):
            t, sim = discrete_markov_simulation(
                hazard_fn=self.transition_rates,
                state=self.initial_state,
                start=self.initial_step,
                end=self.initial_step + self.num_steps * self.time_delta,
                time_step=self.time_delta,
                seed=seed,
            )
            indices = transition_coords(self.stoichiometry)
            sim = batch_gather(sim, indices)
            sim = tf.transpose(sim, perm=(1, 0, 2))
            return tf.expand_dims(sim, 0)
182

Chris Jewell's avatar
Chris Jewell committed
183
    def _log_prob(self, y, **kwargs):
184
185
186
187
188
189
        """Calculates the log probability of observing epidemic events y
        :param y: a list of tensors.  The first is of shape [n_times] containing times,
                  the second is of shape [n_times, n_states, n_states] containing event matrices.
        :param param: a list of parameters
        :returns: a scalar giving the log probability of the epidemic
        """
Chris Jewell's avatar
Chris Jewell committed
190
        dtype = dtype_util.common_dtype([y, self.initial_state], dtype_hint=DTYPE)
191
        y = tf.convert_to_tensor(y, dtype)
Chris Jewell's avatar
Chris Jewell committed
192
        with tf.name_scope("CovidUKStochastic.log_prob"):
193
            hazard = self.transition_rates
Chris Jewell's avatar
Chris Jewell committed
194
            return discrete_markov_log_prob(
195
                y, self.initial_state, hazard, self.time_delta, self.stoichiometry
Chris Jewell's avatar
Chris Jewell committed
196
            )