prediction.py 5.98 KB
Newer Older
1
"""Prediction functions"""
2
import optparse
3
4

import seaborn
5
import yaml
6
7
8
import pickle as pkl
import numpy as np
import pandas as pd
9
import tensorflow as tf
10
from tensorflow_probability import stats as tfs
11
import matplotlib.pyplot as plt
12
import h5py
13
14

from covid.model import CovidUKODE
15
from covid.rdata import load_age_mixing, load_mobility_matrix, load_population
16
17
18
19
20
21
22
23
24
25
26
27
from covid.util import sanitise_settings, sanitise_parameter, seed_areas, doubling_time


def save_sims(sims, la_names, age_groups, filename):
    f = h5py.File(filename, 'w')
    dset_sim = f.create_dataset('prediction', data=sims)
    la_long = np.repeat(la_names, age_groups.shape[0]).astype(np.string_)
    age_long = np.tile(age_groups, la_names.shape[0]).astype(np.string_)
    dset_dims = f.create_dataset("dimnames", data=[b'sim_id', b't', b'state', b'la_age'])
    dset_la = f.create_dataset('la_names', data=la_long)
    dset_age = f.create_dataset('age_names', data=age_long)
    f.close()
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


if __name__ == '__main__':

    parser = optparse.OptionParser()
    parser.add_option("--config", "-c", dest="config",
                      help="configuration file")
    options, args = parser.parse_args()
    with open(options.config, 'r') as ymlfile:
        config = yaml.load(ymlfile)

    K_tt, age_groups = load_age_mixing(config['data']['age_mixing_matrix_term'])
    K_hh, _ = load_age_mixing(config['data']['age_mixing_matrix_hol'])

    T, la_names = load_mobility_matrix(config['data']['mobility_matrix'])
    np.fill_diagonal(T, 0.)

    N, n_names = load_population(config['data']['population_size'])

    param = sanitise_parameter(config['parameter'])
Chris Jewell's avatar
Chris Jewell committed
48
    param['epsilon'] = 0.0
49
50
    settings = sanitise_settings(config['settings'])

51
52
53
54
55
56
57
58
59
60
61
62
63
64
    case_reports = pd.read_csv(config['data']['reported_cases'])
    case_reports['DateVal'] = pd.to_datetime(case_reports['DateVal'])
    date_range = [case_reports['DateVal'].min(), case_reports['DateVal'].max()]
    y = case_reports['CumCases'].to_numpy()
    y_incr = y[1:] - y[-1:]

    with open('pi_beta_2020-03-15.pkl', 'rb') as f:
        pi_beta = pkl.load(f)

    # Predictive distribution of epidemic spread
    data_dates = np.arange(date_range[0],
                           date_range[1]+np.timedelta64(1,'D'),
                           np.timedelta64(1, 'D'))
    simulator = CovidUKODE(K_tt, K_hh, T, N, date_range[0] - np.timedelta64(1, 'D'),
65
                           np.datetime64('2020-09-01'), settings['holiday'], settings['bg_max_time'], 1)
66
    seeding = seed_areas(N, n_names)  # Seed 40-44 age group, 30 seeds by popn size
67
68
    state_init = simulator.create_initial_state(init_matrix=seeding)

Chris Jewell's avatar
Chris Jewell committed
69
    @tf.function
70
    def prediction(epsilon, beta):
71
        sims = tf.TensorArray(tf.float32, size=beta.shape[0])
72
73
        R0 = tf.TensorArray(tf.float32, size=beta.shape[0])
        #d_time = tf.TensorArray(tf.float32, size=beta.shape[0])
74
75
        for i in tf.range(beta.shape[0]):
            p = param
76
            p['epsilon'] = epsilon[i]
77
78
            p['beta1'] = beta[i]
            t, sim, solver_results = simulator.simulate(p, state_init)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            r = simulator.eval_R0(p)
            R0 = R0.write(i, r[0])
            #d_time = d_time.write(i, doubling_time(t, sim, '2002-03-01', '2002-04-01'))
            #sim_aggr = tf.reduce_sum(sim, axis=2)
            sims = sims.write(i, sim)
        return sims.gather(range(beta.shape[0])), R0.gather(range(beta.shape[0]))

    draws = [pi_beta[0].numpy()[np.arange(500, pi_beta[0].shape[0], 10)],
             pi_beta[1].numpy()[np.arange(500, pi_beta[1].shape[0], 10)]]
    with tf.device('/CPU:0'):
        sims, R0 = prediction(draws[0], draws[1])
    sims = tf.stack(sims) # shape=[n_sims, n_times, n_states, n_metapops]

    save_sims(sims, la_names, age_groups, 'pred_2020-03-15.h5')
93

94
    dub_time = [doubling_time(simulator.times, sim, '2020-03-01', '2020-04-01') for sim in sims.numpy()]
95

96
97
98
99
100
    # Sum over country
    sims = tf.reduce_sum(sims, axis=3)

    print("Plotting...", flush=True)
    dates = np.arange(date_range[0]-np.timedelta64(1, 'D'), np.datetime64('2020-09-01'),
101
102
103
104
                      np.timedelta64(1, 'D'))
    total_infected = tfs.percentile(tf.reduce_sum(sims[:, :, 1:3], axis=2), q=[2.5, 50, 97.5], axis=0)
    removed = tfs.percentile(sims[:, :, 3], q=[2.5, 50, 97.5], axis=0)
    removed_observed = tfs.percentile(removed * 0.1, q=[2.5, 50, 97.5], axis=0)
105
106

    fig = plt.figure()
107
108
109
110
111
112
113
114
115
116
117
118
    filler = plt.fill_between(dates, total_infected[0, :], total_infected[2, :], color='lightgray', label="95% credible interval")
    plt.fill_between(dates, removed[0, :], removed[2, :], color='lightgray')
    plt.fill_between(dates, removed_observed[0, :], removed_observed[2, :], color='lightgray')
    ti_line = plt.plot(dates, total_infected[1, :], '-', color='red', alpha=0.4, label="Infected")
    rem_line = plt.plot(dates, removed[1, :], '-', color='blue', label="Removed")
    ro_line = plt.plot(dates, removed_observed[1, :], '-', color='orange', label='Predicted detections')
    marks = plt.plot(data_dates, y, '+', label='Observed cases')
    plt.legend([ti_line[0], rem_line[0], ro_line[0], filler, marks[0]],
               ["Infected", "Removed", "Predicted detections", "95% credible interval", "Observed counts"])
    plt.grid()
    plt.xlabel("Date")
    plt.ylabel("$10^7$ individuals")
119
120
121
    fig.autofmt_xdate()
    plt.show()

122
123
124
125
126
127
128
129
130
131
    # Number of new cases per day
    new_cases = tfs.percentile(removed[:, 1:] - removed[:, :-1],  q=[2.5, 50, 97.5], axis=0)/10000.
    fig = plt.figure()
    plt.fill_between(dates[:-1], new_cases[0, :], new_cases[2, :], color='lightgray', label="95% credible interval")
    plt.plot(dates[:-1], new_cases[1, :], '-', alpha=0.2, label='New cases')
    plt.grid()
    plt.xlabel("Date")
    plt.ylabel("Incidence per 10,000")
    fig.autofmt_xdate()
    plt.show()
132
133
134
135
136
137
138
139
140
141
142
143

    # R0
    R0_ci = tfs.percentile(R0, q=[2.5, 50, 97.5])
    print("R0:", R0_ci)
    fig = plt.figure()
    seaborn.kdeplot(R0.numpy(), ax=fig.gca())
    plt.title("R0")
    plt.show()

    # Doubling time
    dub_ci = tfs.percentile(dub_time, q=[2.5, 50, 97.5])
    print("Doubling time:", dub_ci)