Skip to content

dtype bug in DiscreteTimeStateTransitionModel class log_prob()

I think there's a bug in your log_prob function in DiscreteTimeStateTransitionModel class. The code below fails (with dtype error from log_prob) when I insert a seasonality term in txrates() of the form tf.math.sin(t / 3.14) / 2 + 1. However the code works if I change the seasonality term to tf.math.sin(tf.cast(t,dtype) / 3.14) / 2 + 1 I suggest this is a bug!

No matter how the seasonality term is written I can plot the timeseries hence eventlist and compute_state() seem on the face of it to be ok. So the problem does appear to be specific to the log_prob function.


import tensorflow as tf
from gemlib.distributions.discrete_time_state_transition_model import DiscreteTimeStateTransitionModel
from gemlib.util import compute_state

dtype = tf.float32

# Initial state, counts per compartment (S, I, R), for one population
initial_state = tf.constant([[99, 1, 0]], dtype)

# Stoichiometry matrix      #  S, I, R
stoichiometry = tf.constant([[-1, 1, 0],  # S->I
                             [0, -1, 1]],  # I->R
                            dtype)


def txrates(t, state):
    """Transition rate per individual corresponding to each row of the stoichiometry matrix.
        Args:
            state: `Tensor` representing the current state (count of individuals in each compartment).
            t: Python float representing the current time e.g. seasonality in S->I transition could
                be driven by tensors of the following form:
                    seasonality = tf.math.sin(t / 3.14) / 2 + 1
                    si = seasonality * beta * state[:, 1] / tf.reduce_sum(state)
        Returns: List of `Tensor`(s) each of which corresponds to a transition.
    """
    aa = tf.math.sin(t / 3.14) / 2 + 1 # log prob errors with this term but doesn't error if 't' is replaced by 'tf.cast(t,dtype)'
    beta, gamma = 0.28, 0.14  # note R0=beta/gamma
    si = aa * beta * state[:, 1] / tf.reduce_sum(state)  # S->I transition rate
    ir = tf.constant([gamma], dtype)  # I->R transition rate
    return [si, ir]


# Instantiate model
sir = DiscreteTimeStateTransitionModel(
    transition_rates=txrates,
    stoichiometry=stoichiometry,
    initial_state=initial_state,
    initial_step=0,
    time_delta=1.0,
    num_steps=100,
)


@tf.function
def simulate_one(elems):
    """One realisation of the epidemic process."""
    return sir.sample()


nsim = 15  # Number of realisations of the epidemic process
eventlist = tf.map_fn(simulate_one, tf.ones([nsim, stoichiometry.shape[0]]))

# Log prob of observing the eventlist, of first simulation, given the model
print('Log prob:', sir.log_prob(eventlist[0, ...]))