stmts.py 48 KB
Newer Older
1
# copyright 2004-2021 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of rql.
#
# rql is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# rql is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with rql. If not, see <http://www.gnu.org/licenses/>.
18
19
"""Construction and manipulation of RQL syntax trees.

Nicolas Chauvat's avatar
Nicolas Chauvat committed
20
21
22
This module defines only first level nodes (i.e. statements). Child nodes are
defined in the nodes module
"""
Sylvain Thenault's avatar
Sylvain Thenault committed
23
from copy import deepcopy
24
from warnings import warn
Sylvain Thenault's avatar
Sylvain Thenault committed
25

26
from logilab.common.deprecation import callable_deprecated
Nicolas Chauvat's avatar
Nicolas Chauvat committed
27

28
from rql import BadRQLQuery, CoercionError, nodes
29
from rql.base import BaseNode, Node
30
from rql.utils import rqlvar_maker
31
from rql import rqltypes as rt
Nicolas Chauvat's avatar
Nicolas Chauvat committed
32

33
from typing import (
34
35
36
37
38
39
40
41
42
43
44
    TYPE_CHECKING,
    Dict,
    List,
    Union as Union_,
    Any,
    Iterable,
    Optional,
    cast,
    Set as Set_,
    Tuple,
    Iterator,
45
)
46

47
_MARKER: object = object()
48

49
__docformat__: str = "restructuredtext en"
50

51
52
53
if TYPE_CHECKING:
    import rql

54
55
56
Solution = Dict[str, str]
SolutionsList = List[Solution]

57
58

def _check_references(
59
60
61
    defined: Dict[str, Union_["rql.nodes.Variable", "rql.nodes.ColumnAlias"]],
    varrefs: Iterable[Union_["rql.nodes.VariableRef", "rql.base.BaseNode"]],
) -> bool:
62
63
64
65
66
    refs = {}
    for var in defined.values():
        for vref in var.references():
            # be careful, Variable and VariableRef define __cmp__
            if not [v for v in varrefs if v is vref]:
67
                raise AssertionError("vref %r is not in the tree" % vref)
68
69
            refs[id(vref)] = 1
    for vref in varrefs:
Aurelien Campeas's avatar
Aurelien Campeas committed
70
        if id(vref) not in refs:
71
            raise AssertionError("vref %r is not referenced (%r)" % (vref, vref.stmt))
72
    return True
73

74

75
class undo_modification:
76
77
    def __init__(self, select):
        self.select = select
78

79
80
    def __enter__(self):
        self.select.save_state()
81

82
83
    def __exit__(self):
        self.select.recover()
84

85

86
class ScopeNode(BaseNode):
87
    def __init__(self):
Nicolas Chauvat's avatar
Nicolas Chauvat committed
88
        # dictionnary of defined variables in the original RQL syntax tree
89
        self.defined_vars: Dict[str, "rql.nodes.Variable"] = {}
90
        self.with_: List["rql.nodes.SubQuery"] = []
91
        # list of possibles solutions for used variables
92
        self.solutions: SolutionsList = []
93
94
95
96
97
98
99
100
101
        self._varmaker = None  # variable names generator, built when necessary
        self.where: Optional["rql.base.Node"] = None  # where clause node
        self.having: Iterable["rql.base.Node"] = ()  # XXX now a single node
        # "ScopeNode" has no attribute "schema"
        self.schema: Optional[Any] = None
        # "ScopeNode" has no attribute "aliases"
        self.aliases: Dict[str, "rql.nodes.ColumnAlias"] = {}

    # "ScopeNode" has no attribute "undo_manager"
102
103
104
105
106
107
    @property
    def undo_manager(self):
        try:
            return self._undo_manager
        except AttributeError:
            from rql.undo import SelectionManager
108

109
110
            self._undo_manager = SelectionManager(self)
            return self._undo_manager
111

112
113
114
115
    @property
    def should_register_op(self):
        return None

116
117
    def get_selected_variables(self):
        return self.selected_terms()
118

119
    def set_where(self, node: "rql.base.Node") -> None:
120
121
        self.where = node
        node.parent = self
122

123
    def set_having(self, terms: Iterable["rql.base.Node"]) -> None:
124
125
        if self.should_register_op:
            from rql.undo import SetHavingOperation
126

127
            # "ScopeNode" has no attribute "undo_manager"  [attr-defined]
128
            # either we ignore, either we redefine
129
            self.undo_manager.add_operation(SetHavingOperation(self, self.having))
130
131
132
133
        self.having = terms
        for node in terms:
            node.parent = self

134
135
    # Signature of "copy" incompatible with supertype "BaseNode"
    # either add a parameter that won't be used, either ignore
136
137
138
139
    def copy(
        self,
        stmt: Optional["rql.stmts.Statement"] = None,
        copy_solutions: bool = True,
140
        solutions: SolutionsList = None,
141
    ) -> "rql.base.BaseNode":
142
        new = self.__class__()
Sylvain Thenault's avatar
Sylvain Thenault committed
143
144
145
146
        if self.schema is not None:
            new.schema = self.schema
        if solutions is not None:
            new.solutions = solutions
147
        elif copy_solutions and self.solutions:
Sylvain Thenault's avatar
Sylvain Thenault committed
148
            new.solutions = deepcopy(self.solutions)
Nicolas Chauvat's avatar
Nicolas Chauvat committed
149
        return new
150

Nicolas Chauvat's avatar
Nicolas Chauvat committed
151
    # construction helper methods #############################################
152
    def get_etype(self, name: str) -> "rql.nodes.Constant":
Nicolas Chauvat's avatar
Nicolas Chauvat committed
153
        """return the type object for the given entity's type name
154

Nicolas Chauvat's avatar
Nicolas Chauvat committed
155
156
        raise BadRQLQuery on unknown type
        """
157
        return nodes.Constant(name, "etype")
158

159
160
161
    def get_variable(
        self, name: str
    ) -> Union_["rql.nodes.Variable", "rql.nodes.ColumnAlias"]:
Nicolas Chauvat's avatar
Nicolas Chauvat committed
162
        """get a variable instance from its name
163

Nicolas Chauvat's avatar
Nicolas Chauvat committed
164
165
166
167
        the variable is created if it doesn't exist yet
        """
        try:
            return self.defined_vars[name]
168
        except Exception:
Nicolas Chauvat's avatar
Nicolas Chauvat committed
169
            self.defined_vars[name] = var = nodes.Variable(name)
170
            var.stmt = self
Nicolas Chauvat's avatar
Nicolas Chauvat committed
171
            return var
Sylvain's avatar
Sylvain committed
172

173
    def allocate_varname(self) -> str:
174
175
        """return an yet undefined variable name"""
        if self._varmaker is None:
176
177
178
179
180
            self._varmaker = rqlvar_maker(
                defined=self.defined_vars,
                # XXX only on Select node
                aliases=getattr(self, "aliases", None),
            )
181
        return next(self._varmaker)
182

183
    def make_variable(self) -> "rql.nodes.Variable":
184
185
186
187
        """create a new variable with an unique name for this tree"""
        var = self.get_variable(self.allocate_varname())
        if self.should_register_op:
            from rql.undo import MakeVarOperation
188

189
            self.undo_manager.add_operation(MakeVarOperation(var))
190
        return cast("rql.nodes.Variable", var)
191

192
193
    def set_possible_types(
        self,
194
        solutions: SolutionsList,
195
196
197
198
        kwargs: Optional[Union_[object, Dict[str, str]]] = _MARKER,
        key: str = "possibletypes",
    ) -> None:
        if key == "possibletypes":
199
            self.solutions = solutions
Sylvain's avatar
Sylvain committed
200
        defined = self.defined_vars
201
        for var in defined.values():
202
            var.stinfo[key] = set()
Sylvain Thenault's avatar
Sylvain Thenault committed
203
            for solution in solutions:
204
                var.stinfo[key].add(solution[var.name])
205
        # for debugging
206
        # for sol in solutions:
207
208
        #    for vname in sol:
        #        assert vname in self.defined_vars or vname in self.aliases
209

210
    def check_references(self) -> bool:
211
        """test function"""
212
        try:
213
214
215
216
            defined = cast(
                Dict[str, Union_["rql.nodes.ColumnAlias", "rql.nodes.Variable"]],
                self.aliases.copy(),
            )
217

218
        except AttributeError:
219
220
221
222
            defined = cast(
                Dict[str, Union_["rql.nodes.ColumnAlias", "rql.nodes.Variable"]],
                self.defined_vars.copy(),
            )
223
224
225
226
        else:
            defined.update(self.defined_vars)
            for subq in self.with_:
                subq.query.check_references()
227
228
229
        varrefs = [
            vref for vref in self.get_nodes(nodes.VariableRef) if vref.stmt is self
        ]
230
231
        try:
            _check_references(defined, varrefs)
232
        except Exception:
Rémi Cardona's avatar
Rémi Cardona committed
233
            print(repr(self))
234
235
            raise
        return True
236

237

238
class Statement:
239
240
241
242
    """base class for statement nodes"""

    # default values for optional instance attributes, set on the instance when
    # used
243
    schema: Optional["rql.interfaces.ISchema"] = None
244
    annotated: bool = False  # set by the annotator
245

246
    if TYPE_CHECKING:
247

248
249
250
        def get_variable(self, name, column=None):
            raise NotImplementedError()

251
    # navigation helper methods #############################################
252

253
    @property
254
255
256
    def root(self):
        """return the root node of the tree"""
        return self
257

258
259
260
261
262
263
264
    @property
    def stmt(self):
        return self

    @property
    def scope(self):
        return self
265

266
    def ored(
267
        self, traverse_scope: bool = False, _fromnode: Optional["rql.nodes.And"] = None
268
    ) -> Optional["rql.nodes.Or"]:
269
        return None
270

271
    def neged(
272
        self, traverse_scope: bool = False, _fromnode: Optional["rql.nodes.Or"] = None
273
    ) -> Optional["rql.nodes.Not"]:
274
275
        return None

Nicolas Chauvat's avatar
Nicolas Chauvat committed
276

277
class Union(Statement, Node):
Sylvain's avatar
Sylvain committed
278
279
280
    """the select node is the root of the syntax tree for selection statement
    using UNION
    """
281
282

    TYPE: str = "select"
283
284
    # default values for optional instance attributes, set on the instance when
    # used
285
    undoing: bool = False  # used to prevent from memorizing when undoing !
286
    memorizing: int = 0  # recoverable modification attributes
287
    children: List["rql.stmts.Select"]
Sylvain's avatar
Sylvain committed
288

289
    def wrap_selects(self) -> None:
290
291
292
293
        """return a new rqlst root containing the given union as a subquery"""
        child = Union()
        for select in self.children[:]:
            child.append(select)
294
            self.remove_select(select)
295
296
        newselect: "rql.stmts.Select" = Select()
        aliases: List["rql.nodes.VariableRef"] = []
297
        for i in range(len(select.selection)):
298
            aliases.append(nodes.VariableRef(newselect.make_variable()))
299
300
301
302
        newselect.add_subquery(nodes.SubQuery(aliases, child), check=False)
        for vref in aliases:
            newselect.append_selected(nodes.VariableRef(vref.variable))
        self.append_select(newselect)
303

304
    def _get_offset(self) -> int:
305
        warn("offset is now a Select node attribute", DeprecationWarning, stacklevel=2)
306
        last_children = self.children[-1]
307
        return last_children.offset
308

309
    def set_offset(self, offset: int) -> None:
310
        if len(self.children) == 1:
311
312
            last_children = cast("rql.stmts.Select", self.children[-1])
            last_children.set_offset(offset)
313
314
        # we have to introduce a new root
        # XXX not undoable since a new root has to be introduced
315
        self.wrap_selects()
316
317
        first_child = cast("rql.stmts.Select", self.children[0])
        first_child.set_offset(offset)
318

319
    offset = property(_get_offset, set_offset)
320

Sylvain Thenault's avatar
Sylvain Thenault committed
321
    def _get_limit(self):
322
        warn("limit is now a Select node attribute", DeprecationWarning, stacklevel=2)
323
        return self.children[-1].limit
324

325
    def set_limit(self, limit: int) -> None:
326
        if len(self.children) == 1:
327
            self.children[-1].set_limit(limit)
328
            return None
329
        self.wrap_selects()
330
        self.children[0].set_limit(limit)
331
332
        return None

333
    limit = property(_get_limit, set_limit)
334
335

    @property
336
337
338
339
340
    def root(self):
        """return the root node of the tree"""
        if self.parent is None:
            return self
        return self.parent.root
341

342
    def get_description(
343
344
        self,
        mainindex: Optional[int] = None,
345
        tr: Optional[rt.TranslationFunction] = None,
346
    ) -> List[List[str]]:
347
348
349
350
351
352
353
354
355
        """
        `mainindex`:
          selection index to consider as main column, useful to get smarter
          results
        `tr`:
          optional translation function taking a string as argument and
          returning a string
        """
        if tr is None:
356
357
358

            def tr(msg, context=None):
                return msg
359

360
        return [c.get_description(mainindex, tr) for c in self.children]
Sylvain's avatar
Sylvain committed
361

362
    # repr / as_string / copy #################################################
363

364
    def __repr__(self) -> str:
365
        return "\nUNION\n".join(repr(select) for select in self.children)
366

367
    def as_string(self, kwargs: Optional[Dict] = None) -> str:
Sylvain's avatar
Sylvain committed
368
        """return the tree as an encoded rql string"""
369
370
371
        strings: List[str] = [
            select.as_string(kwargs=kwargs) for select in self.children
        ]
372
373
        if len(strings) == 1:
            return strings[0]
374
        return " UNION ".join("(%s)" % part for part in strings)
375

376
377
378
    def copy(
        self, stmt: Optional["rql.stmts.Statement"] = None, copy_children: bool = True
    ) -> "rql.stmts.Union":
379
        new: "rql.stmts.Union" = Union()
380
381
        if self.schema is not None:
            new.schema = self.schema
Sylvain Thenault's avatar
Sylvain Thenault committed
382
383
384
        if copy_children:
            for child in self.children:
                new.append(child.copy())
385
                assert new.children[-1].parent is new
Sylvain Thenault's avatar
Sylvain Thenault committed
386
        return new
Nicolas Chauvat's avatar
Nicolas Chauvat committed
387

388
    # union specific methods ##################################################
389

390
    # XXX for bw compat, should now use get_variable_indices (cw > 3.8.4)
391
392
    def get_variable_variables(self) -> Set_[int]:
        change: Set_[int] = set()
393
        for idx in self.get_variable_indices():
394
            first_child = self.children[0]
395
            vrefs = first_child.selection[idx].iget_nodes(nodes.VariableRef)
396
397

            for vref in vrefs:
398
399
400
                change.add(vref.name)
        return change

401
    def get_variable_indices(self) -> Set_[int]:
402
403
        """return the set of selection indexes which take different types
        according to the solutions
404
        """
405
406
        change: Set_[int] = set()
        values: Dict[int, Set_] = {}
407
        for select in self.children:
408
            for descr in select.get_selection_solutions():
409
410
                for i, etype in enumerate(descr):
                    values.setdefault(i, set()).add(etype)
411
        for idx, etypes in values.items():
412
413
            if len(etypes) > 1:
                change.add(idx)
414
        return change
415

416
417
418
    def _locate_subquery(
        self, col: int, etype: str, kwargs: Optional[Dict[Any, Any]] = None
    ) -> Tuple:
419
        first_child = self.children[0]
420
421
422
        has_children = len(self.children) == 1
        first_child_subqueries = not first_child.with_
        if has_children and first_child_subqueries:
423
            return self.children[0], col
424
        for select in self.children:
425
            term = select.selection[col]
426
            try:
427
428
                if term.name in select.aliases:
                    alias = select.aliases[term.name]
429
                    return alias.query._locate_subquery(alias.colnum, etype, kwargs)
430
            except AttributeError:
Sylvain Thénault's avatar
cleanup    
Sylvain Thénault committed
431
                # term has no 'name' attribute
432
                pass
433
            for i, solution in enumerate(select.solutions):
434
                if term.get_type(solution, kwargs) == etype:
435
                    return select, col
436
        raise Exception(f"internal error, {etype} not found on col {col}")
437

438
439
440
    def locate_subquery(
        self, col: int, etype: str, kwargs: Optional[Dict] = None
    ) -> Any:
441
442
443
        """return a select node and associated selection index where root
        variable at column `col` is of type `etype`
        """
444
        try:
445
446
            # Cannot determine type of '_subq_cache'  [has-type]
            return self._subq_cache[(col, etype)]  # type: ignore[has-type]
447
448
449
450
        except AttributeError:
            self._subq_cache = {}
        except KeyError:
            pass
451
        self._subq_cache[(col, etype)] = self._locate_subquery(col, etype, kwargs)
452
        return self._subq_cache[(col, etype)]
453

454
    def subquery_selection_index(self, subselect: Any, col: int) -> int:
455
456
        """given a select sub-query and a column index in the root query, return
        the selection index for this column in the sub-query
457
        """
458
        selectpath: List = []
459
        while subselect.parent.parent is not None:
460
461
            subq = subselect.parent.parent
            subselect = subq.parent
462
463
464
            selectpath.insert(0, subselect)
        for select in selectpath:
            col = select.selection[col].variable.colnum
465
        return col
466

467
    # recoverable modification methods ########################################
Sylvain's avatar
explain    
Sylvain committed
468
469
470

    # don't use @cached: we want to be able to disable it while this must still
    # be cached
471
    @property
472
473
    def undo_manager(self) -> "rql.undo.SelectionManager":
        from rql.undo import SelectionManager
474

475
476
477
478
479
        undo_manager = getattr(self, "_undo_manager", None)
        if undo_manager:
            return undo_manager
        self._undo_manager = SelectionManager(self)
        return self._undo_manager
480

481
482
483
484
    @property
    def should_register_op(self):
        return self.memorizing and not self.undoing

485
    def undo_modification(self) -> "rql.stmts.undo_modification":
486
487
        return undo_modification(self)

488
    def save_state(self) -> None:
489
490
491
492
        """save the current tree"""
        self.undo_manager.push_state()
        self.memorizing += 1

493
    def recover(self) -> None:
494
495
496
        """reverts the tree as it was when save_state() was last called"""
        self.memorizing -= 1
        assert self.memorizing >= 0
497
        self.undo_manager.recover()
498

499
    def check_references(self) -> bool:
500
501
        """test function"""
        for select in self.children:
502
            select.check_references()
503
504
        return True

505
    def append_select(self, select: "rql.stmts.Select") -> None:
506
507
        if self.should_register_op:
            from rql.undo import AppendSelectOperation
508

509
            self.undo_manager.add_operation(AppendSelectOperation(self, select))
510
        self.children.append(select)
511

512
    def remove_select(self, select: "rql.stmts.Select") -> None:
513
        idx: int = self.children.index(select)
514
515
        if self.should_register_op:
            from rql.undo import RemoveSelectOperation
516

517
518
519
            self.undo_manager.add_operation(RemoveSelectOperation(self, select, idx))
        self.children.pop(idx)

520

521
class Select(Statement, nodes.EditableMixIn, ScopeNode):
522
523
524
    """the select node is the base statement of the syntax tree for selection
    statement, always child of a UNION root.
    """
525

526
    vargraph: rt.Graph = {}
527
    parent = None
528
    distinct: bool = False
529
    # limit / offset
530
531
532
533
534
    limit: Optional[int] = None
    offset: int = 0
    # already defined inside ScopeNode right?
    # But py3 fails when I change anything here
    # RecursionError: maximum recursion depth excedeed
535
    # set by the annotator
536
    has_aggregat: bool = False
537

538
539
    def __init__(self):
        Statement.__init__(self)
540
        ScopeNode.__init__(self)
541
        self.selection: List = []
542
        # subqueries alias
543
        self.aliases: Dict[str, "rql.nodes.ColumnAlias"] = {}
544
        # syntax tree meta-information
545
        self.stinfo: Dict[str, Dict] = {"rewritten": {}}
546

547
548
549
550
        # select clauses
        self.groupby: List[Any] = []
        self.orderby: List[Any] = []

551
    @property
552
553
554
    def root(self):
        """return the root node of the tree"""
        return self.parent
555

556
557
558
559
560
    def get_description(
        self,
        mainindex: Optional[int] = None,
        tr: Optional[rt.TranslationFunction] = None,
    ) -> List[str]:
561
        """return the list of types or relations (if not found) associated to
562
563
564
        selected variables.
        mainindex is an optional selection index which should be considered has
        'pivot' entity.
565
        """
566
        descr: List[str] = []
567
        for term in self.selection:
568
            # don't translate Any
569
            try:
570
                descr.append(term.get_description(mainindex, tr) or "Any")
571
            except CoercionError:
572
                descr.append("Any")
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        return descr

    @property
    def children(self):
        children = self.selection[:]
        if self.groupby:
            children += self.groupby
        if self.orderby:
            children += self.orderby
        if self.where:
            children.append(self.where)
        if self.having:
            children += self.having
        if self.with_:
            children += self.with_
        return children

    # repr / as_string / copy #################################################
591

592
    def __repr__(self) -> str:
593
        return self.as_string(userepr=True)
594

595
    def as_string(self, kwargs: Optional[Dict] = None, userepr: bool = False) -> str:
596
        """return the tree as an encoded rql string"""
597
598
        if userepr:
            as_string = repr
599
        else:
600

601
602
            def as_string(x):
                return x.as_string(kwargs=kwargs)
603
604

        s = [",".join(as_string(term) for term in self.selection)]
605
        if self.groupby:
606
            s.append("GROUPBY " + ",".join(as_string(term) for term in self.groupby))
607
        if self.orderby:
608
            s.append("ORDERBY " + ",".join(as_string(term) for term in self.orderby))
609
        if self.limit is not None:
610
            s.append("LIMIT %s" % self.limit)
611
        if self.offset:
612
            s.append("OFFSET %s" % self.offset)
613
        if self.where is not None:
614
            s.append("WHERE " + as_string(self.where))
615
        if self.having:
616
            s.append("HAVING " + ",".join(as_string(term) for term in self.having))
617
        if self.with_:
618
            s.append("WITH " + ",".join(as_string(term) for term in self.with_))
619
        if self.distinct:
620
621
622
623
624
625
626
            return "DISTINCT Any " + " ".join(s)
        return "Any " + " ".join(s)

    def copy(
        self,
        stmt: Optional["rql.stmts.Statement"] = None,
        copy_solutions: bool = True,
627
        solutions: Optional[SolutionsList] = None,
628
    ) -> "rql.stmts.Select":
629
630
631
632
633
        new = super().copy(self, copy_solutions, solutions)

        # "ScopeNode" has no attribute ....  [attr-defined]
        new = cast("rql.stmts.Select", new)

634
        if self.with_:
Sylvain's avatar
Sylvain committed
635
            new.set_with([sq.copy(new) for sq in self.with_], check=False)
636
        for child in self.selection:
637
            new.append_selected(child.copy(new))
638
639
640
641
        if self.groupby:
            new.set_groupby([sq.copy(new) for sq in self.groupby])
        if self.orderby:
            new.set_orderby([sq.copy(new) for sq in self.orderby])
642
643
        # Argument 1 to "set_where" of "ScopeNode" has incompatible type "Node"
        # expected "Union[Or, Not, And, Relation]"  [arg-type]
644
645
646
647
        if self.where:
            new.set_where(self.where.copy(new))
        if self.having:
            new.set_having([sq.copy(new) for sq in self.having])
Nicolas Chauvat's avatar
Nicolas Chauvat committed
648
        new.distinct = self.distinct
649
650
        new.limit = self.limit
        new.offset = self.offset
Sylvain Thenault's avatar
Sylvain Thenault committed
651
        new.vargraph = self.vargraph
Nicolas Chauvat's avatar
Nicolas Chauvat committed
652
        return new
653

654
    # select specific methods #################################################
655

656
657
    def set_possible_types(
        self,
658
        solutions: SolutionsList,
659
660
661
        kwargs: Optional[Union_[object, Dict[str, str]]] = _MARKER,
        key: str = "possibletypes",
    ) -> None:
662
        super(Select, self).set_possible_types(solutions, kwargs, key)
663
        for ca in self.aliases.values():
664
            ca.stinfo[key] = capt = set()
665
666
667
668
669
670
671
            for solution in solutions:
                capt.add(solution[ca.name])
            if kwargs is _MARKER:
                continue
            # propagage to subqueries in case we're introducing additional
            # type constraints
            for stmt in ca.query.children[:]:
672
673
                # better type for term
                term: Any = stmt.selection[ca.colnum]
674
675
676
                sols: List = [
                    sol for sol in stmt.solutions if term.get_type(sol, kwargs) in capt
                ]
677
678
679
680
                if not sols:
                    ca.query.remove_select(stmt)
                else:
                    stmt.set_possible_types(sols)
681

682
    def set_statement_type(self, etype: str) -> None:
683
684
685
686
687
688
        """set the statement type for this selection
        this method must be called last (i.e. once selected variables has been
        added)
        """
        assert self.selection
        # Person P  ->  Any P where P is Person
689
690
691
692
        if etype != "Any":
            variables: List["rql.nodes.VariableRef"] = list(
                self.get_selected_variables()
            )
693
            if not variables:
694
695
696
697
                raise BadRQLQuery(
                    "Setting type in selection is only allowed "
                    "when some variable is selected"
                )
698
            for var in variables:
699
                self.add_type_restriction(var.variable, etype)
700

701
    def set_distinct(self, value: bool) -> None:
702
703
704
        """mark DISTINCT query"""
        if self.should_register_op and value != self.distinct:
            from rql.undo import SetDistinctOperation
705
706

            self.undo_manager.add_operation(SetDistinctOperation(self.distinct, self))
707
        self.distinct = value
708

709
    def set_limit(self, limit: int) -> None:
Philippe Pepiot's avatar
Philippe Pepiot committed
710
        if limit is not None and (not isinstance(limit, int) or limit <= 0):
711
            raise BadRQLQuery("bad limit %s" % limit)
712
713
        if self.should_register_op and limit != self.limit:
            from rql.undo import SetLimitOperation
714
715

            self.undo_manager.add_operation(SetLimitOperation(self.limit, self))
716
717
        self.limit = limit

718
    def set_offset(self, offset: int) -> None:
Philippe Pepiot's avatar
Philippe Pepiot committed
719
        if offset is not None and (not isinstance(offset, int) or offset < 0):
720
            raise BadRQLQuery("bad offset %s" % offset)
721
722
        if self.should_register_op and offset != self.offset:
            from rql.undo import SetOffsetOperation
723
724

            self.undo_manager.add_operation(SetOffsetOperation(self.offset, self))
725
        self.offset = offset
726

727
    def set_orderby(self, terms: List["rql.nodes.SortTerm"]) -> None:
728
729
730
731
        self.orderby = terms
        for node in terms:
            node.parent = self

732
733
734
    def set_groupby(
        self, terms: List[Union_["rql.nodes.Function", "rql.nodes.VariableRef"]]
    ) -> None:
735
736
737
738
        self.groupby = terms
        for node in terms:
            node.parent = self

739
    def set_with(self, terms: List["rql.nodes.SubQuery"], check: bool = True) -> None:
Sylvain's avatar
Sylvain committed
740
        self.with_ = []
741
        for node in terms:
Sylvain's avatar
Sylvain committed
742
            self.add_subquery(node, check)
743

744
    def add_subquery(self, node: "rql.nodes.SubQuery", check: bool = True) -> None:
745

Sylvain's avatar
Sylvain committed
746
747
        assert node.query
        node.parent = self
748
749
        self.with_.append(node)

750
        # "BaseNode" has no attribute "selection"
751
        if check and len(node.aliases) != len(
752
            cast("rql.stmts.Select", node.query.children[0]).selection
753
        ):
754
755
756
757
            raise BadRQLQuery(
                "Should have the same number of aliases than "
                "selected terms in sub-query"
            )
Sylvain's avatar
Sylvain committed
758
        for i, alias in enumerate(node.aliases):
759
            if check and alias.name in self.aliases:
760
                raise BadRQLQuery("Duplicated alias %s" % alias)
761
762
763
764
765
            ca = self.get_variable(alias.name, i)
            # "Variable" has no attribute "query"
            # classes with query attribute: ColumnAlias, Exists, SubQuery
            ca = cast("rql.nodes.ColumnAlias", ca)
            ca.query = node.query
766
767

    def remove_subquery(self, node: "rql.nodes.SubQuery") -> None:
768
        self.with_.remove(node)
Sylvain Thenault's avatar
Sylvain Thenault committed
769
770
771
        node.parent = None
        for i, alias in enumerate(node.aliases):
            del self.aliases[alias.name]
772

773
    # Signature of "get_variable" incompatible with supertype "ScopeNode"  [override]
774
775
776
    def get_variable(
        self, name: str, colnum: Optional[int] = None
    ) -> Union_["rql.nodes.Variable", "rql.nodes.ColumnAlias"]:
Sylvain Thenault's avatar
Sylvain Thenault committed
777
        """get a variable instance from its name
778

779
780
        the variable is created if it doesn't exist yet
        """
781
        if name in self.aliases.keys():
782
            return self.aliases[name]
783
        if colnum is not None:  # take care, may be 0
784
            self.aliases[name] = calias = nodes.ColumnAlias(name, colnum)
785
            calias.stmt = self
Sylvain's avatar
Sylvain committed
786
787
            # alias may already have been used as a regular variable, replace it
            if name in self.defined_vars: