Commit ca7e464c authored by Chris Jewell's avatar Chris Jewell
Browse files

Merged in fix code for tf seeding issue.

parent e80ed802
...@@ -27,4 +27,4 @@ class Categorical2(tfd.Categorical): ...@@ -27,4 +27,4 @@ class Categorical2(tfd.Categorical):
k, logits, base_dtype=dtype_util.base_dtype(self.dtype) k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
) )
logits_normalised = tf.math.log(tf.math.softmax(logits)) logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1) return tf.cast(tf.gather(logits_normalised, k, batch_dims=1), tf.float64)
...@@ -69,7 +69,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): ...@@ -69,7 +69,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax, dmax,
mmax, mmax,
nmax, nmax,
seed=None,
name=None, name=None,
): ):
"""An uncalibrated random walk for event times. """An uncalibrated random walk for event times.
...@@ -82,7 +81,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): ...@@ -82,7 +81,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed :param seed: a random seed
:param name: the name of the update step :param name: the name of the update step
""" """
self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
self._name = name self._name = name
self._parameters = dict( self._parameters = dict(
target_log_prob_fn=target_log_prob_fn, target_log_prob_fn=target_log_prob_fn,
...@@ -93,7 +91,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): ...@@ -93,7 +91,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax=dmax, dmax=dmax,
mmax=mmax, mmax=mmax,
nmax=nmax, nmax=nmax,
seed=seed,
name=name, name=name,
) )
self.tx_topology = TransitionTopology( self.tx_topology = TransitionTopology(
...@@ -134,7 +131,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): ...@@ -134,7 +131,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def is_calibrated(self): def is_calibrated(self):
return False return False
def one_step(self, current_events, previous_kernel_results): def one_step(self, current_events, previous_kernel_results, seed=None):
"""One update of event times. """One update of event times.
:param current_events: a [T, M, X] tensor containing number of events :param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m, per time t, metapopulation m,
...@@ -155,7 +152,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel): ...@@ -155,7 +152,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
d_max=self.parameters["dmax"], d_max=self.parameters["dmax"],
n_max=self.parameters["nmax"], n_max=self.parameters["nmax"],
) )
update = proposal.sample() update = proposal.sample(seed=seed)
move = update["move"] move = update["move"]
to_t = move["t"] + move["delta_t"] to_t = move["t"] + move["delta_t"]
......
...@@ -132,7 +132,7 @@ def EventTimeProposal( ...@@ -132,7 +132,7 @@ def EventTimeProposal(
def t(): def t():
with tf.name_scope("t"): with tf.name_scope("t"):
# Waiting for fixed tf.nn.sparse_softmax_cross_entropy_with_logits # Waiting for fixed tf.nn.sparse_softmax_cross_entropy_with_logits
x = tf.cast(target_events > 0, dtype=tf.float64) # [M, T] x = tf.cast(target_events > 0, dtype=tf.float32) # [M, T]
return Categorical2(logits=tf.math.log(x), name="event_coords") return Categorical2(logits=tf.math.log(x), name="event_coords")
def x_star(t, delta_t): def x_star(t, delta_t):
......
...@@ -63,51 +63,12 @@ class UncalibratedLogRandomWalk(tfp.mcmc.UncalibratedRandomWalk): ...@@ -63,51 +63,12 @@ class UncalibratedLogRandomWalk(tfp.mcmc.UncalibratedRandomWalk):
for s in current_state_parts for s in current_state_parts
] ]
# Seed handling complexity is due to users possibly expecting an old-style seed = samplers.sanitize_seed(seed)
# stateful seed to be passed to `self.new_state_fn`.
# In other words:
# - If we were given a seed, we sanitize it to stateless, and
# if the `new_state_fn` doesn't like that, we crash and propagate
# the error. Rationale: The contract is stateless sampling given
# seed, and doing otherwise would not meet it.
# - If we were not given a seed, we try `new_state_fn` with a stateless
# seed. Rationale: This is the future.
# - If it fails with a seed incompatibility problem (as best we can
# detect from here), we issue a warning and try it again with a
# stateful-style seed. Rationale: User code that didn't set seeds
# shouldn't suddenly break.
# TODO(b/159636942): Clean up after 2020-09-20.
if seed is not None:
force_stateless = True
seed = samplers.sanitize_seed(seed)
else:
force_stateless = False
if self._seed_stream.original_seed is not None:
warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
stateful_seed = self._seed_stream()
seed = samplers.sanitize_seed(stateful_seed)
try:
# Log random walk
next_state_parts = self.new_state_fn( # pylint: disable=not-callable
[tf.zeros_like(s) for s in current_state_parts], seed,
)
except TypeError as e:
if (
"Expected int for argument" not in str(e)
and TENSOR_SEED_MSG_PREFIX not in str(e)
) or force_stateless:
raise
msg = (
"Falling back to `int` seed for `new_state_fn` {}. Please update "
"to use `tf.random.stateless_*` RNGs. "
"This fallback may be removed after 10-Sep-2020. ({})"
)
warnings.warn(msg.format(self.new_state_fn, str(e)))
seed = None
next_state_parts = self.new_state_fn(
[tf.zeros_lik(s) for s in current_state_parts], stateful_seed
)
# Log random walk
next_state_parts = self.new_state_fn( # pylint: disable=not-callable
[tf.zeros_like(s) for s in current_state_parts], seed,
)
next_state_parts = [ next_state_parts = [
cs * tf.exp(ns) for cs, ns in zip(current_state_parts, next_state_parts) cs * tf.exp(ns) for cs, ns in zip(current_state_parts, next_state_parts)
] ]
......
...@@ -66,7 +66,7 @@ def DelOccultProposal( ...@@ -66,7 +66,7 @@ def DelOccultProposal(
) )
hot_meta = tf.cast(tf.transpose(hot_meta), dtype=events.dtype) hot_meta = tf.cast(tf.transpose(hot_meta), dtype=events.dtype)
logits = tf.math.log(hot_meta) logits = tf.math.log(hot_meta)
X = Categorical2(logits=logits, dtype=dtype, name="m") X = Categorical2(logits=tf.cast(logits, tf.float32), dtype=dtype, name="m")
return X return X
def t(m): def t(m):
...@@ -80,7 +80,9 @@ def DelOccultProposal( ...@@ -80,7 +80,9 @@ def DelOccultProposal(
) )
hot_times = tf.cast(hot_times, dtype=events.dtype) hot_times = tf.cast(hot_times, dtype=events.dtype)
logits = tf.math.log(hot_times) logits = tf.math.log(hot_times)
return Categorical2(logits=logits, dtype=dtype, name="t") return Categorical2(
logits=tf.cast(logits, tf.float32), dtype=dtype, name="t"
)
def x_star(m, t): def x_star(m, t):
"""Draw num to delete""" """Draw num to delete"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment