Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Chris Jewell
covid19uk
Commits
5c2058da
Commit
5c2058da
authored
Jan 07, 2021
by
Chris Jewell
Browse files
Ruffus-style pipeline implemented
parent
74aaa3dd
Changes
18
Hide whitespace changes
Inline
Side-by-side
covid/data/area_code.py
View file @
5c2058da
...
...
@@ -66,16 +66,6 @@ class AreaCodeData:
if
settings
[
"format"
]
==
"ons"
:
print
(
"Retrieving Area Code data from the ONS"
)
data
=
response
.
json
()
if
config
[
"GenerateOutput"
][
"storeInputs"
]:
fn
=
format_output_filename
(
config
[
"GenerateOutput"
][
"scrapedDataDir"
]
+
"AreaCodeData_ONS.json"
,
config
,
)
with
open
(
fn
,
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
df
=
AreaCodeData
.
getJSON
(
json
.
dumps
(
data
))
return
df
...
...
@@ -162,28 +152,22 @@ class AreaCodeData:
"""
Adapt the area codes to the desired dataframe format
"""
output_settings
=
config
[
"GenerateOutput"
]
settings
=
config
[
"AreaCodeData"
]
output
=
settings
[
"output"
]
regions
=
settings
[
"regions"
]
if
settings
[
"input"
]
==
"processed"
:
return
df
if
settings
[
"format"
].
lower
()
==
"ons"
:
df
=
AreaCodeData
.
adapt_ons
(
df
,
regions
,
output
,
config
)
df
=
AreaCodeData
.
adapt_ons
(
df
,
regions
)
# if we have a predefined list of LADs, filter them down
if
"lad19cds"
in
config
:
df
=
df
[[
x
in
config
[
"lad19cds"
]
for
x
in
df
.
lad19cd
.
values
]]
if
output_settings
[
"storeProcessedInputs"
]
and
output
!=
"None"
:
output
=
format_output_filename
(
output
,
config
)
df
.
to_csv
(
output
,
index
=
False
)
return
df
def
adapt_ons
(
df
,
regions
,
output
,
config
):
def
adapt_ons
(
df
,
regions
):
colnames
=
[
"lad19cd"
,
"name"
]
df
.
columns
=
colnames
filters
=
df
[
"lad19cd"
].
str
.
contains
(
str
.
join
(
"|"
,
regions
))
...
...
covid/data/area_code_test.py
View file @
5c2058da
...
...
@@ -14,12 +14,6 @@ def test_url():
"output"
:
"processed_data/processed_lad19cd.csv"
,
"regions"
:
[
"E"
],
},
"GenerateOutput"
:
{
"storeInputs"
:
True
,
"scrapedDataDir"
:
"scraped_data"
,
"storeProcessedInputs"
:
True
,
},
"Global"
:
{
"prependID"
:
False
,
"prependDate"
:
False
},
}
df
=
AreaCodeData
.
process
(
config
)
...
...
covid/data/util.py
View file @
5c2058da
...
...
@@ -55,16 +55,8 @@ def merge_lad_values(df):
def
get_date_low_high
(
config
):
if
"dates"
in
config
:
low
=
config
[
"dates"
][
"low"
]
high
=
config
[
"dates"
][
"high"
]
else
:
inference_period
=
[
np
.
datetime64
(
x
)
for
x
in
config
[
"Global"
][
"inference_period"
]
]
low
=
inference_period
[
0
]
high
=
inference_period
[
1
]
return
(
low
,
high
)
date_range
=
[
np
.
datetime64
(
x
)
for
x
in
config
[
"date_range"
]]
return
tuple
(
date_range
)
def
check_date_format
(
df
):
...
...
covid/model_spec.py
View file @
5c2058da
"""Implements the COVID SEIR model as a TFP Joint Distribution"""
import
pandas
as
pd
import
geopandas
as
gp
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
...
...
@@ -19,7 +18,7 @@ XI_FREQ = 14 # baseline transmission changes every 14 days
NU
=
tf
.
constant
(
0.28
,
dtype
=
DTYPE
)
# E->I rate assumed known.
def
read_covariates
(
config
):
def
gather_data
(
config
):
"""Loads covariate data
:param paths: a dictionary of paths to data with keys {'mobility_matrix',
...
...
@@ -27,31 +26,36 @@ def read_covariates(config):
:returns: a dictionary of covariate information to be consumed by the model
{'C': commute_matrix, 'W': traffic_flow, 'N': population_size}
"""
paths
=
config
[
"data"
]
date_low
=
np
.
datetime64
(
config
[
"
Global"
][
"inference_period
"
][
0
])
date_high
=
np
.
datetime64
(
config
[
"
Global"
][
"inference_period
"
][
1
])
mobility
=
data
.
read_mobility
(
paths
[
"mobility_matrix"
])
popsize
=
data
.
read_population
(
paths
[
"population_size"
])
date_low
=
np
.
datetime64
(
config
[
"
date_range
"
][
0
])
date_high
=
np
.
datetime64
(
config
[
"
date_range
"
][
1
])
mobility
=
data
.
read_mobility
(
config
[
"mobility_matrix"
])
popsize
=
data
.
read_population
(
config
[
"population_size"
])
commute_volume
=
data
.
read_traffic_flow
(
paths
[
"commute_volume"
],
date_low
=
date_low
,
date_high
=
date_high
config
[
"commute_volume"
],
date_low
=
date_low
,
date_high
=
date_high
)
geo
=
gp
.
read_file
(
paths
[
"geopackage"
])
geo
=
geo
.
loc
[
geo
[
"lad19cd"
].
str
.
startswith
(
"E"
)]
# tier_restriction = data.read_challen_tier_restriction(
# paths["tier_restriction_csv"],
# date_low,
# date_high,
# )
locations
=
data
.
AreaCodeData
.
process
(
config
)
tier_restriction
=
data
.
TierData
.
process
(
config
)[:,
:,
2
:]
date_range
=
[
date_low
,
date_high
]
weekday
=
pd
.
date_range
(
date_low
,
date_high
).
weekday
<
5
cases
=
data
.
read_phe_cases
(
config
[
"reported_cases"
],
date_low
,
date_high
,
pillar
=
config
[
"pillar"
],
date_type
=
config
[
"case_date_type"
],
)
return
dict
(
C
=
mobility
.
to_numpy
().
astype
(
DTYPE
),
W
=
commute_volume
.
to_numpy
().
astype
(
DTYPE
),
N
=
popsize
.
to_numpy
().
astype
(
DTYPE
),
L
=
tier_restriction
.
astype
(
DTYPE
),
weekday
=
weekday
.
astype
(
DTYPE
),
date_range
=
date_range
,
locations
=
locations
,
cases
=
cases
,
)
...
...
@@ -143,6 +147,15 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
gamma0
=
tf
.
convert_to_tensor
(
gamma0
,
DTYPE
)
gamma1
=
tf
.
convert_to_tensor
(
gamma1
,
DTYPE
)
C
=
tf
.
convert_to_tensor
(
covariates
[
"C"
],
dtype
=
DTYPE
)
C
=
tf
.
linalg
.
set_diag
(
C
,
tf
.
zeros
(
C
.
shape
[
0
],
dtype
=
DTYPE
))
Cstar
=
C
+
tf
.
transpose
(
C
)
Cstar
=
tf
.
linalg
.
set_diag
(
Cstar
,
-
tf
.
reduce_sum
(
C
,
axis
=-
2
))
W
=
tf
.
convert_to_tensor
(
tf
.
squeeze
(
covariates
[
"W"
]),
dtype
=
DTYPE
)
N
=
tf
.
convert_to_tensor
(
tf
.
squeeze
(
covariates
[
"N"
]),
dtype
=
DTYPE
)
L
=
tf
.
convert_to_tensor
(
covariates
[
"L"
],
DTYPE
)
L
=
L
-
tf
.
reduce_mean
(
L
,
axis
=
(
0
,
1
))
...
...
@@ -150,14 +163,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
weekday
=
weekday
-
tf
.
reduce_mean
(
weekday
,
axis
=-
1
)
def
transition_rate_fn
(
t
,
state
):
C
=
tf
.
convert_to_tensor
(
covariates
[
"C"
],
dtype
=
DTYPE
)
C
=
tf
.
linalg
.
set_diag
(
C
,
tf
.
zeros
(
C
.
shape
[
0
],
dtype
=
DTYPE
))
Cstar
=
C
+
tf
.
transpose
(
C
)
Cstar
=
tf
.
linalg
.
set_diag
(
Cstar
,
-
tf
.
reduce_sum
(
C
,
axis
=-
2
))
W
=
tf
.
constant
(
np
.
squeeze
(
covariates
[
"W"
]),
dtype
=
DTYPE
)
N
=
tf
.
constant
(
np
.
squeeze
(
covariates
[
"N"
]),
dtype
=
DTYPE
)
w_idx
=
tf
.
clip_by_value
(
tf
.
cast
(
t
,
tf
.
int64
),
0
,
W
.
shape
[
0
]
-
1
)
commute_volume
=
tf
.
gather
(
W
,
w_idx
)
...
...
@@ -166,7 +171,6 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
dtype
=
tf
.
int64
,
)
xi_
=
tf
.
gather
(
xi
,
xi_idx
)
L_idx
=
tf
.
clip_by_value
(
tf
.
cast
(
t
,
tf
.
int64
),
0
,
L
.
shape
[
0
]
-
1
)
Lt
=
tf
.
gather
(
L
,
L_idx
)
xB
=
tf
.
linalg
.
matvec
(
Lt
,
beta3
)
...
...
covid/tasks/__init__.py
View file @
5c2058da
"""Import tasks"""
from
covid.tasks.assemble_data
import
assemble_data
from
covid.tasks.inference
import
mcmc
from
covid.tasks.thin_posterior
import
thin_posterior
from
covid.tasks.next_generation_matrix
import
next_generation_matrix
from
covid.tasks.overall_rt
import
overall_rt
from
covid.tasks.predict
import
predict
import
covid.tasks.summarize
as
summarize
from
covid.tasks.within_between
import
within_between
from
covid.tasks.case_exceedance
import
case_exceedance
from
covid.tasks.summary_geopackage
import
summary_geopackage
__all__
=
[
"assemble_data"
,
"mcmc"
,
"thin_posterior"
,
"next_generation_matrix"
,
"overall_rt"
,
"predict"
,
"summarize"
,
"within_between"
,
"case_exceedance"
,
"summary_geopackage"
,
]
covid/tasks/assemble_data.py
View file @
5c2058da
...
...
@@ -2,6 +2,29 @@
to instantiate the COVID19 model"""
import
pickle
as
pkl
from
covid.model_spec
import
gather_data
def
assemble_data
(
output_file
,
config
):
covar_data
=
{}
all_data
=
gather_data
(
config
)
with
open
(
output_file
,
"wb"
)
as
f
:
pkl
.
dump
(
all_data
,
f
)
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
import
yaml
parser
=
ArgumentParser
(
description
=
"Bundle data into a pickled dictionary"
)
parser
.
add_argument
(
"config_file"
,
help
=
"Global config file"
)
parser
.
add_argument
(
"output_file"
,
help
=
"Data bundle pkl file"
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"r"
)
as
f
:
global_config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
assemble_data
(
args
.
output_file
,
global_config
[
"ProcessData"
])
covid/tasks/case_exceedance.py
0 → 100644
View file @
5c2058da
"""Calculates case exceedance probabilities"""
import
numpy
as
np
import
pickle
as
pkl
import
pandas
as
pd
def
case_exceedance
(
input_files
,
lag
):
"""Calculates case exceedance probabilities,
i.e. Pr(pred[lag:] < observed[lag:])
:param input_files: [data pickle, prediction pickle]
:param lag: the lag for which to calculate the exceedance
"""
with
open
(
input_files
[
0
],
"rb"
)
as
f
:
data
=
pkl
.
load
(
f
)
with
open
(
input_files
[
1
],
"rb"
)
as
f
:
prediction
=
pkl
.
load
(
f
)
modelled_cases
=
np
.
sum
(
prediction
[...,
:
lag
,
-
1
],
axis
=-
1
)
observed_cases
=
np
.
sum
(
data
[
"cases"
].
to_numpy
()[:,
-
lag
:],
axis
=-
1
)
exceedance
=
np
.
mean
(
modelled_cases
<
observed_cases
,
axis
=
0
)
df
=
pd
.
Series
(
exceedance
,
index
=
pd
.
Index
(
data
[
"locations"
][
"lad19cd"
],
name
=
"location"
),
)
return
df
covid/tasks/hotspot_detection.py
View file @
5c2058da
...
...
@@ -92,7 +92,7 @@ if __name__ == "__main__":
]
# Load covariate data
covar_data
=
model_spec
.
read_covariates
(
config
)
covar_data
=
model_spec
.
gather_data
(
config
)
output_folder_path
=
config
[
"output"
][
"results_dir"
]
geopackage_path
=
os
.
path
.
expandvars
(
...
...
@@ -111,9 +111,7 @@ if __name__ == "__main__":
)
print
(
"Using posterior:"
,
posterior_path
)
posterior
=
h5py
.
File
(
os
.
path
.
expandvars
(
posterior_path
,
),
os
.
path
.
expandvars
(
posterior_path
,),
"r"
,
rdcc_nbytes
=
1024
**
3
,
rdcc_nslots
=
1e6
,
...
...
@@ -125,9 +123,7 @@ if __name__ == "__main__":
beta1
=
posterior
[
"samples/beta1"
][
idx
],
beta2
=
posterior
[
"samples/beta2"
][
idx
],
beta3
=
posterior
[
"samples/beta3"
][
idx
],
sigma
=
posterior
[
"samples/sigma"
][
idx
,
],
sigma
=
posterior
[
"samples/sigma"
][
idx
,],
xi
=
posterior
[
"samples/xi"
][
idx
],
gamma0
=
posterior
[
"samples/gamma0"
][
idx
],
gamma1
=
posterior
[
"samples/gamma1"
][
idx
],
...
...
covid/tasks/inference.py
View file @
5c2058da
...
...
@@ -2,6 +2,8 @@
# pylint: disable=E402
import
os
import
h5py
import
pickle
as
pkl
from
time
import
perf_counter
import
tqdm
import
yaml
...
...
@@ -9,7 +11,6 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
from
covid.data
import
AreaCodeData
from
gemlib.util
import
compute_state
from
gemlib.mcmc
import
UncalibratedEventTimesUpdate
from
gemlib.mcmc
import
UncalibratedOccultUpdate
,
TransitionTopology
...
...
@@ -18,9 +19,6 @@ from gemlib.mcmc import MultiScanKernel
from
gemlib.mcmc
import
AdaptiveRandomWalkMetropolis
from
gemlib.mcmc
import
Posterior
from
covid.data
import
read_phe_cases
from
covid.cli_arg_parse
import
cli_args
import
covid.model_spec
as
model_spec
tfd
=
tfp
.
distributions
...
...
@@ -28,7 +26,7 @@ tfb = tfp.bijectors
DTYPE
=
model_spec
.
DTYPE
def
run_
mcmc
(
config
):
def
mcmc
(
data_file
,
output_file
,
config
):
"""Constructs and runs the MCMC"""
if
tf
.
test
.
gpu_device_name
():
...
...
@@ -36,24 +34,13 @@ def run_mcmc(config):
else
:
print
(
"Using CPU"
)
inference_period
=
[
np
.
datetime64
(
x
)
for
x
in
config
[
"Global"
][
"inference_period"
]
]
covar_data
=
model_spec
.
read_covariates
(
config
)
with
open
(
data_file
,
"rb"
)
as
f
:
data
=
pkl
.
load
(
f
)
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
cases
=
read_phe_cases
(
config
[
"data"
][
"reported_cases"
],
date_low
=
inference_period
[
0
],
date_high
=
inference_period
[
1
],
date_type
=
config
[
"data"
][
"case_date_type"
],
pillar
=
config
[
"data"
][
"pillar"
],
).
astype
(
DTYPE
)
# Impute censored events, return cases
events
=
model_spec
.
impute_censored_events
(
cases
)
events
=
model_spec
.
impute_censored_events
(
data
[
"cases"
].
astype
(
DTYPE
)
)
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
...
...
@@ -63,13 +50,13 @@ def run_mcmc(config):
# to set up a sensible initial state.
state
=
compute_state
(
initial_state
=
tf
.
concat
(
[
covar_
data
[
"N"
][:,
tf
.
newaxis
],
tf
.
zeros_like
(
events
[:,
0
,
:])],
[
data
[
"N"
][:,
tf
.
newaxis
],
tf
.
zeros_like
(
events
[:,
0
,
:])],
axis
=-
1
,
),
events
=
events
,
stoichiometry
=
model_spec
.
STOICHIOMETRY
,
)
start_time
=
state
.
shape
[
1
]
-
cases
.
shape
[
1
]
start_time
=
state
.
shape
[
1
]
-
data
[
"
cases
"
]
.
shape
[
1
]
initial_state
=
state
[:,
start_time
,
:]
events
=
events
[:,
start_time
:,
:]
...
...
@@ -77,7 +64,7 @@ def run_mcmc(config):
# Construct the MCMC kernels #
########################################################
model
=
model_spec
.
CovidUK
(
covariates
=
covar_
data
,
covariates
=
data
,
initial_state
=
initial_state
,
initial_step
=
0
,
num_steps
=
events
.
shape
[
1
],
...
...
@@ -152,9 +139,9 @@ def run_mcmc(config):
prev_event_id
=
prev_event_id
,
next_event_id
=
next_event_id
,
initial_state
=
initial_state
,
dmax
=
config
[
"
mcmc"
][
"
dmax"
],
mmax
=
config
[
"
mcmc"
][
"
m"
],
nmax
=
config
[
"
mcmc"
][
"
nmax"
],
dmax
=
config
[
"dmax"
],
mmax
=
config
[
"m"
],
nmax
=
config
[
"nmax"
],
),
name
=
name
,
)
...
...
@@ -170,7 +157,7 @@ def run_mcmc(config):
prev_event_id
,
target_event_id
,
next_event_id
),
cumulative_event_offset
=
initial_state
,
nmax
=
config
[
"
mcmc"
][
"
occult_nmax"
],
nmax
=
config
[
"occult_nmax"
],
t_range
=
(
events
.
shape
[
1
]
-
21
,
events
.
shape
[
1
]),
name
=
name
,
),
...
...
@@ -181,7 +168,7 @@ def run_mcmc(config):
def
make_event_multiscan_kernel
(
target_log_prob_fn
,
_
):
return
MultiScanKernel
(
config
[
"
mcmc"
][
"
num_event_time_updates"
],
config
[
"num_event_time_updates"
],
GibbsKernel
(
target_log_prob_fn
=
target_log_prob_fn
,
kernel_list
=
[
...
...
@@ -234,7 +221,7 @@ def run_mcmc(config):
return
results_dict
# Build MCMC algorithm here. This will be run in bursts for memory economy
@
tf
.
function
(
autograph
=
False
,
experimental_compile
=
True
)
@
tf
.
function
#
(autograph=False, experimental_compile=True)
def
sample
(
n_samples
,
init_state
,
thin
=
0
,
previous_results
=
None
):
with
tf
.
name_scope
(
"main_mcmc_sample_loop"
):
...
...
@@ -265,9 +252,8 @@ def run_mcmc(config):
###############################
# Construct bursted MCMC loop #
###############################
NUM_BURSTS
=
int
(
config
[
"mcmc"
][
"num_bursts"
])
NUM_BURST_SAMPLES
=
int
(
config
[
"mcmc"
][
"num_burst_samples"
])
NUM_EVENT_TIME_UPDATES
=
int
(
config
[
"mcmc"
][
"num_event_time_updates"
])
NUM_BURSTS
=
int
(
config
[
"num_bursts"
])
NUM_BURST_SAMPLES
=
int
(
config
[
"num_burst_samples"
])
NUM_SAVED_SAMPLES
=
NUM_BURST_SAMPLES
*
NUM_BURSTS
# RNG stuff
...
...
@@ -286,10 +272,7 @@ def run_mcmc(config):
# Output file
samples
,
results
,
_
=
sample
(
1
,
current_state
)
posterior
=
Posterior
(
os
.
path
.
join
(
os
.
path
.
expandvars
(
config
[
"output"
][
"results_dir"
]),
config
[
"output"
][
"posterior"
],
),
output_file
,
sample_dict
=
{
"beta2"
:
(
samples
[
0
][:,
0
],
(
NUM_BURST_SAMPLES
,)),
"gamma0"
:
(
samples
[
0
][:,
1
],
(
NUM_BURST_SAMPLES
,)),
...
...
@@ -307,19 +290,21 @@ def run_mcmc(config):
num_samples
=
NUM_SAVED_SAMPLES
,
)
posterior
.
_file
.
create_dataset
(
"initial_state"
,
data
=
initial_state
)
posterior
.
_file
.
create_dataset
(
"config"
,
data
=
yaml
.
dump
(
config
))
posterior
.
_file
.
create_dataset
(
"date_range"
,
data
=
np
.
array
(
data
[
"date_range"
]).
astype
(
h5py
.
string_dtype
()),
)
# We loop over successive calls to sample because we have to dump results
# to disc, or else end OOM (even on a 32GB system).
# with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
final_results
=
None
for
i
in
tqdm
.
tqdm
(
range
(
NUM_BURSTS
),
unit_scale
=
NUM_BURST_SAMPLES
*
config
[
"
mcmc"
][
"
thin"
]
range
(
NUM_BURSTS
),
unit_scale
=
NUM_BURST_SAMPLES
*
config
[
"thin"
]
):
samples
,
results
,
final_results
=
sample
(
NUM_BURST_SAMPLES
,
init_state
=
current_state
,
thin
=
config
[
"
mcmc"
][
"
thin"
]
-
1
,
thin
=
config
[
"thin"
]
-
1
,
previous_results
=
final_results
,
)
current_state
=
[
s
[
-
1
]
for
s
in
samples
]
...
...
@@ -343,7 +328,8 @@ def run_mcmc(config):
end
=
perf_counter
()
print
(
"Storage time:"
,
end
-
start
,
"seconds"
)
for
k
,
v
in
results
:
print
(
"Results type: "
,
type
(
results
))
for
k
,
v
in
results
.
items
():
print
(
f
"Acceptance
{
k
}
:"
,
tf
.
reduce_mean
(
tf
.
cast
(
v
[
"is_accepted"
],
tf
.
float32
)),
...
...
@@ -371,10 +357,21 @@ def run_mcmc(config):
if
__name__
==
"__main__"
:
# Read in settings
args
=
cli_args
()
from
argparse
import
ArgumentParser
parser
=
ArgumentParser
(
description
=
"Run MCMC inference algorithm"
)
parser
.
add_argument
(
"-c"
,
"--config"
,
type
=
str
,
help
=
"Config file"
,
required
=
True
)
parser
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
help
=
"Output file"
,
required
=
True
)
parser
.
add_argument
(
"data_file"
,
type
=
str
,
help
=
"Data pickle file"
,
required
=
True
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
"r"
)
as
f
:
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
run_
mcmc
(
config
)
mcmc
(
args
.
data_file
,
args
.
output
,
config
[
"Mcmc"
]
)
covid/tasks/insample_predictive_timeseries.py
View file @
5c2058da
...
...
@@ -8,7 +8,7 @@ import pickle as pkl
import
matplotlib.pyplot
as
plt
from
covid.cli_arg_parse
import
cli_args
from
model_spec
import
read_covariates
from
model_spec
import
gather_data
from
covid.data
import
read_phe_cases
from
covid.data
import
AreaCodeData
...
...
@@ -88,10 +88,7 @@ def main(config):
for
i
in
range
(
cases
.
shape
[
0
]):
title
=
lads
[
"name"
].
iloc
[
i
]
plot_timeseries
(
pred_quants
[:,
i
,
:
14
],
cases
.
iloc
[
i
,
-
14
:],
dates
,
title
,
pred_quants
[:,
i
,
:
14
],
cases
.
iloc
[
i
,
-
14
:],
dates
,
title
,
)
plt
.
savefig
(
results_dir
.
joinpath
(
f
"
{
lads
[
'lad19cd'
].
iloc
[
i
]
}
.png"
))
...
...
covid/tasks/next_generation_matrix.py
View file @
5c2058da
"""Calculates and saves a next generation matrix"""
import
argparse
import
yaml
import
numpy
as
np
import
pickle
as
pkl
import
xarray
import
tensorflow
as
tf
...
...
@@ -10,7 +10,7 @@ from covid import model_spec
from
gemlib.util
import
compute_state
def
calc_posterior_ngm
(
param
,
events
,
init_state
,
covar_data
):
def
calc_posterior_ngm
(
samples
,
covar_data
):
"""Calculates effective reproduction number for batches of metapopulations
:param theta: a tensor of batched theta parameters [B] + theta.shape
:param xi: a tensor of batched xi parameters [B] + xi.shape
...
...
@@ -23,15 +23,10 @@ def calc_posterior_ngm(param, events, init_state, covar_data):
def
r_fn
(
args
):
beta1_
,
beta2_
,
beta3_
,
sigma_
,
xi_
,
gamma0_
,
events_
=
args
t
=
events_
.
shape
[
-
2
]
-
1
state
=
compute_state
(
init_state
,
events_
,
model_spec
.
STOICHIOMETRY
)
state
=
tf
.
gather
(
state
,
t
,
axis
=-
2
)
# State on final inference day
model
=
model_spec
.
CovidUK
(
covariates
=
covar_data
,
initial_state
=
init_state
,
initial_step
=
0
,
num_steps
=
events_
.
shape
[
-
2
],
state
=
compute_state
(
samples
[
"init_state"
],
events_
,
model_spec
.
STOICHIOMETRY
)
state
=
tf
.
gather
(
state
,
t
,
axis
=-
2
)
# State on final inference day
par
=
dict
(
beta1
=
beta1_
,
...
...
@@ -48,13 +43,13 @@ def calc_posterior_ngm(param, events, init_state, covar_data):
return
tf
.
vectorized_map
(
r_fn
,
elems
=
(
param
[
"beta1"
],
param
[
"beta2"
],
param
[
"beta3"
],
param
[
"sigma"
],
param
[
"xi"
],
param
[
"gamma0"
],
events
,
samples
[
"beta1"
],
samples
[
"beta2"
],
samples
[
"beta3"
],
samples
[
"sigma"
],
samples
[
"xi"
],
samples
[
"gamma0"
],
samples
[
"seir"
]
,
),
)
...
...
@@ -65,13 +60,19 @@ def next_generation_matrix(input_files, output_file):
covar_data
=
pkl
.
load
(
f
)
with
open
(
input_files
[
1
],
"rb"
)
as
f
:
param
=
pkl
.
load
(
f
)
samples
=
pkl
.
load
(
f
)
# Compute ngm posterior
ngm
=
calc_posterior_ngm
(
param
,
param
[
"events"
],
param
[
"init_state"
],
covar_data
ngm
=
calc_posterior_ngm
(
samples
,
covar_data
).
numpy
()
ngm
=
xarray
.
DataArray
(
ngm
,
coords
=
[
np
.
arange
(
ngm
.
shape
[
0
]),
covar_data
[
"locations"
][
"lad19cd"
],
covar_data
[
"locations"
][
"lad19cd"
],
],