util.py 4.22 KB
Newer Older
Chris Jewell's avatar
Chris Jewell committed
1
2
3
4
"""Utility functions for model implementation code"""

import numpy as np
import tensorflow as tf
5
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
Chris Jewell's avatar
Chris Jewell committed
6
from tensorflow_probability.python.internal import prefer_static as ps
Chris Jewell's avatar
Chris Jewell committed
7

8
9
10

def which(predicate):
    """Returns the indices of True elements of predicate"""
Chris Jewell's avatar
Chris Jewell committed
11
    with tf.name_scope("which"):
12
13
14
15
16
17
18
        x = tf.cast(predicate, dtype=tf.int32)
        index_range = tf.range(x.shape[0])
        indices = tf.cumsum(x) * x
        indices = tf.scatter_nd(indices[:, None], index_range, x.shape)
        return indices[1:]


19
def _gen_index(state_shape, trm_coords):
Chris Jewell's avatar
Chris Jewell committed
20
21
    """Returns a tensor of indices indexing
    coordinates trm_coords into a
22
    state_shape + state_shape[-1] tensor.
Chris Jewell's avatar
Chris Jewell committed
23
24
25
    """
    trm_coords = tf.convert_to_tensor(trm_coords)

Chris Jewell's avatar
Chris Jewell committed
26
    i_shp = state_shape[:-1] + [trm_coords.shape[0]] + [len(state_shape) + 1]
Chris Jewell's avatar
Chris Jewell committed
27
28
29
30
31
32
33
34

    b_idx = np.array(list(np.ndindex(*i_shp[:-1])))[:, :-1]
    m_idx = tf.tile(trm_coords, [tf.reduce_prod(i_shp[:-2]), 1])

    idx = tf.concat([b_idx, m_idx], axis=-1)
    return tf.reshape(idx, i_shp)


35
def make_transition_matrix(rates, rate_coords, state_shape):
Chris Jewell's avatar
Chris Jewell committed
36
    """Create a transition rate matrix
37
38
    :param rates: batched transition rate tensors  [b1, b2, n_rates] or a list of length n_rates of batched
                  tensors [b1, b2]
Chris Jewell's avatar
Chris Jewell committed
39
    :param rate_coords: coordinates of rates in resulting transition matrix
40
    :param state_shape: the shape of the state tensor with ns states
Chris Jewell's avatar
Chris Jewell committed
41
42
    :returns: a tensor of shape [..., ns, ns]
    """
43
    indices = _gen_index(state_shape, rate_coords)
44
45
    if mcmc_util.is_list_like(rates):
        rates = tf.stack(rates, axis=-1)
46
    output_shape = state_shape + [state_shape[-1]]
Chris Jewell's avatar
Chris Jewell committed
47
48
49
    rate_tensor = tf.scatter_nd(
        indices=indices, updates=rates, shape=output_shape, name="build_markov_matrix"
    )
50
    return rate_tensor
Chris Jewell's avatar
Chris Jewell committed
51
52
53
54
55
56
57
58
59
60
61
62


def compute_state(initial_state, events, stoichiometry):
    """Computes a state tensor from initial state and event tensor

    :param initial_state: a tensor of shape [M, S]
    :param events: a tensor of shape [M, T, X]
    :param stoichiometry: a stoichiometry matrix of shape [X, S] describing
                          how transitions update the state.
    :return: a tensor of shape [M, T, S] describing the state of the
             system for each batch M at time T.
    """
Chris Jewell's avatar
Chris Jewell committed
63
64
65
66
67
    if isinstance(stoichiometry, tf.Tensor):
        stoichiometry_ = tf.cast(stoichiometry, dtype=events.dtype)
    else:
        stoichiometry_ = tf.convert_to_tensor(stoichiometry, dtype=events.dtype)
    increments = tf.tensordot(events, stoichiometry_, axes=[[-1], [-2]])  # mtx,xs->mts
Chris Jewell's avatar
Chris Jewell committed
68
69
70
    cum_increments = tf.cumsum(increments, axis=-2, exclusive=True)
    state = cum_increments + tf.expand_dims(initial_state, axis=-2)
    return state
Chris Jewell's avatar
Chris Jewell committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113


def transition_coords(stoichiometry):
    src = np.where(stoichiometry == -1)[1]
    dest = np.where(stoichiometry == 1)[1]
    return np.stack([src, dest], axis=-1)


def batch_gather(tensor, indices):
    """Written by Chris Suter (c) 2020
       Modified by Chris Jewell, 2020
    """
    tensor_shape = ps.shape(tensor)  # B + E
    tensor_rank = ps.rank(tensor)
    indices_shape = ps.shape(indices)  # [N, E]
    num_outputs = indices_shape[0]
    non_batch_rank = indices_shape[1]  # r(E)
    batch_rank = tensor_rank - non_batch_rank

    # batch_shape = tf.cast(tensor_shape[:batch_rank], dtype=tf.int64)
    # batch_size = tf.reduce_prod(batch_shape)
    # Create indices into batch_shape, of shape [batch_size, batch_rank]
    # batch_indices = tf.transpose(
    #    tf.unravel_index(tf.range(batch_size), dims=batch_shape)
    # )

    batch_shape = tensor_shape[:batch_rank]
    batch_size = np.prod(batch_shape)
    batch_indices = np.transpose(
        np.unravel_index(np.arange(batch_size), dims=batch_shape)
    )

    # Tile the batch indices num_outputs times
    batch_indices_tiled = tf.reshape(
        tf.tile(batch_indices, multiples=[1, num_outputs]),
        [batch_size * num_outputs, -1],
    )

    batched_output_indices = tf.tile(indices, multiples=[batch_size, 1])
    full_indices = tf.concat([batch_indices_tiled, batched_output_indices], axis=-1)

    output_shape = ps.concat([batch_shape, [num_outputs]], axis=0)
    return tf.reshape(tf.gather_nd(tensor, full_indices), output_shape)