ruffus_pipeline.py 6.11 KB
Newer Older
1
"""Represents the analytic pipeline as a ruffus chain"""
Chris Jewell's avatar
Chris Jewell committed
2
3
4

import os
import yaml
5
6
7
8
from datetime import datetime
from uuid import uuid1
import json
import netCDF4 as nc
Chris Jewell's avatar
Chris Jewell committed
9
import pandas as pd
Chris Jewell's avatar
Chris Jewell committed
10
11
import ruffus as rf

12

Chris Jewell's avatar
Chris Jewell committed
13
from covid.tasks import (
14
15
    assemble_data,
    mcmc,
Chris Jewell's avatar
Chris Jewell committed
16
17
18
19
    thin_posterior,
    next_generation_matrix,
    overall_rt,
    predict,
20
21
22
23
    summarize,
    within_between,
    case_exceedance,
    summary_geopackage,
24
    insample_predictive_timeseries,
Chris Jewell's avatar
Chris Jewell committed
25
    summary_longformat,
Chris Jewell's avatar
Chris Jewell committed
26
27
)

28
__all__ = ["run_pipeline"]
Chris Jewell's avatar
Chris Jewell committed
29
30


31
32
def _make_append_work_dir(work_dir):
    return lambda filename: os.path.expandvars(os.path.join(work_dir, filename))
Chris Jewell's avatar
Chris Jewell committed
33
34


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _create_metadata(config):
    return dict(
        pipeline_id=uuid1().hex,
        created_at=str(datetime.now()),
        inference_library="GEM",
        inference_library_version="0.1.alpha0",
        pipeline_config=json.dumps(config, default=str),
    )


def _create_nc_file(output_file, meta_data_dict):
    nc_file = nc.Dataset(output_file, "w", format="NETCDF4")
    for k, v in meta_data_dict.items():
        setattr(nc_file, k, v)
    nc_file.close()


52
def run_pipeline(global_config, results_directory, cli_options):
Chris Jewell's avatar
Chris Jewell committed
53

54
    wd = _make_append_work_dir(results_directory)
Chris Jewell's avatar
Chris Jewell committed
55

56
57
    pipeline_meta = _create_metadata(global_config)

Chris Jewell's avatar
Chris Jewell committed
58
    # Pipeline starts here
59
    @rf.mkdir(results_directory)
60
61
62
63
    @rf.originate(wd("inferencedata.nc"), global_config)
    def process_data(output_file, config):
        _create_nc_file(output_file, pipeline_meta)
        assemble_data(output_file, config["ProcessData"])
Chris Jewell's avatar
Chris Jewell committed
64
65

    @rf.transform(
66
        process_data,
Chris Jewell's avatar
Chris Jewell committed
67
        rf.formatter(),
68
        wd("config.yaml"),
Chris Jewell's avatar
Chris Jewell committed
69
70
        global_config,
    )
71
72
73
    def save_config(input_file, output_file, config):
        with open(output_file, "w") as f:
            yaml.dump(config, f)
Chris Jewell's avatar
Chris Jewell committed
74

75
    @rf.transform(
Chris Jewell's avatar
Chris Jewell committed
76
77
        process_data,
        rf.formatter(),
78
        wd("posterior.hd5"),
79
        global_config,
Chris Jewell's avatar
Chris Jewell committed
80
    )
81
82
    def run_mcmc(input_file, output_file, config):
        mcmc(input_file, output_file, config["Mcmc"])
Chris Jewell's avatar
Chris Jewell committed
83

84
85
    @rf.transform(
        input=run_mcmc,
Chris Jewell's avatar
Chris Jewell committed
86
        filter=rf.formatter(),
87
        output=wd("thin_samples.pkl"),
Chris Jewell's avatar
Chris Jewell committed
88
        extras=[global_config],
Chris Jewell's avatar
Chris Jewell committed
89
    )
Chris Jewell's avatar
Chris Jewell committed
90
    def thin_samples(input_file, output_file, config):
Chris Jewell's avatar
Chris Jewell committed
91
        thin_posterior(input_file, output_file, config["ThinPosterior"])
Chris Jewell's avatar
Chris Jewell committed
92
93
94

    # Rt related steps
    rf.transform(
95
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
96
        filter=rf.formatter(),
97
        output=wd("ngm.nc"),
98
    )(next_generation_matrix)
Chris Jewell's avatar
Chris Jewell committed
99
100
101
102

    rf.transform(
        input=next_generation_matrix,
        filter=rf.formatter(),
103
        output=wd("national_rt.xlsx"),
104
    )(overall_rt)
Chris Jewell's avatar
Chris Jewell committed
105
106
107

    # In-sample prediction
    @rf.transform(
108
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
109
        filter=rf.formatter(),
110
        output=wd("insample7.nc"),
Chris Jewell's avatar
Chris Jewell committed
111
112
    )
    def insample7(input_files, output_file):
113
        predict(
Chris Jewell's avatar
Chris Jewell committed
114
115
116
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
117
            initial_step=-7,
118
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
119
120
121
        )

    @rf.transform(
122
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
123
        filter=rf.formatter(),
124
        output=wd("insample14.nc"),
Chris Jewell's avatar
Chris Jewell committed
125
126
127
128
129
130
131
    )
    def insample14(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-14,
132
            num_steps=28,
Chris Jewell's avatar
Chris Jewell committed
133
134
135
136
        )

    # Medium-term prediction
    @rf.transform(
137
        input=[[process_data, thin_samples]],
Chris Jewell's avatar
Chris Jewell committed
138
        filter=rf.formatter(),
139
        output=wd("medium_term.nc"),
Chris Jewell's avatar
Chris Jewell committed
140
141
142
143
144
145
146
    )
    def medium_term(input_files, output_file):
        return predict(
            data=input_files[0],
            posterior_samples=input_files[1],
            output_file=output_file,
            initial_step=-1,
147
            num_steps=61,
Chris Jewell's avatar
Chris Jewell committed
148
149
        )

150
    # Summarisation
Chris Jewell's avatar
Chris Jewell committed
151
    rf.transform(
152
        input=next_generation_matrix,
Chris Jewell's avatar
Chris Jewell committed
153
        filter=rf.formatter(),
154
        output=wd("rt_summary.csv"),
155
156
157
158
159
    )(summarize.rt)

    rf.transform(
        input=medium_term,
        filter=rf.formatter(),
160
        output=wd("infec_incidence_summary.csv"),
161
162
163
    )(summarize.infec_incidence)

    rf.transform(
164
        input=[[process_data, medium_term]],
165
        filter=rf.formatter(),
166
        output=wd("prevalence_summary.csv"),
167
168
169
170
171
    )(summarize.prevalence)

    rf.transform(
        input=[[process_data, thin_samples]],
        filter=rf.formatter(),
172
        output=wd("within_between_summary.csv"),
173
174
175
176
177
    )(within_between)

    @rf.transform(
        input=[[process_data, insample7, insample14]],
        filter=rf.formatter(),
178
        output=wd("exceedance_summary.csv"),
Chris Jewell's avatar
Chris Jewell committed
179
    )
180
181
182
183
    def exceedance(input_files, output_file):
        exceed7 = case_exceedance((input_files[0], input_files[1]), 7)
        exceed14 = case_exceedance((input_files[0], input_files[2]), 14)
        df = pd.DataFrame(
Chris Jewell's avatar
Chris Jewell committed
184
185
            {"Pr(pred<obs)_7": exceed7, "Pr(pred<obs)_14": exceed14},
            index=exceed7.coords["location"],
186
187
188
        )
        df.to_csv(output_file)

189
190
191
    # Plot in-sample
    @rf.transform(
        input=[insample7, insample14],
192
        filter=rf.formatter(".+/insample(?P<LAG>\d+).nc"),
193
194
195
196
197
198
        add_inputs=rf.add_inputs(process_data),
        output="{path[0]}/insample_plots{LAG[0]}",
        extras=["{LAG[0]}"],
    )
    def plot_insample_predictive_timeseries(input_files, output_dir, lag):
        insample_predictive_timeseries(input_files, output_dir, lag)
199

200
201
202
203
204
205
206
207
208
209
210
211
212
    # Geopackage
    rf.transform(
        [
            [
                process_data,
                summarize.rt,
                summarize.infec_incidence,
                summarize.prevalence,
                within_between,
                exceedance,
            ]
        ],
        rf.formatter(),
213
        wd("prediction.gpkg"),
214
215
        global_config["Geopackage"],
    )(summary_geopackage)
Chris Jewell's avatar
Chris Jewell committed
216

217
    rf.cmdline.run(cli_options)
Chris Jewell's avatar
Chris Jewell committed
218
219
220

    # DSTL Summary
    rf.transform(
221
222
223
224
225
226
227
228
229
        [
            [
                process_data,
                insample7,
                insample14,
                medium_term,
                next_generation_matrix,
            ]
        ],
Chris Jewell's avatar
Chris Jewell committed
230
231
232
233
234
        rf.formatter(),
        wd("summary_longformat.xlsx"),
    )(summary_longformat)

    rf.cmdline.run(cli_options)