Commit 9b2ad613 authored by Chris Jewell's avatar Chris Jewell
Browse files

Simplified symbol table slightly, and fixed bug in dag generator caused by...

Simplified symbol table slightly, and fixed bug in dag generator caused by IdRef definition pointing to IdRef nodes rather than *Assign nodes for formal arguments.
parent f31b3365
Pipeline #253 passed with stage
in 4 minutes and 17 seconds
......@@ -25,6 +25,7 @@ if have_graphviz:
from graphviz import Digraph
from gem.gemlang.ast import ASTNode
from gem.gemlang.ast.ast_statement import Statement
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.symbol import RandomVariableSymbol, VariableSymbol, STMSymbol
......@@ -67,14 +68,13 @@ def get_dependencies(ast_node):
:param ast_node: a sub-AST root node
:returns: a set of IdRef AST nodes
"""
if isinstance(ast_node, IdRef):
node_symbol = ast_node.symbol
if isinstance(node_symbol, RandomVariableSymbol):
return {node_symbol}
elif isinstance(node_symbol, STMSymbol):
return get_dependencies(node_symbol.definition.children[2])
elif isinstance(node_symbol, VariableSymbol):
elif isinstance(node_symbol, VariableSymbol) and isinstance(node_symbol.definition, Statement):
return get_dependencies(node_symbol.definition.children[1])
random_vars = set()
......@@ -118,7 +118,6 @@ def generate_dag(symbol_table):
dag_dict = {rv_symbol.name: DAGNode(rv_symbol.name) for rv_symbol in random_var_symbols}
for rv_symbol in random_var_symbols:
ast = rv_symbol.definition
print(ast.children[0])
deps = get_dependencies(ast.children[1])
for dep in deps:
dag_dict[rv_symbol.name].add_parent(dag_dict[dep.name])
......
......@@ -9,7 +9,7 @@ __all__ = ['Symbol', 'GlobalScope', 'LocalScope', 'ExternalDataSymbol',
'RandomVariableSymbol',
'PlaceholderSymbol', 'BuiltInFunctionSymbol',
'ProbabilityDistributionSymbol',
'BuiltinProbabilityDistributionSymbol', 'FArgSymbol', 'STMSymbol', 'Scope',
'BuiltinProbabilityDistributionSymbol', 'STMSymbol', 'Scope',
'ScopedSymbol']
_anon_scope_count = 0
......@@ -188,7 +188,7 @@ class VariableSymbol(Symbol):
:param type: the variable type
"""
def __init__(self, name, definition, type='None'):
def __init__(self, name, definition, type='None', is_parameter=False):
super().__init__(name, definition)
self.type = type
......@@ -222,11 +222,6 @@ class ScopedSymbol(Scope, Symbol):
super().__init__(name, enclosing_scope=enclosing_scope)
class FArgSymbol(Symbol):
def __init__(self, name, definition):
super().__init__(name, definition)
class FunctionSymbol(ScopedSymbol):
"""Represents a function scope.
......@@ -239,13 +234,11 @@ class FunctionSymbol(ScopedSymbol):
assert isinstance(definition, (ASTNode,
str)), f"definition={definition} is of type {type(definition)}"
super().__init__(name, definition, enclosing_scope=enclosing_scope)
self.__params = Scope()
@property
def argcount(self):
"""Returns the number of required arguments"""
symbols = [k for k in self._symbols.keys() if k[0] != '_']
return len(symbols)
return len(self.get_by_type(VariableSymbol))
class PlaceholderSymbol(FunctionSymbol):
......@@ -297,21 +290,17 @@ class STMSymbol(ProbabilityDistributionSymbol):
:param ast_ref: reference to an ASTNode defining the STM.
:param enclosing_scope: the enclosing scope.
"""
def __init__(self, name, ast_ref, enclosing_scope=None):
assert enclosing_scope is not None
super().__init__(str(name), ast_ref, enclosing_scope)
self.config = Scope(name='options', enclosing_scope=self)
self.config.declare(VariableSymbol('class', self.definition))
self.config.declare(VariableSymbol('time_origin', self.definition))
self.config.declare(VariableSymbol('time_step', self.definition))
def _init_builtins(self):
self.declare(make_builtin_symbol(BuiltInFunctionSymbol, 'State', ['init'],
enclosing_scope=self)) # State only allowable in an STM
self.declare(VariableSymbol('class', self.definition))
self.declare(VariableSymbol('time_origin', self.definition))
self.declare(VariableSymbol('time_step', self.definition))
@property
def argcount(self):
return len(self.get_by_type(FArgSymbol))
......
......@@ -80,7 +80,7 @@ class SymbolDeclarer(SemanticAnalysis):
self._pop_scope()
def onExit_FArg(self, argdef):
sym = FArgSymbol(name=argdef.value, definition=argdef)
sym = VariableSymbol(name=argdef.value, definition=argdef, is_parameter=True)
self.current_scope.declare(sym)
def onEnter_Block(self, block):
......@@ -173,13 +173,11 @@ class SymbolResolver(SemanticAnalysis):
try:
sym = self.current_scope.resolve(ast_node.value)
except NameError as e:
print(self.current_scope)
raise NameError(
"{} at line {} column {}".format(e, ast_node.meta.line,
ast_node.meta.column))
ast_node.symbol = sym
RandomVariableList = namedtuple("RandomVariableList", ['free', 'observed'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment