Created by: csuter
The for loop based approach, for reasons I don't yet understand, incurs XLA compilation times that scale poorly with inference period (at least linearly, maybe as much as quadratically; haven't measured). We are working on incorporating the Multinomial code I added here into TFP's Multinomial distribution, but until that is checked in (should be soon), I wanted to get this in here.
We do take a hit here on the post-compilation iteration times, but I suspect we can improve further. I'm seeing the following numbers; this is for an inference period of length 132 (sorry for the weird number...)
Run 1: 165.3 seconds Run 2: 1.796 seconds ~= 0.0136 per iter
Run 1: 18.224 seconds Run 2: 6.886 seconds ~= 0.052 per iter
If the long compile and faster iteration time is preferable, feel free not to merge this PR!