Skip to content

Improve efficiency of stochastic model.

Chris Jewell requested to merge github/fork/csuter/stochastic into stochastic

Created by: csuter

Big changes:

  1. replace python for loop with tf.while_loop
  2. work with a transposed state tensor shape
  • instead of [4, nlads * nages], use [nlads * nages, 4]
  • this made it pretty easy to eliminate some transposes in propagate_fn (there were comments there seemingly contemplating this shape arrangement)
  • this feels a little more natural to me, too; in TFP we'd call the 4 SEIR states components of the "event shape" of the system, and the nlads * nages part a "batch shape" (although one could reasonably also combine these together into one big matrix "event shape")
  • anyway, this allowed elimination of 3 transpose ops which makes for simpler code and avoids some memcpys
  • I also made an effort to update surrounding code to use the same data layout, but it seems like mcmc.py and covid_ode.py are broken right now anyway, due to other changes made in support of stochastic mode, so I couldn't confirm that my changes were sufficient.
  1. switch off XLA (which didn't yield any clear improvement, although it also didn't really hurt), and disable autograph (which tries to do things like rewrite python for loops into TF graph code but tends to produce less performant than manually optimized code like what I've done here)

Merge request reports

Loading