# HG changeset patch
# User Sylvain Thénault <sylvain.thenault@logilab.fr>
# Date 1269420510 -3600
#      Wed Mar 24 09:48:30 2010 +0100
# Node ID 74eba13308fec99d2e3cf4ec1340f1222a2ec629
# Parent  3bfc7423b041546462e12ac65e0d3bdddd44d498
move some checks done in the annotator to the checker

diff --git a/__init__.py b/__init__.py
--- a/__init__.py
+++ b/__init__.py
@@ -45,7 +45,7 @@
         if uid_func_mapping:
             for key in uid_func_mapping:
                 special_relations[key] = 'uid'
-        self._checker = RQLSTChecker(schema)
+        self._checker = RQLSTChecker(schema, special_relations)
         self._annotator = RQLSTAnnotator(schema, special_relations)
         self._analyser_lock = threading.Lock()
         if resolver_class is None:
diff --git a/stcheck.py b/stcheck.py
--- a/stcheck.py
+++ b/stcheck.py
@@ -30,6 +30,25 @@
     def __init__(self, node):
         self.node = node
 
+VAR_SELECTED = 1
+VAR_HAS_TYPE_REL = 2
+VAR_HAS_UID_REL = 4
+VAR_HAS_REL = 8
+
+class STCheckState(object):
+    def __init__(self):
+        self.errors = []
+        self.under_not = []
+        self.var_info = {}
+
+    def error(self, msg):
+        self.errors.append(msg)
+
+    def add_var_info(self, var, vi):
+        try:
+            self.var_info[var] |= vi
+        except KeyError:
+            self.var_info[var] = vi
 
 class RQLSTChecker(object):
     """Check a RQL syntax tree for errors not detected on parsing.
@@ -42,37 +61,38 @@
     errors due to a bad rql input
     """
 
-    def __init__(self, schema):
+    def __init__(self, schema, special_relations=None):
         self.schema = schema
+        self.special_relations = special_relations or {}
 
     def check(self, node):
-        errors = []
-        self._visit(node, errors)
-        if errors:
-            raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(errors)))
+        state = STCheckState()
+        self._visit(node, state)
+        if state.errors:
+            raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(state.errors)))
         #if node.TYPE == 'select' and \
         #       not node.defined_vars and not node.get_restriction():
         #    result = []
         #    for term in node.selected_terms():
         #        result.append(term.eval(kwargs))
 
-    def _visit(self, node, errors):
+    def _visit(self, node, state):
         try:
-            node.accept(self, errors)
+            node.accept(self, state)
         except GoTo, ex:
-            self._visit(ex.node, errors)
+            self._visit(ex.node, state)
         else:
             for c in node.children:
-                self._visit(c, errors)
-            node.leave(self, errors)
+                self._visit(c, state)
+            node.leave(self, state)
 
-    def _visit_selectedterm(self, node, errors):
+    def _visit_selectedterm(self, node, state):
         for i, term in enumerate(node.selection):
             # selected terms are not included by the default visit,
             # accept manually each of them
-            self._visit(term, errors)
+            self._visit(term, state)
 
-    def _check_selected(self, term, termtype, errors):
+    def _check_selected(self, term, termtype, state):
         """check that variables referenced in the given term are selected"""
         for vref in variable_refs(term):
             # no stinfo yet, use references
@@ -82,37 +102,44 @@
                     break
             else:
                 msg = 'variable %s used in %s is not referenced by any relation'
-                errors.append(msg % (vref.name, termtype))
+                state.error(msg % (vref.name, termtype))
 
     # statement nodes #########################################################
 
-    def visit_union(self, node, errors):
+    def visit_union(self, node, state):
         nbselected = len(node.children[0].selection)
         for select in node.children[1:]:
             if not len(select.selection) == nbselected:
-                errors.append('when using union, all subqueries should have '
+                state.error('when using union, all subqueries should have '
                               'the same number of selected terms')
-    def leave_union(self, node, errors):
+    def leave_union(self, node, state):
         pass
 
-    def visit_select(self, node, errors):
+    def visit_select(self, node, state):
         node.vargraph = {} # graph representing links between variable
         node.aggregated = set()
-        self._visit_selectedterm(node, errors)
+        self._visit_selectedterm(node, state)
 
-    def leave_select(self, node, errors):
+    def leave_select(self, node, state):
         selected = node.selection
         # check selected variable are used in restriction
         if node.where is not None or len(selected) > 1:
             for term in selected:
-                self._check_selected(term, 'selection', errors)
+                self._check_selected(term, 'selection', state)
+                for vref in term.iget_nodes(VariableRef):
+                    state.add_var_info(vref.variable, VAR_SELECTED)
+        for var in node.defined_vars.itervalues():
+            vinfo = state.var_info.get(var, 0)
+            if not (vinfo & VAR_HAS_REL) and (vinfo & VAR_HAS_TYPE_REL) \
+                   and not (vinfo & VAR_SELECTED):
+                raise BadRQLQuery('unbound variable %s (%s)' % (var.name, selected))
         if node.groupby:
             # check that selected variables are used in groups
             for var in node.selection:
                 if isinstance(var, VariableRef) and not var in node.groupby:
-                    errors.append('variable %s should be grouped' % var)
+                    state.error('variable %s should be grouped' % var)
             for group in node.groupby:
-                self._check_selected(group, 'group', errors)
+                self._check_selected(group, 'group', state)
         if node.distinct and node.orderby:
             # check that variables referenced in the given term are reachable from
             # a selected variable with only ?1 cardinalityselected
@@ -132,7 +159,7 @@
                         msg = ('can\'t sort on variable %s which is linked to a'
                                ' variable in the selection but may have different'
                                ' values for a resulting row')
-                        errors.append(msg % vref.name)
+                        state.error(msg % vref.name)
 
     def has_unique_value_path(self, select, fromvar, tovar):
         graph = select.vargraph
@@ -156,32 +183,32 @@
         return True
 
 
-    def visit_insert(self, insert, errors):
-        self._visit_selectedterm(insert, errors)
-    def leave_insert(self, node, errors):
+    def visit_insert(self, insert, state):
+        self._visit_selectedterm(insert, state)
+    def leave_insert(self, node, state):
         pass
 
-    def visit_delete(self, delete, errors):
-        self._visit_selectedterm(delete, errors)
-    def leave_delete(self, node, errors):
+    def visit_delete(self, delete, state):
+        self._visit_selectedterm(delete, state)
+    def leave_delete(self, node, state):
         pass
 
-    def visit_set(self, update, errors):
-        self._visit_selectedterm(update, errors)
-    def leave_set(self, node, errors):
+    def visit_set(self, update, state):
+        self._visit_selectedterm(update, state)
+    def leave_set(self, node, state):
         pass
 
     # tree nodes ##############################################################
 
-    def visit_exists(self, node, errors):
+    def visit_exists(self, node, state):
         pass
-    def leave_exists(self, node, errors):
+    def leave_exists(self, node, state):
         pass
 
-    def visit_subquery(self, node, errors):
+    def visit_subquery(self, node, state):
         pass
 
-    def leave_subquery(self, node, errors):
+    def leave_subquery(self, node, state):
         # copy graph information we're interested in
         pgraph = node.parent.vargraph
         for select in node.query.children:
@@ -191,7 +218,7 @@
                 try:
                     subvref = select.selection[i]
                 except IndexError:
-                    errors.append('subquery "%s" has only %s selected terms, needs %s'
+                    state.error('subquery "%s" has only %s selected terms, needs %s'
                                   % (select, len(select.selection), len(node.aliases)))
                     continue
                 if isinstance(subvref, VariableRef):
@@ -211,12 +238,12 @@
                     values = pgraph.setdefault(_var_graphid(key, trmap, select), [])
                     values += [_var_graphid(v, trmap, select) for v in val]
 
-    def visit_sortterm(self, sortterm, errors):
+    def visit_sortterm(self, sortterm, state):
         term = sortterm.term
         if isinstance(term, Constant):
             for select in sortterm.root.children:
                 if len(select.selection) < term.value:
-                    errors.append('order column out of bound %s' % term.value)
+                    state.error('order column out of bound %s' % term.value)
         else:
             stmt = term.stmt
             for tvref in variable_refs(term):
@@ -225,17 +252,17 @@
                         break
                 else:
                     msg = 'sort variable %s is not referenced any where else'
-                    errors.append(msg % tvref.name)
+                    state.error(msg % tvref.name)
 
-    def leave_sortterm(self, node, errors):
+    def leave_sortterm(self, node, state):
         pass
 
-    def visit_and(self, et, errors):
+    def visit_and(self, et, state):
         pass #assert len(et.children) == 2, len(et.children)
-    def leave_and(self, node, errors):
+    def leave_and(self, node, state):
         pass
 
-    def visit_or(self, ou, errors):
+    def visit_or(self, ou, state):
         #assert len(ou.children) == 2, len(ou.children)
         # simplify Ored expression of a symmetric relation
         r1, r2 = ou.children[0], ou.children[1]
@@ -256,80 +283,93 @@
                     raise GoTo(r1)
             except AttributeError:
                 pass
-    def leave_or(self, node, errors):
-        pass
-
-    def visit_not(self, not_, errors):
-        pass
-    def leave_not(self, not_, errors):
+    def leave_or(self, node, state):
         pass
 
-    def visit_relation(self, relation, errors):
-        if relation.optional and relation.neged():
-            errors.append("can use optional relation under NOT (%s)"
-                          % relation.as_string())
-        # special case "X identity Y"
-        if relation.r_type == 'identity':
-            lhs, rhs = relation.children
-            #assert not isinstance(relation.parent, Not)
-            #assert rhs.operator == '='
-        elif relation.r_type == 'is':
+    def visit_not(self, not_, state):
+        state.under_not.append(True)
+    def leave_not(self, not_, state):
+        state.under_not.pop()
+
+    def visit_relation(self, relation, state):
+        if relation.optional and state.under_not:
+            state.error("can't use optional relation under NOT (%s)"
+                        % relation.as_string())
+        lhsvar = relation.children[0].variable
+        if relation.is_types_restriction():
+            if relation.optional:
+                state.error('can\'t use optional relation on "%s"'
+                            % relation.as_string())
+            if state.var_info.get(lhsvar, 0) & VAR_HAS_TYPE_REL:
+                state.error('can only one type restriction per variable (use '
+                            'IN for %s if desired)' % lhsvar.name)
+            else:
+                state.add_var_info(lhsvar, VAR_HAS_TYPE_REL)
             # special case "C is NULL"
-            if relation.children[1].operator == 'IS':
-                lhs, rhs = relation.children
-                #assert isinstance(lhs, VariableRef), lhs
-                #assert isinstance(rhs.children[0], Constant)
-                #assert rhs.operator == 'IS', rhs.operator
-                #assert rhs.children[0].type == None
+            # if relation.children[1].operator == 'IS':
+            #     lhs, rhs = relation.children
+            #     #assert isinstance(lhs, VariableRef), lhs
+            #     #assert isinstance(rhs.children[0], Constant)
+            #     #assert rhs.operator == 'IS', rhs.operator
+            #     #assert rhs.children[0].type == None
         else:
+            state.add_var_info(lhsvar, VAR_HAS_REL)
+            rtype = relation.r_type
             try:
-                rschema = self.schema.rschema(relation.r_type)
+                rschema = self.schema.rschema(rtype)
             except KeyError:
-                errors.append('unknown relation `%s`' % relation.r_type)
+                state.error('unknown relation `%s`' % rtype)
             else:
                 if relation.optional and rschema.final:
-                    errors.append("shouldn't use optional on final relation `%s`"
-                                  % relation.r_type)
+                    state.error("shouldn't use optional on final relation `%s`"
+                                % relation.r_type)
+                if self.special_relations.get(rtype) == 'uid':
+                    if state.var_info.get(lhsvar, 0) & VAR_HAS_UID_REL:
+                        state.error('can only one uid restriction per variable '
+                                    '(use IN for %s if desired)' % lhsvar.name)
+                    else:
+                        state.add_var_info(lhsvar, VAR_HAS_UID_REL)
+            for vref in relation.children[1].get_nodes(VariableRef):
+                state.add_var_info(vref.variable, VAR_HAS_REL)
         try:
             vargraph = relation.stmt.vargraph
             rhsvarname = relation.children[1].children[0].variable.name
-            lhsvarname = relation.children[0].name
         except AttributeError:
             pass
         else:
-            vargraph.setdefault(lhsvarname, []).append(rhsvarname)
-            vargraph.setdefault(rhsvarname, []).append(lhsvarname)
-            vargraph[(lhsvarname, rhsvarname)] = relation.r_type
+            vargraph.setdefault(lhsvar.name, []).append(rhsvarname)
+            vargraph.setdefault(rhsvarname, []).append(lhsvar.name)
+            vargraph[(lhsvar.name, rhsvarname)] = relation.r_type
 
-    def leave_relation(self, relation, errors):
+    def leave_relation(self, relation, state):
         pass
         #assert isinstance(lhs, VariableRef), '%s: %s' % (lhs.__class__,
         #                                                       relation)
 
-    def visit_comparison(self, comparison, errors):
+    def visit_comparison(self, comparison, state):
         pass #assert len(comparison.children) in (1,2), len(comparison.children)
-    def leave_comparison(self, node, errors):
+    def leave_comparison(self, node, state):
         pass
 
-    def visit_mathexpression(self, mathexpr, errors):
+    def visit_mathexpression(self, mathexpr, state):
         pass #assert len(mathexpr.children) == 2, len(mathexpr.children)
-    def leave_mathexpression(self, node, errors):
+    def leave_mathexpression(self, node, state):
         pass
 
-    def visit_function(self, function, errors):
+    def visit_function(self, function, state):
         try:
             funcdescr = function_description(function.name)
-        except UnknownFunction:
-            errors.append('unknown function "%s"' % function.name)
+         except UnknownFunction:
+            state.error('unknown function "%s"' % function.name)
         else:
             try:
                 funcdescr.check_nbargs(len(function.children))
             except BadRQLQuery, ex:
-                errors.append(str(ex))
+                state.error(str(ex))
             if funcdescr.aggregat:
                 if isinstance(function.children[0], Function) and \
                        function.children[0].descr().aggregat:
-                    errors.append('can\'t nest aggregat functions')
+                    state.error('can\'t nest aggregat functions')
             if funcdescr.name == 'IN':
                 #assert function.parent.operator == '='
                 if len(function.children) == 1:
@@ -337,10 +377,11 @@
                     function.parent.remove(function)
                 #else:
                 #    assert len(function.children) >= 1
-    def leave_function(self, node, errors):
+
+    def leave_function(self, node, state):
         pass
 
-    def visit_variableref(self, variableref, errors):
+    def visit_variableref(self, variableref, state):
         #assert len(variableref.children)==0
         #assert not variableref.parent is variableref
 ##         try:
@@ -350,19 +391,19 @@
 ##             raise Exception((variableref.root(), variableref.variable))
         pass
 
-    def leave_variableref(self, node, errors):
+    def leave_variableref(self, node, state):
         pass
 
-    def visit_constant(self, constant, errors):
+    def visit_constant(self, constant, state):
         #assert len(constant.children)==0
         if constant.type == 'etype':
             if constant.relation().r_type not in ('is', 'is_instance_of'):
                 msg ='using an entity type in only allowed with "is" relation'
-                errors.append(msg)
+                state.error(msg)
             if not constant.value in self.schema:
-                errors.append('unknown entity type %s' % constant.value)
+                state.error('unknown entity type %s' % constant.value)
 
-    def leave_constant(self, node, errors):
+    def leave_constant(self, node, state):
         pass
 
 
@@ -419,16 +460,6 @@
             for term in node.groupby:
                 for vref in term.get_nodes(VariableRef):
                     vref.variable.stinfo['blocsimplification'].add(term)
-        for var in node.defined_vars.itervalues():
-            if not var.stinfo['relations'] and var.stinfo['typerels'] and not var.stinfo['selected']:
-                raise BadRQLQuery('unbound variable %s (%s)' % (var.name, var.stmt.root))
-            if len(var.stinfo['uidrels']) > 1:
-                uidrels = iter(var.stinfo['uidrels'])
-                val = getattr(uidrels.next().get_variable_parts()[1], 'value', object())
-                for uidrel in uidrels:
-                    if getattr(uidrel.get_variable_parts()[1], 'value', None) != val:
-                        # XXX should check OR branch and check simplify in that case as well
-                        raise BadRQLQuery('conflicting eid values for %s' % var.name)
 
     def rewrite_shared_optional(self, exists, var):
         """if variable is shared across multiple scopes, need some tree
@@ -532,10 +563,7 @@
                 # may have been rewritten as well
                 pass
         rtype = relation.r_type
-        try:
-            rschema = self.schema.rschema(rtype)
-        except KeyError:
-            raise BadRQLQuery('no relation %s' % rtype)
+        rschema = self.schema.rschema(rtype)
         if lhsvar is not None:
             lhsvar.set_scope(scope)
             lhsvar.set_sqlscope(sqlscope)
diff --git a/test/unittest_stcheck.py b/test/unittest_stcheck.py
--- a/test/unittest_stcheck.py
+++ b/test/unittest_stcheck.py
@@ -1,6 +1,8 @@
 from logilab.common.testlib import TestCase, unittest_main
+
+from rql import RQLHelper, BadRQLQuery, stmts, nodes
+
 from unittest_analyze import DummySchema
-from rql import RQLHelper, BadRQLQuery, stmts, nodes
 
 BAD_QUERIES = (
     'Any X, Y GROUPBY X',