covid_ode.py 6.8 KB
Newer Older
1
2
import optparse
import time
Chris Jewell's avatar
Chris Jewell committed
3
4

import h5py
5
import matplotlib.pyplot as plt
Chris Jewell's avatar
Chris Jewell committed
6
7
import tensorflow as tf
import yaml
Chris Jewell's avatar
Chris Jewell committed
8

9
10
from covid.model import CovidUKODE
from covid.rdata import *
Chris Jewell's avatar
Chris Jewell committed
11
12


13
14
def sanitise_parameter(par_dict):
    """Sanitises a dictionary of parameters"""
Chris Jewell's avatar
Chris Jewell committed
15
    par = ['beta1', 'beta2', 'nu','gamma']
16
17
    d = {key: np.float64(par_dict[key]) for key in par}
    return d
Chris Jewell's avatar
Chris Jewell committed
18

Chris Jewell's avatar
Chris Jewell committed
19

20
def sanitise_settings(par_dict):
Chris Jewell's avatar
Chris Jewell committed
21
22
23
24
    d = {'start': np.datetime64(par_dict['start']),
         'end': np.datetime64(par_dict['end']),
         'time_step': float(par_dict['time_step']),
         'holiday': np.array([np.datetime64(date) for date in par_dict['holiday']])}
25
    return d
Chris Jewell's avatar
Chris Jewell committed
26

Chris Jewell's avatar
Chris Jewell committed
27

Chris Jewell's avatar
Chris Jewell committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def seed_areas(N, names, age_group=8, num_la=152, num_age=17, n_seed=30.):
    areas = ['Inner London',
             'Outer London',
             'West Midlands (Met County)',
             'Greater Manchester (Met County)']

    names_matrix = names['Area.name.2'].to_numpy().reshape([num_la, num_age])

    seed_areas = np.in1d(names_matrix[:, age_group], areas)
    N_matrix = N.reshape([num_la, num_age])  # LA x Age

    pop_size_sub = N_matrix[seed_areas, age_group]  # Gather
    n = np.round(n_seed * pop_size_sub / pop_size_sub.sum())

    seeding = np.zeros_like(N_matrix)
    seeding[seed_areas, age_group] = n  # Scatter
    return seeding


def sum_age_groups(sim):
    infec = sim[:, 2, :]
    infec = infec.reshape([infec.shape[0], 152, 17])
    infec_uk = infec.sum(axis=2)
    return infec_uk


def sum_la(sim):
    infec = sim[:, 2, :]
    infec = infec.reshape([infec.shape[0], 152, 17])
    infec_uk = infec.sum(axis=1)
    return infec_uk


def sum_total_removals(sim):
    remove = sim[:, 3, :]
    return remove.sum(axis=1)


def final_size(sim):
    remove = sim[:, 3, :]
    remove = remove.reshape([remove.shape[0], 152, 17])
    fs = remove[-1, :, :].sum(axis=0)
    return fs


def write_hdf5(filename, param, t, sim):
    with h5py.File(filename, "w") as f:
        dset_sim = f.create_dataset("simulation", sim.shape, dtype='f')
        dset_sim[:] = sim
        dset_t = f.create_dataset("time", t.shape, dtype='f')
        dset_t[:] = t
        grp_param = f.create_group("parameter")
        for k, v in param.items():
            d_beta = grp_param.create_dataset(k, [1], dtype='f')
            d_beta[()] = v



def plot_total_curve(sim):
    infec_uk = sum_la(sim)
    infec_uk = infec_uk.sum(axis=1)
    removals = sum_total_removals(sim)
    times = np.datetime64('2020-02-20') + np.arange(removals.shape[0])
    plt.plot(times, infec_uk, 'r-', label='Infected')
    plt.plot(times, removals, 'b-', label='Removed')
    plt.title('UK total cases')
    plt.xlabel('Date')
    plt.ylabel('Num infected or removed')
    plt.legend()


def plot_by_age(sim, labels, t0=np.datetime64('2020-02-20'), ax=None):
    if ax is None:
        ax = plt.figure().gca()
    infec_uk = sum_la(sim)
    total_uk = infec_uk.mean(axis=1)
    t = t0 + np.arange(infec_uk.shape[0])
    colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
    for i in range(infec_uk.shape[1]):
        ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
    ax.plot(t, total_uk, '-', color='black', label='Mean')
    return ax


def plot_by_la(sim, labels,t0=np.datetime64('2020-02-20'), ax=None):
    if ax is None:
        ax = plt.figure().gca()
    infec_uk = sum_age_groups(sim)
    total_uk = infec_uk.mean(axis=1)
    t = t0 + np.arange(infec_uk.shape[0])
    colours = plt.cm.viridis(np.linspace(0., 1., infec_uk.shape[1]))
    for i in range(infec_uk.shape[1]):
        ax.plot(t, infec_uk[:, i], 'r-', alpha=0.4, color=colours[i], label=labels[i])
    ax.plot(t, total_uk, '-', color='black', label='Mean')
    return ax


def draw_figs(sim, N):
    # Attack rate
    N = N.reshape([152, 17]).sum(axis=0)
    fs = final_size(sim)
    attack_rate = fs / N
    print("Attack rate:", attack_rate)
    print("Overall attack rate: ", np.sum(fs)/np.sum(N))

    # Total UK epidemic curve
    plot_total_curve(sim)
    plt.xticks(rotation=45, horizontalalignment="right")
    plt.savefig('total_uk_curve.pdf')
    plt.show()

    # TotalUK epidemic curve by age-group
    fig, ax = plt.subplots(1, 2, figsize=[24, 12])
    plot_by_la(sim, la_names, ax=ax[0])
    plot_by_age(sim, age_groups, ax=ax[1])
    ax[1].legend()
    plt.xticks(rotation=45, horizontalalignment="right")
    fig.autofmt_xdate()
    plt.savefig('la_age_infec_curves.pdf')
    plt.show()

    # Plot attack rate
    plt.figure(figsize=[4, 2])
    plt.plot(age_groups, attack_rate, 'o-')
    plt.xticks(rotation=90)
    plt.title('Age-specific attack rate')
    plt.savefig('age_attack_rate.pdf')
    plt.show()


158
if __name__ == '__main__':
Chris Jewell's avatar
Chris Jewell committed
159

160
161
162
163
164
165
    parser = optparse.OptionParser()
    parser.add_option("--config", "-c", dest="config",
                      help="configuration file")
    options, args = parser.parse_args()
    with open(options.config, 'r') as ymlfile:
        config = yaml.load(ymlfile)
Chris Jewell's avatar
Chris Jewell committed
166

Chris Jewell's avatar
Chris Jewell committed
167
168
169
170
    K_tt, age_groups = load_age_mixing(config['data']['age_mixing_matrix_term'])
    K_hh, _ = load_age_mixing(config['data']['age_mixing_matrix_hol'])

    T, la_names = load_mobility_matrix(config['data']['mobility_matrix'])
171
    np.fill_diagonal(T, 0.)
Chris Jewell's avatar
Chris Jewell committed
172
173

    N, n_names = load_population(config['data']['population_size'])
Chris Jewell's avatar
Chris Jewell committed
174

175
176
    param = sanitise_parameter(config['parameter'])
    settings = sanitise_settings(config['settings'])
Chris Jewell's avatar
Chris Jewell committed
177
178


Chris Jewell's avatar
Chris Jewell committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Straight, no school closures
    model_term = CovidUKODE(K_tt, T, N)
    model_holiday = CovidUKODE(K_hh/2., T/2., N)

    seeding = seed_areas(N, n_names)  # Seed 40-44 age group, 30 seeds by popn size
    state_init = model_term.create_initial_state(init_matrix=seeding)

    print('R_term=', model_term.eval_R0(param))
    print('R_holiday=', model_holiday.eval_R0(param))

    # School holidays and closures
    @tf.function
    def simulate():
        t0, sim_0 = model_term.simulate(param, state_init,
                                        settings['start'], settings['holiday'][0],
                                        settings['time_step'])
        t1, sim_1 = model_holiday.simulate(param, sim_0[-1, :, :],
                                           settings['holiday'][0], settings['holiday'][1],
                                           settings['time_step'])
        t2, sim_2 = model_term.simulate(param, sim_1[-1, :, :],
                                        settings['holiday'][1], settings['end'],
                                        settings['time_step'])
        t = tf.concat([t0, t1, t2], axis=0)
        sim = tf.concat([tf.expand_dims(state_init, axis=0), sim_0, sim_1, sim_2], axis=0)
        return t, sim

205
    start = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
206
    t, sim = simulate()
207
    end = time.perf_counter()
Chris Jewell's avatar
Chris Jewell committed
208
209
210
211
212
213
214
215
    print(f'Complete in {end-start} seconds')

    draw_figs(sim.numpy(), N)

    if 'simulation' in config['output']:
        write_hdf5(config['output']['simulation'], param, t, sim)

    print(f"Complete in {end-start} seconds")