Commit e3fded50 authored by Chris Jewell's avatar Chris Jewell
Browse files

Fixed bug in DAG Generator. Now appears to output sensible graphs using graphviz/dot.

parent c868302e
Pipeline #268 passed with stage
in 6 minutes and 50 seconds
......@@ -25,8 +25,8 @@ if have_graphviz:
from graphviz import Digraph
from gem.gemlang.ast_walker import ASTWalker, StopWalk
from gem.gemlang.ast.ast_base import ASTNode, NullNode
from gem.gemlang.ast.ast_statement import Statement
from gem.gemlang.ast.ast_base import ASTNode
from gem.gemlang.ast.ast_statement import EpiDecl
from gem.gemlang.ast.ast_expression import IdRef
from gem.gemlang.semantics.symbol_table import SymbolTable
from gem.gemlang.semantics.symbol import VariableSymbol, STMSymbol
......@@ -96,7 +96,7 @@ def get_dependencies(ast_node):
def dag2dot(dag):
if not have_graphviz:
Warning("Install graphviz to generate DAG visualisations")
print(Warning("Install graphviz to generate DAG visualisations"))
return None
dot = Digraph()
......@@ -152,18 +152,27 @@ class DAG(ASTWalker):
rv_decl = node.children[0]
dependent_node = DAGNode(rv_decl.value)
self.__dag[rv_decl.value] = dependent_node
print(f"Push {rv_decl.value}")
self.__current_rv.append(dependent_node)
self._walk(node.children[1])
print(f"Pop {rv_decl.value}")
self.__current_rv.pop()
raise StopWalk()
def onEnter_IdRef(self, node):
if isinstance(node.symbol, VariableSymbol) and node.symbol.type == SymbolTable._random_variable:
parent = self.__dag[node.value]
child = self.__current_rv[-1]
child.add_parent(parent)
parent.add_child(child)
if len(self.__current_rv) > 0:
child = self.__current_rv[-1]
child.add_parent(parent)
parent.add_child(child)
raise StopWalk()
def onEnter_Call(self, call):
symb = call.children[0].symbol
fn_def = symb.ast_def
if isinstance(fn_def, EpiDecl):
self._walk(fn_def.children[2])
def dot(self):
return dag2dot(self.__dag)
\ No newline at end of file
......@@ -89,8 +89,8 @@ class CodeGenerator(ASTWalker):
def onExit_EpiDecl(self, epidecl):
name = epidecl.name
config = epidecl.children[1]
param = epidecl.children[2]
config = epidecl.children[0]
param = epidecl.children[1]
stm = out.STM(name, config, param, self.current_workspace['states'],
self.current_workspace['transitions'],
self.current_workspace['initial'])
......@@ -103,6 +103,9 @@ class CodeGenerator(ASTWalker):
args = self._walk(call.children[1])
return out.Function(str(fname), args)
def onExit_FArgList(self, farglist):
return self.onExit_ArgList(farglist)
def onExit_ArgList(self, arglist):
args = [arg for arg in arglist.children]
return args
......
......@@ -102,6 +102,8 @@ class STM(Outputter):
**kwargs):
super().__init__(**kwargs)
self.name = name
self.config = config
self.param = param
self.states = states
self.transitions = transitions
self.inits = initial_states
......@@ -155,10 +157,16 @@ class STM(Outputter):
self.dedent()
return f_repr
def fargs(self):
if len(self.param) == 0:
return ""
str_args = [str(arg) for arg in self.param]
return ", ".join(str_args) + ','
def __str__(self):
s = f"class {self.name}_(TFContinuousMarkovILM):\n"
self.indent()
s += self.line("def __init__(self, name=None):")
s += self.line(f"def __init__(self, {self.fargs()} name=None):")
self.indent()
s += self.line(f"super().__init__(name=name)")
s += self.line("self.time_origin = 0.")
......@@ -356,10 +364,11 @@ class STMInstance(Function):
def __str__(self):
s = f"{self.fname}("
if self.args:
s += f"{self.args}, "
s += f"{self.args} "
s += f"value={self.value}, name='{self.name}')"
return s
class Number(Outputter):
def __init__(self, value):
self.value = value
......
......@@ -20,7 +20,7 @@
import unittest
from gem import GEM
from gem.gemlang.dag import dag2dot
from gem.gemlang.dag import DAG
......@@ -38,17 +38,13 @@ class TestSimpleDAGGenerator(unittest.TestCase):
x = [1.,2.,3.,4.]
self.model = GEM(prog, {'x': x})
def test_dag_generation(self):
from gem.gemlang.dag import DAG
DAG(self.model.ast)
def test_dot_generation(self):
from gem.gemlang.dag import DAG
dag = DAG(self.model.ast)
self.assertEqual(dag.dot().source, self.expected_graph)
dot = dag.dot()
self.assertEqual(dot.source, self.expected_graph)
@unittest.skip("Work in progress")
class TestEpiDAGGenerator(unittest.TestCase):
def setUp(self):
self.expected_graph = "digraph {\n\tbeta\n\tgamma\n\tepi\n\tbeta -> epi\n\tgamma -> epi\n}"
......@@ -65,13 +61,9 @@ class TestEpiDAGGenerator(unittest.TestCase):
}
epi ~ SIR()
"""
ast = gemparse(prog)
symtab = SymbolDeclarer().visit(ast)
SymbolResolver().visit(ast, symtab)
self.ast = ast
self.symtab = symtab
self.model = GEM(prog)
def test_dag_generator(self):
dag = generate_dag(self.symtab)
dot = dag2dot(dag)
dag = DAG(self.model.ast)
dot = dag.dot()
self.assertEqual(dot.source, self.expected_graph)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment