Skip to content

Commit ce59f80

Browse files
committed
Add support for array operators
1 parent 8ee1b7f commit ce59f80

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

CHANGELOG

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Add support for array operators
12
* Remove the parentheses around the unary and binary operators
23
* Use the ordinal number as aliases for GROUP BY
34
* Check the coherence of the aliases of GROUP BY and ORDER BY expressions

sql/operators.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def _format(self, operand, param=None):
4949
if param is None:
5050
param = Flavor.get().param
5151
if (isinstance(operand, Expression)
52-
and not isinstance(operand, Operator)):
52+
and (not isinstance(operand, Operator)
53+
or isinstance(operand, UnaryOperator))):
5354
return str(operand)
5455
elif isinstance(operand, (Expression, Select, CombiningQuery)):
5556
return '(%s)' % operand
@@ -458,15 +459,32 @@ class Exists(UnaryOperator):
458459
_operator = 'EXISTS'
459460

460461

461-
class Any(UnaryOperator):
462+
class _ArrayOperator(UnaryOperator):
463+
__slots__ = ()
464+
465+
@property
466+
def params(self):
467+
if isinstance(self.operand, (list, tuple, array)):
468+
return (list(self.operand),)
469+
return super().params
470+
471+
def _format(self, operand, param=None):
472+
if param is None:
473+
param = Flavor.get().param
474+
if isinstance(operand, (list, tuple, array)):
475+
return '(%s)' % param
476+
return super()._format(operand, param=param)
477+
478+
479+
class Any(_ArrayOperator):
462480
__slots__ = ()
463481
_operator = 'ANY'
464482

465483

466484
Some = Any
467485

468486

469-
class All(UnaryOperator):
487+
class All(_ArrayOperator):
470488
__slots__ = ()
471489
_operator = 'ALL'
472490

sql/tests/test_operators.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from sql import Flavor, Literal, Null, Table
88
from sql.operators import (
9-
Abs, And, Between, Div, Equal, Exists, FloorDiv, Greater, GreaterEqual,
10-
ILike, In, Is, IsDistinct, IsNot, IsNotDistinct, Less, LessEqual, Like,
11-
LShift, Mod, Mul, Neg, Not, NotBetween, NotEqual, NotILike, NotIn, NotLike,
12-
Operator, Or, Pos, Pow, RShift, Sub)
9+
Abs, And, Any, Between, Div, Equal, Exists, FloorDiv, Greater,
10+
GreaterEqual, ILike, In, Is, IsDistinct, IsNot, IsNotDistinct, Less,
11+
LessEqual, Like, LShift, Mod, Mul, Neg, Not, NotBetween, NotEqual,
12+
NotILike, NotIn, NotLike, Operator, Or, Pos, Pow, RShift, Sub)
1313

1414

1515
class TestOperators(unittest.TestCase):
@@ -418,3 +418,21 @@ def test_floordiv(self):
418418
self.assertIn(
419419
'FloorDiv operator is deprecated, use Div function',
420420
str(w[-1].message))
421+
422+
def test_any(self):
423+
any_ = Any(self.table.select(self.table.c1, where=self.table.c2 == 1))
424+
self.assertEqual(str(any_),
425+
'ANY (SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c2" = %s)')
426+
self.assertEqual(any_.params, (1,))
427+
428+
for value in [[1, 2, 3], (1, 2, 3), array('l', [1, 2, 3])]:
429+
with self.subTest(value=value):
430+
any_ = Any(value)
431+
self.assertEqual(str(any_), 'ANY (%s)')
432+
self.assertEqual(any_.params, ([1, 2, 3],))
433+
434+
def test_binary_unary(self):
435+
operator = Equal(self.table.c1, Any([1, 2, 3]))
436+
437+
self.assertEqual(str(operator), '"c1" = ANY (%s)')
438+
self.assertEqual(operator.params, ([1, 2, 3],))

0 commit comments

Comments
 (0)