Commit 5bf61083 authored by Chris Jewell's avatar Chris Jewell
Browse files

Non-centering of spatial effect variance

parent 7883c91b
......@@ -27,9 +27,20 @@ NU = tf.constant(0.28, dtype=DTYPE) # E->I rate assumed known.
def _compute_adjacency_matrix(geom, names, tol=200):
mat = geom.apply(lambda x: geom.distance(x) < tol)
mat = geom.apply(lambda x: geom.distance(x) < tol).to_numpy()
np.fill_diagonal(mat, False)
# Fix for islands > tol apart
num_neighbours = mat.sum(axis=-1)
islands = np.where(num_neighbours == 0)[0]
closest_neighbour = [
geom.distance(geom.iloc[i]).argsort()[1] for i in islands
]
mat[islands, closest_neighbour] = True
mat = mat | mat.T # Ensure symmetry
return xarray.DataArray(
mat.to_numpy().astype(DTYPE),
mat.astype(DTYPE), # Coerce to global float type
coords=[names, names],
dims=["location_dest", "location_src"],
)
......@@ -157,28 +168,22 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
"""Variance of CAR prior on space"""
return tfd.HalfNormal(scale=tf.constant(0.1, dtype=DTYPE))
def spatial_effect(sigma_space):
# W = covariates["adjacency"]
# W = tf.linalg.set_diag(W, tf.zeros(W.shape[0], dtype=W.dtype))
# Dw = tf.linalg.diag(tf.reduce_sum(W, axis=-1))
# rho = 0.5
# tf.print("Sigma space:", sigma_space)
# precision = (
# 1.0 / sigma_space * (Dw - rho * W)
# + tf.eye(Dw.shape[0], dtype=DTYPE) * 1e-3
# )
# precision_factor = tf.linalg.cholesky(precision)
# return tfp.experimental.distributions.MultivariateNormalPrecisionFactorLinearOperator(
# loc=tf.constant(0.0, DTYPE),
# precision_factor=tf.linalg.LinearOperatorFullMatrix(
# precision_factor
# ),
# )
return tfd.MultivariateNormalDiag(
loc=tf.constant(0.0, dtype=DTYPE),
scale_diag=tf.ones(covariates["adjacency"].shape[0], dtype=DTYPE),
def spatial_effect():
W = tf.convert_to_tensor(covariates["adjacency"])
Dw = tf.linalg.diag(tf.reduce_sum(W, axis=-1)) # row sums
rho = 0.5
precision = Dw - rho * W
cov = tf.linalg.inv(precision)
scale = tf.linalg.cholesky(cov)
return tfd.MultivariateNormalTriL(
loc=tf.constant(0.0, DTYPE),
scale_tril=scale,
)
# return tfd.MultivariateNormalDiag(
# loc=tf.constant(0.0, dtype=DTYPE),
# scale_diag=tf.ones(covariates["adjacency"].shape[0], dtype=DTYPE)
# * sigma_space,
# )
def gamma0():
return tfd.Normal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment