within_between.py 3.2 KB
Newer Older
1
2
"""Creates a medium term prediction"""

3
import pickle as pkl
4
import numpy as np
5
import pandas as pd
6
import xarray
7
8
import tensorflow as tf

Chris Jewell's avatar
Chris Jewell committed
9
from gemlib.util import compute_state
10
from covid import model_spec
11
12


13
def make_within_rate_fns(covariates, beta2):
14
15

    C = tf.convert_to_tensor(covariates["C"], dtype=model_spec.DTYPE)
16
17
    C = tf.linalg.set_diag(C, tf.zeros(C.shape[0], dtype=model_spec.DTYPE))

Chris Jewell's avatar
Chris Jewell committed
18
19
20
21
22
23
    W = tf.convert_to_tensor(
        tf.squeeze(covariates["W"]), dtype=model_spec.DTYPE
    )
    N = tf.convert_to_tensor(
        tf.squeeze(covariates["N"]), dtype=model_spec.DTYPE
    )
24
25

    def within_fn(t, state):
26
27
28
29
30
        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
        rate = state[..., 2] - beta2 * state[
            ..., 2
        ] / N * commute_volume * tf.reduce_sum(C, axis=-2)
31
32
33
34
35
        return rate

    def between_fn(t, state):
        w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
        commute_volume = tf.gather(W, w_idx)
36
37
38
39
40
        rate = (
            beta2
            * commute_volume
            * tf.linalg.matvec(C + tf.transpose(C), state[..., 2] / N)
        )
41
42
43
44
45
        return rate

    return within_fn, between_fn


Chris Jewell's avatar
Chris Jewell committed
46
# @tf.function
47
def calc_pressure_components(covariates, beta2, state):
Chris Jewell's avatar
Chris Jewell committed
48
    def atomic_fn(args):
49
50
        beta2_, state_ = args
        within_fn, between_fn = make_within_rate_fns(covariates, beta2_)
Chris Jewell's avatar
Chris Jewell committed
51
52
        within = within_fn(covariates["W"].shape[0], state_)
        between = between_fn(covariates["W"].shape[0], state_)
53
54
55
        total = within + between
        return within / total, between / total

56
    return tf.vectorized_map(atomic_fn, elems=(beta2, state))
57
58


59
60
def within_between(input_files, output_file):
    """Calculates PAF for within- and between-location infection.
61

62
63
64
    :param input_files: a list of [data pickle, posterior samples pickle]
    :param output_file: a csv with within/between summary
    """
65

66
    covar_data = xarray.open_dataset(input_files[0], group="constant_data")
67

68
69
    with open(input_files[1], "rb") as f:
        samples = pkl.load(f)
70

71
    psi = samples["psi"]
72
    events = samples["seir"]
73
    init_state = samples["initial_state"]
74
75
76
77
78
    state_timeseries = compute_state(
        init_state, events, model_spec.STOICHIOMETRY
    )

    within, between = calc_pressure_components(
79
        covar_data, psi, state_timeseries[..., -1, :]
80
    )
81
82
83
84
85
86
87

    df = pd.DataFrame(
        dict(
            within_mean=np.mean(within, axis=0),
            between_mean=np.mean(between, axis=0),
            p_within_gt_between=np.mean(within > between),
        ),
88
89
90
        index=pd.Index(
            covar_data["locations"].coords["location"], name="location"
        ),
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    )
    df.to_csv(output_file)


if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
        "-d", "--datafile", type=str, help="Data pickle file", requied=True
    )
    parser.add_argument(
        "-s",
        "--samples",
        type=str,
        help="Posterior samples pickle",
        required=True,
    )
    parser.add_argument("-o", "--output", type=str, help="Output csv")
    args = parser.parse_args()

    within_between([args.datafile, args.samples], args.output)