Commit ca7e464c authored by Chris Jewell's avatar Chris Jewell
Browse files

Merged in fix code for tf seeding issue.

parent e80ed802
......@@ -27,4 +27,4 @@ class Categorical2(tfd.Categorical):
k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
)
logits_normalised = tf.math.log(tf.math.softmax(logits))
return tf.gather(logits_normalised, k, batch_dims=1)
return tf.cast(tf.gather(logits_normalised, k, batch_dims=1), tf.float64)
......@@ -69,7 +69,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax,
mmax,
nmax,
seed=None,
name=None,
):
"""An uncalibrated random walk for event times.
......@@ -82,7 +81,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param seed: a random seed
:param name: the name of the update step
"""
self._seed_stream = SeedStream(seed, salt="UncalibratedEventTimesUpdate")
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
......@@ -93,7 +91,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
dmax=dmax,
mmax=mmax,
nmax=nmax,
seed=seed,
name=name,
)
self.tx_topology = TransitionTopology(
......@@ -134,7 +131,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def is_calibrated(self):
return False
def one_step(self, current_events, previous_kernel_results):
def one_step(self, current_events, previous_kernel_results, seed=None):
"""One update of event times.
:param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m,
......@@ -155,7 +152,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
d_max=self.parameters["dmax"],
n_max=self.parameters["nmax"],
)
update = proposal.sample()
update = proposal.sample(seed=seed)
move = update["move"]
to_t = move["t"] + move["delta_t"]
......
......@@ -132,7 +132,7 @@ def EventTimeProposal(
def t():
with tf.name_scope("t"):
# Waiting for fixed tf.nn.sparse_softmax_cross_entropy_with_logits
x = tf.cast(target_events > 0, dtype=tf.float64) # [M, T]
x = tf.cast(target_events > 0, dtype=tf.float32) # [M, T]
return Categorical2(logits=tf.math.log(x), name="event_coords")
def x_star(t, delta_t):
......
......@@ -63,51 +63,12 @@ class UncalibratedLogRandomWalk(tfp.mcmc.UncalibratedRandomWalk):
for s in current_state_parts
]
# Seed handling complexity is due to users possibly expecting an old-style
# stateful seed to be passed to `self.new_state_fn`.
# In other words:
# - If we were given a seed, we sanitize it to stateless, and
# if the `new_state_fn` doesn't like that, we crash and propagate
# the error. Rationale: The contract is stateless sampling given
# seed, and doing otherwise would not meet it.
# - If we were not given a seed, we try `new_state_fn` with a stateless
# seed. Rationale: This is the future.
# - If it fails with a seed incompatibility problem (as best we can
# detect from here), we issue a warning and try it again with a
# stateful-style seed. Rationale: User code that didn't set seeds
# shouldn't suddenly break.
# TODO(b/159636942): Clean up after 2020-09-20.
if seed is not None:
force_stateless = True
seed = samplers.sanitize_seed(seed)
else:
force_stateless = False
if self._seed_stream.original_seed is not None:
warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
stateful_seed = self._seed_stream()
seed = samplers.sanitize_seed(stateful_seed)
try:
# Log random walk
next_state_parts = self.new_state_fn( # pylint: disable=not-callable
[tf.zeros_like(s) for s in current_state_parts], seed,
)
except TypeError as e:
if (
"Expected int for argument" not in str(e)
and TENSOR_SEED_MSG_PREFIX not in str(e)
) or force_stateless:
raise
msg = (
"Falling back to `int` seed for `new_state_fn` {}. Please update "
"to use `tf.random.stateless_*` RNGs. "
"This fallback may be removed after 10-Sep-2020. ({})"
)
warnings.warn(msg.format(self.new_state_fn, str(e)))
seed = None
next_state_parts = self.new_state_fn(
[tf.zeros_lik(s) for s in current_state_parts], stateful_seed
)
seed = samplers.sanitize_seed(seed)
# Log random walk
next_state_parts = self.new_state_fn( # pylint: disable=not-callable
[tf.zeros_like(s) for s in current_state_parts], seed,
)
next_state_parts = [
cs * tf.exp(ns) for cs, ns in zip(current_state_parts, next_state_parts)
]
......
......@@ -66,7 +66,7 @@ def DelOccultProposal(
)
hot_meta = tf.cast(tf.transpose(hot_meta), dtype=events.dtype)
logits = tf.math.log(hot_meta)
X = Categorical2(logits=logits, dtype=dtype, name="m")
X = Categorical2(logits=tf.cast(logits, tf.float32), dtype=dtype, name="m")
return X
def t(m):
......@@ -80,7 +80,9 @@ def DelOccultProposal(
)
hot_times = tf.cast(hot_times, dtype=events.dtype)
logits = tf.math.log(hot_times)
return Categorical2(logits=logits, dtype=dtype, name="t")
return Categorical2(
logits=tf.cast(logits, tf.float32), dtype=dtype, name="t"
)
def x_star(m, t):
"""Draw num to delete"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment