ruffus_pipeline.py 5.1 KB
Newer Older
1
"""Represents the analytic pipeline as a ruffus chain"""
Chris Jewell's avatar
Chris Jewell committed
2
3
4

import os
import yaml
Chris Jewell's avatar
Chris Jewell committed
5
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
6
7
import ruffus as rf

8

Chris Jewell's avatar
Chris Jewell committed
9
from covid.tasks import (
10
11
    assemble_data,
    mcmc,
Chris Jewell's avatar
Chris Jewell committed
12
13
14
15
    thin_posterior,
    next_generation_matrix,
    overall_rt,
    predict,
16
17
18
19
    summarize,
    within_between,
    case_exceedance,
    summary_geopackage,
20
    insample_predictive_timeseries,
Chris Jewell's avatar
Chris Jewell committed
21
22
)

23
__all__ = ["run_pipeline"]
Chris Jewell's avatar
Chris Jewell committed
24
25


26
27
def _make_append_work_dir(work_dir):
    return lambda filename: os.path.expandvars(os.path.join(work_dir, filename))
Chris Jewell's avatar
Chris Jewell committed
28
29


30
def run_pipeline(global_config, results_directory, cli_options):
Chris Jewell's avatar
Chris Jewell committed
31

32
    wd = _make_append_work_dir(results_directory)
Chris Jewell's avatar
Chris Jewell committed
33
34

    # Pipeline starts here
35
36
    @rf.mkdir(results_directory)
    @rf.originate(wd("config.yaml"), global_config)
37
38
39
    def save_config(output_file, config):
        with open(output_file, "w") as f:
            yaml.dump(config, f)
Chris Jewell's avatar
Chris Jewell committed
40
41
42
43

    @rf.transform(
        save_config,
        rf.formatter(),
44
        wd("pipeline_data.pkl"),
Chris Jewell's avatar
Chris Jewell committed
45
46
        global_config,
    )
47
48
    def process_data(input_file, output_file, config):
        assemble_data(output_file, config["ProcessData"])
Chris Jewell's avatar
Chris Jewell committed
49

50
    @rf.transform(
Chris Jewell's avatar
Chris Jewell committed
51
52
        process_data,
        rf.formatter(),
53
        wd("posterior.hd5"),
54
        global_config,
Chris Jewell's avatar
Chris Jewell committed
55
    )
56
57
    def run_mcmc(input_file, output_file, config):
        mcmc(input_file, output_file, config["Mcmc"])
Chris Jewell's avatar
Chris Jewell committed
58

59
60
    @rf.transform(
        input=run_mcmc,
Chris Jewell's avatar
Chris Jewell committed
61
        filter=rf.formatter(),
62
        output=wd("thin_samples.pkl"),
Chris Jewell's avatar
Chris Jewell committed
63
        extras=[global_config],
Chris Jewell's avatar
Chris Jewell committed
64
    )
Chris Jewell's avatar
Chris Jewell committed
65
    def thin_samples(input_file, output_file, config):
Chris Jewell's avatar
Chris Jewell committed
66
        thin_posterior(input_file, output_file, config["ThinPosterior"])
Chris Jewell's avatar
Chris Jewell committed
67
68
69

    # Rt related steps
    rf.transform(
70
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
71
        filter=rf.formatter(),
72
        output=wd("ngm.pkl"),
73
    )(next_generation_matrix)
Chris Jewell's avatar
Chris Jewell committed
74
75
76
77

    rf.transform(
        input=next_generation_matrix,
        filter=rf.formatter(),
78
        output=wd("national_rt.xlsx"),
79
    )(overall_rt)
Chris Jewell's avatar
Chris Jewell committed
80
81
82

    # In-sample prediction
    @rf.transform(
83
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
84
        filter=rf.formatter(),
85
        output=wd("insample7.pkl"),
Chris Jewell's avatar
Chris Jewell committed
86
87
    )
    def insample7(input_files, output_file):
88
        predict(
Chris Jewell's avatar
Chris Jewell committed
89
90
91
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
92
93
            initial_step=-8,
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
94
95
96
        )

    @rf.transform(
97
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
98
        filter=rf.formatter(),
99
        output=wd("insample14.pkl"),
Chris Jewell's avatar
Chris Jewell committed
100
101
102
103
104
105
106
    )
    def insample14(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-14,
107
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
108
109
110
111
        )

    # Medium-term prediction
    @rf.transform(
112
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
113
        filter=rf.formatter(),
114
        output=wd("medium_term.pkl"),
Chris Jewell's avatar
Chris Jewell committed
115
116
117
118
119
120
121
    )
    def medium_term(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-1,
122
            num_steps=61,
Chris Jewell's avatar
Chris Jewell committed
123
124
        )

125
    # Summarisation
Chris Jewell's avatar
Chris Jewell committed
126
    rf.transform(
127
        input=next_generation_matrix,
Chris Jewell's avatar
Chris Jewell committed
128
        filter=rf.formatter(),
129
        output=wd("rt_summary.csv"),
130
131
132
133
134
    )(summarize.rt)

    rf.transform(
        input=medium_term,
        filter=rf.formatter(),
135
        output=wd("infec_incidence_summary.csv"),
136
137
138
139
140
    )(summarize.infec_incidence)

    rf.transform(
        input=[[process_data, thin_samples, medium_term]],
        filter=rf.formatter(),
141
        output=wd("prevalence_summary.csv"),
142
143
144
145
146
    )(summarize.prevalence)

    rf.transform(
        input=[[process_data, thin_samples]],
        filter=rf.formatter(),
147
        output=wd("within_between_summary.csv"),
148
149
150
151
152
    )(within_between)

    @rf.transform(
        input=[[process_data, insample7, insample14]],
        filter=rf.formatter(),
153
        output=wd("exceedance_summary.csv"),
Chris Jewell's avatar
Chris Jewell committed
154
    )
155
156
157
158
    def exceedance(input_files, output_file):
        exceed7 = case_exceedance((input_files[0], input_files[1]), 7)
        exceed14 = case_exceedance((input_files[0], input_files[2]), 14)
        df = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
159
160
            {"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14},
            index=exceed7.coords["location"],
161
162
163
        )
        df.to_csv(output_file)

164
165
166
167
168
169
170
171
172
173
    # Plot in-sample
    @rf.transform(
        input=[insample7, insample14],
        filter=rf.formatter(".+/insample(?P<LAG>\d+).pkl"),
        add_inputs=rf.add_inputs(process_data),
        output="{path[0]}/insample_plots{LAG[0]}",
        extras=["{LAG[0]}"],
    )
    def plot_insample_predictive_timeseries(input_files, output_dir, lag):
        insample_predictive_timeseries(input_files, output_dir, lag)
174

175
176
177
178
179
180
181
182
183
184
185
186
187
    # Geopackage
    rf.transform(
        [
            [
                process_data,
                summarize.rt,
                summarize.infec_incidence,
                summarize.prevalence,
                within_between,
                exceedance,
            ]
        ],
        rf.formatter(),
188
        wd("prediction.gpkg"),
189
190
        global_config["Geopackage"],
    )(summary_geopackage)
Chris Jewell's avatar
Chris Jewell committed
191

192
    rf.cmdline.run(cli_options)