This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

402 lines
13 KiB

  1. """This module implements a CYK parser."""
  2. from collections import defaultdict
  3. import itertools
  4. import re
  5. from ..common import ParseError, Terminal, Terminal_Regexp
  6. from ..lexer import Token
  7. from ..tree import Tree
  8. def TypeName(x):
  9. return type(x).__name__
  10. class Symbol(object):
  11. """Any grammar symbol."""
  12. def __init__(self, s):
  13. self.s = s
  14. def __repr__(self):
  15. return '%s(%s)' % (TypeName(self), str(self))
  16. def __str__(self):
  17. return str(self.s)
  18. def __eq__(self, other):
  19. return str(self) == str(other)
  20. def __ne__(self, other):
  21. return not self.__eq__(other)
  22. def __hash__(self):
  23. return hash(TypeName(self) + '&' + self.__str__())
  24. class T(Symbol):
  25. """Terminal."""
  26. def __init__(self, s):
  27. super(T, self).__init__(s)
  28. self.regexp = re.compile(s)
  29. def match(self, s):
  30. m = self.regexp.match(s)
  31. return bool(m) and len(m.group(0)) == len(s)
  32. def __eq__(self, other):
  33. return super(T, self).__eq__(other) and isinstance(other, T)
  34. class NT(Symbol):
  35. """Non-terminal."""
  36. def __eq__(self, other):
  37. return super(NT, self).__eq__(other) and isinstance(other, NT)
  38. class Rule(object):
  39. """Context-free grammar rule."""
  40. def __init__(self, lhs, rhs, weight, alias):
  41. super(Rule, self).__init__()
  42. assert isinstance(lhs, NT), lhs
  43. assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
  44. self.lhs = lhs
  45. self.rhs = rhs
  46. self.weight = weight
  47. self.alias = alias
  48. def __str__(self):
  49. return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
  50. def __repr__(self):
  51. return str(self)
  52. def __hash__(self):
  53. return hash(self.__repr__())
  54. def __eq__(self, other):
  55. return self.lhs == other.lhs and self.rhs == other.rhs
  56. def __ne__(self, other):
  57. return not self.__eq__(other)
  58. class Grammar(object):
  59. """Context-free grammar."""
  60. def __init__(self, rules):
  61. super(Grammar, self).__init__()
  62. self.rules = sorted(rules, key=lambda x: str(x))
  63. def __eq__(self, other):
  64. return set(self.rules) == set(other.rules)
  65. def __str__(self):
  66. return '\n' + '\n'.join(sorted(x.__repr__() for x in self.rules)) + '\n'
  67. def __repr__(self):
  68. return str(self)
  69. # Parse tree data structures
  70. class RuleNode(object):
  71. """A node in the parse tree, which also contains the full rhs rule."""
  72. def __init__(self, rule, children, weight=0):
  73. super(RuleNode, self).__init__()
  74. self.rule = rule
  75. self.children = children
  76. self.weight = weight
  77. def __repr__(self):
  78. return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(
  79. str(x) for x in self.children))
  80. def __hash__(self):
  81. return hash(self.__repr__())
  82. class Node(object):
  83. """A node in the parse tree."""
  84. def __init__(self, lhs, children):
  85. super(Node, self).__init__()
  86. self.lhs = lhs
  87. self.children = children
  88. def __repr__(self):
  89. return 'Node(%s, [%s])' % (repr(self.lhs), ', '.join(
  90. str(x) for x in self.children))
  91. def __hash__(self):
  92. return hash(self.__repr__())
  93. class Parser(object):
  94. """Parser wrapper."""
  95. def __init__(self, rules, start):
  96. super(Parser, self).__init__()
  97. self.orig_rules = {rule.alias: rule for rule in rules}
  98. rules = [self._ToRule(rule) for rule in rules]
  99. self.grammar = ToCnf(Grammar(rules))
  100. self.start = NT(start)
  101. def _ToRule(self, lark_rule):
  102. """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
  103. return Rule(
  104. NT(lark_rule.origin), [
  105. T(x.data) if (isinstance(x, Terminal_Regexp) or
  106. isinstance(x, Terminal)) else NT(x)
  107. for x in lark_rule.expansion
  108. ], weight=lark_rule.options.priority if lark_rule.options and lark_rule.options.priority else 0, alias=lark_rule.alias)
  109. def parse(self, tokenized): # pylint: disable=invalid-name
  110. """Parses input, which is a list of tokens."""
  111. table, trees = _Parse(tokenized, self.grammar)
  112. # Check if the parse succeeded.
  113. if all(r.lhs != self.start for r in table[(0, len(tokenized) - 1)]):
  114. raise ParseError('Parsing failed.')
  115. parse = trees[(0, len(tokenized) - 1)][NT(self.start)]
  116. return self._ToTree(RevertCnf(parse))
  117. def _ToTree(self, rule_node):
  118. """Converts a RuleNode parse tree to a lark Tree."""
  119. orig_rule = self.orig_rules[rule_node.rule.alias]
  120. children = []
  121. for i, child in enumerate(rule_node.children):
  122. if isinstance(child, RuleNode):
  123. children.append(self._ToTree(child))
  124. elif isinstance(child, Terminal_Regexp):
  125. children.append(Token(orig_rule.expansion[i].name, child.s))
  126. else:
  127. children.append(Token(orig_rule.expansion[i], child.s))
  128. return Tree(orig_rule.origin, children, rule=orig_rule)
  129. def PrintParse(node, indent=0):
  130. if isinstance(node, RuleNode):
  131. print(' ' * (indent * 2) + str(node.rule.lhs))
  132. for child in node.children:
  133. PrintParse(child, indent + 1)
  134. else:
  135. print(' ' * (indent * 2) + str(node.s))
  136. def _Parse(s, g):
  137. """Parses sentence 's' using CNF grammar 'g'."""
  138. # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
  139. table = defaultdict(set)
  140. # Top-level structure is similar to the CYK table. Each cell is a dict from
  141. # rule name to the best (lightest) tree for that rule.
  142. trees = defaultdict(dict)
  143. # Populate base case with existing terminal production rules
  144. for i, w in enumerate(s):
  145. for terminal, rules in g.terminal_rules.iteritems():
  146. if terminal.match(w):
  147. for rule in rules:
  148. table[(i, i)].add(rule)
  149. if (rule.lhs not in trees[(i, i)] or
  150. rule.weight < trees[(i, i)][rule.lhs].weight):
  151. trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
  152. # Iterate over lengths of sub-sentences
  153. for l in xrange(2, len(s) + 1):
  154. # Iterate over sub-sentences with the given length
  155. for i in xrange(len(s) - l + 1):
  156. # Choose partition of the sub-sentence in [1, l)
  157. for p in xrange(i + 1, i + l):
  158. span1 = (i, p - 1)
  159. span2 = (p, i + l - 1)
  160. for r1, r2 in itertools.product(table[span1], table[span2]):
  161. for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
  162. table[(i, i + l - 1)].add(rule)
  163. r1_tree = trees[span1][r1.lhs]
  164. r2_tree = trees[span2][r2.lhs]
  165. rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
  166. if (rule.lhs not in trees[(i, i + l - 1)] or
  167. rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
  168. trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
  169. return table, trees
  170. # This section implements context-free grammar converter to Chomsky normal form.
  171. # It also implements a conversion of parse trees from its CNF to the original
  172. # grammar.
  173. # Overview:
  174. # Applies the following operations in this order:
  175. # * TERM: Eliminates non-solitary terminals from all rules
  176. # * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
  177. # * UNIT: Eliminates non-terminal unit rules
  178. #
  179. # The following grammar characteristics aren't featured:
  180. # * Start symbol appears on RHS
  181. # * Empty rules (epsilon rules)
  182. class CnfWrapper(object):
  183. """CNF wrapper for grammar.
  184. Validates that the input grammar is CNF and provides helper data structures.
  185. """
  186. def __init__(self, grammar):
  187. super(CnfWrapper, self).__init__()
  188. self.grammar = grammar
  189. self.rules = grammar.rules
  190. self.terminal_rules = defaultdict(list)
  191. self.nonterminal_rules = defaultdict(list)
  192. for r in self.rules:
  193. # Validate that the grammar is CNF and populate auxiliary data structures.
  194. assert isinstance(r.lhs, NT), r
  195. assert len(r.rhs) in [1, 2], r
  196. if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
  197. self.terminal_rules[r.rhs[0]].append(r)
  198. elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
  199. self.nonterminal_rules[tuple(r.rhs)].append(r)
  200. else:
  201. assert False, r
  202. def __eq__(self, other):
  203. return self.grammar == other.grammar
  204. def __repr__(self):
  205. return self.grammar.__repr__()
  206. class UnitSkipRule(Rule):
  207. """A rule that records NTs that were skipped during transformation."""
  208. def __init__(self, lhs, rhs, skipped_rules, weight, alias):
  209. super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
  210. self.skipped_rules = skipped_rules
  211. def __eq__(self, other):
  212. return (super(UnitSkipRule, self).__eq__(other) and
  213. isinstance(other, type(self)) and
  214. self.skipped_rules == other.skipped_rules)
  215. def BuildUnitSkipRule(unit_rule, target_rule):
  216. skipped_rules = []
  217. if isinstance(unit_rule, UnitSkipRule):
  218. skipped_rules += unit_rule.skipped_rules
  219. skipped_rules.append(target_rule)
  220. if isinstance(target_rule, UnitSkipRule):
  221. skipped_rules += target_rule.skipped_rules
  222. return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
  223. weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
  224. def GetAnyNtUnitRule(g):
  225. """Returns a non-terminal unit rule from 'g', or None if there is none."""
  226. for rule in g.rules:
  227. if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
  228. return rule
  229. return None
  230. def RemoveUnitRule(g, rule):
  231. """Removes 'rule' from 'g' without changing the langugage produced by 'g'."""
  232. new_rules = [x for x in g.rules if x != rule]
  233. refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
  234. for ref in refs:
  235. new_rules.append(BuildUnitSkipRule(rule, ref))
  236. return Grammar(new_rules)
  237. def Split(rule):
  238. """Splits a rule whose len(rhs) > 2 into shorter rules."""
  239. # if len(rule.rhs) <= 2:
  240. # return [rule]
  241. rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
  242. rule_name = '__SP_%s' % (rule_str) + '_%d'
  243. new_rules = [Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)]
  244. for i in xrange(1, len(rule.rhs) - 2):
  245. new_rules.append(
  246. Rule(NT(rule_name % i),
  247. [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split'))
  248. new_rules.append(Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split'))
  249. return new_rules
  250. def Term(g):
  251. """Applies the TERM rule on 'g' (see top comment)."""
  252. all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
  253. t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
  254. new_rules = []
  255. for rule in g.rules:
  256. if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
  257. new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
  258. new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
  259. new_rules.extend(v for k, v in t_rules.iteritems() if k in rule.rhs)
  260. else:
  261. new_rules.append(rule)
  262. return Grammar(new_rules)
  263. def Bin(g):
  264. """Applies the BIN rule to 'g' (see top comment)."""
  265. new_rules = []
  266. for rule in g.rules:
  267. if len(rule.rhs) > 2:
  268. new_rules.extend(Split(rule))
  269. else:
  270. new_rules.append(rule)
  271. return Grammar(new_rules)
  272. def Unit(g):
  273. """Applies the UNIT rule to 'g' (see top comment)."""
  274. nt_unit_rule = GetAnyNtUnitRule(g)
  275. while nt_unit_rule:
  276. g = RemoveUnitRule(g, nt_unit_rule)
  277. nt_unit_rule = GetAnyNtUnitRule(g)
  278. return g
  279. def ToCnf(g):
  280. """Creates a CNF grammar from a general context-free grammar 'g'."""
  281. g = Unit(Bin(Term(g)))
  282. return CnfWrapper(g)
  283. def UnrollUnitSkipRule(lhs, orig_rhs, skipped_rules, children, weight, alias):
  284. if not skipped_rules:
  285. return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
  286. else:
  287. weight = weight - skipped_rules[0].weight
  288. return RuleNode(
  289. Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
  290. UnrollUnitSkipRule(skipped_rules[0].lhs, orig_rhs,
  291. skipped_rules[1:], children,
  292. skipped_rules[0].weight, skipped_rules[0].alias)
  293. ], weight=weight)
  294. def RevertCnf(node):
  295. """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
  296. if isinstance(node, T):
  297. return node
  298. # Reverts TERM rule.
  299. if node.rule.lhs.s.startswith('__T_'):
  300. return node.children[0]
  301. else:
  302. children = []
  303. reverted_children = [RevertCnf(x) for x in node.children]
  304. for child in reverted_children:
  305. # Reverts BIN rule.
  306. if isinstance(child, RuleNode) and child.rule.lhs.s.startswith('__SP_'):
  307. children.extend(child.children)
  308. else:
  309. children.append(child)
  310. # Reverts UNIT rule.
  311. if isinstance(node.rule, UnitSkipRule):
  312. return UnrollUnitSkipRule(node.rule.lhs, node.rule.rhs,
  313. node.rule.skipped_rules, children,
  314. node.rule.weight, node.rule.alias)
  315. else:
  316. return RuleNode(node.rule, children)