# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import re from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime, time, timezone from decimal import Decimal import arro3.core import nanoarrow import pyarrow as pa import pytest from datafusion import ( SessionContext, col, functions, lit, lit_with_metadata, literal_with_metadata, ) from datafusion.expr import ( EXPR_TYPE_ERROR, Aggregate, AggregateFunction, BinaryExpr, Column, CopyTo, CreateIndex, DescribeTable, DmlStatement, DropCatalogSchema, Filter, Limit, Literal, Projection, RecursiveQuery, Sort, TableScan, TransactionEnd, TransactionStart, Values, coerce_to_expr, coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @pytest.fixture def test_ctx(): ctx = SessionContext() ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") return ctx def test_projection(test_ctx): df = test_ctx.sql("select c1, 123, c1 < 123 from test") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Projection) expr = plan.projections() col1 = expr[0].to_variant() assert isinstance(col1, Column) assert col1.name() == "c1" assert col1.qualified_name() == "test.c1" col2 = expr[1].to_variant() assert isinstance(col2, Literal) assert col2.data_type() == "Int64" assert col2.value_i64() == 123 col3 = expr[2].to_variant() assert isinstance(col3, BinaryExpr) assert isinstance(col3.left().to_variant(), Column) assert col3.op() == "<" assert isinstance(col3.right().to_variant(), Literal) plan = plan.input()[0].to_variant() assert isinstance(plan, TableScan) def test_filter(test_ctx): df = test_ctx.sql("select c1 from test WHERE c1 > 5") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Projection) plan = plan.input()[0].to_variant() assert isinstance(plan, Filter) def test_limit(test_ctx): df = test_ctx.sql("select c1 from test LIMIT 10") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Limit) assert "Skip: None" in str(plan) df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Limit) assert "Skip: Some(Literal(Int64(5), None))" in str(plan) def test_aggregate_query(test_ctx): df = test_ctx.sql("select c1, count(*) from test group by c1") plan = df.logical_plan() projection = plan.to_variant() assert isinstance(projection, Projection) aggregate = projection.input()[0].to_variant() assert isinstance(aggregate, Aggregate) col1 = aggregate.group_by_exprs()[0].to_variant() assert isinstance(col1, Column) assert col1.name() == "c1" assert col1.qualified_name() == "test.c1" col2 = aggregate.aggregate_exprs()[0].to_variant() assert isinstance(col2, AggregateFunction) def test_sort(test_ctx): df = test_ctx.sql("select c1 from test order by c1") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Sort) def test_relational_expr(test_ctx): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [ pa.array([1, 2, 3, None]), pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()), ], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], name="batch_array") assert df.filter(col("a") == 1).count() == 1 assert df.filter(col("a") != 1).count() == 2 assert df.filter(col("a") >= 1).count() == 3 assert df.filter(col("a") > 1).count() == 2 assert df.filter(col("a") <= 3).count() == 3 assert df.filter(col("a") < 3).count() == 2 assert df.filter(col("b") == "beta").count() == 1 assert df.filter(col("b") != "beta").count() == 2 assert df.filter(col("a") == "beta").count() == 0 assert df.filter(col("a") == None).count() == 1 # noqa: E711 assert df.filter(col("a") != None).count() == 3 # noqa: E711 assert df.filter(col("b") == None).count() == 1 # noqa: E711 assert df.filter(col("b") != None).count() == 3 # noqa: E711 def test_expr_to_variant(): # Taken from https://github.com/apache/datafusion-python/issues/781 from datafusion import SessionContext from datafusion.expr import Filter def traverse_logical_plan(plan): cur_node = plan.to_variant() if isinstance(cur_node, Filter): return cur_node.predicate().to_variant() if hasattr(plan, "inputs"): for input_plan in plan.inputs(): res = traverse_logical_plan(input_plan) if res is not None: return res return None ctx = SessionContext() data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} ctx.from_pydict(data, name="table1") query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')" logical_plan = ctx.sql(query).optimized_logical_plan() variant = traverse_logical_plan(logical_plan) assert variant is not None assert variant.expr().to_variant().qualified_name() == "table1.name" assert ( str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' # noqa: E501 ) assert not variant.negated() def test_case_builder_error_preserves_builder_state(): case_builder = functions.when(lit(True), lit(1)) with pytest.raises(Exception) as exc_info: _ = case_builder.otherwise(lit("bad")) err_msg = str(exc_info.value) assert "multiple data types" in err_msg assert "CaseBuilder has already been consumed" not in err_msg _ = case_builder.end() err_msg = str(exc_info.value) assert "multiple data types" in err_msg assert "CaseBuilder has already been consumed" not in err_msg def test_case_builder_success_preserves_builder_state(): ctx = SessionContext() df = ctx.from_pydict({"flag": [False]}, name="tbl") case_builder = functions.when(col("flag"), lit("true")) expr_default_one = case_builder.otherwise(lit("default-1")).alias("result") result_one = df.select(expr_default_one).collect() assert result_one[0].column(0).to_pylist() == ["default-1"] expr_default_two = case_builder.otherwise(lit("default-2")).alias("result") result_two = df.select(expr_default_two).collect() assert result_two[0].column(0).to_pylist() == ["default-2"] expr_end_one = case_builder.end().alias("result") end_one = df.select(expr_end_one).collect() assert end_one[0].column(0).to_pylist() == [None] def test_case_builder_when_handles_are_independent(): ctx = SessionContext() df = ctx.from_pydict( { "flag": [True, False, False, False], "value": [1, 15, 25, 5], }, name="tbl", ) base_builder = functions.when(col("flag"), lit("flag-true")) first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) first_builder = first_builder.when(lit(True), lit("final-one")) expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") result = df.select(expr_first, expr_second).collect()[0] assert result.column(0).to_pylist() == [ "flag-true", "gt10", "gt10", "final-one", ] assert result.column(1).to_pylist() == [ "flag-true", "fallback-two", "gt20", "fallback-two", ] def test_case_builder_when_thread_safe(): case_builder = functions.when(lit(True), lit(1)) def build_expr(value: int) -> bool: builder = case_builder.when(lit(True), lit(value)) builder.otherwise(lit(value)) return True with ThreadPoolExecutor(max_workers=8) as executor: futures = [executor.submit(build_expr, idx) for idx in range(16)] results = [future.result() for future in futures] assert all(results) # Ensure the shared builder remains usable after concurrent `when` calls. follow_up_builder = case_builder.when(lit(True), lit(42)) assert isinstance(follow_up_builder, type(case_builder)) follow_up_builder.otherwise(lit(7)) def test_expr_getitem() -> None: ctx = SessionContext() data = { "array_values": [[1, 2, 3], [4, 5], [6], []], "struct_values": [ {"name": "Alice", "age": 15}, {"name": "Bob", "age": 14}, {"name": "Charlie", "age": 13}, {"name": None, "age": 12}, ], } df = ctx.from_pydict(data, name="table1") names = df.select(col("struct_values")["name"].alias("name")).collect() names = [r.as_py() for rs in names for r in rs["name"]] array_values = df.select(col("array_values")[1].alias("value")).collect() array_values = [r.as_py() for rs in array_values for r in rs["value"]] assert names == ["Alice", "Bob", "Charlie", None] assert array_values == [2, 5, None, None] @pytest.fixture def df(): ctx = SessionContext() # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])], names=["a", "b", "c"], ) return ctx.from_arrow(batch) def test_fill_null(df): df = df.select( col("a").fill_null(100).alias("a"), col("b").fill_null(25).alias("b"), col("c").fill_null(1234).alias("c"), ) df.show() result = df.collect()[0] assert result.column(0) == pa.array([1, 2, 100]) assert result.column(1) == pa.array([4, 25, 6]) assert result.column(2) == pa.array([1234, 1234, 8]) def test_copy_to(): ctx = SessionContext() ctx.sql("CREATE TABLE foo (a int, b int)").collect() df = ctx.sql("COPY foo TO bar STORED AS CSV") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, CopyTo) def test_create_index(): ctx = SessionContext() ctx.sql("CREATE TABLE foo (a int, b int)").collect() plan = ctx.sql("create index idx on foo (a)").logical_plan() plan = plan.to_variant() assert isinstance(plan, CreateIndex) def test_describe_table(): ctx = SessionContext() ctx.sql("CREATE TABLE foo (a int, b int)").collect() plan = ctx.sql("describe foo").logical_plan() plan = plan.to_variant() assert isinstance(plan, DescribeTable) def test_dml_statement(): ctx = SessionContext() ctx.sql("CREATE TABLE foo (a int, b int)").collect() plan = ctx.sql("insert into foo values (1, 2)").logical_plan() plan = plan.to_variant() assert isinstance(plan, DmlStatement) def drop_catalog_schema(): ctx = SessionContext() plan = ctx.sql("drop schema cat").logical_plan() plan = plan.to_variant() assert isinstance(plan, DropCatalogSchema) def test_recursive_query(): ctx = SessionContext() plan = ctx.sql( """ WITH RECURSIVE cte AS ( SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5 ) SELECT * FROM cte; """ ).logical_plan() plan = plan.inputs()[0].inputs()[0].to_variant() assert isinstance(plan, RecursiveQuery) def test_values(): ctx = SessionContext() plan = ctx.sql("values (1, 'foo'), (2, 'bar')").logical_plan() plan = plan.to_variant() assert isinstance(plan, Values) def test_transaction_start(): ctx = SessionContext() plan = ctx.sql("START TRANSACTION").logical_plan() plan = plan.to_variant() assert isinstance(plan, TransactionStart) def test_transaction_end(): ctx = SessionContext() plan = ctx.sql("COMMIT").logical_plan() plan = plan.to_variant() assert isinstance(plan, TransactionEnd) def test_col_getattr(): ctx = SessionContext() data = { "array_values": [[1, 2, 3], [4, 5], [6], []], "struct_values": [ {"name": "Alice", "age": 15}, {"name": "Bob", "age": 14}, {"name": "Charlie", "age": 13}, {"name": None, "age": 12}, ], } df = ctx.from_pydict(data, name="table1") names = df.select(col.struct_values["name"].alias("name")).collect() names = [r.as_py() for rs in names for r in rs["name"]] array_values = df.select(col.array_values[1].alias("value")).collect() array_values = [r.as_py() for rs in array_values for r in rs["value"]] assert names == ["Alice", "Bob", "Charlie", None] assert array_values == [2, 5, None, None] def test_alias_with_metadata(df): df = df.select(col("a").alias("b", {"key": "value"})) assert df.schema().field("b").metadata == {b"key": b"value"} # These unit tests are to ensure the expression functions do not regress # For the math functions we will use `functions.round` so we can more # easily test for equivalence and not worry about floating point precision @pytest.mark.parametrize( ("function", "expected_result"), [ # Math Functions pytest.param( functions.round(col("a").asin(), lit(4)), pa.array([-0.8481, 0.5236, 0.0, None], type=pa.float64()), id="asin", ), pytest.param( functions.round(col("a").sin(), lit(4)), pa.array([-0.6816, 0.4794, 0.0, None], type=pa.float64()), id="sin", ), pytest.param( # Since log10 of negative returns NaN and you can't test NaN for # equivalence, also do an abs() here. functions.round(col("a").abs().log10(), lit(4)), pa.array([-0.1249, -0.301, -float("inf"), None], type=pa.float64()), id="log10", ), pytest.param( col("a").iszero(), pa.array([False, False, True, None], type=pa.bool_()), id="iszero", ), pytest.param( functions.round(col("a").acos(), lit(4)), pa.array([2.4189, 1.0472, 1.5708, None], type=pa.float64()), id="acos", ), pytest.param( col("e").isnan(), pa.array([False, True, False, None], type=pa.bool_()), id="isnan", ), pytest.param( functions.round(col("a").degrees(), lit(4)), pa.array([-42.9718, 28.6479, 0.0, None], type=pa.float64()), id="degrees", ), pytest.param( functions.round(col("a").asinh(), lit(4)), pa.array([-0.6931, 0.4812, 0.0, None], type=pa.float64()), id="asinh", ), pytest.param( col("a").abs(), pa.array([0.75, 0.5, 0.0, None], type=pa.float64()), id="abs", ), pytest.param( functions.round(col("a").exp(), lit(4)), pa.array([0.4724, 1.6487, 1.0, None], type=pa.float64()), id="exp", ), pytest.param( functions.round(col("a").cosh(), lit(4)), pa.array([1.2947, 1.1276, 1.0, None], type=pa.float64()), id="cosh", ), pytest.param( functions.round(col("a").radians(), lit(4)), pa.array([-0.0131, 0.0087, 0.0, None], type=pa.float64()), id="radians", ), pytest.param( functions.round(col("a").abs().sqrt(), lit(4)), pa.array([0.866, 0.7071, 0.0, None], type=pa.float64()), id="sqrt", ), pytest.param( functions.round(col("a").tanh(), lit(4)), pa.array([-0.6351, 0.4621, 0.0, None], type=pa.float64()), id="tanh", ), pytest.param( functions.round(col("a").atan(), lit(4)), pa.array([-0.6435, 0.4636, 0.0, None], type=pa.float64()), id="atan", ), pytest.param( functions.round(col("a").atanh(), lit(4)), pa.array([-0.973, 0.5493, 0.0, None], type=pa.float64()), id="atanh", ), pytest.param( # large numbers cause an integer overflow so divid to make smaller (col("b") / lit(4)).factorial(), pa.array([1, 3628800, 1, None], type=pa.int64()), id="factorial", ), pytest.param( # Valid values of acosh must be >= 1.0 functions.round((col("a").abs() + lit(1.0)).acosh(), lit(4)), pa.array([1.1588, 0.9624, 0.0, None], type=pa.float64()), id="acosh", ), pytest.param( col("a").floor(), pa.array([-1.0, 0.0, 0.0, None], type=pa.float64()), id="floor", ), pytest.param( col("a").ceil(), pa.array([-0.0, 1.0, 0.0, None], type=pa.float64()), id="ceil", ), pytest.param( functions.round(col("a").abs().ln(), lit(4)), pa.array([-0.2877, -0.6931, float("-inf"), None], type=pa.float64()), id="ln", ), pytest.param( functions.round(col("a").tan(), lit(4)), pa.array([-0.9316, 0.5463, 0.0, None], type=pa.float64()), id="tan", ), pytest.param( functions.round(col("a").cbrt(), lit(4)), pa.array([-0.9086, 0.7937, 0.0, None], type=pa.float64()), id="cbrt", ), pytest.param( functions.round(col("a").cos(), lit(4)), pa.array([0.7317, 0.8776, 1.0, None], type=pa.float64()), id="cos", ), pytest.param( functions.round(col("a").sinh(), lit(4)), pa.array([-0.8223, 0.5211, 0.0, None], type=pa.float64()), id="sinh", ), pytest.param( col("a").signum(), pa.array([-1.0, 1.0, 0.0, None], type=pa.float64()), id="signum", ), pytest.param( functions.round(col("a").abs().log2(), lit(4)), pa.array([-0.415, -1.0, float("-inf"), None], type=pa.float64()), id="log2", ), pytest.param( functions.round(col("a").cot(), lit(4)), pa.array([-1.0734, 1.8305, float("inf"), None], type=pa.float64()), id="cot", ), # # String Functions # pytest.param( col("c").reverse(), pa.array(["olleH", " dlrow ", "!", None], type=pa.string()), id="reverse", ), pytest.param( col("c").bit_length(), pa.array([40, 56, 8, None], type=pa.int32()), id="bit_length", ), pytest.param( col("b").to_hex(), pa.array(["ffffffffffffffe2", "2a", "0", None], type=pa.string()), id="to_hex", ), pytest.param( col("c").length(), pa.array([5, 7, 1, None], type=pa.int32()), id="length", ), pytest.param( col("c").lower(), pa.array(["hello", " world ", "!", None], type=pa.string()), id="lower", ), pytest.param( col("c").ascii(), pa.array([72, 32, 33, None], type=pa.int32()), id="ascii", ), pytest.param( col("c").sha512(), pa.array( [ bytes.fromhex( "3615F80C9D293ED7402687F94B22D58E529B8CC7916F8FAC7FDDF7FBD5AF4CF777D3D795A7A00A16BF7E7F3FB9561EE9BAAE480DA9FE7A18769E71886B03F315" ), bytes.fromhex( "A6758FDA3C2F0B554084E18308EA99B94B54EEE8FDA72697CEA7844E524CC2F2F2EE4CC8BAC87D2E3E7222959FE3D0CA1A841761FDC0D1780F6FE9E39E369500" ), bytes.fromhex( "3831A6A6155E509DEE59A7F451EB35324D8F8F2DF6E3708894740F98FDEE23889F4DE5ADB0C5010DFB555CDA77C8AB5DC902094C52DE3278F35A75EBC25F093A" ), None, ], type=pa.binary(), ), id="sha512", ), pytest.param( col("c").sha384(), pa.array( [ bytes.fromhex( "3519FE5AD2C596EFE3E276A6F351B8FC0B03DB861782490D45F7598EBD0AB5FD5520ED102F38C4A5EC834E98668035FC" ), bytes.fromhex( "A6A38A9AE2CFD0D67F49989AD584632BF7D7A07DAD2277E92326A6A0B37F884A871D6173FB342CFE258E375258ACAAEC" ), bytes.fromhex( "1D0EC8C84EE9521E21F06774DE232367B64DE628474CB5B2E372B699A1F55AE335CC37193EF823E33324DFD9A70738A6" ), None, ], type=pa.binary(), ), id="sha384", ), pytest.param( col("c").sha256(), pa.array( [ bytes.fromhex( "185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969" ), bytes.fromhex( "DE2EF0D77D456EC1CDE2C52F75996F6636A64079297213D548D875A488B03A75" ), bytes.fromhex( "BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62" ), None, ], type=pa.binary(), ), id="sha256", ), pytest.param( col("c").sha224(), pa.array( [ bytes.fromhex( "4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC" ), bytes.fromhex( "AD6DF6D9ECDDF50AF2A72D5E3144BA813EE954537572C0E8AB3066BE" ), bytes.fromhex( "6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0" ), None, ], type=pa.binary(), ), id="sha224", ), pytest.param( col("c").btrim(), pa.array(["Hello", "world", "!", None], type=pa.string_view()), id="btrim", ), pytest.param( col("c").trim(), pa.array(["Hello", "world", "!", None], type=pa.string_view()), id="trim", ), pytest.param( col("c").md5(), pa.array( [ "8b1a9953c4611296a827abf8c47804d7", "de802497c24568d9a85d4eb8c2b6e8fe", "9033e0e305f247c0c3c80d0c7848c8b3", None, ], type=pa.string_view(), ), id="md5", ), pytest.param( col("c").octet_length(), pa.array([5, 7, 1, None], type=pa.int32()), id="octet_length", ), pytest.param( col("c").character_length(), pa.array([5, 7, 1, None], type=pa.int32()), id="character_length", ), pytest.param( col("c").char_length(), pa.array([5, 7, 1, None], type=pa.int32()), id="char_length", ), pytest.param( col("c").rtrim(), pa.array(["Hello", " world", "!", None], type=pa.string_view()), id="rtrim", ), pytest.param( col("c").ltrim(), pa.array(["Hello", "world ", "!", None], type=pa.string_view()), id="ltrim", ), pytest.param( col("c").upper(), pa.array(["HELLO", " WORLD ", "!", None], type=pa.string()), id="upper", ), pytest.param( lit(65).chr(), pa.array(["A", "A", "A", "A"], type=pa.string()), id="chr", ), # # Time Functions # pytest.param( col("b").from_unixtime(), pa.array( [ datetime(1969, 12, 31, 23, 59, 30, tzinfo=timezone.utc), datetime(1970, 1, 1, 0, 0, 42, tzinfo=timezone.utc), datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc), None, ], type=pa.timestamp("s"), ), id="from_unixtime", ), pytest.param( col("c").initcap(), pa.array(["Hello", " World ", "!", None], type=pa.string_view()), id="initcap", ), # # Array Functions # pytest.param( col("d").array_pop_back(), pa.array([[-1, 1], [5, 10, 15], [], None], type=pa.list_(pa.int64())), id="array_pop_back", ), pytest.param( col("d").array_pop_front(), pa.array([[1, 0], [10, 15, 20], [], None], type=pa.list_(pa.int64())), id="array_pop_front", ), pytest.param( col("d").array_length(), pa.array([3, 4, 0, None], type=pa.uint64()), id="array_length", ), pytest.param( col("d").list_length(), pa.array([3, 4, 0, None], type=pa.uint64()), id="list_length", ), pytest.param( col("d").array_ndims(), pa.array([1, 1, 1, None], type=pa.uint64()), id="array_ndims", ), pytest.param( col("d").list_ndims(), pa.array([1, 1, 1, None], type=pa.uint64()), id="list_ndims", ), pytest.param( col("d").array_dims(), pa.array([[3], [4], None, None], type=pa.list_(pa.uint64())), id="array_dims", ), pytest.param( col("d").array_empty(), pa.array([False, False, True, None], type=pa.bool_()), id="array_empty", ), pytest.param( col("d").list_distinct(), pa.array( [[-1, 1, 0], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) ), id="list_distinct", ), pytest.param( col("d").array_distinct(), pa.array( [[-1, 1, 0], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) ), id="array_distinct", ), pytest.param( col("d").cardinality(), pa.array([3, 4, 0, None], type=pa.uint64()), id="cardinality", ), pytest.param( col("f").flatten(), pa.array( [[-1, 1, 0, 4, 4], [5, 10, 15, 20, 3], [], []], type=pa.list_(pa.int64()), ), id="flatten", ), pytest.param( col("d").list_dims(), pa.array([[3], [4], None, None], type=pa.list_(pa.uint64())), id="list_dims", ), pytest.param( col("d").empty(), pa.array([False, False, True, None], type=pa.bool_()), id="empty", ), # # Other Tests # pytest.param( col("d").arrow_typeof(), pa.array( [ "List(Int64)", "List(Int64)", "List(Int64)", "List(Int64)", ], type=pa.string(), ), id="arrow_typeof", ), ], ) def test_expr_functions(ctx, function, expected_result): batch = pa.RecordBatch.from_arrays( [ pa.array([-0.75, 0.5, 0.0, None], type=pa.float64()), pa.array([-30, 42, 0, None], type=pa.int64()), pa.array(["Hello", " world ", "!", None], type=pa.string_view()), pa.array( [[-1, 1, 0], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) ), pa.array([-0.75, float("nan"), 0.0, None], type=pa.float64()), pa.array( [[[-1, 1, 0], [4, 4]], [[5, 10, 15, 20], [3]], [[]], [None]], type=pa.list_(pa.list_(pa.int64())), ), ], names=["a", "b", "c", "d", "e", "f"], ) df = ctx.create_dataframe([[batch]]).select(function) result = df.collect() assert len(result) == 1 assert result[0].column(0).equals(expected_result) def test_literal_metadata(ctx): result = ( ctx.from_pydict({"a": [1]}) .select( lit(1).alias("no_metadata"), lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"), literal_with_metadata(3, {"key2": "value2"}).alias( "literal_with_metadata_fn" ), ) .collect() ) expected_schema = pa.schema( [ pa.field("no_metadata", pa.int64(), nullable=False), pa.field( "lit_with_metadata_fn", pa.int64(), nullable=False, metadata={"key1": "value1"}, ), pa.field( "literal_with_metadata_fn", pa.int64(), nullable=False, metadata={"key2": "value2"}, ), ] ) expected = pa.RecordBatch.from_pydict( { "no_metadata": pa.array([1]), "lit_with_metadata_fn": pa.array([2]), "literal_with_metadata_fn": pa.array([3]), }, schema=expected_schema, ) assert result[0] == expected # Testing result[0].schema == expected_schema does not check each key/value pair # so we want to explicitly test these for expected_field in expected_schema: actual_field = result[0].schema.field(expected_field.name) assert expected_field.metadata == actual_field.metadata def test_scalar_conversion() -> None: class WrappedPyArrow: """Wrapper class for testing __arrow_c_array__.""" def __init__(self, val: pa.Array) -> None: self.val = val def __arrow_c_array__(self, requested_schema=None): return self.val.__arrow_c_array__(requested_schema=requested_schema) expected_value = lit(1) assert str(expected_value) == "Expr(Int64(1))" # Test pyarrow imports assert expected_value == lit(pa.scalar(1)) assert expected_value == lit(pa.scalar(1, type=pa.int32())) # Test nanoarrow na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0] assert expected_value == lit(na_scalar) # Test pyo3 arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) assert expected_value == lit(arro3_scalar) generic_scalar = WrappedPyArrow(pa.array([1])) assert expected_value == lit(generic_scalar) expected_value = lit([1, 2, 3]) assert str(expected_value) == "Expr(List([1, 2, 3]))" assert expected_value == lit(pa.scalar([1, 2, 3])) na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32()) assert expected_value == lit(na_array) arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) assert expected_value == lit(arro3_array) generic_array = WrappedPyArrow(pa.array([1, 2, 3])) assert expected_value == lit(generic_array) def test_ensure_expr(): e = col("a") assert ensure_expr(e) is e.expr with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): ensure_expr("a") def test_ensure_expr_list_string(): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): ensure_expr_list("a") def test_ensure_expr_list_bytes(): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): ensure_expr_list(b"a") def test_ensure_expr_list_bytearray(): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): ensure_expr_list(bytearray(b"a")) def test_coerce_to_expr_passes_expr_through(): e = col("a") result = coerce_to_expr(e) assert isinstance(result, type(e)) assert str(result) == str(e) def test_coerce_to_expr_wraps_int(): result = coerce_to_expr(42) assert isinstance(result, type(lit(42))) def test_coerce_to_expr_wraps_str(): result = coerce_to_expr("hello") assert isinstance(result, type(lit("hello"))) def test_coerce_to_expr_wraps_float(): result = coerce_to_expr(3.14) assert isinstance(result, type(lit(3.14))) def test_coerce_to_expr_wraps_bool(): result = coerce_to_expr(True) assert isinstance(result, type(lit(True))) def test_coerce_to_expr_or_none_returns_none(): assert coerce_to_expr_or_none(None) is None def test_coerce_to_expr_or_none_wraps_value(): result = coerce_to_expr_or_none(42) assert isinstance(result, type(lit(42))) def test_coerce_to_expr_or_none_passes_expr_through(): e = col("a") result = coerce_to_expr_or_none(e) assert isinstance(result, type(e)) assert str(result) == str(e) @pytest.mark.parametrize( "value", [ # Boolean pa.scalar(True, type=pa.bool_()), pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), pa.scalar(32767, type=pa.int16()), pa.scalar(-32768, type=pa.int16()), pa.scalar(2147483647, type=pa.int32()), pa.scalar(-2147483648, type=pa.int32()), pa.scalar(9223372036854775807, type=pa.int64()), pa.scalar(-9223372036854775808, type=pa.int64()), # Integers - unsigned pa.scalar(255, type=pa.uint8()), pa.scalar(65535, type=pa.uint16()), pa.scalar(4294967295, type=pa.uint32()), pa.scalar(18446744073709551615, type=pa.uint64()), # Floating point pa.scalar(3.14, type=pa.float32()), pa.scalar(3.141592653589793, type=pa.float64()), pa.scalar(float("inf"), type=pa.float64()), pa.scalar(float("-inf"), type=pa.float64()), # Decimal pa.scalar(Decimal("123.45"), type=pa.decimal128(10, 2)), pa.scalar(Decimal("-999999.999"), type=pa.decimal128(12, 3)), pa.scalar(Decimal("0.00001"), type=pa.decimal128(10, 5)), pa.scalar(Decimal("123.45"), type=pa.decimal256(20, 2)), # Strings pa.scalar("hello world", type=pa.string()), pa.scalar("", type=pa.string()), pa.scalar("unicode: 日本語 🎉", type=pa.string()), pa.scalar("hello", type=pa.large_string()), # Binary pa.scalar(b"binary data", type=pa.binary()), pa.scalar(b"", type=pa.binary()), pa.scalar(b"\x00\x01\x02\xff", type=pa.binary()), pa.scalar(b"large binary", type=pa.large_binary()), pa.scalar(b"fixed!", type=pa.binary(6)), # fixed size binary # Date pa.scalar(date(2023, 8, 18), type=pa.date32()), pa.scalar(date(1970, 1, 1), type=pa.date32()), pa.scalar(date(2023, 8, 18), type=pa.date64()), # Time pa.scalar(time(12, 30, 45), type=pa.time32("s")), pa.scalar(time(12, 30, 45, 123000), type=pa.time32("ms")), pa.scalar(time(12, 30, 45, 123456), type=pa.time64("us")), pa.scalar( 12 * 3600 * 10**9 + 30 * 60 * 10**9, type=pa.time64("ns") ), # raw nanos # Timestamp - various resolutions pa.scalar(1692335046, type=pa.timestamp("s")), pa.scalar(1692335046618, type=pa.timestamp("ms")), pa.scalar(1692335046618897, type=pa.timestamp("us")), pa.scalar(1692335046618897499, type=pa.timestamp("ns")), # Timestamp with timezone pa.scalar(1692335046, type=pa.timestamp("s", tz="UTC")), pa.scalar(1692335046618897, type=pa.timestamp("us", tz="America/New_York")), pa.scalar(1692335046618897499, type=pa.timestamp("ns", tz="Europe/London")), # Duration pa.scalar(3600, type=pa.duration("s")), pa.scalar(3600000, type=pa.duration("ms")), pa.scalar(3600000000, type=pa.duration("us")), pa.scalar(3600000000000, type=pa.duration("ns")), # Interval pa.scalar((1, 15, 3600000000000), type=pa.month_day_nano_interval()), pa.scalar((0, 0, 0), type=pa.month_day_nano_interval()), pa.scalar((12, 30, 0), type=pa.month_day_nano_interval()), # Null pa.scalar(None, type=pa.null()), pa.scalar(None, type=pa.int64()), pa.scalar(None, type=pa.string()), # List types pa.scalar([1, 2, 3], type=pa.list_(pa.int64())), pa.scalar([], type=pa.list_(pa.int64())), pa.scalar(["a", "b", "c"], type=pa.list_(pa.string())), pa.scalar([[1, 2], [3, 4]], type=pa.list_(pa.list_(pa.int64()))), pa.scalar([1, 2, 3], type=pa.large_list(pa.int64())), # Fixed size list pa.scalar([1, 2, 3], type=pa.list_(pa.int64(), 3)), # Struct pa.scalar( {"x": 1, "y": 2}, type=pa.struct([("x", pa.int64()), ("y", pa.int64())]) ), pa.scalar( {"name": "Alice", "age": 30}, type=pa.struct([("name", pa.string()), ("age", pa.int32())]), ), pa.scalar( {"nested": {"a": 1}}, type=pa.struct([("nested", pa.struct([("a", pa.int64())]))]), ), # Map pa.scalar([("key1", 1), ("key2", 2)], type=pa.map_(pa.string(), pa.int64())), ], ids=lambda v: f"{v.type}", ) def test_round_trip_pyscalar_value(ctx: SessionContext, value: pa.Scalar): df = ctx.sql("select 1 as a") df = df.select(lit(value)) assert pa.table(df)[0][0] == value