Commit 909bd164 authored by Chris Jewell's avatar Chris Jewell
Browse files

Big interface change in semantic analysis. See gem.interface.py.

Still not a great interface to symbol declaration and resolution phases.  Answers on a postcard...
parent bb33c5b7
Pipeline #249 passed with stage
in 4 minutes and 10 seconds
......@@ -37,8 +37,12 @@ class ASTWalker:
4. Whatever an `OnEntry_*` or `OnExit_*` method returns, it replaces the node in the children list of the
parent node in the AST.
:param in=place: Should the AST be modified in place? (default False)
"""
def __init__(self, in_place=False):
self.__in_place = in_place
def __call(self, when, ast_node):
attr = "{}_{}".format(when, ast_node.__class__.__name__)
action = getattr(self, attr, self.__default__)
......@@ -61,17 +65,19 @@ class ASTWalker:
def __default__(self, ast_node):
return ast_node
def walk(self, ast_node):
def _walk(self, ast_node):
"""Walks an AST, dispatching methods as appropriate.
:param ast_node: an ASTNode instance.
:return an object according to the return values of the OnExit_* and OnEntry_* methods.
"""
ast_node = copy.copy(ast_node)
if self.__in_place is not True:
ast_node = copy.copy(ast_node)
ast_node = self.__callOnEntry(ast_node)
if isinstance(ast_node,
ASTNode): # Trap to allow re-writing to something other than an AST
ast_node.children = [self.walk(child) for child in
ast_node.children = [self._walk(child) for child in
ast_node.children]
ast_node = self.__callOnExit(ast_node)
return ast_node
......@@ -6,7 +6,7 @@ from gem.gemlang.symbol import PlaceholderSymbol
from gem.gemlang.tf_output import convert_to_maths_layer
class InjectData(ASTWalker):
class DataAttacher(ASTWalker):
"""InjectData does three things:
1. It converts raw data to tf.Tensor
2. It looks up data placeholders in `data`, removing declaration nodes from AST
......@@ -17,9 +17,14 @@ class InjectData(ASTWalker):
:return a tuple of (ASTNode, SymbolTable)
"""
def __init__(self, data):
super().__init__()
def __init__(self):
super().__init__(in_place=True)
self.__data = None
def attach(self, ast, data):
self.__data = convert_to_maths_layer(data)
self._walk(ast)
self.__data = None
def onExit_GEMProgram(self, ast_node):
# N.B. global scope is pushed by __init__
......
......@@ -39,11 +39,11 @@ class CodeGenerator(ASTWalker):
"""
# pylint: disable=invalid-name,missing-docstring
def __init__(self, symtab):
def __init__(self):
super().__init__()
self.__workspace_stack = []
self.vars = [symb.name for symb in symtab.get_by_type(VariableSymbol)]
self.random_vars = get_random_vars(symtab)
self.vars = []
self.random_vars = []
@property
def current_workspace(self):
......@@ -56,6 +56,11 @@ class CodeGenerator(ASTWalker):
def pop_ws(self):
return self.__workspace_stack.pop()
def generate(self, ast, symbol_table):
self.vars = [symb.name for symb in symbol_table.get_by_type(VariableSymbol)]
self.random_vars = get_random_vars(symbol_table)
return self._walk(ast)
def onEnter_GEMProgram(self, output_obj):
self.push_ws()
......@@ -82,8 +87,8 @@ class CodeGenerator(ASTWalker):
def onEnter_Call(self, call):
fsymb = call.children[0].symbol
fname = self.walk(call.children[0])
args = self.walk(call.children[1])
fname = self._walk(call.children[0])
args = self._walk(call.children[1])
if isinstance(fsymb, BuiltInFunctionSymbol):
return out.builtin_function[str(fname)](args)
return out.Function(str(fname), args)
......@@ -149,8 +154,8 @@ class CodeGenerator(ASTWalker):
lhs = ast_node.children[0]
rhs = ast_node.children[1]
varname = lhs.value
fname = self.walk(rhs.children[0])
args = self.walk(rhs.children[1])
fname = self._walk(rhs.children[0])
args = self._walk(rhs.children[1])
observed = getattr(rhs, 'observed', None)
if str(
fname) in out.builtin_function.keys(): # TODO find a better solution to translating GEM distros to TF
......
......@@ -50,9 +50,9 @@ class GEMParser(Transformer):
def epidecl(self, args, meta):
if isinstance(args[0], ArgList):
return EpiDecl(args[1], args[0], args[2], args[3], meta=meta)
return EpiDecl(str(args[1]), args[0], args[2], args[3], meta=meta)
else:
return EpiDecl(args[0], ArgList([]), args[1], args[2], meta=meta)
return EpiDecl(str(args[0]), ArgList([]), args[1], args[2], meta=meta)
def transitiondef(self, args, meta):
return TransitionDef(args[0], args[1], meta=meta)
......
......@@ -309,7 +309,7 @@ class STMSymbol(ProbabilityDistributionSymbol):
def __init__(self, name, ast_ref, argnames, enclosing_scope=None):
assert enclosing_scope is not None
super().__init__(name, ast_ref, argnames, enclosing_scope)
super().__init__(str(name), ast_ref, argnames, enclosing_scope)
def _init_builtins(self):
state = BuiltInFunctionSymbol('State', ['init'],
......
......@@ -9,7 +9,7 @@ from gem.gemlang.symbol import *
class SemanticAnalysis(ASTWalker):
def __init__(self):
super().__init__()
super().__init__(in_place=True)
self._scope_stack = []
@property
......@@ -25,21 +25,28 @@ class SemanticAnalysis(ASTWalker):
return self._scope_stack.pop()
class ParseDeclarations(SemanticAnalysis):
class SymbolDeclarer(SemanticAnalysis):
def declare(self, ast_node):
"""Walks the AST and constructs a symbol table of declared symbols.
Each symbol is annotated with a pointer to its respective declaration
in the AST.
def __init__(self):
super().__init__()
:param ast_node: an AST root node
:returns: a populated SymbolTable corresponding to ast_node
"""
self._walk(ast_node)
return self._pop_scope()
def onEnter_GEMProgram(self, ast_node):
self._push_scope(GlobalScope(ast_node))
def onExit_GEMProgram(self, ast_node):
return ast_node, self._pop_scope()
return ast_node
def onExit_AssignExpr(self, assign):
varname = str(assign.children[0].value)
expr = assign.children[1]
sym = VariableSymbol(name=varname, definition=assign.children[0])
sym = VariableSymbol(name=varname, definition=assign)
# TODO Set type here
try:
self.current_scope.declare(sym)
......@@ -53,7 +60,7 @@ class ParseDeclarations(SemanticAnalysis):
raise NotImplementedError(
"Nested random variable declarations are currently not supported.")
varname = str(assign.children[0].value)
sym = RandomVariableSymbol(varname, assign.children[0])
sym = RandomVariableSymbol(varname, assign)
try:
self.current_scope.declare(sym)
except NameError as e:
......@@ -62,9 +69,10 @@ class ParseDeclarations(SemanticAnalysis):
assign.children[0].symbol = sym
def onEnter_EpiDecl(self, epidecl):
sym = STMSymbol(epidecl.name, epidecl.children[0],
[arg.children[0].value for arg in
epidecl.children[1].children], self.current_scope)
sym = STMSymbol(name=epidecl.name,
ast_ref=epidecl,
argnames=[arg.children[0].value for arg in epidecl.children[1].children],
enclosing_scope=self.current_scope)
self.current_scope.declare(sym)
self._push_scope(sym)
epidecl.children[0].symbol = sym
......@@ -82,20 +90,29 @@ class ParseDeclarations(SemanticAnalysis):
self._pop_scope()
class ResolveSymbols(SemanticAnalysis):
def __init__(self, symtable):
super().__init__()
self._push_scope(symtable)
class SymbolResolver(SemanticAnalysis):
def resolve(self, ast_node, symbol_table):
"""Resolves symbols in ast_node against symbol_table.
ast_node is modified in-place, each reference to a symbol being
annotated with a reference to the symbol in the symbol table.
def onExit_GEMProgram(self, ast_node):
# N.B. global scope is pushed by __init__
return ast_node, self._pop_scope()
:param ast_node: an AST root node
:param symbol_table: a SymbolTable to resolve ast_node against
:returns: Nothing. ast_node modified in place.
"""
self._push_scope(symbol_table)
self._walk(ast_node)
def onEnter_EpiDecl(self, ast_node):
self._push_scope(ast_node.children[0].symbol)
def onExit_EpiDecl(self, ast_node):
self._pop_scope()
try:
sym = self.current_scope.resolve(ast_node.name)
except NameError as e:
raise NameError(f"{e} at line {ast_node.meta.line} column {ast_node.meta.column}")
ast_node.symbol = sym
def onEnter_Block(self, ast_node):
self._push_scope(ast_node.symbol)
......@@ -185,3 +202,4 @@ def get_random_vars(symtab):
else:
free_rvs.append(symbol.name)
return tuple(free_rvs)
......@@ -28,9 +28,10 @@ from tensorflow_probability import edward2 as ed
from gem.gemlang import gemparse
from gem.gemlang.model_generator import CodeGenerator
from gem.gemlang.symbol_resolve import ParseDeclarations, ResolveSymbols, \
from gem.gemlang.symbol import GlobalScope
from gem.gemlang.symbol_resolve import SymbolDeclarer, SymbolResolver, \
get_random_vars
from gem.gemlang.inject_data import InjectData
from gem.gemlang.inject_data import DataAttacher
from gem.model.edward2_extn import make_value_setter, TransformedRVBijector
......@@ -92,10 +93,10 @@ class GEM:
def __parse__(self, gemprog):
try:
ast = gemparse(gemprog)
ast, self.__symtab = ParseDeclarations().walk(ast)
ast, self.__symtab = ResolveSymbols(self.__symtab).walk(ast)
ast = InjectData(self.gem_external).walk(ast)
self.__pyprog = CodeGenerator(self.__symtab).walk(ast)
self.__symtab = SymbolDeclarer().declare(ast)
SymbolResolver().resolve(ast, self.__symtab)
DataAttacher().attach(ast, self.gem_external)
self.__pyprog = CodeGenerator().generate(ast, self.__symtab)
except Exception as e:
print(f"An error occurred during model compilation: {e}")
raise e
......
......@@ -72,6 +72,8 @@ class TestASTVisitor(unittest.TestCase):
def setUp(self):
class Walker(ASTWalker):
def walk(self, node):
return self._walk(node)
def onEnter_AddExpr(self, node):
pass
......@@ -120,6 +122,6 @@ class TestASTWalkerTransform(unittest.TestCase):
addnode = AddExpr(x, y)
subnode = SubExpr(z, addnode)
subnode = self.visitor.walk(subnode)
subnode = self.visitor._walk(subnode)
self.assertEqual('(SubExpr (Number 2.3) (Number 4.5))',
serialize(subnode))
......@@ -23,7 +23,7 @@ import unittest
from gem.gemlang import serialize
from gem.gemlang.parse_gemlang import gemparse
from gem.gemlang.ast.ast_base import GEMProgram
from gem.gemlang.symbol_resolve import ParseDeclarations, ResolveSymbols
from gem.gemlang.symbol_resolve import SymbolDeclarer, SymbolResolver
class TestInterpreter(unittest.TestCase):
......@@ -49,8 +49,8 @@ class TestInterpreter(unittest.TestCase):
ast = gemparse(model_str)
self.assertIsInstance(ast, GEMProgram)
declarer = ParseDeclarations()
ast, symtable = declarer.walk(ast)
declarer = SymbolDeclarer()
symtable = declarer.declare(ast)
resolver = ResolveSymbols(symtable)
ode1, symtable1 = resolver.walk(ast)
resolver = SymbolResolver()
resolver.resolve(ast, symtable)
......@@ -20,7 +20,7 @@
import unittest
from gem.gemlang import gemparse
from gem.gemlang.symbol_resolve import ParseDeclarations, ResolveSymbols
from gem.gemlang.symbol_resolve import SymbolDeclarer, SymbolResolver
from gem.gemlang.symbol import *
......@@ -30,26 +30,26 @@ class TestParseDeclarations(unittest.TestCase):
pi = 3.14157
"""
ast = gemparse(prog)
ast, symtab = ParseDeclarations().walk(ast)
symtab = SymbolDeclarer().declare(ast)
assign = ast.children[0]
pi_ref = assign.children[0]
self.assertTrue(symtab.resolve('pi'))
self.assertIsInstance(symtab.resolve('pi'), VariableSymbol)
self.assertIs(pi_ref.symbol, symtab.resolve('pi'))
self.assertIs(symtab.resolve('pi').definition, pi_ref)
self.assertIs(symtab.resolve('pi').definition, assign)
def test_decl_stochastic(self):
prog = """
rv ~ Gamma(0.1, 0.1)
"""
ast = gemparse(prog)
ast, symtab = ParseDeclarations().walk(ast)
symtab = SymbolDeclarer().declare(ast)
assign = ast.children[0]
rv_ref = assign.children[0]
self.assertTrue(symtab.resolve('rv'))
self.assertIsInstance(symtab.resolve('rv'), RandomVariableSymbol)
self.assertIs(rv_ref.symbol, symtab.resolve('rv'))
self.assertIs(symtab.resolve('rv').definition, rv_ref)
self.assertIs(symtab.resolve('rv').definition, assign)
class TestResolveSymbols(unittest.TestCase):
......@@ -59,8 +59,8 @@ class TestResolveSymbols(unittest.TestCase):
pi * 2
"""
ast = gemparse(prog)
ast, symtab = ParseDeclarations().walk(ast)
ast, symtab = ResolveSymbols(symtab).walk(ast)
symtab = SymbolDeclarer().declare(ast)
SymbolResolver().resolve(ast, symtab)
mul_expr = ast.children[1]
pi_ref = mul_expr.children[0]
self.assertIs(symtab.resolve('pi'), pi_ref.symbol)
......@@ -71,8 +71,10 @@ class TestResolveSymbols(unittest.TestCase):
gamma * 2
"""
ast = gemparse(prog)
ast, symtab = ParseDeclarations().walk(ast)
ast, symtab = ResolveSymbols(symtab).walk(ast)
declarer2 = SymbolDeclarer()
declarer = SymbolDeclarer()
symtab = declarer.declare(ast)
SymbolResolver().resolve(ast, symtab)
mul_expr = ast.children[1]
gamma_ref = mul_expr.children[0]
self.assertIs(symtab.resolve('gamma'), gamma_ref.symbol)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment