chainbinom_simulate.py 1.59 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Functions for chain binomial simulation."""

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions


def update_state(update, state, stoichiometry):
    update = tf.expand_dims(update, 1)  # Rx1xN
    update *= tf.expand_dims(stoichiometry, -1)  # RxSx1
    update = tf.reduce_sum(update, axis=0)  # SxN
    return state + update


def chain_binomial_propagate(h, time_step, stoichiometry):
    def propagate_fn(state):
17
        state_idx, rates = h(state)
Chris Jewell's avatar
Chris Jewell committed
18
        probs = 1 - tf.exp(-rates*time_step)  # RxN
19
20
21
        state_mult = tf.scatter_nd(state_idx[:, None], state,
                                   shape=[state_idx.shape[0], state.shape[1], state.shape[2]])
        update = tfd.Binomial(state_mult, probs=probs).sample()  # RxN
Chris Jewell's avatar
Chris Jewell committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        update = tf.expand_dims(update, 1)  # Rx1xN
        upd_shape = tf.concat([stoichiometry.shape, tf.fill([tf.rank(state)-1], 1)], axis=0)
        update *= tf.reshape(stoichiometry, upd_shape)  # RxSx1
        update = tf.reduce_sum(update, axis=0)
        state = state + update
        return state
    return propagate_fn


def chain_binomial_simulate(hazard_fn, state, start, end, time_step, stoichiometry):

    propagate = chain_binomial_propagate(hazard_fn, time_step, stoichiometry)
    times = tf.range(start, end, time_step)

Chris Jewell's avatar
Chris Jewell committed
36
    output = tf.TensorArray(tf.float64, size=times.shape[0])
Chris Jewell's avatar
Chris Jewell committed
37
38
39
40
41
42
    output = output.write(0, state)

    for i in tf.range(1, times.shape[0]):
        state = propagate(state)
        output = output.write(i, state)

43
44
    with tf.device("/CPU:0"):
        sim = output.gather(tf.range(times.shape[0]))
Chris Jewell's avatar
Chris Jewell committed
45
    return times, sim