Commit b442d03f authored by Chris Jewell's avatar Chris Jewell
Browse files

Pylint tidying

parent 7eb901e8
Pipeline #271 passed with stage
in 1 minute and 8 seconds
......@@ -18,9 +18,9 @@
# #
#
"""GEM domain-specific language and analysis for epidemic modelling."""
from gem.interface import GEM
import gem.version as version
"""GEM domain-specific language and analysis for epidemic modelling."""
__version__ = f"{version.MAJOR}.{version.MINOR}.{version.MAINTENANCE}.{version.RELEASE}"
......@@ -17,6 +17,8 @@
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""The gemlang source-source translation chain."""
from gem.gemlang.ast.utils import serialize
from gem.gemlang.ast_walker import ASTWalker
from gem.gemlang.parse_gemlang import GEMParser
......
......@@ -19,15 +19,14 @@
"""AST Expression classes"""
import operator
__all__ = ['Expr', 'UnaryExpr', 'BinaryExpr', 'Call', 'KwArg',
'IdRef', 'ArrayLiteral', 'Number', 'Integer', 'Float', 'String', 'TransitionDef', 'ArgList']
import numpy as np
import operator
from gem.gemlang.ast.ast_base import Statement, ASTNode, KwRef
from gem.gemlang.semantics.scope import Scope
__all__ = ['Expr', 'UnaryExpr', 'BinaryExpr', 'Call', 'KwArg',
'IdRef', 'ArrayLiteral', 'Number', 'Integer', 'Float', 'String', 'TransitionDef', 'ArgList']
_BINARY_OPS = ['truediv', 'mul', 'matmul', 'add', 'sub', 'pow', 'lt', 'gt',
'le', 'ge', 'eq', 'ne']
......
......@@ -17,6 +17,15 @@
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""
AST Walker module
=================
The ASTWalker provides a base class to generate algorithms that walk the AST do do stuff like
symbol declaration, resolution, type-checking, and finally code generation.
"""
import copy
from gem.gemlang.ast.ast_base import ASTNode
......@@ -24,10 +33,9 @@ from gem.gemlang.ast.ast_base import ASTNode
class StopWalk(Exception):
"""Raise this exception in an action method to abort a tree walk."""
pass
class ASTWalker:
class ASTWalker: # pylint: disable=too-few-public-methods
"""Base class for all AST walking.
This class is essentially a dispatcher for calling node-specific methods
......@@ -53,24 +61,24 @@ class ASTWalker:
attr = "{}_{}".format(when, cls.__name__)
action = getattr(self, attr, None)
if action is not None:
return action(ast_node)
return action(ast_node) # pylint: disable=not-callable
return self.__default__(ast_node)
def __callOnEntry(self, ast_node):
def __callOnEntry(self, ast_node): # pylint: disable=invalid-name
new_node = self.__call('onEnter', ast_node)
if new_node is None: # Option if a method has a "void" return type
return ast_node
else:
return new_node
def __callOnExit(self, ast_node):
def __callOnExit(self, ast_node): # pylint: disable=invalid-name
new_node = self.__call('onExit', ast_node)
if new_node is None:
return ast_node
else:
return new_node
def __default__(self, ast_node):
def __default__(self, ast_node): # pylint: disable=no-self-use
return ast_node
def _walk(self, ast_node):
......@@ -89,6 +97,6 @@ class ASTWalker:
ast_node.children = [self._walk(child) for child in
ast_node.children]
ast_node = self.__callOnExit(ast_node)
except StopWalk: # Allows us to abort a walk at the end of an onEnter method if topdown parsing only required
except StopWalk: # Allows us to abort a walk at the end of an onEnter method for topdown parsing
pass
return ast_node
......@@ -16,20 +16,21 @@
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
# IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
import importlib
gv_spec = importlib.util.find_spec("graphviz")
have_graphviz = gv_spec is not None
"""DAG generation"""
if have_graphviz:
from graphviz import Digraph
import importlib
from gem.gemlang.ast_walker import ASTWalker, StopWalk
from gem.gemlang.ast.ast_base import ASTNode
from gem.gemlang.ast.ast_statement import EpiDecl
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import VariableSymbol, STMSymbol
from gem.gemlang.semantics.symbol import VariableSymbol
HAVE_GRAPHVIZ = importlib.util.find_spec("graphviz") is not None
if HAVE_GRAPHVIZ:
from graphviz import Digraph
class DAGNode(ASTNode):
......@@ -48,54 +49,41 @@ class DAGNode(ASTNode):
@property
def name(self):
"""Returns the name of the DAGNode"""
return self.__name
@property
def is_observed(self):
"""Is the DAGNode observed?"""
return self.__observed
@property
def parents(self):
"""Returns a list of the DAGNode's parent DAGNodes."""
return self.__parents
def add_parent(self, dag_node):
"""Add a parent DAGNode"""
assert isinstance(dag_node, DAGNode), f"'{dag_node}' is not a DAGNode"
self.__parents.append(dag_node)
def __repr__(self):
s = f"<{self.__class__.__name__} {self.name}:"
string = f"<{self.__class__.__name__} {self.name}:"
for parent in self.parents:
s += f" {parent}"
return s+">"
string += f" {parent}"
return string+">"
def __str__(self):
return self.name
def get_dependencies(ast_node):
"""Depth-first traversal of a sub-AST, returning
symbols of random variables.
def dag2dot(dag):
"""Takes a DAGNode tree and returns a string with dot code representation.
:param ast_node: a sub-AST root node
:returns: a set of IdRef AST nodes
:param dag: a DAGNode tree
:returns: a string of dot code.
"""
if isinstance(ast_node, IdRef):
node_symbol = ast_node.symbol
if ast_node.evalType == SymbolTable._random_variable:
return {ast_node.value}
elif isinstance(node_symbol, STMSymbol):
return get_dependencies(node_symbol.definition.children[2])
elif isinstance(node_symbol, VariableSymbol) and isinstance(node_symbol.definition, Statement):
return get_dependencies(node_symbol.definition.children[1])
random_vars = set()
for child in ast_node.children:
random_vars = random_vars.union(get_dependencies(child))
return random_vars
def dag2dot(dag):
if not have_graphviz:
if not HAVE_GRAPHVIZ:
print(Warning("Install graphviz to generate DAG visualisations"))
return None
......@@ -109,35 +97,20 @@ def dag2dot(dag):
return dot
# Algorithm
# =========
# 1. Collect all RandomVariableSymbol symbols from symbol table
# 2. For each RandomVariableSymbol:
# a. Visit declaration AST
# b. Recursively follow RHS, stop at IdRef nodes
# i. If IdRef is RandomVariableSymbol return symbol
# ii. If IdRef is VariableSymbol, follow definition and go to (b)
def generate_dag(symbol_table):
"""Generates a DAG representation of a probability model embedded in GEM code.
:param symbol_table: a symbol table generated by :class:`ParseDeclarations <gem.gemlang.parse_gemlang.ParseDeclarations>`
:returns: a dictionary of nodes of type DAGNode
class DAG(ASTWalker):
"""Creates a DAG representation of the probability model, with possible graphviz representation.
Algorithm
=========
1. Walk AST
2. For each StochasticAssignExpr:
a. push declaration to a dict
b. Recursively follow RHS, stop at IdRef nodes
i. If IdRef is a VariableSymbol of type RandomVariable, return symbol
ii. If IdRef is a MethodSymbol, follow definition and go to (b)
"""
random_var_symbols = symbol_table.get_by_type(VariableSymbol)
dag_dict = {rv_symbol.name: DAGNode(rv_symbol.name) for rv_symbol in random_var_symbols}
for rv_symbol in random_var_symbols:
ast = rv_symbol.definition
deps = get_dependencies(ast.children[1])
for dep in deps:
dag_dict[rv_symbol.name].add_parent(dag_dict[dep.name])
dag_dict[dep.name].add_child(dag_dict[rv_symbol.name])
return dag_dict
class DAG(ASTWalker):
# pylint: disable=missing-function-docstring,invalid-name
def __init__(self, ast):
super().__init__()
......@@ -146,9 +119,10 @@ class DAG(ASTWalker):
self._walk(ast)
def get_dag(self):
"""Returns the DAG"""
return self.__dag
def onEnter_StochasticAssignExpr(self, node):
def onEnter_StochasticAssignExpr(self, node): # pylint: disable=invalid-name
rv_decl = node.children[0]
dependent_node = DAGNode(rv_decl.value)
self.__dag[rv_decl.value] = dependent_node
......@@ -175,4 +149,4 @@ class DAG(ASTWalker):
self._walk(fn_def.children[2])
def dot(self):
return dag2dot(self.__dag)
\ No newline at end of file
return dag2dot(self.__dag)
......@@ -18,13 +18,13 @@
# #
#
"""GEM gemlang ode generator stage"""
import gem.gemlang.tf_output as out
from gem.gemlang.ast.ast_base import NullNode
from gem.gemlang.ast.ast_expression import Call
from gem.gemlang.ast_walker import ASTWalker
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import *
import gem.gemlang.semantics.symbol
from gem.gemlang.semantics.symbol_resolve import get_random_vars
......@@ -38,7 +38,7 @@ class CodeGenerator(ASTWalker):
:param symtab: a fully populated symbol table
:returns: a CodeGenerator object
"""
# pylint: disable=invalid-name,missing-docstring
# pylint: disable=invalid-name,missing-docstring,no-self-use,too-many-public-methods
def __init__(self):
super().__init__()
......@@ -58,19 +58,20 @@ class CodeGenerator(ASTWalker):
def pop_ws(self):
return self.__workspace_stack.pop()
def generate(self, ast, symbol_table, external_data={}):
self.vars = [symb.name for symb in symbol_table.globals.get_by_type(VariableSymbol)]
def generate(self, ast, symbol_table, external_data=None):
external_data = {} if external_data is None else external_data
self.vars = [symb.name for symb in symbol_table.globals.get_by_type(gem.gemlang.semantics.symbol.VariableSymbol)]
self.random_vars = get_random_vars(symbol_table)
self.external_data = external_data
return self._walk(ast)
def onEnter_GEMProgram(self, output_obj):
def onEnter_GEMProgram(self, program): # pylint: disable=unused-argument
self.push_ws()
def onExit_GEMProgram(self, output_obj):
stmts = [out.AssignExternal(k) for k in self.external_data.keys()]
stmts.extend([stmt for stmt in output_obj.children if
not isinstance(stmt, NullNode)])
not isinstance(stmt, NullNode)])
self.pop_ws()
return str(out.GEMProgram(self.vars, stmts))
......@@ -98,7 +99,6 @@ class CodeGenerator(ASTWalker):
return stm
def onEnter_Call(self, call):
fsymb = call.children[0].symbol
fname = self._walk(call.children[0])
args = self._walk(call.children[1])
return out.Function(str(fname), args)
......@@ -107,8 +107,7 @@ class CodeGenerator(ASTWalker):
return self.onExit_ArgList(farglist)
def onExit_ArgList(self, arglist):
args = [arg for arg in arglist.children]
return args
return arglist.children
def onExit_KwArg(self, kwarg):
return str(kwarg.children[1])
......@@ -165,6 +164,7 @@ class CodeGenerator(ASTWalker):
self.current_workspace['states'].append(lhs.value)
self.current_workspace['initial'].append(initial[0])
return NullNode('state reference removed')
return ast_node
def onExit_AssignExpr(self, ast_node):
lhs = ast_node.children[0]
......@@ -179,9 +179,9 @@ class CodeGenerator(ASTWalker):
args = self._walk(rhs.children[1])
observed = getattr(rhs, 'observed', None)
if str(
fname) in out.builtins_.keys(): # TODO find a better solution to translating GEM distros to TF
fname) in out.builtins_.keys(): # TODO find better solution to translating GEM distros to TF
rv_obj = out.builtins_[str(fname)](args, value=observed,
name=varname)
name=varname)
else:
rv_obj = out.STMInstance(fname, args, value=observed, name=varname)
return out.Assign(varname, rv_obj)
......
......@@ -20,19 +20,24 @@
# Parse GEM Language to AST
import os
import inspect
from lark import Lark, Transformer, v_args
from gem.gemlang.ast.ast_base import GEMProgram, Block, FArgList, FArg, KwRef, Statement
from gem.gemlang.ast.ast_statement import EpiDecl, AssignExpr, AssignTransition, StochasticAssignExpr
from gem.gemlang.ast.ast_expression import *
from gem.gemlang.ast.ast_expression import * # pylint: disable=unused-wildcard-import
class GrammarError(Exception):
pass
@v_args(meta=True)
class GEMParser(Transformer):
"""GEMParser parse-tree to AST transformer."""
# pylint: disable=missing-function-docstring,no-self-use,unused-argument,undefined-variable,invalid-name
# pylint: disable=too-many-public-methods
def gem(self, stmts, meta):
# Filter out NEWLINE tokens.
......@@ -68,7 +73,7 @@ class GEMParser(Transformer):
def assign_sojourn_density(self, args, meta):
raise NotImplementedError("Non-Markovian models not implemented yet.")
return StochasticAssignExpr(args[0], args[1], meta=meta)
# return StochasticAssignExpr(args[0], args[1], meta=meta)
def expr(self, ex, meta):
if len(ex) > 1:
......
......@@ -43,6 +43,9 @@ class SymbolDeclarer(ScopedWalker):
:param ast_node: an AST root node
:returns: a populated SymbolTable corresponding to ast_node
"""
# pylint: disable=missing-function-docstring,no-self-use,unused-argument
self._push_scope(symbol_table.globals)
self._walk(ast_node)
......
......@@ -21,10 +21,10 @@
from collections import namedtuple
from gem.gemlang.ast.ast_expression import *
from gem.gemlang.ast.ast_expression import * # pylint: disable=unused-wildcard-import
from gem.gemlang.semantics.scoped_ast_walker import ScopedWalker
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import *
import gem.gemlang.semantics.symbol
class SymbolResolver(ScopedWalker):
......@@ -37,6 +37,9 @@ class SymbolResolver(ScopedWalker):
:param symbol_table: a SymbolTable to resolve ast_node against
:returns: Nothing. ast_node modified in place.
"""
# pylint: disable=missing-function-docstring,no-self-use,unused-argument,invalid-name
self._push_scope(symbol_table.globals)
self._walk(ast_node)
......@@ -147,7 +150,7 @@ def get_random_vars(symtab):
free_rvs = []
observed_rvs = []
symbols = [sym for sym in symtab.globals.get_by_type(VariableSymbol) if sym.type == SymbolTable._random_variable]
symbols = [sym for sym in symtab.globals.get_by_type(gem.gemlang.semantics.symbol.VariableSymbol) if sym.type == SymbolTable._random_variable]
for symbol in symbols:
ast = symbol.ast_def
assert ast is not None, f"symbol {symbol} is not annotated with an ASTNode"
......
......@@ -18,9 +18,11 @@
# #
#
"""GEM version information"""
MAJOR = '0'
MINOR = '1'
MAINTENANCE = '0'
RELEASE = 'dev0'
RELEASE = 'pre-alpha'
VERSION_STRING = f"{MAJOR}.{MINOR}.{MAINTENANCE}{RELEASE}"
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment