Commit 9e3a9e78 authored by Chris Jewell's avatar Chris Jewell
Browse files

Fixed bug in setup of tiers data structure

parent 8423f467
...@@ -69,38 +69,39 @@ class TierData: ...@@ -69,38 +69,39 @@ class TierData:
""" """
tiers["date"] = pd.to_datetime(tiers["date"], format="%Y-%m-%d") tiers["date"] = pd.to_datetime(tiers["date"], format="%Y-%m-%d")
tiers["lad19cd"] = merge_lad_codes(tiers["areaCode"]) tiers["lad19cd"] = merge_lad_codes(tiers["areaCode"])
tiers["alert_level"] = tiers["alertLevel"] tiers["alert_level"] = tiers["alertLevel"].astype(int)
tiers = tiers[["date", "lad19cd", "alert_level"]] tiers = tiers[["date", "lad19cd", "alert_level"]]
if len(lads) > 0: if len(lads) > 0:
tiers = tiers[tiers["lad19cd"].isin(lads)] tiers = tiers[tiers["lad19cd"].isin(lads)]
tiers = tiers.drop_duplicates()
date_range = pd.date_range(date_low, date_high - np.timedelta64(1, "D")) tiers_wide = tiers.pivot(
index="date", columns="lad19cd", values="alert_level"
def interpolate(df): )
df.index = pd.Index(pd.to_datetime(df["date"]), name="date") tiers_wide = tiers_wide.sort_index()
df = df.drop(columns="date").sort_index()
df = df.reindex(date_range)
df["alert_level"] = (
df["alert_level"].ffill().backfill().astype("int")
)
return df[["alert_level"]]
tiers = tiers.groupby(["lad19cd"]).apply(interpolate) # Fill in time gaps
date_range = pd.date_range(date_low, date_high - np.timedelta64(1, "D"))
tiers_wide = tiers_wide.reindex(
pd.Index(date_range, name="date"), method="ffill"
)
tiers_wide = tiers_wide.backfill()
tiers = tiers_wide.melt(
value_name="alert_level", var_name="lad19cd", ignore_index=False
)
tiers = tiers.reset_index() tiers = tiers.reset_index()
tiers.columns = ["lad19cd", "date", "alert_level"] tiers["alert_level"] = tiers["alert_level"].astype(np.int32)
# Convert to xarray to create table of factors
index = pd.MultiIndex.from_frame(tiers) index = pd.MultiIndex.from_frame(tiers)
index = index.sort_values() index = index.sort_values()
index = index[~index.duplicated()] ser = pd.Series(1, index=index, name="value").astype(int)
ser = pd.Series(1, index=index, name="value")
ser = ser.loc[
pd.IndexSlice[:, date_low : (date_high - np.timedelta64(1, "D")), :]
]
xarr = ser.to_xarray() xarr = ser.to_xarray()
xarr.data[np.isnan(xarr.data)] = 0.0 xarr = xarr.fillna(0.0)
# return [T, M, V] structure # return [date, location, alert_level] structure
return xarr.transpose("date", "lad19cd", "alert_level") return xarr
def adapt_xarray(tiers, date_low, date_high, lads, settings): def adapt_xarray(tiers, date_low, date_high, lads, settings):
""" """
......
...@@ -36,18 +36,13 @@ def gather_data(config): ...@@ -36,18 +36,13 @@ def gather_data(config):
) )
locations = data.AreaCodeData.process(config) locations = data.AreaCodeData.process(config)
tier_restriction = data.TierData.process(config)[:, :, 2:] tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4, 5]]
date_range = [date_low, date_high] date_range = [date_low, date_high]
weekday = pd.date_range(date_low, date_high).weekday < 5 weekday = (
pd.date_range(date_low, date_high - np.timedelta64(1, "D")).weekday < 5
)
cases = data.CasesData.process(config).to_xarray() cases = data.CasesData.process(config).to_xarray()
# cases = data.read_phe_cases(
# config['reported_cases'],
# date_low,
# date_high,
# pillar=config['pillar'],
# date_type=config['case_date_type'],
# )
return dict( return dict(
C=mobility.to_numpy().astype(DTYPE), C=mobility.to_numpy().astype(DTYPE),
W=commute_volume.to_numpy().astype(DTYPE), W=commute_volume.to_numpy().astype(DTYPE),
...@@ -107,8 +102,10 @@ def CovidUK(covariates, initial_state, initial_step, num_steps): ...@@ -107,8 +102,10 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def beta3(): def beta3():
return tfd.Independent( return tfd.Independent(
tfd.Normal( tfd.Normal(
loc=tf.constant([0.0] * 4, dtype=DTYPE), loc=tf.constant([0.0] * covariates["L"].shape[-1], dtype=DTYPE),
scale=tf.constant([1.0] * 4, dtype=DTYPE), scale=tf.constant(
[1.0] * covariates["L"].shape[-1], dtype=DTYPE
),
), ),
reinterpreted_batch_ndims=1, reinterpreted_batch_ndims=1,
) )
......
...@@ -101,7 +101,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -101,7 +101,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
tfp.bijectors.Exp(), tfp.bijectors.Exp(),
tfp.bijectors.Identity(), tfp.bijectors.Identity(),
], ],
block_sizes=[1, 2, 1, 4], block_sizes=[1, 2, 1, 5],
), ),
name=name, name=name,
) )
...@@ -252,7 +252,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True): ...@@ -252,7 +252,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
tf.random.set_seed(2) tf.random.set_seed(2)
current_state = [ current_state = [
np.array([0.6, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE), np.array([0.6, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=DTYPE),
np.zeros( np.zeros(
model.model["xi"](0.0, 0.1).event_shape[-1] + 1, model.model["xi"](0.0, 0.1).event_shape[-1] + 1,
dtype=DTYPE, dtype=DTYPE,
......
"""Calculates and saves a next generation matrix""" """Calculates and saves a next generation matrix"""
import numpy as np
import pickle as pkl import pickle as pkl
import numpy as np
import xarray import xarray
import tensorflow as tf import tensorflow as tf
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment