# HG changeset patch # User Sylvain <syt@logilab.fr> # Date 1208856390 -7200 # Tue Apr 22 11:26:30 2008 +0200 # Node ID 016e237462c1d0e39afdc1e1de0eb49a03265bfa # Parent b66c3bc2e92a42b5aa05dbf4fbe4fcfc92b55d6c more test and fixes diff --git a/analyze.py b/analyze.py --- a/analyze.py +++ b/analyze.py @@ -135,7 +135,7 @@ visit_delete = visit_insert - def visit_update(self, node): + def visit_set(self, node): if not node.defined_vars: node.set_possible_types([{}]) return diff --git a/nodes.py b/nodes.py --- a/nodes.py +++ b/nodes.py @@ -247,6 +247,12 @@ class SubQuery(BaseNode): """WITH clause""" __slots__ = ('aliases', 'query') + def __init__(self, aliases=None, query=None): + if aliases is not None: + self.set_aliases(aliases) + if query is not None: + self.set_query(query) + def set_aliases(self, aliases): self.aliases = aliases for node in aliases: @@ -257,9 +263,8 @@ node.parent = self def copy(self, stmt): - self.set_aliases([v.copy(stmt) for v in self.aliases]) - self.set_query(self.query.copy()) - + return SubQuery([v.copy(stmt) for v in self.aliases], self.query.copy()) + @property def children(self): return self.aliases + [self.query] @@ -331,6 +336,10 @@ if restriction is not None: self.set_where(restriction) + def copy(self, stmt): + new = self.query.copy(stmt) + return Exists(new) + @property def children(self): return (self.query,) @@ -352,6 +361,11 @@ @property def where(self): return self.query + + def replace(self, oldnode, newnode): + assert oldnode is self.query + self.query = newnode + newnode.parent = self @property def scope(self): @@ -721,7 +735,10 @@ def initargs(self, stmt): """return list of arguments to give to __init__ to clone this node""" - newvar = stmt.get_variable(self.name) + if isinstance(self.variable, ColumnAlias): + newvar = stmt.get_variable(self.name, self.variable.colnum) + else: + newvar = stmt.get_variable(self.name) newvar.init_copy(self.variable) return (newvar,) diff --git a/parser.py b/parser.py --- a/parser.py +++ b/parser.py @@ -47,8 +47,8 @@ from warnings import warn from rql.stmts import Union, Select, Delete, Insert, Set from rql.nodes import * - - +def warn(*args): + raise Exception() def unquote(string): """Remove quotes from a string.""" if string.startswith('"'): diff --git a/stcheck.py b/stcheck.py --- a/stcheck.py +++ b/stcheck.py @@ -39,6 +39,7 @@ errors = [] self._visit(node, errors) if errors: + print node raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(errors))) #if node.TYPE == 'select' and \ # not node.defined_vars and not node.get_restriction(): @@ -133,7 +134,6 @@ def visit_set(self, update, errors): self._visit_selectedterm(update, errors) - assert len(update.children) <= 1 def leave_set(self, node, errors): pass diff --git a/stmts.py b/stmts.py --- a/stmts.py +++ b/stmts.py @@ -343,7 +343,7 @@ elif copy_solutions and self.solutions is not None: new.solutions = deepcopy(self.solutions) if self.with_: - new.set_with([sq.copy(new) for sq in self.with_]) + new.set_with([sq.copy(new) for sq in self.with_], check=False) for child in self.selection: new.append_selected(child.copy(new)) if self.groupby: @@ -393,30 +393,38 @@ node.parent = self def set_with(self, terms, check=True): - self.with_ = terms + self.with_ = [] for node in terms: - node.parent = self - if check and len(node.aliases) != len(node.query.children[0].selection): - raise BadRQLQuery('Should have the same number of aliases than ' - 'selected terms in sub-query') - for i, alias in enumerate(node.aliases): - alias = alias.name - if alias in self.aliases: - raise BadRQLQuery('Duplicated alias %s' % alias) - self.aliases[alias] = nodes.ColumnAlias(alias, i, node.query) - # alias may already have been used as a regular variable, replace it - if alias in self.defined_vars: - for vref in self.defined_vars.pop(alias).references(): - vref.variable = self.aliases[alias] - - - def get_variable(self, name): + self.add_subquery(node, check) + + def add_subquery(self, node, check=True): + assert node.query + node.parent = self + self.with_.append(node) + if check and len(node.aliases) != len(node.query.children[0].selection): + raise BadRQLQuery('Should have the same number of aliases than ' + 'selected terms in sub-query') + for i, alias in enumerate(node.aliases): + alias = alias.name + if check and alias in self.aliases: + raise BadRQLQuery('Duplicated alias %s' % alias) + ca = self.get_variable(alias, i) + ca.query = node.query + + def get_variable(self, name, colnum=None): """get a variable instance from its name the variable is created if it doesn't exist yet """ if name in self.aliases: return self.aliases[name] + if colnum is not None: # take care, may be 0 + self.aliases[name] = nodes.ColumnAlias(name, colnum) + # alias may already have been used as a regular variable, replace it + if name in self.defined_vars: + for vref in self.defined_vars.pop(name).references(): + vref.variable = self.aliases[name] + return self.aliases[name] return super(Select, self).get_variable(name) def clean_solutions(self, solutions=None): @@ -548,7 +556,7 @@ """add var in 'orderby' constraints asc is a boolean indicating the group order (ascendent or descendent) """ - if self.groupby is None: + if not self.groupby: self.groupby = [] vref = nodes.variable_ref(var) vref.register_reference() @@ -578,7 +586,7 @@ self.add_sort_term(term) def add_sort_term(self, term): - if self.orderby is None: + if not self.orderby: self.orderby = [] self.orderby.append(term) term.parent = self @@ -608,8 +616,12 @@ self.orderby = None def select_only_variables(self): - self.selection = [vref for term in self.selection - for vref in term.iget_nodes(nodes.VariableRef)] + selection = [] + for term in self.selection: + for vref in term.iget_nodes(nodes.VariableRef): + if not vref in selection: + selection.append(vref) + self.selection = selection class Delete(Statement, ScopeNode): diff --git a/test/unittest_analyze.py b/test/unittest_analyze.py --- a/test/unittest_analyze.py +++ b/test/unittest_analyze.py @@ -146,8 +146,6 @@ if DEBUG: print rql node = self.helper.parse(rql) - print rql - print node self.assertRaises(TypeResolverException, self.helper.compute_solutions, node, debug=DEBUG) @@ -270,12 +268,6 @@ sols = sorted(node.children[0].solutions) self.assertEqual(sols, [{'E2': 'Person', 'E1': 'String'}]) - def test_insert_1(self): - node = self.helper.parse('INSERT Person X : X name "toto", X work_for Y WHERE Y name "logilab"') - self.helper.compute_solutions(node, debug=DEBUG) - sols = sorted(node.solutions) - self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}]) - def test_relation_eid(self): node = self.helper.parse('Any E2 WHERE E2 work_for E1, E2 eid 2') self.helper.compute_solutions(node, debug=DEBUG) @@ -348,6 +340,25 @@ self.assertEqual(node.children[0].with_[0].query.children[0].solutions, [{'X': 'Person'}]) self.assertEqual(node.children[0].solutions, [{'X': 'Person', 'Y': 'Person', 'L': 'Address'}]) + + def test_insert(self): + node = self.helper.parse('INSERT Person X : X name "toto", X work_for Y WHERE Y name "logilab"') + self.helper.compute_solutions(node, debug=DEBUG) + sols = sorted(node.solutions) + self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}]) + + def test_delete(self): + node = self.helper.parse('DELETE Person X WHERE X name "toto", X work_for Y') + self.helper.compute_solutions(node, debug=DEBUG) + sols = sorted(node.solutions) + self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}]) + + def test_set(self): + node = self.helper.parse('SET X name "toto", X work_for Y WHERE Y name "logilab"') + self.helper.compute_solutions(node, debug=DEBUG) + sols = sorted(node.solutions) + self.assertEqual(sols, [{'X': 'Person', 'Y': 'Company'}]) + def test_nongrer_not_u_ownedby_u(self): node = self.helper.parse('Any U WHERE NOT U owned_by U') diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py --- a/test/unittest_nodes.py +++ b/test/unittest_nodes.py @@ -238,6 +238,9 @@ tree = self._parse("Any X WHERE X name 1.0 LIMIT 10 OFFSET 10") self.assertEqual(tree.limit, 10) self.assertEqual(tree.offset, 10) + + def test_exists(self): + tree = self._simpleparse("Any X,N WHERE X is Person, X name N, EXISTS(X work_for Y)") def test_copy(self): tree = self._parse("Any X,LOWER(Y) GROUPBY N ORDERBY N WHERE X is Person, X name N, X date >= TODAY") @@ -254,7 +257,7 @@ annotator.annotate(tree) self.assertEquals(tree.defined_vars['X'].selected_index(), 0) self.assertEquals(tree.defined_vars['N'].selected_index(), None) - + # insertion tests ######################################################### def test_insert_base_1(self):