Commit 35f6c6cc authored by Chris Jewell's avatar Chris Jewell
Browse files

Built rudimentary DAG generator. Need to embellish this!

parent 909bd164
Pipeline #250 failed with stage
in 4 minutes and 11 seconds
......@@ -4,3 +4,4 @@ numpy
tensorflow>=1.14
tensorflow-probability>=0.7.0
lark-parser
graphviz
# Copyright 2019 Chris Jewell <c.jewell@lancaster.ac.uk>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
import importlib
gv_spec = importlib.util.find_spec("graphviz")
have_graphviz = gv_spec is not None
if have_graphviz:
from graphviz import Digraph
from gem.gemlang.ast import ASTNode
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.symbol import RandomVariableSymbol, VariableSymbol, STMSymbol
class DAGNode(ASTNode):
"""Represents a random variable in a probabilistic DAG.
:param name: the name of the node
:param observed: is the random variable observed?
:type observed: bool
:returns: an object of type DAGNode
"""
def __init__(self, name, is_observed=False):
super().__init__()
self.__name = name
self.__observed = is_observed
self.__parents = []
@property
def name(self):
return self.__name
@property
def is_observed(self):
return self.__observed
@property
def parents(self):
return self.__parents
def add_parent(self, dag_node):
assert isinstance(dag_node, DAGNode), f"'{dag_node}' is not a DAGNode"
self.__parents.append(DAGNode)
def get_dependencies(ast_node):
"""Depth-first traversal of a sub-AST, returning
symbols of random variables.
:param ast_node: a sub-AST root node
:returns: a set of IdRef AST nodes
"""
if isinstance(ast_node, IdRef):
node_symbol = ast_node.symbol
if isinstance(node_symbol, RandomVariableSymbol):
return {node_symbol}
elif isinstance(node_symbol, STMSymbol):
return get_dependencies(node_symbol.definition.children[2])
elif isinstance(node_symbol, VariableSymbol):
return get_dependencies(node_symbol.definition.children[1])
random_vars = set()
for child in ast_node.children:
random_vars = random_vars.union(get_dependencies(child))
return random_vars
def dag2dot(dag):
if not have_graphviz:
Warning("Install graphviz to generate DAG visualisations")
return None
dot = Digraph()
for node in dag.values():
dot.node(node.name)
for node in dag.values():
for child in node.children:
dot.edge(node.name, child.name)
return dot
# Algorithm
# =========
# 1. Collect all RandomVariableSymbol symbols from symbol table
# 2. For each RandomVariableSymbol:
# a. Visit declaration AST
# b. Recursively follow RHS, stop at IdRef nodes
# i. If IdRef is RandomVariableSymbol return symbol
# ii. If IdRef is VariableSymbol, follow definition and go to (b)
def generate_dag(symbol_table):
"""Generates a DAG representation of a probability model embedded in GEM code.
:param symbol_table: a symbol table generated by :class:`ParseDeclarations <gem.gemlang.parse_gemlang.ParseDeclarations>`
:returns: a dictionary of nodes of type DAGNode
"""
random_var_symbols = symbol_table.get_by_type(RandomVariableSymbol)
dag_dict = {rv_symbol.name: DAGNode(rv_symbol.name) for rv_symbol in random_var_symbols}
for rv_symbol in random_var_symbols:
ast = rv_symbol.definition
print(ast.children[0])
deps = get_dependencies(ast.children[1])
for dep in deps:
dag_dict[rv_symbol.name].add_parent(dag_dict[dep.name])
dag_dict[dep.name].add_child(dag_dict[rv_symbol.name])
return dag_dict
......@@ -38,7 +38,8 @@ class DataAttacher(ASTWalker):
def onExit_AssignExpr(self, assign):
lhs = assign.children[0]
rhs = assign.children[1]
if isinstance(rhs.symbol, PlaceholderSymbol):
rhs_symb = rhs.children[0].symbol if rhs.has_children() else None
if isinstance(rhs_symb, PlaceholderSymbol):
# ensure lhs is represented in the data
if not lhs.value in self.__data.keys():
raise KeyError(
......
......@@ -131,7 +131,6 @@ class SymbolResolver(SemanticAnalysis):
raise NameError(
f"Undefined function '{f_idref.value}' at line {call.meta.line} column {call.meta.column}.")
f_idref.symbol = fsym
call.symbol = fsym
if len(arglist.children) != fsym.argcount:
raise TypeError(
......
......@@ -7,3 +7,4 @@ tensorflow-probability>=0.7.0
tabulate
matplotlib
lark-parser
graphviz
......@@ -72,6 +72,7 @@ class TestLogProb(unittest.TestCase):
1.5820221, 1.58739755, 1.86379043, 1.66634023, 1.74729067],
dtype=np.float32)
model = GEM(prog, const_data={'x': x})
print(model.pyprog)
with tf.Session() as s:
res = s.run(
model.log_prob_fn(alpha=1.4, beta=0.5, sigma=0.01, y=y))
......
# Copyright 2019 Chris Jewell <c.jewell@lancaster.ac.uk>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
import unittest
from gem.gemlang import gemparse
from gem.gemlang.ast.ast_expression import Call
from gem.gemlang.symbol import RandomVariableSymbol
from gem.gemlang.symbol_resolve import SymbolDeclarer, SymbolResolver
from gem.gemlang.dag import generate_dag, get_dependencies, dag2dot
class TestSimpleDAGGenerator(unittest.TestCase):
def setUp(self):
self.graph_result = """digraph {
alpha
beta
sigma
y
alpha -> y
beta -> y
sigma -> y
}"""
prog = """
x = Vector()
alpha ~ Normal(0, 1000)
beta ~ Normal(0, 1000)
sigma ~ Gamma(1, 1)
y ~ Normal(beta * x + alpha, sigma)
"""
ast = gemparse(prog)
symtab = SymbolDeclarer().declare(ast)
SymbolResolver().resolve(ast, symtab)
self.ast = ast
self.symtab = symtab
def test_get_dependencies(self):
y_rhs = self.ast.children[4].children[1]
deps = get_dependencies(y_rhs)
expected = {'alpha','beta','sigma'}
dep_names = {dep.name for dep in deps}
self.assertIsInstance(y_rhs, Call)
self.assertEqual(expected, dep_names)
def test_dag_generation(self):
dag = generate_dag(self.symtab)
dot = dag2dot(dag)
self.assertEqual(dot.source, self.graph_result)
class TestEpiDAGGenerator(unittest.TestCase):
def setUp(self):
self.expected_graph = "digraph {\n\tbeta\n\tgamma\n\tepi\n\tbeta -> epi\n\tgamma -> epi\n}"
prog = """
beta ~ Gamma(1, 1)
gamma ~ Gamma(1, 1)
Epidemic SIR() {
S = State(init=999)
I = State(init=1)
R = State(init=0)
[S -> I] = beta * I
[I -> R] = gamma
}
epi ~ SIR()
"""
ast = gemparse(prog)
symtab = SymbolDeclarer().declare(ast)
SymbolResolver().resolve(ast, symtab)
self.ast = ast
self.symtab = symtab
def test_dag_generator(self):
dag = generate_dag(self.symtab)
dot = dag2dot(dag)
self.assertEqual(dot.source, self.expected_graph)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment