dtype bug in DiscreteTimeStateTransitionModel class log_prob()
I think there's a bug in your log_prob function in DiscreteTimeStateTransitionModel class. The code below fails (with dtype error from log_prob) when I insert a seasonality term in txrates() of the form tf.math.sin(t / 3.14) / 2 + 1. However the code works if I change the seasonality term to tf.math.sin(tf.cast(t,dtype) / 3.14) / 2 + 1 I suggest this is a bug!
No matter how the seasonality term is written I can plot the timeseries hence eventlist and compute_state() seem on the face of it to be ok. So the problem does appear to be specific to the log_prob function.
import tensorflow as tf
from gemlib.distributions.discrete_time_state_transition_model import DiscreteTimeStateTransitionModel
from gemlib.util import compute_state
dtype = tf.float32
# Initial state, counts per compartment (S, I, R), for one population
initial_state = tf.constant([[99, 1, 0]], dtype)
# Stoichiometry matrix # S, I, R
stoichiometry = tf.constant([[-1, 1, 0], # S->I
[0, -1, 1]], # I->R
dtype)
def txrates(t, state):
"""Transition rate per individual corresponding to each row of the stoichiometry matrix.
Args:
state: `Tensor` representing the current state (count of individuals in each compartment).
t: Python float representing the current time e.g. seasonality in S->I transition could
be driven by tensors of the following form:
seasonality = tf.math.sin(t / 3.14) / 2 + 1
si = seasonality * beta * state[:, 1] / tf.reduce_sum(state)
Returns: List of `Tensor`(s) each of which corresponds to a transition.
"""
aa = tf.math.sin(t / 3.14) / 2 + 1 # log prob errors with this term but doesn't error if 't' is replaced by 'tf.cast(t,dtype)'
beta, gamma = 0.28, 0.14 # note R0=beta/gamma
si = aa * beta * state[:, 1] / tf.reduce_sum(state) # S->I transition rate
ir = tf.constant([gamma], dtype) # I->R transition rate
return [si, ir]
# Instantiate model
sir = DiscreteTimeStateTransitionModel(
transition_rates=txrates,
stoichiometry=stoichiometry,
initial_state=initial_state,
initial_step=0,
time_delta=1.0,
num_steps=100,
)
@tf.function
def simulate_one(elems):
"""One realisation of the epidemic process."""
return sir.sample()
nsim = 15 # Number of realisations of the epidemic process
eventlist = tf.map_fn(simulate_one, tf.ones([nsim, stoichiometry.shape[0]]))
# Log prob of observing the eventlist, of first simulation, given the model
print('Log prob:', sir.log_prob(eventlist[0, ...]))