Commit c4e550c7 authored by Chris Jewell's avatar Chris Jewell
Browse files

Implemented a gemlib math library as a thin wrapper around Tensorflow. This...

Implemented a gemlib math library as a thin wrapper around Tensorflow.  This allows for contextual argument checking and dispatching of specialised TF operations based on dimension of input arguments (see e.g. gem.gemlib.math.matmul).
parent a635a36d
Pipeline #291 passed with stage
in 3 minutes and 56 seconds
......@@ -23,6 +23,7 @@ import inspect
from gem.gemlang.semantics.scope import GlobalScope
from gem.gemlang.semantics.symbol import BuiltinTypeSymbol, SYMBOL_ATTR, _make_method_symbol
from gem.gemlib import distribution as gemdist
from gem.gemlib import math as gemmath
class SymbolTable:
......@@ -57,18 +58,15 @@ class SymbolTable:
# External data
for data_symbol in ['Scalar', 'Vector', 'Matrix', 'NDArray']:
symbol = _make_method_symbol(data_symbol, SymbolTable.EXTERNAL_DATA, [],
symbol.type = SymbolTable.EXTERNAL_DATA
# Builtin functions
global_scope.declare(_make_method_symbol('exp', None,
['x'], enclosing_scope=global_scope))
# Parse builtin RandomVariable distributions
def predicate(x):
return inspect.isclass(x) and \
......@@ -77,6 +75,11 @@ class SymbolTable:
for name, dist in dist_obj:
self.__register_function(dist, SymbolTable.RANDOM_VARIABLE)
# Parse builtin maths functions
funcs = inspect.getmembers(gemmath, inspect.isfunction)
for name, func in funcs:
self.__register_function(func, None)
def __register_function(self, function, return_type=None):
signature = inspect.signature(function)
args = [arg for arg in signature.parameters.keys() if arg is not 'kwargs']
......@@ -72,6 +72,7 @@ class GEMProgram(Outputter):
str = self.line("### Start of GEM-generated code ###") + \
self.line("import tensorflow as tf") + \
self.line("from gem.gemlib.distribution import *") + \
self.line("import gem.gemlib.math as gm") + \
"from gem.model import TFContinuousMarkovILM, make_epidemic_process")
if tf.executing_eagerly():
......@@ -219,7 +220,7 @@ class Truediv(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.divide({self.lhs}, {self.rhs})"
return f"gm.truediv({self.lhs}, {self.rhs})"
class Mul(BinaryOp):
......@@ -227,7 +228,7 @@ class Mul(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.multiply({self.lhs}, {self.rhs})"
return f"gm.multiply({self.lhs}, {self.rhs})"
class Matmul(BinaryOp):
......@@ -235,7 +236,7 @@ class Matmul(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.matmul({self.lhs}, {self.rhs})"
return f"gm.matmul({self.lhs}, {self.rhs})"
class Add(BinaryOp):
......@@ -243,7 +244,7 @@ class Add(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.add({self.lhs}, {self.rhs})"
return f"gm.add({self.lhs}, {self.rhs})"
class Sub(BinaryOp):
......@@ -251,7 +252,7 @@ class Sub(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.subtract({self.lhs}, {self.rhs})"
return f"gm.subtract({self.lhs}, {self.rhs})"
class Pow(BinaryOp):
......@@ -259,7 +260,7 @@ class Pow(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.pow({self.lhs}, {self.rhs})"
return f"gm.pow({self.lhs}, {self.rhs})"
class Lt(BinaryOp):
......@@ -267,7 +268,7 @@ class Lt(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.less({self.lhs}, {self.rhs})"
return f"gm.less({self.lhs}, {self.rhs})"
class Gt(BinaryOp):
......@@ -275,7 +276,7 @@ class Gt(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.greater({self.lhs}, {self.rhs})"
return f"gm.greater({self.lhs}, {self.rhs})"
class Leq(BinaryOp):
......@@ -283,7 +284,7 @@ class Leq(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.less_equal({self.lhs}, {self.rhs})"
return f"gm.less_equal({self.lhs}, {self.rhs})"
class Geq(BinaryOp):
......@@ -291,7 +292,7 @@ class Geq(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.greater_equal({self.lhs}, {self.rhs})"
return f"gm.greater_equal({self.lhs}, {self.rhs})"
class Eq(BinaryOp):
......@@ -299,7 +300,7 @@ class Eq(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.equal({self.lhs}, {self.rhs})"
return f"gm.equal({self.lhs}, {self.rhs})"
class Ne(BinaryOp):
......@@ -307,7 +308,7 @@ class Ne(BinaryOp):
super().__init__(lhs, rhs, **kwargs)
def __str__(self):
return f"tf.not_equal({self.lhs}, {self.rhs})"
return f"gm.not_equal({self.lhs}, {self.rhs})"
class UnaryOp(Outputter):
......@@ -19,21 +19,63 @@
"""gemlib maths functions, such as operators, exponents, and square roots."""
import tensorflow as tf
from tensorflow import math as tfm
class Math:
def __init__(self, lib):
self.lib = lib
def add(self, a, b):
return self.lib.add(a, b)
def exp(x):
return tfm.exp(x)
def matmul(a, b):
if len(b.shape) == 1:
return tf.linalg.matvec(a, b)
return tf.linalg.matmul(a, b)
def add(a, b):
return tf.add(a, b)
def subtract(a, b):
return tf.subtract(a, b)
def multiply(a, b):
return tf.multiply(a, b)
def truediv(a, b):
return tf.divide(a, b)
def pow(a, b):
return tf.pow(a, b)
def less(a, b):
return tf.less(a, b)
def greater(a, b):
return tf.greater(a, b)
def less_equal(a, b):
return tf.less_equal(a, b)
def greater_equal(a, b):
return tf.greater_equal(a, b)
def equal(a, b):
return tf.equal(a, b)
def not_equal(a, b):
return tf.not_equal(a, b)
def mul(self, a, b):
return self.lib.mul(a, b)
def sub(self, a, b):
return self.lib.sub(a, b)
def truediv(self, a, b):
return self.lib.truediv(a, b)
......@@ -294,13 +294,11 @@ class TestInference(unittest.TestCase):
self.assertAlmostEqual(res / 1006.799, 1.0, delta=1e-5)
end = time.perf_counter()
print("Time log_prob hom_sinr_1000:", end - start, "seconds")
post, accept ={'epi': epidata}, n_samples=10000,
post, accept ={'epi': epidata}, n_samples=5000,
init=[0.1, 0.5, 0.1, 0.1], burnin=2000,
class TestMetaPopulation(unittest.TestCase):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment