# HG changeset patch
# User Sylvain Thénault <sylvain.thenault@logilab.fr>
# Date 1275574304 -7200
#      Thu Jun 03 16:11:44 2010 +0200
# Node ID 9d0085e2bde2249586a083b6e016b51e8de05527
# Parent  28c79b9641985f8c230cf39fd9002941ebd5dc42
fix grammar bug in HAVING clause: should all arbitrary expression, + fix to deal with IN() hack

diff --git a/nodes.py b/nodes.py
--- a/nodes.py
+++ b/nodes.py
@@ -259,6 +259,10 @@
 class Not(Node):
     """a logical NOT node (unary)"""
     __slots__ = ()
+    def __init__(self, expr=None):
+        Node.__init__(self)
+        if expr is not None:
+            self.append(expr)
 
     def as_string(self, encoding=None, kwargs=None):
         if isinstance(self.children[0], (Exists, Relation)):
diff --git a/parser.g b/parser.g
--- a/parser.g
+++ b/parser.g
@@ -176,10 +176,7 @@
 rule groupby<<S>>: GROUPBY variables<<S>> {{ S.set_groupby(variables); return True }}
                  |
 
-rule having<<S>>: HAVING               {{ nodes = [] }}
-                   expr_cmp<<S>>       {{ nodes.append(expr_cmp) }}
-                   ( ',' expr_cmp<<S>> {{ nodes.append(expr_cmp) }}
-                   )*                  {{ S.set_having(nodes) }}
+rule having<<S>>: HAVING logical_expr<<S>> {{ S.set_having([logical_expr]) }}
                 |
 
 rule orderby<<S>>: ORDERBY              {{ nodes = [] }}
@@ -198,11 +195,6 @@
                     BEING r"\(" union<<Union()>> r"\)" {{ node.set_query(union); return node }}
 
 
-rule expr_cmp<<S>>: expr_add<<S>>  {{ c1 = expr_add }}
-                    CMP_OP         {{ cmp = Comparison(CMP_OP.upper(), c1) }}
-                    expr_add<<S>>  {{ cmp.append(expr_add); return cmp }}
-
-
 rule sort_term<<S>>: expr_add<<S>> sort_meth {{ return SortTerm(expr_add, sort_meth) }}
 
 
@@ -241,7 +233,7 @@
                     (  AND rels_not<<S>> {{ node = And(node, rels_not) }}
                     )*                   {{ return node }}
 
-rule rels_not<<S>>: NOT rel<<S>> {{ node = Not(); node.append(rel); return node }}
+rule rels_not<<S>>: NOT rel<<S>> {{ return Not(rel) }}
                   | rel<<S>>     {{ return rel }}
 
 rule rel<<S>>: rel_base<<S>>                {{ return rel_base }}
@@ -259,6 +251,31 @@
 rule opt_right<<S>>: QMARK  {{ return 'right' }}
                    |
 
+#// restriction expressions ####################################################
+
+rule logical_expr<<S>>: exprs_or<<S>>       {{ node = exprs_or }}
+                        ( ',' exprs_or<<S>> {{ node = And(node, exprs_or) }}
+                        )*                  {{ return node }}
+
+rule exprs_or<<S>>: exprs_and<<S>>      {{ node = exprs_and }}
+                    ( OR exprs_and<<S>> {{ node = Or(node, exprs_and) }}
+                    )*                  {{ return node }}
+
+rule exprs_and<<S>>: exprs_not<<S>>        {{ node = exprs_not }}
+                     (  AND exprs_not<<S>> {{ node = And(node, exprs_not) }}
+                     )*                    {{ return node }}
+
+rule exprs_not<<S>>: NOT balanced_expr<<S>> {{ return Not(balanced_expr) }}
+                   | balanced_expr<<S>>     {{ return balanced_expr }}
+
+rule balanced_expr<<S>>: expr_add<<S>> expr_op<<S>>       {{ expr_op.insert(0, expr_add); return expr_op }}
+                       | r"\(" logical_expr<<S>> r"\)" {{ return logical_expr }}
+
+# // cant use expr<<S>> without introducing some ambiguities
+rule expr_op<<S>>: CMP_OP expr_add<<S>> {{ return Comparison(CMP_OP.upper(), expr_add) }}
+                 | in_expr<<S>>      {{ return Comparison('=', in_expr) }}
+
+
 #// common statements ###########################################################
 
 rule variables<<S>>:                   {{ vars = [] }}
@@ -307,6 +324,13 @@
                    )?
                 r"\)"                 {{ return F }}
 
+rule in_expr<<S>>: 'IN' r"\("        {{ F = Function('IN') }}
+                   ( expr_add<<S>> (     {{ F.append(expr_add) }}
+                      ',' expr_add<<S>>
+                     )*                  {{ F.append(expr_add) }}
+                   )?
+                r"\)"                 {{ return F }}
+
 
 rule var<<S>>: VARIABLE {{ return VariableRef(S.get_variable(VARIABLE)) }}
 
diff --git a/parser.py b/parser.py
--- a/parser.py
+++ b/parser.py
@@ -77,6 +77,7 @@
 
 class HerculeScanner(runtime.Scanner):
     patterns = [
+        ("'IN'", re.compile('IN')),
         ("','", re.compile(',')),
         ('r"\\)"', re.compile('\\)')),
         ('r"\\("', re.compile('\\(')),
@@ -271,14 +272,8 @@
         _token = self._peek('HAVING', 'WITH', 'GROUPBY', 'ORDERBY', 'LIMIT', 'OFFSET', 'WHERE', "';'", 'r"\\)"', context=_context)
         if _token == 'HAVING':
             HAVING = self._scan('HAVING', context=_context)
-            nodes = []
-            expr_cmp = self.expr_cmp(S, _context)
-            nodes.append(expr_cmp)
-            while self._peek("','", 'WITH', 'GROUPBY', 'ORDERBY', 'LIMIT', 'OFFSET', 'WHERE', 'HAVING', "';'", 'r"\\)"', context=_context) == "','":
-                self._scan("','", context=_context)
-                expr_cmp = self.expr_cmp(S, _context)
-                nodes.append(expr_cmp)
-            S.set_having(nodes)
+            logical_expr = self.logical_expr(S, _context)
+            S.set_having([logical_expr])
         elif 1:
             pass
         else:
@@ -330,15 +325,6 @@
         self._scan('r"\\)"', context=_context)
         node.set_query(union); return node
 
-    def expr_cmp(self, S, _parent=None):
-        _context = self.Context(_parent, self._scanner, 'expr_cmp', [S])
-        expr_add = self.expr_add(S, _context)
-        c1 = expr_add
-        CMP_OP = self._scan('CMP_OP', context=_context)
-        cmp = Comparison(CMP_OP.upper(), c1)
-        expr_add = self.expr_add(S, _context)
-        cmp.append(expr_add); return cmp
-
     def sort_term(self, S, _parent=None):
         _context = self.Context(_parent, self._scanner, 'sort_term', [S])
         expr_add = self.expr_add(S, _context)
@@ -431,7 +417,7 @@
         if _token == 'NOT':
             NOT = self._scan('NOT', context=_context)
             rel = self.rel(S, _context)
-            node = Not(); node.append(rel); return node
+            return Not(rel)
         else: # in ['r"\\("', 'EXISTS', 'VARIABLE']
             rel = self.rel(S, _context)
             return rel
@@ -489,6 +475,65 @@
         else:
             pass
 
+    def logical_expr(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'logical_expr', [S])
+        exprs_or = self.exprs_or(S, _context)
+        node = exprs_or
+        while self._peek("','", 'r"\\)"', 'WITH', 'GROUPBY', 'ORDERBY', 'LIMIT', 'OFFSET', 'WHERE', 'HAVING', "';'", context=_context) == "','":
+            self._scan("','", context=_context)
+            exprs_or = self.exprs_or(S, _context)
+            node = And(node, exprs_or)
+        return node
+
+    def exprs_or(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'exprs_or', [S])
+        exprs_and = self.exprs_and(S, _context)
+        node = exprs_and
+        while self._peek('OR', "','", 'r"\\)"', 'WITH', 'GROUPBY', 'ORDERBY', 'LIMIT', 'OFFSET', 'WHERE', 'HAVING', "';'", context=_context) == 'OR':
+            OR = self._scan('OR', context=_context)
+            exprs_and = self.exprs_and(S, _context)
+            node = Or(node, exprs_and)
+        return node
+
+    def exprs_and(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'exprs_and', [S])
+        exprs_not = self.exprs_not(S, _context)
+        node = exprs_not
+        while self._peek('AND', 'OR', "','", 'r"\\)"', 'WITH', 'GROUPBY', 'ORDERBY', 'LIMIT', 'OFFSET', 'WHERE', 'HAVING', "';'", context=_context) == 'AND':
+            AND = self._scan('AND', context=_context)
+            exprs_not = self.exprs_not(S, _context)
+            node = And(node, exprs_not)
+        return node
+
+    def exprs_not(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'exprs_not', [S])
+        _token = self._peek('NOT', 'r"\\("', 'NULL', 'DATE', 'DATETIME', 'TRUE', 'FALSE', 'FLOAT', 'INT', 'STRING', 'SUBSTITUTE', 'VARIABLE', 'E_TYPE', 'FUNCTION', context=_context)
+        if _token == 'NOT':
+            NOT = self._scan('NOT', context=_context)
+            balanced_expr = self.balanced_expr(S, _context)
+            return Not(balanced_expr)
+        else:
+            balanced_expr = self.balanced_expr(S, _context)
+            return balanced_expr
+
+    def balanced_expr(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'balanced_expr', [S])
+        _token = self._peek('r"\\("', 'NULL', 'DATE', 'DATETIME', 'TRUE', 'FALSE', 'FLOAT', 'INT', 'STRING', 'SUBSTITUTE', 'VARIABLE', 'E_TYPE', 'FUNCTION', context=_context)
+        expr_add = self.expr_add(S, _context)
+        expr_op = self.expr_op(S, _context)
+        expr_op.insert(0, expr_add); return expr_op
+
+    def expr_op(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'expr_op', [S])
+        _token = self._peek('CMP_OP', "'IN'", context=_context)
+        if _token == 'CMP_OP':
+            CMP_OP = self._scan('CMP_OP', context=_context)
+            expr_add = self.expr_add(S, _context)
+            return Comparison(CMP_OP.upper(), expr_add)
+        else: # == "'IN'"
+            in_expr = self.in_expr(S, _context)
+            return Comparison('=', in_expr)
+
     def variables(self, S, _parent=None):
         _context = self.Context(_parent, self._scanner, 'variables', [S])
         vars = []
@@ -504,7 +549,7 @@
         _context = self.Context(_parent, self._scanner, 'decl_vars', [R])
         E_TYPE = self._scan('E_TYPE', context=_context)
         var = self.var(R, _context)
-        while self._peek("','", 'R_TYPE', 'QMARK', 'WHERE', '":"', 'HAVING', "';'", 'MUL_OP', 'BEING', 'WITH', 'GROUPBY', 'ORDERBY', 'ADD_OP', 'LIMIT', 'OFFSET', 'r"\\)"', 'CMP_OP', 'SORT_DESC', 'SORT_ASC', 'AND', 'OR', context=_context) == "','":
+        while self._peek("','", 'R_TYPE', 'QMARK', 'WHERE', '":"', 'HAVING', "';'", 'MUL_OP', 'BEING', 'WITH', 'GROUPBY', 'ORDERBY', 'ADD_OP', 'LIMIT', 'OFFSET', 'r"\\)"', 'SORT_DESC', 'SORT_ASC', 'CMP_OP', "'IN'", 'AND', 'OR', context=_context) == "','":
             R.add_main_variable(E_TYPE, var)
             self._scan("','", context=_context)
             E_TYPE = self._scan('E_TYPE', context=_context)
@@ -543,7 +588,7 @@
         _context = self.Context(_parent, self._scanner, 'expr_add', [S])
         expr_mul = self.expr_mul(S, _context)
         node = expr_mul
-        while self._peek('ADD_OP', 'r"\\)"', "','", 'CMP_OP', 'SORT_DESC', 'SORT_ASC', 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', 'WITH', "';'", 'AND', 'OR', context=_context) == 'ADD_OP':
+        while self._peek('ADD_OP', 'r"\\)"', "','", 'SORT_DESC', 'SORT_ASC', 'CMP_OP', "'IN'", 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', "';'", 'WITH', 'AND', 'OR', context=_context) == 'ADD_OP':
             ADD_OP = self._scan('ADD_OP', context=_context)
             expr_mul = self.expr_mul(S, _context)
             node = MathExpression( ADD_OP, node, expr_mul )
@@ -553,7 +598,7 @@
         _context = self.Context(_parent, self._scanner, 'expr_mul', [S])
         expr_base = self.expr_base(S, _context)
         node = expr_base
-        while self._peek('MUL_OP', 'ADD_OP', 'r"\\)"', "','", 'CMP_OP', 'SORT_DESC', 'SORT_ASC', 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', 'WITH', "';'", 'AND', 'OR', context=_context) == 'MUL_OP':
+        while self._peek('MUL_OP', 'ADD_OP', 'r"\\)"', "','", 'SORT_DESC', 'SORT_ASC', 'CMP_OP', "'IN'", 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', "';'", 'WITH', 'AND', 'OR', context=_context) == 'MUL_OP':
             MUL_OP = self._scan('MUL_OP', context=_context)
             expr_base = self.expr_base(S, _context)
             node = MathExpression( MUL_OP, node, expr_base)
@@ -587,7 +632,22 @@
         F = Function(FUNCTION)
         if self._peek('r"\\)"', 'r"\\("', 'NULL', 'DATE', 'DATETIME', 'TRUE', 'FALSE', 'FLOAT', 'INT', 'STRING', 'SUBSTITUTE', 'VARIABLE', 'E_TYPE', 'FUNCTION', context=_context) != 'r"\\)"':
             expr_add = self.expr_add(S, _context)
-            while self._peek("','", 'r"\\)"', 'CMP_OP', 'SORT_DESC', 'SORT_ASC', 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', 'WITH', "';'", 'AND', 'OR', context=_context) == "','":
+            while self._peek("','", 'r"\\)"', 'SORT_DESC', 'SORT_ASC', 'CMP_OP', "'IN'", 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', "';'", 'WITH', 'AND', 'OR', context=_context) == "','":
+                F.append(expr_add)
+                self._scan("','", context=_context)
+                expr_add = self.expr_add(S, _context)
+            F.append(expr_add)
+        self._scan('r"\\)"', context=_context)
+        return F
+
+    def in_expr(self, S, _parent=None):
+        _context = self.Context(_parent, self._scanner, 'in_expr', [S])
+        self._scan("'IN'", context=_context)
+        self._scan('r"\\("', context=_context)
+        F = Function('IN')
+        if self._peek('r"\\)"', 'r"\\("', 'NULL', 'DATE', 'DATETIME', 'TRUE', 'FALSE', 'FLOAT', 'INT', 'STRING', 'SUBSTITUTE', 'VARIABLE', 'E_TYPE', 'FUNCTION', context=_context) != 'r"\\)"':
+            expr_add = self.expr_add(S, _context)
+            while self._peek("','", 'r"\\)"', 'SORT_DESC', 'SORT_ASC', 'CMP_OP', "'IN'", 'GROUPBY', 'QMARK', 'ORDERBY', 'WHERE', 'LIMIT', 'OFFSET', 'HAVING', "';'", 'WITH', 'AND', 'OR', context=_context) == "','":
                 F.append(expr_add)
                 self._scan("','", context=_context)
                 expr_add = self.expr_add(S, _context)
diff --git a/test/unittest_parser.py b/test/unittest_parser.py
--- a/test/unittest_parser.py
+++ b/test/unittest_parser.py
@@ -116,7 +116,23 @@
     ' WITH T1,T2 BEING ('
     '      (Any X,N WHERE X name N, X transition_of E, E name %(name)s)'
     '       UNION '
-    '      (Any X,N WHERE X name N, X state_of E, E name %(name)s))',
+    '      (Any X,N WHERE X name N, X state_of E, E name %(name)s));',
+
+
+    'Any T2'
+    ' GROUPBY T2'
+    ' WHERE T1 relation T2'
+    ' HAVING COUNT(T1) IN (1,2);',
+
+    'Any T2'
+    ' GROUPBY T2'
+    ' WHERE T1 relation T2'
+    ' HAVING COUNT(T1) IN (1,2) OR COUNT(T1) IN (3,4);',
+
+    'Any T2'
+    ' GROUPBY T2'
+    ' WHERE T1 relation T2'
+    ' HAVING 1 < COUNT(T1) OR COUNT(T1) IN (3,4);',
     )
 
 class ParserHercule(TestCase):