# HG changeset patch
# User sylvain.thenault@logilab.fr
# Date 1236864565 -3600
#      Thu Mar 12 14:29:25 2009 +0100
# Node ID 162274178556e46f95b385c4a3dab7b6200b3cef
# Parent  5761e339a2183c5691b44932d76a927cf033cd25
consider subquery in variables graph

diff --git a/stcheck.py b/stcheck.py
--- a/stcheck.py
+++ b/stcheck.py
@@ -17,6 +17,13 @@
 from rql.stmts import Union
 
 
+def _var_graphid(subvarname, trmap, select):
+    try:
+        return trmap[subvarname]
+    except KeyError:
+        return subvarname + str(id(select))
+
+
 class GoTo(Exception):
     """Exception used to control the visit of the tree."""
     def __init__(self, node):
@@ -89,6 +96,7 @@
     
     def visit_select(self, node, errors):
         node.vargraph = {} # graph representing links between variable
+        node.aggregated = set()
         self._visit_selectedterm(node, errors)
         
     def leave_select(self, node, errors):
@@ -104,10 +112,9 @@
                     errors.append('variable %s should be grouped' % var)
             for group in node.groupby:
                 self._check_selected(group, 'group', errors)
-        if node.distinct:
+        if node.distinct and node.orderby:
             # check that variables referenced in the given term are reachable from
             # a selected variable with only ?1 cardinalityselected
-            graph = node.vargraph
             selectidx = frozenset(vref.name for term in selected for vref in term.iget_nodes(VariableRef))
             schema = self.schema
             for sortterm in node.orderby:
@@ -115,15 +122,19 @@
                     if vref.name in selectidx:
                         continue
                     for vname in selectidx:
-                        if self.has_unique_value_path(graph, vname, vref.name):
-                            break
+                        try:
+                            if self.has_unique_value_path(node, vname, vref.name):
+                                break
+                        except KeyError:
+                            continue # unlinked variable (usually from a subquery)
                     else:
                         msg = ('can\'t sort on variable %s which is linked to a'
                                ' variable in the selection but may have different'
                                ' values for a resulting row')
                         errors.append(msg % vref.name)
 
-    def has_unique_value_path(self, graph, fromvar, tovar):
+    def has_unique_value_path(self, select, fromvar, tovar):
+        graph = select.vargraph
         path = has_path(graph, fromvar, tovar)
         if path is None:
             return False
@@ -136,7 +147,9 @@
                 cardidx = 1
             rschema = self.schema.rschema(rtype)
             for rdef in rschema.iter_rdefs():
-                if not rschema.rproperty(rdef[0], rdef[1], 'cardinality')[cardidx] in '?1':
+                # XXX aggregats handling needs much probably some enhancements...
+                if not (tovar in select.aggregated
+                        or rschema.rproperty(rdef[0], rdef[1], 'cardinality')[cardidx] in '?1'):
                     return False
             fromvar = tovar
         return True
@@ -166,8 +179,31 @@
     
     def visit_subquery(self, node, errors):
         pass
+    
     def leave_subquery(self, node, errors):
-        pass 
+        # copy graph information we're interested in
+        pgraph = node.parent.vargraph
+        for select in node.query.children:
+            # map subquery variable names to outer query variable names
+            trmap = {} 
+            for i, vref in enumerate(node.aliases):
+                subvref = select.selection[i]
+                if isinstance(subvref, VariableRef):
+                    trmap[subvref.name] = vref.name
+                elif (isinstance(subvref, Function) and subvref.descr().aggregat
+                      and len(subvref.children) == 1
+                      and isinstance(subvref.children[0], VariableRef)):
+                    # XXX ok for MIN, MAX, but what about COUNT, AVG...
+                    trmap[subvref.children[0].name] = vref.name
+                    node.parent.aggregated.add(vref.name)
+            for key, val in select.vargraph.iteritems():
+                if isinstance(key, tuple):
+                    key = (_var_graphid(key[0], trmap, select),
+                           _var_graphid(key[1], trmap, select))
+                    pgraph[key] = val
+                else:
+                    values = pgraph.setdefault(_var_graphid(key, trmap, select), [])
+                    values += [_var_graphid(v, trmap, select) for v in val]
     
     def visit_sortterm(self, sortterm, errors):
         term = sortterm.term
diff --git a/test/unittest_stcheck.py b/test/unittest_stcheck.py
--- a/test/unittest_stcheck.py
+++ b/test/unittest_stcheck.py
@@ -152,6 +152,22 @@
             ):
             yield self._test_rewrite, rql, expected
 
+    def test_subquery_graphdict(self):
+        # test two things:
+        # * we get graph information from subquery
+        # * we see that we can sort on VCS (eg we have a unique value path from VF to VCD)
+        rqlst = self.parse(('DISTINCT Any VF ORDERBY VCD DESC WHERE '
+                            'VC work_for S, S name "draft" '
+                            'WITH VF, VC, VCD BEING (Any VF, MAX(VC), VCD GROUPBY VF, VCD '
+                            '                        WHERE VC connait VF, VC creation_date VCD)'))
+        self.assertEquals(rqlst.children[0].vargraph,
+                          {'VCD': ['VC'], 'VF': ['VC'], 'S': ['VC'], 'VC': ['S', 'VF', 'VCD'],
+                           ('VC', 'S'): 'work_for',
+                           ('VC', 'VF'): 'connait',
+                           ('VC', 'VCD'): 'creation_date'})
+        self.assertEquals(rqlst.children[0].aggregated, set(('VC',)))
+        
+            
 ##     def test_rewriten_as_string(self):
 ##         rqlst = self.parse('Any X WHERE X eid 12')
 ##         self.assertEquals(rqlst.as_string(), 'Any X WHERE X eid 12')