Commit 6800fea5 authored by Chris Jewell's avatar Chris Jewell
Browse files

Small tweaks to ASTNode, adding fields to record a node's current scope, and...

Small tweaks to ASTNode, adding fields to record a node's current scope, and evalType to record what type an expression returns.

Deleted now redundant inject_data.py file.
parent 8ec15bb5
Pipeline #255 failed with stage
in 4 minutes and 20 seconds
......@@ -29,12 +29,12 @@ class ASTNode:
:param meta: line/col metainformation
"""
__children__ = []
def __init__(self, meta=None):
self.__children__ = []
self.__meta__ = meta
self.symbol = None
self.scope = None
self.evalType = None
@property
def children(self):
......@@ -75,7 +75,6 @@ class NullNode(ASTNode):
:param value: a string used to describe the NullNode.
"""
def __init__(self, value, *args, **kwargs):
super().__init__(*args, **kwargs)
self.value = value
......@@ -92,7 +91,6 @@ class GEMProgram(ASTNode):
:param statements: list of objects of type Statement
"""
def __init__(self, statements, *args, **kwargs):
super().__init__(*args, **kwargs)
for stmt in statements:
......@@ -105,7 +103,6 @@ class Block(GEMProgram):
:param statements: list of objects of type Statement
"""
def __init__(self, statements, **kwargs):
statements = [s for s in statements if isinstance(s, ASTNode)]
super().__init__(statements, **kwargs)
......@@ -114,7 +111,6 @@ class Block(GEMProgram):
class Statement(ASTNode):
"""Statement base class. This is not meant to be used as part of the AST API.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
......@@ -150,7 +146,6 @@ class KwRef(ASTNode):
:param name: keyword name
"""
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.value = name
......
......@@ -27,7 +27,8 @@ if have_graphviz:
from gem.gemlang.ast import ASTNode
from gem.gemlang.ast.ast_statement import Statement
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.semantics.symbol import RandomVariableSymbol, VariableSymbol, STMSymbol
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import VariableSymbol, STMSymbol
class DAGNode(ASTNode):
......@@ -70,8 +71,8 @@ def get_dependencies(ast_node):
"""
if isinstance(ast_node, IdRef):
node_symbol = ast_node.symbol
if isinstance(node_symbol, RandomVariableSymbol):
return {node_symbol}
if ast_node.evalType == SymbolTable._random_variable:
return {ast_node.value}
elif isinstance(node_symbol, STMSymbol):
return get_dependencies(node_symbol.definition.children[2])
elif isinstance(node_symbol, VariableSymbol) and isinstance(node_symbol.definition, Statement):
......@@ -114,7 +115,7 @@ def generate_dag(symbol_table):
:param symbol_table: a symbol table generated by :class:`ParseDeclarations <gem.gemlang.parse_gemlang.ParseDeclarations>`
:returns: a dictionary of nodes of type DAGNode
"""
random_var_symbols = symbol_table.get_by_type(RandomVariableSymbol)
random_var_symbols = symbol_table.get_by_type(VariableSymbol)
dag_dict = {rv_symbol.name: DAGNode(rv_symbol.name) for rv_symbol in random_var_symbols}
for rv_symbol in random_var_symbols:
ast = rv_symbol.definition
......
from gem.gemlang import ASTWalker
from gem.gemlang.ast.ast_base import NullNode
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.ast.ast_statement import AssignExpr
from gem.gemlang.tf_output import convert_to_maths_layer
from gem.gemlang.semantics.symbol import *
class DataAttacher(ASTWalker):
"""InjectData does three things:
1. It converts raw data to tf.Tensor
2. It looks up data placeholders in `data`, removing declaration nodes from AST
3. It looks up random variables in the `data` dictionary, adding them as a 'value' field if observations exist
:param symtab: a SymbolTable instance
:param data: a dictionary of data to inject into the model
:return a tuple of (ASTNode, SymbolTable)
"""
def __init__(self):
super().__init__(in_place=True)
self.__data = None
def attach(self, ast, data):
self.__data = convert_to_maths_layer(data)
self._walk(ast)
self.__data = None
def onExit_GEMProgram(self, ast_node):
# N.B. global scope is pushed by __init__
children = ast_node.children
for k, v in self.__data.items():
children.insert(0, AssignExpr(IdRef(k),
IdRef(f"gem_external['{k}']")))
ast_node.children = children
def onExit_AssignExpr(self, assign):
lhs = assign.children[0]
rhs = assign.children[1]
rhs_symb = rhs.children[0].symbol if rhs.has_children() else None
if isinstance(rhs_symb, ExternalDataSymbol):
# ensure lhs is represented in the data
if not lhs.value in self.__data.keys():
raise KeyError(
f"Data for external data reference '{lhs.value}' not supplied at line {lhs.meta.line} column {lhs.meta.column}")
# TODO ensure types match
return NullNode("Placeholder removed")
def onExit_StochasticAssignExpr(self, assign):
lhs = assign.children[0]
rhs = assign.children[1]
if lhs.value in self.__data.keys():
raise SyntaxError(
f"Redefinition of symbol '{lhs.value}' as constant data")
......@@ -69,8 +69,9 @@ class SymbolDeclarer(ScopedWalker):
sym = VariableSymbol(name=varname)
self.current_scope.declare(sym)
sym.ast_def = assign
sym.ast_def = var
var.symbol = sym
var.scope = self.current_scope
def onExit_StochasticAssignExpr(self, assign):
var = assign.children[0]
......@@ -82,8 +83,9 @@ class SymbolDeclarer(ScopedWalker):
except NameError as e:
raise SyntaxError(
f"Duplicate declaration of '{varname}' at line {assign.meta.line} column {assign.meta.column}")
sym.ast_def = assign
sym.ast_def = var
var.symbol = sym
var.scope = self.current_scope
def onEnter_EpiDecl(self, epidecl):
sym = STMSymbol(name=epidecl.name,
......@@ -92,7 +94,7 @@ class SymbolDeclarer(ScopedWalker):
self._push_scope(sym)
sym.ast_def = epidecl
epidecl.symbol = sym
epidecl.children[0].symbol = sym
epidecl.scope = self.current_scope
def onExit_EpiDecl(self, epidecl):
self._pop_scope()
......@@ -100,14 +102,16 @@ class SymbolDeclarer(ScopedWalker):
def onExit_FArg(self, argdef):
sym = VariableSymbol(name=argdef.value)
sym.ast_def = argdef
argdef.symbol = sym
self.current_scope.declare(sym)
argdef.symbol = sym
argdef.scope = self.current_scope
def onEnter_Block(self, block):
scope = LocalScope(enclosing_scope=self.current_scope)
scope.ast_def = block
block.symbol = scope
block.scope = self.current_scope
self._push_scope(scope)
def onExit_Block(self, block):
self._pop_scope()
\ No newline at end of file
self._pop_scope()
......@@ -41,7 +41,7 @@ class SymbolResolver(ScopedWalker):
self._walk(ast_node)
def onEnter_EpiDecl(self, ast_node):
self._push_scope(ast_node.children[0].symbol)
self._push_scope(ast_node.symbol)
def onExit_EpiDecl(self, ast_node):
self._pop_scope()
......@@ -57,18 +57,18 @@ class SymbolResolver(ScopedWalker):
def onExit_Block(self, ast_node):
self._pop_scope()
def onEnter_Call(self, call):
def onExit_Call(self, call):
# 1. Keyword argument names have to be resolved in the function call's scope
# 2. Keyword argument values have to be resolved in the current enclosing scope
f_idref = call.children[0]
arglist = call.children[1]
try:
fsym = self.current_scope.resolve(f_idref.value)
fsym = f_idref.scope.resolve(f_idref.value)
except NameError as err:
raise NameError(
f"Undefined function '{f_idref.value}' at line {call.meta.line} column {call.meta.column}.")
f_idref.symbol = fsym
call.symbol = fsym
call.evalType = fsym.type
if len(arglist.children) != fsym.argcount:
raise SyntaxError(
......@@ -85,19 +85,15 @@ class SymbolResolver(ScopedWalker):
f"Unexpected keyword argument '{argname}' at line {call.meta.line} column {call.meta.line}")
arg.children[0].symbol = kwsym
def onExit_Call(self, call):
call.type = call.symbol.type
pass
def onExit_AssignExpr(self, assign):
lhs = assign.children[0]
rhs = assign.children[1]
if isinstance(rhs, Call) and (rhs.symbol.attrib & SymbolTable.ATTRIB.RANDOM):
if rhs.evalType == SymbolTable._random_variable:
raise ValueError(
f"Probability distribution cannot be assigned in a '=' statement at line {assign.meta.line}"
)
lhs.type = rhs.type
lhs.type = rhs.evalType
def onExit_StochasticAssignExpr(self, assign):
lhs = assign.children[0]
......@@ -120,19 +116,20 @@ class SymbolResolver(ScopedWalker):
"{} at line {} column {}".format(e, ast_node.meta.line,
ast_node.meta.column))
ast_node.symbol = sym
ast_node.type = sym.type
ast_node.scope = self.current_scope
ast_node.evalType = sym.type
def onExit_Integer(self, integer):
integer.type = self.current_scope.resolve('int')
integer.evalType = self.current_scope.resolve('int')
def onExit_Float(self, node):
node.type = self.current_scope.resolve('float')
node.evalType = self.current_scope.resolve('float')
def onExit_BinaryExpr(self, node):
node.type = node.children[0].type
node.type = node.children[0].evalType
def onExit_UnaryExpr(self, node):
node.type = node.children[0].type
node.type = node.children[0].evalType
RandomVariableList = namedtuple("RandomVariableList", ['free', 'observed'])
......
......@@ -67,3 +67,6 @@ class SymbolTable:
sym = VariableSymbol(name=k)
self.globals.declare(sym)
sym.attrib |= SymbolTable.ATTRIB.EXTERN
def __str__(self):
return str(self.globals)
\ No newline at end of file
......@@ -26,6 +26,7 @@ from gem.gemlang.semantics.symbol_declare import SymbolDeclarer
from gem.gemlang.dag import generate_dag, get_dependencies, dag2dot
@unittest.skip("Work in progress")
class TestSimpleDAGGenerator(unittest.TestCase):
def setUp(self):
......
......@@ -39,7 +39,7 @@ class TestParseDeclarations(unittest.TestCase):
self.assertTrue(symtab.globals.resolve('pi'))
self.assertIsInstance(symtab.globals.resolve('pi'), VariableSymbol)
self.assertIs(pi_ref.symbol, symtab.globals.resolve('pi'))
self.assertIs(symtab.globals.resolve('pi').ast_def, assign)
self.assertIs(symtab.globals.resolve('pi').ast_def, pi_ref)
def test_decl_stochastic(self):
prog = """
......@@ -53,7 +53,7 @@ class TestParseDeclarations(unittest.TestCase):
self.assertTrue(symtab.globals.resolve('rv'))
self.assertIsInstance(symtab.globals.resolve('rv'), VariableSymbol)
self.assertIs(rv_ref.symbol, symtab.globals.resolve('rv'))
self.assertIs(symtab.globals.resolve('rv').ast_def, assign)
self.assertIs(symtab.globals.resolve('rv').ast_def, rv_ref)
class TestResolveSymbols(unittest.TestCase):
......
......@@ -19,6 +19,7 @@
import unittest
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import *
from gem.gemlang.semantics.scope import GlobalScope, LocalScope
......@@ -56,3 +57,29 @@ class TestScopedSymbol(unittest.TestCase):
# Tests ability to resolve up the scope
local = LocalScope(enclosing_scope=m)
self.assertIsInstance(local.resolve('bar'), VariableSymbol)
class TestSymbolTable(unittest.TestCase):
def test_symbol_table(self):
symtab = SymbolTable()
print(symtab)
def test_symbol_sir(self):
from gem import GEM
prog = """
beta ~ Gamma(1.0, 1.0)
gamma ~ Gamma(1.0, 1.0)
Epidemic SIR() {
S = State(init=999)
I = State(init=1)
R = State(init=0)
[S -> I] = beta * I
[I -> R] = gamma
}
epi ~ SIR()
"""
model = GEM(prog)
print(model.symtab)
print(model.pyprog)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment