| @@ -17,9 +17,6 @@ class LexerConf(Serialize): | |||||
| self.skip_validation = skip_validation | self.skip_validation = skip_validation | ||||
| self.use_bytes = use_bytes | self.use_bytes = use_bytes | ||||
| def _deserialize(self): | |||||
| self.callbacks = {} # TODO | |||||
| ###} | ###} | ||||
| class ParserConf: | class ParserConf: | ||||
| @@ -11,7 +11,7 @@ from .common import LexerConf, ParserConf | |||||
| from .lexer import Lexer, TraditionalLexer, TerminalDef, UnexpectedToken | from .lexer import Lexer, TraditionalLexer, TerminalDef, UnexpectedToken | ||||
| from .parse_tree_builder import ParseTreeBuilder | from .parse_tree_builder import ParseTreeBuilder | ||||
| from .parser_frontends import get_frontend | |||||
| from .parser_frontends import get_frontend, _get_lexer_callbacks | |||||
| from .grammar import Rule | from .grammar import Rule | ||||
| import re | import re | ||||
| @@ -278,12 +278,10 @@ class Lark(Serialize): | |||||
| rule.options.priority = None | rule.options.priority = None | ||||
| # TODO Deprecate lexer_callbacks? | # TODO Deprecate lexer_callbacks? | ||||
| lexer_callbacks = dict(self.options.lexer_callbacks) | |||||
| if self.options.transformer: | |||||
| t = self.options.transformer | |||||
| for term in self.terminals: | |||||
| if hasattr(t, term.name): | |||||
| lexer_callbacks[term.name] = getattr(t, term.name) | |||||
| lexer_callbacks = (_get_lexer_callbacks(self.options.transformer, self.terminals) | |||||
| if self.options.transformer | |||||
| else {}) | |||||
| lexer_callbacks.update(self.options.lexer_callbacks) | |||||
| self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes) | self.lexer_conf = LexerConf(self.terminals, re_module, self.ignore_tokens, self.options.postlex, lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes) | ||||
| @@ -344,7 +342,14 @@ class Lark(Serialize): | |||||
| self.rules = [Rule.deserialize(r, memo) for r in data['rules']] | self.rules = [Rule.deserialize(r, memo) for r in data['rules']] | ||||
| self.source = '<deserialized>' | self.source = '<deserialized>' | ||||
| self._prepare_callbacks() | self._prepare_callbacks() | ||||
| self.parser = self.parser_class.deserialize(data['parser'], memo, self._callbacks, self.options.postlex, re_module) | |||||
| self.parser = self.parser_class.deserialize( | |||||
| data['parser'], | |||||
| memo, | |||||
| self._callbacks, | |||||
| self.options.postlex, | |||||
| self.options.transformer, | |||||
| re_module | |||||
| ) | |||||
| return self | return self | ||||
| @classmethod | @classmethod | ||||
| @@ -1,6 +1,6 @@ | |||||
| from .utils import get_regexp_width, Serialize | from .utils import get_regexp_width, Serialize | ||||
| from .parsers.grammar_analysis import GrammarAnalyzer | from .parsers.grammar_analysis import GrammarAnalyzer | ||||
| from .lexer import TraditionalLexer, ContextualLexer, Lexer, Token | |||||
| from .lexer import TraditionalLexer, ContextualLexer, Lexer, Token, TerminalDef | |||||
| from .parsers import earley, xearley, cyk | from .parsers import earley, xearley, cyk | ||||
| from .parsers.lalr_parser import LALR_Parser | from .parsers.lalr_parser import LALR_Parser | ||||
| from .grammar import Rule | from .grammar import Rule | ||||
| @@ -58,6 +58,15 @@ class _ParserFrontend(Serialize): | |||||
| return self.parser.parse(input, start, *args) | return self.parser.parse(input, start, *args) | ||||
| def _get_lexer_callbacks(transformer, terminals): | |||||
| result = {} | |||||
| for terminal in terminals: | |||||
| callback = getattr(transformer, terminal.name, None) | |||||
| if callback is not None: | |||||
| result[terminal.name] = callback | |||||
| return result | |||||
| class WithLexer(_ParserFrontend): | class WithLexer(_ParserFrontend): | ||||
| lexer = None | lexer = None | ||||
| parser = None | parser = None | ||||
| @@ -73,13 +82,18 @@ class WithLexer(_ParserFrontend): | |||||
| self.postlex = lexer_conf.postlex | self.postlex = lexer_conf.postlex | ||||
| @classmethod | @classmethod | ||||
| def deserialize(cls, data, memo, callbacks, postlex, re_module): | |||||
| def deserialize(cls, data, memo, callbacks, postlex, transformer, re_module): | |||||
| inst = super(WithLexer, cls).deserialize(data, memo) | inst = super(WithLexer, cls).deserialize(data, memo) | ||||
| inst.postlex = postlex | inst.postlex = postlex | ||||
| inst.parser = LALR_Parser.deserialize(inst.parser, memo, callbacks) | inst.parser = LALR_Parser.deserialize(inst.parser, memo, callbacks) | ||||
| terminals = [item for item in memo.values() if isinstance(item, TerminalDef)] | |||||
| inst.lexer_conf.callbacks = _get_lexer_callbacks(transformer, terminals) | |||||
| inst.lexer_conf.re_module = re_module | inst.lexer_conf.re_module = re_module | ||||
| inst.lexer_conf.skip_validation=True | inst.lexer_conf.skip_validation=True | ||||
| inst.init_lexer() | inst.init_lexer() | ||||
| return inst | return inst | ||||
| def _serialize(self, data, memo): | def _serialize(self, data, memo): | ||||
| @@ -229,4 +243,3 @@ class CYK(WithLexer): | |||||
| def _apply_callback(self, tree): | def _apply_callback(self, tree): | ||||
| return self.callbacks[tree.rule](tree.children) | return self.callbacks[tree.rule](tree.children) | ||||
| @@ -106,6 +106,33 @@ class TestStandalone(TestCase): | |||||
| x = l.parse('(\n)\n') | x = l.parse('(\n)\n') | ||||
| self.assertEqual(x, Tree('start', [])) | self.assertEqual(x, Tree('start', [])) | ||||
| def test_transformer(self): | |||||
| grammar = r""" | |||||
| start: some_rule "(" SOME_TERMINAL ")" | |||||
| some_rule: SOME_TERMINAL | |||||
| SOME_TERMINAL: /[A-Za-z_][A-Za-z0-9_]*/ | |||||
| """ | |||||
| context = self._create_standalone(grammar) | |||||
| _Lark = context["Lark_StandAlone"] | |||||
| _Token = context["Token"] | |||||
| _Tree = context["Tree"] | |||||
| class MyTransformer(context["Transformer"]): | |||||
| def SOME_TERMINAL(self, token): | |||||
| return _Token("SOME_TERMINAL", "token is transformed") | |||||
| def some_rule(self, children): | |||||
| return _Tree("rule_is_transformed", []) | |||||
| parser = _Lark(transformer=MyTransformer()) | |||||
| self.assertEqual( | |||||
| parser.parse("FOO(BAR)"), | |||||
| _Tree("start", [ | |||||
| _Tree("rule_is_transformed", []), | |||||
| _Token("SOME_TERMINAL", "token is transformed") | |||||
| ]) | |||||
| ) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||