Commit 3462054d authored by Chris Jewell's avatar Chris Jewell
Browse files

Add area to transmission model

parent 73fcac5c
"""Implements the COVID SEIR model as a TFP Joint Distribution"""
import pandas as pd
import geopandas as gp
import numpy as np
import xarray
import tensorflow as tf
import tensorflow_probability as tfp
from gemlib.distributions import DiscreteTimeStateTransitionModel
from gemlib.distributions import BrownianMotion
from covid.util import impute_previous_cases
import covid.data as data
tfd = tfp.distributions
VERSION = "0.5.0"
VERSION = "0.7.0"
DTYPE = np.float64
STOICHIOMETRY = np.array([[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]])
TIME_DELTA = 1.0
XI_FREQ = 14 # baseline transmission changes every 14 days
NU = tf.constant(0.28, dtype=DTYPE) # E->I rate assumed known.
......@@ -42,6 +44,14 @@ def gather_data(config):
commute_volume = data.read_traffic_flow(
config["commute_volume"], date_low=date_low, date_high=date_high
)
geo = gp.read_file(config["geopackage"])
geo = geo.sort_values("lad19cd")
area = xarray.DataArray(
geo.area,
name="area",
dims=["location"],
coords=[geo["lad19cd"]],
)
# tier_restriction = data.TierData.process(config)[:, :, [0, 2, 3, 4]]
dates = pd.date_range(*config["date_range"], closed="left")
......@@ -60,6 +70,7 @@ def gather_data(config):
W=commute_volume.astype(DTYPE),
N=popsize.astype(DTYPE),
weekday=weekday.astype(DTYPE),
area=area.astype(DTYPE),
locations=xarray.DataArray(
locations["name"],
dims=["location"],
......@@ -103,32 +114,27 @@ def conditional_gp(gp, observations, new_index_points):
def CovidUK(covariates, initial_state, initial_step, num_steps):
def beta1():
def alpha_0():
return tfd.Normal(
loc=tf.constant(0.0, dtype=DTYPE),
scale=tf.constant(10.0, dtype=DTYPE),
)
def beta_area():
return tfd.Normal(
loc=tf.constant(0.0, dtype=DTYPE),
scale=tf.constant(1.0, dtype=DTYPE),
)
def beta2():
def psi():
return tfd.Gamma(
concentration=tf.constant(3.0, dtype=DTYPE),
rate=tf.constant(10.0, dtype=DTYPE),
)
def sigma():
return tfd.Gamma(
concentration=tf.constant(20.0, dtype=DTYPE),
rate=tf.constant(200.0, dtype=DTYPE),
)
def xi(beta1, sigma):
phi = tf.constant(24.0, dtype=DTYPE)
kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
idx_pts = tf.cast(tf.range(num_steps // XI_FREQ) * XI_FREQ, dtype=DTYPE)
return tfd.GaussianProcessRegressionModel(
kernel,
mean_fn=lambda idx: beta1,
index_points=idx_pts[:, tf.newaxis],
def alpha_t(alpha_0):
return BrownianMotion(
tf.range(num_steps, dtype=DTYPE), x0=alpha_0, scale=0.01
)
def gamma0():
......@@ -143,9 +149,10 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
scale=tf.constant(100.0, dtype=DTYPE),
)
def seir(beta2, xi, gamma0, gamma1):
beta2 = tf.convert_to_tensor(beta2, DTYPE)
xi = tf.convert_to_tensor(xi, DTYPE)
def seir(psi, beta_area, alpha_0, alpha_t, gamma0, gamma1):
psi = tf.convert_to_tensor(psi, DTYPE)
beta_area = tf.convert_to_tensor(beta_area, DTYPE)
alpha_t = tf.convert_to_tensor(alpha_t, DTYPE)
gamma0 = tf.convert_to_tensor(gamma0, DTYPE)
gamma1 = tf.convert_to_tensor(gamma1, DTYPE)
......@@ -161,24 +168,40 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
weekday = tf.convert_to_tensor(covariates["weekday"], DTYPE)
weekday = weekday - tf.reduce_mean(weekday, axis=-1)
# Area in 100km^2
area = tf.convert_to_tensor(covariates["area"], DTYPE)
log_area = tf.math.log(area / 100000000.0) # log area in 100km^2
log_area = log_area - tf.reduce_mean(log_area)
def transition_rate_fn(t, state):
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
dtype=tf.int64,
)
xi_ = tf.gather(xi, xi_idx)
weekday_idx = tf.clip_by_value(
tf.cast(t, tf.int64), 0, weekday.shape[0] - 1
)
weekday_t = tf.gather(weekday, weekday_idx)
infec_rate = tf.math.exp(xi_) * (
with tf.name_scope("Pick_alpha_t"):
alpha_t_idx = tf.cast(t, tf.int64)
alpha_t_ = tf.where(
alpha_t_idx == initial_step,
alpha_0,
tf.gather(
alpha_t,
tf.clip_by_value(
alpha_t_idx - initial_step - 1,
clip_value_min=0,
clip_value_max=alpha_t.shape[0] - 1,
),
),
)
eta = alpha_t_ + beta_area * log_area
infec_rate = tf.math.exp(eta) * (
state[..., 2]
+ beta2
+ psi
* commute_volume
* tf.linalg.matvec(Cstar, state[..., 2] / tf.squeeze(N))
)
......@@ -207,10 +230,10 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
return tfd.JointDistributionNamed(
dict(
beta1=beta1,
beta2=beta2,
sigma=sigma,
xi=xi,
alpha_0=alpha_0,
beta_area=beta_area,
psi=psi,
alpha_t=alpha_t,
gamma0=gamma0,
gamma1=gamma1,
seir=seir,
......@@ -244,17 +267,24 @@ def next_generation_matrix_fn(covar_data, param):
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
dtype=tf.int64,
xi = tf.where(
t == 0,
param["alpha_0"],
tf.gather(
param["alpha_t"],
tf.clip_by_value(
t,
clip_value_min=0,
clip_value_max=param["alpha_t"].shape[-1] - 1,
),
),
)
xi = tf.gather(param["xi"], xi_idx)
beta = tf.math.exp(xi)
ngm = beta * (
tf.eye(Cstar.shape[0], dtype=state.dtype)
+ param["beta2"] * commute_volume * Cstar / N[tf.newaxis, :]
+ param["psi"] * commute_volume * Cstar / N[tf.newaxis, :]
)
ngm = (
ngm
......
"""MCMC Test Rig for COVID-19 UK model"""
# pylint: disable=E402
import pickle as pkl
from time import perf_counter
import sys
import h5py
import xarray
......@@ -260,14 +259,13 @@ def trace_results_fn(_, results):
def draws_to_dict(draws):
return {
"beta2": draws[0][:, 0],
"gamma0": draws[0][:, 1],
"gamma1": draws[0][:, 2],
"sigma": draws[0][:, 3],
"beta3": tf.zeros([1, 5], dtype=DTYPE),
"beta1": draws[0][:, 4],
"xi": draws[0][:, 5:],
"events": draws[1],
"psi": draws[0][:, 0],
"beta_area": draws[0][:, 1],
"gamma0": draws[0][:, 2],
"gamma1": draws[0][:, 3],
"alpha_0": draws[0][:, 4],
"alpha_t": draws[0][:, 5:],
"seir": draws[1],
}
......@@ -297,13 +295,14 @@ def run_mcmc(
event_kernel_kwargs = {
"initial_state": initial_conditions,
"t_range": [
current_state[1].shape[-2] - 21,
current_state[1].shape[-2] - 28,
current_state[1].shape[-2],
],
"config": config,
}
# Set up posterior
print("Initialising output...", end="", flush=True, file=sys.stderr)
draws, trace, _ = _fixed_window(
num_draws=1,
joint_log_prob_fn=joint_log_prob_fn,
......@@ -320,9 +319,10 @@ def run_mcmc(
+ config["num_burst_samples"] * config["num_bursts"],
)
offset = 0
print("Done", flush=True, file=sys.stderr)
# Fast adaptation sampling
print(f"Fast window {first_window_size}")
print(f"Fast window {first_window_size}", file=sys.stderr, flush=True)
draws, trace, step_size, running_variance = _fast_adapt_window(
num_draws=first_window_size,
joint_log_prob_fn=joint_log_prob_fn,
......@@ -335,7 +335,8 @@ def run_mcmc(
current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
draws_to_dict(draws), first_dim_offset=offset,
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += first_window_size
......@@ -344,7 +345,7 @@ def run_mcmc(
hmc_kernel_kwargs["step_size"] = step_size
for slow_window_idx in range(num_slow_windows):
window_num_draws = slow_window_size * (2 ** slow_window_idx)
print(f"Slow window {window_num_draws}")
print(f"Slow window {window_num_draws}", file=sys.stderr, flush=True)
(
draws,
trace,
......@@ -366,13 +367,14 @@ def run_mcmc(
current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
draws_to_dict(draws), first_dim_offset=offset,
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += window_num_draws
# Fast adaptation sampling
print(f"Fast window {last_window_size}")
print(f"Fast window {last_window_size}", file=sys.stderr, flush=True)
dual_averaging_kwargs["num_adaptation_steps"] = last_window_size
draws, trace, step_size, _ = _fast_adapt_window(
num_draws=last_window_size,
......@@ -386,13 +388,14 @@ def run_mcmc(
current_state = [s[-1] for s in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
draws_to_dict(draws), first_dim_offset=offset,
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(trace, first_dim_offset=offset)
offset += last_window_size
# Fixed window sampling
print("Sampling...")
print("Sampling...", file=sys.stderr, flush=True)
hmc_kernel_kwargs["step_size"] = step_size
for i in tqdm.tqdm(
range(config["num_bursts"]),
......@@ -409,10 +412,12 @@ def run_mcmc(
current_state = [state_part[-1] for state_part in draws]
draws[0] = param_bijector.inverse(draws[0])
posterior.write_samples(
draws_to_dict(draws), first_dim_offset=offset,
draws_to_dict(draws),
first_dim_offset=offset,
)
posterior.write_results(
trace, first_dim_offset=offset,
trace,
first_dim_offset=offset,
)
offset += config["num_burst_samples"]
......@@ -428,12 +433,17 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
print("Using CPU")
data = xarray.open_dataset(data_file, group="constant_data")
cases = xarray.open_dataset(data_file, group="observations")["cases"]
cases = xarray.open_dataset(data_file, group="observations")[
"cases"
].astype(DTYPE)
dates = cases.coords["time"]
# We load in cases and impute missing infections first, since this sets the
# time epoch which we are analysing.
# Impute censored events, return cases
events = model_spec.impute_censored_events(cases.astype(DTYPE))
# Take the last week of data, and repeat a further 3 times
# to get a better occult initialisation.
extra_cases = tf.tile(cases[:, -7:], [1, 3])
cases = tf.concat([cases, extra_cases], axis=-1)
events = model_spec.impute_censored_events(cases).numpy()
# Initial conditions are calculated by calculating the state
# at the beginning of the inference period
......@@ -454,7 +464,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
)
start_time = state.shape[1] - cases.shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
events = events[:, start_time:-21, :] # Clip off the "extra" events
########################################################
# Construct the MCMC kernels #
......@@ -471,10 +481,9 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
[
tfb.Softplus(low=dtype_util.eps(DTYPE)),
tfb.Identity(),
tfb.Softplus(low=dtype_util.eps(DTYPE)),
tfb.Identity(),
],
block_sizes=[1, 2, 1, model.event_shape["xi"][0] + 1],
block_sizes=[1, 3, events.shape[1]],
)
)
......@@ -482,12 +491,12 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
params = param_bij.inverse(unconstrained_params)
return model.log_prob(
dict(
beta2=params[0],
gamma0=params[1],
gamma1=params[2],
sigma=params[3],
beta1=params[4],
xi=params[5:],
psi=params[0],
beta_area=params[1],
gamma0=params[2],
gamma1=params[3],
alpha_0=params[4],
alpha_t=params[5:],
seir=events,
)
) + param_bij.inverse_log_det_jacobian(
......@@ -501,9 +510,10 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
current_chain_state = [
tf.concat(
[
np.array([0.6, 0.0, 0.0, 0.1], dtype=DTYPE),
np.zeros(
model.model["xi"](0.0, 0.1).event_shape[-1] + 1,
np.array([0.1, 0.0, 0.0, 0.0], dtype=DTYPE),
np.full(
events.shape[1],
-1.75,
dtype=DTYPE,
),
],
......@@ -511,7 +521,10 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
),
events,
]
print("Initial logpi:", joint_log_prob(*current_chain_state))
print("Num time steps:", events.shape[1], flush=True)
print("alpha_t shape", model.event_shape["alpha_t"], flush=True)
print("Initial chain state:", current_chain_state[0], flush=True)
print("Initial logpi:", joint_log_prob(*current_chain_state), flush=True)
# Output file
posterior = run_mcmc(
......@@ -525,9 +538,7 @@ def mcmc(data_file, output_file, config, use_autograph=False, use_xla=True):
posterior._file.create_dataset("initial_state", data=initial_state)
posterior._file.create_dataset(
"time",
data=np.array(cases.coords["time"])
.astype(str)
.astype(h5py.string_dtype()),
data=np.array(dates).astype(str).astype(h5py.string_dtype()),
)
print(f"Acceptance theta: {posterior['results/hmc/is_accepted'][:].mean()}")
......@@ -559,7 +570,9 @@ if __name__ == "__main__":
"-o", "--output", type=str, help="Output file", required=True
)
parser.add_argument(
"data_file", type=str, help="Data pickle file",
"data_file",
type=str,
help="Data pickle file",
)
args = parser.parse_args()
......
[tool.poetry]
name = "covid19uk"
version = "0.5.0"
version = "0.7.0-alpha.0"
description = "Spatial stochastic SEIR analysis of COVID-19 in the UK"
authors = ["Chris Jewell <c.jewell@lancaster.ac.uk>"]
license = "MIT"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment