Commit fea8ed43 authored by Chris Jewell's avatar Chris Jewell
Browse files

Further changes for dependency on gemlib

parent b2c6806a
"""MCMC Test Rig for COVID-19 UK model"""
# pylint: disable=E402
import argparse
import os
# Uncomment to block GPU use
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from time import perf_counter
import tqdm
......@@ -19,16 +14,15 @@ import tensorflow_probability as tfp
from tensorflow_probability.python.experimental import unnest
from covid.impl.util import compute_state
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
from covid.impl.occult_events_mh import UncalibratedOccultUpdate, TransitionTopology
from covid.impl.gibbs import flatten_results
from covid.impl.gibbs_kernel import GibbsKernel, GibbsKernelResults
from covid.impl.multi_scan_kernel import MultiScanKernel
from covid.impl.adaptive_random_walk_metropolis import (
AdaptiveRandomWalkMetropolisHastings,
)
from gemlib.util import compute_state
from gemlib.mcmc import UncalibratedEventTimesUpdate
from gemlib.mcmc import UncalibratedOccultUpdate, TransitionTopology
from gemlib.mcmc import GibbsKernel
from gemlib.mcmc.gibbs_kernel import GibbsKernelResults
from gemlib.mcmc.gibbs_kernel import flatten_results
from gemlib.mcmc import MultiScanKernel
from gemlib.mcmc import AdaptiveRandomWalkMetropolis
from covid.data import read_phe_cases
from covid.cli_arg_parse import cli_args
......@@ -58,7 +52,9 @@ if __name__ == "__main__":
]
covar_data = model_spec.read_covariates(
config["data"], date_low=inference_period[0], date_high=inference_period[1],
config["data"],
date_low=inference_period[0],
date_high=inference_period[1],
)
# We load in cases and impute missing infections first, since this sets the
......@@ -82,7 +78,8 @@ if __name__ == "__main__":
# to set up a sensible initial state.
state = compute_state(
initial_state=tf.concat(
[covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], axis=-1
[covar_data["N"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])],
axis=-1,
),
events=events,
stoichiometry=model_spec.STOICHIOMETRY,
......@@ -106,7 +103,13 @@ if __name__ == "__main__":
# $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
def logp(theta, xi, events):
return model.log_prob(
dict(beta1=theta[0], beta2=theta[1], gamma=theta[2], xi=xi, seir=events,)
dict(
beta1=theta[0],
beta2=theta[1],
gamma=theta[2],
xi=xi,
seir=events,
)
)
# Build Metropolis within Gibbs sampler
......@@ -121,7 +124,7 @@ if __name__ == "__main__":
def make_theta_kernel(shape, name):
def fn(target_log_prob_fn, state):
return tfp.mcmc.TransformedTransitionKernel(
inner_kernel=AdaptiveRandomWalkMetropolisHastings(
inner_kernel=AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_state=tf.zeros(shape, dtype=model_spec.DTYPE),
initial_covariance=[np.eye(shape[0]) * 1e-1],
......@@ -135,7 +138,7 @@ if __name__ == "__main__":
def make_xi_kernel(shape, name):
def fn(target_log_prob_fn, state):
return AdaptiveRandomWalkMetropolisHastings(
return AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_state=tf.ones(shape, dtype=model_spec.DTYPE),
initial_covariance=[np.eye(shape[0]) * 1e-1],
......@@ -211,7 +214,9 @@ if __name__ == "__main__":
q_ratio = proposed_results.log_acceptance_correction
if hasattr(proposed_results, "extra"):
proposed = tf.cast(proposed_results.extra, log_prob.dtype)
return tf.concat([[log_prob], [accepted], [q_ratio], proposed], axis=0)
return tf.concat(
[[log_prob], [accepted], [q_ratio], proposed], axis=0
)
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
def recurse(f, results):
......@@ -293,7 +298,9 @@ if __name__ == "__main__":
dtype=np.float64,
)
xi_samples = posterior.create_dataset(
"samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
"samples/xi",
[NUM_SAVED_SAMPLES, current_state[1].shape[0]],
dtype=np.float64,
)
event_samples = posterior.create_dataset(
"samples/events",
......@@ -305,13 +312,25 @@ if __name__ == "__main__":
)
output_results = [
posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
posterior.create_dataset(
"results/move/S->E", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
"results/theta",
(NUM_SAVED_SAMPLES, 3),
dtype=DTYPE,
),
posterior.create_dataset(
"results/xi",
(NUM_SAVED_SAMPLES, 3),
dtype=DTYPE,
),
posterior.create_dataset(
"results/move/E->I", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
"results/move/S->E",
(NUM_SAVED_SAMPLES, 3 + num_metapop),
dtype=DTYPE,
),
posterior.create_dataset(
"results/move/E->I",
(NUM_SAVED_SAMPLES, 3 + num_metapop),
dtype=DTYPE,
),
posterior.create_dataset(
"results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
......@@ -330,10 +349,14 @@ if __name__ == "__main__":
final_results = None
for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
samples, results, final_results = sample(
NUM_BURST_SAMPLES, init_state=current_state, previous_results=final_results,
NUM_BURST_SAMPLES,
init_state=current_state,
previous_results=final_results,
)
current_state = [s[-1] for s in samples]
s = slice(i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES)
s = slice(
i * THIN_BURST_SAMPLES, i * THIN_BURST_SAMPLES + THIN_BURST_SAMPLES
)
idx = tf.constant(range(0, NUM_BURST_SAMPLES, config["mcmc"]["thin"]))
theta_samples[s, ...] = tf.gather(samples[0], idx)
xi_samples[s, ...] = tf.gather(samples[1], idx)
......@@ -361,7 +384,8 @@ if __name__ == "__main__":
tf.reduce_mean(tf.cast(flat_results[0][:, 1], tf.float32)),
)
print(
"Acceptance xi:", tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32))
"Acceptance xi:",
tf.reduce_mean(tf.cast(flat_results[1][:, 1], tf.float32)),
)
print(
"Acceptance move S->E:",
......
"""Implements the COVID SEIR model as a TFP Joint Distribution"""
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from covid.model import DiscreteTimeStateTransitionModel
from gemlib.distributions import DiscreteTimeStateTransitionModel
from covid.util import impute_previous_cases
import covid.data as data
......@@ -104,7 +103,8 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1), dtype=tf.int64,
tf.clip_by_value(t // XI_FREQ, 0, xi.shape[0] - 1),
dtype=tf.int64,
)
xi_ = tf.gather(xi, xi_idx)
beta = beta1 * tf.math.exp(xi_)
......@@ -115,10 +115,16 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
* commute_volume
* tf.linalg.matvec(C, state[..., 2] / tf.squeeze(N))
)
infec_rate = infec_rate / tf.squeeze(N) + 0.000000001 # Vector of length nc
infec_rate = (
infec_rate / tf.squeeze(N) + 0.000000001
) # Vector of length nc
ei = tf.broadcast_to([NU], shape=[state.shape[0]]) # Vector of length nc
ir = tf.broadcast_to([gamma], shape=[state.shape[0]]) # Vector of length nc
ei = tf.broadcast_to(
[NU], shape=[state.shape[0]]
) # Vector of length nc
ir = tf.broadcast_to(
[gamma], shape=[state.shape[0]]
) # Vector of length nc
return [infec_rate, ei, ir]
......@@ -137,29 +143,32 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
def next_generation_matrix_fn(covar_data, param):
"""The next generation matrix calculates the force of infection from
individuals in metapopulation i to all other metapopulations j during
a typical infectious period (1/gamma). i.e.
\[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]
:param covar_data: a dictionary of covariate data
:param param: a dictionary of parameters
:returns: a function taking arguments `t` and `state` giving the time and
epidemic state (SEIR) for which the NGM is to be calculated. This
function in turn returns an MxM next generation matrix.
"""The next generation matrix calculates the force of infection from
individuals in metapopulation i to all other metapopulations j during
a typical infectious period (1/gamma). i.e.
\[ A_{ij} = S_j * \beta_1 ( 1 + \beta_2 * w_t * C_{ij} / N_i) / N_j / gamma \]
:param covar_data: a dictionary of covariate data
:param param: a dictionary of parameters
:returns: a function taking arguments `t` and `state` giving the time and
epidemic state (SEIR) for which the NGM is to be calculated. This
function in turn returns an MxM next generation matrix.
"""
def fn(t, state):
C = tf.convert_to_tensor(covar_data["C"], dtype=DTYPE)
C = tf.linalg.set_diag(C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE))
C = tf.linalg.set_diag(
C + tf.transpose(C), tf.zeros(C.shape[0], dtype=DTYPE)
)
W = tf.constant(covar_data["W"], dtype=DTYPE)
N = tf.constant(covar_data["N"], dtype=DTYPE)
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, W.shape[0] - 1)
commute_volume = tf.gather(W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1), dtype=tf.int64,
tf.clip_by_value(t // XI_FREQ, 0, param["xi"].shape[0] - 1),
dtype=tf.int64,
)
xi = tf.gather(param["xi"], xi_idx)
beta = param["beta1"] * tf.math.exp(xi)
......@@ -168,7 +177,11 @@ def next_generation_matrix_fn(covar_data, param):
tf.eye(C.shape[0], dtype=state.dtype)
+ param["beta2"] * commute_volume * C / N[tf.newaxis, :]
)
ngm = ngm * state[..., 0][..., tf.newaxis] / (N[:, tf.newaxis] * param["gamma"])
ngm = (
ngm
* state[..., 0][..., tf.newaxis]
/ (N[:, tf.newaxis] * param["gamma"])
)
return ngm
return fn
......@@ -323,23 +323,23 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "gemlib-tfp-extra"
name = "gemlib"
version = "0.1.0"
description = "Add on classes for Tensorflow Probability used in GEM"
description = "GEMlib scientific compute library for epidemic modelling"
category = "main"
optional = false
python-versions = "^3.7"
[package.dependencies]
numpy = "^1.18.5"
tf-nightly = "^2.4.0-alpha.20201021"
tfp-nightly = "^0.12.0-alpha.20201021"
tf-nightly = "2.4.0-alpha.20201021"
tfp-nightly = "0.12.0-alpha.20201021"
[package.source]
type = "git"
url = "ssh://git@fhm-chicas-code.lancs.ac.uk/GEM/gemlib-tfp-extra.git"
url = "http://fhm-chicas-code.lancs.ac.uk/GEM/gemlib.git"
reference = "master"
resolved_reference = "39f5ba71e16feedc2f29ea6018f972534b7ac73b"
resolved_reference = "85b0050780a3042a3527fed12857c290878d5b00"
[[package]]
name = "geopandas"
......@@ -1348,7 +1348,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "tb-nightly"
version = "2.4.0a20201021"
version = "2.4.0a20201022"
description = "TensorBoard lets you watch Tensors Flow"
category = "main"
optional = false
......@@ -1610,7 +1610,7 @@ testing = ["pytest (>=3.5,<3.7.3 || >3.7.3)", "pytest-checkdocs (>=1.2.3)", "pyt
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
content-hash = "bbfab17e05f19cd163dc5f9d343ece67b4ed8e1cbd0dd90e0c18ee27d7d5886b"
content-hash = "395b276b10f26a5b382015a38a3f6e7465fde3b785efe0ff59b4dd79f01f57a6"
[metadata.files]
absl-py = [
......@@ -1803,7 +1803,7 @@ gast = [
{file = "gast-0.3.3-py2.py3-none-any.whl", hash = "sha256:8f46f5be57ae6889a4e16e2ca113b1703ef17f2b0abceb83793eaba9e1351a45"},
{file = "gast-0.3.3.tar.gz", hash = "sha256:b881ef288a49aa81440d2c5eb8aeefd4c2bb8993d5f50edae7413a85bfdb3b57"},
]
gemlib-tfp-extra = []
gemlib = []
geopandas = [
{file = "geopandas-0.8.1-py2.py3-none-any.whl", hash = "sha256:ef90a7f5ce9337f412c1dab9e014b4076b639fb7fc0edcf8b57c252c91a096c4"},
{file = "geopandas-0.8.1.tar.gz", hash = "sha256:e28a729e44ac53c1891b54b1aca60e3bc0bb9e88ad0f2be8e301a03b9510f6e2"},
......@@ -2510,7 +2510,7 @@ six = [
{file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
]
tb-nightly = [
{file = "tb_nightly-2.4.0a20201021-py3-none-any.whl", hash = "sha256:c2bb215344f162350b19edf0d9c148f2c919b3fab7f03efe57f35f1a04044756"},
{file = "tb_nightly-2.4.0a20201022-py3-none-any.whl", hash = "sha256:46269dbf19664fe060de937f058a7b3eeae61f9a0858936cadd3067fbda32b15"},
]
tensorboard-plugin-wit = [
{file = "tensorboard_plugin_wit-1.7.0-py3-none-any.whl", hash = "sha256:ee775f04821185c90d9a0e9c56970ee43d7c41403beb6629385b39517129685b"},
......
......@@ -9,7 +9,6 @@ license = "MIT"
python = "^3.7"
pandas = "^1.1.3"
geopandas = "^0.8.1"
gemlib-tfp-extra = {git = "ssh://git@fhm-chicas-code.lancs.ac.uk/GEM/gemlib-tfp-extra.git"}
mapclassify = "^2.3.0"
PyYAML = "^5.3.1"
descartes = "^1.1.0"
......@@ -18,6 +17,7 @@ xlrd = "^1.2.0"
tqdm = "^4.50.2"
openpyxl = "^3.0.5"
h5py = "^2.10.0"
gemlib = {git = "http://fhm-chicas-code.lancs.ac.uk/GEM/gemlib.git"}
[tool.poetry.dev-dependencies]
ipython = "^7.18.1"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment