model.py 5.61 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
"""Functions for infection rates"""
import tensorflow as tf
Chris Jewell's avatar
Chris Jewell committed
3
4
import tensorflow_probability as tfp
import numpy as np
Chris Jewell's avatar
Chris Jewell committed
5
6
7

from covid.impl.chainbinom_simulate import chain_binomial_simulate

Chris Jewell's avatar
Chris Jewell committed
8
9
10
11
tode = tfp.math.ode
tla = tf.linalg


Chris Jewell's avatar
Chris Jewell committed
12
def power_iteration(A, tol=1e-3):
13
14
    b_k = tf.random.normal([A.shape[1], 1], dtype=tf.float64)
    epsilon = tf.constant(1., dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
15
16
17
18
19
20
21
22
23
24
    i = 0
    while tf.greater(epsilon, tol):
        b_k1 = tf.matmul(A, b_k)
        b_k1_norm = tf.linalg.norm(b_k1)
        b_k_new = b_k1 / b_k1_norm
        epsilon = tf.reduce_sum(tf.pow(b_k_new-b_k, 2))
        b_k = b_k_new
        i += 1
    return b_k, i

Chris Jewell's avatar
Chris Jewell committed
25

Chris Jewell's avatar
Chris Jewell committed
26
27
28
29
30
def rayleigh_quotient(A, b):
    b = tf.reshape(b, [b.shape[0], 1])
    numerator = tf.matmul(tf.transpose(b), tf.matmul(A, b))
    denominator = tf.matmul(tf.transpose(b), b)
    return numerator / denominator
31

Chris Jewell's avatar
Chris Jewell committed
32

33
34
def dense_to_block_diagonal(A, n_blocks):
    A_dense = tf.linalg.LinearOperatorFullMatrix(A)
35
    eye = tf.linalg.LinearOperatorIdentity(n_blocks, dtype=tf.float64)
36
37
    A_block = tf.linalg.LinearOperatorKronecker([eye, A_dense])
    return A_block
Chris Jewell's avatar
Chris Jewell committed
38

39
40

class CovidUKODE:  # TODO: add background case importation rate to the UK, e.g. \epsilon term.
Chris Jewell's avatar
Chris Jewell committed
41
    def __init__(self, M_tt, M_hh, C, N, start, end, holidays, bg_max_t, t_step):
Chris Jewell's avatar
Chris Jewell committed
42
43
        """Represents a CovidUK ODE model

44
45
46
47
        :param K_tt: a MxM matrix of age group mixing in term time
        :param K_hh: a MxM matrix of age group mixing in holiday time
        :param holidays: a list of length-2 tuples containing dates of holidays
        :param C: a n_ladsxn_lads matrix of inter-LAD commuting
Chris Jewell's avatar
Chris Jewell committed
48
49
50
        :param N: a vector of population sizes in each LAD
        :param n_lads: the number of LADS
        """
51
52
        self.n_ages = M_tt.shape[0]
        self.n_lads = C.shape[0]
Chris Jewell's avatar
Chris Jewell committed
53
54
        self.M_tt = tf.convert_to_tensor(M_tt, dtype=tf.float64)
        self.M_hh = tf.convert_to_tensor(M_hh, dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
55

Chris Jewell's avatar
Chris Jewell committed
56
        self.Kbar = tf.reduce_mean(self.M_tt)
Chris Jewell's avatar
Chris Jewell committed
57

58
        C = tf.cast(C, tf.float64)
Chris Jewell's avatar
Chris Jewell committed
59
60
61
        self.C = tla.LinearOperatorFullMatrix(C + tf.transpose(C))
        shp = tla.LinearOperatorFullMatrix(np.ones_like(M_tt, dtype=np.float64))
        self.C = tla.LinearOperatorKronecker([self.C, shp])
62

63
        self.N = tf.constant(N, dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
64
65
        N_matrix = tf.reshape(self.N, [self.n_lads, self.n_ages])
        N_sum = tf.reduce_sum(N_matrix, axis=1)
66
        N_sum = N_sum[:, None] * tf.ones([1, self.n_ages], dtype=tf.float64)
Chris Jewell's avatar
Chris Jewell committed
67
        self.N_sum = tf.reshape(N_sum, [-1])
Chris Jewell's avatar
Chris Jewell committed
68

69
        self.times = np.arange(start, end, np.timedelta64(t_step, 'D'))
Chris Jewell's avatar
Chris Jewell committed
70
71
72
73
        self.school_hols = [tf.constant((holidays[0] - start) // np.timedelta64(1, 'D'), dtype=tf.float64),
                            tf.constant((holidays[1] - start) // np.timedelta64(1, 'D'), dtype=tf.float64)]

        self.bg_max_t = tf.convert_to_tensor(bg_max_t, dtype=tf.float64)
74
75
        self.solver = tode.DormandPrince()

Chris Jewell's avatar
Chris Jewell committed
76
    def make_h(self, param):
Chris Jewell's avatar
Chris Jewell committed
77

78
        def h_fn(t, state):
Chris Jewell's avatar
Chris Jewell committed
79
80
            state = tf.unstack(state, axis=0)
            S, E, I, R = state
Chris Jewell's avatar
Chris Jewell committed
81
82
83
84
85
86
87
88
89
90

            M = tf.where(tf.less_equal(self.school_hols[0], t) & tf.less(t, self.school_hols[1]),
                         self.M_hh, self.M_tt)

            M = dense_to_block_diagonal(M, self.n_lads)

            epsilon = tf.where(t < self.bg_max_t, param['epsilon'], tf.constant(0., dtype=tf.float64))

            infec_rate = param['beta1'] * tla.matvec(M, I)
            infec_rate += param['beta1'] * param['beta2'] * self.Kbar * tla.matvec(self.C, I / self.N_sum)
Chris Jewell's avatar
Chris Jewell committed
91
            infec_rate = S / self.N * (infec_rate + epsilon)
Chris Jewell's avatar
Chris Jewell committed
92
93
94
95
96
97
98
99

            dS = -infec_rate
            dE = infec_rate - param['nu'] * E
            dI = param['nu'] * E - param['gamma'] * I
            dR = param['gamma'] * I

            df = tf.stack([dS, dE, dI, dR])
            return df
100

Chris Jewell's avatar
Chris Jewell committed
101
102
        return h_fn

103
104
    def create_initial_state(self, init_matrix=None):
        if init_matrix is None:
105
            I = np.zeros(self.N.shape, dtype=np.float64)
Chris Jewell's avatar
Chris Jewell committed
106
107
108
109
110
111
112
            I[149*17+10] = 30. # Middle-aged in Surrey
        else:
            np.testing.assert_array_equal(init_matrix.shape, [self.n_lads, self.n_ages],
                                          err_msg=f"init_matrix does not have shape [<num lads>,<num ages>] \
                                          ({self.n_lads},{self.n_ages})")
            I = init_matrix.flatten()
        S = self.N - I
113
114
        E = np.zeros(self.N.shape, dtype=np.float64)
        R = np.zeros(self.N.shape, dtype=np.float64)
Chris Jewell's avatar
Chris Jewell committed
115
        return np.stack([S, E, I, R])
116

117
    def simulate(self, param, state_init, solver_state=None):
Chris Jewell's avatar
Chris Jewell committed
118
        h = self.make_h(param)
119
120
121
        t = np.arange(self.times.shape[0])
        results = self.solver.solve(ode_fn=h, initial_time=t[0], initial_state=state_init,
                                    solution_times=t, previous_solver_internal_state=solver_state)
122
        return results.times, results.states, results.solver_internal_state
Chris Jewell's avatar
Chris Jewell committed
123
124

    def ngm(self, param):
Chris Jewell's avatar
Chris Jewell committed
125
        infec_rate = param['beta1'] * dense_to_block_diagonal(self.M_tt, self.n_lads).to_dense()
126
        infec_rate += param['beta1'] * param['beta2'] * self.Kbar * self.C.to_dense() / self.N_sum[None, :]
Chris Jewell's avatar
Chris Jewell committed
127
128
129
130
131
132
        ngm = infec_rate / param['gamma']
        return ngm

    def eval_R0(self, param, tol=1e-8):
        ngm = self.ngm(param)
        # Dominant eigen value by power iteration
133
        dom_eigen_vec, i = power_iteration(ngm, tol=tf.cast(tol, tf.float64))
Chris Jewell's avatar
Chris Jewell committed
134
135
        R0 = rayleigh_quotient(ngm, dom_eigen_vec)
        return tf.squeeze(R0), i
136
137
138
139
140
141
142
143


def covid19uk_logp(y, sim, phi):
        # Sum daily increments in removed
        r_incr = sim[1:, 3, :] - sim[:-1, 3, :]
        r_incr = tf.reduce_sum(r_incr, axis=1)
        y_ = tfp.distributions.Poisson(rate=phi*r_incr)
        return y_.log_prob(y)