Created
December 9, 2019 16:36
-
-
Save edofic/d0252cf9a5d6973c2590a6d35eb64437 to your computer and use it in GitHub Desktop.
generator based recursive descent parsers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import doctest | |
from collections import namedtuple | |
import sys | |
def parse_string(target): | |
""" | |
>>> list(parse_string('foo')('foobar')) | |
[('foo', 'bar')] | |
>>> list(parse_string('foo')('bar')) | |
[] | |
""" | |
def p(s): | |
if s.startswith(target): | |
yield (target, s[len(target) :]) | |
return p | |
def parse_end(s): | |
""" | |
>>> list(parse_end('')) | |
[((), '')] | |
>>> list(parse_end(' ')) | |
[] | |
""" | |
if s == "": | |
yield ((), "") | |
def parse_one_of(*parsers): | |
""" | |
>>> list(parse_one_of(parse_string('foo'), parse_string('bar'))('foobar')) | |
[('foo', 'bar')] | |
>>> list(parse_one_of(parse_string('foo'), parse_string('bar'))('barfoo')) | |
[('bar', 'foo')] | |
""" | |
def p(s): | |
for parser in parsers: | |
yield from parser(s) | |
return p | |
parse_digit = parse_one_of(*[parse_string(d) for d in "0123456789"]) | |
parse_digit.__doc__ = """ | |
>>> list(parse_digit('7')) | |
[('7', '')] | |
""" | |
def parse_all(*parsers): | |
""" | |
>>> list(parse_all(parse_digit, parse_digit)('12')) | |
[(('1', '2'), '')] | |
""" | |
def p(s): | |
for value, s1 in parsers[0](s): | |
if len(parsers) == 1: | |
yield ([value], s1) | |
else: | |
for values, s2 in parse_all(*parsers[1:])(s1): | |
yield (tuple([value] + list(values)), s2) | |
return p | |
def parse_map(f, p): | |
""" | |
>>> list(parse_map(int, parse_digit)('7')) | |
[(7, '')] | |
""" | |
def p2(s): | |
for value, s2 in p(s): | |
yield f(value), s2 | |
return p2 | |
def parse_plus(parser): | |
""" | |
>>> list(parse_plus(parse_digit)('7')) | |
[(['7'], '')] | |
>>> list(parse_plus(parse_digit)('713a')) | |
[(['7', '1', '3'], 'a')] | |
""" | |
def p(s): | |
for value, s1 in parser(s): | |
more = False | |
if s1: | |
for values, s2 in p(s1): | |
more = True | |
yield [value] + values, s2 | |
if not more: | |
yield [value], s1 | |
return p | |
Sum = namedtuple("Sum", ["n", "m"]) | |
Product = namedtuple("Product", ["n", "m"]) | |
parse_num = parse_map(lambda digits: int("".join(digits)), parse_plus(parse_digit)) | |
parse_num.__doc__ = """ | |
>>> list(parse_num('123')) | |
[(123, '')] | |
""" | |
def parse_product(s): | |
""" | |
>>> list(parse_product('1*2')) | |
[(Product(n=1, m=2), '')] | |
>>> list(parse_product('1*2*3')) | |
[(Product(n=1, m=2), '*3'), (Product(n=1, m=Product(n=2, m=3)), '')] | |
""" | |
return parse_map( | |
lambda vs: Product(vs[0], vs[2]), | |
parse_all(parse_num, parse_string("*"), parse_one_of(parse_num, parse_product)), | |
)(s) | |
def parse_sum(s): | |
""" | |
>>> list(parse_sum('1+2')) | |
[(Sum(n=1, m=2), '')] | |
>>> list(parse_sum('1+2+3')) | |
[(Sum(n=1, m=2), '+3'), (Sum(n=1, m=Sum(n=2, m=3)), '')] | |
>>> list(parse_sum('1+2*3')) | |
[(Sum(n=1, m=Product(n=2, m=3)), ''), (Sum(n=1, m=2), '*3')] | |
""" | |
summand = parse_one_of(parse_product, parse_num) | |
return parse_map( | |
lambda vs: Sum(vs[0], vs[2]), | |
parse_all(summand, parse_string("+"), parse_one_of(summand, parse_sum)), | |
)(s) | |
parse_expr = parse_map( | |
lambda vs: vs[0], | |
parse_all(parse_one_of(parse_sum, parse_product, parse_num), parse_end), | |
) | |
parse_expr.__doc__ = """ | |
>>> list(parse_expr('1')) | |
[(1, '')] | |
>>> list(parse_expr('1*2')) | |
[(Product(n=1, m=2), '')] | |
>>> list(parse_expr('1+2')) | |
[(Sum(n=1, m=2), '')] | |
>>> list(parse_expr('1+2*3')) | |
[(Sum(n=1, m=Product(n=2, m=3)), '')] | |
""" | |
def eval_expr(e): | |
""" | |
>>> eval_expr(Sum(1, Product(2, 3))) | |
7 | |
""" | |
if isinstance(e, int): | |
return e | |
elif isinstance(e, Sum): | |
return eval_expr(e.n) + eval_expr(e.m) | |
elif isinstance(e, Product): | |
return eval_expr(e.n) * eval_expr(e.m) | |
else: | |
raise ValueError("{}:{} must be a int|Sum|Product tree".format(e, type(e))) | |
def run_repl(): | |
while True: | |
print('\nEnter expression to evaluate, "q" to quit') | |
print("> ", end="") | |
raw = input() | |
stripped = raw.replace(" ", "") | |
if stripped == "q": | |
return | |
print(stripped) | |
gen = parse_expr(stripped) | |
expr, rem = next(gen) | |
print(expr, rem) | |
res = eval_expr(expr) | |
print(res) | |
def print_usage(): | |
print(f"Usage: {sys.argv[0]} test|repl\n") | |
if __name__ == "__main__": | |
if len(sys.argv) != 2: | |
print_usage() | |
else: | |
if sys.argv[1] == "test": | |
doctest.testmod() | |
elif sys.argv[1] == "repl": | |
run_repl() | |
else: | |
print_usage() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment