Commit 3e15f0cf authored by Chris Jewell's avatar Chris Jewell
Browse files

Added contraint-obeying occult deletion.

parent 2bb80563
......@@ -117,7 +117,12 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
x=self.tx_topology.target,
x_star=tf.cast(update["x_star"], current_events.dtype),
)
reverse = DelOccultProposal(next_state, self.tx_topology)
reverse = DelOccultProposal(
events=next_state,
topology=self.tx_topology,
t_range=self.parameters["t_range"],
n_max=self.parameters["nmax"],
)
q_fwd = tf.reduce_sum(proposal.log_prob(update))
q_rev = tf.reduce_sum(reverse.log_prob(update))
log_acceptance_correction = q_rev - q_fwd
......@@ -125,7 +130,12 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
def del_occult_fn():
with tf.name_scope("false_fn"):
proposal = DelOccultProposal(current_events, self.tx_topology)
proposal = DelOccultProposal(
events=current_events,
topology=self.tx_topology,
t_range=self.parameters["t_range"],
n_max=self.parameters["nmax"],
)
update = proposal.sample()
next_state = _add_events(
events=current_events,
......
......@@ -44,13 +44,16 @@ def AddOccultProposal(events, topology, n_max, t_range=None, dtype=tf.int32, nam
return tfd.JointDistributionNamed(dict(m=m, t=t, x_star=x_star), name=name)
def DelOccultProposal(events, topology, dtype=tf.int32, name=None):
def DelOccultProposal(events, topology, n_max, t_range=None, dtype=tf.int32, name=None):
if t_range is None:
t_range = [0, events.shape[-2]]
def m():
"""Select a metapopulation"""
with tf.name_scope("m"):
hot_meta = (
tf.math.count_nonzero(
events[..., topology.target], axis=1, keepdims=True
events[..., slice(*t_range), topology.target], axis=1, keepdims=True
)
> 0
)
......@@ -63,7 +66,11 @@ def DelOccultProposal(events, topology, dtype=tf.int32, name=None):
"""Draw timepoint"""
with tf.name_scope("t"):
metapops = tf.gather(events, m)
hot_times = metapops[..., topology.target] > 0
hot_times = (
(metapops[..., topology.target] > 0)
& (t_range[0] <= tf.range(events.shape[-2]))
& (tf.range(events.shape[-2]) < t_range[1])
)
hot_times = tf.cast(hot_times, dtype=events.dtype)
logits = tf.math.log(hot_times)
return Categorical2(logits=logits, dtype=dtype, name="t")
......@@ -71,10 +78,21 @@ def DelOccultProposal(events, topology, dtype=tf.int32, name=None):
def x_star(m, t):
"""Draw num to delete"""
with tf.name_scope("x_star"):
indices = tf.stack([m, t, [topology.target]], axis=-1)
max_occults = tf.gather_nd(events, indices)
return UniformInteger(
low=0, high=max_occults + 1, dtype=dtype, name="x_star"
)
if topology.next is not None:
mask = ( # Mask out times prior to t
tf.cast(tf.range(events.shape[-2]) < t[0], events.dtype)
* events.dtype.max
)
m_events = tf.gather(events, m, axis=-3)
diff = m_events[..., topology.target] - m_events[..., topology.next]
diff = tf.cumsum(diff, axis=-1)
diff = diff + mask
bound = tf.cast(tf.reduce_min(diff, axis=-1), dtype=tf.int32)
bound = tf.minimum(n_max, bound)
else:
bound = tf.broadcast_to(n_max, m.shape)
return UniformInteger(low=0, high=bound + 1, dtype=dtype, name="x_star")
return tfd.JointDistributionNamed(dict(m=m, t=t, x_star=x_star), name=name)
......@@ -173,7 +173,7 @@ def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
prev_event_id, target_event_id, next_event_id
),
nmax=config["mcmc"]["occult_nmax"],
t_range=(se_events.shape[1] - 22, se_events.shape[1] - 1),
t_range=(se_events.shape[1] - 21, se_events.shape[1]),
name=name,
)
),
......@@ -319,7 +319,7 @@ theta_scale = tf.constant(
theta_scale = theta_scale * 0.2 / theta_scale.shape[0]
xi_scale = tf.eye(model.num_xi, dtype=DTYPE)
xi_scale = xi_scale * 0.0001 / xi_scale.shape[0]
xi_scale = xi_scale * 0.001 / xi_scale.shape[0]
# We loop over successive calls to sample because we have to dump results
# to disc, or else end OOM (even on a 32GB system).
......@@ -376,10 +376,11 @@ for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
tf.reduce_mean(tf.cast(flat_results[5][:, 1], tf.float32)),
)
print(f"Acceptance param: {output_results[0][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[1][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[2][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[4][:, 1].mean()}")
print(f"Acceptance theta: {output_results[0][:, 1].mean()}")
print(f"Acceptance xi: {output_results[1][:, 1].mean()}")
print(f"Acceptance move S->E: {output_results[2][:, 1].mean()}")
print(f"Acceptance move E->I: {output_results[3][:, 1].mean()}")
print(f"Acceptance occult S->E: {output_results[4][:, 1].mean()}")
print(f"Acceptance occult E->I: {output_results[5][:, 1].mean()}")
posterior.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment