Commit 588d479e authored by Chris Jewell's avatar Chris Jewell
Browse files

Pulled dates out of CovidUKStochastic class

Changes:

1. Dates are pulled out of CovidUKStochastic
2. CovidUKStochastic now behaves more like a tfd.Distribution
    * CovidUKStochastic is now instantiated with an initial time, number of time steps and
time step size
    * CovidUKStochastic is now instantiated with the initial state.
parent 2d1895bf
......@@ -68,7 +68,7 @@ def discrete_markov_simulation(hazard_fn, state, start, end, time_step, seed=Non
"""Simulates from a discrete time Markov state transition model using multinomial sampling
across rows of the """
propagate = chain_binomial_propagate(hazard_fn, time_step, seed=seed)
times = tf.range(start, end, time_step)
times = tf.range(start, end, time_step, dtype=state.dtype)
state = tf.convert_to_tensor(state)
output = tf.TensorArray(state.dtype, size=times.shape[0])
......
......@@ -79,13 +79,15 @@ def load_data(paths, settings, dtype=DTYPE):
class CovidUK:
def __init__(
self,
initial_state: np.float64,
W: np.float64,
C: np.float64,
N: np.float64,
date_range: list,
xi_freq: int,
params: dict,
initial_state: np.float64,
initial_time: np.float64,
time_step: np.int64,
num_steps: np.int64,
):
"""Represents a CovidUK ODE model
......@@ -97,9 +99,10 @@ class CovidUK:
:param time_step: a time step to use in the discrete time simulation
"""
self.initial_state = initial_state
dtype = dtype_util.common_dtype([W, C, N, initial_state], dtype_hint=DTYPE)
self.initial_state = tf.convert_to_tensor(initial_state, dtype=dtype)
self.initial_time = initial_time
self.n_lads = C.shape[0]
C = tf.convert_to_tensor(C, dtype=dtype)
......@@ -109,24 +112,15 @@ class CovidUK:
self.N = tf.constant(N, dtype=dtype)
self.time_step = time_step
self.times = np.arange(
date_range[0], date_range[1], np.timedelta64(int(time_step), "D")
)
xi_freq = np.int32(xi_freq)
self.xi_select = np.arange(self.times.shape[0], dtype=np.int32) // xi_freq
self.xi_freq = xi_freq
self.max_t = self.xi_select.shape[0] - 1
self.num_steps = num_steps
@property
def xi_times(self):
"""Returns the time indices for beta in units of time_step"""
return np.unique(self.xi_select) * self.xi_freq
self.params = {
k: tf.convert_to_tensor(v, dtype=dtype) for k, v in params.items()
}
@property
def num_xi(self):
"""Return the number of distinct betas"""
return tf.cast(self.xi_select[-1] + 1, tf.int32)
self.xi_freq = np.int32(xi_freq)
self.xi_select = np.arange(self.num_steps, dtype=np.int32) // self.xi_freq
self.max_t = self.xi_select.shape[0] - 1
def create_initial_state(self, init_matrix=None):
I = tf.convert_to_tensor(init_matrix, dtype=DTYPE)
......@@ -137,17 +131,22 @@ class CovidUK:
class CovidUKStochastic(CovidUK):
stoichiometry = tf.constant(
[[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]], dtype=DTYPE
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stoichiometry = tf.constant(
[[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]], dtype=DTYPE
)
def make_h(self, param):
def make_h(self, param=None):
"""Constructs a function that takes `state` and outputs a
transition rate matrix (with 0 diagonal).
"""
if param is None:
param = self.params
def h(t, state):
"""Computes a transition rate matrix
......@@ -156,25 +155,28 @@ class CovidUKStochastic(CovidUK):
contiguously in memory for fast calculation below.
:return a tensor of shape [M, M, S] containing transition matric for each i=0,...,(c-1)
"""
t_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.max_t)
commute_volume = tf.pow(tf.gather(self.W, t_idx), param["omega"])
xi_idx = tf.gather(self.xi_select, t_idx)
xi = tf.gather(param["xi"], xi_idx)
beta = param["beta1"] * tf.math.exp(xi)
w_idx = tf.clip_by_value(tf.cast(t, tf.int64), 0, self.W.shape[0])
commute_volume = tf.gather(self.W, w_idx)
xi_idx = tf.cast(
tf.clip_by_value(t // self.xi_freq, 0, self.params["xi"].shape[0] - 1),
dtype=tf.int64,
)
xi = tf.gather(self.params["xi"], xi_idx)
beta = self.params["beta1"] * tf.math.exp(xi)
infec_rate = beta * (
state[..., 2]
+ param["beta2"]
+ self.params["beta2"]
* commute_volume
* tf.linalg.matvec(self.C, state[..., 2] / self.N)
)
infec_rate = infec_rate / self.N + 0.000000001 # Vector of length nc
ei = tf.broadcast_to(
[param["nu"]], shape=[state.shape[0]]
[self.params["nu"]], shape=[state.shape[0]]
) # Vector of length nc
ir = tf.broadcast_to(
[param["gamma"]], shape=[state.shape[0]]
[self.params["gamma"]], shape=[state.shape[0]]
) # Vector of length nc
return [infec_rate, ei, ir]
......@@ -206,40 +208,37 @@ class CovidUKStochastic(CovidUK):
)
return ngm
def simulate(self, param, state_init, date_range: np.datetime64 = None):
def sample(self, seed=None):
"""Runs a simulation from the epidemic model
:param param: a dictionary of model parameters
:param state_init: the initial state
:returns: a tuple of times and simulated states.
"""
param = {k: tf.convert_to_tensor(v, dtype=tf.float64) for k, v in param.items()}
hazard = self.make_h(param)
if date_range is not None:
start = DTYPE(date_range[0] - self.times[0])
end = DTYPE(date_range[1] - self.times[0])
print(start)
print(end)
else:
start = DTYPE(self.times[0] - self.times[0])
end = DTYPE(self.times[-1] - self.times[0])
hazard = self.make_h()
t, sim = discrete_markov_simulation(
hazard, state_init, start, end, self.time_step,
hazard_fn=hazard,
state=self.initial_state,
start=self.initial_time,
end=self.initial_time + self.num_steps * self.time_step,
time_step=self.time_step,
seed=seed,
)
return t, sim
def log_prob(self, y, param, state_init):
def log_prob(self, y):
"""Calculates the log probability of observing epidemic events y
:param y: a list of tensors. The first is of shape [n_times] containing times,
the second is of shape [n_times, n_states, n_states] containing event matrices.
:param param: a list of parameters
:returns: a scalar giving the log probability of the epidemic
"""
dtype = dtype = dtype_util.common_dtype([y, state_init], dtype_hint=DTYPE)
dtype = dtype = dtype_util.common_dtype(
[y, self.initial_state], dtype_hint=DTYPE
)
y = tf.convert_to_tensor(y, dtype)
state_init = tf.convert_to_tensor(state_init, dtype)
with tf.name_scope("CovidUKStochastic.log_prob"):
hazard = self.make_h(param)
hazard = self.make_h()
return discrete_markov_log_prob(
y, state_init, hazard, self.time_step, self.stoichiometry
y, self.initial_state, hazard, self.time_step, self.stoichiometry
)
......@@ -77,31 +77,31 @@ state = compute_state(
initial_state=tf.concat(
[covar_data["pop"][:, tf.newaxis], tf.zeros_like(events[:, 0, :])], axis=-1
),
events=events, # [:, 1:, :],
stoichiometry=tf.constant(
[[-1, 1, 0, 0], [0, -1, 1, 0], [0, 0, -1, 1]], dtype=DTYPE
),
events=events,
stoichiometry=CovidUKStochastic.stoichiometry,
)
start_time = state.shape[1] - cases.shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]
xi_freq = 14
num_xi = events.shape[1] // xi_freq
num_metapop = covar_data["pop"].shape[0]
# Create initial state, truncate events
# initial_state = tf.concat([covar_data["pop"][:, tf.newaxis], events[:, 0, :]], axis=-1)
# events = events[:, 1:, :]
# Create the model
model = CovidUKStochastic(
C=covar_data["C"],
N=covar_data["pop"],
W=covar_data["W"],
date_range=settings["inference_period"],
xi_freq=14,
initial_state=initial_state,
time_step=1.0,
)
print("Xi_select:", model.xi_select, flush=True)
def build_epidemic(param):
return CovidUKStochastic(
C=covar_data["C"],
N=covar_data["pop"],
W=covar_data["W"],
xi_freq=xi_freq,
params=param,
initial_state=initial_state,
initial_time=0.0,
time_step=1.0,
num_steps=events.shape[1],
)
##########################
# Log p and MCMC kernels #
......@@ -113,29 +113,35 @@ def logp(theta, xi, events):
p["gamma"] = tf.convert_to_tensor(theta[2], dtype=DTYPE)
p["xi"] = tf.convert_to_tensor(xi, dtype=DTYPE)
beta1_logp = tfd.Gamma(
beta1 = tfd.Gamma(
concentration=tf.constant(1.0, dtype=DTYPE), rate=tf.constant(1.0, dtype=DTYPE)
).log_prob(p["beta1"])
)
sigma = tf.constant(0.1, dtype=DTYPE)
phi = tf.constant(12.0, dtype=DTYPE)
kernel = tfp.math.psd_kernels.MaternThreeHalves(sigma, phi)
xi_logp = tfd.GaussianProcess(
kernel, index_points=tf.cast(model.xi_times[:, tf.newaxis], DTYPE)
).log_prob(p["xi"])
idx_pts = tf.cast(tf.range(events.shape[1] // xi_freq) * xi_freq, dtype=DTYPE)
xi = tfd.GaussianProcess(kernel, index_points=idx_pts[:, tf.newaxis])
spatial_beta_logp = tfd.Gamma(
spatial_beta = tfd.Gamma(
concentration=tf.constant(3.0, dtype=DTYPE), rate=tf.constant(10.0, dtype=DTYPE)
).log_prob(p["beta2"])
)
gamma_logp = tfd.Gamma(
gamma = tfd.Gamma(
concentration=tf.constant(100.0, dtype=DTYPE),
rate=tf.constant(400.0, dtype=DTYPE),
).log_prob(p["gamma"])
)
with tf.name_scope("epidemic_log_posterior"):
y_logp = model.log_prob(events, p, initial_state)
logp = beta1_logp + spatial_beta_logp + gamma_logp + xi_logp + y_logp
return logp
seir = build_epidemic(p)
return (
beta1.log_prob(p["beta1"])
+ xi.log_prob(p["xi"])
+ spatial_beta.log_prob(p["beta2"])
+ gamma.log_prob(p["gamma"])
+ seir.log_prob(events)
)
# Pavel's suggestion for a Gibbs kernel requires
......@@ -230,7 +236,7 @@ def trace_results_fn(_, results):
return recurse(f, results)
@tf.function(autograph=False, experimental_compile=True)
@tf.function # (autograph=False, experimental_compile=True)
def sample(n_samples, init_state, sigma_theta, sigma_xi, num_event_updates):
with tf.name_scope("main_mcmc_sample_loop"):
......@@ -278,7 +284,7 @@ tf.random.set_seed(2)
current_state = [
np.array([0.85, 0.3, 0.25], dtype=DTYPE),
np.zeros(model.num_xi, dtype=DTYPE),
np.zeros(num_xi, dtype=DTYPE),
events,
]
......@@ -315,10 +321,10 @@ output_results = [
posterior.create_dataset("results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
posterior.create_dataset("results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,),
posterior.create_dataset(
"results/move/S->E", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
"results/move/S->E", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
),
posterior.create_dataset(
"results/move/E->I", (NUM_SAVED_SAMPLES, 3 + model.N.shape[0]), dtype=DTYPE,
"results/move/E->I", (NUM_SAVED_SAMPLES, 3 + num_metapop), dtype=DTYPE,
),
posterior.create_dataset(
"results/occult/S->E", (NUM_SAVED_SAMPLES, 6), dtype=DTYPE
......@@ -341,7 +347,7 @@ theta_scale = tf.constant(
)
theta_scale = theta_scale * 0.2 / theta_scale.shape[0]
xi_scale = tf.eye(model.num_xi, dtype=DTYPE)
xi_scale = tf.eye(current_state[1].shape[0], dtype=DTYPE)
xi_scale = xi_scale * 0.001 / xi_scale.shape[0]
# We loop over successive calls to sample because we have to dump results
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment