diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 139a54e..e027c05 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,10 +1,8 @@ workflow: rules: - - if: $CI_PIPELINE_SOURCE == "trigger" - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH && $CI_OPEN_MERGE_REQUESTS + - if: $CI_COMMIT_BRANCH =~ /^topic\/.*/ && $CI_PIPELINE_SOURCE == "push" when: never - - if: $CI_COMMIT_BRANCH =~ /^branch\/.*/ + - when: always stages: - check @@ -12,33 +10,28 @@ stages: .check: stage: check - rules: - - if: $CI_MERGE_REQUEST_ID != null - when: always image: ${CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX}/tryton/ci check-flake8: extends: .check script: - - hg diff --rev s0 | flake8 --diff + - flake8 check-isort: extends: .check script: - - isort -m VERTICAL_GRID -p trytond -c `hg status --no-status --added --modified --rev s0` + - isort -m VERTICAL_GRID -c . check-dist: extends: .check before_script: - - pip install twine + - pip install build twine script: - - python setup.py sdist + - pyproject-build - twine check dist/* .test: stage: test - rules: - - when: always .test-tox: extends: .test @@ -49,18 +42,25 @@ check-dist: - .cache/pip before_script: - pip install tox + coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' + artifacts: + reports: + junit: ${CI_PROJECT_DIR}/junit.xml + coverage_report: + coverage_format: cobertura + path: ${CI_PROJECT_DIR}/coverage.xml test-tox-python: extends: .test-tox image: ${CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX}/python:${PYTHON_VERSION} script: - - tox -e "py${PYTHON_VERSION/./}" + - tox -e "py${PYTHON_VERSION/./}" -- -v --output-file "${CI_PROJECT_DIR}/junit.xml" parallel: matrix: - - PYTHON_VERSION: ["3.5", "3.6", "3.7", "3.8", "3.9", "3.10"] + - PYTHON_VERSION: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] test-tox-pypy: extends: .test-tox image: ${CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX}/pypy:3 script: - - tox -e pypy3 + - tox -e pypy3 -- -v --output-file "${CI_PROJECT_DIR}/junit.xml" diff --git a/.hgignore b/.hgignore new file mode 100644 index 0000000..19f0bee --- /dev/null +++ b/.hgignore @@ -0,0 +1,7 @@ +syntax: glob +*.py[cdo] +*.egg-info +dist/ +build/ +.tox/ +.coverage diff --git a/.hgtags b/.hgtags index aefc730..517fffd 100644 --- a/.hgtags +++ b/.hgtags @@ -14,3 +14,13 @@ b2bcc0f71f6881316c11330c07de34113f088888 1.2.1 1c38ffeacbb82a9ff6ae3568cdc017dbbeddff5d 1.2.2 edc03ee84f0ac96d403d8f984d59fffa3274cd2f 1.3.0 a317c40a4d60089ba9e465fbd64b78df24f9e890 1.4.0 +e71bbae3398cb6a0e72f97a0cada9fcdee2bddea 1.4.1 +fcb64787b51db2068061eb4aa13825abc1134916 1.4.2 +111e3e86865360f83a65c04fa48c55f3d2957ee3 1.4.3 +6f9066b83fe3a8c4699a8555ad1bc406f18974ff 1.5.0 +79a69b0bbbd35a8d95e1b754ed3feb03df23fb70 1.5.1 +41b0aaa68f5e5bab3889fa1ef57ef44c6c21cacf 1.5.2 +475502ba46eba3b7e141e8fbceaf495b545bcddb 1.6.0 +231ce10b975e41027c6121f9bb9033d786553b90 1.7.0 +a1db1b7c55132372b933242b2f07cb353b973b29 1.8.0 +80c2d7b2dc49be9ee71a4c5c2b54b00003ef67d1 1.8.1 diff --git a/CHANGELOG b/CHANGELOG index a583b16..64a2b69 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,55 @@ + +Version 1.8.1 - 2026-04-03 +-------------------------- +* Bug fixes (see mercurial logs for details) + + +Version 1.8.0 - 2026-03-21 +-------------------------- +* Bug fixes (see mercurial logs for details) +* Upgrade to pyproject +* Add support for array operators +* Remove the parentheses around the unary and binary operators +* Use the ordinal number as aliases for GROUP BY +* Check the coherence of the aliases of GROUP BY and ORDER BY expressions +* Do not use parameter for EXTRACT field +* Remove support for Python older than 3.9 + +Version 1.7.0 - 2025-11-24 +* Add support for Python 3.14 +* Do not use parameters for COUNT(*) + +Version 1.6.0 - 2025-05-02 +* Fix position of order_by parameters in Select query +* Add support for weak reference on SQL objects +* Add support for Python 3.13 + +Version 1.5.2 - 2024-09-30 +* Use parameter for unary operator +* Support default values when inserting not matched merge +* Replace assert by ValueError + +Version 1.5.1 - 2024-05-28 +* Use parameter for start and end of WINDOW FRAME +* Use parameter for limit and offset + +Version 1.5.0 - 2024-05-13 +* Skip alias on INSERT without ON CONFLICT or RETURNING +* Add MERGE +* Support UPSERT +* Remove default escape char on LIKE and ILIKE +* Add GROUPING SETS, CUBE, and ROLLUP + +Version 1.4.3 - 2023-12-30 +* Render common table expression in combining query +* Add support for Python 3.12 + +Version 1.4.2 - 2023-06-25 +* Restore usage of alias in returning expression + +Version 1.4.1 - 2023-06-16 +* Do not use alias in returning expression + Version 1.4.0 - 2022-05-02 * Use unittest discover * Use only column name for INSERT and UPDATE diff --git a/COPYRIGHT b/COPYRIGHT index 2b1ca76..e0e9536 100644 --- a/COPYRIGHT +++ b/COPYRIGHT @@ -1,6 +1,6 @@ -Copyright (c) 2011-2022, Cédric Krier -Copyright (c) 2013-2021, Nicolas Évrard -Copyright (c) 2011-2022, B2CK +Copyright (c) 2011-2026 Cédric Krier +Copyright (c) 2013-2025 Nicolas Évrard +Copyright (c) 2011-2026 B2CK SRL All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index b1a2e8f..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -include COPYRIGHT -include README -include CHANGELOG diff --git a/README b/README.rst similarity index 76% rename from README rename to README.rst index fc87231..e53aaef 100644 --- a/README +++ b/README.rst @@ -39,14 +39,14 @@ Select with where condition:: >>> select.where = user.name == 'foo' >>> tuple(select) - ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE ("a"."name" = %s)', ('foo',)) + ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE "a"."name" = %s', ('foo',)) >>> select.where = (user.name == 'foo') & (user.active == True) >>> tuple(select) - ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE (("a"."name" = %s) AND ("a"."active" = %s))', ('foo', True)) + ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE ("a"."name" = %s) AND ("a"."active" = %s)', ('foo', True)) >>> select.where = user.name == user.login >>> tuple(select) - ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE ("a"."name" = "a"."login")', ()) + ('SELECT "a"."id", "a"."name" FROM "user" AS "a" WHERE "a"."name" = "a"."login"', ()) Select with join:: @@ -54,7 +54,7 @@ Select with join:: >>> join.condition = join.right.user == user.id >>> select = join.select(user.name, join.right.group) >>> tuple(select) - ('SELECT "a"."name", "b"."group" FROM "user" AS "a" INNER JOIN "user_group" AS "b" ON ("b"."user" = "a"."id")', ()) + ('SELECT "a"."name", "b"."group" FROM "user" AS "a" INNER JOIN "user_group" AS "b" ON "b"."user" = "a"."id"', ()) Select with multiple joins:: @@ -93,9 +93,9 @@ Select with sub-select:: ... where=user_group.active == True) >>> user = Table('user') >>> tuple(user.select(user.id, where=user.id.in_(subselect))) - ('SELECT "a"."id" FROM "user" AS "a" WHERE ("a"."id" IN (SELECT "b"."user" FROM "user_group" AS "b" WHERE ("b"."active" = %s)))', (True,)) + ('SELECT "a"."id" FROM "user" AS "a" WHERE "a"."id" IN (SELECT "b"."user" FROM "user_group" AS "b" WHERE "b"."active" = %s)', (True,)) >>> tuple(subselect.select(subselect.user)) - ('SELECT "a"."user" FROM (SELECT "b"."user" FROM "user_group" AS "b" WHERE ("b"."active" = %s)) AS "a"', (True,)) + ('SELECT "a"."user" FROM (SELECT "b"."user" FROM "user_group" AS "b" WHERE "b"."active" = %s) AS "a"', (True,)) Select on other schema:: @@ -106,43 +106,43 @@ Select on other schema:: Insert query with default values:: >>> tuple(user.insert()) - ('INSERT INTO "user" AS "a" DEFAULT VALUES', ()) + ('INSERT INTO "user" DEFAULT VALUES', ()) Insert query with values:: >>> tuple(user.insert(columns=[user.name, user.login], ... values=[['Foo', 'foo']])) - ('INSERT INTO "user" AS "a" ("name", "login") VALUES (%s, %s)', ('Foo', 'foo')) + ('INSERT INTO "user" ("name", "login") VALUES (%s, %s)', ('Foo', 'foo')) >>> tuple(user.insert(columns=[user.name, user.login], ... values=[['Foo', 'foo'], ['Bar', 'bar']])) - ('INSERT INTO "user" AS "a" ("name", "login") VALUES (%s, %s), (%s, %s)', ('Foo', 'foo', 'Bar', 'bar')) + ('INSERT INTO "user" ("name", "login") VALUES (%s, %s), (%s, %s)', ('Foo', 'foo', 'Bar', 'bar')) Insert query with query:: >>> passwd = Table('passwd') >>> select = passwd.select(passwd.login, passwd.passwd) >>> tuple(user.insert(values=select)) - ('INSERT INTO "user" AS "b" SELECT "a"."login", "a"."passwd" FROM "passwd" AS "a"', ()) + ('INSERT INTO "user" SELECT "a"."login", "a"."passwd" FROM "passwd" AS "a"', ()) Update query with values:: >>> tuple(user.update(columns=[user.active], values=[True])) ('UPDATE "user" AS "a" SET "active" = %s', (True,)) >>> tuple(invoice.update(columns=[invoice.total], values=[invoice.amount + invoice.tax])) - ('UPDATE "invoice" AS "a" SET "total" = ("a"."amount" + "a"."tax")', ()) + ('UPDATE "invoice" AS "a" SET "total" = "a"."amount" + "a"."tax"', ()) Update query with where condition:: >>> tuple(user.update(columns=[user.active], values=[True], ... where=user.active == False)) - ('UPDATE "user" AS "a" SET "active" = %s WHERE ("a"."active" = %s)', (True, False)) + ('UPDATE "user" AS "a" SET "active" = %s WHERE "a"."active" = %s', (True, False)) Update query with from list:: >>> group = Table('user_group') >>> tuple(user.update(columns=[user.active], values=[group.active], ... from_=[group], where=user.id == group.user)) - ('UPDATE "user" AS "b" SET "active" = "a"."active" FROM "user_group" AS "a" WHERE ("b"."id" = "a"."user")', ()) + ('UPDATE "user" AS "b" SET "active" = "a"."active" FROM "user_group" AS "a" WHERE "b"."id" = "a"."user"', ()) Delete query:: @@ -152,13 +152,13 @@ Delete query:: Delete query with where condition:: >>> tuple(user.delete(where=user.name == 'foo')) - ('DELETE FROM "user" WHERE ("name" = %s)', ('foo',)) + ('DELETE FROM "user" WHERE "name" = %s', ('foo',)) Delete query with sub-query:: >>> tuple(user.delete( ... where=user.id.in_(user_group.select(user_group.user)))) - ('DELETE FROM "user" WHERE ("id" IN (SELECT "a"."user" FROM "user_group" AS "a"))', ()) + ('DELETE FROM "user" WHERE "id" IN (SELECT "a"."user" FROM "user_group" AS "a")', ()) Flavors:: @@ -166,26 +166,26 @@ Flavors:: >>> select.offset = 10 >>> Flavor.set(Flavor()) >>> tuple(select) - ('SELECT * FROM "user" AS "a" OFFSET 10', ()) + ('SELECT * FROM "user" AS "a" OFFSET %s', (10,)) >>> Flavor.set(Flavor(max_limit=18446744073709551615)) >>> tuple(select) - ('SELECT * FROM "user" AS "a" LIMIT 18446744073709551615 OFFSET 10', ()) + ('SELECT * FROM "user" AS "a" LIMIT 18446744073709551615 OFFSET %s', (10,)) >>> Flavor.set(Flavor(max_limit=-1)) >>> tuple(select) - ('SELECT * FROM "user" AS "a" LIMIT -1 OFFSET 10', ()) + ('SELECT * FROM "user" AS "a" LIMIT -1 OFFSET %s', (10,)) Limit style:: >>> select = user.select(limit=10, offset=20) >>> Flavor.set(Flavor(limitstyle='limit')) >>> tuple(select) - ('SELECT * FROM "user" AS "a" LIMIT 10 OFFSET 20', ()) + ('SELECT * FROM "user" AS "a" LIMIT %s OFFSET %s', (10, 20)) >>> Flavor.set(Flavor(limitstyle='fetch')) >>> tuple(select) - ('SELECT * FROM "user" AS "a" OFFSET (20) ROWS FETCH FIRST (10) ROWS ONLY', ()) + ('SELECT * FROM "user" AS "a" OFFSET (%s) ROWS FETCH FIRST (%s) ROWS ONLY', (20, 10)) >>> Flavor.set(Flavor(limitstyle='rownum')) >>> tuple(select) - ('SELECT "a".* FROM (SELECT "b".*, ROWNUM AS "rnum" FROM (SELECT * FROM "user" AS "c") AS "b" WHERE (ROWNUM <= %s)) AS "a" WHERE ("rnum" > %s)', (30, 20)) + ('SELECT "a".* FROM (SELECT "b".*, ROWNUM AS "rnum" FROM (SELECT * FROM "user" AS "c") AS "b" WHERE ROWNUM <= %s) AS "a" WHERE "rnum" > %s', (30, 20)) qmark style:: @@ -193,7 +193,7 @@ qmark style:: >>> select = user.select() >>> select.where = user.name == 'foo' >>> tuple(select) - ('SELECT * FROM "user" AS "a" WHERE ("a"."name" = ?)', ('foo',)) + ('SELECT * FROM "user" AS "a" WHERE "a"."name" = ?', ('foo',)) numeric style:: @@ -201,4 +201,4 @@ numeric style:: >>> select = user.select() >>> select.where = user.name == 'foo' >>> format2numeric(*select) - ('SELECT * FROM "user" AS "a" WHERE ("a"."name" = :0)', ('foo',)) + ('SELECT * FROM "user" AS "a" WHERE "a"."name" = :0', ('foo',)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c6545ec --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ['hatchling >= 1', 'hatch-tryton'] +build-backend = 'hatchling.build' + +[project] +name = 'python-sql' +dynamic = ['version', 'authors'] +requires-python = '>=3.9' +maintainers = [ + {name = "Tryton", email = "foundation@tryton.org"}, + ] +description = "Library to write SQL queries" +readme = 'README.rst' +license = 'BSD-3-Clause' +license-files = ['COPYRIGHT'] +keywords = ["SQL", "database", "query"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Database", + "Topic :: Software Development :: Libraries :: Python Modules", + ] + +[project.urls] +homepage = "https://www.tryton.org/" +changelog = "https://code.tryton.org/python-sql/-/blob/branch/default/CHANGELOG" +forum = "https://discuss.tryton.org/tags/python-sql" +issues = "https://bugs.tryton.org/python-sql" +repository = "https://code.tryton.org/python-sql" + +[tool.hatch.version] +path = 'sql/__init__.py' + +[tool.hatch.build] +packages = ['sql'] + +[tool.hatch.metadata.hooks.tryton] +copyright = 'COPYRIGHT' diff --git a/setup.py b/setup.py deleted file mode 100644 index 9264e56..0000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python -# This file is part of python-sql. The COPYRIGHT file at the top level of -# this repository contains the full copyright notices and license terms. -import codecs -import os -import re - -from setuptools import find_packages, setup - - -def read(fname): - return codecs.open( - os.path.join(os.path.dirname(__file__), fname), 'r', 'utf-8').read() - - -def get_version(): - init = read(os.path.join('sql', '__init__.py')) - return re.search("__version__ = '([0-9.]*)'", init).group(1) - - -setup(name='python-sql', - version=get_version(), - description='Library to write SQL queries', - long_description=read('README'), - author='Tryton', - author_email='python-sql@tryton.org', - url='https://pypi.org/project/python-sql/', - download_url='https://downloads.tryton.org/python-sql/', - project_urls={ - "Bug Tracker": 'https://python-sql.tryton.org/', - "Forum": 'https://discuss.tryton.org/tags/python-sql', - "Source Code": 'https://code.tryton.org/python-sql', - }, - keywords='SQL database query', - packages=find_packages(), - python_requires='>=3.5', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Database', - 'Topic :: Software Development :: Libraries :: Python Modules', - ], - license='BSD', - ) diff --git a/sql/__init__.py b/sql/__init__.py index 868ef45..c5357d2 100644 --- a/sql/__init__.py +++ b/sql/__init__.py @@ -7,9 +7,13 @@ from itertools import chain from threading import current_thread, local -__version__ = '1.4.1' -__all__ = ['Flavor', 'Table', 'Values', 'Literal', 'Column', 'Join', - 'Asc', 'Desc', 'NullsFirst', 'NullsLast', 'format2numeric'] +__version__ = '1.8.2' +__all__ = [ + 'Flavor', 'Table', 'Values', 'Literal', 'Column', 'Grouping', 'Conflict', + 'Matched', 'MatchedUpdate', 'MatchedDelete', + 'NotMatched', 'NotMatchedInsert', + 'Rollup', 'Cube', 'Excluded', 'Join', 'Asc', 'Desc', 'NullsFirst', + 'NullsLast', 'format2numeric'] def _escape_identifier(name): @@ -58,17 +62,23 @@ class Flavor(object): def __init__(self, limitstyle='limit', max_limit=None, paramstyle='format', ilike=False, no_as=False, no_boolean=False, null_ordering=True, function_mapping=None, filter_=False, escape_empty=False): - assert limitstyle in ['fetch', 'limit', 'rownum'] + if limitstyle not in {'fetch', 'limit', 'rownum'}: + raise ValueError("unsupported limitstyle: %r" % limitstyle) self.limitstyle = limitstyle + if (max_limit is not None + and not isinstance(max_limit, numbers.Integral)): + raise ValueError("unsupported max_limit: %r" % max_limit) self.max_limit = max_limit + if paramstyle not in {'format', 'qmark'}: + raise ValueError("unsupported paramstyle: %r" % paramstyle) self.paramstyle = paramstyle - self.ilike = ilike - self.no_as = no_as - self.no_boolean = no_boolean - self.null_ordering = null_ordering - self.function_mapping = function_mapping or {} - self.filter_ = filter_ - self.escape_empty = escape_empty + self.ilike = bool(ilike) + self.no_as = bool(no_as) + self.no_boolean = bool(no_boolean) + self.null_ordering = bool(null_ordering) + self.function_mapping = dict(function_mapping or {}) + self.filter_ = bool(filter_) + self.escape_empty = bool(escape_empty) @property def param(self): @@ -172,7 +182,7 @@ def format2numeric(query, params): class Query(object): - __slots__ = () + __slots__ = ('__weakref__',) @property def params(self): @@ -209,7 +219,8 @@ def with_(self, value): if value is not None: if isinstance(value, With): value = [value] - assert all(isinstance(w, With) for w in value) + if any(not isinstance(w, With) for w in value): + raise ValueError("invalid with: %r" % value) self._with = value def _with_str(self): @@ -232,7 +243,7 @@ def _with_params(self): class FromItem(object): - __slots__ = () + __slots__ = ('__weakref__',) @property def alias(self): @@ -248,7 +259,8 @@ def __getattr__(self, name): return Column(self, name) def __add__(self, other): - assert isinstance(other, FromItem) + if not isinstance(other, FromItem): + return NotImplemented return From((self, other)) def select(self, *args, **kwargs): @@ -344,7 +356,8 @@ def order_by(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid order by: %r" % value) self._order_by = value @property @@ -361,7 +374,8 @@ def limit(self): @limit.setter def limit(self, value): if value is not None: - assert isinstance(value, numbers.Integral) + if not isinstance(value, numbers.Integral): + raise ValueError("invalid limit: %r" % value) self._limit = value @property @@ -371,18 +385,20 @@ def offset(self): @offset.setter def offset(self, value): if value is not None: - assert isinstance(value, numbers.Integral) + if not isinstance(value, numbers.Integral): + raise ValueError("invalid offset: %r" % value) self._offset = value @property def _limit_offset_str(self): + param = Flavor.get().param if Flavor.get().limitstyle == 'limit': offset = '' if self.offset: - offset = ' OFFSET %s' % self.offset + offset = ' OFFSET %s' % param limit = '' if self.limit is not None: - limit = ' LIMIT %s' % self.limit + limit = ' LIMIT %s' % param elif self.offset: max_limit = Flavor.get().max_limit if max_limit: @@ -391,12 +407,27 @@ def _limit_offset_str(self): else: offset = '' if self.offset: - offset = ' OFFSET (%s) ROWS' % self.offset + offset = ' OFFSET (%s) ROWS' % param fetch = '' if self.limit is not None: - fetch = ' FETCH FIRST (%s) ROWS ONLY' % self.limit + fetch = ' FETCH FIRST (%s) ROWS ONLY' % param return offset + fetch + @property + def _limit_offset_params(self): + p = [] + if Flavor.get().limitstyle == 'limit': + if self.limit is not None: + p.append(self.limit) + if self.offset: + p.append(self.offset) + else: + if self.offset: + p.append(self.offset) + if self.limit is not None: + p.append(self.limit) + return tuple(p) + def as_(self, output_name): return As(self, output_name) @@ -444,7 +475,8 @@ def distinct_on(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid distinct on: %r" % value) self._distinct_on = value @property @@ -453,7 +485,10 @@ def columns(self): @columns.setter def columns(self, value): - assert all(isinstance(col, (Expression, SelectQuery)) for col in value) + if any( + not isinstance(col, (Expression, SelectQuery)) + for col in value): + raise ValueError("invalid columns: %r" % value) self._columns = tuple(value) @property @@ -464,7 +499,8 @@ def where(self): def where(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid where: %r" % value) self._where = value @property @@ -476,7 +512,8 @@ def group_by(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid group by: %r" % value) self._group_by = value @property @@ -487,7 +524,8 @@ def having(self): def having(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid having: %r" % value) self._having = value @property @@ -499,7 +537,8 @@ def for_(self, value): if value is not None: if isinstance(value, For): value = [value] - assert all(isinstance(f, For) for f in value) + if any(not isinstance(f, For) for f in value): + raise ValueError("invalid for: %r" % value) self._for_ = value @property @@ -527,7 +566,8 @@ def windows(self): @windows.setter def windows(self, value): if value is not None: - assert all(isinstance(w, Window) for w in value) + if any(not isinstance(w, Window) for w in value): + raise ValueError("invalid windows: %r" % value) self._windows = value @staticmethod @@ -589,6 +629,27 @@ def __str__(self): and (self.limit is not None or self.offset is not None)): return self._rownum(str) + ordinals = {} + for expression in chain( + self.group_by or [], + self.order_by or []): + if not isinstance(expression, As): + continue + for i, column in enumerate(self.columns, start=1): + if not isinstance(column, As): + continue + if column.output_name != expression.output_name: + continue + if (str(column.expression) != str(expression.expression) + or column.params != expression.params): + raise ValueError("%r != %r" % (expression, column)) + ordinals[column.output_name] = i + + def str_or_ordinal(expression): + if isinstance(expression, As): + expression = ordinals.get(expression.output_name, expression) + return str(expression) + with AliasManager(): if self.from_ is not None: from_ = ' FROM %s' % self.from_ @@ -617,7 +678,8 @@ def __str__(self): where = ' WHERE ' + str(self.where) group_by = '' if self.group_by: - group_by = ' GROUP BY ' + ', '.join(map(str, self.group_by)) + group_by = ' GROUP BY ' + ', '.join( + map(str_or_ordinal, self.group_by)) having = '' if self.having: having = ' HAVING ' + str(self.having) @@ -652,28 +714,32 @@ def params(self): if self.group_by: for expression in self.group_by: p.extend(expression.params) - if self.order_by: - for expression in self.order_by: - p.extend(expression.params) if self.having: p.extend(self.having.params) for window in self.windows: p.extend(window.params) + if self.order_by: + for expression in self.order_by: + p.extend(expression.params) + p.extend(self._limit_offset_params) return tuple(p) class Insert(WithQuery): - __slots__ = ('_table', '_columns', '_values', '_returning') + __slots__ = ('_table', '_columns', '_values', '_on_conflict', '_returning') - def __init__(self, table, columns=None, values=None, returning=None, - **kwargs): + def __init__( + self, table, columns=None, values=None, returning=None, + on_conflict=None, **kwargs): self._table = None self._columns = None self._values = None + self._on_conflict = None self._returning = None self.table = table self.columns = columns self.values = values + self.on_conflict = on_conflict self.returning = returning super(Insert, self).__init__(**kwargs) @@ -683,7 +749,8 @@ def table(self): @table.setter def table(self, value): - assert isinstance(value, Table) + if not isinstance(value, Table): + raise ValueError("invalid table: %r" % value) self._table = value @property @@ -693,8 +760,10 @@ def columns(self): @columns.setter def columns(self, value): if value is not None: - assert all(isinstance(col, Column) for col in value) - assert all(col.table == self.table for col in value) + if any( + not isinstance(col, Column) or col.table != self.table + for col in value): + raise ValueError("invalid columns: %r" % value) self._columns = value @property @@ -704,11 +773,23 @@ def values(self): @values.setter def values(self, value): if value is not None: - assert isinstance(value, (list, Select)) + if not isinstance(value, (list, Select)): + raise ValueError("invalid values: %r" % value) if isinstance(value, list): value = Values(value) self._values = value + @property + def on_conflict(self): + return self._on_conflict + + @on_conflict.setter + def on_conflict(self, value): + if value is not None: + if not isinstance(value, Conflict) or value.table != self.table: + raise ValueError("invalid on conflict: %r" % value) + self._on_conflict = value + @property def returning(self): return self._returning @@ -716,7 +797,8 @@ def returning(self): @returning.setter def returning(self, value): if value is not None: - assert isinstance(value, list) + if not isinstance(value, list): + raise ValueError("invalid returning: %r" % value) self._returning = value @staticmethod @@ -743,13 +825,20 @@ def __str__(self): # TODO manage DEFAULT elif self.values is None: values = ' DEFAULT VALUES' + on_conflict = '' + if self.on_conflict: + on_conflict = ' %s' % self.on_conflict returning = '' if self.returning: returning = ' RETURNING ' + ', '.join( map(self._format, self.returning)) + if on_conflict or returning: + table = '%s AS "%s"' % (self.table, self.table.alias) + else: + table = str(self.table) return (self._with_str() - + 'INSERT INTO %s AS "%s"' % (self.table, self.table.alias) - + columns + values + returning) + + 'INSERT INTO %s' % table + + columns + values + on_conflict + returning) @property def params(self): @@ -757,12 +846,157 @@ def params(self): p.extend(self._with_params()) if isinstance(self.values, Query): p.extend(self.values.params) + if self.on_conflict: + p.extend(self.on_conflict.params) if self.returning: for exp in self.returning: p.extend(exp.params) return tuple(p) +class Conflict(object): + __slots__ = ( + '_table', '_indexed_columns', '_index_where', '_columns', '_values', + '_where') + + def __init__( + self, table, indexed_columns=None, index_where=None, + columns=None, values=None, where=None): + self._table = None + self._indexed_columns = None + self._index_where = None + self._columns = None + self._values = None + self._where = None + self.table = table + self.indexed_columns = indexed_columns + self.index_where = index_where + self.columns = columns + self.values = values + self.where = where + + @property + def table(self): + return self._table + + @table.setter + def table(self, value): + if not isinstance(value, Table): + raise ValueError("invalid table: %r" % value) + self._table = value + + @property + def indexed_columns(self): + return self._indexed_columns + + @indexed_columns.setter + def indexed_columns(self, value): + if value is not None: + if any( + not isinstance(col, Column) or col.table != self.table + for col in value): + raise ValueError("invalid indexed columns: %r" % value) + self._indexed_columns = value + + @property + def index_where(self): + return self._index_where + + @index_where.setter + def index_where(self, value): + from sql.operators import And, Or + if value is not None: + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid index where: %r" % value) + self._index_where = value + + @property + def columns(self): + return self._columns + + @columns.setter + def columns(self, value): + if value is not None: + if any( + not isinstance(col, Column) or col.table != self.table + for col in value): + raise ValueError("invalid columns: %r" % value) + self._columns = value + + @property + def values(self): + return self._values + + @values.setter + def values(self, value): + if value is not None: + if not isinstance(value, (list, Select)): + raise ValueError("invalid values: %r" % value) + if isinstance(value, list): + value = Values([value]) + self._values = value + + @property + def where(self): + return self._where + + @where.setter + def where(self, value): + from sql.operators import And, Or + if value is not None: + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid where: %r" % value) + self._where = value + + def __str__(self): + indexed_columns = '' + if self.indexed_columns: + assert all(c.table == self.table for c in self.indexed_columns) + # Get columns without alias + indexed_columns = ', '.join( + c.column_name for c in self.indexed_columns) + indexed_columns = ' (' + indexed_columns + ')' + if self.index_where: + indexed_columns += ' WHERE ' + str(self.index_where) + else: + assert not self.index_where + do = '' + if not self.columns: + assert not self.values + assert not self.where + do = 'NOTHING' + else: + assert all(c.table == self.table for c in self.columns) + # Get columns without alias + do = ', '.join(c.column_name for c in self.columns) + # TODO manage DEFAULT + values = str(self.values) + if values.startswith('VALUES'): + values = values[len('VALUES'):] + else: + values = ' (' + values + ')' + if len(self.columns) == 1: + # PostgreSQL would require ROW expression + # with single column with parenthesis + do = 'UPDATE SET ' + do + ' =' + values + else: + do = 'UPDATE SET (' + do + ') =' + values + if self.where: + do += ' WHERE %s' % self.where + return 'ON CONFLICT' + indexed_columns + ' DO ' + do + + @property + def params(self): + p = [] + if self.index_where: + p.extend(self.index_where.params) + if self.values: + p.extend(self.values.params) + if self.where: + p.extend(self.where.params) + return p + + class Update(Insert): __slots__ = ('_where', '_values', 'from_') @@ -782,7 +1016,8 @@ def values(self): def values(self, value): if isinstance(value, Select): value = [value] - assert isinstance(value, list) + if not isinstance(value, list): + raise ValueError("invalid values: %r" % value) self._values = value @property @@ -793,9 +1028,14 @@ def where(self): def where(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid where: %r" % value) self._where = value + @staticmethod + def _format_column(value): + return Select._format_column(value) + def __str__(self): assert all(col.table == self.table for col in self.columns) # Get columns without alias @@ -813,7 +1053,7 @@ def __str__(self): returning = '' if self.returning: returning = ' RETURNING ' + ', '.join( - map(self._format, self.returning)) + map(self._format_column, self.returning)) return (self._with_str() + 'UPDATE %s AS "%s" SET ' % (self.table, self.table.alias) + values + from_ + where + returning) @@ -858,7 +1098,8 @@ def table(self): @table.setter def table(self, value): - assert isinstance(value, Table) + if not isinstance(value, Table): + raise ValueError("invalid table: %r" % value) self._table = value @property @@ -869,7 +1110,8 @@ def where(self): def where(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid where: %r" % value) self._where = value @property @@ -879,9 +1121,16 @@ def returning(self): @returning.setter def returning(self, value): if value is not None: - assert isinstance(value, list) + if any( + not isinstance(col, (Expression, SelectQuery)) + for col in value): + raise ValueError("invalid returning: %r" % value) self._returning = value + @staticmethod + def _format(value): + return Select._format_column(value) + def __str__(self): with AliasManager(exclude=[self.table]): only = ' ONLY' if self.only else '' @@ -890,7 +1139,8 @@ def __str__(self): where = ' WHERE ' + str(self.where) returning = '' if self.returning: - returning = ' RETURNING ' + ', '.join(map(str, self.returning)) + returning = ' RETURNING ' + ', '.join( + map(self._format, self.returning)) return (self._with_str() + 'DELETE FROM%s %s' % (only, self.table) + where + returning) @@ -907,12 +1157,220 @@ def params(self): return tuple(p) +class Merge(WithQuery): + __slots__ = ('_target', '_source', '_condition', '_whens') + + def __init__(self, target, source, condition, *whens, **kwargs): + self._target = None + self._source = None + self._condition = None + self._whens = None + self.target = target + self.source = source + self.condition = condition + self.whens = whens + super().__init__(**kwargs) + + @property + def target(self): + return self._target + + @target.setter + def target(self, value): + if not isinstance(value, Table): + raise ValueError("invalid target: %r" % value) + self._target = value + + @property + def source(self): + return self._source + + @source.setter + def source(self, value): + if not isinstance(value, (Table, SelectQuery, Values)): + raise ValueError("invalid source: %r" % value) + self._source = value + + @property + def condition(self): + return self._condition + + @condition.setter + def condition(self, value): + if not isinstance(value, Expression): + raise ValueError("invalid condition: %r" % value) + self._condition = value + + @property + def whens(self): + return self._whens + + @whens.setter + def whens(self, value): + if any(not isinstance(w, Matched) for w in value): + raise ValueError("invalid whens: %r" % value) + self._whens = tuple(value) + + def __str__(self): + with AliasManager(): + if isinstance(self.source, (Select, Values)): + source = '(%s)' % self.source + else: + source = self.source + condition = 'ON %s' % self.condition + return (self._with_str() + + 'MERGE INTO %s AS "%s" ' % (self.target, self.target.alias) + + 'USING %s AS "%s" ' % (source, self.source.alias) + + condition + ' ' + ' '.join(map(str, self.whens))) + + @property + def params(self): + p = [] + p.extend(self._with_params()) + if isinstance(self.source, (SelectQuery, Values)): + p.extend(self.source.params) + if self.condition: + p.extend(self.condition.params) + for match in self.whens: + p.extend(match.params) + return tuple(p) + + +class Matched(object): + __slots__ = ('_condition',) + _when = 'MATCHED' + + def __init__(self, condition=None): + self._condition = None + self.condition = condition + + @property + def condition(self): + return self._condition + + @condition.setter + def condition(self, value): + if value is not None: + if not isinstance(value, Expression): + raise ValueError("invalid condition: %r" % value) + self._condition = value + + def _then_str(self): + return 'DO NOTHING' + + def __str__(self): + if self.condition is not None: + condition = ' AND ' + str(self.condition) + else: + condition = '' + return 'WHEN ' + self._when + condition + ' THEN ' + self._then_str() + + @property + def params(self): + p = [] + if self.condition: + p.extend(self.condition.params) + return tuple(p) + + +class _MatchedValues(Matched): + __slots__ = ('_columns', '_values') + + def __init__(self, columns, values, **kwargs): + self._columns = columns + self._values = values + self.columns = columns + self.values = values + super().__init__(**kwargs) + + @property + def columns(self): + return self._columns + + @columns.setter + def columns(self, value): + if any(not isinstance(col, Column) for col in value): + raise ValueError("invalid columns: %r" % value) + self._columns = value + + +class MatchedUpdate(_MatchedValues, Matched): + __slots__ = () + + @property + def values(self): + return self._values + + @values.setter + def values(self, value): + self._values = value + + def _then_str(self): + columns = [c.column_name for c in self.columns] + return 'UPDATE SET ' + ', '.join( + '%s = %s' % (c, Update._format(v)) + for c, v in zip(columns, self.values)) + + @property + def params(self): + p = list(super().params) + for value in self.values: + if isinstance(value, (Expression, Select)): + p.extend(value.params) + else: + p.append(value) + return tuple(p) + + +class MatchedDelete(Matched): + __slots__ = () + + def _then_str(self): + return 'DELETE' + + +class NotMatched(Matched): + __slots__ = () + _when = 'NOT MATCHED' + + +class NotMatchedInsert(_MatchedValues, NotMatched): + __slots__ = () + + @property + def values(self): + return self._values + + @values.setter + def values(self, value): + if value is not None: + value = Values([value]) + self._values = value + + def _then_str(self): + columns = ', '.join(c.column_name for c in self.columns) + columns = '(' + columns + ')' + if self.values is None: + values = ' DEFAULT VALUES' + else: + values = ' ' + str(self.values) + return 'INSERT ' + columns + values + + @property + def params(self): + p = list(super().params) + if self.values: + p.extend(self.values.params) + return tuple(p) + + class CombiningQuery(FromItem, SelectQuery): __slots__ = ('queries', 'all_') _operator = '' def __init__(self, *queries, **kwargs): - assert all(isinstance(q, Query) for q in queries) + if any(not isinstance(q, Query) for q in queries): + raise ValueError("invalid queries: %r" % (queries,)) self.queries = queries self.all_ = kwargs.pop('all_', False) super(CombiningQuery, self).__init__(**kwargs) @@ -920,17 +1378,21 @@ def __init__(self, *queries, **kwargs): def __str__(self): with AliasManager(): operator = ' %s %s' % (self._operator, 'ALL ' if self.all_ else '') - return (operator.join(map(str, self.queries)) + self._order_by_str + return ( + self._with_str() + + operator.join(map(str, self.queries)) + self._order_by_str + self._limit_offset_str) @property def params(self): p = [] - for q in self.queries: - p.extend(q.params) - if self.order_by: - for expression in self.order_by: - p.extend(expression.params) + with AliasManager(): + p.extend(self._with_params()) + for q in self.queries: + p.extend(q.params) + if self.order_by: + for expression in self.order_by: + p.extend(expression.params) return tuple(p) @@ -973,9 +1435,11 @@ def __str__(self): def params(self): return () - def insert(self, columns=None, values=None, returning=None, with_=None): + def insert( + self, columns=None, values=None, returning=None, with_=None, + on_conflict=None): return Insert(self, columns=columns, values=values, - returning=returning, with_=with_) + on_conflict=on_conflict, returning=returning, with_=with_) def update(self, columns, values, from_=None, where=None, returning=None, with_=None): @@ -987,6 +1451,25 @@ def delete(self, only=False, using=None, where=None, returning=None, return Delete(self, only=only, using=using, where=where, returning=returning, with_=with_) + def merge(self, source, condition, *whens, with_=None): + return Merge(self, source, condition, *whens, with_=with_) + + +class _Excluded(Table): + def __init__(self): + super().__init__('EXCLUDED') + + @property + def alias(self): + return 'EXCLUDED' + + @property + def has_alias(self): + return False + + +Excluded = _Excluded() + class Join(FromItem): __slots__ = ('_left', '_right', '_condition', '_type_') @@ -1007,7 +1490,8 @@ def left(self): @left.setter def left(self, value): - assert isinstance(value, FromItem) + if not isinstance(value, FromItem): + raise ValueError("invalid left: %r" % value) self._left = value @property @@ -1016,7 +1500,8 @@ def right(self): @right.setter def right(self, value): - assert isinstance(value, FromItem) + if not isinstance(value, FromItem): + raise ValueError("invalid right: %r" % value) self._right = value @property @@ -1027,7 +1512,8 @@ def condition(self): def condition(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid condition: %r" % value) self._condition = value @property @@ -1037,8 +1523,10 @@ def type_(self): @type_.setter def type_(self, value): value = value.upper() - assert value in ('INNER', 'LEFT', 'LEFT OUTER', - 'RIGHT', 'RIGHT OUTER', 'FULL', 'FULL OUTER', 'CROSS') + if value not in { + 'INNER', 'LEFT', 'LEFT OUTER', 'RIGHT', 'RIGHT OUTER', 'FULL', + 'FULL OUTER', 'CROSS'}: + raise ValueError("invalid type: %r" % value) self._type_ = value def __str__(self): @@ -1054,9 +1542,8 @@ def __str__(self): def params(self): p = [] for item in (self.left, self.right): - if hasattr(item, 'params'): - p.extend(item.params) - if hasattr(self.condition, 'params'): + p.extend(item.params) + if self.condition: p.extend(self.condition.params) return tuple(p) @@ -1113,8 +1600,10 @@ def params(self): return tuple(p) def __add__(self, other): - assert isinstance(other, FromItem) - assert not isinstance(other, CombiningQuery) + if not isinstance(other, FromItem): + return NotImplemented + elif isinstance(other, CombiningQuery): + return NotImplemented return From(super(From, self).__add__([other])) @@ -1137,7 +1626,7 @@ def format_(value): @property def params(self): - p = [] + p = list(super().params) for values in self: for value in values: if isinstance(value, Expression): @@ -1148,7 +1637,7 @@ def params(self): class Expression(object): - __slots__ = () + __slots__ = ('__weakref__',) def __str__(self): raise NotImplementedError @@ -1402,21 +1891,35 @@ def params(self): class Collate(Expression): - __slots__ = ('expression', 'collation') + __slots__ = ('_expression', '_collation') def __init__(self, expression, collation): super(Collate, self).__init__() self.expression = expression self.collation = collation + @property + def expression(self): + return self._expression + + @expression.setter + def expression(self, value): + self._expression = value + + @property + def collation(self): + return self._collation + + @collation.setter + def collation(self, value): + self._collation = value + def __str__(self): if isinstance(self.expression, Expression): value = self.expression else: value = Flavor.get().param - if '"' in self.collation: - raise ValueError("Wrong collation %s" % self.collation) - return '%s COLLATE "%s"' % (value, self.collation) + return '%s COLLATE %s' % (value, _escape_identifier(self.collation)) @property def params(self): @@ -1426,9 +1929,87 @@ def params(self): return (self.expression,) +class Grouping(Expression): + __slots__ = ('_sets',) + + def __init__(self, *sets): + super().__init__() + self.sets = sets + + @property + def sets(self): + return self._sets + + @sets.setter + def sets(self, value): + if any( + not isinstance(col, Expression) + for cols in value + for col in cols): + raise ValueError("invalid sets: %r" % value) + self._sets = tuple(tuple(cols) for cols in value) + + def __str__(self): + return 'GROUPING SETS (%s)' % ( + ', '.join( + '(%s)' % ', '.join(str(col) for col in cols) + for cols in self.sets)) + + @property + def params(self): + return sum((col.params for cols in self.sets for col in cols), ()) + + +class Rollup(Expression): + __slots__ = ('_expressions',) + + def __init__(self, *expressions): + super().__init__() + self.expressions = expressions + + @property + def expressions(self): + return self._expressions + + @expressions.setter + def expressions(self, value): + if not all( + isinstance(col, Expression) + or all(isinstance(c, Expression) for c in col) + for col in value): + raise ValueError("invalid expressions: %r" % value) + self._expressions = tuple(value) + + def __str__(self): + def format(col): + if isinstance(col, Expression): + return str(col) + else: + return '(%s)' % ', '.join(str(c) for c in col) + return '%s (%s)' % ( + self.__class__.__name__.upper(), + ', '.join(format(col) for col in self.expressions)) + + @property + def params(self): + p = [] + for col in self.expressions: + if isinstance(col, Expression): + p.extend(col.params) + else: + for c in col: + p.extend(c.params) + return tuple(p) + + +class Cube(Rollup): + pass + + class Window(object): __slots__ = ( - '_partition', '_order_by', '_frame', '_start', '_end', '_exclude') + '_partition', '_order_by', '_frame', '_start', '_end', '_exclude', + '__weakref__') def __init__(self, partition, order_by=None, frame=None, start=None, end=0, exclude=None): @@ -1451,7 +2032,8 @@ def partition(self): @partition.setter def partition(self, value): - assert all(isinstance(e, Expression) for e in value) + if any(not isinstance(e, Expression) for e in value): + raise ValueError("invalid partition: %r" % value) self._partition = value @property @@ -1463,7 +2045,8 @@ def order_by(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid order by: %r" % value) self._order_by = value @property @@ -1473,7 +2056,8 @@ def frame(self): @frame.setter def frame(self, value): if value: - assert value in ['RANGE', 'ROWS', 'GROUPS'] + if value not in {'RANGE', 'ROWS', 'GROUPS'}: + raise ValueError("invalid frame: %r" % value) self._frame = value @property @@ -1483,7 +2067,8 @@ def start(self): @start.setter def start(self, value): if value: - assert isinstance(value, numbers.Integral) + if not isinstance(value, numbers.Integral): + raise ValueError("invalid start: %r" % value) self._start = value @property @@ -1493,7 +2078,8 @@ def end(self): @end.setter def end(self, value): if value: - assert isinstance(value, numbers.Integral) + if not isinstance(value, numbers.Integral): + raise ValueError("invalid end: %r" % value) self._end = value @property @@ -1503,7 +2089,8 @@ def exclude(self): @exclude.setter def exclude(self, value): if value: - assert value in ['CURRENT ROW', 'GROUP', 'TIES'] + if value not in {'CURRENT ROW', 'GROUP', 'TIES'}: + raise ValueError("invalid exclude: %r" % value) self._exclude = value @property @@ -1515,6 +2102,7 @@ def has_alias(self): return AliasManager.contains(self) def __str__(self): + param = Flavor.get().param partition = '' if self.partition: partition = 'PARTITION BY ' + ', '.join(map(str, self.partition)) @@ -1528,9 +2116,9 @@ def format(frame, direction): elif not frame: return 'CURRENT ROW' elif frame < 0: - return '%s PRECEDING' % -frame + return '%s PRECEDING' % param elif frame > 0: - return '%s FOLLOWING' % frame + return '%s FOLLOWING' % param frame = '' if self.frame: @@ -1551,6 +2139,11 @@ def params(self): if self.order_by: for expression in self.order_by: p.extend(expression.params) + if self.frame: + if self.start: + p.append(abs(self.start)) + if self.end: + p.append(abs(self.end)) return tuple(p) @@ -1570,7 +2163,8 @@ def expression(self): @expression.setter def expression(self, value): - assert isinstance(value, (Expression, SelectQuery)) + if not isinstance(value, (Expression, SelectQuery)): + raise ValueError("invalid expression: %r" % value) self._expression = value def __str__(self): @@ -1673,7 +2267,8 @@ def type_(self): @type_.setter def type_(self, value): value = value.upper() - assert value in ('UPDATE', 'SHARE') + if value not in {'UPDATE', 'SHARE'}: + raise ValueError("invalid type: %r" % value) self._type_ = value def __str__(self): diff --git a/sql/aggregate.py b/sql/aggregate.py index 3cbc28a..dca5394 100644 --- a/sql/aggregate.py +++ b/sql/aggregate.py @@ -4,11 +4,10 @@ __all__ = ['Avg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'Count', 'Every', 'Max', 'Min', 'Stddev', 'Sum', 'Variance'] -_sentinel = object() class Aggregate(Expression): - __slots__ = ('expression', '_distinct', '_order_by', '_within', + __slots__ = ('_expression', '_distinct', '_order_by', '_within', '_filter', '_window') _sql = '' @@ -22,13 +21,24 @@ def __init__(self, expression, distinct=False, order_by=None, within=None, self.filter_ = filter_ self.window = window + @property + def expression(self): + return self._expression + + @expression.setter + def expression(self, value): + if not isinstance(value, Expression): + raise ValueError("invalid expression: %r" % value) + self._expression = value + @property def distinct(self): return self._distinct @distinct.setter def distinct(self, value): - assert isinstance(value, bool) + if not isinstance(value, bool): + raise ValueError("invalid distinct: %r" % value) self._distinct = value @property @@ -40,7 +50,8 @@ def order_by(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid order by: %r" % value) self._order_by = value @property @@ -52,7 +63,8 @@ def within(self, value): if value is not None: if isinstance(value, Expression): value = [value] - assert all(isinstance(col, Expression) for col in value) + if any(not isinstance(col, Expression) for col in value): + raise ValueError("invalid within: %r" % value) self._within = value @property @@ -63,7 +75,8 @@ def filter_(self): def filter_(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid filter: %r" % value) self._filter = value @property @@ -73,7 +86,8 @@ def window(self): @window.setter def window(self, value): if value: - assert isinstance(value, Window) + if not isinstance(value, Window): + raise ValueError("invalid window: %r" % value) self._window = value @property @@ -154,20 +168,31 @@ class BoolOr(Aggregate): _sql = 'BOOL_OR' +class _Star(Expression): + __slots__ = () + + def __str__(self): + return '*' + + @property + def params(self): + return () + + class Count(Aggregate): __slots__ = () _sql = 'COUNT' - def __init__(self, expression=_sentinel, **kwargs): - if expression is _sentinel: - expression = Literal('*') + def __init__(self, expression=_Star(), **kwargs): super().__init__(expression, **kwargs) @property def _case_expression(self): expression = super(Count, self)._case_expression - if (isinstance(self.expression, Literal) - and expression.value == '*'): + if (isinstance(self.expression, _Star) + # Keep testing Literal('*') for backward compatibility + or (isinstance(self.expression, Literal) + and expression.value == '*')): expression = Literal(1) return expression diff --git a/sql/functions.py b/sql/functions.py index 2d19db5..18b39ef 100644 --- a/sql/functions.py +++ b/sql/functions.py @@ -1,5 +1,7 @@ # This file is part of python-sql. The COPYRIGHT file at the top level of # this repository contains the full copyright notices and license terms. + +from enum import Enum, auto from itertools import chain from sql import CombiningQuery, Expression, Flavor, FromItem, Select, Window @@ -39,7 +41,8 @@ def columns_definitions(self): @columns_definitions.setter def columns_definitions(self, value): - assert isinstance(value, list) + if not isinstance(value, list): + raise ValueError("invalid columns definitions: %r" % value) self._columns_definitions = value @staticmethod @@ -84,7 +87,7 @@ def __str__(self): return (self._function + '(' + ' '.join(chain(*zip( self._keywords, - map(self._format, self.args))))[1:] + map(self._format, self.args)))).strip() + ')') @@ -286,7 +289,8 @@ class Trim(Function): _function = 'TRIM' def __init__(self, string, position='BOTH', characters=' '): - assert position.upper() in ('LEADING', 'TRAILING', 'BOTH') + if position.upper() not in {'LEADING', 'TRAILING', 'BOTH'}: + raise ValueError("invalid position: %r" % position) self.position = position.upper() self.characters = characters self.string = string @@ -315,7 +319,7 @@ def params(self): for arg in (self.characters, self.string): if isinstance(arg, str): p.append(arg) - elif hasattr(arg, 'params'): + else: p.extend(arg.params) return tuple(p) @@ -381,9 +385,67 @@ class DateTrunc(Function): class Extract(FunctionKeyword): - __slots__ = () + __slots__ = ('_field',) _function = 'EXTRACT' - _keywords = ('', 'FROM') + + class Fields(str, Enum): + def _generate_next_value_(name, start, count, last_values): + return name.upper() + + CENTURY = auto() + DAY = auto() + DECADE = auto() + DOW = auto() + DOY = auto() + EPOCH = auto() + HOUR = auto() + ISODOW = auto() + ISOYEAR = auto() + JULIAN = auto() + MICROSECONDS = auto() + MILLENNIUM = auto() + MILLISECONDS = auto() + MINUTE = auto() + MONTH = auto() + QUARTER = auto() + SECOND = auto() + TIMEZONE = auto() + TIMEZONE_HOUR = auto() + TIMEZONE_MINUTE = auto() + WEEK = auto() + YEAR = auto() + + def __init__(self, field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.field = field + + @property + def field(self): + return self._field + + @field.setter + def field(self, value): + value = value.upper() + if not hasattr(self.Fields, value): + raise ValueError("invalid field: %r" % value) + self._field = value + + @property + def _keywords(self): + return ('%s FROM' % self.field,) + + def __str__(self): + Mapping = Flavor.get().function_mapping.get(self.__class__) + if Mapping: + return str(Mapping(self.field, *self.args)) + return super().__str__() + + @property + def params(self): + Mapping = Flavor.get().function_mapping.get(self.__class__) + if Mapping: + return Mapping(self.field, *self.args).params + return super().params class Isfinite(Function): @@ -483,7 +545,8 @@ def filter_(self): def filter_(self, value): from sql.operators import And, Or if value is not None: - assert isinstance(value, (Expression, And, Or)) + if not isinstance(value, (Expression, And, Or)): + raise ValueError("invalid filter: %r" % value) self._filter = value @property @@ -493,7 +556,8 @@ def window(self): @window.setter def window(self, value): if value: - assert isinstance(value, Window) + if not isinstance(value, Window): + raise ValueError("invalid window: %r" % value) self._window = value def __str__(self): diff --git a/sql/operators.py b/sql/operators.py index f0feda5..8026b94 100644 --- a/sql/operators.py +++ b/sql/operators.py @@ -48,9 +48,11 @@ def convert(operands): def _format(self, operand, param=None): if param is None: param = Flavor.get().param - if isinstance(operand, Expression): + if (isinstance(operand, Expression) + and (not isinstance(operand, Operator) + or isinstance(operand, UnaryOperator))): return str(operand) - elif isinstance(operand, (Select, CombiningQuery)): + elif isinstance(operand, (Expression, Select, CombiningQuery)): return '(%s)' % operand elif isinstance(operand, (list, tuple)): return '(' + ', '.join(self._format(o, param) @@ -88,7 +90,7 @@ def _operands(self): return (self.operand,) def __str__(self): - return '(%s %s)' % (self._operator, self._format(self.operand)) + return '%s %s' % (self._operator, self._format(self.operand)) class BinaryOperator(Operator): @@ -105,7 +107,7 @@ def _operands(self): def __str__(self): left, right = self._operands - return '(%s %s %s)' % (self._format(left), self._operator, + return '%s %s %s' % (self._format(left), self._operator, self._format(right)) def __invert__(self): @@ -121,7 +123,7 @@ def _operands(self): return self def __str__(self): - return '(' + (' %s ' % self._operator).join(map(str, self)) + ')' + return (' %s ' % self._operator).join(map(self._format, self)) class And(NaryOperator): @@ -183,9 +185,9 @@ def _operands(self): def __str__(self): if self.left is Null: - return '(%s IS NULL)' % self.right + return '%s IS NULL' % self.right elif self.right is Null: - return '(%s IS NULL)' % self.left + return '%s IS NULL' % self.left return super(Equal, self).__str__() @@ -195,9 +197,9 @@ class NotEqual(Equal): def __str__(self): if self.left is Null: - return '(%s IS NOT NULL)' % self.right + return '%s IS NOT NULL' % self.right elif self.right is Null: - return '(%s IS NOT NULL)' % self.left + return '%s IS NOT NULL' % self.left return super(Equal, self).__str__() @@ -219,7 +221,7 @@ def __str__(self): operator = self._operator if self.symmetric: operator += ' SYMMETRIC' - return '(%s %s %s AND %s)' % ( + return '%s %s %s AND %s' % ( self._format(self.operand), operator, self._format(self.left), self._format(self.right)) @@ -248,7 +250,8 @@ class Is(BinaryOperator): _operator = 'IS' def __init__(self, left, right): - assert right in [None, True, False] + if right not in {None, True, False}: + raise ValueError("invalid right: %r" % right) super(Is, self).__init__(left, right) @property @@ -257,12 +260,12 @@ def _operands(self): def __str__(self): if self.right is None: - return '(%s %s UNKNOWN)' % ( + return '%s %s UNKNOWN' % ( self._format(self.left), self._operator) elif self.right is True: - return '(%s %s TRUE)' % (self._format(self.left), self._operator) + return '%s %s TRUE' % (self._format(self.left), self._operator) elif self.right is False: - return '(%s %s FALSE)' % (self._format(self.left), self._operator) + return '%s %s FALSE' % (self._format(self.left), self._operator) class IsNot(Is): @@ -377,26 +380,27 @@ class Like(BinaryOperator): __slots__ = 'escape' _operator = 'LIKE' - def __init__(self, left, right, escape='\\'): + def __init__(self, left, right, escape=None): super().__init__(left, right) - assert not escape or len(escape) == 1 + if escape and len(escape) != 1: + raise ValueError("invalid escape: %r" % escape) self.escape = escape @property def params(self): params = super().params if self.escape or Flavor().get().escape_empty: - params += (self.escape,) + params += (self.escape or '',) return params def __str__(self): left, right = self._operands if self.escape or Flavor().get().escape_empty: - return '(%s %s %s ESCAPE %s)' % ( + return '%s %s %s ESCAPE %s' % ( self._format(left), self._operator, self._format(right), self._format(self.escape or '')) else: - return '(%s %s %s)' % ( + return '%s %s %s' % ( self._format(left), self._operator, self._format(right)) def __invert__(self): @@ -455,7 +459,24 @@ class Exists(UnaryOperator): _operator = 'EXISTS' -class Any(UnaryOperator): +class _ArrayOperator(UnaryOperator): + __slots__ = () + + @property + def params(self): + if isinstance(self.operand, (list, tuple, array)): + return (list(self.operand),) + return super().params + + def _format(self, operand, param=None): + if param is None: + param = Flavor.get().param + if isinstance(operand, (list, tuple, array)): + return '(%s)' % param + return super()._format(operand, param=param) + + +class Any(_ArrayOperator): __slots__ = () _operator = 'ANY' @@ -463,7 +484,7 @@ class Any(UnaryOperator): Some = Any -class All(UnaryOperator): +class All(_ArrayOperator): __slots__ = () _operator = 'ALL' diff --git a/sql/tests/__init__.py b/sql/tests/__init__.py index e1048d4..099392d 100644 --- a/sql/tests/__init__.py +++ b/sql/tests/__init__.py @@ -6,7 +6,7 @@ import sql here = os.path.dirname(__file__) -readme = os.path.normpath(os.path.join(here, '..', '..', 'README')) +readme = os.path.normpath(os.path.join(here, '..', '..', 'README.rst')) def load_tests(loader, tests, pattern): diff --git a/sql/tests/test_aggregate.py b/sql/tests/test_aggregate.py index e70fc93..3e8dbc1 100644 --- a/sql/tests/test_aggregate.py +++ b/sql/tests/test_aggregate.py @@ -3,23 +3,47 @@ import unittest from sql import AliasManager, Flavor, Literal, Table, Window -from sql.aggregate import Avg, Count +from sql.aggregate import Aggregate, Avg, Count class TestAggregate(unittest.TestCase): table = Table('t') + def test_invalid_expression(self): + with self.assertRaises(ValueError): + Aggregate('foo') + + def test_invalid_distinct(self): + with self.assertRaises(ValueError): + Aggregate(self.table.c, distinct='foo') + + def test_invalid_order(self): + with self.assertRaises(ValueError): + Aggregate(self.table.c, order_by=['foo']) + + def test_invalid_within(self): + with self.assertRaises(ValueError): + Aggregate(self.table.c, within=['foo']) + + def test_invalid_filter(self): + with self.assertRaises(ValueError): + Aggregate(self.table.c, filter_='foo') + + def test_invalid_window(self): + with self.assertRaises(ValueError): + Aggregate(self.table.c, window='foo') + def test_avg(self): avg = Avg(self.table.c) self.assertEqual(str(avg), 'AVG("c")') avg = Avg(self.table.a + self.table.b) - self.assertEqual(str(avg), 'AVG(("a" + "b"))') + self.assertEqual(str(avg), 'AVG("a" + "b")') def test_count_without_expression(self): count = Count() - self.assertEqual(str(count), 'COUNT(%s)') - self.assertEqual(count.params, ('*',)) + self.assertEqual(str(count), 'COUNT(*)') + self.assertEqual(count.params, ()) def test_order_by_one_column(self): avg = Avg(self.table.a, order_by=self.table.b) @@ -43,7 +67,7 @@ def test_filter(self): try: avg = Avg(self.table.a + 1, filter_=self.table.a > 0) self.assertEqual( - str(avg), 'AVG(("a" + %s)) FILTER (WHERE ("a" > %s))') + str(avg), 'AVG("a" + %s) FILTER (WHERE "a" > %s)') self.assertEqual(avg.params, (1, 0)) finally: Flavor.set(Flavor()) @@ -51,13 +75,13 @@ def test_filter(self): def test_filter_case(self): avg = Avg(self.table.a + 1, filter_=self.table.a > 0) self.assertEqual( - str(avg), 'AVG(CASE WHEN ("a" > %s) THEN ("a" + %s) END)') + str(avg), 'AVG(CASE WHEN "a" > %s THEN "a" + %s END)') self.assertEqual(avg.params, (0, 1)) def test_filter_case_count_star(self): count = Count(Literal('*'), filter_=self.table.a > 0) self.assertEqual( - str(count), 'COUNT(CASE WHEN ("a" > %s) THEN %s END)') + str(count), 'COUNT(CASE WHEN "a" > %s THEN %s END)') self.assertEqual(count.params, (0, 1)) def test_window(self): diff --git a/sql/tests/test_alias.py b/sql/tests/test_alias.py index a630e2e..c37ad27 100644 --- a/sql/tests/test_alias.py +++ b/sql/tests/test_alias.py @@ -61,3 +61,18 @@ def test_threading(self): self.finish2.wait() if not self.succeed1.is_set() or not self.succeed2.is_set(): self.fail() + + def test_contains(self): + with AliasManager(): + AliasManager.get(self.t1) + self.assertTrue(AliasManager.contains(self.t1)) + + def test_contains_exclude(self): + with AliasManager(exclude=[self.t1]): + self.assertEqual(AliasManager.get(self.t1), '') + self.assertFalse(AliasManager.contains(self.t1)) + + def test_set(self): + with AliasManager(): + AliasManager.set(self.t1, 'foo') + self.assertEqual(AliasManager.get(self.t1), 'foo') diff --git a/sql/tests/test_collate.py b/sql/tests/test_collate.py index a06a6ba..381f7ca 100644 --- a/sql/tests/test_collate.py +++ b/sql/tests/test_collate.py @@ -17,8 +17,3 @@ def test_collate_no_expression(self): collate = Collate("foo", 'C') self.assertEqual(str(collate), '%s COLLATE "C"') self.assertEqual(collate.params, ("foo",)) - - def test_collate_injection(self): - collate = Collate(self.column, 'C";') - with self.assertRaises(ValueError): - str(collate) diff --git a/sql/tests/test_combining_query.py b/sql/tests/test_combining_query.py index 2d6ca47..caa2845 100644 --- a/sql/tests/test_combining_query.py +++ b/sql/tests/test_combining_query.py @@ -2,7 +2,7 @@ # this repository contains the full copyright notices and license terms. import unittest -from sql import Table, Union +from sql import CombiningQuery, Table, Union, With class TestUnion(unittest.TestCase): @@ -10,6 +10,10 @@ class TestUnion(unittest.TestCase): q2 = Table('t2').select() q3 = Table('t3').select() + def test_invalid_queries(self): + with self.assertRaises(ValueError): + CombiningQuery('foo', 'bar') + def test_union2(self): query = Union(self.q1, self.q2) self.assertEqual(str(query), @@ -21,6 +25,18 @@ def test_union2(self): 'SELECT * FROM "t1" AS "a" UNION SELECT * FROM "t2" AS "b"') self.assertEqual(tuple(query.params), ()) + def test_union_with(self): + table = Table('t') + with_ = With() + with_.query = table.select(table.id, where=table.id == 1) + query = Union(self.q1, self.q2, with_=with_) + + self.assertEqual(str(query), + 'WITH "a" AS (' + 'SELECT "b"."id" FROM "t" AS "b" WHERE "b"."id" = %s) ' + 'SELECT * FROM "t1" AS "c" UNION SELECT * FROM "t2" AS "d"') + self.assertEqual(tuple(query.params), (1,)) + def test_union3(self): query = Union(self.q1, self.q2, self.q3) self.assertEqual(str(query), diff --git a/sql/tests/test_conditionals.py b/sql/tests/test_conditionals.py index 937d7e3..7945f59 100644 --- a/sql/tests/test_conditionals.py +++ b/sql/tests/test_conditionals.py @@ -36,9 +36,9 @@ def test_case_sql(self): where=self.table.c2 == 'foo')) self.assertEqual(str(case), 'CASE WHEN ' - '(SELECT "a"."bool" FROM "t" AS "a" WHERE ("a"."c2" = %s)) ' + '(SELECT "a"."bool" FROM "t" AS "a" WHERE "a"."c2" = %s) ' 'THEN "c1" ' - 'ELSE (SELECT "a"."c1" FROM "t" AS "a" WHERE ("a"."c2" = %s)) END') + 'ELSE (SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c2" = %s) END') self.assertEqual(case.params, ('bar', 'foo')) def test_coalesce(self): @@ -52,7 +52,7 @@ def test_coalesce_sql(self): self.table.c2) self.assertEqual(str(coalesce), 'COALESCE(' - '(SELECT "a"."c1" FROM "t" AS "a" WHERE ("a"."c2" = %s)), "c2")') + '(SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c2" = %s), "c2")') self.assertEqual(coalesce.params, ('bar',)) def test_nullif(self): diff --git a/sql/tests/test_delete.py b/sql/tests/test_delete.py index 764b87f..2564991 100644 --- a/sql/tests/test_delete.py +++ b/sql/tests/test_delete.py @@ -2,7 +2,7 @@ # this repository contains the full copyright notices and license terms. import unittest -from sql import Table, With +from sql import Delete, Table, With class TestDelete(unittest.TestCase): @@ -15,7 +15,7 @@ def test_delete1(self): def test_delete2(self): query = self.table.delete(where=(self.table.c == 'foo')) - self.assertEqual(str(query), 'DELETE FROM "t" WHERE ("c" = %s)') + self.assertEqual(str(query), 'DELETE FROM "t" WHERE "c" = %s') self.assertEqual(query.params, ('foo',)) def test_delete3(self): @@ -23,15 +23,35 @@ def test_delete3(self): t2 = Table('t2') query = t1.delete(where=(t1.c.in_(t2.select(t2.c)))) self.assertEqual(str(query), - 'DELETE FROM "t1" WHERE ("c" IN (' - 'SELECT "a"."c" FROM "t2" AS "a"))') + 'DELETE FROM "t1" WHERE "c" IN (' + 'SELECT "a"."c" FROM "t2" AS "a")') self.assertEqual(query.params, ()) + def test_delete_invalid_table(self): + with self.assertRaises(ValueError): + Delete('foo') + + def test_delete_invalid_where(self): + with self.assertRaises(ValueError): + self.table.delete(where='foo') + def test_delete_returning(self): query = self.table.delete(returning=[self.table.c]) self.assertEqual(str(query), 'DELETE FROM "t" RETURNING "c"') self.assertEqual(query.params, ()) + def test_delet_returning_select(self): + query = self.table.delete(returning=[self.table.select()]) + + self.assertEqual( + str(query), + 'DELETE FROM "t" RETURNING (SELECT * FROM "t")') + self.assertEqual(query.params, ()) + + def test_delete_invalid_returning(self): + with self.assertRaises(ValueError): + self.table.delete(returning='foo') + def test_with(self): t1 = Table('t1') w = With(query=t1.select(t1.c1)) @@ -41,5 +61,5 @@ def test_with(self): self.assertEqual(str(query), 'WITH "a" AS (SELECT "b"."c1" FROM "t1" AS "b") ' 'DELETE FROM "t" WHERE ' - '("c2" IN (SELECT "a"."c3" FROM "a" AS "a"))') + '"c2" IN (SELECT "a"."c3" FROM "a" AS "a")') self.assertEqual(query.params, ()) diff --git a/sql/tests/test_excluded.py b/sql/tests/test_excluded.py new file mode 100644 index 0000000..fb47e53 --- /dev/null +++ b/sql/tests/test_excluded.py @@ -0,0 +1,15 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import Excluded + + +class TestExcluded(unittest.TestCase): + + def test_alias(self): + self.assertEqual(Excluded.alias, 'EXCLUDED') + + def test_has_alias(self): + self.assertFalse(Excluded.has_alias) diff --git a/sql/tests/test_expression.py b/sql/tests/test_expression.py new file mode 100644 index 0000000..0741094 --- /dev/null +++ b/sql/tests/test_expression.py @@ -0,0 +1,17 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import Expression + + +class TestExpression(unittest.TestCase): + + def test_str(self): + with self.assertRaises(NotImplementedError): + str(Expression()) + + def test_params(self): + with self.assertRaises(NotImplementedError): + Expression().params diff --git a/sql/tests/test_flavor.py b/sql/tests/test_flavor.py new file mode 100644 index 0000000..205422e --- /dev/null +++ b/sql/tests/test_flavor.py @@ -0,0 +1,46 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import Flavor + + +class TestFlavor(unittest.TestCase): + + def test(self): + Flavor() + + def test_limitstyle(self): + flavor = Flavor(limitstyle='rownum') + + self.assertEqual(flavor.limitstyle, 'rownum') + + def test_invalid_limitstyle(self): + with self.assertRaises(ValueError): + Flavor(limitstyle='foo') + + def test_max_limit(self): + flavor = Flavor(max_limit=42) + + self.assertEqual(flavor.max_limit, 42) + + def test_invalid_max_limit(self): + with self.assertRaises(ValueError): + Flavor(max_limit='foo') + + def test_paramstyle_format(self): + flavor = Flavor(paramstyle='format') + + self.assertEqual(flavor.paramstyle, 'format') + self.assertEqual(flavor.param, '%s') + + def test_paramstyle_qmark(self): + flavor = Flavor(paramstyle='qmark') + + self.assertEqual(flavor.paramstyle, 'qmark') + self.assertEqual(flavor.param, '?') + + def test_invalid_paramstyle(self): + with self.assertRaises(ValueError): + Flavor(paramstyle='foo') diff --git a/sql/tests/test_for.py b/sql/tests/test_for.py index 8003257..599de0e 100644 --- a/sql/tests/test_for.py +++ b/sql/tests/test_for.py @@ -14,3 +14,7 @@ def test_for_single_table(self): for_ = For('UPDATE') for_.tables = Table('t1') self.assertEqual(str(for_), 'FOR UPDATE OF "t1"') + + def test_invalid_type(self): + with self.assertRaises(ValueError): + For('foo') diff --git a/sql/tests/test_from.py b/sql/tests/test_from.py new file mode 100644 index 0000000..98cb6c2 --- /dev/null +++ b/sql/tests/test_from.py @@ -0,0 +1,25 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import CombiningQuery, From, Table + + +class TestFrom(unittest.TestCase): + + def test_add(self): + t1 = Table('t1') + t2 = Table('t2') + from_ = From([t1]) + t2 + + self.assertEqual(from_, [t1, t2]) + + def test_invalid_add(self): + with self.assertRaises(TypeError): + From([Table('t')]) + 'foo' + + def test_invalid_add_combining_query(self): + with self.assertRaises(TypeError): + From([Table('t')]) + CombiningQuery( + Table('t1').select(), Table('t2').select()) diff --git a/sql/tests/test_from_item.py b/sql/tests/test_from_item.py new file mode 100644 index 0000000..9b0a465 --- /dev/null +++ b/sql/tests/test_from_item.py @@ -0,0 +1,46 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import AliasManager, Column, From, FromItem + + +class TestFromItem(unittest.TestCase): + + def test_from_item(self): + from_item = FromItem() + + with AliasManager(): + self.assertFalse(from_item.has_alias) + from_item.alias + self.assertTrue(from_item.has_alias) + + def test_get_column(self): + from_item = FromItem() + + foo = from_item.foo + + self.assertIsInstance(foo, Column) + self.assertEqual(foo.name, 'foo') + + def test_get_invalid_column(self): + from_item = FromItem() + + with self.assertRaises(AttributeError): + from_item.__foo__ + + def test_add(self): + from_item1 = FromItem() + from_item2 = FromItem() + + from_ = from_item1 + from_item2 + + self.assertIsInstance(from_, From) + self.assertEqual(from_, [from_item1, from_item2]) + + def test_invalid_add(self): + from_item = FromItem() + + with self.assertRaises(TypeError): + from_item + 'foo' diff --git a/sql/tests/test_functions.py b/sql/tests/test_functions.py index 51f7022..0b5d3a6 100644 --- a/sql/tests/test_functions.py +++ b/sql/tests/test_functions.py @@ -4,13 +4,17 @@ from sql import AliasManager, Flavor, Table, Window from sql.functions import ( - Abs, AtTimeZone, CurrentTime, Div, Function, FunctionKeyword, - FunctionNotCallable, Overlay, Rank, Trim) + Abs, AtTimeZone, CurrentTime, Div, Extract, Function, FunctionKeyword, + FunctionNotCallable, Overlay, Rank, Trim, WindowFunction) class TestFunctions(unittest.TestCase): table = Table('t') + def test_invalid_columns_definitions(self): + with self.assertRaises(ValueError): + Function(columns_definitions='foo') + def test_abs(self): abs_ = Abs(self.table.c1) self.assertEqual(str(abs_), 'ABS("c1")') @@ -66,7 +70,7 @@ def test_sql(self): abs_ = Abs(self.table.select(self.table.c1, where=self.table.c2 == 'foo')) self.assertEqual(str(abs_), - 'ABS((SELECT "a"."c1" FROM "t" AS "a" WHERE ("a"."c2" = %s)))') + 'ABS((SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c2" = %s))') self.assertEqual(abs_.params, ('foo',)) def test_overlay(self): @@ -87,6 +91,10 @@ def test_trim(self): self.assertEqual(str(trim), 'TRIM(BOTH %s FROM "c1")') self.assertEqual(trim.params, (' ',)) + def test_trim_invalid_position(self): + with self.assertRaises(ValueError): + Trim('test', 'foo') + def test_at_time_zone(self): time_zone = AtTimeZone(self.table.c1, 'UTC') self.assertEqual(str(time_zone), '"c1" AT TIME ZONE %s') @@ -102,7 +110,7 @@ def test_at_time_zone_sql(self): self.table.select(self.table.tz, where=self.table.c1 == 'foo')) self.assertEqual(str(time_zone), '"c1" AT TIME ZONE ' - '(SELECT "a"."tz" FROM "t" AS "a" WHERE ("a"."c1" = %s))') + '(SELECT "a"."tz" FROM "t" AS "a" WHERE "a"."c1" = %s)') self.assertEqual(time_zone.params, ('foo',)) def test_at_time_zone_mapping(self): @@ -131,6 +139,38 @@ def test_current_time(self): self.assertEqual(str(current_time), 'CURRENT_TIME') self.assertEqual(current_time.params, ()) + def test_extract(self): + extract = Extract(Extract.Fields.DAY, self.table.c) + self.assertEqual(str(extract), 'EXTRACT(DAY FROM "c")') + self.assertEqual(extract.params, ()) + + extract = Extract('day', self.table.c) + self.assertEqual(str(extract), 'EXTRACT(DAY FROM "c")') + self.assertEqual(extract.params, ()) + + extract = Extract(Extract.Fields.DAY, '2000-01-01') + self.assertEqual(str(extract), 'EXTRACT(DAY FROM %s)') + self.assertEqual(extract.params, ('2000-01-01',)) + + def test_extract_mapping(self): + class MyExtract(Function): + _function = 'MY_EXTRACT' + + extract = Extract(Extract.Fields.DAY, '2000-01-01') + flavor = Flavor(function_mapping={ + Extract: MyExtract, + }) + Flavor.set(flavor) + try: + self.assertEqual(str(extract), 'MY_EXTRACT(%s, %s)') + self.assertEqual(extract.params, ('DAY', '2000-01-01')) + finally: + Flavor.set(Flavor()) + + def test_extract_invalid_field(self): + with self.assertRaises(ValueError): + Extract('foo', self.table.c) + class TestWindowFunction(unittest.TestCase): @@ -142,11 +182,19 @@ def test_window(self): self.assertEqual(str(function), 'RANK("a"."c") OVER ()') self.assertEqual(function.params, ()) + def test_invalid_window(self): + with self.assertRaises(ValueError): + WindowFunction(window='foo') + def test_filter(self): t = Table('t') function = Rank(t.c, filter_=t.c > 0, window=Window([])) with AliasManager(): self.assertEqual(str(function), - 'RANK("a"."c") FILTER (WHERE ("a"."c" > %s)) OVER ()') + 'RANK("a"."c") FILTER (WHERE "a"."c" > %s) OVER ()') self.assertEqual(function.params, (0,)) + + def test_invalid_filter(self): + with self.assertRaises(ValueError): + WindowFunction(filter_='foo', window=Window([])) diff --git a/sql/tests/test_grouping.py b/sql/tests/test_grouping.py new file mode 100644 index 0000000..7da9757 --- /dev/null +++ b/sql/tests/test_grouping.py @@ -0,0 +1,13 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import Grouping + + +class TestGrouping(unittest.TestCase): + + def test_invalid_sets(self): + with self.assertRaises(ValueError): + Grouping('foo') diff --git a/sql/tests/test_insert.py b/sql/tests/test_insert.py index 1921bf8..02a5439 100644 --- a/sql/tests/test_insert.py +++ b/sql/tests/test_insert.py @@ -2,30 +2,42 @@ # this repository contains the full copyright notices and license terms. import unittest -from sql import Table, With +from sql import Conflict, Excluded, Insert, Table, With from sql.functions import Abs class TestInsert(unittest.TestCase): table = Table('t') + def test_insert_invalid_table(self): + with self.assertRaises(ValueError): + Insert('foo') + + def test_insert_invalid_columns(self): + with self.assertRaises(ValueError): + self.table.insert(['foo'], [['foo']]) + + def test_insert_invalid_values(self): + with self.assertRaises(ValueError): + self.table.insert([self.table.c], 'foo') + def test_insert_default(self): query = self.table.insert() - self.assertEqual(str(query), 'INSERT INTO "t" AS "a" DEFAULT VALUES') + self.assertEqual(str(query), 'INSERT INTO "t" DEFAULT VALUES') self.assertEqual(tuple(query.params), ()) def test_insert_values(self): query = self.table.insert([self.table.c1, self.table.c2], [['foo', 'bar']]) self.assertEqual(str(query), - 'INSERT INTO "t" AS "a" ("c1", "c2") VALUES (%s, %s)') + 'INSERT INTO "t" ("c1", "c2") VALUES (%s, %s)') self.assertEqual(tuple(query.params), ('foo', 'bar')) def test_insert_many_values(self): query = self.table.insert([self.table.c1, self.table.c2], [['foo', 'bar'], ['spam', 'eggs']]) self.assertEqual(str(query), - 'INSERT INTO "t" AS "a" ("c1", "c2") VALUES (%s, %s), (%s, %s)') + 'INSERT INTO "t" ("c1", "c2") VALUES (%s, %s), (%s, %s)') self.assertEqual(tuple(query.params), ('foo', 'bar', 'spam', 'eggs')) def test_insert_subselect(self): @@ -34,14 +46,14 @@ def test_insert_subselect(self): subquery = t2.select(t2.c1, t2.c2) query = t1.insert([t1.c1, t1.c2], subquery) self.assertEqual(str(query), - 'INSERT INTO "t1" AS "b" ("c1", "c2") ' + 'INSERT INTO "t1" ("c1", "c2") ' 'SELECT "a"."c1", "a"."c2" FROM "t2" AS "a"') self.assertEqual(tuple(query.params), ()) def test_insert_function(self): query = self.table.insert([self.table.c], [[Abs(-1)]]) self.assertEqual(str(query), - 'INSERT INTO "t" AS "a" ("c") VALUES (ABS(%s))') + 'INSERT INTO "t" ("c") VALUES (ABS(%s))') self.assertEqual(tuple(query.params), (-1,)) def test_insert_returning(self): @@ -61,9 +73,13 @@ def test_insert_returning_select(self): self.assertEqual(str(query), 'INSERT INTO "t1" AS "b" ("c") VALUES (%s) ' 'RETURNING (SELECT "a"."c" FROM "t2" AS "a" ' - 'WHERE (("a"."c1" = "b"."c") AND ("a"."c2" = %s)))') + 'WHERE ("a"."c1" = "b"."c") AND ("a"."c2" = %s))') self.assertEqual(tuple(query.params), ('foo', 'bar')) + def test_insert_invalid_returning(self): + with self.assertRaises(ValueError): + self.table.insert(returning='foo') + def test_with(self): t1 = Table('t1') w = With(query=t1.select()) @@ -74,7 +90,7 @@ def test_with(self): values=w.select()) self.assertEqual(str(query), 'WITH "a" AS (SELECT * FROM "t1" AS "b") ' - 'INSERT INTO "t" AS "c" ("c1") SELECT * FROM "a" AS "a"') + 'INSERT INTO "t" ("c1") SELECT * FROM "a" AS "a"') self.assertEqual(tuple(query.params), ()) def test_insert_in_with(self): @@ -101,5 +117,138 @@ def test_schema(self): query = t1.insert([t1.c1], [['foo']]) self.assertEqual(str(query), - 'INSERT INTO "default"."t1" AS "a" ("c1") VALUES (%s)') + 'INSERT INTO "default"."t1" ("c1") VALUES (%s)') + self.assertEqual(tuple(query.params), ('foo',)) + + def test_upsert_invalid_on_conflict(self): + with self.assertRaises(ValueError): + self.table.insert(on_conflict='foo') + + def test_upsert_invalid_table_on_conflict(self): + with self.assertRaises(ValueError): + self.table.insert(on_conflict=Conflict(Table('t1'))) + + def test_upsert_nothing(self): + query = self.table.insert( + [self.table.c1], [['foo']], + on_conflict=Conflict(self.table)) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT DO NOTHING') + self.assertEqual(tuple(query.params), ('foo',)) + + def test_upsert_indexed_column(self): + query = self.table.insert( + [self.table.c1], [['foo']], + on_conflict=Conflict( + self.table, + indexed_columns=[self.table.c1, self.table.c2])) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT ("c1", "c2") DO NOTHING') self.assertEqual(tuple(query.params), ('foo',)) + + def test_upsert_indexed_column_index_where(self): + query = self.table.insert( + [self.table.c1], [['foo']], + on_conflict=Conflict( + self.table, + indexed_columns=[self.table.c1], + index_where=self.table.c2 == 'bar')) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT ("c1") WHERE "a"."c2" = %s DO NOTHING') + self.assertEqual(tuple(query.params), ('foo', 'bar')) + + def test_upsert_update(self): + query = self.table.insert( + [self.table.c1], [['baz']], + on_conflict=Conflict( + self.table, + columns=[self.table.c1, self.table.c2], + values=['foo', 'bar'])) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT DO UPDATE SET ("c1", "c2") = (%s, %s)') + self.assertEqual(tuple(query.params), ('baz', 'foo', 'bar')) + + def test_upsert_update_where(self): + query = self.table.insert( + [self.table.c1], [['baz']], + on_conflict=Conflict( + self.table, + columns=[self.table.c1], + values=['foo'], + where=self.table.c2 == 'bar')) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT DO UPDATE SET "c1" = (%s) ' + 'WHERE "a"."c2" = %s') + self.assertEqual(tuple(query.params), ('baz', 'foo', 'bar')) + + def test_upsert_update_subquery(self): + t1 = Table('t1') + t2 = Table('t2') + subquery = t2.select(t2.c1, t2.c2) + query = t1.insert( + [t1.c1], [['baz']], + on_conflict=Conflict( + t1, + columns=[t1.c1, t1.c2], + values=subquery)) + + self.assertEqual(str(query), + 'INSERT INTO "t1" AS "b" ("c1") VALUES (%s) ' + 'ON CONFLICT DO UPDATE SET ("c1", "c2") = ' + '(SELECT "a"."c1", "a"."c2" FROM "t2" AS "a")') + self.assertEqual(tuple(query.params), ('baz',)) + + def test_upsert_update_excluded(self): + query = self.table.insert( + [self.table.c1], [[1]], + on_conflict=Conflict( + self.table, + columns=[self.table.c1], + values=[Excluded.c1 + 2])) + + self.assertEqual(str(query), + 'INSERT INTO "t" AS "a" ("c1") VALUES (%s) ' + 'ON CONFLICT DO UPDATE SET "c1" = ("EXCLUDED"."c1" + %s)') + self.assertEqual(tuple(query.params), (1, 2)) + + def test_conflict_invalid_table(self): + with self.assertRaises(ValueError): + Conflict('foo') + + def test_conflict_invalid_indexed_columns(self): + with self.assertRaises(ValueError): + Conflict(self.table, indexed_columns=['foo']) + + def test_conflict_indexed_columns_invalid_table(self): + with self.assertRaises(ValueError): + Conflict(self.table, indexed_columns=[Table('t').c]) + + def test_conflict_invalid_index_where(self): + with self.assertRaises(ValueError): + Conflict(self.table, index_where='foo') + + def test_conflict_invalid_columns(self): + with self.assertRaises(ValueError): + Conflict(self.table, columns=['foo']) + + def test_conflict_columns_invalid_table(self): + with self.assertRaises(ValueError): + Conflict(self.table, columns=[Table('t').c]) + + def test_conflict_invalid_values(self): + with self.assertRaises(ValueError): + Conflict(self.table, values='foo') + + def test_conflict_invalid_where(self): + with self.assertRaises(ValueError): + Conflict(self.table, where='foo') diff --git a/sql/tests/test_join.py b/sql/tests/test_join.py index 7e537c9..1c8c576 100644 --- a/sql/tests/test_join.py +++ b/sql/tests/test_join.py @@ -18,7 +18,23 @@ def test_join(self): join.condition = t1.c == t2.c with AliasManager(): self.assertEqual(str(join), - '"t1" AS "a" INNER JOIN "t2" AS "b" ON ("a"."c" = "b"."c")') + '"t1" AS "a" INNER JOIN "t2" AS "b" ON "a"."c" = "b"."c"') + + def test_join_invalid_left(self): + with self.assertRaises(ValueError): + Join('foo', Table('t1')) + + def test_join_invalid_right(self): + with self.assertRaises(ValueError): + Join(Table('t1'), 'foo') + + def test_join_invalid_condition(self): + with self.assertRaises(ValueError): + Join(Table('t1'), Table('t2'), condition='foo') + + def test_join_invalid_type(self): + with self.assertRaises(ValueError): + Join(Table('t1'), Table('t2'), type_='foo') def test_join_subselect(self): t1 = Table('t1') @@ -29,7 +45,7 @@ def test_join_subselect(self): with AliasManager(): self.assertEqual(str(join), '"t1" AS "a" INNER JOIN (SELECT * FROM "t2" AS "c") AS "b" ' - 'ON ("a"."c" = "b"."c")') + 'ON "a"."c" = "b"."c"') self.assertEqual(tuple(join.params), ()) def test_join_function(self): @@ -50,3 +66,10 @@ def test_join_methods(self): join = getattr(t1, method)(t2) type_ = method[:-len('_join')].replace('_', ' ').upper() self.assertEqual(join.type_, type_) + + def test_join_alias(self): + join = Join(Table('t1'), Table('t2')) + with self.assertRaises(AttributeError): + join.alias + with self.assertRaises(AttributeError): + join.has_alias diff --git a/sql/tests/test_lateral.py b/sql/tests/test_lateral.py index d70bee0..4764079 100644 --- a/sql/tests/test_lateral.py +++ b/sql/tests/test_lateral.py @@ -11,12 +11,12 @@ class TestLateral(unittest.TestCase): def test_lateral_select(self): t1 = Table('t1') t2 = Table('t2') - lateral = Lateral(t2.select(where=t2.id == t1.t2)) + lateral = t2.select(where=t2.id == t1.t2).lateral() query = From([t1, lateral]).select() self.assertEqual(str(query), 'SELECT * FROM "t1" AS "a", LATERAL ' - '(SELECT * FROM "t2" AS "c" WHERE ("c"."id" = "a"."t2")) AS "b"') + '(SELECT * FROM "t2" AS "c" WHERE "c"."id" = "a"."t2") AS "b"') self.assertEqual(tuple(query.params), ()) def test_lateral_function(self): diff --git a/sql/tests/test_merge.py b/sql/tests/test_merge.py new file mode 100644 index 0000000..08bebae --- /dev/null +++ b/sql/tests/test_merge.py @@ -0,0 +1,149 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import ( + Literal, Matched, MatchedDelete, MatchedUpdate, Merge, NotMatched, + NotMatchedInsert, Table, With) + + +class TestMerge(unittest.TestCase): + target = Table('t') + source = Table('s') + + def test_merge(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, Matched()) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN MATCHED THEN DO NOTHING') + self.assertEqual(query.params, ()) + + def test_merge_invalid_target(self): + with self.assertRaises(ValueError): + Merge('foo', self.source, Literal(True)) + + def test_merge_invalid_source(self): + with self.assertRaises(ValueError): + self.target.merge('foo', Literal(True)) + + def test_merge_invalid_condition(self): + with self.assertRaises(ValueError): + self.target.merge(self.source, 'foo') + + def test_merge_invalid_whens(self): + with self.assertRaises(ValueError): + self.target.merge(self.source, Literal(True), 'foo') + + def test_condition(self): + query = self.target.merge( + self.source, + (self.target.c1 == self.source.c2) & (self.target.c3 == 42), + Matched()) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON ("a"."c1" = "b"."c2") AND ("a"."c3" = %s) ' + 'WHEN MATCHED THEN DO NOTHING') + self.assertEqual(query.params, (42,)) + + def test_matched(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, + Matched((self.source.c3 == 42) + & (self.target.c4 == self.source.c5))) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN MATCHED ' + 'AND ("b"."c3" = %s) AND ("a"."c4" = "b"."c5") ' + 'THEN DO NOTHING') + self.assertEqual(query.params, (42,)) + + def test_matched_update(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, + MatchedUpdate( + [self.target.c1, self.target.c2], + [self.target.c1 + self.source.c2, 42])) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN MATCHED THEN ' + 'UPDATE SET "c1" = "a"."c1" + "b"."c2", "c2" = %s') + self.assertEqual(query.params, (42,)) + + def test_matched_delete(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, MatchedDelete()) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN MATCHED THEN DELETE') + self.assertEqual(query.params, ()) + + def test_not_matched(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, NotMatched()) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN NOT MATCHED THEN DO NOTHING') + self.assertEqual(query.params, ()) + + def test_not_matched_insert(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, + NotMatchedInsert( + [self.target.c1, self.target.c2], + [self.source.c3, self.source.c4])) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN NOT MATCHED THEN ' + 'INSERT ("c1", "c2") VALUES ("b"."c3", "b"."c4")') + self.assertEqual(query.params, ()) + + def test_not_matched_insert_default(self): + query = self.target.merge( + self.source, self.target.c1 == self.source.c2, + NotMatchedInsert([self.target.c1, self.target.c2], None)) + self.assertEqual( + str(query), + 'MERGE INTO "t" AS "a" USING "s" AS "b" ' + 'ON "a"."c1" = "b"."c2" ' + 'WHEN NOT MATCHED THEN ' + 'INSERT ("c1", "c2") DEFAULT VALUES') + self.assertEqual(query.params, ()) + + def test_matched_invalid_condition(self): + with self.assertRaises(ValueError): + Matched('foo') + + def test_matched_values_invalid_columns(self): + with self.assertRaises(ValueError): + MatchedUpdate('foo', []) + + def test_with(self): + t1 = Table('t1') + w = With(query=t1.select(where=t1.c2 == 42)) + source = w.select() + + query = self.target.merge( + source, self.target.c1 == source.c2, Matched(), with_=[w]) + self.assertEqual( + str(query), + 'WITH "a" AS (SELECT * FROM "t1" AS "d" WHERE "d"."c2" = %s) ' + 'MERGE INTO "t" AS "b" ' + 'USING (SELECT * FROM "a" AS "a") AS "c" ' + 'ON "b"."c1" = "c"."c2" ' + 'WHEN MATCHED THEN DO NOTHING') + self.assertEqual(query.params, (42,)) diff --git a/sql/tests/test_operators.py b/sql/tests/test_operators.py index b126454..4fed465 100644 --- a/sql/tests/test_operators.py +++ b/sql/tests/test_operators.py @@ -6,42 +6,53 @@ from sql import Flavor, Literal, Null, Table from sql.operators import ( - Abs, And, Between, Div, Equal, Exists, FloorDiv, Greater, GreaterEqual, - ILike, In, Is, IsDistinct, IsNot, IsNotDistinct, Less, LessEqual, Like, - LShift, Mod, Mul, Neg, Not, NotBetween, NotEqual, NotILike, NotIn, NotLike, - Or, Pos, Pow, RShift, Sub) + Abs, And, Any, Between, Div, Equal, Exists, FloorDiv, Greater, + GreaterEqual, ILike, In, Is, IsDistinct, IsNot, IsNotDistinct, Less, + LessEqual, Like, LShift, Mod, Mul, Neg, Not, NotBetween, NotEqual, + NotILike, NotIn, NotLike, Operator, Or, Pos, Pow, RShift, Sub) class TestOperators(unittest.TestCase): table = Table('t') + def test_operator_operands(self): + self.assertEqual(Operator()._operands, ()) + + def test_operator_str(self): + with self.assertRaises(NotImplementedError): + str(Operator()) + def test_and(self): for and_ in [And((self.table.c1, self.table.c2)), self.table.c1 & self.table.c2]: - self.assertEqual(str(and_), '("c1" AND "c2")') + self.assertEqual(str(and_), '"c1" AND "c2"') self.assertEqual(and_.params, ()) and_ = And((Literal(True), self.table.c2)) - self.assertEqual(str(and_), '(%s AND "c2")') + self.assertEqual(str(and_), '%s AND "c2"') self.assertEqual(and_.params, (True,)) + and_ = And((Literal(True), 'foo')) + self.assertEqual(str(and_), '%s AND %s') + self.assertEqual(and_.params, (True, 'foo')) + def test_operator_operators(self): and_ = And((Literal(True), self.table.c1)) and2 = and_ & And((Literal(True), self.table.c2)) - self.assertEqual(str(and2), '((%s AND "c1") AND %s AND "c2")') + self.assertEqual(str(and2), '(%s AND "c1") AND %s AND "c2"') self.assertEqual(and2.params, (True, True)) and3 = and_ & Literal(True) - self.assertEqual(str(and3), '((%s AND "c1") AND %s)') + self.assertEqual(str(and3), '(%s AND "c1") AND %s') self.assertEqual(and3.params, (True, True)) or_ = Or((Literal(True), self.table.c1)) or2 = or_ | Or((Literal(True), self.table.c2)) - self.assertEqual(str(or2), '((%s OR "c1") OR %s OR "c2")') + self.assertEqual(str(or2), '(%s OR "c1") OR %s OR "c2"') self.assertEqual(or2.params, (True, True)) or3 = or_ | Literal(True) - self.assertEqual(str(or3), '((%s OR "c1") OR %s)') + self.assertEqual(str(or3), '(%s OR "c1") OR %s') self.assertEqual(or3.params, (True, True)) def test_operator_compat_column(self): @@ -52,192 +63,196 @@ def test_operator_compat_column(self): def test_or(self): for or_ in [Or((self.table.c1, self.table.c2)), self.table.c1 | self.table.c2]: - self.assertEqual(str(or_), '("c1" OR "c2")') + self.assertEqual(str(or_), '"c1" OR "c2"') self.assertEqual(or_.params, ()) def test_not(self): for not_ in [Not(self.table.c), ~self.table.c]: - self.assertEqual(str(not_), '(NOT "c")') + self.assertEqual(str(not_), 'NOT "c"') self.assertEqual(not_.params, ()) not_ = Not(Literal(False)) - self.assertEqual(str(not_), '(NOT %s)') + self.assertEqual(str(not_), 'NOT %s') self.assertEqual(not_.params, (False,)) def test_neg(self): for neg in [Neg(self.table.c1), -self.table.c1]: - self.assertEqual(str(neg), '(- "c1")') + self.assertEqual(str(neg), '- "c1"') self.assertEqual(neg.params, ()) def test_pos(self): for pos in [Pos(self.table.c1), +self.table.c1]: - self.assertEqual(str(pos), '(+ "c1")') + self.assertEqual(str(pos), '+ "c1"') self.assertEqual(pos.params, ()) def test_less(self): for less in [Less(self.table.c1, self.table.c2), self.table.c1 < self.table.c2, ~GreaterEqual(self.table.c1, self.table.c2)]: - self.assertEqual(str(less), '("c1" < "c2")') + self.assertEqual(str(less), '"c1" < "c2"') self.assertEqual(less.params, ()) less = Less(Literal(0), self.table.c2) - self.assertEqual(str(less), '(%s < "c2")') + self.assertEqual(str(less), '%s < "c2"') self.assertEqual(less.params, (0,)) def test_greater(self): for greater in [Greater(self.table.c1, self.table.c2), self.table.c1 > self.table.c2, ~LessEqual(self.table.c1, self.table.c2)]: - self.assertEqual(str(greater), '("c1" > "c2")') + self.assertEqual(str(greater), '"c1" > "c2"') self.assertEqual(greater.params, ()) def test_less_equal(self): for less in [LessEqual(self.table.c1, self.table.c2), self.table.c1 <= self.table.c2, ~Greater(self.table.c1, self.table.c2)]: - self.assertEqual(str(less), '("c1" <= "c2")') + self.assertEqual(str(less), '"c1" <= "c2"') self.assertEqual(less.params, ()) def test_greater_equal(self): for greater in [GreaterEqual(self.table.c1, self.table.c2), self.table.c1 >= self.table.c2, ~Less(self.table.c1, self.table.c2)]: - self.assertEqual(str(greater), '("c1" >= "c2")') + self.assertEqual(str(greater), '"c1" >= "c2"') self.assertEqual(greater.params, ()) def test_equal(self): for equal in [Equal(self.table.c1, self.table.c2), self.table.c1 == self.table.c2, ~NotEqual(self.table.c1, self.table.c2)]: - self.assertEqual(str(equal), '("c1" = "c2")') + self.assertEqual(str(equal), '"c1" = "c2"') self.assertEqual(equal.params, ()) equal = Equal(Literal('foo'), Literal('bar')) - self.assertEqual(str(equal), '(%s = %s)') + self.assertEqual(str(equal), '%s = %s') self.assertEqual(equal.params, ('foo', 'bar')) equal = Equal(self.table.c1, Null) - self.assertEqual(str(equal), '("c1" IS NULL)') + self.assertEqual(str(equal), '"c1" IS NULL') self.assertEqual(equal.params, ()) equal = Equal(Literal('test'), Null) - self.assertEqual(str(equal), '(%s IS NULL)') + self.assertEqual(str(equal), '%s IS NULL') self.assertEqual(equal.params, ('test',)) equal = Equal(Null, self.table.c1) - self.assertEqual(str(equal), '("c1" IS NULL)') + self.assertEqual(str(equal), '"c1" IS NULL') self.assertEqual(equal.params, ()) equal = Equal(Null, Literal('test')) - self.assertEqual(str(equal), '(%s IS NULL)') + self.assertEqual(str(equal), '%s IS NULL') self.assertEqual(equal.params, ('test',)) def test_not_equal(self): for equal in [NotEqual(self.table.c1, self.table.c2), self.table.c1 != self.table.c2, ~Equal(self.table.c1, self.table.c2)]: - self.assertEqual(str(equal), '("c1" != "c2")') + self.assertEqual(str(equal), '"c1" != "c2"') self.assertEqual(equal.params, ()) equal = NotEqual(self.table.c1, Null) - self.assertEqual(str(equal), '("c1" IS NOT NULL)') + self.assertEqual(str(equal), '"c1" IS NOT NULL') self.assertEqual(equal.params, ()) equal = NotEqual(Null, self.table.c1) - self.assertEqual(str(equal), '("c1" IS NOT NULL)') + self.assertEqual(str(equal), '"c1" IS NOT NULL') self.assertEqual(equal.params, ()) def test_between(self): for between in [Between(self.table.c1, 1, 2), ~NotBetween(self.table.c1, 1, 2)]: - self.assertEqual(str(between), '("c1" BETWEEN %s AND %s)') + self.assertEqual(str(between), '"c1" BETWEEN %s AND %s') self.assertEqual(between.params, (1, 2)) between = Between( self.table.c1, self.table.c2, self.table.c3, symmetric=True) self.assertEqual( - str(between), '("c1" BETWEEN SYMMETRIC "c2" AND "c3")') + str(between), '"c1" BETWEEN SYMMETRIC "c2" AND "c3"') self.assertEqual(between.params, ()) def test_not_between(self): for between in [NotBetween(self.table.c1, 1, 2), ~Between(self.table.c1, 1, 2)]: - self.assertEqual(str(between), '("c1" NOT BETWEEN %s AND %s)') + self.assertEqual(str(between), '"c1" NOT BETWEEN %s AND %s') self.assertEqual(between.params, (1, 2)) between = NotBetween( self.table.c1, self.table.c2, self.table.c3, symmetric=True) self.assertEqual( - str(between), '("c1" NOT BETWEEN SYMMETRIC "c2" AND "c3")') + str(between), '"c1" NOT BETWEEN SYMMETRIC "c2" AND "c3"') self.assertEqual(between.params, ()) def test_is_distinct(self): for distinct in [IsDistinct(self.table.c1, self.table.c2), ~IsNotDistinct(self.table.c1, self.table.c2)]: - self.assertEqual(str(distinct), '("c1" IS DISTINCT FROM "c2")') + self.assertEqual(str(distinct), '"c1" IS DISTINCT FROM "c2"') self.assertEqual(distinct.params, ()) def test_is_not_distinct(self): for distinct in [IsNotDistinct(self.table.c1, self.table.c2), ~IsDistinct(self.table.c1, self.table.c2)]: - self.assertEqual(str(distinct), '("c1" IS NOT DISTINCT FROM "c2")') + self.assertEqual(str(distinct), '"c1" IS NOT DISTINCT FROM "c2"') self.assertEqual(distinct.params, ()) def test_is(self): for is_ in [Is(self.table.c1, None), ~IsNot(self.table.c1, None)]: - self.assertEqual(str(is_), '("c1" IS UNKNOWN)') + self.assertEqual(str(is_), '"c1" IS UNKNOWN') self.assertEqual(is_.params, ()) for is_ in [Is(self.table.c1, True), ~IsNot(self.table.c1, True)]: - self.assertEqual(str(is_), '("c1" IS TRUE)') + self.assertEqual(str(is_), '"c1" IS TRUE') self.assertEqual(is_.params, ()) for is_ in [Is(self.table.c1, False), ~IsNot(self.table.c1, False)]: - self.assertEqual(str(is_), '("c1" IS FALSE)') + self.assertEqual(str(is_), '"c1" IS FALSE') self.assertEqual(is_.params, ()) + def test_is_invalid_right(self): + with self.assertRaises(ValueError): + Is(self.table.c, 'foo') + def test_is_not(self): for is_ in [IsNot(self.table.c1, None), ~Is(self.table.c1, None)]: - self.assertEqual(str(is_), '("c1" IS NOT UNKNOWN)') + self.assertEqual(str(is_), '"c1" IS NOT UNKNOWN') self.assertEqual(is_.params, ()) for is_ in [IsNot(self.table.c1, True), ~Is(self.table.c1, True)]: - self.assertEqual(str(is_), '("c1" IS NOT TRUE)') + self.assertEqual(str(is_), '"c1" IS NOT TRUE') self.assertEqual(is_.params, ()) for is_ in [IsNot(self.table.c1, False), ~Is(self.table.c1, False)]: - self.assertEqual(str(is_), '("c1" IS NOT FALSE)') + self.assertEqual(str(is_), '"c1" IS NOT FALSE') self.assertEqual(is_.params, ()) def test_sub(self): for sub in [Sub(self.table.c1, self.table.c2), self.table.c1 - self.table.c2]: - self.assertEqual(str(sub), '("c1" - "c2")') + self.assertEqual(str(sub), '"c1" - "c2"') self.assertEqual(sub.params, ()) def test_mul(self): for mul in [Mul(self.table.c1, self.table.c2), self.table.c1 * self.table.c2]: - self.assertEqual(str(mul), '("c1" * "c2")') + self.assertEqual(str(mul), '"c1" * "c2"') self.assertEqual(mul.params, ()) def test_div(self): for div in [Div(self.table.c1, self.table.c2), self.table.c1 / self.table.c2]: - self.assertEqual(str(div), '("c1" / "c2")') + self.assertEqual(str(div), '"c1" / "c2"') self.assertEqual(div.params, ()) def test_mod(self): for mod in [Mod(self.table.c1, self.table.c2), self.table.c1 % self.table.c2]: - self.assertEqual(str(mod), '("c1" %% "c2")') + self.assertEqual(str(mod), '"c1" %% "c2"') self.assertEqual(mod.params, ()) def test_mod_paramstyle(self): @@ -245,7 +260,7 @@ def test_mod_paramstyle(self): Flavor.set(flavor) try: mod = Mod(self.table.c1, self.table.c2) - self.assertEqual(str(mod), '("c1" %% "c2")') + self.assertEqual(str(mod), '"c1" %% "c2"') self.assertEqual(mod.params, ()) finally: Flavor.set(Flavor()) @@ -254,7 +269,7 @@ def test_mod_paramstyle(self): Flavor.set(flavor) try: mod = Mod(self.table.c1, self.table.c2) - self.assertEqual(str(mod), '("c1" % "c2")') + self.assertEqual(str(mod), '"c1" % "c2"') self.assertEqual(mod.params, ()) finally: Flavor.set(Flavor()) @@ -262,24 +277,24 @@ def test_mod_paramstyle(self): def test_pow(self): for pow_ in [Pow(self.table.c1, self.table.c2), self.table.c1 ** self.table.c2]: - self.assertEqual(str(pow_), '("c1" ^ "c2")') + self.assertEqual(str(pow_), '"c1" ^ "c2"') self.assertEqual(pow_.params, ()) def test_abs(self): for abs_ in [Abs(self.table.c1), abs(self.table.c1)]: - self.assertEqual(str(abs_), '(@ "c1")') + self.assertEqual(str(abs_), '@ "c1"') self.assertEqual(abs_.params, ()) def test_lshift(self): for lshift in [LShift(self.table.c1, 2), self.table.c1 << 2]: - self.assertEqual(str(lshift), '("c1" << %s)') + self.assertEqual(str(lshift), '"c1" << %s') self.assertEqual(lshift.params, (2,)) def test_rshift(self): for rshift in [RShift(self.table.c1, 2), self.table.c1 >> 2]: - self.assertEqual(str(rshift), '("c1" >> %s)') + self.assertEqual(str(rshift), '"c1" >> %s') self.assertEqual(rshift.params, (2,)) def test_like(self): @@ -287,20 +302,20 @@ def test_like(self): self.table.c1.like('foo'), ~NotLike(self.table.c1, 'foo'), ~~Like(self.table.c1, 'foo')]: - self.assertEqual(str(like), '("c1" LIKE %s ESCAPE %s)') - self.assertEqual(like.params, ('foo', '\\')) + self.assertEqual(str(like), '"c1" LIKE %s') + self.assertEqual(like.params, ('foo',)) def test_like_escape(self): like = Like(self.table.c1, 'foo', escape='$') - self.assertEqual(str(like), '("c1" LIKE %s ESCAPE %s)') + self.assertEqual(str(like), '"c1" LIKE %s ESCAPE %s') self.assertEqual(like.params, ('foo', '$')) def test_like_escape_empty_false(self): flavor = Flavor(escape_empty=False) Flavor.set(flavor) try: - like = Like(self.table.c1, 'foo', escape='') - self.assertEqual(str(like), '("c1" LIKE %s)') + like = Like(self.table.c1, 'foo') + self.assertEqual(str(like), '"c1" LIKE %s') self.assertEqual(like.params, ('foo',)) finally: Flavor.set(Flavor()) @@ -309,12 +324,16 @@ def test_like_escape_empty_true(self): flavor = Flavor(escape_empty=True) Flavor.set(flavor) try: - like = Like(self.table.c1, 'foo', escape='') - self.assertEqual(str(like), '("c1" LIKE %s ESCAPE %s)') + like = Like(self.table.c1, 'foo') + self.assertEqual(str(like), '"c1" LIKE %s ESCAPE %s') self.assertEqual(like.params, ('foo', '')) finally: Flavor.set(Flavor()) + def test_like_invalid_escape(self): + with self.assertRaises(ValueError): + Like(self.table.c, 'test', escape='fo') + def test_ilike(self): flavor = Flavor(ilike=True) Flavor.set(flavor) @@ -322,8 +341,8 @@ def test_ilike(self): for like in [ILike(self.table.c1, 'foo'), self.table.c1.ilike('foo'), ~NotILike(self.table.c1, 'foo')]: - self.assertEqual(str(like), '("c1" ILIKE %s ESCAPE %s)') - self.assertEqual(like.params, ('foo', '\\')) + self.assertEqual(str(like), '"c1" ILIKE %s') + self.assertEqual(like.params, ('foo',)) finally: Flavor.set(Flavor()) @@ -332,8 +351,8 @@ def test_ilike(self): try: like = ILike(self.table.c1, 'foo') self.assertEqual( - str(like), '(UPPER("c1") LIKE UPPER(%s) ESCAPE %s)') - self.assertEqual(like.params, ('foo', '\\')) + str(like), 'UPPER("c1") LIKE UPPER(%s)') + self.assertEqual(like.params, ('foo',)) finally: Flavor.set(Flavor()) @@ -343,8 +362,8 @@ def test_not_ilike(self): try: for like in [NotILike(self.table.c1, 'foo'), ~self.table.c1.ilike('foo')]: - self.assertEqual(str(like), '("c1" NOT ILIKE %s ESCAPE %s)') - self.assertEqual(like.params, ('foo', '\\')) + self.assertEqual(str(like), '"c1" NOT ILIKE %s') + self.assertEqual(like.params, ('foo',)) finally: Flavor.set(Flavor()) @@ -353,8 +372,8 @@ def test_not_ilike(self): try: like = NotILike(self.table.c1, 'foo') self.assertEqual( - str(like), '(UPPER("c1") NOT LIKE UPPER(%s) ESCAPE %s)') - self.assertEqual(like.params, ('foo', '\\')) + str(like), 'UPPER("c1") NOT LIKE UPPER(%s)') + self.assertEqual(like.params, ('foo',)) finally: Flavor.set(Flavor()) @@ -362,32 +381,31 @@ def test_in(self): for in_ in [In(self.table.c1, [self.table.c2, 1, Null]), ~NotIn(self.table.c1, [self.table.c2, 1, Null]), ~~In(self.table.c1, [self.table.c2, 1, Null])]: - self.assertEqual(str(in_), '("c1" IN ("c2", %s, %s))') + self.assertEqual(str(in_), '"c1" IN ("c2", %s, %s)') self.assertEqual(in_.params, (1, None)) t2 = Table('t2') in_ = In(self.table.c1, t2.select(t2.c2)) self.assertEqual(str(in_), - '("c1" IN (SELECT "a"."c2" FROM "t2" AS "a"))') + '"c1" IN (SELECT "a"."c2" FROM "t2" AS "a")') self.assertEqual(in_.params, ()) in_ = In(self.table.c1, t2.select(t2.c2) | t2.select(t2.c3)) self.assertEqual(str(in_), - '("c1" IN (SELECT "a"."c2" FROM "t2" AS "a" ' - 'UNION SELECT "a"."c3" FROM "t2" AS "a"))') + '"c1" IN (SELECT "a"."c2" FROM "t2" AS "a" ' + 'UNION SELECT "a"."c3" FROM "t2" AS "a")') self.assertEqual(in_.params, ()) in_ = In(self.table.c1, array('l', list(range(10)))) self.assertEqual(str(in_), - '("c1" IN (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s))') + '"c1" IN (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)') self.assertEqual(in_.params, tuple(range(10))) def test_exists(self): exists = Exists(self.table.select(self.table.c1, where=self.table.c1 == 1)) self.assertEqual(str(exists), - '(EXISTS (SELECT "a"."c1" FROM "t" AS "a" ' - 'WHERE ("a"."c1" = %s)))') + 'EXISTS (SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c1" = %s)') self.assertEqual(exists.params, (1,)) def test_floordiv(self): @@ -400,3 +418,21 @@ def test_floordiv(self): self.assertIn( 'FloorDiv operator is deprecated, use Div function', str(w[-1].message)) + + def test_any(self): + any_ = Any(self.table.select(self.table.c1, where=self.table.c2 == 1)) + self.assertEqual(str(any_), + 'ANY (SELECT "a"."c1" FROM "t" AS "a" WHERE "a"."c2" = %s)') + self.assertEqual(any_.params, (1,)) + + for value in [[1, 2, 3], (1, 2, 3), array('l', [1, 2, 3])]: + with self.subTest(value=value): + any_ = Any(value) + self.assertEqual(str(any_), 'ANY (%s)') + self.assertEqual(any_.params, ([1, 2, 3],)) + + def test_binary_unary(self): + operator = Equal(self.table.c1, Any([1, 2, 3])) + + self.assertEqual(str(operator), '"c1" = ANY (%s)') + self.assertEqual(operator.params, ([1, 2, 3],)) diff --git a/sql/tests/test_order.py b/sql/tests/test_order.py index f012d92..466950e 100644 --- a/sql/tests/test_order.py +++ b/sql/tests/test_order.py @@ -3,45 +3,50 @@ import unittest from sql import ( - Asc, Column, Desc, Flavor, Literal, NullsFirst, NullsLast, Table) + Asc, Column, Desc, Flavor, Literal, NullOrder, NullsFirst, NullsLast, + Order, Table) class TestOrder(unittest.TestCase): column = Column(Table('t'), 'c') def test_asc(self): - self.assertEqual(str(Asc(self.column)), '"c" ASC') + self.assertEqual(str(self.column.asc), '"c" ASC') def test_desc(self): - self.assertEqual(str(Desc(self.column)), '"c" DESC') + self.assertEqual(str(self.column.desc), '"c" DESC') def test_nulls_first(self): - self.assertEqual(str(NullsFirst(self.column)), '"c" NULLS FIRST') - self.assertEqual(str(NullsFirst(Asc(self.column))), + self.assertEqual(str(self.column.nulls_first), '"c" NULLS FIRST') + self.assertEqual(str(Asc(self.column).nulls_first), '"c" ASC NULLS FIRST') def test_nulls_last(self): - self.assertEqual(str(NullsLast(self.column)), '"c" NULLS LAST') - self.assertEqual(str(NullsLast(Asc(self.column))), + self.assertEqual(str(self.column.nulls_last), '"c" NULLS LAST') + self.assertEqual(str(Asc(self.column).nulls_last), '"c" ASC NULLS LAST') + def test_null_order_case_values(self): + with self.assertRaises(NotImplementedError): + NullOrder(self.column)._case_values() + def test_no_null_ordering(self): try: Flavor.set(Flavor(null_ordering=False)) exp = NullsFirst(self.column) self.assertEqual(str(exp), - 'CASE WHEN ("c" IS NULL) THEN %s ELSE %s END ASC, "c"') + 'CASE WHEN "c" IS NULL THEN %s ELSE %s END ASC, "c"') self.assertEqual(exp.params, (0, 1)) exp = NullsFirst(Desc(self.column)) self.assertEqual(str(exp), - 'CASE WHEN ("c" IS NULL) THEN %s ELSE %s END ASC, "c" DESC') + 'CASE WHEN "c" IS NULL THEN %s ELSE %s END ASC, "c" DESC') self.assertEqual(exp.params, (0, 1)) exp = NullsLast(Literal(2)) self.assertEqual(str(exp), - 'CASE WHEN (%s IS NULL) THEN %s ELSE %s END ASC, %s') + 'CASE WHEN %s IS NULL THEN %s ELSE %s END ASC, %s') self.assertEqual(exp.params, (2, 1, 0, 2)) finally: Flavor.set(Flavor()) @@ -54,3 +59,7 @@ def test_order_query(self): '(SELECT "a"."c" FROM "t" AS "a") ASC') self.assertEqual(str(Desc(query)), '(SELECT "a"."c" FROM "t" AS "a") DESC') + + def test_invalid_expression(self): + with self.assertRaises(ValueError): + Order('foo') diff --git a/sql/tests/test_rollup.py b/sql/tests/test_rollup.py new file mode 100644 index 0000000..8b29227 --- /dev/null +++ b/sql/tests/test_rollup.py @@ -0,0 +1,13 @@ +# This file is part of python-sql. The COPYRIGHT file at the top level of +# this repository contains the full copyright notices and license terms. + +import unittest + +from sql import Rollup + + +class TestRollup(unittest.TestCase): + + def test_invalid_expressions(self): + with self.assertRaises(ValueError): + Rollup('foo') diff --git a/sql/tests/test_select.py b/sql/tests/test_select.py index a41332a..232ea51 100644 --- a/sql/tests/test_select.py +++ b/sql/tests/test_select.py @@ -4,7 +4,9 @@ import warnings from copy import deepcopy -from sql import Flavor, For, Join, Literal, Select, Table, Union, Window, With +from sql import ( + Cube, Flavor, For, Grouping, Join, Literal, Rollup, Select, Table, Union, + Window, With) from sql.aggregate import Max, Min from sql.functions import DatePart, Function, Now, Rank @@ -29,9 +31,15 @@ def test_select2(self): def test_select3(self): query = self.table.select(where=(self.table.c == 'foo')) self.assertEqual(str(query), - 'SELECT * FROM "t" AS "a" WHERE ("a"."c" = %s)') + 'SELECT * FROM "t" AS "a" WHERE "a"."c" = %s') self.assertEqual(tuple(query.params), ('foo',)) + def test_select_iter(self): + query = self.table.select() + self.assertEqual( + tuple(query), + ('SELECT * FROM "t" AS "a"', ())) + def test_select_without_from(self): query = Select([Literal(1)]) self.assertEqual(str(query), 'SELECT %s') @@ -47,6 +55,14 @@ def test_select_select_as(self): self.assertEqual(str(query), 'SELECT (SELECT %s) AS "foo"') self.assertEqual(tuple(query.params), (1,)) + def test_select_invalid_column(self): + with self.assertRaises(ValueError): + Select(['foo']) + + def test_select_invalid_where(self): + with self.assertRaises(ValueError): + self.table.select(where='foo') + def test_select_distinct(self): query = self.table.select(self.table.c, distinct=True) self.assertEqual( @@ -66,6 +82,10 @@ def test_select_distinct_on(self): 'SELECT DISTINCT ON ("a"."a", "a"."b") "a"."c" FROM "t" AS "a"') self.assertEqual(tuple(query.params), ()) + def test_select_invalid_distinct_on(self): + with self.assertRaises(ValueError): + self.table.select(self.table.c, distinct_on='foo') + def test_select_from_list(self): t2 = Table('t2') t3 = Table('t3') @@ -90,7 +110,7 @@ def test_select_union(self): 'SELECT * FROM "t2" AS "c") AS "a"') query1.where = self.table.c == 'foo' self.assertEqual(str(union), - 'SELECT * FROM "t" AS "a" WHERE ("a"."c" = %s) UNION ALL ' + 'SELECT * FROM "t" AS "a" WHERE "a"."c" = %s UNION ALL ' 'SELECT * FROM "t2" AS "b"') self.assertEqual(tuple(union.params), ('foo',)) @@ -177,7 +197,7 @@ def test_select_group_by(self): output = column.as_('c1') query = self.table.select(output, group_by=output) self.assertEqual(str(query), - 'SELECT "a"."c" AS "c1" FROM "t" AS "a" GROUP BY "c1"') + 'SELECT "a"."c" AS "c1" FROM "t" AS "a" GROUP BY 1') self.assertEqual(tuple(query.params), ()) query = self.table.select(Literal('foo'), group_by=Literal('foo')) @@ -185,6 +205,73 @@ def test_select_group_by(self): 'SELECT %s FROM "t" AS "a" GROUP BY %s') self.assertEqual(tuple(query.params), ('foo', 'foo')) + output1 = column.as_('c1') + output2 = column.as_('c2') + query = self.table.select(output1, output2, group_by=output2) + self.assertEqual(str(query), + 'SELECT "a"."c" AS "c1", "a"."c" AS "c2" FROM "t" AS "a" ' + 'GROUP BY 2') + self.assertEqual(tuple(query.params), ()) + + query = self.table.select(column, group_by=output) + self.assertEqual(str(query), + 'SELECT "a"."c" FROM "t" AS "a" GROUP BY "c1"') + self.assertEqual(tuple(query.params), ()) + + def test_select_group_by_grouping_sets(self): + query = self.table.select( + Literal('*'), + group_by=Grouping((self.table.a, self.table.b), (Literal('foo'),))) + self.assertEqual(str(query), + 'SELECT %s FROM "t" AS "a" ' + 'GROUP BY GROUPING SETS (("a"."a", "a"."b"), (%s))') + self.assertEqual(tuple(query.params), ('*', 'foo',)) + + query = self.table.select( + Literal('*'), + group_by=[ + self.table.a, Grouping((self.table.b,), (self.table.c,))]) + self.assertEqual(str(query), + 'SELECT %s FROM "t" AS "a" ' + 'GROUP BY "a"."a", GROUPING SETS (("a"."b"), ("a"."c"))') + self.assertEqual(tuple(query.params), ('*',)) + + def test_select_group_by_rollup(self): + query = self.table.select( + Literal('*'), + group_by=Rollup(self.table.a, self.table.b, Literal('foo'))) + self.assertEqual(str(query), + 'SELECT %s FROM "t" AS "a" ' + 'GROUP BY ROLLUP ("a"."a", "a"."b", %s)') + self.assertEqual(tuple(query.params), ('*', 'foo')) + + query = self.table.select( + Literal('*'), + group_by=Rollup((self.table.a, self.table.b), self.table.c)) + self.assertEqual(str(query), + 'SELECT %s FROM "t" AS "a" ' + 'GROUP BY ROLLUP (("a"."a", "a"."b"), "a"."c")') + self.assertEqual(tuple(query.params), ('*',)) + + def test_select_group_by_cube(self): + query = self.table.select( + Literal('*'), + group_by=Cube(self.table.a, self.table.b)) + self.assertEqual(str(query), + 'SELECT %s FROM "t" AS "a" ' + 'GROUP BY CUBE ("a"."a", "a"."b")') + self.assertEqual(tuple(query.params), ('*',)) + + def test_select_invalid_group_by(self): + with self.assertRaises(ValueError): + self.table.select(group_by=['foo']) + + def test_select_invalid_group_by_alias(self): + query = self.table.select( + self.table.c1.as_('c'), group_by=self.table.c2.as_('c')) + with self.assertRaises(ValueError): + str(query) + def test_select_having(self): col1 = self.table.col1 col2 = self.table.col2 @@ -192,28 +279,48 @@ def test_select_having(self): having=(Min(col2) > 3)) self.assertEqual(str(query), 'SELECT "a"."col1", MIN("a"."col2") FROM "t" AS "a" ' - 'HAVING (MIN("a"."col2") > %s)') + 'HAVING MIN("a"."col2") > %s') self.assertEqual(tuple(query.params), (3,)) + def test_select_invalid_having(self): + with self.assertRaises(ValueError): + self.table.select(having='foo') + def test_select_order(self): - c = self.table.c - query = self.table.select(c, order_by=Literal(1)) + column = self.table.c + query = self.table.select(column, order_by=column) self.assertEqual(str(query), - 'SELECT "a"."c" FROM "t" AS "a" ORDER BY %s') - self.assertEqual(tuple(query.params), (1,)) + 'SELECT "a"."c" FROM "t" AS "a" ORDER BY "a"."c"') + self.assertEqual(tuple(query.params), ()) + + output = column.as_('c1') + query = self.table.select(output, order_by=output) + self.assertEqual(str(query), + 'SELECT "a"."c" AS "c1" FROM "t" AS "a" ORDER BY "c1"') + self.assertEqual(tuple(query.params), ()) + + def test_select_invalid_order(self): + with self.assertRaises(ValueError): + self.table.select(order_by='foo') + + def test_select_invalid_order_alias(self): + query = self.table.select( + self.table.c1.as_('c'), order_by=self.table.c2.as_('c')) + with self.assertRaises(ValueError): + str(query) def test_select_limit_offset(self): try: Flavor.set(Flavor(limitstyle='limit')) query = self.table.select(limit=50, offset=10) self.assertEqual(str(query), - 'SELECT * FROM "t" AS "a" LIMIT 50 OFFSET 10') - self.assertEqual(tuple(query.params), ()) + 'SELECT * FROM "t" AS "a" LIMIT %s OFFSET %s') + self.assertEqual(tuple(query.params), (50, 10)) query.limit = None self.assertEqual(str(query), - 'SELECT * FROM "t" AS "a" OFFSET 10') - self.assertEqual(tuple(query.params), ()) + 'SELECT * FROM "t" AS "a" OFFSET %s') + self.assertEqual(tuple(query.params), (10,)) query.offset = 0 self.assertEqual(str(query), @@ -234,24 +341,32 @@ def test_select_limit_offset(self): query.offset = 10 self.assertEqual(str(query), - 'SELECT * FROM "t" AS "a" LIMIT -1 OFFSET 10') - self.assertEqual(tuple(query.params), ()) + 'SELECT * FROM "t" AS "a" LIMIT -1 OFFSET %s') + self.assertEqual(tuple(query.params), (10,)) finally: Flavor.set(Flavor()) + def test_select_invalid_limit(self): + with self.assertRaises(ValueError): + self.table.select(limit='foo') + + def test_select_invalid_offset(self): + with self.assertRaises(ValueError): + self.table.select(offset='foo') + def test_select_offset_fetch(self): try: Flavor.set(Flavor(limitstyle='fetch')) query = self.table.select(limit=50, offset=10) self.assertEqual(str(query), 'SELECT * FROM "t" AS "a" ' - 'OFFSET (10) ROWS FETCH FIRST (50) ROWS ONLY') - self.assertEqual(tuple(query.params), ()) + 'OFFSET (%s) ROWS FETCH FIRST (%s) ROWS ONLY') + self.assertEqual(tuple(query.params), (10, 50)) query.limit = None self.assertEqual(str(query), - 'SELECT * FROM "t" AS "a" OFFSET (10) ROWS') - self.assertEqual(tuple(query.params), ()) + 'SELECT * FROM "t" AS "a" OFFSET (%s) ROWS') + self.assertEqual(tuple(query.params), (10,)) query.offset = 0 self.assertEqual(str(query), @@ -268,8 +383,8 @@ def test_select_rownum(self): 'SELECT "a".* FROM (' 'SELECT "b".*, ROWNUM AS "rnum" FROM (' 'SELECT * FROM "t" AS "c") AS "b" ' - 'WHERE (ROWNUM <= %s)) AS "a" ' - 'WHERE ("rnum" > %s)') + 'WHERE ROWNUM <= %s) AS "a" ' + 'WHERE "rnum" > %s') self.assertEqual(tuple(query.params), (60, 10)) query = self.table.select( @@ -280,8 +395,8 @@ def test_select_rownum(self): 'SELECT "b"."col1", "b"."col2", ROWNUM AS "rnum" FROM (' 'SELECT "c"."c1" AS "col1", "c"."c2" AS "col2" ' 'FROM "t" AS "c") AS "b" ' - 'WHERE (ROWNUM <= %s)) AS "a" ' - 'WHERE ("rnum" > %s)') + 'WHERE ROWNUM <= %s) AS "a" ' + 'WHERE "rnum" > %s') self.assertEqual(tuple(query.params), (60, 10)) subquery = query.select(query.col1, query.col2) @@ -292,8 +407,8 @@ def test_select_rownum(self): 'FROM (' 'SELECT "c"."c1" AS "col1", "c"."c2" AS "col2" ' 'FROM "t" AS "c") AS "a" ' - 'WHERE (ROWNUM <= %s)) AS "b" ' - 'WHERE ("rnum" > %s)) AS "a"') + 'WHERE ROWNUM <= %s) AS "b" ' + 'WHERE "rnum" > %s) AS "a"') # XXX alias of query is reused but not a problem # as it is hidden in subquery self.assertEqual(tuple(query.params), (60, 10)) @@ -304,15 +419,15 @@ def test_select_rownum(self): 'SELECT "a".* FROM (' 'SELECT "b".*, ROWNUM AS "rnum" FROM (' 'SELECT * FROM "t" AS "c" ORDER BY "c"."c") AS "b" ' - 'WHERE (ROWNUM <= %s)) AS "a" ' - 'WHERE ("rnum" > %s)') + 'WHERE ROWNUM <= %s) AS "a" ' + 'WHERE "rnum" > %s') self.assertEqual(tuple(query.params), (60, 10)) query = self.table.select(limit=50) self.assertEqual(str(query), 'SELECT "a".* FROM (' 'SELECT * FROM "t" AS "b") AS "a" ' - 'WHERE (ROWNUM <= %s)') + 'WHERE ROWNUM <= %s') self.assertEqual(tuple(query.params), (50,)) query = self.table.select(offset=10) @@ -320,7 +435,7 @@ def test_select_rownum(self): 'SELECT "a".* FROM (' 'SELECT "b".*, ROWNUM AS "rnum" FROM (' 'SELECT * FROM "t" AS "c") AS "b") AS "a" ' - 'WHERE ("rnum" > %s)') + 'WHERE "rnum" > %s') self.assertEqual(tuple(query.params), (10,)) query = self.table.select(self.table.c.as_('col'), @@ -330,9 +445,9 @@ def test_select_rownum(self): 'SELECT "a"."col" FROM (' 'SELECT "b"."col", ROWNUM AS "rnum" FROM (' 'SELECT "c"."c" AS "col" FROM "t" AS "c" ' - 'WHERE ("c"."c" >= %s)) AS "b" ' - 'WHERE (ROWNUM <= %s)) AS "a" ' - 'WHERE ("rnum" > %s)') + 'WHERE "c"."c" >= %s) AS "b" ' + 'WHERE ROWNUM <= %s) AS "a" ' + 'WHERE "rnum" > %s') self.assertEqual(tuple(query.params), (20, 60, 10)) finally: Flavor.set(Flavor()) @@ -344,6 +459,10 @@ def test_select_for(self): 'SELECT "a"."c" FROM "t" AS "a" FOR UPDATE') self.assertEqual(tuple(query.params), ()) + def test_select_invalid_for(self): + with self.assertRaises(ValueError): + self.table.select(for_=['foo']) + def test_copy(self): query = self.table.select() copy_query = deepcopy(query) @@ -380,7 +499,7 @@ def test_window(self): Rank(filter_=self.table.c1 > 0, window=window), Min(self.table.c1, window=window)) self.assertEqual(str(query), - 'SELECT RANK() FILTER (WHERE ("a"."c1" > %s)) OVER "b", ' + 'SELECT RANK() FILTER (WHERE "a"."c1" > %s) OVER "b", ' 'MIN("a"."c1") OVER "b" FROM "t" AS "a" ' 'WINDOW "b" AS (PARTITION BY "a"."c1")') self.assertEqual(tuple(query.params), (0,)) @@ -398,8 +517,8 @@ def test_window(self): Max(self.table.c1, window=window) / Min(self.table.c1, window=window)) self.assertEqual(str(query), - 'SELECT (MAX("a"."c1") OVER (PARTITION BY "a"."c2") ' - '/ MIN("a"."c1") OVER (PARTITION BY "a"."c2")) ' + 'SELECT MAX("a"."c1") OVER (PARTITION BY "a"."c2") ' + '/ MIN("a"."c1") OVER (PARTITION BY "a"."c2") ' 'FROM "t" AS "a"') self.assertEqual(tuple(query.params), ()) @@ -408,8 +527,8 @@ def test_window(self): Max(self.table.c1, window=window) / Min(self.table.c1, window=window)) self.assertEqual(str(query), - 'SELECT (MAX("a"."c1") OVER (PARTITION BY %s) ' - '/ MIN("a"."c1") OVER (PARTITION BY %s)) ' + 'SELECT MAX("a"."c1") OVER (PARTITION BY %s) ' + '/ MIN("a"."c1") OVER (PARTITION BY %s) ' 'FROM "t" AS "a"') self.assertEqual(tuple(query.params), (1, 1)) @@ -420,23 +539,47 @@ def test_window(self): / Min(self.table.c1, window=window2), windows=[window1]) self.assertEqual(str(query), - 'SELECT (MAX("a"."c1") OVER "b" ' - '/ MIN("a"."c1") OVER (PARTITION BY %s)) ' + 'SELECT MAX("a"."c1") OVER "b" ' + '/ MIN("a"."c1") OVER (PARTITION BY %s) ' 'FROM "t" AS "a" ' 'WINDOW "b" AS (PARTITION BY "a"."c2")') self.assertEqual(tuple(query.params), (1,)) + def test_window_with_alias(self): + query = self.table.select( + Min(self.table.c1, window=Window([self.table.c2])).as_('min')) + + self.assertEqual( + str(query), + 'SELECT MIN("a"."c1") OVER "b" AS "min" FROM "t" AS "a" ' + 'WINDOW "b" AS (PARTITION BY "a"."c2")') + self.assertEqual(query.params, ()) + + def test_select_invalid_window(self): + with self.assertRaises(ValueError): + self.table.select(windows=['foo']) + def test_order_params(self): with_ = With(query=self.table.select(self.table.c, where=(self.table.c > 1))) - w = Window([Literal(8)]) + w = Window([Literal(7)]) query = Select([Literal(2), Min(self.table.c, window=w)], from_=self.table.select(where=self.table.c > 3), with_=with_, where=self.table.c > 4, group_by=[Literal(5)], - order_by=[Literal(6)], - having=Literal(7)) + having=Literal(6), + order_by=[Literal(8)]) + self.assertEqual( + str(query), + 'WITH "c" AS (SELECT "a"."c" FROM "t" AS "a" WHERE "a"."c" > %s)' + ' SELECT %s, MIN("a"."c") OVER "b" ' + 'FROM SELECT * FROM "t" AS "a" WHERE "a"."c" > %s ' + 'WHERE "a"."c" > %s ' + 'GROUP BY %s ' + 'HAVING %s ' + 'WINDOW "b" AS (PARTITION BY %s) ' + 'ORDER BY %s') self.assertEqual(tuple(query.params), (1, 2, 3, 4, 5, 6, 7, 8)) def test_no_as(self): diff --git a/sql/tests/test_update.py b/sql/tests/test_update.py index 89f03f1..713f080 100644 --- a/sql/tests/test_update.py +++ b/sql/tests/test_update.py @@ -15,7 +15,7 @@ def test_update1(self): query.where = (self.table.b == Literal(True)) self.assertEqual(str(query), - 'UPDATE "t" AS "a" SET "c" = %s WHERE ("a"."b" = %s)') + 'UPDATE "t" AS "a" SET "c" = %s WHERE "a"."b" = %s') self.assertEqual(query.params, ('foo', True)) def test_update2(self): @@ -24,9 +24,17 @@ def test_update2(self): query = t1.update([t1.c], ['foo'], from_=[t2], where=(t1.c == t2.c)) self.assertEqual(str(query), 'UPDATE "t1" AS "b" SET "c" = %s FROM "t2" AS "a" ' - 'WHERE ("b"."c" = "a"."c")') + 'WHERE "b"."c" = "a"."c"') self.assertEqual(query.params, ('foo',)) + def test_update_invalid_values(self): + with self.assertRaises(ValueError): + self.table.update([self.table.c], 'foo') + + def test_update_invalid_where(self): + with self.assertRaises(ValueError): + self.table.update([self.table.c], ['foo'], where='foo') + def test_update_subselect(self): t1 = Table('t1') t2 = Table('t2') @@ -35,7 +43,7 @@ def test_update_subselect(self): for query in [query_list, query_nolist]: self.assertEqual(str(query), 'UPDATE "t1" AS "b" SET "c" = (' - 'SELECT "a"."c" FROM "t2" AS "a" WHERE ("a"."i" = "b"."i"))') + 'SELECT "a"."c" FROM "t2" AS "a" WHERE "a"."i" = "b"."i")') self.assertEqual(query.params, ()) def test_update_returning(self): @@ -54,7 +62,7 @@ def test_update_returning_select(self): self.assertEqual(str(query), 'UPDATE "t1" AS "b" SET "c" = %s ' 'RETURNING (SELECT "a"."c" FROM "t2" AS "a" ' - 'WHERE (("a"."c1" = "b"."c") AND ("a"."c2" = %s)))') + 'WHERE ("a"."c1" = "b"."c") AND ("a"."c2" = %s))') self.assertEqual(query.params, ('foo', 'bar')) def test_with(self): @@ -68,7 +76,7 @@ def test_with(self): self.assertEqual(str(query), 'WITH "a" AS (SELECT "b"."c1" FROM "t1" AS "b") ' 'UPDATE "t" AS "c" SET "c2" = (SELECT "a"."c3" FROM "a" AS "a" ' - 'WHERE ("a"."c4" = %s))') + 'WHERE "a"."c4" = %s)') self.assertEqual(query.params, (2,)) def test_schema(self): @@ -87,5 +95,5 @@ def test_schema_subselect(self): self.assertEqual(str(query), 'UPDATE "default"."t1" AS "b" SET "c1" = (' 'SELECT "a"."c" FROM "default"."t2" AS "a" ' - 'WHERE ("a"."i" = "b"."i"))') + 'WHERE "a"."i" = "b"."i")') self.assertEqual(query.params, ()) diff --git a/sql/tests/test_window.py b/sql/tests/test_window.py index 6fd1a1b..ae51835 100644 --- a/sql/tests/test_window.py +++ b/sql/tests/test_window.py @@ -14,6 +14,10 @@ def test_window(self): self.assertEqual(str(window), 'PARTITION BY "c1", "c2"') self.assertEqual(window.params, ()) + def test_window_invalid_partition(self): + with self.assertRaises(ValueError): + Window(['foo']) + def test_window_order(self): t = Table('t') window = Window([t.c], order_by=t.c) @@ -21,6 +25,10 @@ def test_window_order(self): self.assertEqual(str(window), 'PARTITION BY "c" ORDER BY "c"') self.assertEqual(window.params, ()) + def test_window_invalid_order(self): + with self.assertRaises(ValueError): + Window([Table('t').c], order_by='foo') + def test_window_range(self): t = Table('t') window = Window([t.c], frame='RANGE') @@ -33,22 +41,34 @@ def test_window_range(self): window.start = -1 self.assertEqual(str(window), 'PARTITION BY "c" RANGE ' - 'BETWEEN 1 PRECEDING AND CURRENT ROW') - self.assertEqual(window.params, ()) + 'BETWEEN %s PRECEDING AND CURRENT ROW') + self.assertEqual(window.params, (1,)) window.start = 0 window.end = 1 self.assertEqual(str(window), 'PARTITION BY "c" RANGE ' - 'BETWEEN CURRENT ROW AND 1 FOLLOWING') - self.assertEqual(window.params, ()) + 'BETWEEN CURRENT ROW AND %s FOLLOWING') + self.assertEqual(window.params, (1,)) window.start = 1 window.end = None self.assertEqual(str(window), 'PARTITION BY "c" RANGE ' - 'BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING') - self.assertEqual(window.params, ()) + 'BETWEEN %s FOLLOWING AND UNBOUNDED FOLLOWING') + self.assertEqual(window.params, (1,)) + + def test_window_invalid_frame(self): + with self.assertRaises(ValueError): + Window([Table('t').c], frame='foo') + + def test_window_invalid_start(self): + with self.assertRaises(ValueError): + Window([Table('t').c], start='foo') + + def test_window_invalid_end(self): + with self.assertRaises(ValueError): + Window([Table('t').c], end='foo') def test_window_exclude(self): t = Table('t') @@ -58,6 +78,10 @@ def test_window_exclude(self): 'PARTITION BY "c" EXCLUDE TIES') self.assertEqual(window.params, ()) + def test_window_invalid_exclude(self): + with self.assertRaises(ValueError): + Window([Table('t').c], exclude='foo') + def test_window_rows(self): t = Table('t') window = Window([t.c], frame='ROWS') diff --git a/sql/tests/test_with.py b/sql/tests/test_with.py index f608b1c..72abf0a 100644 --- a/sql/tests/test_with.py +++ b/sql/tests/test_with.py @@ -15,7 +15,7 @@ def test_with(self): self.assertEqual(simple.statement(), '"a" AS (' - 'SELECT "b"."id" FROM "t" AS "b" WHERE ("b"."id" = %s)' + 'SELECT "b"."id" FROM "t" AS "b" WHERE "b"."id" = %s' ')') self.assertEqual(simple.statement_params(), (1,)) @@ -40,7 +40,7 @@ def test_with_query(self): wq = WithQuery(with_=[simple, second]) self.assertEqual(wq._with_str(), 'WITH "a" AS (' - 'SELECT "b"."id" FROM "t" AS "b" WHERE ("b"."id" = %s)' + 'SELECT "b"."id" FROM "t" AS "b" WHERE "b"."id" = %s' '), "c" AS (' 'SELECT * FROM "a" AS "a"' ') ') @@ -59,6 +59,10 @@ def test_recursive(self): 'WITH RECURSIVE "a" ("n") AS (' 'VALUES (%s) ' 'UNION ALL ' - 'SELECT ("a"."n" + %s) FROM "a" AS "a" WHERE ("a"."n" < %s)' + 'SELECT "a"."n" + %s FROM "a" AS "a" WHERE "a"."n" < %s' ') SELECT * FROM "a" AS "a"') self.assertEqual(tuple(q.params), (1, 1, 100)) + + def test_invalid_with(self): + with self.assertRaises(ValueError): + WithQuery(with_=['foo']) diff --git a/tox.ini b/tox.ini index 917c277..553e159 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,17 @@ -# Tox (http://tox.testrun.org/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. - [tox] -envlist = py35, py36, py37, py38, py39, py310, pypy3 +envlist = py39, py310, py311, py312, py313, py314, pypy3 [testenv] +changedir = {env_site_packages_dir} commands = - coverage run -m unittest discover -s sql.tests - coverage report --include=./sql/* --omit=*/tests/* + coverage run --rcfile={toxinidir}/tox.ini --source=sql --omit=*/tests/* -m xmlrunner discover -s sql.tests {posargs} +commands_post = + coverage report --rcfile={toxinidir}/tox.ini --omit=README.rst + coverage xml --rcfile={toxinidir}/tox.ini --omit=README.rst -o {package_root}/coverage.xml deps = coverage + unittest-xml-reporting passenv = * + +[coverage:run] +relative_files = true