diff --git a/lark/common.py b/lark/common.py index e1ec220..7103d14 100644 --- a/lark/common.py +++ b/lark/common.py @@ -20,6 +20,7 @@ class LexerConf(Serialize): class ParserConf: def __init__(self, rules, callbacks, start): + assert isinstance(start, list) self.rules = rules self.callbacks = callbacks self.start = start diff --git a/lark/exceptions.py b/lark/exceptions.py index f781968..4207589 100644 --- a/lark/exceptions.py +++ b/lark/exceptions.py @@ -52,7 +52,7 @@ class UnexpectedInput(LarkError): class UnexpectedCharacters(LexError, UnexpectedInput): - def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None): + def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None): message = "No terminal defined for '%s' at line %d col %d" % (seq[lex_pos], line, column) self.line = line @@ -65,6 +65,8 @@ class UnexpectedCharacters(LexError, UnexpectedInput): message += '\n\n' + self.get_context(seq) if allowed: message += '\nExpecting: %s\n' % allowed + if token_history: + message += '\nPrevious tokens: %s\n' % ', '.join(repr(t) for t in token_history) super(UnexpectedCharacters, self).__init__(message) diff --git a/lark/lark.py b/lark/lark.py index 87f7137..82cf76a 100644 --- a/lark/lark.py +++ b/lark/lark.py @@ -85,6 +85,9 @@ class LarkOptions(Serialize): options[name] = value + if isinstance(options['start'], str): + options['start'] = [options['start']] + self.__dict__['options'] = options assert self.parser in ('earley', 'lalr', 'cyk', None) @@ -287,8 +290,8 @@ class Lark(Serialize): return self.options.postlex.process(stream) return stream - def parse(self, text): + def parse(self, text, start=None): "Parse the given text, according to the options provided. Returns a tree, unless specified otherwise." - return self.parser.parse(text) + return self.parser.parse(text, start=start) ###} diff --git a/lark/lexer.py b/lark/lexer.py index bdf635d..3e881f8 100644 --- a/lark/lexer.py +++ b/lark/lexer.py @@ -149,6 +149,7 @@ class _Lex: newline_types = frozenset(newline_types) ignore_types = frozenset(ignore_types) line_ctr = LineCounter() + last_token = None while line_ctr.char_pos < len(stream): lexer = self.lexer @@ -166,6 +167,7 @@ class _Lex: t = lexer.callback[t.type](t) if not isinstance(t, Token): raise ValueError("Callbacks must return a token (returned %r)" % t) + last_token = t yield t else: if type_ in lexer.callback: @@ -180,7 +182,7 @@ class _Lex: break else: allowed = {v for m, tfi in lexer.mres for v in tfi.values()} - raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state) + raise UnexpectedCharacters(stream, line_ctr.char_pos, line_ctr.line, line_ctr.column, allowed=allowed, state=self.state, token_history=last_token and [last_token]) class UnlessCallback: diff --git a/lark/load_grammar.py b/lark/load_grammar.py index 8bda118..f7b1011 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -554,7 +554,8 @@ class Grammar: for s in r.expansion if isinstance(s, NonTerminal) and s != r.origin} - compiled_rules = [r for r in compiled_rules if r.origin.name==start or r.origin in used_rules] + used_rules |= {NonTerminal(s) for s in start} + compiled_rules = [r for r in compiled_rules if r.origin in used_rules] if len(compiled_rules) == c: break @@ -690,7 +691,7 @@ class GrammarLoader: callback = ParseTreeBuilder(rules, ST).create_callback() lexer_conf = LexerConf(terminals, ['WS', 'COMMENT']) - parser_conf = ParserConf(rules, callback, 'start') + parser_conf = ParserConf(rules, callback, ['start']) self.parser = LALR_TraditionalLexer(lexer_conf, parser_conf) self.canonize_tree = CanonizeTree() diff --git a/lark/parser_frontends.py b/lark/parser_frontends.py index 0634814..1b55fe1 100644 --- a/lark/parser_frontends.py +++ b/lark/parser_frontends.py @@ -44,18 +44,28 @@ def get_frontend(parser, lexer): raise ValueError('Unknown parser: %s' % parser) +class _ParserFrontend(Serialize): + def _parse(self, input, start, *args): + if start is None: + start = self.start + if len(start) > 1: + raise ValueError("Lark initialized with more than 1 possible start rule. Must specify which start rule to parse", start) + start ,= start + return self.parser.parse(input, start, *args) -class WithLexer(Serialize): +class WithLexer(_ParserFrontend): lexer = None parser = None lexer_conf = None + start = None - __serialize_fields__ = 'parser', 'lexer_conf' + __serialize_fields__ = 'parser', 'lexer_conf', 'start' __serialize_namespace__ = LexerConf, def __init__(self, lexer_conf, parser_conf, options=None): self.lexer_conf = lexer_conf + self.start = parser_conf.start self.postlex = lexer_conf.postlex @classmethod @@ -73,10 +83,10 @@ class WithLexer(Serialize): stream = self.lexer.lex(text) return self.postlex.process(stream) if self.postlex else stream - def parse(self, text): + def parse(self, text, start=None): token_stream = self.lex(text) sps = self.lexer.set_parser_state - return self.parser.parse(token_stream, *[sps] if sps is not NotImplemented else []) + return self._parse(token_stream, start, *[sps] if sps is not NotImplemented else []) def init_traditional_lexer(self): self.lexer = TraditionalLexer(self.lexer_conf.tokens, ignore=self.lexer_conf.ignore, user_callbacks=self.lexer_conf.callbacks) @@ -135,9 +145,10 @@ class Earley(WithLexer): return term.name == token.type -class XEarley: +class XEarley(_ParserFrontend): def __init__(self, lexer_conf, parser_conf, options=None, **kw): self.token_by_name = {t.name:t for t in lexer_conf.tokens} + self.start = parser_conf.start self._prepare_match(lexer_conf) resolve_ambiguity = options.ambiguity == 'resolve' @@ -167,8 +178,8 @@ class XEarley: self.regexps[t.name] = re.compile(regexp) - def parse(self, text): - return self.parser.parse(text) + def parse(self, text, start): + return self._parse(text, start) class XEarley_CompleteLex(XEarley): def __init__(self, *args, **kw): @@ -187,7 +198,7 @@ class CYK(WithLexer): self.callbacks = parser_conf.callbacks - def parse(self, text): + def parse(self, text, start): tokens = list(self.lex(text)) parse = self._parser.parse(tokens) parse = self._transform(parse) diff --git a/lark/parsers/cyk.py b/lark/parsers/cyk.py index 2121449..52584a7 100644 --- a/lark/parsers/cyk.py +++ b/lark/parsers/cyk.py @@ -89,7 +89,7 @@ class Parser(object): self.orig_rules = {rule: rule for rule in rules} rules = [self._to_rule(rule) for rule in rules] self.grammar = to_cnf(Grammar(rules)) - self.start = NT(start) + self.start = NT(start[0]) def _to_rule(self, lark_rule): """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" diff --git a/lark/parsers/earley.py b/lark/parsers/earley.py index 0518174..4d6201b 100644 --- a/lark/parsers/earley.py +++ b/lark/parsers/earley.py @@ -273,8 +273,9 @@ class Parser: ## Column is now the final column in the parse. assert i == len(columns)-1 - def parse(self, stream, start_symbol=None): - start_symbol = NonTerminal(start_symbol or self.parser_conf.start) + def parse(self, stream, start): + assert start, start + start_symbol = NonTerminal(start) columns = [set()] to_scan = set() # The scan buffer. 'Q' in E.Scott's paper. diff --git a/lark/parsers/grammar_analysis.py b/lark/parsers/grammar_analysis.py index ab84efb..086349c 100644 --- a/lark/parsers/grammar_analysis.py +++ b/lark/parsers/grammar_analysis.py @@ -109,8 +109,10 @@ class GrammarAnalyzer(object): def __init__(self, parser_conf, debug=False): self.debug = debug - root_rule = Rule(NonTerminal('$root'), [NonTerminal(parser_conf.start), Terminal('$END')]) - rules = parser_conf.rules + [root_rule] + root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start), Terminal('$END')]) + for start in parser_conf.start} + + rules = parser_conf.rules + list(root_rules.values()) self.rules_by_origin = classify(rules, lambda r: r.origin) if len(rules) != len(set(rules)): @@ -122,10 +124,11 @@ class GrammarAnalyzer(object): if not (sym.is_term or sym in self.rules_by_origin): raise GrammarError("Using an undefined rule: %s" % sym) # TODO test validation - self.start_state = self.expand_rule(root_rule.origin) + self.start_states = {start: self.expand_rule(root_rule.origin) + for start, root_rule in root_rules.items()} - end_rule = RulePtr(root_rule, len(root_rule.expansion)) - self.end_state = fzset({end_rule}) + self.end_states = {start: fzset({RulePtr(root_rule, len(root_rule.expansion))}) + for start, root_rule in root_rules.items()} self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules) diff --git a/lark/parsers/lalr_analysis.py b/lark/parsers/lalr_analysis.py index ee2f75c..eef1f9b 100644 --- a/lark/parsers/lalr_analysis.py +++ b/lark/parsers/lalr_analysis.py @@ -29,10 +29,10 @@ Shift = Action('Shift') Reduce = Action('Reduce') class ParseTable: - def __init__(self, states, start_state, end_state): + def __init__(self, states, start_states, end_states): self.states = states - self.start_state = start_state - self.end_state = end_state + self.start_states = start_states + self.end_states = end_states def serialize(self, memo): tokens = Enumerator() @@ -47,8 +47,8 @@ class ParseTable: return { 'tokens': tokens.reversed(), 'states': states, - 'start_state': self.start_state, - 'end_state': self.end_state, + 'start_states': self.start_states, + 'end_states': self.end_states, } @classmethod @@ -59,7 +59,7 @@ class ParseTable: for token, (action, arg) in actions.items()} for state, actions in data['states'].items() } - return cls(states, data['start_state'], data['end_state']) + return cls(states, data['start_states'], data['end_states']) class IntParseTable(ParseTable): @@ -76,9 +76,9 @@ class IntParseTable(ParseTable): int_states[ state_to_idx[s] ] = la - start_state = state_to_idx[parse_table.start_state] - end_state = state_to_idx[parse_table.end_state] - return cls(int_states, start_state, end_state) + start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()} + end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()} + return cls(int_states, start_states, end_states) ###} @@ -124,10 +124,10 @@ class LALR_Analyzer(GrammarAnalyzer): self.states[state] = {k.name:v[0] for k, v in lookahead.items()} - for _ in bfs([self.start_state], step): + for _ in bfs(self.start_states.values(), step): pass - self._parse_table = ParseTable(self.states, self.start_state, self.end_state) + self._parse_table = ParseTable(self.states, self.start_states, self.end_states) if self.debug: self.parse_table = self._parse_table diff --git a/lark/parsers/lalr_parser.py b/lark/parsers/lalr_parser.py index aea75ca..39dd5f3 100644 --- a/lark/parsers/lalr_parser.py +++ b/lark/parsers/lalr_parser.py @@ -39,19 +39,22 @@ class LALR_Parser(object): class _Parser: def __init__(self, parse_table, callbacks): self.states = parse_table.states - self.start_state = parse_table.start_state - self.end_state = parse_table.end_state + self.start_states = parse_table.start_states + self.end_states = parse_table.end_states self.callbacks = callbacks - def parse(self, seq, set_state=None): + def parse(self, seq, start, set_state=None): token = None stream = iter(seq) states = self.states - state_stack = [self.start_state] + start_state = self.start_states[start] + end_state = self.end_states[start] + + state_stack = [start_state] value_stack = [] - if set_state: set_state(self.start_state) + if set_state: set_state(start_state) def get_action(token): state = state_stack[-1] @@ -81,7 +84,7 @@ class _Parser: for token in stream: while True: action, arg = get_action(token) - assert arg != self.end_state + assert arg != end_state if action is Shift: state_stack.append(arg) @@ -95,7 +98,7 @@ class _Parser: while True: _action, arg = get_action(token) if _action is Shift: - assert arg == self.end_state + assert arg == end_state val ,= value_stack return val else: diff --git a/tests/test_parser.py b/tests/test_parser.py index d582878..3238ead 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1523,6 +1523,15 @@ def _make_parser_test(LEXER, PARSER): parser3 = Lark.deserialize(d, namespace, m) self.assertEqual(parser3.parse('ABC'), Tree('start', [Tree('b', [])]) ) + def test_multi_start(self): + parser = _Lark(''' + a: "x" "a"? + b: "x" "b"? + ''', start=['a', 'b']) + + self.assertEqual(parser.parse('xa', 'a'), Tree('a', [])) + self.assertEqual(parser.parse('xb', 'b'), Tree('b', [])) + _NAME = "Test" + PARSER.capitalize() + LEXER.capitalize()