Commit c868302e authored by Chris Jewell's avatar Chris Jewell
Browse files

Re-write of the DAG generation algorithm after rewrite of semantic analyser. ...

Re-write of the DAG generation algorithm after rewrite of semantic analyser.  Need to implement for STMs, I think.
parent 70d3de2f
......@@ -35,13 +35,46 @@ gem.gemlang.parse\_gemlang module
:undoc-members:
:show-inheritance:
gem.gemlang.semantics.symbol_table module
-----------------------------------------
.. inheritance-diagram:: gem.gemlang.semantics.symbol_table
:parts: 1
.. automodule:: gem.gemlang.semantics.symbol_table
:members:
:undoc-members:
:show-inheritance:
gem.gemlang.semantics.scope module
----------------------------------
.. inheritance-diagram:: gem.gemlang.semantics.scope
:parts: 1
.. automodule:: gem.gemlang.semantics.scope
:members:
:undoc-members:
:show-inheritance:
gem.gemlang.symbol module
-------------------------
.. inheritance-diagram:: gem.gemlang.symbol
.. inheritance-diagram:: gem.gemlang.semantics.symbol
:parts: 1
.. automodule:: gem.gemlang.semantics.symbol
:members:
:undoc-members:
:show-inheritance:
gem.gemlang.symbol\_declare module
----------------------------------
.. inheritance-diagram:: gem.gemlang.semantics.symbol_declare
:parts: 1
.. automodule:: gem.gemlang.symbol
.. automodule:: gem.gemlang.semantics.symbol_declare
:members:
:undoc-members:
:show-inheritance:
......@@ -49,10 +82,10 @@ gem.gemlang.symbol module
gem.gemlang.symbol\_resolve module
----------------------------------
.. inheritance-diagram:: gem.gemlang.symbol_resolve
.. inheritance-diagram:: gem.gemlang.semantics.symbol_resolve
:parts: 1
.. automodule:: gem.gemlang.symbol_resolve
.. automodule:: gem.gemlang.semantics.symbol_resolve
:members:
:undoc-members:
:show-inheritance:
......
......@@ -22,6 +22,11 @@ import copy
from gem.gemlang.ast.ast_base import ASTNode
class StopWalk(Exception):
"""Raise this exception in an action method to abort a tree walk."""
pass
class ASTWalker:
"""Base class for all AST walking.
......@@ -53,7 +58,7 @@ class ASTWalker:
def __callOnEntry(self, ast_node):
new_node = self.__call('onEnter', ast_node)
if new_node is None:
if new_node is None: # Option if a method has a "void" return type
return ast_node
else:
return new_node
......@@ -77,10 +82,13 @@ class ASTWalker:
if self.__in_place is not True:
ast_node = copy.copy(ast_node)
ast_node = self.__callOnEntry(ast_node)
if isinstance(ast_node,
ASTNode): # Trap to allow re-writing to something other than an AST
ast_node.children = [self._walk(child) for child in
ast_node.children]
ast_node = self.__callOnExit(ast_node)
try:
ast_node = self.__callOnEntry(ast_node)
if isinstance(ast_node,
ASTNode): # Trap to allow re-writing to something other than an AST
ast_node.children = [self._walk(child) for child in
ast_node.children]
ast_node = self.__callOnExit(ast_node)
except StopWalk: # Allows us to abort a walk at the end of an onEnter method if topdown parsing only required
pass
return ast_node
......@@ -24,7 +24,8 @@ have_graphviz = gv_spec is not None
if have_graphviz:
from graphviz import Digraph
from gem.gemlang.ast import ASTNode
from gem.gemlang.ast_walker import ASTWalker, StopWalk
from gem.gemlang.ast.ast_base import ASTNode, NullNode
from gem.gemlang.ast.ast_statement import Statement
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.semantics.symbol_table import SymbolTable
......@@ -59,7 +60,16 @@ class DAGNode(ASTNode):
def add_parent(self, dag_node):
assert isinstance(dag_node, DAGNode), f"'{dag_node}' is not a DAGNode"
self.__parents.append(DAGNode)
self.__parents.append(dag_node)
def __repr__(self):
s = f"<{self.__class__.__name__} {self.name}:"
for parent in self.parents:
s += f" {parent}"
return s+">"
def __str__(self):
return self.name
def get_dependencies(ast_node):
......@@ -77,7 +87,7 @@ def get_dependencies(ast_node):
return get_dependencies(node_symbol.definition.children[2])
elif isinstance(node_symbol, VariableSymbol) and isinstance(node_symbol.definition, Statement):
return get_dependencies(node_symbol.definition.children[1])
random_vars = set()
for child in ast_node.children:
random_vars = random_vars.union(get_dependencies(child))
......@@ -125,3 +135,35 @@ def generate_dag(symbol_table):
dag_dict[dep.name].add_child(dag_dict[rv_symbol.name])
return dag_dict
class DAG(ASTWalker):
def __init__(self, ast):
super().__init__()
self.__dag = {}
self.__current_rv = []
self._walk(ast)
def get_dag(self):
return self.__dag
def onEnter_StochasticAssignExpr(self, node):
rv_decl = node.children[0]
dependent_node = DAGNode(rv_decl.value)
self.__dag[rv_decl.value] = dependent_node
self.__current_rv.append(dependent_node)
self._walk(node.children[1])
self.__current_rv.pop()
raise StopWalk()
def onEnter_IdRef(self, node):
if isinstance(node.symbol, VariableSymbol) and node.symbol.type == SymbolTable._random_variable:
parent = self.__dag[node.value]
child = self.__current_rv[-1]
child.add_parent(parent)
parent.add_child(child)
raise StopWalk()
def dot(self):
return dag2dot(self.__dag)
\ No newline at end of file
......@@ -76,8 +76,7 @@ class SymbolDeclarer(ScopedWalker):
def onExit_StochasticAssignExpr(self, assign):
var = assign.children[0]
varname = str(var.value)
sym = VariableSymbol(varname, assign)
sym.attrib |= SymbolTable.ATTRIB.RANDOM
sym = VariableSymbol(varname, SymbolTable._random_variable)
try:
self.current_scope.declare(sym)
except NameError as e:
......
......@@ -147,7 +147,7 @@ def get_random_vars(symtab):
free_rvs = []
observed_rvs = []
symbols = [sym for sym in symtab.globals.get_by_type(VariableSymbol) if sym.attrib & SymbolTable.ATTRIB.RANDOM]
symbols = [sym for sym in symtab.globals.get_by_type(VariableSymbol) if sym.type == SymbolTable._random_variable]
for symbol in symbols:
ast = symbol.ast_def
assert ast is not None, f"symbol {symbol} is not annotated with an ASTNode"
......
......@@ -101,6 +101,7 @@ class GEM:
self.__create_symbol_table()
SymbolDeclarer().visit(ast, self.__symtab)
SymbolResolver().visit(ast, self.__symtab)
self.ast = ast
self.__pyprog = CodeGenerator().generate(ast, self.__symtab, self.gem_external)
except Exception as e:
print(f"An error occurred during model compilation: {e}")
......
......@@ -19,26 +19,15 @@
import unittest
from gem.gemlang import gemparse
from gem.gemlang.ast.ast_expression import Call
from gem.gemlang.semantics.symbol_resolve import SymbolResolver
from gem.gemlang.semantics.symbol_declare import SymbolDeclarer
from gem.gemlang.dag import generate_dag, get_dependencies, dag2dot
from gem import GEM
from gem.gemlang.dag import dag2dot
@unittest.skip("Work in progress")
class TestSimpleDAGGenerator(unittest.TestCase):
def setUp(self):
self.graph_result = """digraph {
alpha
beta
sigma
y
alpha -> y
beta -> y
sigma -> y
}"""
self.expected_graph = "digraph {\n\talpha\n\tbeta\n\tsigma\n\ty\n\talpha -> y\n\tbeta -> y\n\tsigma -> y\n}"
prog = """
x = Vector()
alpha ~ Normal(0, 1000)
......@@ -46,24 +35,18 @@ class TestSimpleDAGGenerator(unittest.TestCase):
sigma ~ Gamma(1, 1)
y ~ Normal(beta * x + alpha, sigma)
"""
ast = gemparse(prog)
symtab = SymbolDeclarer().visit(ast)
SymbolResolver().visit(ast, symtab)
self.ast = ast
self.symtab = symtab
def test_get_dependencies(self):
y_rhs = self.ast.children[4].children[1]
deps = get_dependencies(y_rhs)
expected = {'alpha','beta','sigma'}
dep_names = {dep.name for dep in deps}
self.assertIsInstance(y_rhs, Call)
self.assertEqual(expected, dep_names)
x = [1.,2.,3.,4.]
self.model = GEM(prog, {'x': x})
def test_dag_generation(self):
dag = generate_dag(self.symtab)
dot = dag2dot(dag)
self.assertEqual(dot.source, self.graph_result)
from gem.gemlang.dag import DAG
DAG(self.model.ast)
def test_dot_generation(self):
from gem.gemlang.dag import DAG
dag = DAG(self.model.ast)
self.assertEqual(dag.dot().source, self.expected_graph)
@unittest.skip("Work in progress")
class TestEpiDAGGenerator(unittest.TestCase):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment