model.py 7.32 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
"""Functions for infection rates"""
2
from warnings import warn
Chris Jewell's avatar
Chris Jewell committed
3
import tensorflow as tf
Chris Jewell's avatar
Chris Jewell committed
4
import tensorflow_probability as tfp
5

Chris Jewell's avatar
Chris Jewell committed
6
7
tode = tfp.math.ode
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
8
9
10

from covid.impl.chainbinom_simulate import chain_binomial_simulate

Chris Jewell's avatar
Chris Jewell committed
11
def power_iteration(A, tol=1e-3):
12
13
    b_k = tf.random.normal([A.shape[1], 1], dtype=tf.float64)
    epsilon = tf.constant(1., dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    i = 0
    while tf.greater(epsilon, tol):
        b_k1 = tf.matmul(A, b_k)
        b_k1_norm = tf.linalg.norm(b_k1)
        b_k_new = b_k1 / b_k1_norm
        epsilon = tf.reduce_sum(tf.pow(b_k_new-b_k, 2))
        b_k = b_k_new
        i += 1
    return b_k, i

#@tf.function
def rayleigh_quotient(A, b):
    b = tf.reshape(b, [b.shape[0], 1])
    numerator = tf.matmul(tf.transpose(b), tf.matmul(A, b))
    denominator = tf.matmul(tf.transpose(b), b)
    return numerator / denominator
30

Chris Jewell's avatar
Chris Jewell committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class CovidUK:
    def __init__(self, K, T, W):
        self.K = K
        self.T = T
        self.W = W

        self.stoichiometry = [[-1, 1, 0, 0],
                              [0, -1, 1, 0],
                              [0, 0, -1, 1]]

    def h(self, state):
        state = tf.unstack(state, axis=0)
        S, E, I, R = state

        hazard_rates = tf.stack([
Chris Jewell's avatar
Chris Jewell committed
46
            self.param['beta1'] * tf.dot(self.T, tf.dot(self.K, I))/self.K.shape[0],
Chris Jewell's avatar
Chris Jewell committed
47
48
49
50
51
            self.param['nu'],
            self.param['gamma']
        ])
        return hazard_rates

52
    #@tf.function
Chris Jewell's avatar
Chris Jewell committed
53
54
55
56
57
58
59
60
61
    def sample(self, initial_state, time_lims, param):
        self.param = param
        return chain_binomial_simulate(self.h, initial_state, time_lims[0],
                                       time_lims[1], 1., self.stoichiometry)


class Homogeneous:
    def __init__(self):
        self.stoichiometry = tf.constant([[-1, 1, 0, 0],
62
63
                                          [0, -1, 1, 0],
                                          [0, 0, -1, 1]], dtype=tf.float32)
Chris Jewell's avatar
Chris Jewell committed
64
65
66
67
68
69

    def h(self, state):
        state = tf.unstack(state, axis=0)
        S, E, I, R = state

        hazard_rates = tf.stack([
70
            self.param['beta'] * I / tf.reduce_sum(state),
Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
            self.param['nu'] * tf.ones_like(I),
            self.param['gamma'] * tf.ones_like(I)
        ])
        return hazard_rates

Chris Jewell's avatar
Chris Jewell committed
76
    @tf.function
Chris Jewell's avatar
Chris Jewell committed
77
78
79
80
    def sample(self, initial_state, time_lims, param):
        self.param = param
        return chain_binomial_simulate(self.h, initial_state, time_lims[0],
                                       time_lims[1], 1., self.stoichiometry)
Chris Jewell's avatar
Chris Jewell committed
81
82


83
84
def dense_to_block_diagonal(A, n_blocks):
    A_dense = tf.linalg.LinearOperatorFullMatrix(A)
85
    eye = tf.linalg.LinearOperatorIdentity(n_blocks, dtype=tf.float64)
86
87
    A_block = tf.linalg.LinearOperatorKronecker([eye, A_dense])
    return A_block
Chris Jewell's avatar
Chris Jewell committed
88

89
90

class CovidUKODE:  # TODO: add background case importation rate to the UK, e.g. \epsilon term.
Chris Jewell's avatar
Chris Jewell committed
91
    def __init__(self, M_tt, M_hh, C, N, start, end, holidays, bg_max_t, t_step):
Chris Jewell's avatar
Chris Jewell committed
92
93
        """Represents a CovidUK ODE model

94
95
96
97
        :param K_tt: a MxM matrix of age group mixing in term time
        :param K_hh: a MxM matrix of age group mixing in holiday time
        :param holidays: a list of length-2 tuples containing dates of holidays
        :param C: a n_ladsxn_lads matrix of inter-LAD commuting
Chris Jewell's avatar
Chris Jewell committed
98
99
100
        :param N: a vector of population sizes in each LAD
        :param n_lads: the number of LADS
        """
101
102
        self.n_ages = M_tt.shape[0]
        self.n_lads = C.shape[0]
Chris Jewell's avatar
Chris Jewell committed
103

104
105
106
        self.Kbar = tf.reduce_mean(tf.cast(M_tt, tf.float64))
        self.M = tf.tuple([dense_to_block_diagonal(tf.cast(M_tt, tf.float64), self.n_lads),
                           dense_to_block_diagonal(tf.cast(M_hh, tf.float64), self.n_lads)])
Chris Jewell's avatar
Chris Jewell committed
107

108
        C = tf.cast(C, tf.float64)
109
        self.C = tf.linalg.LinearOperatorFullMatrix(C + tf.transpose(C))
110
        shp = tf.linalg.LinearOperatorFullMatrix(np.ones_like(M_tt, dtype=np.float64))
111
        self.C = tf.linalg.LinearOperatorKronecker([self.C, shp])
112

113
        self.N = tf.constant(N, dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
114
115
        N_matrix = tf.reshape(self.N, [self.n_lads, self.n_ages])
        N_sum = tf.reduce_sum(N_matrix, axis=1)
116
        N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
117
        self.N_sum = tf.reshape(N_sum, [-1])
Chris Jewell's avatar
Chris Jewell committed
118

119
120
        self.times = np.arange(start, end, np.timedelta64(t_step, 'D'))
        m_select = (np.less_equal(holidays[0], self.times) & np.less(self.times, holidays[1])).astype(np.int64)
Chris Jewell's avatar
Chris Jewell committed
121
        self.m_select = tf.constant(m_select, dtype=tf.int64)
122
        self.bg_select = tf.constant(np.less(self.times, bg_max_t), dtype=tf.int64)
123
124
        self.solver = tode.DormandPrince()

Chris Jewell's avatar
Chris Jewell committed
125
    def make_h(self, param):
Chris Jewell's avatar
Chris Jewell committed
126

127
        def h_fn(t, state):
Chris Jewell's avatar
Chris Jewell committed
128
129
            state = tf.unstack(state, axis=0)
            S, E, I, R = state
Chris Jewell's avatar
Chris Jewell committed
130
131
            t = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.m_select.shape[0]-1)
            m_switch = tf.gather(self.m_select, t)
132
            epsilon = param['epsilon'] * tf.cast(tf.gather(self.bg_select, t), tf.float64)
133
134
135
136
137
            if m_switch == 0:
               infec_rate = param['beta1'] * tf.linalg.matvec(self.M[0], I)
            else:
                infec_rate = param['beta1'] * tf.linalg.matvec(self.M[1], I)
            infec_rate += param['beta1'] * param['beta2'] * self.Kbar * tf.linalg.matvec(self.C, I / self.N_sum)
Chris Jewell's avatar
Chris Jewell committed
138
            infec_rate = S / self.N * (infec_rate + epsilon)
Chris Jewell's avatar
Chris Jewell committed
139
140
141
142
143
144
145
146

            dS = -infec_rate
            dE = infec_rate - param['nu'] * E
            dI = param['nu'] * E - param['gamma'] * I
            dR = param['gamma'] * I

            df = tf.stack([dS, dE, dI, dR])
            return df
147

Chris Jewell's avatar
Chris Jewell committed
148
149
        return h_fn

150
151
    def create_initial_state(self, init_matrix=None):
        if init_matrix is None:
152
            I = np.zeros(self.N.shape, dtype=np.float64)
Chris Jewell's avatar
Chris Jewell committed
153
154
155
156
157
158
159
            I[149*17+10] = 30. # Middle-aged in Surrey
        else:
            np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
                                          err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
                                          ({self.n_lads},{self.n_ages})")
            I = init_matrix.flatten()
        S = self.N - I
160
161
        E = np.zeros(self.N.shape, dtype=np.float64)
        R = np.zeros(self.N.shape, dtype=np.float64)
Chris Jewell's avatar
Chris Jewell committed
162
        return np.stack([S, E, I, R])
163

Chris Jewell's avatar
Chris Jewell committed
164
    @tf.function
165
    def simulate(self, param, state_init, solver_state=None):
Chris Jewell's avatar
Chris Jewell committed
166
        h = self.make_h(param)
167
168
169
        t = np.arange(self.times.shape[0])
        results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init,
                                    solution_times=t, previous_solver_internal_state=solver_state)
170
        return results.times, results.states, results.solver_internal_state
Chris Jewell's avatar
Chris Jewell committed
171
172

    def ngm(self, param):
173
174
        infec_rate = param['beta1'] * self.M[0].to_dense()
        infec_rate += param['beta1'] * param['beta2'] * self.Kbar * self.C.to_dense() / self.N_sum[None, :]
Chris Jewell's avatar
Chris Jewell committed
175
176
177
178
179
180
        ngm = infec_rate / param['gamma']
        return ngm

    def eval_R0(self, param, tol=1e-8):
        ngm = self.ngm(param)
        # Dominant eigen value by power iteration
181
        dom_eigen_vec, i = power_iteration(ngm, tol=tf.cast(tol, tf.float64))
Chris Jewell's avatar
Chris Jewell committed
182
183
        R0 = rayleigh_quotient(ngm, dom_eigen_vec)
        return tf.squeeze(R0), i
184
185
186
187
188
189
190
191


def covid19uk_logp(y, sim, phi):
        # Sum daily increments in removed
        r_incr = sim[1:, 3, :] - sim[:-1, 3, :]
        r_incr = tf.reduce_sum(r_incr, axis=1)
        y_ = tfp.distributions.Poisson(rate=phi*r_incr)
        return y_.log_prob(y)