Commit 2a381548 authored by Chris Jewell's avatar Chris Jewell
Browse files

Workaround for TFP issue #1127

parent 0eb4326b
......@@ -27,4 +27,4 @@ class Categorical2(tfd.Categorical):
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
return tf.cast(tf.gather(logits_normalised, k, batch_dims=1), tf.float64)
......@@ -69,7 +69,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax,
mmax,
nmax,
seed=None,
name=None,
):
"""An uncalibrated random walk for event times.
......@@ -82,7 +81,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed
:param name: the name of the update step
"""
self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
......@@ -93,7 +91,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax=dmax,
mmax=mmax,
nmax=nmax,
seed=seed,
name=name,
)
self.tx_topology = TransitionTopology(
......@@ -134,7 +131,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def is_calibrated(self):
return False
def one_step(self, current_events, previous_kernel_results):
def one_step(self, current_events, previous_kernel_results, seed=None):
"""One update of event times.
:param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m,
......@@ -155,7 +152,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
d_max=self.parameters["dmax"],
n_max=self.parameters["nmax"],
)
update = proposal.sample()
update = proposal.sample(seed=seed)
move = update["move"]
to_t = move["t"] + move["delta_t"]
......
......@@ -132,7 +132,7 @@ def EventTimeProposal(
def t():
with tf.name_scope("t"):
# Waiting for fixed tf.nn.sparse_softmax_cross_entropy_with_logits
x = tf.cast(target_events > 0, dtype=tf.float64) # [M, T]
x = tf.cast(target_events > 0, dtype=tf.float32) # [M, T]
return Categorical2(logits=tf.math.log(x), name="event_coords")
def x_star(t, delta_t):
......
......@@ -66,7 +66,7 @@ def DelOccultProposal(
)
hot_meta = tf.cast(tf.transpose(hot_meta), dtype=events.dtype)
logits = tf.math.log(hot_meta)
X = Categorical2(logits=logits, dtype=dtype, name="m")
X = Categorical2(logits=tf.cast(logits, tf.float32), dtype=dtype, name="m")
return X
def t(m):
......@@ -80,7 +80,9 @@ def DelOccultProposal(
)
hot_times = tf.cast(hot_times, dtype=events.dtype)
logits = tf.math.log(hot_times)
return Categorical2(logits=logits, dtype=dtype, name="t")
return Categorical2(
logits=tf.cast(logits, tf.float32), dtype=dtype, name="t"
)
def x_star(m, t):
"""Draw num to delete"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment