100 Languages Speedrun: Episode 75: Abstract Syntax Trees with Python ANTLR 4

The role of a parser is turning text into an Abstract Syntax Tree, with minimum of hassle. In the previous episode, ANTLR 4 failed us utterly, as it merely generates a Concrete Syntax Tree.

So we need to do something about it, and here are a few different design patterns you can use to get either Abstract Syntax Tree, or at least Better Concrete Syntax Trees.

As far as I know, nothing like that is documented anywhere, as ANTLR 4 documentation is painfully Java only, and many techniques I'll be using are Python specific and wouldn't even work in Java.

Abstract and Concrete Syntax Trees

So what are different types of Syntax Trees, and why does it matter?

Let's take a very simple language for parsing just mathematical expressions, and our program is 2 + 3 * 4.

Here's possible Abstract Syntax Tree for that (ignoring namespaces for these classes, debug info like line numbers etc.):

Add(
  Number(2),
  Multiply(
    Number(3),
    Number(4)
  )
)

If we arrive at a tree like that, the parsing part is done, and we can move on with the rest of the program. That's the Abstract Syntax Tree - all the irrelevant details are removed, just meaningful structure remains, and the structure of the tree matches the logical structure of the program.

On the other hand, Concrete Syntax Tree instead returns a tree that very closely follows structure of the grammar. Concrete Syntax Tree for the same program might look like this:

Expr(
  Expr(
    Term(
      Factor(
        Number("2")
      ),
    ),
  ),
  Operator("+"),
  Term(
    Term(
      Factor(
        Number("3")
      )
    ),
    Operator("*"),
    Factor(
      Number("4")
    )
  )
)

This is definitely not what we want! Nearly every parser generator, including older versions of ANTLR, generates AST-style results, because that is always the end goal, with CST at most being an intermediate representation created along the way, and most parsers don't even bother constructing that.

ANTLR 4 somehow made an insane decision to not do that, and only provide the CST, leaving it up to us to convert it to AST. This means any program using ANTLR 4 needs to include a lot of tedious boilerplate for CST to AST conversion. Or operate directly on CST, which is completely impractical except for maybe some toy cases.

Even for this simple program CST was already quite bad, but the more complex the grammar, the more complex CST gets, even for the same AST. If your table of operators has 15 precedence levels (like C), that means 15 extra levels of pointless trivial nodes in the CST! Totally crazy!

Annotated Grammar

Back in ANTLR 1-3 the way to get an AST was to put some actions in the grammar itself. Typically an action would do some part of AST construction. As AST construction is very closely driven by the syntax tree, this usually works beautifully.

ANTLR 4 makes it quite hard to put actions in the grammar, but they provided a few partial replacement features which we can use to get better CST at least.

So we take our Math.g4:

grammar Math;

expr : expr ('+' | '-') term
     | term;

term : term ('*' | '/') factor
     | factor;

factor : '(' expr ')'
       | number
       | identifier;

number: NUM;
identifier: ID;

ID : [a-zA-Z_] [a-zA-Z0-9_]* ;
NUM : '-'? [0-9]+ ( '.' [0-9]*)?;

WS : [ \t\r\n]+ -> skip; // ignore all whitespace

And add some #-annotations to alternative branches, as well as a=/b= annotations to parts of the match, getting this:.

grammar Math;

expr : a=expr '+' b=term # Add
     | a=expr '-' b=term # Sub
     | a=term # TrivialExpr
     ;

term : a=term '*' b=factor # Mul
     | a=term '/' b=factor # Div
     | a=factor # TrivialTerm
     ;

factor : '(' a=expr ')' # TrivialParensExpr
       | a=NUM # Number
       | a=ID  # Identifier
       ;

ID : [a-zA-Z_] [a-zA-Z0-9_]* ;
NUM : '-'? [0-9]+ ( '.' [0-9]*)?;

WS : [ \t\r\n]+ -> skip; // ignore all whitespace

Let's see the resulting CST, and how much it's improved:

Sum(
  TrivialExpr(
    TrivialTerm(
      Number("2")
    ),
  ),
  Mul(
    TrivialTerm(
      Number("3")
    ),
    Number("4")
  )
)

This isn't quite AST, as Numbers are not aware they should be converting their string contents to floats, and we have a few Trivial* nodes, but we took a good few steps towards our goal.

Let's update both Listener version, and Visitor version to use the new CST.

Updated Listener with Better CSTs

#!/usr/bin/env python3

from antlr4 import *
from MathLexer import MathLexer
from MathParser import MathParser
from MathListener import MathListener
import sys

class MathProgram(MathListener):
  def exitNumber(self, node):
    value = float(node.getText())
    self.stack.append(value)

  def exitIdentifier(self, node):
    value = self.getVar(node.getText())
    self.stack.append(value)

  def exitAdd(self, node):
    b = self.stack.pop()
    a = self.stack.pop()
    self.stack.append(a + b)

  def exitSub(self, node):
    b = self.stack.pop()
    a = self.stack.pop()
    self.stack.append(a - b)

  def exitMul(self, node):
    b = self.stack.pop()
    a = self.stack.pop()
    self.stack.append(a * b)

  def exitDiv(self, node):
    b = self.stack.pop()
    a = self.stack.pop()
    self.stack.append(a / b)

  def getVar(self, name):
    if name not in self.vars:
      self.vars[name] = float(input(f"Enter value for {name}: "))
    return self.vars[name]

  def run(self, node):
    self.stack = []
    self.vars = {}
    ParseTreeWalker().walk(self, node)
    result = self.stack[0]
    print(result)

def parseFile(path):
  lexer = MathLexer(FileStream(path))
  stream = CommonTokenStream(lexer)
  parser = MathParser(stream)
  tree = parser.expr()
  MathProgram().run(tree)

if __name__ == "__main__":
  path = sys.argv[1]
  parseFile(path)

This is more code, but it's so much cleaner than what we had before. There are zero grammar checks like len(node.children) == 3 or node.children[1].getText() == "*". We know which branch matched, and some branches we had to check for before (all the Trivial ones) we can now ignore, as their default action is to do nothing anyway.

This didn't require any extra work on Python code part, and we didn't even use the a= and b= annotations here (these are for the visitor pattern only).

One issue still remains - listener pattern is highly specific to the kind of processing we want to do. For our math program we could get away with simply using self.stack and processing things in order, but this won't be so easy in general.

Updated Visitor with Better CSTs

Visitor will need some Python meta-programming. We'll need to implement slightly magical MathProgram.eval to do our routing depending on type of each node, and for automatically skipping trivial nodes.

But the rest of MathProgram is actually really nice, and almost follows what an AST-processing MathProgram would do.

#!/usr/bin/env python3

from antlr4 import *
from MathLexer import MathLexer
from MathParser import MathParser
import sys

class MathProgram:
  def __init__(self, program):
    self.program = program

  def evalAdd(self, node):
    return self.eval(node.a) + self.eval(node.b)

  def evalSub(self, node):
    return self.eval(node.a) - self.eval(node.b)

  def evalMul(self, node):
    return self.eval(node.a) * self.eval(node.b)

  def evalDiv(self, node):
    return self.eval(node.a) / self.eval(node.b)

  def evalNumber(self, node):
    return float(node.getText())

  def evalIdentifier(self, node):
    return self.getVar(node.getText())

  def getVar(self, name):
    if name not in self.vars:
      self.vars[name] = float(input(f"Enter value for {name}: "))
    return self.vars[name]

  def eval(self, node):
    if not isinstance(node, ParserRuleContext):
      raise Exception(f"{node} must be a node, not a {type(node)}")
    name = "eval" + type(node).__name__[:-7]
    if name[:11] == "evalTrivial":
      return self.eval(node.a)
    return self.__getattribute__(name)(node)

  def run(self):
    self.vars = {}
    result = self.eval(self.program)
    print(result)

def parseFile(path):
  lexer = MathLexer(FileStream(path))
  stream = CommonTokenStream(lexer)
  parser = MathParser(stream)
  tree = parser.expr()
  MathProgram(tree).run()

if __name__ == "__main__":
  path = sys.argv[1]
  parseFile(path)

OK, what's going on here:

  • before, we knew types of each subexpression (like Term), so we could call evalTerm from evalExpr. But these were all stupid types, and we don't want them. Now that Term was split between Mul, Div, and TrivialTerm, evalExpr can't possibly know what node.b is going to be.
  • and so to get the right method to be called, we use eval method, which checks type of the node and calls what it needs to
  • MathProgram.eval method is completely unrelated to Python's eval global function, so it's perhaps a slightly confusing naming, even if it's not ambiguous to Python (self.eval vs eval)
  • evalAdd etc. methods - they need to recursively call self.eval(node.a) - but we really don't care about specific types, so it's much more readable than self.evalExpr(node.a) + self.evalTerm(node.b) would be
  • if we had an AST, evalAdd, evalSub, evalMul, and evalDiv would likely just like they do now, huge win!
  • evalNumber is not perfect, as it need to do float(node.getText()) instead of having that number converted to a float during parsing as AST would, but it's good enough
  • evalIdentifier also needs to call .getText() which does some calculations to build that string, AST would just have .name or such
  • we could implement a bunch of methods like evalTrivialExpr(self, node): return self.eval(node.a), but I made eval method handle all of these automatically - this automatically deals with the biggest difference between our "improved CST" which has those extra nodes, and a true "AST" which wouldn't.

Interestingly, in some languages you can implement multiple methods with the same name, and they'd dispatch on dynamic type of the argument, so we could implement def eval(self, node : DivContext) etc. and it would pick the right one. It's not actually that common - most languages don't support that at all, or only support static version of it and the point is that we don't know argument type statically.

And in some other languages we could easily add extra methods to those generated classes, which would be another way to do it more cleanly. This is even sort of doable in Python, but it's a bit messy.

There are ways to do this with less aggressive meta-programming, like a bunch of ininstance statements, or a dictionary keyed by class, but I think this is the most concise.

Abstract Syntax Tree

Well, so how about we go beyond the Better Concrete Syntax Tree idea, and just build a whole new AST. This is the best way if you need to do a lot of processing. You can leave the CST behind, and work with just very nice AST.

Here's one way to do this:

#!/usr/bin/env python3

from antlr4 import *
from MathLexer import MathLexer
from MathParser import MathParser
from collections import namedtuple
import sys

AddNode = namedtuple("AddNode", ["a", "b"])
SubNode = namedtuple("SubNode", ["a", "b"])
MulNode = namedtuple("MulNode", ["a", "b"])
DivNode = namedtuple("DivNode", ["a", "b"])
NumberNode = namedtuple("NumberNode", ["value"])
IdentifierNode = namedtuple("IdentifierNode", ["name"])

class MathAstBuilder:
  def buildAdd(self, node):
    return AddNode(self.build(node.a), self.build(node.b))

  def buildSub(self, node):
    return SubNode(self.build(node.a), self.build(node.b))

  def buildMul(self, node):
    return MulNode(self.build(node.a), self.build(node.b))

  def buildDiv(self, node):
    return DivNode(self.build(node.a), self.build(node.b))

  def buildNumber(self, node):
    return NumberNode(float(node.getText()))

  def buildIdentifier(self, node):
    return IdentifierNode(node.getText())

  def build(self, node):
    if not isinstance(node, ParserRuleContext):
      raise Exception(f"{node} must be a node, not a {type(node)}")
    name = "build" + type(node).__name__[:-7]
    if name[:12] == "buildTrivial":
      return self.build(node.a)
    return self.__getattribute__(name)(node)

class MathProgram:
  def __init__(self, program):
    self.program = program

  def evalAdd(self, node):
    return self.eval(node.a) + self.eval(node.b)

  def evalSub(self, node):
    return self.eval(node.a) - self.eval(node.b)

  def evalMul(self, node):
    return self.eval(node.a) * self.eval(node.b)

  def evalDiv(self, node):
    return self.eval(node.a) / self.eval(node.b)

  def evalNumber(self, node):
    return node.value

  def evalIdentifier(self, node):
    return self.getVar(node.name)

  def getVar(self, name):
    if name not in self.vars:
      self.vars[name] = float(input(f"Enter value for {name}: "))
    return self.vars[name]

  def eval(self, node):
    name = "eval" + type(node).__name__[:-4]
    return self.__getattribute__(name)(node)

  def run(self):
    self.vars = {}
    result = self.eval(self.program)
    print(result)

def parseFile(path):
  lexer = MathLexer(FileStream(path))
  stream = CommonTokenStream(lexer)
  parser = MathParser(stream)
  cst = parser.expr()
  ast = MathAstBuilder().build(cst)
  MathProgram(ast).run()

if __name__ == "__main__":
  path = sys.argv[1]
  parseFile(path)

Step by step:

  • we need to define every AST node type - there's usually going to be a lot less than CST node types (usually around half)
  • as these are super simple we can use collections.namedtuple to define a lot of them at once
  • we could also use dictionaries to represent nodes etc. - AST can be anything you want
  • MathAstBuilder.build does AST construction, using the same metaprogramming as we had before - it's mostly independent of the kind of language we have, we just need to dig down the CST, skip trivial nodes, and call the appropriate build method recursively
  • then we have MathProgram - in this case it's the same tree of eval as we had before, but that's just because it's such a simple kind of program. Normally this part would be the most complex part of your program, and contain all the actual logic.
  • MathProgram still uses some metaprogramming in eval but we can get rid of that...

Abstract Syntax Tree with proper classes

If you don't like MathProgram.eval and want your program to use proper classes and no magic beyond parsing, here's another way to do the AST. All the code in the logic in just normal OOP methods:

#!/usr/bin/env python3

from antlr4 import *
from MathLexer import MathLexer
from MathParser import MathParser
from collections import namedtuple
import sys

class AddNode:
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def eval(self, context):
    return self.a.eval(context) + self.b.eval(context)

class SubNode:
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def eval(self, context):
    return self.a.eval(context) - self.b.eval(context)

class MulNode:
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def eval(self, context):
    return self.a.eval(context) * self.b.eval(context)

class DivNode:
  def __init__(self, a, b):
    self.a = a
    self.b = b

  def eval(self, context):
    return self.a.eval(context) / self.b.eval(context)

class NumberNode:
  def __init__(self, value):
    self.value = value

  def eval(self, context):
    return self.value

class IdentifierNode:
  def __init__(self, name):
    self.name = name

  def eval(self, context):
    return context.getVar(self.name)

class MathAstBuilder:
  def buildAdd(self, node):
    return AddNode(self.build(node.a), self.build(node.b))

  def buildSub(self, node):
    return SubNode(self.build(node.a), self.build(node.b))

  def buildMul(self, node):
    return MulNode(self.build(node.a), self.build(node.b))

  def buildDiv(self, node):
    return DivNode(self.build(node.a), self.build(node.b))

  def buildNumber(self, node):
    return NumberNode(float(node.getText()))

  def buildIdentifier(self, node):
    return IdentifierNode(node.getText())

  def build(self, node):
    if not isinstance(node, ParserRuleContext):
      raise Exception(f"{node} must be a node, not a {type(node)}")
    name = "build" + type(node).__name__[:-7]
    if name[:12] == "buildTrivial":
      return self.build(node.a)
    return self.__getattribute__(name)(node)

class MathProgram:
  def __init__(self, program):
    self.program = program

  def getVar(self, name):
    if name not in self.vars:
      self.vars[name] = float(input(f"Enter value for {name}: "))
    return self.vars[name]

  def run(self):
    self.vars = {}
    result = self.program.eval(self)
    print(result)

def parseFile(path):
  lexer = MathLexer(FileStream(path))
  stream = CommonTokenStream(lexer)
  parser = MathParser(stream)
  cst = parser.expr()
  ast = MathAstBuilder().build(cst)
  MathProgram(ast).run()

if __name__ == "__main__":
  path = sys.argv[1]
  parseFile(path)

Abstractify Concrete Syntax Tree

And for the last design pattern, how about we don't build a new AST, but take the existing CST, but "abstractify" it by removing any trivial nodes, and doing any node-specific fixes, like converting strings to numbers? This is very easy in Python, as it's very dynamic. It would of course be completely impossible in Java with its rigid static type system.

#!/usr/bin/env python3

from antlr4 import *
from MathLexer import MathLexer
from MathParser import MathParser
import sys

class MathAbstractify:
  def abstractifyAdd(self, node):
    node.a = self.abstractify(node.a)
    node.b = self.abstractify(node.b)

  def abstractifySub(self, node):
    node.a = self.abstractify(node.a)
    node.b = self.abstractify(node.b)

  def abstractifyMul(self, node):
    node.a = self.abstractify(node.a)
    node.b = self.abstractify(node.b)

  def abstractifyDiv(self, node):
    node.a = self.abstractify(node.a)
    node.b = self.abstractify(node.b)

  def abstractifyNumber(self, node):
    node.value = float(node.getText())

  def abstractifyIdentifier(self, node):
    node.name = node.getText()

  def abstractify(self, node):
    if not isinstance(node, ParserRuleContext):
      raise Exception(f"{node} must be a node, not a {type(node)}")
    name = "abstractify" + type(node).__name__[:-7]
    if name[:18] == "abstractifyTrivial":
      return self.abstractify(node.a)
    else:
      return self.__getattribute__(name)(node) or node

class MathProgram:
  def __init__(self, program):
    self.program = program

  def evalAdd(self, node):
    return self.eval(node.a) + self.eval(node.b)

  def evalSub(self, node):
    return self.eval(node.a) - self.eval(node.b)

  def evalMul(self, node):
    return self.eval(node.a) * self.eval(node.b)

  def evalDiv(self, node):
    return self.eval(node.a) / self.eval(node.b)

  def evalNumber(self, node):
    return node.value

  def evalIdentifier(self, node):
    return self.getVar(node.name)

  def getVar(self, name):
    if name not in self.vars:
      self.vars[name] = float(input(f"Enter value for {name}: "))
    return self.vars[name]

  def eval(self, node):
    if not isinstance(node, ParserRuleContext):
      raise Exception(f"{node} must be a node, not a {type(node)}")
    name = "eval" + type(node).__name__[:-7]
    return self.__getattribute__(name)(node)

  def run(self):
    self.vars = {}
    result = self.eval(self.program)
    print(result)

def parseFile(path):
  lexer = MathLexer(FileStream(path))
  stream = CommonTokenStream(lexer)
  parser = MathParser(stream)
  tree = parser.expr()
  tree = MathAbstractify().abstractify(tree)
  MathProgram(tree).run()

if __name__ == "__main__":
  path = sys.argv[1]
  parseFile(path)

Step by step:

  • The program has two parts now - MathAbstractify to convert CST into basically AST, and MathProgram to execute it
  • MathAbstractify.abstractifyX calls MathAbstractify.abstractify on the node's children, and returns the updated node
  • for trivial nodes, MathAbstractify.abstractify simply removes the node and returns whatever is MathAbstractify.abstractify of the child - so resulting tree will have no trivial node
  • for node that require special processing like Number and Identifier we do that, and assign the results to appropriate fields
  • the code in MathAbstractify.abstractify is very repetitive, and there are many ways to make it a lot more concise, like having a dictionary with each node's fields to update, but I didn't want to complicate this even further
  • after that our MathProgram is pretty much the same as our first Abstract Syntax Tree version - no need to deal with trivial nodes, and everything on the nodes is precomputed

This pattern is arguably the hackiest of them all, but we get most of the advantage of ASTs without having to define all the node types. MathAbstractify could also be a lot more concise than what I did here.

Should you use Python ANTLR 4 now?

I think with these design patterns, it's a much better experience than with the official Java-style code documentation recommends.

ANTLR 4 is still likely the most powerful parser generator out there, and it deserves a lot better API than the one it got.

None of these design patterns I presented here are as clean as what we could have if ANTLR 4 supported direct AST generation, but they're useful to consider depending on the type of program you're writing. I definitely do not recommend following either of the official Java-style patterns I showed in the previous episode.

Code

All code examples for the series will be in this repository.

Code for the Abstract Syntax Trees with Python ANTLR 4 episode is available here.