diff --git a/lark/load_grammar.py b/lark/load_grammar.py index 786b4a3..01739a8 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -24,6 +24,8 @@ inline_args = v_args(inline=True) __path__ = os.path.dirname(__file__) IMPORT_PATHS = [os.path.join(__path__, 'grammars')] +EXT = '.lark' + _RE_FLAGS = 'imslux' def is_terminal(sym): @@ -139,16 +141,16 @@ RULES = { 'ignore': ['_IGNORE expansions _NL'], 'declare': ['_DECLARE _declare_args _NL'], 'import': ['_IMPORT _import_path _NL', - '_IMPORT _import_path _LPAR list_name _RPAR _NL', + '_IMPORT _import_path _LPAR name_list _RPAR _NL', '_IMPORT _import_path _TO TERMINAL _NL'], - '_import_path': ['import_common', 'import_rel'], - 'import_common': ['_import_args'], + '_import_path': ['import_lib', 'import_rel'], + 'import_lib': ['_import_args'], 'import_rel': ['_DOT _import_args'], '_import_args': ['name', '_import_args _DOT name'], - 'list_name': ['_list_name'], - '_list_name': ['name', '_list_name _COMMA name'], + 'name_list': ['_name_list'], + '_name_list': ['name', '_name_list _COMMA name'], '_declare_args': ['name', '_declare_args name'], 'literal': ['REGEXP', 'STRING'], @@ -506,25 +508,19 @@ class Grammar: _imported_grammars = {} -def import_grammar(grammar_path, base_path=None): +def import_grammar(grammar_path, base_paths=[]): if grammar_path not in _imported_grammars: - if base_path is None: - import_paths = IMPORT_PATHS - else: - import_paths = [base_path] + IMPORT_PATHS - found = False + import_paths = base_paths + IMPORT_PATHS for import_path in import_paths: - try: + with suppress(IOError): with open(os.path.join(import_path, grammar_path)) as f: text = f.read() grammar = load_grammar(text, grammar_path) _imported_grammars[grammar_path] = grammar - found = True break - except FileNotFoundError: - pass - if not found: - raise FileNotFoundError(grammar_path) + else: + open(grammar_path) + assert False return _imported_grammars[grammar_path] @@ -640,18 +636,25 @@ class GrammarLoader: t ,= stmt.children ignore.append(t) elif stmt.data == 'import': - dotted_path = stmt.children[0].children + if len(stmt.children) > 1: + path_node, arg1 = stmt.children + else: + path_node ,= stmt.children + arg1 = None + + dotted_path = path_node.children - if len(stmt.children) > 1 and hasattr(stmt.children[1], 'children'): # Multi import - names = stmt.children[1].children + if isinstance(arg1, Tree): # Multi import + names = arg1.children aliases = names # Can't have aliased multi import, so all aliases will be the same as names - grammar_path = os.path.join(*dotted_path) + '.lark' else: # Single import names = [dotted_path[-1]] # Get name from dotted path - aliases = [stmt.children[1] if len(stmt.children) > 1 else dotted_path[-1]] # Aliases if exist - grammar_path = os.path.join(*dotted_path[:-1]) + '.lark' # Exclude name from grammar path + aliases = [arg1] if arg1 else names # Aliases if exist + dotted_path = dotted_path[:-1] + + grammar_path = os.path.join(*dotted_path) + EXT - if stmt.children[0].data == 'import_common': # Regular import + if path_node.data == 'import_lib': # Import from library g = import_grammar(grammar_path) else: # Relative import if grammar_name == '': # Import relative to script file path if grammar is coded in script @@ -659,7 +662,7 @@ class GrammarLoader: else: base_file = grammar_name # Import relative to grammar file path if external grammar file base_path = os.path.split(base_file)[0] - g = import_grammar(grammar_path, base_path=base_path) + g = import_grammar(grammar_path, base_paths=[base_path]) for name, alias in zip(names, aliases): token_options = dict(g.token_defs)[name] diff --git a/tests/test_parser.py b/tests/test_parser.py index 0cb74e2..c4d7147 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -976,6 +976,20 @@ def _make_parser_test(LEXER, PARSER): x = l.parse('12 capybaras') self.assertEqual(x.children, ['12', 'capybaras']) + def test_import_errors(self): + grammar = """ + start: NUMBER WORD + + %import .grammars.bad_test.NUMBER + """ + self.assertRaises(IOError, _Lark, grammar) + + grammar = """ + start: NUMBER WORD + + %import bad_test.NUMBER + """ + self.assertRaises(IOError, _Lark, grammar) @unittest.skipIf(PARSER != 'earley', "Currently only Earley supports priority in rules") def test_earley_prioritization(self):