# HG changeset patch
# User Sylvain <syt@logilab.fr>
# Date 1208856390 -7200
#      Tue Apr 22 11:26:30 2008 +0200
# Node ID 016e237462c1d0e39afdc1e1de0eb49a03265bfa
# Parent  b66c3bc2e92a42b5aa05dbf4fbe4fcfc92b55d6c
more test and fixes

diff --git a/analyze.py b/analyze.py
--- a/analyze.py
+++ b/analyze.py
@@ -135,7 +135,7 @@
         
     visit_delete = visit_insert
     
-    def visit_update(self, node):
+    def visit_set(self, node):
         if not node.defined_vars:
             node.set_possible_types([{}])
             return
diff --git a/nodes.py b/nodes.py
--- a/nodes.py
+++ b/nodes.py
@@ -247,6 +247,12 @@
 class SubQuery(BaseNode):
     """WITH clause"""
     __slots__ = ('aliases', 'query')
+    def __init__(self, aliases=None, query=None):
+        if aliases is not None:
+            self.set_aliases(aliases)
+        if query is not None:
+            self.set_query(query)
+            
     def set_aliases(self, aliases):
         self.aliases = aliases
         for node in aliases:
@@ -257,9 +263,8 @@
         node.parent = self
 
     def copy(self, stmt):
-        self.set_aliases([v.copy(stmt) for v in self.aliases])
-        self.set_query(self.query.copy())
-        
+        return SubQuery([v.copy(stmt) for v in self.aliases], self.query.copy())
+    
     @property
     def children(self):
         return self.aliases + [self.query]
@@ -331,6 +336,10 @@
         if restriction is not None:
             self.set_where(restriction)
 
+    def copy(self, stmt):
+        new = self.query.copy(stmt)
+        return Exists(new)
+    
     @property
     def children(self):
         return (self.query,)
@@ -352,6 +361,11 @@
     @property
     def where(self):
         return self.query
+    
+    def replace(self, oldnode, newnode):
+        assert oldnode is self.query
+        self.query = newnode
+        newnode.parent = self
 
     @property
     def scope(self):
@@ -721,7 +735,10 @@
 
     def initargs(self, stmt):
         """return list of arguments to give to __init__ to clone this node"""
-        newvar = stmt.get_variable(self.name)
+        if isinstance(self.variable, ColumnAlias):
+            newvar = stmt.get_variable(self.name, self.variable.colnum)
+        else:
+            newvar = stmt.get_variable(self.name)
         newvar.init_copy(self.variable)
         return (newvar,)
 
diff --git a/parser.py b/parser.py
--- a/parser.py
+++ b/parser.py
@@ -47,8 +47,8 @@
 from warnings import warn
 from rql.stmts import Union, Select, Delete, Insert, Set
 from rql.nodes import *
-
-
+def warn(*args):
+    raise Exception()
 def unquote(string):
     """Remove quotes from a string."""
     if string.startswith('"'):
diff --git a/stcheck.py b/stcheck.py
--- a/stcheck.py
+++ b/stcheck.py
@@ -39,6 +39,7 @@
         errors = []
         self._visit(node, errors)
         if errors:
+            print node
             raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(errors)))
         #if node.TYPE == 'select' and \
         #       not node.defined_vars and not node.get_restriction():
@@ -133,7 +134,6 @@
         
     def visit_set(self, update, errors):
         self._visit_selectedterm(update, errors)
-        assert len(update.children) <= 1                
     def leave_set(self, node, errors):
         pass                
 
diff --git a/stmts.py b/stmts.py
--- a/stmts.py
+++ b/stmts.py
@@ -343,7 +343,7 @@
         elif copy_solutions and self.solutions is not None:
             new.solutions = deepcopy(self.solutions)
         if self.with_:
-            new.set_with([sq.copy(new) for sq in self.with_])
+            new.set_with([sq.copy(new) for sq in self.with_], check=False)
         for child in self.selection:
             new.append_selected(child.copy(new))
         if self.groupby:
@@ -393,30 +393,38 @@
             node.parent = self
             
     def set_with(self, terms, check=True):
-        self.with_ = terms
+        self.with_ = []
         for node in terms:
-            node.parent = self
-            if check and len(node.aliases) != len(node.query.children[0].selection):
-                raise BadRQLQuery('Should have the same number of aliases than '
-                                  'selected terms in sub-query')
-            for i, alias in enumerate(node.aliases):
-                alias = alias.name
-                if alias in self.aliases:
-                    raise BadRQLQuery('Duplicated alias %s' % alias)
-                self.aliases[alias] = nodes.ColumnAlias(alias, i, node.query)
-                # alias may already have been used as a regular variable, replace it
-                if alias in self.defined_vars:
-                    for vref in self.defined_vars.pop(alias).references():
-                        vref.variable = self.aliases[alias]
-
-
-    def get_variable(self, name):
+            self.add_subquery(node, check)
+            
+    def add_subquery(self, node, check=True):
+        assert node.query
+        node.parent = self
+        self.with_.append(node)
+        if check and len(node.aliases) != len(node.query.children[0].selection):
+            raise BadRQLQuery('Should have the same number of aliases than '
+                              'selected terms in sub-query')
+        for i, alias in enumerate(node.aliases):
+            alias = alias.name
+            if check and alias in self.aliases:
+                raise BadRQLQuery('Duplicated alias %s' % alias)
+            ca = self.get_variable(alias, i)
+            ca.query = node.query
+            
+    def get_variable(self, name, colnum=None):
         """get a variable instance from its name
         
         the variable is created if it doesn't exist yet
         """
         if name in self.aliases:
             return self.aliases[name]
+        if colnum is not None: # take care, may be 0
+            self.aliases[name] = nodes.ColumnAlias(name, colnum)
+            # alias may already have been used as a regular variable, replace it
+            if name in self.defined_vars:
+                for vref in self.defined_vars.pop(name).references():
+                    vref.variable = self.aliases[name]
+            return self.aliases[name]
         return super(Select, self).get_variable(name)
     
     def clean_solutions(self, solutions=None):
@@ -548,7 +556,7 @@
         """add var in 'orderby' constraints
         asc is a boolean indicating the group order (ascendent or descendent)
         """
-        if self.groupby is None:
+        if not self.groupby:
             self.groupby = []
         vref = nodes.variable_ref(var)
         vref.register_reference()
@@ -578,7 +586,7 @@
         self.add_sort_term(term)
         
     def add_sort_term(self, term):
-        if self.orderby is None:
+        if not self.orderby:
             self.orderby = []
         self.orderby.append(term)
         term.parent = self
@@ -608,8 +616,12 @@
             self.orderby = None
 
     def select_only_variables(self):
-        self.selection = [vref for term in self.selection
-                         for vref in term.iget_nodes(nodes.VariableRef)]
+        selection = []
+        for term in self.selection:
+            for vref in term.iget_nodes(nodes.VariableRef):
+                if not vref in selection:
+                    selection.append(vref)
+        self.selection = selection
 
     
 class Delete(Statement, ScopeNode):
diff --git a/test/unittest_analyze.py b/test/unittest_analyze.py
--- a/test/unittest_analyze.py
+++ b/test/unittest_analyze.py
@@ -146,8 +146,6 @@
             if DEBUG:
                 print rql
             node = self.helper.parse(rql)
-            print rql
-            print node
             self.assertRaises(TypeResolverException,
                               self.helper.compute_solutions, node, debug=DEBUG)
         
@@ -270,12 +268,6 @@
         sols = sorted(node.children[0].solutions)
         self.assertEqual(sols, [{'E2': 'Person', 'E1': 'String'}])
 
-    def test_insert_1(self):
-        node = self.helper.parse('INSERT Person X : X name "toto", X work_for Y WHERE Y name "logilab"')
-        self.helper.compute_solutions(node, debug=DEBUG)
-        sols = sorted(node.solutions)
-        self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}])
-
     def test_relation_eid(self):
         node = self.helper.parse('Any E2 WHERE E2 work_for E1, E2 eid 2')
         self.helper.compute_solutions(node, debug=DEBUG)
@@ -348,6 +340,25 @@
         self.assertEqual(node.children[0].with_[0].query.children[0].solutions, [{'X': 'Person'}])
         self.assertEqual(node.children[0].solutions, [{'X': 'Person', 'Y': 'Person',
                                                        'L': 'Address'}])
+
+    def test_insert(self):
+        node = self.helper.parse('INSERT Person X : X name "toto", X work_for Y WHERE Y name "logilab"')
+        self.helper.compute_solutions(node, debug=DEBUG)
+        sols = sorted(node.solutions)
+        self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}])
+
+    def test_delete(self):
+        node = self.helper.parse('DELETE Person X WHERE X name "toto", X work_for Y')
+        self.helper.compute_solutions(node, debug=DEBUG)
+        sols = sorted(node.solutions)
+        self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}])
+
+    def test_set(self):
+        node = self.helper.parse('SET X name "toto", X work_for Y WHERE Y name "logilab"')
+        self.helper.compute_solutions(node, debug=DEBUG)
+        sols = sorted(node.solutions)
+        self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}])
+
         
     def test_nongrer_not_u_ownedby_u(self):
         node = self.helper.parse('Any U WHERE NOT U owned_by U')
diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py
--- a/test/unittest_nodes.py
+++ b/test/unittest_nodes.py
@@ -238,6 +238,9 @@
         tree = self._parse("Any X WHERE X name 1.0 LIMIT 10 OFFSET 10")
         self.assertEqual(tree.limit, 10)
         self.assertEqual(tree.offset, 10)
+
+    def test_exists(self):
+        tree = self._simpleparse("Any X,N WHERE X is Person, X name N, EXISTS(X work_for Y)")
         
     def test_copy(self):
         tree = self._parse("Any X,LOWER(Y) GROUPBY N ORDERBY N WHERE X is Person, X name N, X date >= TODAY")
@@ -254,7 +257,7 @@
         annotator.annotate(tree)
         self.assertEquals(tree.defined_vars['X'].selected_index(), 0)
         self.assertEquals(tree.defined_vars['N'].selected_index(), None)
-            
+        
     # insertion tests #########################################################
 
     def test_insert_base_1(self):