| @@ -9,6 +9,7 @@ _R = TypeVar('_R') | |||||
| _FUNC = Callable[..., _T] | _FUNC = Callable[..., _T] | ||||
| _DECORATED = Union[_FUNC, type] | _DECORATED = Union[_FUNC, type] | ||||
| class Transformer(ABC, Generic[_T]): | class Transformer(ABC, Generic[_T]): | ||||
| def __init__(self, visit_tokens: bool = True) -> None: | def __init__(self, visit_tokens: bool = True) -> None: | ||||
| @@ -38,6 +39,14 @@ class Transformer_InPlace(Transformer): | |||||
| pass | pass | ||||
| class Transformer_NonRecursive(Transformer): | |||||
| pass | |||||
| class Transformer_InPlaceRecursive(Transformer): | |||||
| pass | |||||
| class VisitorBase: | class VisitorBase: | ||||
| pass | pass | ||||
| @@ -73,10 +82,10 @@ _InterMethod = Callable[[Type[Interpreter], _T], _R] | |||||
| def v_args( | def v_args( | ||||
| inline: bool = False, | |||||
| meta: bool = False, | |||||
| tree: bool = False, | |||||
| wrapper: Callable = None | |||||
| inline: bool = False, | |||||
| meta: bool = False, | |||||
| tree: bool = False, | |||||
| wrapper: Callable = None | |||||
| ) -> Callable[[_DECORATED], _DECORATED]: | ) -> Callable[[_DECORATED], _DECORATED]: | ||||
| ... | ... | ||||
| @@ -218,6 +218,8 @@ class Transformer_NonRecursive(Transformer): | |||||
| else: | else: | ||||
| args = [] | args = [] | ||||
| stack.append(self._call_userfunc(x, args)) | stack.append(self._call_userfunc(x, args)) | ||||
| elif self.__visit_tokens__ and isinstance(x, Token): | |||||
| stack.append(self._call_userfunc_token(x)) | |||||
| else: | else: | ||||
| stack.append(x) | stack.append(x) | ||||
| @@ -8,7 +8,8 @@ import functools | |||||
| from lark.tree import Tree | from lark.tree import Tree | ||||
| from lark.lexer import Token | from lark.lexer import Token | ||||
| from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard | |||||
| from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \ | |||||
| Transformer_InPlaceRecursive, Transformer_NonRecursive | |||||
| class TestTrees(TestCase): | class TestTrees(TestCase): | ||||
| @@ -232,6 +233,20 @@ class TestTrees(TestCase): | |||||
| x = MyTransformer().transform( t ) | x = MyTransformer().transform( t ) | ||||
| self.assertEqual(x, t2) | self.assertEqual(x, t2) | ||||
| def test_transformer_variants(self): | |||||
| tree = Tree('start', [Tree('add', [Token('N', '1'), Token('N', '2')]), Tree('add', [Token('N', '3'), Token('N', '4')])]) | |||||
| for base in (Transformer, Transformer_InPlace, Transformer_NonRecursive, Transformer_InPlaceRecursive): | |||||
| class T(base): | |||||
| def add(self, children): | |||||
| return sum(children) | |||||
| def N(self, token): | |||||
| return int(token) | |||||
| copied = copy.deepcopy(tree) | |||||
| result = T().transform(copied) | |||||
| self.assertEqual(result, Tree('start', [3, 7])) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||