Commit edfb5b84 authored by Chris Jewell's avatar Chris Jewell
Browse files

Added facility for incorporating tier efficacy parameters beta3

parent d23152c1
......@@ -3,10 +3,10 @@
import argparse
def cli_args():
def cli_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, help="configuration file")
args = parser.parse_args()
args = parser.parse_args(args)
return args
......@@ -5,7 +5,12 @@ from warnings import warn
import numpy as np
import pandas as pd
__all__ = ["read_mobility", "read_population", "read_traffic_flow", "read_phe_cases"]
__all__ = [
"read_mobility",
"read_population",
"read_traffic_flow",
"read_phe_cases",
]
def read_mobility(path):
......@@ -19,7 +24,8 @@ def read_mobility(path):
"""
mobility = pd.read_csv(path)
mobility = mobility[
mobility["From"].str.startswith("E") & mobility["To"].str.startswith("E")
mobility["From"].str.startswith("E")
& mobility["To"].str.startswith("E")
]
mobility = mobility.sort_values(["From", "To"])
mobility = mobility.groupby(["From", "To"]).agg({"Flow": sum}).reset_index()
......@@ -40,7 +46,9 @@ def read_population(path):
return pop
def read_traffic_flow(path: str, date_low: np.datetime64, date_high: np.datetime64):
def read_traffic_flow(
path: str, date_low: np.datetime64, date_high: np.datetime64
):
"""Read traffic flow data, returning a timeseries between dates.
:param path: path to a traffic flow CSV with <date>,<Car> columns
:returns: a Pandas timeseries
......@@ -50,8 +58,12 @@ def read_traffic_flow(path: str, date_low: np.datetime64, date_high: np.datetime
)
commute_raw.index = pd.to_datetime(commute_raw.index, format="%Y-%m-%d")
commute_raw.sort_index(axis=0, inplace=True)
commute = pd.DataFrame(index=np.arange(date_low, date_high, np.timedelta64(1, "D")))
commute = commute.merge(commute_raw, left_index=True, right_index=True, how="left")
commute = pd.DataFrame(
index=np.arange(date_low, date_high, np.timedelta64(1, "D"))
)
commute = commute.merge(
commute_raw, left_index=True, right_index=True, how="left"
)
commute[commute.index < commute_raw.index[0]] = commute_raw.iloc[0, 0]
commute[commute.index > commute_raw.index[-1]] = commute_raw.iloc[-1, 0]
commute["Cars"] = commute["Cars"] / 100.0
......@@ -87,14 +99,18 @@ def read_phe_cases(
line_listing["lad19cd"] = _merge_ltla(line_listing["lad19cd"])
# Select dates
line_listing["date"] = pd.to_datetime(line_listing["date"], format="%d/%m/%Y")
line_listing["date"] = pd.to_datetime(
line_listing["date"], format="%d/%m/%Y"
)
line_listing = line_listing[
(date_low <= line_listing["date"]) & (line_listing["date"] < date_high)
]
# Choose pillar
if pillar_map[pillar] is not None:
line_listing = line_listing.loc[line_listing["pillar"] == pillar_map[pillar]]
line_listing = line_listing.loc[
line_listing["pillar"] == pillar_map[pillar]
]
# Drop na rows
orig_len = line_listing.shape[0]
......@@ -112,8 +128,45 @@ due to missing values ({100. * (orig_len - line_listing.shape[0])/orig_len}%)"
dates = pd.date_range(date_low, date_high, closed="left")
if ltlas is None:
ltlas = case_counts.index.levels[1]
index = pd.MultiIndex.from_product([dates, ltlas], names=["date", "lad19cd"])
index = pd.MultiIndex.from_product(
[dates, ltlas], names=["date", "lad19cd"]
)
case_counts = case_counts.reindex(index, fill_value=0)
return case_counts.reset_index().pivot(
index="lad19cd", columns="date", values="count"
)
def read_tier_restriction_data(
tier_restriction_csv, lad19cd_lookup, date_low, date_high
):
data = pd.read_csv(tier_restriction_csv)
# Group merged ltlas
london = ["City of London", "Westminster"]
corn_scilly = ["Cornwall", "Isles of Scilly"]
data.loc[data["ltla"].isin(london), "ltla"] = ":".join(london)
data.loc[data["ltla"].isin(corn_scilly), "ltla"] = ":".join(corn_scilly)
# Fix up dodgy names
data.loc[
data["ltla"] == "Blackburn With Darwen", "ltla"
] = "Blackburn with Darwen"
# Merge
data = lad19cd_lookup.merge(
data, how="left", left_on="lad19nm", right_on="ltla"
)
# Re-index
data.index = pd.MultiIndex.from_frame(data[["date", "lad19cd"]])
data = data[["tier_2", "tier_3"]]
data = data[~data.index.duplicated()]
dates = pd.date_range(date_low, date_high - pd.Timedelta(1, "D"))
lad19cd = lad19cd_lookup["lad19cd"].sort_values().unique()
new_index = pd.MultiIndex.from_product([dates, lad19cd])
data = data.reindex(new_index, fill_value=0.0)
# Pack into [T, M, V] array.
arr_data = data.to_xarray().to_array()
return np.transpose(arr_data, axes=[1, 2, 0])
......@@ -102,18 +102,19 @@ if __name__ == "__main__":
initial_state=initial_state,
initial_step=0,
num_steps=events.shape[1],
priors=convert_priors(config['mcmc']['prior']),
priors=convert_priors(config["mcmc"]["prior"]),
)
# Full joint log posterior distribution
# $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
def logp(theta, xi, events):
def logp(block0, block1, events):
return model.log_prob(
dict(
beta1=xi[0],
beta2=theta[0],
gamma=theta[1],
xi=xi[1:],
beta1=block1[0],
beta2=block0[0],
beta3=block1[1:3],
gamma=block0[1],
xi=block1[3:],
seir=events,
)
)
......@@ -127,14 +128,13 @@ if __name__ == "__main__":
# Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
# Q(Z^{se}, Z^{se\prime}) (occult)
# Q(Z^{ei}, Z^{ei\prime}) (occult)
def make_theta_kernel(shape, name):
def make_blk0_kernel(shape, name):
def fn(target_log_prob_fn, _):
return tfp.mcmc.TransformedTransitionKernel(
inner_kernel=AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_covariance=[
np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
],
initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
* 1e-1,
covariance_burnin=200,
),
bijector=tfp.bijectors.Exp(),
......@@ -143,13 +143,12 @@ if __name__ == "__main__":
return fn
def make_xi_kernel(shape, name):
def make_blk1_kernel(shape, name):
def fn(target_log_prob_fn, _):
return AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_covariance=[
np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
],
initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
* 1e-1,
covariance_burnin=200,
name=name,
)
......@@ -235,7 +234,7 @@ if __name__ == "__main__":
return recurse(f, results)
# Build MCMC algorithm here. This will be run in bursts for memory economy
@tf.function(autograph=False, experimental_compile=True)
@tf.function # (autograph=False, experimental_compile=True)
def sample(n_samples, init_state, previous_results=None):
with tf.name_scope("main_mcmc_sample_loop"):
......@@ -244,8 +243,8 @@ if __name__ == "__main__":
gibbs_schema = GibbsKernel(
target_log_prob_fn=logp,
kernel_list=[
(0, make_theta_kernel(init_state[0].shape, "theta")),
(1, make_xi_kernel(init_state[1].shape, "xi")),
(0, make_blk0_kernel(init_state[0].shape, "theta")),
(1, make_blk1_kernel(init_state[1].shape, "xi")),
(2, make_event_multiscan_kernel),
],
name="gibbs0",
......@@ -277,7 +276,7 @@ if __name__ == "__main__":
current_state = [
np.array([0.65, 0.48], dtype=DTYPE),
np.zeros(model.model["xi"](0.0).event_shape[-1] + 1, dtype=DTYPE),
np.zeros(model.model["xi"](0.0).event_shape[-1] + 3, dtype=DTYPE),
events,
]
......
"""Implements the COVID SEIR model as a TFP Joint Distribution"""
import geopandas as gp
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
......@@ -11,7 +12,7 @@ import covid.data as data
tfd = tfp.distributions
DTYPE = np.float64
STOICHIOMETRY = tf.constant([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
TIME_DELTA = 1.0
XI_FREQ = 14 # baseline transmission changes every 14 days
NU = tf.constant(0.5, dtype=DTYPE) # E->I rate assumed known.
......@@ -31,10 +32,19 @@ def read_covariates(paths, date_low, date_high):
paths["commute_volume"], date_low=date_low, date_high=date_high
)
geo = gp.read_file(paths["geopackage"])
geo = geo.loc[geo["lad19cd"].str.startswith("E")]
tier_restriction = data.read_tier_restriction_data(
paths["tier_restriction_csv"],
geo[["lad19cd", "lad19nm"]],
date_low,
date_high,
)
return dict(
C=mobility.to_numpy().astype(DTYPE),
W=commute_volume.to_numpy().astype(DTYPE),
N=popsize.to_numpy().astype(DTYPE),
L=tier_restriction.astype(DTYPE),
)
......@@ -72,6 +82,15 @@ def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
rate=tf.constant(10.0, dtype=DTYPE),
)
def beta3():
return tfd.Sample(
tfd.Normal(
loc=tf.constant(0.0, dtype=DTYPE),
scale=tf.constant(1000.0, dtype=DTYPE),
),
sample_shape=2,
)
def xi(beta1):
sigma = tf.constant(0.1, dtype=DTYPE)
phi = tf.constant(24.0, dtype=DTYPE)
......@@ -91,12 +110,16 @@ def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
rate=tf.constant(priors["gamma"]["rate"], dtype=DTYPE),
)
def seir(beta2, xi, gamma):
def seir(beta2, beta3, xi, gamma):
beta2 = tf.convert_to_tensor(beta2, DTYPE)
beta3 = tf.convert_to_tensor(beta3, DTYPE)
xi = tf.convert_to_tensor(xi, DTYPE)
gamma = tf.convert_to_tensor(gamma, DTYPE)
L = tf.convert_to_tensor(covariates["L"], DTYPE)
L = L - tf.reduce_mean(L, axis=0)
def transition_rate_fn(t, state):
C = tf.convert_to_tensor(covariates["C"], dtype=DTYPE)
C = tf.linalg.set_diag(
......@@ -113,7 +136,11 @@ def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
)
xi_ = tf.gather(xi, xi_idx)
infec_rate = tf.math.exp(xi_) * (
L_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, L.shape[0] - 1)
Lt = tf.gather(L, L_idx)
xB = tf.linalg.matvec(Lt, beta3)
infec_rate = tf.math.exp(xi_ + xB) * (
state[..., 2]
+ beta2
* commute_volume
......@@ -142,27 +169,10 @@ def CovidUK(covariates, initial_state, initial_step, num_steps, priors):
)
return tfd.JointDistributionNamed(
dict(beta1=beta1, beta2=beta2, xi=xi, gamma=gamma, seir=seir)
)
def marginalized_log_prob(model):
"""Joint log_prob function with baseline hazard
rates marginalized out.
"""
def log_prob(beta2, xi, seir):
lp_beta2 = model.model.modules["beta"].log_prob(beta2)
lp_xi = model.model.modules["xi"].log_prob(xi)
seir_marginal = DiscreteTimeStateTransitionMarginalModel(
*model.model.modules["seir"]._parameters
dict(
beta1=beta1, beta2=beta2, beta3=beta3, xi=xi, gamma=gamma, seir=seir
)
lp_seir_marginal = seir_marginal(seir)
return lp_beta2 + lp_xi + lp_seir_marginal
)
def next_generation_matrix_fn(covar_data, param):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment