ruffus_pipeline.py 6.86 KB
Newer Older
1
"""Represents the analytic pipeline as a ruffus chain"""
Chris Jewell's avatar
Chris Jewell committed
2
3

import os
4
import warnings
Chris Jewell's avatar
Chris Jewell committed
5
import yaml
6
7
8
9
from datetime import datetime
from uuid import uuid1
import json
import netCDF4 as nc
10
import s3fs
Chris Jewell's avatar
Chris Jewell committed
11
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
12
13
import ruffus as rf

14

Chris Jewell's avatar
Chris Jewell committed
15
from covid.tasks import (
16
17
    assemble_data,
    mcmc,
Chris Jewell's avatar
Chris Jewell committed
18
    thin_posterior,
19
    reproduction_number,
Chris Jewell's avatar
Chris Jewell committed
20
21
    overall_rt,
    predict,
22
23
24
25
    summarize,
    within_between,
    case_exceedance,
    summary_geopackage,
26
    insample_predictive_timeseries,
Chris Jewell's avatar
Chris Jewell committed
27
    summary_longformat,
Chris Jewell's avatar
Chris Jewell committed
28
29
)

30
__all__ = ["run_pipeline"]
Chris Jewell's avatar
Chris Jewell committed
31
32


33
34
def _make_append_work_dir(work_dir):
    return lambda filename: os.path.expandvars(os.path.join(work_dir, filename))
Chris Jewell's avatar
Chris Jewell committed
35
36


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def _create_metadata(config):
    return dict(
        pipeline_id=uuid1().hex,
        created_at=str(datetime.now()),
        inference_library="GEM",
        inference_library_version="0.1.alpha0",
        pipeline_config=json.dumps(config, default=str),
    )


def _create_nc_file(output_file, meta_data_dict):
    nc_file = nc.Dataset(output_file, "w", format="NETCDF4")
    for k, v in meta_data_dict.items():
        setattr(nc_file, k, v)
    nc_file.close()


54
def run_pipeline(global_config, results_directory, cli_options):
Chris Jewell's avatar
Chris Jewell committed
55

56
    wd = _make_append_work_dir(results_directory)
Chris Jewell's avatar
Chris Jewell committed
57

58
59
    pipeline_meta = _create_metadata(global_config)

Chris Jewell's avatar
Chris Jewell committed
60
    # Pipeline starts here
61
    @rf.mkdir(results_directory)
62
63
64
65
66
67
    @rf.originate(wd("config.yaml"), global_config)
    def save_config(output_file, config):
        with open(output_file, "w") as f:
            yaml.dump(config, f)

    @rf.follows(save_config)
68
69
    @rf.originate(wd("inferencedata.nc"), global_config)
    def process_data(output_file, config):
70

71
72
        _create_nc_file(output_file, pipeline_meta)
        assemble_data(output_file, config["ProcessData"])
Chris Jewell's avatar
Chris Jewell committed
73

74
    @rf.transform(
Chris Jewell's avatar
Chris Jewell committed
75
76
        process_data,
        rf.formatter(),
77
        wd("posterior.hd5"),
78
        global_config,
Chris Jewell's avatar
Chris Jewell committed
79
    )
80
81
    def run_mcmc(input_file, output_file, config):
        mcmc(input_file, output_file, config["Mcmc"])
Chris Jewell's avatar
Chris Jewell committed
82

83
84
    @rf.transform(
        input=run_mcmc,
Chris Jewell's avatar
Chris Jewell committed
85
        filter=rf.formatter(),
86
        output=wd("thin_samples.pkl"),
Chris Jewell's avatar
Chris Jewell committed
87
        extras=[global_config],
Chris Jewell's avatar
Chris Jewell committed
88
    )
Chris Jewell's avatar
Chris Jewell committed
89
    def thin_samples(input_file, output_file, config):
Chris Jewell's avatar
Chris Jewell committed
90
        thin_posterior(input_file, output_file, config["ThinPosterior"])
Chris Jewell's avatar
Chris Jewell committed
91
92
93

    # Rt related steps
    rf.transform(
94
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
95
        filter=rf.formatter(),
96
97
        output=wd("reproduction_number.nc"),
    )(reproduction_number)
Chris Jewell's avatar
Chris Jewell committed
98
99

    rf.transform(
100
        input=reproduction_number,
Chris Jewell's avatar
Chris Jewell committed
101
        filter=rf.formatter(),
102
        output=wd("national_rt.xlsx"),
103
    )(overall_rt)
Chris Jewell's avatar
Chris Jewell committed
104
105
106

    # In-sample prediction
    @rf.transform(
107
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
108
        filter=rf.formatter(),
109
        output=wd("insample7.nc"),
Chris Jewell's avatar
Chris Jewell committed
110
111
    )
    def insample7(input_files, output_file):
112
        predict(
Chris Jewell's avatar
Chris Jewell committed
113
114
115
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
116
            initial_step=-7,
117
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
118
119
120
        )

    @rf.transform(
121
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
122
        filter=rf.formatter(),
123
        output=wd("insample14.nc"),
Chris Jewell's avatar
Chris Jewell committed
124
125
126
127
128
129
130
    )
    def insample14(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-14,
131
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
132
133
134
135
        )

    # Medium-term prediction
    @rf.transform(
136
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
137
        filter=rf.formatter(),
138
        output=wd("medium_term.nc"),
Chris Jewell's avatar
Chris Jewell committed
139
140
141
142
143
144
145
    )
    def medium_term(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-1,
146
            num_steps=61,
Chris Jewell's avatar
Chris Jewell committed
147
148
        )

149
    # Summarisation
Chris Jewell's avatar
Chris Jewell committed
150
    rf.transform(
151
        input=reproduction_number,
Chris Jewell's avatar
Chris Jewell committed
152
        filter=rf.formatter(),
153
        output=wd("rt_summary.csv"),
154
155
156
157
158
    )(summarize.rt)

    rf.transform(
        input=medium_term,
        filter=rf.formatter(),
159
        output=wd("infec_incidence_summary.csv"),
160
161
162
    )(summarize.infec_incidence)

    rf.transform(
163
        input=[[process_data, medium_term]],
164
        filter=rf.formatter(),
165
        output=wd("prevalence_summary.csv"),
166
167
168
169
170
    )(summarize.prevalence)

    rf.transform(
        input=[[process_data, thin_samples]],
        filter=rf.formatter(),
171
        output=wd("within_between_summary.csv"),
172
173
174
175
176
    )(within_between)

    @rf.transform(
        input=[[process_data, insample7, insample14]],
        filter=rf.formatter(),
177
        output=wd("exceedance_summary.csv"),
Chris Jewell's avatar
Chris Jewell committed
178
    )
179
180
181
182
    def exceedance(input_files, output_file):
        exceed7 = case_exceedance((input_files[0], input_files[1]), 7)
        exceed14 = case_exceedance((input_files[0], input_files[2]), 14)
        df = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
183
184
            {"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14},
            index=exceed7.coords["location"],
185
186
187
        )
        df.to_csv(output_file)

188
    # Plot in-sample
Chris Jewell's avatar
Chris Jewell committed
189
190
191
192
193
194
195
196
197
    # @rf.transform(
    #     input=[insample7, insample14],
    #     filter=rf.formatter(".+/insample(?P<LAG>\d+).nc"),
    #     add_inputs=rf.add_inputs(process_data),
    #     output="{path[0]}/insample_plots{LAG[0]}",
    #     extras=["{LAG[0]}"],
    # )
    # def plot_insample_predictive_timeseries(input_files, output_dir, lag):
    #     insample_predictive_timeseries(input_files, output_dir, lag)
198

199
200
201
202
203
204
205
206
207
208
209
210
211
    # Geopackage
    rf.transform(
        [
            [
                process_data,
                summarize.rt,
                summarize.infec_incidence,
                summarize.prevalence,
                within_between,
                exceedance,
            ]
        ],
        rf.formatter(),
212
        wd("prediction.gpkg"),
213
214
        global_config["Geopackage"],
    )(summary_geopackage)
Chris Jewell's avatar
Chris Jewell committed
215

216
    rf.cmdline.run(cli_options)
Chris Jewell's avatar
Chris Jewell committed
217
218
219

    # DSTL Summary
    rf.transform(
220
221
222
223
224
225
        [
            [
                process_data,
                insample7,
                insample14,
                medium_term,
226
                reproduction_number,
227
228
            ]
        ],
Chris Jewell's avatar
Chris Jewell committed
229
230
231
232
        rf.formatter(),
        wd("summary_longformat.xlsx"),
    )(summary_longformat)

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    @rf.active_if(cli_options.aws)
    @rf.transform(
        input=[
            process_data,
            run_mcmc,
            insample7,
            insample14,
            medium_term,
            reproduction_number,
        ],
        filter=rf.formatter(),
        output="{subdir[0][0]}/{basename[0]}{ext[0]}",
        extras=[global_config["AWSS3"]],
    )
    def upload_to_aws(input_file, output_file, config):
        obj_path = f"{config['bucket']}/{output_file}"
        s3 = s3fs.S3FileSystem(profile=config["profile"])
        if not s3.exists(obj_path):
            print(f"Copy {input_file} to {obj_path}", flush=True)
            s3.put(input_file, obj_path)
        else:
            warnings.warn(f"Path '{obj_path}' already exists, not uploading.")

Chris Jewell's avatar
Chris Jewell committed
256
    rf.cmdline.run(cli_options)