Commit 58f6963d authored by Chris Jewell's avatar Chris Jewell
Browse files

Switch to `tfd_e.MultivariateNormalPrecisionFactorLinearOperator`

Potential speed improvement, cutting out an explicit matrix
inverse.
parent bd5ca69d
......@@ -18,6 +18,7 @@ from covid19uk.data import read_population
from covid19uk.data import read_traffic_flow
tfd = tfp.distributions
tfd_e = tfp.experimental.distributions
DTYPE = np.float64
......@@ -70,10 +71,7 @@ def gather_data(config):
adjacency = _compute_adjacency_matrix(geo.geometry, geo["lad19cd"], 200)
area = xarray.DataArray(
geo.area,
name="area",
dims=["location"],
coords=[geo["lad19cd"]],
geo.area, name="area", dims=["location"], coords=[geo["lad19cd"]],
)
dates = pd.date_range(*config["date_range"], closed="left")
......@@ -173,11 +171,12 @@ def CovidUK(covariates, initial_state, initial_step, num_steps):
Dw = tf.linalg.diag(tf.reduce_sum(W, axis=-1)) # row sums
rho = 0.25
precision = Dw - rho * W
cov = tf.linalg.inv(precision)
scale = tf.linalg.cholesky(cov)
return tfd.MultivariateNormalTriL(
precision_factor = tf.linalg.cholesky(precision)
return tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
loc=tf.constant(0.0, DTYPE),
scale_tril=scale,
precision_factor=tf.linalg.LinearOperatorFullMatrix(
precision_factor
),
)
# return tfd.MultivariateNormalDiag(
# loc=tf.constant(0.0, dtype=DTYPE),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment