Skip to content

Switch tf.while_loop based Multinomial after all.

Chris Jewell requested to merge github/fork/csuter/stochastic into stochastic

Created by: csuter

The for loop based approach, for reasons I don't yet understand, incurs XLA compilation times that scale poorly with inference period (at least linearly, maybe as much as quadratically; haven't measured). We are working on incorporating the Multinomial code I added here into TFP's Multinomial distribution, but until that is checked in (should be soon), I wanted to get this in here.

We do take a hit here on the post-compilation iteration times, but I suspect we can improve further. I'm seeing the following numbers; this is for an inference period of length 132 (sorry for the weird number...)

For-loop method:

Run 1: 165.3 seconds Run 2: 1.796 seconds ~= 0.0136 per iter

tf.while_loop method:

Run 1: 18.224 seconds Run 2: 6.886 seconds ~= 0.052 per iter

If the long compile and faster iteration time is preferable, feel free not to merge this PR!

Merge request reports

Loading