Commit 95d5f5b4 authored by Chris Jewell's avatar Chris Jewell
Browse files

Added country code filter to PosteriorMethods

parent 552ef36b
......@@ -50,16 +50,27 @@ def _events2xarray(samples, constant_data):
],
dims=["location", "state"],
)
spatial_samples = xarray.DataArray(
samples["spatial_effect"],
coords=[
np.arange(samples["spatial_effect"].shape[0]),
constant_data.coords["location"],
],
dims=["iteration", "location"],
)
return xarray.Dataset(
{"seir": event_samples, "initial_state": initial_state}
{
"seir": event_samples,
"initial_state": initial_state,
"spatial_effect": spatial_samples,
}
)
class PosteriorMetrics:
def __init__(
self,
results_directory,
aggregate=False,
self, results_directory, aggregate=False, filter_location=None
):
"""Summarizer attaches to a `covid19uk` results directory and summarizes
the output.
......@@ -70,14 +81,27 @@ class PosteriorMetrics:
"""
self._path = Path(results_directory)
self._aggregate = True if aggregate is True else False
self._filter_location_value = filter_location
def _filter_location(self, data_array):
if self._filter_location is None:
return data_array
is_location = data_array.coords["location"].str.startswith(
self._filter_location_value
)
return data_array.sel(location=is_location)
@property
def _constant_data(self):
def _constant_data_unfiltered(self):
"""Returns an `xarray.Dataset` of covariate data"""
return xarray.open_dataset(
str(self._path / CONSTANT_DATA), group=GROUPS["CONSTANT_DATA"]
)
@property
def _constant_data(self):
return self._filter_location(self._constant_data_unfiltered)
@property
def _observations(self):
"""Returns an `xarray.Dataset` of observation data"""
......@@ -99,11 +123,14 @@ class PosteriorFunctions(PosteriorMetrics):
"""Returns a dictionary of MCMC samples"""
with open(self._path / SAMPLES, "rb") as f:
samples = pkl.load(f)
samples = _events2xarray(samples, self._constant_data_unfiltered)
samples = self._filter_location(samples)
return samples
def _compute_state(self):
"""Reconstructs the state of the population a each timepoint"""
event_samples = _events2xarray(self._samples, self._constant_data)
event_samples = self._samples
state = gl_compute_state(
event_samples["initial_state"],
event_samples["seir"],
......@@ -139,7 +166,7 @@ class PosteriorFunctions(PosteriorMetrics):
def absolute_incidence(self):
"""Return `xarray.Dataset` with daily absolute
infection incidence samples"""
event_samples = _events2xarray(self._samples, self._constant_data)
event_samples = self._samples
infection_events = (
event_samples["seir"].sel(event=0).reset_coords(drop=True)
)
......@@ -219,9 +246,10 @@ class PosteriorPredictiveFunctions(PosteriorFunctions):
@property
def _predicted(self):
return xarray.open_dataset(
arr = xarray.open_dataset(
str(self._path / PREDICTIVE_CASES), group=GROUPS["PREDICTIVE_CASES"]
)
return self._filter_location(arr)
def _compute_state(self):
state = gl_compute_state(
......@@ -283,6 +311,7 @@ def make_summary(
summary_xarr.to_dataframe() # Format at pandas dataframe
```
"""
def fn(samples):
data_arrays = {}
if mean is True:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment