Commit ce643600 authored by Chris Jewell's avatar Chris Jewell
Browse files

Merge branch 'master' of fhm-chicas-code.lancs.ac.uk:jewell/covid-pipeline

parents 9e6bb98f 95d5f5b4
"""Covid19UK pipeline"""
from summarizer import (
from covid_pipeline.summarizer import (
PosteriorFunctions,
PosteriorPredictiveFunctions,
make_summary,
......
......@@ -50,16 +50,27 @@ def _events2xarray(samples, constant_data):
],
dims=["location", "state"],
)
spatial_samples = xarray.DataArray(
samples["spatial_effect"],
coords=[
np.arange(samples["spatial_effect"].shape[0]),
constant_data.coords["location"],
],
dims=["iteration", "location"],
)
return xarray.Dataset(
{"seir": event_samples, "initial_state": initial_state}
{
"seir": event_samples,
"initial_state": initial_state,
"spatial_effect": spatial_samples,
}
)
class PosteriorMetrics:
def __init__(
self,
results_directory,
aggregate=False,
self, results_directory, aggregate=False, filter_location=None
):
"""Summarizer attaches to a `covid19uk` results directory and summarizes
the output.
......@@ -70,14 +81,27 @@ class PosteriorMetrics:
"""
self._path = Path(results_directory)
self._aggregate = True if aggregate is True else False
self._filter_location_value = filter_location
def _filter_location(self, data_array):
if self._filter_location is None:
return data_array
is_location = data_array.coords["location"].str.startswith(
self._filter_location_value
)
return data_array.sel(location=is_location)
@property
def _constant_data(self):
def _constant_data_unfiltered(self):
"""Returns an `xarray.Dataset` of covariate data"""
return xarray.open_dataset(
str(self._path / CONSTANT_DATA), group=GROUPS["CONSTANT_DATA"]
)
@property
def _constant_data(self):
return self._filter_location(self._constant_data_unfiltered)
@property
def _observations(self):
"""Returns an `xarray.Dataset` of observation data"""
......@@ -99,11 +123,14 @@ class PosteriorFunctions(PosteriorMetrics):
"""Returns a dictionary of MCMC samples"""
with open(self._path / SAMPLES, "rb") as f:
samples = pkl.load(f)
samples = _events2xarray(samples, self._constant_data_unfiltered)
samples = self._filter_location(samples)
return samples
def _compute_state(self):
"""Reconstructs the state of the population a each timepoint"""
event_samples = _events2xarray(self._samples, self._constant_data)
event_samples = self._samples
state = gl_compute_state(
event_samples["initial_state"],
event_samples["seir"],
......@@ -139,7 +166,7 @@ class PosteriorFunctions(PosteriorMetrics):
def absolute_incidence(self):
"""Return `xarray.Dataset` with daily absolute
infection incidence samples"""
event_samples = _events2xarray(self._samples, self._constant_data)
event_samples = self._samples
infection_events = (
event_samples["seir"].sel(event=0).reset_coords(drop=True)
)
......@@ -219,9 +246,10 @@ class PosteriorPredictiveFunctions(PosteriorFunctions):
@property
def _predicted(self):
return xarray.open_dataset(
arr = xarray.open_dataset(
str(self._path / PREDICTIVE_CASES), group=GROUPS["PREDICTIVE_CASES"]
)
return self._filter_location(arr)
def _compute_state(self):
state = gl_compute_state(
......@@ -283,6 +311,7 @@ def make_summary(
summary_xarr.to_dataframe() # Format at pandas dataframe
```
"""
def fn(samples):
data_arrays = {}
if mean is True:
......
......@@ -36,7 +36,9 @@ def _events2xarray(samples, constant_data):
],
dims=["location", "state"],
)
return xarray.Dataset({"seir": event_samples, "initial_state": initial_state})
return xarray.Dataset(
{"seir": event_samples, "initial_state": initial_state}
)
def _xarray2dstl(xarr, value_type, geography):
......@@ -63,10 +65,10 @@ def _xarray2dstl(xarr, value_type, geography):
def _has_ewsni(locations):
"""Checks to see if all England, Wales, Scotland, NI DAs are
present in `locations`. Return True if so, False otherwise.
present in `locations`. Return True if so, False otherwise.
"""
country_codes = locations.astype(str).str[0]
if xarray.DataArray(['E','W','S','N']).isin(country_codes).all():
if xarray.DataArray(["E", "W", "S", "N"]).isin(country_codes).all():
return True
return False
......@@ -74,7 +76,10 @@ def _has_ewsni(locations):
def incidence(event_samples, popsize):
"""Select infection events, aggregate over location, divide by total popsize"""
infection_events = (
event_samples["seir"].sel(event=0).sum(dim="location").reset_coords(drop=True)
event_samples["seir"]
.sel(event=0)
.sum(dim="location")
.reset_coords(drop=True)
)
return infection_events
......@@ -82,7 +87,9 @@ def incidence(event_samples, popsize):
def prevalence(event_samples, popsize):
"""Prevalence in percentage units"""
state = compute_state(
event_samples["initial_state"], event_samples["seir"], model_spec.STOICHIOMETRY
event_samples["initial_state"],
event_samples["seir"],
model_spec.STOICHIOMETRY,
).numpy()
state = state[..., 1:3].sum(axis=-1).sum(axis=1) # Sum E+I and location
state = xarray.DataArray(
......@@ -113,7 +120,9 @@ def summarize_supergeography(event_samples, rt, population, geography_name):
prev_xarr = prevalence(event_samples, population)
# Rt
rt_summary = (rt["R_it"] * population / population.sum()).sum(dim="location")
rt_summary = (rt["R_it"] * population / population.sum()).sum(
dim="location"
)
df = pd.concat(
[
......@@ -132,7 +141,7 @@ def crystalcast_output(input_files, output):
:param input_files: a list of [inferencedata, thin_samples, reproduction_number]
:param output: name of output XLSX file
"""
path = Path(input_files[0]).parent
constant_data = xarray.open_dataset(input_files[0], group="constant_data")
with open(input_files[1], "rb") as f:
samples = pkl.load(f)
......@@ -141,17 +150,18 @@ def crystalcast_output(input_files, output):
# xarray-ify events tensor for ease of reduction
event_samples = _events2xarray(samples, constant_data)
# Clip off first week for initial conditions burnin
event_samples = event_samples.isel(time=slice(7, None, None))
# population
population = constant_data["N"]
# If all DAs are present, create UK summary
df = []
if _has_ewsni(rt.coords['location']):
df.append(summarize_supergeography(event_samples, rt, population, "United Kingdom"))
if _has_ewsni(rt.coords["location"]):
df.append(
summarize_supergeography(
event_samples, rt, population, "United Kingdom"
)
)
# DAs
for country in ["England", "Scotland", "Wales", "Northern Ireland"]:
regions = event_samples.coords["location"].str.startswith(country[0])
......
[tool.poetry]
name = "covid-pipeline"
version = "0.1.1-alpha.5"
version = "0.1.1-alpha.6"
description = "COVID19 daily production"
authors = ["Chris Jewell <c.jewell@lancaster.ac.uk>"]
license = "MIT"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment