Categorical2.py 1.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Categorical2 corrects a bug in the tfd.Categorical.log_prob"""

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.distributions.categorical import (
    _broadcast_cat_event_and_params,
)

tfd = tfp.distributions

# Todo remove this class when https://github.com/tensorflow/tensorflow/issues/40606
#   is fixed
class Categorical2(tfd.Categorical):
    """Done to override the faulty log_prob in tfd.Categorical due to
       https://github.com/tensorflow/tensorflow/issues/40606"""

    def _log_prob(self, k):
Chris Jewell's avatar
Chris Jewell committed
20
21
22
23
24
25
26
27
        with tf.name_scope("Cat2log_prob"):
            logits = self.logits_parameter()
            if self.validate_args:
                k = distribution_util.embed_check_integer_casting_closed(
                    k, target_dtype=self.dtype
                )
            k, logits = _broadcast_cat_event_and_params(
                k, logits, base_dtype=dtype_util.base_dtype(self.dtype)
28
            )
Chris Jewell's avatar
Chris Jewell committed
29
30
            logits_normalised = tf.math.log(tf.math.softmax(logits))
            return tf.gather(logits_normalised, k, batch_dims=1)