Commit 558a45d9 authored by Chris Jewell's avatar Chris Jewell
Browse files

MCMC updaters modified

MCMC update kernels in `interface.py` changed to reflect new
AdaptiveRandomWalkMetropolis interface.
parent f361a5ae
......@@ -101,7 +101,13 @@ if __name__ == "__main__":
# $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
def logp(theta, xi, events):
return model.log_prob(
dict(beta1=xi[0], beta2=theta[0], gamma=theta[1], xi=xi[1:], seir=events,)
dict(
beta1=xi[0],
beta2=theta[0],
gamma=theta[1],
xi=xi[1:],
seir=events,
)
)
# Build Metropolis within Gibbs sampler
......@@ -114,12 +120,13 @@ if __name__ == "__main__":
# Q(Z^{se}, Z^{se\prime}) (occult)
# Q(Z^{ei}, Z^{ei\prime}) (occult)
def make_theta_kernel(shape, name):
def fn(target_log_prob_fn, state):
def fn(target_log_prob_fn, _):
return tfp.mcmc.TransformedTransitionKernel(
inner_kernel=AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_state=tf.zeros(shape, dtype=model_spec.DTYPE),
initial_covariance=[np.eye(shape[0]) * 1e-1],
initial_covariance=[
np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
],
covariance_burnin=200,
),
bijector=tfp.bijectors.Exp(),
......@@ -129,11 +136,12 @@ if __name__ == "__main__":
return fn
def make_xi_kernel(shape, name):
def fn(target_log_prob_fn, state):
def fn(target_log_prob_fn, _):
return AdaptiveRandomWalkMetropolis(
target_log_prob_fn=target_log_prob_fn,
initial_state=tf.ones(shape, dtype=model_spec.DTYPE),
initial_covariance=[np.eye(shape[0]) * 1e-1],
initial_covariance=[
np.eye(shape[0], dtype=model_spec.DTYPE) * 1e-1
],
covariance_burnin=200,
name=name,
)
......@@ -143,7 +151,7 @@ if __name__ == "__main__":
def make_partially_observed_step(
target_event_id, prev_event_id=None, next_event_id=None, name=None
):
def fn(target_log_prob_fn, state):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
......@@ -161,7 +169,7 @@ if __name__ == "__main__":
return fn
def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
def fn(target_log_prob_fn, state):
def fn(target_log_prob_fn, _):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedOccultUpdate(
target_log_prob_fn=target_log_prob_fn,
......@@ -178,7 +186,7 @@ if __name__ == "__main__":
return fn
def make_event_multiscan_kernel(target_log_prob_fn, state):
def make_event_multiscan_kernel(target_log_prob_fn, _):
return MultiScanKernel(
config["mcmc"]["num_event_time_updates"],
GibbsKernel(
......@@ -261,7 +269,7 @@ if __name__ == "__main__":
current_state = [
np.array([0.65, 0.48], dtype=DTYPE),
np.zeros(model.model["xi"](0.).event_shape[-1]+1, dtype=DTYPE),
np.zeros(model.model["xi"](0.0).event_shape[-1] + 1, dtype=DTYPE),
events,
]
......@@ -290,7 +298,9 @@ if __name__ == "__main__":
dtype=np.float64,
)
xi_samples = posterior.create_dataset(
"samples/xi", [NUM_SAVED_SAMPLES, current_state[1].shape[0]], dtype=np.float64,
"samples/xi",
[NUM_SAVED_SAMPLES, current_state[1].shape[0]],
dtype=np.float64,
)
event_samples = posterior.create_dataset(
"samples/events",
......@@ -303,14 +313,10 @@ if __name__ == "__main__":
output_results = [
posterior.create_dataset(
"results/theta",
(NUM_SAVED_SAMPLES, 3),
dtype=DTYPE,
"results/theta", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
),
posterior.create_dataset(
"results/xi",
(NUM_SAVED_SAMPLES, 3),
dtype=DTYPE,
"results/xi", (NUM_SAVED_SAMPLES, 3), dtype=DTYPE,
),
posterior.create_dataset(
"results/move/S->E",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment