# HG changeset patch # User sylvain.thenault@logilab.fr # Date 1236864565 -3600 # Thu Mar 12 14:29:25 2009 +0100 # Node ID 162274178556e46f95b385c4a3dab7b6200b3cef # Parent 5761e339a2183c5691b44932d76a927cf033cd25 consider subquery in variables graph diff --git a/stcheck.py b/stcheck.py --- a/stcheck.py +++ b/stcheck.py @@ -17,6 +17,13 @@ from rql.stmts import Union +def _var_graphid(subvarname, trmap, select): + try: + return trmap[subvarname] + except KeyError: + return subvarname + str(id(select)) + + class GoTo(Exception): """Exception used to control the visit of the tree.""" def __init__(self, node): @@ -89,6 +96,7 @@ def visit_select(self, node, errors): node.vargraph = {} # graph representing links between variable + node.aggregated = set() self._visit_selectedterm(node, errors) def leave_select(self, node, errors): @@ -104,10 +112,9 @@ errors.append('variable %s should be grouped' % var) for group in node.groupby: self._check_selected(group, 'group', errors) - if node.distinct: + if node.distinct and node.orderby: # check that variables referenced in the given term are reachable from # a selected variable with only ?1 cardinalityselected - graph = node.vargraph selectidx = frozenset(vref.name for term in selected for vref in term.iget_nodes(VariableRef)) schema = self.schema for sortterm in node.orderby: @@ -115,15 +122,19 @@ if vref.name in selectidx: continue for vname in selectidx: - if self.has_unique_value_path(graph, vname, vref.name): - break + try: + if self.has_unique_value_path(node, vname, vref.name): + break + except KeyError: + continue # unlinked variable (usually from a subquery) else: msg = ('can\'t sort on variable %s which is linked to a' ' variable in the selection but may have different' ' values for a resulting row') errors.append(msg % vref.name) - def has_unique_value_path(self, graph, fromvar, tovar): + def has_unique_value_path(self, select, fromvar, tovar): + graph = select.vargraph path = has_path(graph, fromvar, tovar) if path is None: return False @@ -136,7 +147,9 @@ cardidx = 1 rschema = self.schema.rschema(rtype) for rdef in rschema.iter_rdefs(): - if not rschema.rproperty(rdef[0], rdef[1], 'cardinality')[cardidx] in '?1': + # XXX aggregats handling needs much probably some enhancements... + if not (tovar in select.aggregated + or rschema.rproperty(rdef[0], rdef[1], 'cardinality')[cardidx] in '?1'): return False fromvar = tovar return True @@ -166,8 +179,31 @@ def visit_subquery(self, node, errors): pass + def leave_subquery(self, node, errors): - pass + # copy graph information we're interested in + pgraph = node.parent.vargraph + for select in node.query.children: + # map subquery variable names to outer query variable names + trmap = {} + for i, vref in enumerate(node.aliases): + subvref = select.selection[i] + if isinstance(subvref, VariableRef): + trmap[subvref.name] = vref.name + elif (isinstance(subvref, Function) and subvref.descr().aggregat + and len(subvref.children) == 1 + and isinstance(subvref.children[0], VariableRef)): + # XXX ok for MIN, MAX, but what about COUNT, AVG... + trmap[subvref.children[0].name] = vref.name + node.parent.aggregated.add(vref.name) + for key, val in select.vargraph.iteritems(): + if isinstance(key, tuple): + key = (_var_graphid(key[0], trmap, select), + _var_graphid(key[1], trmap, select)) + pgraph[key] = val + else: + values = pgraph.setdefault(_var_graphid(key, trmap, select), []) + values += [_var_graphid(v, trmap, select) for v in val] def visit_sortterm(self, sortterm, errors): term = sortterm.term diff --git a/test/unittest_stcheck.py b/test/unittest_stcheck.py --- a/test/unittest_stcheck.py +++ b/test/unittest_stcheck.py @@ -152,6 +152,22 @@ ): yield self._test_rewrite, rql, expected + def test_subquery_graphdict(self): + # test two things: + # * we get graph information from subquery + # * we see that we can sort on VCS (eg we have a unique value path from VF to VCD) + rqlst = self.parse(('DISTINCT Any VF ORDERBY VCD DESC WHERE ' + 'VC work_for S, S name "draft" ' + 'WITH VF, VC, VCD BEING (Any VF, MAX(VC), VCD GROUPBY VF, VCD ' + ' WHERE VC connait VF, VC creation_date VCD)')) + self.assertEquals(rqlst.children[0].vargraph, + {'VCD': ['VC'], 'VF': ['VC'], 'S': ['VC'], 'VC': ['S', 'VF', 'VCD'], + ('VC', 'S'): 'work_for', + ('VC', 'VF'): 'connait', + ('VC', 'VCD'): 'creation_date'}) + self.assertEquals(rqlst.children[0].aggregated, set(('VC',))) + + ## def test_rewriten_as_string(self): ## rqlst = self.parse('Any X WHERE X eid 12') ## self.assertEquals(rqlst.as_string(), 'Any X WHERE X eid 12')