Commit 552ef36b authored by Chris Jewell's avatar Chris Jewell
Browse files

Removed first week clipping for CrystalCast incidence and prevalence

parent 5483515b
......@@ -36,7 +36,9 @@ def _events2xarray(samples, constant_data):
],
dims=["location", "state"],
)
return xarray.Dataset({"seir": event_samples, "initial_state": initial_state})
return xarray.Dataset(
{"seir": event_samples, "initial_state": initial_state}
)
def _xarray2dstl(xarr, value_type, geography):
......@@ -63,10 +65,10 @@ def _xarray2dstl(xarr, value_type, geography):
def _has_ewsni(locations):
"""Checks to see if all England, Wales, Scotland, NI DAs are
present in `locations`. Return True if so, False otherwise.
present in `locations`. Return True if so, False otherwise.
"""
country_codes = locations.astype(str).str[0]
if xarray.DataArray(['E','W','S','N']).isin(country_codes).all():
if xarray.DataArray(["E", "W", "S", "N"]).isin(country_codes).all():
return True
return False
......@@ -74,7 +76,10 @@ def _has_ewsni(locations):
def incidence(event_samples, popsize):
"""Select infection events, aggregate over location, divide by total popsize"""
infection_events = (
event_samples["seir"].sel(event=0).sum(dim="location").reset_coords(drop=True)
event_samples["seir"]
.sel(event=0)
.sum(dim="location")
.reset_coords(drop=True)
)
return infection_events
......@@ -82,7 +87,9 @@ def incidence(event_samples, popsize):
def prevalence(event_samples, popsize):
"""Prevalence in percentage units"""
state = compute_state(
event_samples["initial_state"], event_samples["seir"], model_spec.STOICHIOMETRY
event_samples["initial_state"],
event_samples["seir"],
model_spec.STOICHIOMETRY,
).numpy()
state = state[..., 1:3].sum(axis=-1).sum(axis=1) # Sum E+I and location
state = xarray.DataArray(
......@@ -113,7 +120,9 @@ def summarize_supergeography(event_samples, rt, population, geography_name):
prev_xarr = prevalence(event_samples, population)
# Rt
rt_summary = (rt["R_it"] * population / population.sum()).sum(dim="location")
rt_summary = (rt["R_it"] * population / population.sum()).sum(
dim="location"
)
df = pd.concat(
[
......@@ -132,7 +141,7 @@ def crystalcast_output(input_files, output):
:param input_files: a list of [inferencedata, thin_samples, reproduction_number]
:param output: name of output XLSX file
"""
path = Path(input_files[0]).parent
constant_data = xarray.open_dataset(input_files[0], group="constant_data")
with open(input_files[1], "rb") as f:
samples = pkl.load(f)
......@@ -141,17 +150,18 @@ def crystalcast_output(input_files, output):
# xarray-ify events tensor for ease of reduction
event_samples = _events2xarray(samples, constant_data)
# Clip off first week for initial conditions burnin
event_samples = event_samples.isel(time=slice(7, None, None))
# population
population = constant_data["N"]
# If all DAs are present, create UK summary
df = []
if _has_ewsni(rt.coords['location']):
df.append(summarize_supergeography(event_samples, rt, population, "United Kingdom"))
if _has_ewsni(rt.coords["location"]):
df.append(
summarize_supergeography(
event_samples, rt, population, "United Kingdom"
)
)
# DAs
for country in ["England", "Scotland", "Wales", "Northern Ireland"]:
regions = event_samples.coords["location"].str.startswith(country[0])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment