diff --git a/parser.py b/parser.py index 7d3056ce37d78decc40c1b236ed6d9708e320f4e_cGFyc2VyLnB5..0d030328411294df042e9fbf1c3a8cdf22bc0c19_cGFyc2VyLnB5 100644 --- a/parser.py +++ b/parser.py @@ -47,8 +47,7 @@ from warnings import warn from rql.stmts import Union, Select, Delete, Insert, Set from rql.nodes import * -def warn(*args): - raise Exception() + def unquote(string): """Remove quotes from a string.""" if string.startswith('"'): diff --git a/stcheck.py b/stcheck.py index 7d3056ce37d78decc40c1b236ed6d9708e320f4e_c3RjaGVjay5weQ==..0d030328411294df042e9fbf1c3a8cdf22bc0c19_c3RjaGVjay5weQ== 100644 --- a/stcheck.py +++ b/stcheck.py @@ -282,18 +282,9 @@ def annotate(self, node): #assert not node.annotated node.accept(self) - #node.annotated = True - - visit_insert = visit_delete = visit_set = lambda s,n: None - - def visit_union(self, node): - for select in node.children: - self.visit_select(select) - - def visit_select(self, node): - if node.with_ is not None: - for subquery in node.with_: - self.visit_union(subquery.query) + node.annotated = True + + def _visit_stmt(self, node): for i, term in enumerate(node.selection): for func in term.iget_nodes(Function): if func.descr().aggregat: @@ -305,6 +296,18 @@ vref.variable.set_scope(node) if node.where is not None: node.where.accept(self, node) + + visit_insert = visit_delete = visit_set = _visit_stmt + + def visit_union(self, node): + for select in node.children: + self.visit_select(select) + + def visit_select(self, node): + if node.with_ is not None: + for subquery in node.with_: + self.visit_union(subquery.query) + self._visit_stmt(node) def rewrite_shared_optional(self, exists, var): """if variable is shared across multiple scopes, need some tree diff --git a/stmts.py b/stmts.py index 7d3056ce37d78decc40c1b236ed6d9708e320f4e_c3RtdHMucHk=..0d030328411294df042e9fbf1c3a8cdf22bc0c19_c3RtdHMucHk= 100644 --- a/stmts.py +++ b/stmts.py @@ -109,7 +109,8 @@ # default values for optional instance attributes, set on the instance when # used - schema = None # ISchema + schema = None # ISchema + annotated = False # set by the annotator # def __init__(self): # Node.__init__(self) @@ -204,6 +205,23 @@ return new # union specific methods ################################################## + + def locate_subquery(self, col, etype, kwargs=None): + if len(self.children) == 1: + return self.children[0] + try: + return self._subq_cache[(col, etype)] + except AttributeError: + self._subq_cache = {} + except KeyError: + pass + for select in self.children: + term = select.selection[col] + for i, solution in enumerate(select.solutions): + if term.get_type(solution, kwargs) == etype: + self._subq_cache[(col, etype)] = select + return select + raise Exception('internal error, %s not found on col %s' % (etype, col)) def set_limit(self, limit): if limit is not None and (not isinstance(limit, (int, long)) or limit <= 0):