# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import ctypes import datetime import os import re import threading import time from typing import Any import pyarrow as pa import pyarrow.parquet as pq import pytest from datafusion import ( DataFrame, ParquetColumnOptions, ParquetWriterOptions, SessionContext, WindowFrame, column, literal, ) from datafusion import ( col as df_col, ) from datafusion import ( functions as f, ) from datafusion.dataframe_formatter import ( DataFrameHtmlFormatter, configure_formatter, get_formatter, reset_formatter, ) from datafusion.expr import EXPR_TYPE_ERROR, Window from pyarrow.csv import write_csv MB = 1024 * 1024 @pytest.fixture def ctx(): return SessionContext() @pytest.fixture def df(): ctx = SessionContext() # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], names=["a", "b", "c"], ) return ctx.from_arrow(batch) @pytest.fixture def large_df(): ctx = SessionContext() rows = 100000 data = { "a": list(range(rows)), "b": [f"s-{i}" for i in range(rows)], "c": [float(i + 0.1) for i in range(rows)], } batch = pa.record_batch(data) return ctx.from_arrow(batch) @pytest.fixture def struct_df(): ctx = SessionContext() # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])], names=["a", "b"], ) return ctx.create_dataframe([[batch]]) @pytest.fixture def nested_df(): ctx = SessionContext() # create a RecordBatch and a new DataFrame from it # Intentionally make each array of different length batch = pa.RecordBatch.from_arrays( [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])], names=["a", "b"], ) return ctx.create_dataframe([[batch]]) @pytest.fixture def aggregate_df(): ctx = SessionContext() ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") return ctx.sql("select c1, sum(c2) from test group by c1") @pytest.fixture def partitioned_df(): ctx = SessionContext() # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [ pa.array([0, 1, 2, 3, 4, 5, 6]), pa.array([7, None, 7, 8, 9, None, 9]), pa.array(["A", "A", "A", "A", "B", "B", "B"]), ], names=["a", "b", "c"], ) return ctx.create_dataframe([[batch]]) @pytest.fixture def clean_formatter_state(): """Reset the HTML formatter after each test.""" reset_formatter() @pytest.fixture def null_df(): """Create a DataFrame with null values of different types.""" ctx = SessionContext() # Create a RecordBatch with nulls across different types batch = pa.RecordBatch.from_arrays( [ pa.array([1, None, 3, None], type=pa.int64()), pa.array([4.5, 6.7, None, None], type=pa.float64()), pa.array(["a", None, "c", None], type=pa.string()), pa.array([True, None, False, None], type=pa.bool_()), pa.array( [10957, None, 18993, None], type=pa.date32() ), # 2000-01-01, null, 2022-01-01, null pa.array( [946684800000, None, 1640995200000, None], type=pa.date64() ), # 2000-01-01, null, 2022-01-01, null ], names=[ "int_col", "float_col", "str_col", "bool_col", "date32_col", "date64_col", ], ) return ctx.create_dataframe([[batch]]) # custom style for testing with html formatter class CustomStyleProvider: def get_cell_style(self) -> str: return ( "background-color: #f5f5f5; color: #333; padding: 8px; border: " "1px solid #ddd;" ) def get_header_style(self) -> str: return ( "background-color: #4285f4; color: white; font-weight: bold; " "padding: 10px; border: 1px solid #3367d6;" ) def count_table_rows(html_content: str) -> int: """Count the number of table rows in HTML content. Args: html_content: HTML string to analyze Returns: Number of table rows found (number of tags) """ return len(re.findall(r" literal(2)).select( column("a") + column("b"), column("a") - column("b"), ) # execute and collect the first (and only) batch result = df1.collect()[0] assert result.column(0) == pa.array([9]) assert result.column(1) == pa.array([-3]) df.show() # verify that if there is no filter applied, internal dataframe is unchanged df2 = df.filter() assert df.df == df2.df df3 = df.filter(column("a") > literal(1), column("b") != literal(6)) result = df3.collect()[0] assert result.column(0) == pa.array([2]) assert result.column(1) == pa.array([5]) assert result.column(2) == pa.array([5]) def test_show_empty(df, capsys): df_empty = df.filter(column("a") > literal(3)) df_empty.show() captured = capsys.readouterr() assert "DataFrame has no rows" in captured.out def test_sort(df): df = df.sort(column("b").sort(ascending=False)) table = pa.Table.from_batches(df.collect()) expected = {"a": [3, 2, 1], "b": [6, 5, 4], "c": [8, 5, 8]} assert table.to_pydict() == expected def test_sort_string_and_expression_equivalent(df): from datafusion import col result_str = df.sort("a").to_pydict() result_expr = df.sort(col("a")).to_pydict() assert result_str == result_expr def test_sort_unsupported(df): with pytest.raises( TypeError, match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}", ): df.sort(1) def test_aggregate_string_and_expression_equivalent(df): from datafusion import col result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict() result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict() assert result_str == result_expr def test_aggregate_tuple_group_by(df): result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict() result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict() assert result_tuple == result_list def test_aggregate_tuple_aggs(df): result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict() result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict() assert result_tuple == result_list def test_filter_string_unsupported(df): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.filter("a > 1") def test_drop(df): df = df.drop("c") # execute and collect the first (and only) batch result = df.collect()[0] assert df.schema().names == ["a", "b"] assert result.column(0) == pa.array([1, 2, 3]) assert result.column(1) == pa.array([4, 5, 6]) def test_limit(df): df = df.limit(1) # execute and collect the first (and only) batch result = df.collect()[0] assert len(result.column(0)) == 1 assert len(result.column(1)) == 1 def test_limit_with_offset(df): # only 3 rows, but limit past the end to ensure that offset is working df = df.limit(5, offset=2) # execute and collect the first (and only) batch result = df.collect()[0] assert len(result.column(0)) == 1 assert len(result.column(1)) == 1 def test_head(df): df = df.head(1) # execute and collect the first (and only) batch result = df.collect()[0] assert result.column(0) == pa.array([1]) assert result.column(1) == pa.array([4]) assert result.column(2) == pa.array([8]) def test_tail(df): df = df.tail(1) # execute and collect the first (and only) batch result = df.collect()[0] assert result.column(0) == pa.array([3]) assert result.column(1) == pa.array([6]) assert result.column(2) == pa.array([8]) def test_with_column(df): df = df.with_column("c", column("a") + column("b")) # execute and collect the first (and only) batch result = df.collect()[0] assert result.schema.field(0).name == "a" assert result.schema.field(1).name == "b" assert result.schema.field(2).name == "c" assert result.column(0) == pa.array([1, 2, 3]) assert result.column(1) == pa.array([4, 5, 6]) assert result.column(2) == pa.array([5, 7, 9]) def test_with_column_invalid_expr(df): with pytest.raises( TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" ): df.with_column("c", "a") def test_with_columns(df): df = df.with_columns( (column("a") + column("b")).alias("c"), (column("a") + column("b")).alias("d"), [ (column("a") + column("b")).alias("e"), (column("a") + column("b")).alias("f"), ], g=(column("a") + column("b")), ) # execute and collect the first (and only) batch result = df.collect()[0] assert result.schema.field(0).name == "a" assert result.schema.field(1).name == "b" assert result.schema.field(2).name == "c" assert result.schema.field(3).name == "d" assert result.schema.field(4).name == "e" assert result.schema.field(5).name == "f" assert result.schema.field(6).name == "g" assert result.column(0) == pa.array([1, 2, 3]) assert result.column(1) == pa.array([4, 5, 6]) assert result.column(2) == pa.array([5, 7, 9]) assert result.column(3) == pa.array([5, 7, 9]) assert result.column(4) == pa.array([5, 7, 9]) assert result.column(5) == pa.array([5, 7, 9]) assert result.column(6) == pa.array([5, 7, 9]) def test_with_columns_invalid_expr(df): with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns("a") with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns(c="a") with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns(["a"]) with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): df.with_columns(c=["a"]) def test_cast(df): df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) expected = pa.schema( [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())] ) assert df.schema() == expected def test_with_column_renamed(df): df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum") result = df.collect()[0] assert result.schema.field(0).name == "a" assert result.schema.field(1).name == "b" assert result.schema.field(2).name == "sum" def test_unnest(nested_df): nested_df = nested_df.unnest_columns("a") # execute and collect the first (and only) batch result = nested_df.collect()[0] assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None]) assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10]) def test_unnest_without_nulls(nested_df): nested_df = nested_df.unnest_columns("a", preserve_nulls=False) # execute and collect the first (and only) batch result = nested_df.collect()[0] assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6]) assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) @pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning") def test_join(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], "l") batch = pa.RecordBatch.from_arrays( [pa.array([1, 2]), pa.array([8, 10])], names=["a", "c"], ) df1 = ctx.create_dataframe([[batch]], "r") df2 = df.join(df1, on="a", how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected df2 = df.join(df1, left_on="a", right_on="a", how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected # Verify we don't make a breaking change to pre-43.0.0 # where users would pass join_keys as a positional argument df2 = df.join(df1, (["a"], ["a"]), how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected def test_join_invalid_params(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], "l") batch = pa.RecordBatch.from_arrays( [pa.array([1, 2]), pa.array([8, 10])], names=["a", "c"], ) df1 = ctx.create_dataframe([[batch]], "r") with pytest.deprecated_call(): df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected with pytest.raises( ValueError, match=r"`left_on` or `right_on` should not provided with `on`" ): df2 = df.join(df1, on="a", how="inner", right_on="test") with pytest.raises( ValueError, match=r"`left_on` and `right_on` should both be provided." ): df2 = df.join(df1, left_on="a", how="inner") with pytest.raises( ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." ): df2 = df.join(df1, how="inner") def test_join_on(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], "l") batch = pa.RecordBatch.from_arrays( [pa.array([1, 2]), pa.array([-8, 10])], names=["a", "c"], ) df1 = ctx.create_dataframe([[batch]], "r") df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]} assert table.to_pydict() == expected df3 = df.join_on( df1, column("l.a").__eq__(column("r.a")), column("l.a").__lt__(column("r.c")), how="inner", ) df3.show() df3 = df3.sort(column("l.a")) table = pa.Table.from_batches(df3.collect()) expected = {"a": [2], "c": [10], "b": [5]} assert table.to_pydict() == expected def test_join_on_invalid_expr(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2]), pa.array([4, 5])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], "l") df1 = ctx.create_dataframe([[batch]], "r") with pytest.raises( TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" ): df.join_on(df1, "a") def test_aggregate_invalid_aggs(df): with pytest.raises( TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" ): df.aggregate([], "a") def test_distinct(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3, 1, 2, 3]), pa.array([4, 5, 6, 4, 5, 6])], names=["a", "b"], ) df_a = ctx.create_dataframe([[batch]]).distinct().sort(column("a")) batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df_b = ctx.create_dataframe([[batch]]).sort(column("a")) assert df_a.collect() == df_b.collect() data_test_window_functions = [ ( "row", f.row_number(order_by=[column("b"), column("a").sort(ascending=False)]), [4, 2, 3, 5, 7, 1, 6], ), ( "row_w_params", f.row_number( order_by=[column("b"), column("a")], partition_by=[column("c")], ), [2, 1, 3, 4, 2, 1, 3], ), ( "row_w_params_no_lists", f.row_number( order_by=column("b"), partition_by=column("c"), ), [2, 1, 3, 4, 2, 1, 3], ), ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]), ( "rank_w_params", f.rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), [2, 1, 3, 4, 2, 1, 3], ), ( "rank_w_params_no_lists", f.rank(order_by=column("a"), partition_by=column("c")), [1, 2, 3, 4, 1, 2, 3], ), ( "dense_rank", f.dense_rank(order_by=[column("b")]), [2, 1, 2, 3, 4, 1, 4], ), ( "dense_rank_w_params", f.dense_rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), [2, 1, 3, 4, 2, 1, 3], ), ( "dense_rank_w_params_no_lists", f.dense_rank(order_by=column("a"), partition_by=column("c")), [1, 2, 3, 4, 1, 2, 3], ), ( "percent_rank", f.round(f.percent_rank(order_by=[column("b")]), literal(3)), [0.333, 0.0, 0.333, 0.667, 0.833, 0.0, 0.833], ), ( "percent_rank_w_params", f.round( f.percent_rank( order_by=[column("b"), column("a")], partition_by=[column("c")] ), literal(3), ), [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0], ), ( "percent_rank_w_params_no_lists", f.round( f.percent_rank(order_by=column("a"), partition_by=column("c")), literal(3), ), [0.0, 0.333, 0.667, 1.0, 0.0, 0.5, 1.0], ), ( "cume_dist", f.round(f.cume_dist(order_by=[column("b")]), literal(3)), [0.571, 0.286, 0.571, 0.714, 1.0, 0.286, 1.0], ), ( "cume_dist_w_params", f.round( f.cume_dist( order_by=[column("b"), column("a")], partition_by=[column("c")] ), literal(3), ), [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0], ), ( "cume_dist_w_params_no_lists", f.round( f.cume_dist(order_by=column("a"), partition_by=column("c")), literal(3), ), [0.25, 0.5, 0.75, 1.0, 0.333, 0.667, 1.0], ), ( "ntile", f.ntile(2, order_by=[column("b")]), [1, 1, 1, 2, 2, 1, 2], ), ( "ntile_w_params", f.ntile(2, order_by=[column("b"), column("a")], partition_by=[column("c")]), [1, 1, 2, 2, 1, 1, 2], ), ( "ntile_w_params_no_lists", f.ntile(2, order_by=column("b"), partition_by=column("c")), [1, 1, 2, 2, 1, 1, 2], ), ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 7, None]), ( "lead_w_params", f.lead( column("b"), shift_offset=2, default_value=-1, order_by=[column("b"), column("a")], partition_by=[column("c")], ), [8, 7, -1, -1, -1, 9, -1], ), ( "lead_w_params_no_lists", f.lead( column("b"), shift_offset=2, default_value=-1, order_by=column("b"), partition_by=column("c"), ), [8, 7, -1, -1, -1, 9, -1], ), ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, None, 9]), ( "lag_w_params", f.lag( column("b"), shift_offset=2, default_value=-1, order_by=[column("b"), column("a")], partition_by=[column("c")], ), [-1, -1, None, 7, -1, -1, None], ), ( "lag_w_params_no_lists", f.lag( column("b"), shift_offset=2, default_value=-1, order_by=column("b"), partition_by=column("c"), ), [-1, -1, None, 7, -1, -1, None], ), ( "first_value", f.first_value(column("a")).over( Window(partition_by=[column("c")], order_by=[column("b")]) ), [1, 1, 1, 1, 5, 5, 5], ), ( "first_value_without_list_args", f.first_value(column("a")).over( Window(partition_by=column("c"), order_by=column("b")) ), [1, 1, 1, 1, 5, 5, 5], ), ( "first_value_order_by_string", f.first_value(column("a")).over( Window(partition_by=[column("c")], order_by="b") ), [1, 1, 1, 1, 5, 5, 5], ), ( "last_value", f.last_value(column("a")).over( Window( partition_by=[column("c")], order_by=[column("b")], window_frame=WindowFrame("rows", None, None), ) ), [3, 3, 3, 3, 6, 6, 6], ), ( "3rd_value", f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])), [None, None, 7, 7, 7, 7, 7], ), ( "avg", f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)), [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0], ), ] @pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions) def test_window_functions(partitioned_df, name, expr, result): df = partitioned_df.select( column("a"), column("b"), column("c"), f.alias(expr, name) ) df.sort(column("a")).show() table = pa.Table.from_batches(df.collect()) expected = { "a": [0, 1, 2, 3, 4, 5, 6], "b": [7, None, 7, 8, 9, None, 9], "c": ["A", "A", "A", "A", "B", "B", "B"], name: result, } assert table.sort_by("a").to_pydict() == expected @pytest.mark.parametrize("partition", ["c", df_col("c")]) def test_rank_partition_by_accepts_string(partitioned_df, partition): """Passing a string to partition_by should match using col().""" df = partitioned_df.select( f.rank(order_by=column("a"), partition_by=partition).alias("r") ) table = pa.Table.from_batches(df.sort(column("a")).collect()) assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3] @pytest.mark.parametrize("partition", ["c", df_col("c")]) def test_window_partition_by_accepts_string(partitioned_df, partition): """Window.partition_by accepts string identifiers.""" expr = f.first_value(column("a")).over( Window(partition_by=partition, order_by=column("b")) ) df = partitioned_df.select(expr.alias("fv")) table = pa.Table.from_batches(df.sort(column("a")).collect()) assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5] @pytest.mark.parametrize( ("units", "start_bound", "end_bound"), [ (units, start_bound, end_bound) for units in ("rows", "range") for start_bound in (None, 0, 1) for end_bound in (None, 0, 1) ] + [ ("groups", 0, 0), ], ) def test_valid_window_frame(units, start_bound, end_bound): WindowFrame(units, start_bound, end_bound) @pytest.mark.parametrize( ("units", "start_bound", "end_bound"), [ ("invalid-units", 0, None), ("invalid-units", None, 0), ("invalid-units", None, None), ("groups", None, 0), ("groups", 0, None), ("groups", None, None), ], ) def test_invalid_window_frame(units, start_bound, end_bound): with pytest.raises(NotImplementedError, match=f"(?i){units}"): WindowFrame(units, start_bound, end_bound) def test_window_frame_defaults_match_postgres(partitioned_df): # ref: https://github.com/apache/datafusion-python/issues/688 window_frame = WindowFrame("rows", None, None) col_a = column("a") # Using `f.window` with or without an unbounded window_frame produces the same # results. These tests are included as a regression check but can be removed when # f.window() is deprecated in favor of using the .over() approach. no_frame = f.window("avg", [col_a]).alias("no_frame") with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame") df_1 = partitioned_df.select(col_a, no_frame, with_frame) expected = { "a": [0, 1, 2, 3, 4, 5, 6], "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], } assert df_1.sort(col_a).to_pydict() == expected # When order is not set, the default frame should be unbounded preceding to # unbounded following. When order is set, the default frame is unbounded preceding # to current row. no_order = f.avg(col_a).over(Window()).alias("over_no_order") with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order") df_2 = partitioned_df.select(col_a, no_order, with_order) expected = { "a": [0, 1, 2, 3, 4, 5, 6], "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0], } assert df_2.sort(col_a).to_pydict() == expected def _build_last_value_df(df): return df.select( f.last_value(column("a")) .over( Window( partition_by=[column("c")], order_by=[column("b")], window_frame=WindowFrame("rows", None, None), ) ) .alias("expr"), f.last_value(column("a")) .over( Window( partition_by=[column("c")], order_by="b", window_frame=WindowFrame("rows", None, None), ) ) .alias("str"), ) def _build_nth_value_df(df): return df.select( f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"), f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"), ) def _build_rank_df(df): return df.select( f.rank(order_by=[column("b")]).alias("expr"), f.rank(order_by="b").alias("str"), ) def _build_array_agg_df(df): return df.aggregate( [column("c")], [ f.array_agg(column("a"), order_by=[column("a")]).alias("expr"), f.array_agg(column("a"), order_by="a").alias("str"), ], ).sort(column("c")) @pytest.mark.parametrize( ("builder", "expected"), [ pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"), pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"), pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"), pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"), ], ) def test_order_by_string_equivalence(partitioned_df, builder, expected): df = builder(partitioned_df) table = pa.Table.from_batches(df.collect()) assert table.column("expr").to_pylist() == expected assert table.column("expr").to_pylist() == table.column("str").to_pylist() def test_html_formatter_cell_dimension(df, clean_formatter_state): """Test configuring the HTML formatter with different options.""" # Configure with custom settings configure_formatter( max_width=500, max_height=200, enable_cell_expansion=False, ) html_output = df._repr_html_() # Verify our configuration was applied assert "max-height: 200px" in html_output assert "max-width: 500px" in html_output # With cell expansion disabled, we shouldn't see expandable-container elements assert "expandable-container" not in html_output def test_html_formatter_custom_style_provider(df, clean_formatter_state): """Test using custom style providers with the HTML formatter.""" # Configure with custom style provider configure_formatter(style_provider=CustomStyleProvider()) html_output = df._repr_html_() # Verify our custom styles were applied assert "background-color: #4285f4" in html_output assert "color: white" in html_output assert "background-color: #f5f5f5" in html_output def test_html_formatter_type_formatters(df, clean_formatter_state): """Test registering custom type formatters for specific data types.""" # Get current formatter and register custom formatters formatter = get_formatter() # Format integers with color based on value # Using int as the type for the formatter will work since we convert # Arrow scalar values to Python native types in _get_cell_value def format_int(value): return f' 2 else "blue"}">{value}' formatter.register_formatter(int, format_int) html_output = df._repr_html_() # Our test dataframe has values 1,2,3 so we should see: assert '1' in html_output def test_html_formatter_custom_cell_builder(df, clean_formatter_state): """Test using a custom cell builder function.""" # Create a custom cell builder with distinct styling for different value ranges def custom_cell_builder(value, row, col, table_id): try: num_value = int(value) if num_value > 5: # Values > 5 get green background with indicator return ( '{value}-high' ) if num_value < 3: # Values < 3 get blue background with indicator return ( '{value}-low' ) except (ValueError, TypeError): pass # Default styling for other cells (3, 4, 5) return f'{value}-mid' # Set our custom cell builder formatter = get_formatter() formatter.set_custom_cell_builder(custom_cell_builder) html_output = df._repr_html_() # Extract cells with specific styling using regex low_cells = re.findall( r']*>(\d+)-low', html_output ) mid_cells = re.findall( r']*>(\d+)-mid', html_output ) high_cells = re.findall( r']*>(\d+)-high', html_output ) # Sort the extracted values for consistent comparison low_cells = sorted(map(int, low_cells)) mid_cells = sorted(map(int, mid_cells)) high_cells = sorted(map(int, high_cells)) # Verify specific values have the correct styling applied assert low_cells == [1, 2] # Values < 3 assert mid_cells == [3, 4, 5, 5] # Values 3-5 assert high_cells == [6, 8, 8] # Values > 5 # Verify the exact content with styling appears in the output assert ( '1-low' in html_output ) assert ( '2-low' in html_output ) assert ( '3-mid' in html_output ) assert ( '4-mid' in html_output ) assert ( '6-high' in html_output ) assert ( '8-high' in html_output ) # Count occurrences to ensure all cells are properly styled assert html_output.count("-low") == 2 # Two low values (1, 2) assert html_output.count("-mid") == 4 # Four mid values (3, 4, 5, 5) assert html_output.count("-high") == 3 # Three high values (6, 8, 8) # Create a custom cell builder that changes background color based on value def custom_cell_builder(value, row, col, table_id): # Handle numeric values regardless of their exact type try: num_value = int(value) if num_value > 5: # Values > 5 get green background return f'{value}' if num_value < 3: # Values < 3 get light blue background return f'{value}' except (ValueError, TypeError): pass # Default styling for other cells return f'{value}' # Set our custom cell builder formatter = get_formatter() formatter.set_custom_cell_builder(custom_cell_builder) html_output = df._repr_html_() # Verify our custom cell styling was applied assert "background-color: #d3e9f0" in html_output # For values 1,2 def test_html_formatter_custom_header_builder(df, clean_formatter_state): """Test using a custom header builder function.""" # Create a custom header builder with tooltips def custom_header_builder(field): tooltips = { "a": "Primary key column", "b": "Secondary values", "c": "Additional data", } tooltip = tooltips.get(field.name, "") return ( f'{field.name}' ) # Set our custom header builder formatter = get_formatter() formatter.set_custom_header_builder(custom_header_builder) html_output = df._repr_html_() # Verify our custom headers were applied assert 'title="Primary key column"' in html_output assert 'title="Secondary values"' in html_output assert "background-color: #333; color: white" in html_output def test_html_formatter_complex_customization(df, clean_formatter_state): """Test combining multiple customization options together.""" # Create a dark mode style provider class DarkModeStyleProvider: def get_cell_style(self) -> str: return ( "background-color: #222; color: #eee; " "padding: 8px; border: 1px solid #444;" ) def get_header_style(self) -> str: return ( "background-color: #111; color: #fff; padding: 10px; " "border: 1px solid #333;" ) # Configure with dark mode style configure_formatter( max_cell_length=10, style_provider=DarkModeStyleProvider(), custom_css=""" .datafusion-table { font-family: monospace; border-collapse: collapse; } .datafusion-table tr:hover td { background-color: #444 !important; } """, ) # Add type formatters for special formatting - now working with native int values formatter = get_formatter() formatter.register_formatter( int, lambda n: f'{n}', ) html_output = df._repr_html_() # Verify our customizations were applied assert "background-color: #222" in html_output assert "background-color: #111" in html_output assert ".datafusion-table" in html_output assert "color: #5af" in html_output # Even numbers def test_html_formatter_memory(df, clean_formatter_state): """Test the memory and row control parameters in DataFrameHtmlFormatter.""" configure_formatter(max_memory_bytes=10, min_rows_display=1) html_output = df._repr_html_() # Count the number of table rows in the output tr_count = count_table_rows(html_output) # With a tiny memory limit of 10 bytes, the formatter should display # the minimum number of rows (1) plus a message about truncation assert tr_count == 2 # 1 for header row, 1 for data row assert "data truncated" in html_output.lower() configure_formatter(max_memory_bytes=10 * MB, min_rows_display=1) html_output = df._repr_html_() # With larger memory limit and min_rows=2, should display all rows tr_count = count_table_rows(html_output) # Table should have header row (1) + 3 data rows = 4 rows assert tr_count == 4 # No truncation message should appear assert "data truncated" not in html_output.lower() def test_html_formatter_repr_rows(df, clean_formatter_state): configure_formatter(min_rows_display=2, repr_rows=2) html_output = df._repr_html_() tr_count = count_table_rows(html_output) # Table should have header row (1) + 2 data rows = 3 rows assert tr_count == 3 configure_formatter(min_rows_display=2, repr_rows=3) html_output = df._repr_html_() tr_count = count_table_rows(html_output) # Table should have header row (1) + 3 data rows = 4 rows assert tr_count == 4 def test_html_formatter_validation(): # Test validation for invalid parameters with pytest.raises(ValueError, match="max_cell_length must be a positive integer"): DataFrameHtmlFormatter(max_cell_length=0) with pytest.raises(ValueError, match="max_width must be a positive integer"): DataFrameHtmlFormatter(max_width=0) with pytest.raises(ValueError, match="max_height must be a positive integer"): DataFrameHtmlFormatter(max_height=0) with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"): DataFrameHtmlFormatter(max_memory_bytes=0) with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"): DataFrameHtmlFormatter(max_memory_bytes=-100) with pytest.raises(ValueError, match="min_rows_display must be a positive integer"): DataFrameHtmlFormatter(min_rows_display=0) with pytest.raises(ValueError, match="min_rows_display must be a positive integer"): DataFrameHtmlFormatter(min_rows_display=-5) with pytest.raises(ValueError, match="repr_rows must be a positive integer"): DataFrameHtmlFormatter(repr_rows=0) with pytest.raises(ValueError, match="repr_rows must be a positive integer"): DataFrameHtmlFormatter(repr_rows=-10) def test_configure_formatter(df, clean_formatter_state): """Test using custom style providers with the HTML formatter and configured parameters.""" # these are non-default values max_cell_length = 10 max_width = 500 max_height = 30 max_memory_bytes = 3 * MB min_rows_display = 2 repr_rows = 2 enable_cell_expansion = False show_truncation_message = False use_shared_styles = False reset_formatter() formatter_default = get_formatter() assert formatter_default.max_cell_length != max_cell_length assert formatter_default.max_width != max_width assert formatter_default.max_height != max_height assert formatter_default.max_memory_bytes != max_memory_bytes assert formatter_default.min_rows_display != min_rows_display assert formatter_default.repr_rows != repr_rows assert formatter_default.enable_cell_expansion != enable_cell_expansion assert formatter_default.show_truncation_message != show_truncation_message assert formatter_default.use_shared_styles != use_shared_styles # Configure with custom style provider and additional parameters configure_formatter( max_cell_length=max_cell_length, max_width=max_width, max_height=max_height, max_memory_bytes=max_memory_bytes, min_rows_display=min_rows_display, repr_rows=repr_rows, enable_cell_expansion=enable_cell_expansion, show_truncation_message=show_truncation_message, use_shared_styles=use_shared_styles, ) formatter_custom = get_formatter() assert formatter_custom.max_cell_length == max_cell_length assert formatter_custom.max_width == max_width assert formatter_custom.max_height == max_height assert formatter_custom.max_memory_bytes == max_memory_bytes assert formatter_custom.min_rows_display == min_rows_display assert formatter_custom.repr_rows == repr_rows assert formatter_custom.enable_cell_expansion == enable_cell_expansion assert formatter_custom.show_truncation_message == show_truncation_message assert formatter_custom.use_shared_styles == use_shared_styles def test_configure_formatter_invalid_params(clean_formatter_state): """Test that configure_formatter rejects invalid parameters.""" with pytest.raises(ValueError, match="Invalid formatter parameters"): configure_formatter(invalid_param=123) # Test with multiple parameters, one valid and one invalid with pytest.raises(ValueError, match="Invalid formatter parameters"): configure_formatter(max_width=500, not_a_real_param="test") # Test with multiple invalid parameters with pytest.raises(ValueError, match="Invalid formatter parameters"): configure_formatter(fake_param1="test", fake_param2=456) def test_get_dataframe(tmp_path): ctx = SessionContext() path = tmp_path / "test.csv" table = pa.Table.from_arrays( [ [1, 2, 3, 4], ["a", "b", "c", "d"], [1.1, 2.2, 3.3, 4.4], ], names=["int", "str", "float"], ) write_csv(table, path) ctx.register_csv("csv", path) df = ctx.table("csv") assert isinstance(df, DataFrame) def test_struct_select(struct_df): df = struct_df.select( column("a")["c"] + column("b"), column("a")["c"] - column("b"), ) # execute and collect the first (and only) batch result = df.collect()[0] assert result.column(0) == pa.array([5, 7, 9]) assert result.column(1) == pa.array([-3, -3, -3]) def test_explain(df): df = df.select( column("a") + column("b"), column("a") - column("b"), ) df.explain() def test_logical_plan(aggregate_df): plan = aggregate_df.logical_plan() expected = "Projection: test.c1, sum(test.c2)" assert expected == plan.display() expected = ( "Projection: test.c1, sum(test.c2)\n" " Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" " TableScan: test" ) assert expected == plan.display_indent() def test_optimized_logical_plan(aggregate_df): plan = aggregate_df.optimized_logical_plan() expected = "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]" assert expected == plan.display() expected = ( "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" " TableScan: test projection=[c1, c2]" ) assert expected == plan.display_indent() def test_execution_plan(aggregate_df): plan = aggregate_df.execution_plan() expected = ( "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" ) assert expected == plan.display() # Check the number of partitions is as expected. assert isinstance(plan.partition_count, int) expected = ( "ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n" " Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n" " TableScan: test projection=[c1, c2]" ) indent = plan.display_indent() # indent plan will be different for everyone due to absolute path # to filename, so we just check for some expected content assert "AggregateExec:" in indent assert "CoalesceBatchesExec:" in indent assert "RepartitionExec:" in indent assert "DataSourceExec:" in indent assert "file_type=csv" in indent ctx = SessionContext() rows_returned = 0 for idx in range(plan.partition_count): stream = ctx.execute(plan, idx) try: batch = stream.next() assert batch is not None rows_returned += len(batch.to_pyarrow()[0]) except StopIteration: # This is one of the partitions with no values pass with pytest.raises(StopIteration): stream.next() assert rows_returned == 5 @pytest.mark.asyncio async def test_async_iteration_of_df(aggregate_df): rows_returned = 0 async for batch in aggregate_df.execute_stream(): assert batch is not None rows_returned += len(batch.to_pyarrow()[0]) assert rows_returned == 5 def test_repartition(df): df.repartition(2) def test_repartition_by_hash(df): df.repartition_by_hash(column("a"), num=2) def test_intersect(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df_a = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([3, 4, 5]), pa.array([6, 7, 8])], names=["a", "b"], ) df_b = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([3]), pa.array([6])], names=["a", "b"], ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) df_a_i_b = df_a.intersect(df_b).sort(column("a")) assert df_c.collect() == df_a_i_b.collect() def test_except_all(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df_a = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([3, 4, 5]), pa.array([6, 7, 8])], names=["a", "b"], ) df_b = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([1, 2]), pa.array([4, 5])], names=["a", "b"], ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) df_a_e_b = df_a.except_all(df_b).sort(column("a")) assert df_c.collect() == df_a_e_b.collect() def test_collect_partitioned(): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() def test_union(ctx): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df_a = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([3, 4, 5]), pa.array([6, 7, 8])], names=["a", "b"], ) df_b = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], names=["a", "b"], ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) df_a_u_b = df_a.union(df_b).sort(column("a")) assert df_c.collect() == df_a_u_b.collect() def test_union_distinct(ctx): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df_a = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([3, 4, 5]), pa.array([6, 7, 8])], names=["a", "b"], ) df_b = ctx.create_dataframe([[batch]]) batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], names=["a", "b"], ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) assert df_c.collect() == df_a_u_b.collect() assert df_c.collect() == df_a_u_b.collect() def test_cache(df): assert df.cache().collect() == df.collect() def test_count(df): # Get number of rows assert df.count() == 3 def test_to_pandas(df): # Skip test if pandas is not installed pd = pytest.importorskip("pandas") # Convert datafusion dataframe to pandas dataframe pandas_df = df.to_pandas() assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (3, 3) assert set(pandas_df.columns) == {"a", "b", "c"} def test_empty_to_pandas(df): # Skip test if pandas is not installed pd = pytest.importorskip("pandas") # Convert empty datafusion dataframe to pandas dataframe pandas_df = df.limit(0).to_pandas() assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (0, 3) assert set(pandas_df.columns) == {"a", "b", "c"} def test_to_polars(df): # Skip test if polars is not installed pl = pytest.importorskip("polars") # Convert datafusion dataframe to polars dataframe polars_df = df.to_polars() assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (3, 3) assert set(polars_df.columns) == {"a", "b", "c"} def test_empty_to_polars(df): # Skip test if polars is not installed pl = pytest.importorskip("polars") # Convert empty datafusion dataframe to polars dataframe polars_df = df.limit(0).to_polars() assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (0, 3) assert set(polars_df.columns) == {"a", "b", "c"} def test_to_arrow_table(df): # Convert datafusion dataframe to pyarrow Table pyarrow_table = df.to_arrow_table() assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} def test_execute_stream(df): stream = df.execute_stream() assert all(batch is not None for batch in stream) assert not list(stream) # after one iteration the generator must be exhausted @pytest.mark.asyncio async def test_execute_stream_async(df): stream = df.execute_stream() batches = [batch async for batch in stream] assert all(batch is not None for batch in batches) # After consuming all batches, the stream should be exhausted remaining_batches = [batch async for batch in stream] assert not remaining_batches @pytest.mark.parametrize("schema", [True, False]) def test_execute_stream_to_arrow_table(df, schema): stream = df.execute_stream() if schema: pyarrow_table = pa.Table.from_batches( (batch.to_pyarrow() for batch in stream), schema=df.schema() ) else: pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} @pytest.mark.asyncio @pytest.mark.parametrize("schema", [True, False]) async def test_execute_stream_to_arrow_table_async(df, schema): stream = df.execute_stream() if schema: pyarrow_table = pa.Table.from_batches( [batch.to_pyarrow() async for batch in stream], schema=df.schema() ) else: pyarrow_table = pa.Table.from_batches( [batch.to_pyarrow() async for batch in stream] ) assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} def test_execute_stream_partitioned(df): streams = df.execute_stream_partitioned() assert all(batch is not None for stream in streams for batch in stream) assert all( not list(stream) for stream in streams ) # after one iteration all generators must be exhausted @pytest.mark.asyncio async def test_execute_stream_partitioned_async(df): streams = df.execute_stream_partitioned() for stream in streams: batches = [batch async for batch in stream] assert all(batch is not None for batch in batches) # Ensure the stream is exhausted after iteration remaining_batches = [batch async for batch in stream] assert not remaining_batches def test_empty_to_arrow_table(df): # Convert empty datafusion dataframe to pyarrow Table pyarrow_table = df.limit(0).to_arrow_table() assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (0, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} def test_to_pylist(df): # Convert datafusion dataframe to Python list pylist = df.to_pylist() assert isinstance(pylist, list) assert pylist == [ {"a": 1, "b": 4, "c": 8}, {"a": 2, "b": 5, "c": 5}, {"a": 3, "b": 6, "c": 8}, ] def test_to_pydict(df): # Convert datafusion dataframe to Python dictionary pydict = df.to_pydict() assert isinstance(pydict, dict) assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]} def test_describe(df): # Calculate statistics df = df.describe() # Collect the result result = df.to_pydict() assert result == { "describe": [ "count", "null_count", "mean", "std", "min", "max", "median", ], "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0], "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0], "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], } @pytest.mark.parametrize("path_to_str", [True, False]) def test_write_csv(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path df.write_csv(path, with_header=True) ctx.register_csv("csv", path) result = ctx.table("csv").to_pydict() expected = df.to_pydict() assert result == expected @pytest.mark.parametrize("path_to_str", [True, False]) def test_write_json(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path df.write_json(path) ctx.register_json("json", path) result = ctx.table("json").to_pydict() expected = df.to_pydict() assert result == expected @pytest.mark.parametrize("path_to_str", [True, False]) def test_write_parquet(df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path df.write_parquet(str(path)) result = pq.read_table(str(path)).to_pydict() expected = df.to_pydict() assert result == expected @pytest.mark.parametrize( ("compression", "compression_level"), [("gzip", 6), ("brotli", 7), ("zstd", 15)], ) def test_write_compressed_parquet(df, tmp_path, compression, compression_level): path = tmp_path df.write_parquet( str(path), compression=compression, compression_level=compression_level ) # test that the actual compression scheme is the one written for _root, _dirs, files in os.walk(path): for file in files: if file.endswith(".parquet"): metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() for row_group in metadata["row_groups"]: for columns in row_group["columns"]: assert columns["compression"].lower() == compression result = pq.read_table(str(path)).to_pydict() expected = df.to_pydict() assert result == expected @pytest.mark.parametrize( ("compression", "compression_level"), [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], ) def test_write_compressed_parquet_wrong_compression_level( df, tmp_path, compression, compression_level ): path = tmp_path with pytest.raises(ValueError): df.write_parquet( str(path), compression=compression, compression_level=compression_level, ) @pytest.mark.parametrize("compression", ["wrong"]) def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): path = tmp_path with pytest.raises(ValueError): df.write_parquet(str(path), compression=compression) # not testing lzo because it it not implemented yet # https://github.com/apache/arrow-rs/issues/6970 @pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): # Test write_parquet with zstd, brotli, gzip default compression level, # ie don't specify compression level # should complete without error path = tmp_path df.write_parquet(str(path), compression=compression) def test_write_parquet_with_options_default_compression(df, tmp_path): """Test that the default compression is ZSTD.""" df.write_parquet(tmp_path) for file in tmp_path.rglob("*.parquet"): metadata = pq.ParquetFile(file).metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: assert col["compression"].lower() == "zstd" @pytest.mark.parametrize( "compression", ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"], ) def test_write_parquet_with_options_compression(df, tmp_path, compression): import re path = tmp_path df.write_parquet_with_options( str(path), ParquetWriterOptions(compression=compression) ) # test that the actual compression scheme is the one written for _root, _dirs, files in os.walk(path): for file in files: if file.endswith(".parquet"): metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: assert col["compression"].lower() == re.sub( r"\(\d+\)", "", compression ) result = pq.read_table(str(path)).to_pydict() expected = df.to_pydict() assert result == expected @pytest.mark.parametrize( "compression", ["gzip(12)", "brotli(15)", "zstd(23)"], ) def test_write_parquet_with_options_wrong_compression_level(df, tmp_path, compression): path = tmp_path with pytest.raises(Exception, match=r"valid compression range .*? exceeded."): df.write_parquet_with_options( str(path), ParquetWriterOptions(compression=compression) ) @pytest.mark.parametrize("compression", ["wrong", "wrong(12)"]) def test_write_parquet_with_options_invalid_compression(df, tmp_path, compression): path = tmp_path with pytest.raises(Exception, match="Unknown or unsupported parquet compression"): df.write_parquet_with_options( str(path), ParquetWriterOptions(compression=compression) ) @pytest.mark.parametrize( ("writer_version", "format_version"), [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")], ) def test_write_parquet_with_options_writer_version( df, tmp_path, writer_version, format_version ): """Test the Parquet writer version. Note that writer_version=2.0 results in format_version=2.6""" if writer_version is None: df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) else: df.write_parquet_with_options( tmp_path, ParquetWriterOptions(writer_version=writer_version) ) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() assert metadata["format_version"] == format_version @pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"]) def test_write_parquet_with_options_wrong_writer_version(df, tmp_path, writer_version): """Test that invalid writer versions in Parquet throw an exception.""" with pytest.raises( Exception, match="Unknown or unsupported parquet writer version" ): df.write_parquet_with_options( tmp_path, ParquetWriterOptions(writer_version=writer_version) ) @pytest.mark.parametrize("dictionary_enabled", [True, False, None]) def test_write_parquet_with_options_dictionary_enabled( df, tmp_path, dictionary_enabled ): """Test enabling/disabling the dictionaries in Parquet.""" df.write_parquet_with_options( tmp_path, ParquetWriterOptions(dictionary_enabled=dictionary_enabled) ) # by default, the dictionary is enabled, so None results in True result = dictionary_enabled if dictionary_enabled is not None else True for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: assert col["has_dictionary_page"] == result @pytest.mark.parametrize( ("statistics_enabled", "has_statistics"), [("page", True), ("chunk", True), ("none", False), (None, True)], ) def test_write_parquet_with_options_statistics_enabled( df, tmp_path, statistics_enabled, has_statistics ): """Test configuring the statistics in Parquet. In pyarrow we can only check for column-level statistics, so "page" and "chunk" are tested in the same way.""" df.write_parquet_with_options( tmp_path, ParquetWriterOptions(statistics_enabled=statistics_enabled) ) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: if has_statistics: assert col["statistics"] is not None else: assert col["statistics"] is None @pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000]) def test_write_parquet_with_options_max_row_group_size( large_df, tmp_path, max_row_group_size ): """Test configuring the max number of rows per group in Parquet. These test cases guarantee that the number of rows for each row group is max_row_group_size, given the total number of rows is a multiple of max_row_group_size.""" path = f"{tmp_path}/t.parquet" large_df.write_parquet_with_options( path, ParquetWriterOptions(max_row_group_size=max_row_group_size) ) parquet = pq.ParquetFile(path) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: assert row_group["num_rows"] == max_row_group_size @pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"]) def test_write_parquet_with_options_created_by(df, tmp_path, created_by): """Test configuring the created by metadata in Parquet.""" df.write_parquet_with_options(tmp_path, ParquetWriterOptions(created_by=created_by)) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() assert metadata["created_by"] == created_by @pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50]) def test_write_parquet_with_options_statistics_truncate_length( df, tmp_path, statistics_truncate_length ): """Test configuring the truncate limit in Parquet's row-group-level statistics.""" ctx = SessionContext() data = { "a": [ "a_the_quick_brown_fox_jumps_over_the_lazy_dog", "m_the_quick_brown_fox_jumps_over_the_lazy_dog", "z_the_quick_brown_fox_jumps_over_the_lazy_dog", ], "b": ["a_smaller", "m_smaller", "z_smaller"], } df = ctx.from_arrow(pa.record_batch(data)) df.write_parquet_with_options( tmp_path, ParquetWriterOptions(statistics_truncate_length=statistics_truncate_length), ) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: statistics = col["statistics"] assert len(statistics["min"]) <= statistics_truncate_length assert len(statistics["max"]) <= statistics_truncate_length def test_write_parquet_with_options_default_encoding(tmp_path): """Test that, by default, Parquet files are written with dictionary encoding. Note that dictionary encoding is not used for boolean values, so it is not tested here.""" ctx = SessionContext() data = { "a": [1, 2, 3], "b": ["1", "2", "3"], "c": [1.01, 2.02, 3.03], } df = ctx.from_arrow(pa.record_batch(data)) df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY") @pytest.mark.parametrize( ("encoding", "data_types", "result"), [ ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")), ("rle", ["bool"], ("RLE",)), ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")), ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")), ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")), ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")), ], ) def test_write_parquet_with_options_encoding(tmp_path, encoding, data_types, result): """Test different encodings in Parquet in their respective support column types.""" ctx = SessionContext() data = {} for data_type in data_types: if data_type == "int": data["int"] = [1, 2, 3] elif data_type == "float": data["float"] = [1.01, 2.02, 3.03] elif data_type == "str": data["str"] = ["a", "b", "c"] elif data_type == "bool": data["bool"] = [True, False, True] df = ctx.from_arrow(pa.record_batch(data)) df.write_parquet_with_options( tmp_path, ParquetWriterOptions(encoding=encoding, dictionary_enabled=False) ) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: assert col["encodings"] == result @pytest.mark.parametrize("encoding", ["bit_packed"]) def test_write_parquet_with_options_unsupported_encoding(df, tmp_path, encoding): """Test that unsupported Parquet encodings do not work.""" # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 with pytest.raises(BaseException, match="Encoding .*? is not supported"): df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) @pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"]) def test_write_parquet_with_options_invalid_encoding(df, tmp_path, encoding): """Test that invalid Parquet encodings do not work.""" with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"): df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) @pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"]) def test_write_parquet_with_options_dictionary_encoding_fallback( df, tmp_path, encoding ): """Test that the dictionary encoding cannot be used as fallback in Parquet.""" # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 with pytest.raises( BaseException, match="Dictionary encoding can not be used as fallback encoding" ): df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) def test_write_parquet_with_options_bloom_filter(df, tmp_path): """Test Parquet files with and without (default) bloom filters. Since pyarrow does not expose any information about bloom filters, the easiest way to confirm that they are actually written is to compare the file size.""" path_no_bloom_filter = tmp_path / "1" path_bloom_filter = tmp_path / "2" df.write_parquet_with_options(path_no_bloom_filter, ParquetWriterOptions()) df.write_parquet_with_options( path_bloom_filter, ParquetWriterOptions(bloom_filter_on_write=True) ) size_no_bloom_filter = 0 for file in path_no_bloom_filter.rglob("*.parquet"): size_no_bloom_filter += os.path.getsize(file) size_bloom_filter = 0 for file in path_bloom_filter.rglob("*.parquet"): size_bloom_filter += os.path.getsize(file) assert size_no_bloom_filter < size_bloom_filter def test_write_parquet_with_options_column_options(df, tmp_path): """Test writing Parquet files with different options for each column, which replace the global configs (when provided).""" data = { "a": [1, 2, 3], "b": ["a", "b", "c"], "c": [False, True, False], "d": [1.01, 2.02, 3.03], "e": [4, 5, 6], } column_specific_options = { "a": ParquetColumnOptions(statistics_enabled="none"), "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False), "c": ParquetColumnOptions( compression="snappy", encoding="rle", dictionary_enabled=False ), "d": ParquetColumnOptions( compression="zstd(6)", encoding="byte_stream_split", dictionary_enabled=False, statistics_enabled="none", ), # column "e" will use the global configs } results = { "a": { "statistics": False, "compression": "brotli", "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), }, "b": { "statistics": True, "compression": "brotli", "encodings": ("PLAIN", "RLE"), }, "c": { "statistics": True, "compression": "snappy", "encodings": ("RLE",), }, "d": { "statistics": False, "compression": "zstd", "encodings": ("RLE", "BYTE_STREAM_SPLIT"), }, "e": { "statistics": True, "compression": "brotli", "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), }, } ctx = SessionContext() df = ctx.from_arrow(pa.record_batch(data)) df.write_parquet_with_options( tmp_path, ParquetWriterOptions( compression="brotli(8)", column_specific_options=column_specific_options ), ) for file in tmp_path.rglob("*.parquet"): parquet = pq.ParquetFile(file) metadata = parquet.metadata.to_dict() for row_group in metadata["row_groups"]: for col in row_group["columns"]: column_name = col["path_in_schema"] result = results[column_name] assert (col["statistics"] is not None) == result["statistics"] assert col["compression"].lower() == result["compression"].lower() assert col["encodings"] == result["encodings"] def test_write_parquet_options(df, tmp_path): options = ParquetWriterOptions(compression="gzip", compression_level=6) df.write_parquet(str(tmp_path), options) result = pq.read_table(str(tmp_path)).to_pydict() expected = df.to_pydict() assert result == expected def test_write_parquet_options_error(df, tmp_path): options = ParquetWriterOptions(compression="gzip", compression_level=6) with pytest.raises(ValueError): df.write_parquet(str(tmp_path), options, compression_level=1) def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export table = pa.table(df) assert table.num_columns == 3 assert table.num_rows == 3 desired_schema = pa.schema([("a", pa.int64())]) # Verify we can request a schema table = pa.table(df, schema=desired_schema) assert table.num_columns == 1 assert table.num_rows == 3 # Expect a table of nulls if the schema don't overlap desired_schema = pa.schema([("g", pa.string())]) table = pa.table(df, schema=desired_schema) assert table.num_columns == 1 assert table.num_rows == 3 for i in range(3): assert table[0][i].as_py() is None # Expect an error when we cannot convert schema desired_schema = pa.schema([("a", pa.float32())]) failed_convert = False try: table = pa.table(df, schema=desired_schema) except Exception: failed_convert = True assert failed_convert # Expect an error when we have a not set non-nullable desired_schema = pa.schema([("g", pa.string(), False)]) failed_convert = False try: table = pa.table(df, schema=desired_schema) except Exception: failed_convert = True assert failed_convert def test_dataframe_transform(df): def add_string_col(df_internal) -> DataFrame: return df_internal.with_column("string_col", literal("string data")) def add_with_parameter(df_internal, value: Any) -> DataFrame: return df_internal.with_column("new_col", literal(value)) df = df.transform(add_string_col).transform(add_with_parameter, 3) result = df.to_pydict() assert result["a"] == [1, 2, 3] assert result["string_col"] == ["string data" for _i in range(3)] assert result["new_col"] == [3 for _i in range(3)] def test_dataframe_repr_html_structure(df, clean_formatter_state) -> None: """Test that DataFrame._repr_html_ produces expected HTML output structure.""" output = df._repr_html_() # Since we've added a fair bit of processing to the html output, lets just verify # the values we are expecting in the table exist. Use regex and ignore everything # between the and . We also don't want the closing > on the # td and th segments because that is where the formatting data is written. headers = ["a", "b", "c"] headers = [f"{v}" for v in headers] header_pattern = "(.*?)".join(headers) header_matches = re.findall(header_pattern, output, re.DOTALL) assert len(header_matches) == 1 # Update the pattern to handle values that may be wrapped in spans body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] body_lines = [ f"(?:]*?>)?{v}(?:)?" for inner in body_data for v in inner ] body_pattern = "(.*?)".join(body_lines) body_matches = re.findall(body_pattern, output, re.DOTALL) assert len(body_matches) == 1, "Expected pattern of values not found in HTML output" def test_dataframe_repr_html_values(df, clean_formatter_state): """Test that DataFrame._repr_html_ contains the expected data values.""" html = df._repr_html_() assert html is not None # Create a more flexible pattern that handles values being wrapped in spans # This pattern will match the sequence of values 1,4,8,2,5,5,3,6,8 regardless # of formatting pattern = re.compile( r"]*?>(?:]*?>)?1(?:)?.*?" r"]*?>(?:]*?>)?4(?:)?.*?" r"]*?>(?:]*?>)?8(?:)?.*?" r"]*?>(?:]*?>)?2(?:)?.*?" r"]*?>(?:]*?>)?5(?:)?.*?" r"]*?>(?:]*?>)?5(?:)?.*?" r"]*?>(?:]*?>)?3(?:)?.*?" r"]*?>(?:]*?>)?6(?:)?.*?" r"]*?>(?:]*?>)?8(?:)?", re.DOTALL, ) # Print debug info if the test fails matches = re.findall(pattern, html) if not matches: print(f"HTML output snippet: {html[:500]}...") # noqa: T201 assert len(matches) > 0, "Expected pattern of values not found in HTML output" def test_html_formatter_shared_styles(df, clean_formatter_state): """Test that shared styles work correctly across multiple tables.""" # First, ensure we're using shared styles configure_formatter(use_shared_styles=True) html_first = df._repr_html_() html_second = df._repr_html_() assert "