Commit 5974ccac authored by Chris Jewell's avatar Chris Jewell
Browse files

Bugfixes in PosteriorFunctions and PosteriorPredictiveFunctions

parent 5557139a
......@@ -25,8 +25,6 @@ GROUPS = {
"OBSERVATION_DATA": "observations",
"REPRODUCTION_NUMBER": "posterior_predictive",
"PREDICTIVE_CASES": "predictions",
"INSAMPLE7": "predictions",
"INSAMPLE14": "predictions",
"SAMPLES": None,
"WITHIN_BETWEEN": None,
}
......@@ -88,7 +86,7 @@ class PosteriorMetrics:
self._filter_location_value = filter_location
def _filter_location(self, data_array):
if self._filter_location is None:
if self._filter_location_value is None:
return data_array
is_location = data_array.coords["location"].str.startswith(
self._filter_location_value
......@@ -206,48 +204,30 @@ class PosteriorFunctions(PosteriorMetrics):
return abs_incidence / population_size
def case_exceedance(self):
"""Calculates the probability of observed cases exceeding the
"""Calculates the probability of observed cases exceeding the
predicted I->R events in the last 7 and 14 days of the analysis
time window.
:param lag: Either 7 or 14 denoting the required lag
:returns: `xarray.Dataset` containing 7 and 14 day exceedance
"""
observed7 = self._maybe_aggregate(
self._observations["cases"]
.isel(time=slice(-7, None))
.sum(dim="time")
)
observed14 = self._maybe_aggregate(
self._observations["cases"]
.isel(time=slice(-14, None))
.sum(dim="time")
)
predicted7 = (
xarray.open_dataset(
self._path / INSAMPLE7, group=GROUPS["INSAMPLE7"]
)["events"]
.isel(time=slice(-7, None))
.sum(dim="time")
)
predicted14 = (
xarray.open_dataset(
self._path / INSAMPLE14, group=GROUPS["INSAMPLE14"]
)["events"]
.isel(time=slice(-14, None))
.sum(dim="time")
)
darrays = {}
for dsrc, lag in zip([INSAMPLE7, INSAMPLE14], [7, 14]):
times = self._observations.coords["time"][-lag:]
observed = self._maybe_aggregate(
self._observations["cases"].sel(time=times).sum(dim="time")
)
return xarray.Dataset(
{
"Pr(pred<obs)_7": (predicted7 < observed7).mean(
dim="iteration"
),
"Pr(pred<obs)_14": (predicted14 < observed14).mean(
dim="iteration"
),
}
)
predicted = (
xarray.open_dataset(
self._path / dsrc, group=GROUPS["PREDICTIVE_CASES"]
)["events"]
.sel(time=times, event=2)
.sum(dim="time")
)
darrays[f"Pr(pred<obs)_{lag}"] = (predicted < observed).mean(
dim="iteration"
)
return xarray.Dataset(darrays).reset_coords(drop=True)
def spatial_mean(self):
"""Returns the mean of the spatial random effect"""
......@@ -296,34 +276,39 @@ class PosteriorPredictiveFunctions(PosteriorFunctions):
return self._filter_location(arr)
def _compute_state(self):
predicted = self._predicted
state = gl_compute_state(
self._predicted["initial_state"],
self._predicted["events"],
predicted["initial_state"],
predicted["events"],
model_spec.STOICHIOMETRY,
).numpy()
return xarray.DataArray(
state,
coords=[
np.arange(state.shape[0]),
self._constant_data.coords["location"],
self._constant_data.coords["time"],
predicted.coords["location"],
predicted.coords["time"],
np.arange(state.shape[-1]),
],
dims=["iteration", "location", "time", "state"],
)
def absolute_incidence(self):
def absolute_incidence(self, event=2):
"""Returns absolute predicted absolute incidence"""
return self._maybe_aggregate(self._predicted["events"].sel(event=0))
return self._maybe_aggregate(
self._predicted["events"].sel(event=event)
).reset_coords(drop=True)
def cumulative_absolute_incidence(self):
def cumulative_absolute_incidence(self, event=2):
"""Returns predicted cumulative absolute incidence"""
return self._maybe_aggregate(self.absolute_incidence.cumsum(dim="time"))
return self._maybe_aggregate(
self.absolute_incidence(event).cumsum(dim="time")
)
def relative_incidence(self):
def relative_incidence(self, event=2):
"""Returns predicted relative incidence"""
population_size = self._maybe_aggregate(self._constant_data["N"])
return self.absolute_incidence() / population_size
return self.absolute_incidence(event) / population_size
def prevalence(self):
state = self._compute_state()
......@@ -334,7 +319,9 @@ class PosteriorPredictiveFunctions(PosteriorFunctions):
def make_summary(
mean=True, quantiles=(0.025, 0.975), dim="iteration",
mean=True,
quantiles=(0.025, 0.975),
dim="iteration",
):
"""Make a summarisation function
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment