# HG changeset patch # User Sylvain Thénault <sylvain.thenault@logilab.fr> # Date 1269420510 -3600 # Wed Mar 24 09:48:30 2010 +0100 # Node ID 74eba13308fec99d2e3cf4ec1340f1222a2ec629 # Parent 3bfc7423b041546462e12ac65e0d3bdddd44d498 move some checks done in the annotator to the checker diff --git a/__init__.py b/__init__.py --- a/__init__.py +++ b/__init__.py @@ -45,7 +45,7 @@ if uid_func_mapping: for key in uid_func_mapping: special_relations[key] = 'uid' - self._checker = RQLSTChecker(schema) + self._checker = RQLSTChecker(schema, special_relations) self._annotator = RQLSTAnnotator(schema, special_relations) self._analyser_lock = threading.Lock() if resolver_class is None: diff --git a/stcheck.py b/stcheck.py --- a/stcheck.py +++ b/stcheck.py @@ -30,6 +30,25 @@ def __init__(self, node): self.node = node +VAR_SELECTED = 1 +VAR_HAS_TYPE_REL = 2 +VAR_HAS_UID_REL = 4 +VAR_HAS_REL = 8 + +class STCheckState(object): + def __init__(self): + self.errors = [] + self.under_not = [] + self.var_info = {} + + def error(self, msg): + self.errors.append(msg) + + def add_var_info(self, var, vi): + try: + self.var_info[var] |= vi + except KeyError: + self.var_info[var] = vi class RQLSTChecker(object): """Check a RQL syntax tree for errors not detected on parsing. @@ -42,37 +61,38 @@ errors due to a bad rql input """ - def __init__(self, schema): + def __init__(self, schema, special_relations=None): self.schema = schema + self.special_relations = special_relations or {} def check(self, node): - errors = [] - self._visit(node, errors) - if errors: - raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(errors))) + state = STCheckState() + self._visit(node, state) + if state.errors: + raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(state.errors))) #if node.TYPE == 'select' and \ # not node.defined_vars and not node.get_restriction(): # result = [] # for term in node.selected_terms(): # result.append(term.eval(kwargs)) - def _visit(self, node, errors): + def _visit(self, node, state): try: - node.accept(self, errors) + node.accept(self, state) except GoTo, ex: - self._visit(ex.node, errors) + self._visit(ex.node, state) else: for c in node.children: - self._visit(c, errors) - node.leave(self, errors) + self._visit(c, state) + node.leave(self, state) - def _visit_selectedterm(self, node, errors): + def _visit_selectedterm(self, node, state): for i, term in enumerate(node.selection): # selected terms are not included by the default visit, # accept manually each of them - self._visit(term, errors) + self._visit(term, state) - def _check_selected(self, term, termtype, errors): + def _check_selected(self, term, termtype, state): """check that variables referenced in the given term are selected""" for vref in variable_refs(term): # no stinfo yet, use references @@ -82,37 +102,44 @@ break else: msg = 'variable %s used in %s is not referenced by any relation' - errors.append(msg % (vref.name, termtype)) + state.error(msg % (vref.name, termtype)) # statement nodes ######################################################### - def visit_union(self, node, errors): + def visit_union(self, node, state): nbselected = len(node.children[0].selection) for select in node.children[1:]: if not len(select.selection) == nbselected: - errors.append('when using union, all subqueries should have ' + state.error('when using union, all subqueries should have ' 'the same number of selected terms') - def leave_union(self, node, errors): + def leave_union(self, node, state): pass - def visit_select(self, node, errors): + def visit_select(self, node, state): node.vargraph = {} # graph representing links between variable node.aggregated = set() - self._visit_selectedterm(node, errors) + self._visit_selectedterm(node, state) - def leave_select(self, node, errors): + def leave_select(self, node, state): selected = node.selection # check selected variable are used in restriction if node.where is not None or len(selected) > 1: for term in selected: - self._check_selected(term, 'selection', errors) + self._check_selected(term, 'selection', state) + for vref in term.iget_nodes(VariableRef): + state.add_var_info(vref.variable, VAR_SELECTED) + for var in node.defined_vars.itervalues(): + vinfo = state.var_info.get(var, 0) + if not (vinfo & VAR_HAS_REL) and (vinfo & VAR_HAS_TYPE_REL) \ + and not (vinfo & VAR_SELECTED): + raise BadRQLQuery('unbound variable %s (%s)' % (var.name, selected)) if node.groupby: # check that selected variables are used in groups for var in node.selection: if isinstance(var, VariableRef) and not var in node.groupby: - errors.append('variable %s should be grouped' % var) + state.error('variable %s should be grouped' % var) for group in node.groupby: - self._check_selected(group, 'group', errors) + self._check_selected(group, 'group', state) if node.distinct and node.orderby: # check that variables referenced in the given term are reachable from # a selected variable with only ?1 cardinalityselected @@ -132,7 +159,7 @@ msg = ('can\'t sort on variable %s which is linked to a' ' variable in the selection but may have different' ' values for a resulting row') - errors.append(msg % vref.name) + state.error(msg % vref.name) def has_unique_value_path(self, select, fromvar, tovar): graph = select.vargraph @@ -156,32 +183,32 @@ return True - def visit_insert(self, insert, errors): - self._visit_selectedterm(insert, errors) - def leave_insert(self, node, errors): + def visit_insert(self, insert, state): + self._visit_selectedterm(insert, state) + def leave_insert(self, node, state): pass - def visit_delete(self, delete, errors): - self._visit_selectedterm(delete, errors) - def leave_delete(self, node, errors): + def visit_delete(self, delete, state): + self._visit_selectedterm(delete, state) + def leave_delete(self, node, state): pass - def visit_set(self, update, errors): - self._visit_selectedterm(update, errors) - def leave_set(self, node, errors): + def visit_set(self, update, state): + self._visit_selectedterm(update, state) + def leave_set(self, node, state): pass # tree nodes ############################################################## - def visit_exists(self, node, errors): + def visit_exists(self, node, state): pass - def leave_exists(self, node, errors): + def leave_exists(self, node, state): pass - def visit_subquery(self, node, errors): + def visit_subquery(self, node, state): pass - def leave_subquery(self, node, errors): + def leave_subquery(self, node, state): # copy graph information we're interested in pgraph = node.parent.vargraph for select in node.query.children: @@ -191,7 +218,7 @@ try: subvref = select.selection[i] except IndexError: - errors.append('subquery "%s" has only %s selected terms, needs %s' + state.error('subquery "%s" has only %s selected terms, needs %s' % (select, len(select.selection), len(node.aliases))) continue if isinstance(subvref, VariableRef): @@ -211,12 +238,12 @@ values = pgraph.setdefault(_var_graphid(key, trmap, select), []) values += [_var_graphid(v, trmap, select) for v in val] - def visit_sortterm(self, sortterm, errors): + def visit_sortterm(self, sortterm, state): term = sortterm.term if isinstance(term, Constant): for select in sortterm.root.children: if len(select.selection) < term.value: - errors.append('order column out of bound %s' % term.value) + state.error('order column out of bound %s' % term.value) else: stmt = term.stmt for tvref in variable_refs(term): @@ -225,17 +252,17 @@ break else: msg = 'sort variable %s is not referenced any where else' - errors.append(msg % tvref.name) + state.error(msg % tvref.name) - def leave_sortterm(self, node, errors): + def leave_sortterm(self, node, state): pass - def visit_and(self, et, errors): + def visit_and(self, et, state): pass #assert len(et.children) == 2, len(et.children) - def leave_and(self, node, errors): + def leave_and(self, node, state): pass - def visit_or(self, ou, errors): + def visit_or(self, ou, state): #assert len(ou.children) == 2, len(ou.children) # simplify Ored expression of a symmetric relation r1, r2 = ou.children[0], ou.children[1] @@ -256,80 +283,93 @@ raise GoTo(r1) except AttributeError: pass - def leave_or(self, node, errors): - pass - - def visit_not(self, not_, errors): - pass - def leave_not(self, not_, errors): + def leave_or(self, node, state): pass - def visit_relation(self, relation, errors): - if relation.optional and relation.neged(): - errors.append("can use optional relation under NOT (%s)" - % relation.as_string()) - # special case "X identity Y" - if relation.r_type == 'identity': - lhs, rhs = relation.children - #assert not isinstance(relation.parent, Not) - #assert rhs.operator == '=' - elif relation.r_type == 'is': + def visit_not(self, not_, state): + state.under_not.append(True) + def leave_not(self, not_, state): + state.under_not.pop() + + def visit_relation(self, relation, state): + if relation.optional and state.under_not: + state.error("can't use optional relation under NOT (%s)" + % relation.as_string()) + lhsvar = relation.children[0].variable + if relation.is_types_restriction(): + if relation.optional: + state.error('can\'t use optional relation on "%s"' + % relation.as_string()) + if state.var_info.get(lhsvar, 0) & VAR_HAS_TYPE_REL: + state.error('can only one type restriction per variable (use ' + 'IN for %s if desired)' % lhsvar.name) + else: + state.add_var_info(lhsvar, VAR_HAS_TYPE_REL) # special case "C is NULL" - if relation.children[1].operator == 'IS': - lhs, rhs = relation.children - #assert isinstance(lhs, VariableRef), lhs - #assert isinstance(rhs.children[0], Constant) - #assert rhs.operator == 'IS', rhs.operator - #assert rhs.children[0].type == None + # if relation.children[1].operator == 'IS': + # lhs, rhs = relation.children + # #assert isinstance(lhs, VariableRef), lhs + # #assert isinstance(rhs.children[0], Constant) + # #assert rhs.operator == 'IS', rhs.operator + # #assert rhs.children[0].type == None else: + state.add_var_info(lhsvar, VAR_HAS_REL) + rtype = relation.r_type try: - rschema = self.schema.rschema(relation.r_type) + rschema = self.schema.rschema(rtype) except KeyError: - errors.append('unknown relation `%s`' % relation.r_type) + state.error('unknown relation `%s`' % rtype) else: if relation.optional and rschema.final: - errors.append("shouldn't use optional on final relation `%s`" - % relation.r_type) + state.error("shouldn't use optional on final relation `%s`" + % relation.r_type) + if self.special_relations.get(rtype) == 'uid': + if state.var_info.get(lhsvar, 0) & VAR_HAS_UID_REL: + state.error('can only one uid restriction per variable ' + '(use IN for %s if desired)' % lhsvar.name) + else: + state.add_var_info(lhsvar, VAR_HAS_UID_REL) + for vref in relation.children[1].get_nodes(VariableRef): + state.add_var_info(vref.variable, VAR_HAS_REL) try: vargraph = relation.stmt.vargraph rhsvarname = relation.children[1].children[0].variable.name - lhsvarname = relation.children[0].name except AttributeError: pass else: - vargraph.setdefault(lhsvarname, []).append(rhsvarname) - vargraph.setdefault(rhsvarname, []).append(lhsvarname) - vargraph[(lhsvarname, rhsvarname)] = relation.r_type + vargraph.setdefault(lhsvar.name, []).append(rhsvarname) + vargraph.setdefault(rhsvarname, []).append(lhsvar.name) + vargraph[(lhsvar.name, rhsvarname)] = relation.r_type - def leave_relation(self, relation, errors): + def leave_relation(self, relation, state): pass #assert isinstance(lhs, VariableRef), '%s: %s' % (lhs.__class__, # relation) - def visit_comparison(self, comparison, errors): + def visit_comparison(self, comparison, state): pass #assert len(comparison.children) in (1,2), len(comparison.children) - def leave_comparison(self, node, errors): + def leave_comparison(self, node, state): pass - def visit_mathexpression(self, mathexpr, errors): + def visit_mathexpression(self, mathexpr, state): pass #assert len(mathexpr.children) == 2, len(mathexpr.children) - def leave_mathexpression(self, node, errors): + def leave_mathexpression(self, node, state): pass - def visit_function(self, function, errors): + def visit_function(self, function, state): try: funcdescr = function_description(function.name) - except UnknownFunction: - errors.append('unknown function "%s"' % function.name) + except UnknownFunction: + state.error('unknown function "%s"' % function.name) else: try: funcdescr.check_nbargs(len(function.children)) except BadRQLQuery, ex: - errors.append(str(ex)) + state.error(str(ex)) if funcdescr.aggregat: if isinstance(function.children[0], Function) and \ function.children[0].descr().aggregat: - errors.append('can\'t nest aggregat functions') + state.error('can\'t nest aggregat functions') if funcdescr.name == 'IN': #assert function.parent.operator == '=' if len(function.children) == 1: @@ -337,10 +377,11 @@ function.parent.remove(function) #else: # assert len(function.children) >= 1 - def leave_function(self, node, errors): + + def leave_function(self, node, state): pass - def visit_variableref(self, variableref, errors): + def visit_variableref(self, variableref, state): #assert len(variableref.children)==0 #assert not variableref.parent is variableref ## try: @@ -350,19 +391,19 @@ ## raise Exception((variableref.root(), variableref.variable)) pass - def leave_variableref(self, node, errors): + def leave_variableref(self, node, state): pass - def visit_constant(self, constant, errors): + def visit_constant(self, constant, state): #assert len(constant.children)==0 if constant.type == 'etype': if constant.relation().r_type not in ('is', 'is_instance_of'): msg ='using an entity type in only allowed with "is" relation' - errors.append(msg) + state.error(msg) if not constant.value in self.schema: - errors.append('unknown entity type %s' % constant.value) + state.error('unknown entity type %s' % constant.value) - def leave_constant(self, node, errors): + def leave_constant(self, node, state): pass @@ -419,16 +460,6 @@ for term in node.groupby: for vref in term.get_nodes(VariableRef): vref.variable.stinfo['blocsimplification'].add(term) - for var in node.defined_vars.itervalues(): - if not var.stinfo['relations'] and var.stinfo['typerels'] and not var.stinfo['selected']: - raise BadRQLQuery('unbound variable %s (%s)' % (var.name, var.stmt.root)) - if len(var.stinfo['uidrels']) > 1: - uidrels = iter(var.stinfo['uidrels']) - val = getattr(uidrels.next().get_variable_parts()[1], 'value', object()) - for uidrel in uidrels: - if getattr(uidrel.get_variable_parts()[1], 'value', None) != val: - # XXX should check OR branch and check simplify in that case as well - raise BadRQLQuery('conflicting eid values for %s' % var.name) def rewrite_shared_optional(self, exists, var): """if variable is shared across multiple scopes, need some tree @@ -532,10 +563,7 @@ # may have been rewritten as well pass rtype = relation.r_type - try: - rschema = self.schema.rschema(rtype) - except KeyError: - raise BadRQLQuery('no relation %s' % rtype) + rschema = self.schema.rschema(rtype) if lhsvar is not None: lhsvar.set_scope(scope) lhsvar.set_sqlscope(sqlscope) diff --git a/test/unittest_stcheck.py b/test/unittest_stcheck.py --- a/test/unittest_stcheck.py +++ b/test/unittest_stcheck.py @@ -1,6 +1,8 @@ from logilab.common.testlib import TestCase, unittest_main + +from rql import RQLHelper, BadRQLQuery, stmts, nodes + from unittest_analyze import DummySchema -from rql import RQLHelper, BadRQLQuery, stmts, nodes BAD_QUERIES = ( 'Any X, Y GROUPBY X',