Re-wrote discrete_markov_log_prob to batch the call to tfd.Multinomial.log_prob to avoid lgamma time.