From 73a9d53a37f6ce864b68dda1b07e92a0fed8c8ba Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 31 Mar 2026 01:57:32 -0700 Subject: [PATCH 01/29] CI: Add CodeQL workflow for GitHub Actions security scanning (#1408) * CI: Add CodeQL workflow for GitHub Actions security scanning * Update .github/workflows/codeql.yml --- .github/workflows/codeql.yml | 54 ++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 .github/workflows/codeql.yml diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..a9855cf48 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: '16 4 * * 1' + +permissions: + contents: read + +jobs: + analyze: + name: Analyze Actions + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + persist-credentials: false + + - name: Initialize CodeQL + uses: github/codeql-action/init@c793b717bc78562f491db7b0e93a3a178b099162 # v4 + with: + languages: actions + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@c793b717bc78562f491db7b0e93a3a178b099162 # v4 + with: + category: "/language:actions" From 24994099e41a4e933f883557e2bce1a963bac0ea Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 14:09:16 -0400 Subject: [PATCH 02/29] ci: update codespell paths (#1469) * Update path so it works well with pre-commit * Prefix path with asterisk so we get matching in both CI and pre-commit * Update paths for codespell --- pyproject.toml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d05a64083..327199d1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,12 +170,15 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "docs/*" = ["D"] "docs/source/conf.py" = ["ANN001", "ERA001", "INP001"] +# CI and pre-commit invoke codespell with different paths, so we have a little +# redundancy here, and we intentionally drop python in the path. [tool.codespell] skip = [ - "./python/tests/test_functions.py", - "./target", + "*/tests/test_functions.py", + "*/target", + "./uv.lock", "uv.lock", - "./examples/tpch/answers_sf1/*", + "*/tpch/answers_sf1/*", ] count = true ignore-words-list = ["IST", "ans"] From 0113a6ee55cc61f9ebd897ae8cfc9213f560e468 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 2 Apr 2026 17:47:47 -0400 Subject: [PATCH 03/29] Add missing datetime functions (#1467) * Add missing datetime functions: make_time, current_timestamp, date_format Closes #1451. Adds make_time Rust binding and Python wrapper, and adds current_timestamp (alias for now) and date_format (alias for to_char) Python functions. Co-Authored-By: Claude Opus 4.6 (1M context) * Add unit tests for make_time, current_timestamp, and date_format Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 2 ++ python/datafusion/functions.py | 36 ++++++++++++++++++++++++++++++++++ python/tests/test_functions.py | 33 +++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index c32134054..6996dca94 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -616,6 +616,7 @@ expr_fn!(date_part, part date); expr_fn!(date_trunc, part date); expr_fn!(date_bin, stride source origin); expr_fn!(make_date, year month day); +expr_fn!(make_time, hour minute second); expr_fn!(to_char, datetime format); expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); @@ -974,6 +975,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(date_part))?; m.add_wrapped(wrap_pyfunction!(date_trunc))?; m.add_wrapped(wrap_pyfunction!(make_date))?; + m.add_wrapped(wrap_pyfunction!(make_time))?; m.add_wrapped(wrap_pyfunction!(digest))?; m.add_wrapped(wrap_pyfunction!(ends_with))?; m.add_wrapped(wrap_pyfunction!(exp))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f062cbfce..3c8d2bcee 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -128,7 +128,9 @@ "cume_dist", "current_date", "current_time", + "current_timestamp", "date_bin", + "date_format", "date_part", "date_trunc", "datepart", @@ -200,6 +202,7 @@ "make_array", "make_date", "make_list", + "make_time", "max", "md5", "mean", @@ -1948,6 +1951,15 @@ def now() -> Expr: return Expr(f.now()) +def current_timestamp() -> Expr: + """Returns the current timestamp in nanoseconds. + + See Also: + This is an alias for :py:func:`now`. + """ + return now() + + def to_char(arg: Expr, formatter: Expr) -> Expr: """Returns a string representation of a date, time, timestamp or duration. @@ -1970,6 +1982,15 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: return Expr(f.to_char(arg.expr, formatter.expr)) +def date_format(arg: Expr, formatter: Expr) -> Expr: + """Returns a string representation of a date, time, timestamp or duration. + + See Also: + This is an alias for :py:func:`to_char`. + """ + return to_char(arg, formatter) + + def _unwrap_exprs(args: tuple[Expr, ...]) -> list: return [arg.expr for arg in args] @@ -2270,6 +2291,21 @@ def make_date(year: Expr, month: Expr, day: Expr) -> Expr: return Expr(f.make_date(year.expr, month.expr, day.expr)) +def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: + """Make a time from hour, minute and second component parts. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"h": [12], "m": [30], "s": [0]}) + >>> result = df.select( + ... dfn.functions.make_time(dfn.col("h"), dfn.col("m"), + ... dfn.col("s")).alias("t")) + >>> result.collect_column("t")[0].as_py() + datetime.time(12, 30) + """ + return Expr(f.make_time(hour.expr, minute.expr, second.expr)) + + def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 37d349c58..08420826d 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1107,6 +1107,39 @@ def test_today_alias_matches_current_date(df): assert result.column(0) == result.column(1) +def test_current_timestamp_alias_matches_now(df): + result = df.select( + f.now().alias("now"), + f.current_timestamp().alias("current_timestamp"), + ).collect()[0] + + assert result.column(0) == result.column(1) + + +def test_date_format_alias_matches_to_char(df): + result = df.select( + f.to_char( + f.to_timestamp(literal("2021-01-01T00:00:00")), literal("%Y/%m/%d") + ).alias("to_char"), + f.date_format( + f.to_timestamp(literal("2021-01-01T00:00:00")), literal("%Y/%m/%d") + ).alias("date_format"), + ).collect()[0] + + assert result.column(0) == result.column(1) + assert result.column(0)[0].as_py() == "2021/01/01" + + +def test_make_time(df): + ctx = SessionContext() + df_time = ctx.from_pydict({"h": [12], "m": [30], "s": [0]}) + result = df_time.select( + f.make_time(column("h"), column("m"), column("s")).alias("t") + ).collect()[0] + + assert result.column(0)[0].as_py() == time(12, 30) + + def test_arrow_cast(df): df = df.select( # we use `string_literal` to return utf8 instead of `literal` which returns From be8dd9d08fd284cf1747a2c1b965d9c95fff117c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 3 Apr 2026 09:37:00 -0400 Subject: [PATCH 04/29] Add AI skill to check current repository against upstream APIs (#1460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial commit for skill to check upstream repo * Add instructions on using the check-upstream skill * Add FFI type coverage and implementation pattern to check-upstream skill Document the full FFI type pipeline (Rust PyO3 wrapper → Protocol type → Python wrapper → ABC base class → exports → example) and catalog which upstream datafusion-ffi types are supported, which have been evaluated as not needing direct exposure, and how to check for new gaps. Co-Authored-By: Claude Opus 4.6 (1M context) * Update check-upstream skill to include FFI types as a checkable area Add "ffi types" to the argument-hint and description so users can invoke the skill with `/check-upstream ffi types`. Also add pipeline verification step to ensure each supported FFI type has the full end-to-end chain (PyO3 wrapper, Protocol, Python wrapper with type hints, ABC, exports). Co-Authored-By: Claude Opus 4.6 (1M context) * Move FFI Types section alongside other areas to check Section 7 (FFI Types) was incorrectly placed after the Output Format and Implementation Pattern sections. Move it to sit after Section 6 (SessionContext Methods), consistent with the other checkable areas. Co-Authored-By: Claude Opus 4.6 (1M context) * Replace static FFI type list with dynamic discovery instruction The supported FFI types list would go stale as new types are added. Replace it with a grep instruction to discover them at check time, keeping only the "evaluated and not requiring exposure" list which captures rationale not derivable from code. Co-Authored-By: Claude Opus 4.6 (1M context) * Make Python API the source of truth for upstream coverage checks Functions exposed in Python (e.g., as aliases of other Rust bindings) were being falsely reported as missing because they lacked a dedicated #[pyfunction] in Rust. The user-facing API is the Python layer, so coverage should be measured there. Co-Authored-By: Claude Opus 4.6 (1M context) * Add exclusion list for DataFrame methods already covered by Python API show_limit is covered by DataFrame.show() and with_param_values is covered by SessionContext.sql(param_values=...), so neither needs separate exposure. Co-Authored-By: Claude Opus 4.6 (1M context) * Move skills to .ai/skills/ for tool-agnostic discoverability Moves the canonical skill definitions from .claude/skills/ to .ai/skills/ and replaces .claude/skills with a symlink, so Claude Code still discovers them while other AI agents can find them in a tool-neutral location. Co-Authored-By: Claude Opus 4.6 (1M context) * Add AGENTS.md for tool-agnostic agent instructions with CLAUDE.md symlink AGENTS.md points agents to .ai/skills/ for skill discovery. CLAUDE.md symlinks to it so Claude Code picks it up as project instructions. Co-Authored-By: Claude Opus 4.6 (1M context) * Make README upstream coverage section tool-agnostic Remove Claude Code references and update skill path from .claude/skills/ to .ai/skills/ to match the new tool-neutral directory structure. Co-Authored-By: Claude Opus 4.6 (1M context) * Add GitHub issue lookup step to check-upstream skill When gaps are identified, search open issues at apache/datafusion-python before reporting. Existing issues are linked in the report rather than duplicated. Co-Authored-By: Claude Opus 4.6 (1M context) * Require Python test coverage in issues created by check-upstream skill Co-Authored-By: Claude Opus 4.6 (1M context) * Add license text --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .ai/skills/check-upstream/SKILL.md | 382 +++++++++++++++++++++++++++++ .claude/skills | 1 + AGENTS.md | 27 ++ CLAUDE.md | 1 + README.md | 27 ++ 5 files changed, 438 insertions(+) create mode 100644 .ai/skills/check-upstream/SKILL.md create mode 120000 .claude/skills create mode 100644 AGENTS.md create mode 120000 CLAUDE.md diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md new file mode 100644 index 000000000..f77210371 --- /dev/null +++ b/.ai/skills/check-upstream/SKILL.md @@ -0,0 +1,382 @@ + + +--- +name: check-upstream +description: Check if upstream Apache DataFusion features (functions, DataFrame ops, SessionContext methods, FFI types) are exposed in this Python project. Use when adding missing functions, auditing API coverage, or ensuring parity with upstream. +argument-hint: [area] (e.g., "scalar functions", "aggregate functions", "window functions", "dataframe", "session context", "ffi types", "all") +--- + +# Check Upstream DataFusion Feature Coverage + +You are auditing the datafusion-python project to find features from the upstream Apache DataFusion Rust library that are **not yet exposed** in this Python binding project. Your goal is to identify gaps and, if asked, implement the missing bindings. + +**IMPORTANT: The Python API is the source of truth for coverage.** A function or method is considered "exposed" if it exists in the Python API (e.g., `python/datafusion/functions.py`), even if there is no corresponding entry in the Rust bindings. Many upstream functions are aliases of other functions — the Python layer can expose these aliases by calling a different underlying Rust binding. Do NOT report a function as missing if it appears in the Python `__all__` list and has a working implementation, regardless of whether a matching `#[pyfunction]` exists in Rust. + +## Areas to Check + +The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" is given, check all areas. + +### 1. Scalar Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/scalar_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions` +- Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()` + +**How to check:** +1. Fetch the upstream scalar function documentation page +2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions) +3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding. +4. Only report functions that are missing from the Python `__all__` list / function definitions + +### 2. Aggregate Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_aggregate/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/aggregate_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions) +- Rust bindings: `crates/core/src/functions.rs` + +**How to check:** +1. Fetch the upstream aggregate function documentation page +2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions) +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 3. Window Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_window/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/window_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` (window functions like `rank`, `dense_rank`, `lag`, `lead`, etc.) +- Rust bindings: `crates/core/src/functions.rs` + +**How to check:** +1. Fetch the upstream window function documentation page +2. Compare against window functions in `python/datafusion/functions.py` (check `__all__` list and function definitions) +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 4. Table Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_table/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/table_functions.html (if available) + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` and `python/datafusion/user_defined.py` (TableFunction/udtf) +- Rust bindings: `crates/core/src/functions.rs` and `crates/core/src/udtf.rs` + +**How to check:** +1. Fetch the upstream table function documentation +2. Compare against what's available in the Python API +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 5. DataFrame Operations + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/dataframe.py` — the `DataFrame` class +- Rust bindings: `crates/core/src/dataframe.rs` — `PyDataFrame` with `#[pymethods]` + +**Evaluated and not requiring separate Python exposure:** +- `show_limit` — already covered by `DataFrame.show()`, which provides the same functionality with a simpler API +- `with_param_values` — already covered by the `param_values` argument on `SessionContext.sql()`, which accomplishes the same thing more robustly + +**How to check:** +1. Fetch the upstream DataFrame documentation page listing all methods +2. Compare against methods in `python/datafusion/dataframe.py` — this is the source of truth for coverage +3. The Rust bindings (`crates/core/src/dataframe.rs`) may be consulted for context, but a method is covered if it exists in the Python API +4. Check against the "evaluated and not requiring exposure" list before flagging as a gap +5. Report only methods missing from the Python API + +### 6. SessionContext Methods + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/context.py` — the `SessionContext` class +- Rust bindings: `crates/core/src/context.rs` — `PySessionContext` with `#[pymethods]` + +**How to check:** +1. Fetch the upstream SessionContext documentation page listing all methods +2. Compare against methods in `python/datafusion/context.py` — this is the source of truth for coverage +3. The Rust bindings (`crates/core/src/context.rs`) may be consulted for context, but a method is covered if it exists in the Python API +4. Report only methods missing from the Python API + +### 7. FFI Types (datafusion-ffi) + +**Upstream source of truth:** +- Crate source: https://github.com/apache/datafusion/tree/main/datafusion/ffi/src +- Rust docs: https://docs.rs/datafusion-ffi/latest/datafusion_ffi/ + +**Where they are exposed in this project:** +- Rust bindings: various files under `crates/core/src/` and `crates/util/src/` +- FFI example: `examples/datafusion-ffi-example/src/` +- Dependency declared in root `Cargo.toml` and `crates/core/Cargo.toml` + +**Discovering currently supported FFI types:** +Grep for `use datafusion_ffi::` in `crates/core/src/` and `crates/util/src/` to find all FFI types currently imported and used. + +**Evaluated and not requiring direct Python exposure:** +These upstream FFI types have been reviewed and do not need to be independently exposed to end users: +- `FFI_ExecutionPlan` — already used indirectly through table providers; no need for direct exposure +- `FFI_PhysicalExpr` / `FFI_PhysicalSortExpr` — internal physical planning types not expected to be needed by end users +- `FFI_RecordBatchStream` — one level deeper than FFI_ExecutionPlan, used internally when execution plans stream results +- `FFI_SessionRef` / `ForeignSession` — session sharing across FFI; Python manages sessions natively via SessionContext +- `FFI_SessionConfig` — Python can configure sessions natively without FFI +- `FFI_ConfigOptions` / `FFI_TableOptions` — internal configuration plumbing +- `FFI_PlanProperties` / `FFI_Boundedness` / `FFI_EmissionType` — read from existing plans, not user-facing +- `FFI_Partitioning` — supporting type for physical planning +- Supporting/utility types (`FFI_Option`, `FFI_Result`, `WrappedSchema`, `WrappedArray`, `FFI_ColumnarValue`, `FFI_Volatility`, `FFI_InsertOp`, `FFI_AccumulatorArgs`, `FFI_Accumulator`, `FFI_GroupsAccumulator`, `FFI_EmitTo`, `FFI_AggregateOrderSensitivity`, `FFI_PartitionEvaluator`, `FFI_PartitionEvaluatorArgs`, `FFI_Range`, `FFI_SortOptions`, `FFI_Distribution`, `FFI_ExprProperties`, `FFI_SortProperties`, `FFI_Interval`, `FFI_TableProviderFilterPushDown`, `FFI_TableType`) — used as building blocks within the types above, not independently exposed + +**How to check:** +1. Discover currently supported types by grepping for `use datafusion_ffi::` in `crates/core/src/` and `crates/util/src/`, then compare against the upstream `datafusion-ffi` crate's `lib.rs` exports +2. If new FFI types appear upstream, evaluate whether they represent a user-facing capability +3. Check against the "evaluated and not requiring exposure" list before flagging as a gap +4. Report any genuinely new types that enable user-facing functionality +5. For each currently supported FFI type, verify the full pipeline is present using the checklist from "Adding a New FFI Type": + - Rust PyO3 wrapper with `from_pycapsule()` method + - Python Protocol type (e.g., `ScalarUDFExportable`) for FFI objects + - Python wrapper class with full type hints on all public methods + - ABC base class (if the type can be user-implemented) + - Registered in Rust `init_module()` and Python `__init__.py` + - FFI example in `examples/datafusion-ffi-example/` + - Type appears in union type hints where accepted + +## Checking for Existing GitHub Issues + +After identifying missing APIs, search the open issues at https://github.com/apache/datafusion-python/issues for each gap to see if an issue already exists requesting that API be exposed. Search using the function or method name as the query. + +- If an existing issue is found, include a link to it in the report. Do NOT create a new issue. +- If no existing issue is found, note that no issue exists yet. If the user asks to create issues for missing APIs, each issue should specify that Python test coverage is required as part of the implementation. + +## Output Format + +For each area checked, produce a report like: + +``` +## [Area Name] Coverage Report + +### Currently Exposed (X functions/methods) +- list of what's already available + +### Missing from Upstream (Y functions/methods) +- function_name — brief description of what it does (existing issue: #123) +- function_name — brief description of what it does (no existing issue) + +### Notes +- Any relevant observations about partial implementations, naming differences, etc. +``` + +## Implementation Pattern + +If the user asks you to implement missing features, follow these patterns: + +### Adding a New Function (Scalar/Aggregate/Window) + +**Step 1: Rust binding** in `crates/core/src/functions.rs`: +```rust +#[pyfunction] +#[pyo3(signature = (arg1, arg2))] +fn new_function_name(arg1: PyExpr, arg2: PyExpr) -> PyResult { + Ok(datafusion::functions::module::expr_fn::new_function_name(arg1.expr, arg2.expr).into()) +} +``` +Then register in `init_module()`: +```rust +m.add_wrapped(wrap_pyfunction!(new_function_name))?; +``` + +**Step 2: Python wrapper** in `python/datafusion/functions.py`: +```python +def new_function_name(arg1: Expr, arg2: Expr) -> Expr: + """Description of what the function does. + + Args: + arg1: Description of first argument. + arg2: Description of second argument. + + Returns: + Description of return value. + """ + return Expr(f.new_function_name(arg1.expr, arg2.expr)) +``` +Add to `__all__` list. + +### Adding a New DataFrame Method + +**Step 1: Rust binding** in `crates/core/src/dataframe.rs`: +```rust +#[pymethods] +impl PyDataFrame { + fn new_method(&self, py: Python, param: PyExpr) -> PyDataFusionResult { + let df = self.df.as_ref().clone().new_method(param.into())?; + Ok(Self::new(df)) + } +} +``` + +**Step 2: Python wrapper** in `python/datafusion/dataframe.py`: +```python +def new_method(self, param: Expr) -> DataFrame: + """Description of the method.""" + return DataFrame(self.df.new_method(param.expr)) +``` + +### Adding a New SessionContext Method + +**Step 1: Rust binding** in `crates/core/src/context.rs`: +```rust +#[pymethods] +impl PySessionContext { + pub fn new_method(&self, py: Python, param: String) -> PyDataFusionResult { + let df = wait_for_future(py, self.ctx.new_method(¶m))?; + Ok(PyDataFrame::new(df)) + } +} +``` + +**Step 2: Python wrapper** in `python/datafusion/context.py`: +```python +def new_method(self, param: str) -> DataFrame: + """Description of the method.""" + return DataFrame(self.ctx.new_method(param)) +``` + +### Adding a New FFI Type + +FFI types require a full pipeline from C struct through to a typed Python wrapper. Each layer must be present. + +**Step 1: Rust PyO3 wrapper class** in a new or existing file under `crates/core/src/`: +```rust +use datafusion_ffi::new_type::FFI_NewType; + +#[pyclass(from_py_object, frozen, name = "RawNewType", module = "datafusion.module_name", subclass)] +pub struct PyNewType { + pub inner: Arc, +} + +#[pymethods] +impl PyNewType { + #[staticmethod] + fn from_pycapsule(obj: &Bound<'_, PyAny>) -> PyDataFusionResult { + let capsule = obj + .getattr("__datafusion_new_type__")? + .call0()? + .downcast::()?; + let ffi_ptr = unsafe { capsule.reference::() }; + let provider: Arc = ffi_ptr.into(); + Ok(Self { inner: provider }) + } + + fn some_method(&self) -> PyResult<...> { + // wrap inner trait method + } +} +``` +Register in the appropriate `init_module()`: +```rust +m.add_class::()?; +``` + +**Step 2: Python Protocol type** in the appropriate Python module (e.g., `python/datafusion/catalog.py`): +```python +class NewTypeExportable(Protocol): + """Type hint for objects providing a __datafusion_new_type__ PyCapsule.""" + + def __datafusion_new_type__(self) -> object: ... +``` + +**Step 3: Python wrapper class** in the same module: +```python +class NewType: + """Description of the type. + + This class wraps a DataFusion NewType, which can be created from a native + Python implementation or imported from an FFI-compatible library. + """ + + def __init__( + self, + new_type: df_internal.module_name.RawNewType | NewTypeExportable, + ) -> None: + if isinstance(new_type, df_internal.module_name.RawNewType): + self._raw = new_type + else: + self._raw = df_internal.module_name.RawNewType.from_pycapsule(new_type) + + def some_method(self) -> ReturnType: + """Description of the method.""" + return self._raw.some_method() +``` + +**Step 4: ABC base class** (if users should be able to subclass and provide custom implementations in Python): +```python +from abc import ABC, abstractmethod + +class NewTypeProvider(ABC): + """Abstract base class for implementing a custom NewType in Python.""" + + @abstractmethod + def some_method(self) -> ReturnType: + """Description of the method.""" + ... +``` + +**Step 5: Module exports** — add to the appropriate `__init__.py`: +- Add the wrapper class (`NewType`) to `python/datafusion/__init__.py` +- Add the ABC (`NewTypeProvider`) if applicable +- Add the Protocol type (`NewTypeExportable`) if it should be public + +**Step 6: FFI example** — add an example implementation under `examples/datafusion-ffi-example/src/`: +```rust +// examples/datafusion-ffi-example/src/new_type.rs +use datafusion_ffi::new_type::FFI_NewType; +// ... example showing how an external Rust library exposes this type via PyCapsule +``` + +**Checklist for each FFI type:** +- [ ] Rust PyO3 wrapper with `from_pycapsule()` method +- [ ] Python Protocol type (e.g., `NewTypeExportable`) for FFI objects +- [ ] Python wrapper class with full type hints on all public methods +- [ ] ABC base class (if the type can be user-implemented) +- [ ] Registered in Rust `init_module()` and Python `__init__.py` +- [ ] FFI example in `examples/datafusion-ffi-example/` +- [ ] Type appears in union type hints where accepted (e.g., `Table | TableProviderExportable`) + +## Important Notes + +- The upstream DataFusion version used by this project is specified in `crates/core/Cargo.toml` — check the `datafusion` dependency version to ensure you're comparing against the right upstream version. +- Some upstream features may intentionally not be exposed (e.g., internal-only APIs). Use judgment about what's user-facing. +- When fetching upstream docs, prefer the published docs.rs documentation as it matches the crate version. +- Function aliases (e.g., `array_append` / `list_append`) should both be exposed if upstream supports them. +- Check the `__all__` list in `functions.py` to see what's publicly exported vs just defined. diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 000000000..6838a1160 --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +../.ai/skills \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..1853a84cd --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,27 @@ + + +# Agent Instructions + +This project uses AI agent skills stored in `.ai/skills/`. Each skill is a directory containing a `SKILL.md` file with instructions for performing a specific task. + +Skills follow the [Agent Skills](https://agentskills.io) open standard. Each skill directory contains: + +- `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. +- Additional supporting files as needed. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 000000000..47dc3e3d8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/README.md b/README.md index c24257876..7c1c71281 100644 --- a/README.md +++ b/README.md @@ -312,6 +312,33 @@ There are scripts in `ci/scripts` for running Rust and Python linters. ./ci/scripts/rust_toml_fmt.sh ``` +## Checking Upstream DataFusion Coverage + +This project includes an [AI agent skill](.ai/skills/check-upstream/SKILL.md) for auditing which +features from the upstream Apache DataFusion Rust library are not yet exposed in these Python +bindings. This is useful when adding missing functions, auditing API coverage, or ensuring parity +with upstream. + +The skill accepts an optional area argument: + +``` +scalar functions +aggregate functions +window functions +dataframe +session context +ffi types +all +``` + +If no argument is provided, it defaults to checking all areas. The skill will fetch the upstream +DataFusion documentation, compare it against the functions and methods exposed in this project, and +produce a coverage report listing what is currently exposed and what is missing. + +The skill definition lives in `.ai/skills/check-upstream/SKILL.md` and follows the +[Agent Skills](https://agentskills.io) open standard. It can be used by any AI coding agent that +supports skill discovery, or followed manually. + ## How to update dependencies To change test dependencies, change the `pyproject.toml` and run From 645d261ce3bc0b3b610c8d82422042b3e573e793 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 3 Apr 2026 13:51:43 -0400 Subject: [PATCH 05/29] Add missing string function `contains` (#1465) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add missing `contains` string function Expose the upstream DataFusion `contains(string, search_str)` function which returns true if search_str is found within string (case-sensitive). Note: the other functions from #1450 (instr, position, substring_index) already exist — instr and position are aliases for strpos, and substring_index is exposed as substr_index. Closes #1450 Co-Authored-By: Claude Opus 4.6 (1M context) * Add unit test for contains string function Co-Authored-By: Claude Opus 4.6 (1M context) * Update python/datafusion/functions.py Co-authored-by: Nuno Faria --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Nuno Faria --- crates/core/src/functions.rs | 6 ++++++ python/datafusion/functions.py | 15 +++++++++++++++ python/tests/test_functions.py | 1 + 3 files changed, 22 insertions(+) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 6996dca94..fefe14b3e 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -494,6 +494,11 @@ expr_fn!(length, string); expr_fn!(char_length, string); expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); +expr_fn!( + contains, + string search_str, + "Return true if search_str is found within string (case-sensitive)." +); expr_fn!(cos, num); expr_fn!(cosh, num); expr_fn!(cot, num); @@ -961,6 +966,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(contains))?; m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(cosh))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 3c8d2bcee..2ef2f0473 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -116,6 +116,7 @@ "col", "concat", "concat_ws", + "contains", "corr", "cos", "cosh", @@ -439,6 +440,20 @@ def digest(value: Expr, method: Expr) -> Expr: return Expr(f.digest(value.expr, method.expr)) +def contains(string: Expr, search_str: Expr) -> Expr: + """Returns true if ``search_str`` is found within ``string`` (case-sensitive). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) + >>> result = df.select( + ... dfn.functions.contains(dfn.col("a"), dfn.lit("brown")).alias("c")) + >>> result.collect_column("c")[0].as_py() + True + """ + return Expr(f.contains(string.expr, search_str.expr)) + + def concat(*args: Expr) -> Expr: """Concatenates the text representations of all the arguments. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 08420826d..db141fbe0 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -745,6 +745,7 @@ def test_array_function_obj_tests(stmt, py_expr): f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"]), ), + (f.contains(column("a"), literal("ell")), pa.array([True, False, False])), (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), ( From 0b6ea95a3d304a774bbe512bb70fbca332aa5426 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 3 Apr 2026 15:43:28 -0400 Subject: [PATCH 06/29] Add missing conditional functions (#1464) * Add missing conditional functions: greatest, least, nvl2, ifnull (#1449) Expose four conditional functions from upstream DataFusion that were not yet available in the Python bindings. Co-Authored-By: Claude Opus 4.6 (1M context) * Add unit tests for greatest, least, nvl2, and ifnull functions Tests cover multiple data types (integers, strings), null handling (all-null, partial-null), multiple arguments, and ifnull/nvl equivalence. Co-Authored-By: Claude Opus 4.6 (1M context) * Use standard alias docstring pattern for ifnull Co-Authored-By: Claude Opus 4.6 (1M context) * remove unused df fixture and fix parameter shadowing * Refactor conditional function tests into parametrized test suite Replace separate test functions for coalesce, greatest, least, nvl, nvl2, ifnull with a single parametrized test using a shared fixture. Adds coverage for nvl, nullif (previously untested), datetime and boolean types, literal fallbacks, and variadic calls. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 10 ++ python/datafusion/functions.py | 69 ++++++++ python/tests/test_functions.py | 289 +++++++++++++++++++++++++++------ 3 files changed, 319 insertions(+), 49 deletions(-) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index fefe14b3e..3f07da95b 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -494,6 +494,8 @@ expr_fn!(length, string); expr_fn!(char_length, string); expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); +expr_fn_vec!(greatest); +expr_fn_vec!(least); expr_fn!( contains, string search_str, @@ -548,6 +550,11 @@ expr_fn!( x y, "Returns x if x is not NULL otherwise returns y." ); +expr_fn!( + nvl2, + x y z, + "Returns y if x is not NULL; otherwise returns z." +); expr_fn!(nullif, arg_1 arg_2); expr_fn!( octet_length, @@ -989,6 +996,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; + m.add_wrapped(wrap_pyfunction!(greatest))?; // m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; @@ -996,6 +1004,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(iszero))?; m.add_wrapped(wrap_pyfunction!(levenshtein))?; m.add_wrapped(wrap_pyfunction!(lcm))?; + m.add_wrapped(wrap_pyfunction!(least))?; m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; @@ -1013,6 +1022,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(named_struct))?; m.add_wrapped(wrap_pyfunction!(nanvl))?; m.add_wrapped(wrap_pyfunction!(nvl))?; + m.add_wrapped(wrap_pyfunction!(nvl2))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(nullif))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 2ef2f0473..f1ea3d256 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -152,6 +152,8 @@ "floor", "from_unixtime", "gcd", + "greatest", + "ifnull", "in_list", "initcap", "isnan", @@ -160,6 +162,7 @@ "last_value", "lcm", "lead", + "least", "left", "length", "levenshtein", @@ -216,6 +219,7 @@ "ntile", "nullif", "nvl", + "nvl2", "octet_length", "order_by", "overlay", @@ -1045,6 +1049,34 @@ def gcd(x: Expr, y: Expr) -> Expr: return Expr(f.gcd(x.expr, y.expr)) +def greatest(*args: Expr) -> Expr: + """Returns the greatest value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.greatest(dfn.col("a"), dfn.col("b")).alias("greatest")) + >>> result.collect_column("greatest")[0].as_py() + 2 + >>> result.collect_column("greatest")[1].as_py() + 3 + """ + exprs = [arg.expr for arg in args] + return Expr(f.greatest(*exprs)) + + +def ifnull(x: Expr, y: Expr) -> Expr: + """Returns ``x`` if ``x`` is not NULL. Otherwise returns ``y``. + + See Also: + This is an alias for :py:func:`nvl`. + """ + return nvl(x, y) + + def initcap(string: Expr) -> Expr: """Set the initial letter of each word to capital. @@ -1098,6 +1130,25 @@ def lcm(x: Expr, y: Expr) -> Expr: return Expr(f.lcm(x.expr, y.expr)) +def least(*args: Expr) -> Expr: + """Returns the least value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.least(dfn.col("a"), dfn.col("b")).alias("least")) + >>> result.collect_column("least")[0].as_py() + 1 + >>> result.collect_column("least")[1].as_py() + 1 + """ + exprs = [arg.expr for arg in args] + return Expr(f.least(*exprs)) + + def left(string: Expr, n: Expr) -> Expr: """Returns the first ``n`` characters in the ``string``. @@ -1282,6 +1333,24 @@ def nvl(x: Expr, y: Expr) -> Expr: return Expr(f.nvl(x.expr, y.expr)) +def nvl2(x: Expr, y: Expr, z: Expr) -> Expr: + """Returns ``y`` if ``x`` is not NULL. Otherwise returns ``z``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [None, 1], "b": [10, 20], "c": [30, 40]}) + >>> result = df.select( + ... dfn.functions.nvl2( + ... dfn.col("a"), dfn.col("b"), dfn.col("c")).alias("nvl2") + ... ) + >>> result.collect_column("nvl2")[0].as_py() + 30 + >>> result.collect_column("nvl2")[1].as_py() + 20 + """ + return Expr(f.nvl2(x.expr, y.expr, z.expr)) + + def octet_length(arg: Expr) -> Expr: """Returns the number of bytes of a string. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index db141fbe0..74fcbffb4 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1410,62 +1410,253 @@ def test_alias_with_metadata(df): assert df.schema().field("b").metadata == {b"key": b"value"} -def test_coalesce(df): - # Create a DataFrame with null values +@pytest.fixture +def df_with_nulls(): ctx = SessionContext() + # Rows: + # 0: both values present + # 1: a/d/h/k null, b/e/i/l present + # 2: a/d/h/k present, b/e/i/l null + # 3: all null batch = pa.RecordBatch.from_arrays( [ - pa.array(["Hello", None, "!"]), # string column with null - pa.array([4, None, 6]), # integer column with null - pa.array(["hello ", None, " !"]), # string column with null + pa.array([1, None, 3, None], type=pa.int64()), + pa.array([5, 10, None, None], type=pa.int64()), + pa.array([20, 30, 40, None], type=pa.int64()), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + pa.array(["banana", "date", None, None], type=pa.utf8()), + pa.array(["x", "y", "z", None], type=pa.utf8()), pa.array( [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), None, - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - ), # datetime with null - pa.array([False, None, True]), # boolean column with null + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + None, + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array([True, None, False, None], type=pa.bool_()), + pa.array([False, True, None, None], type=pa.bool_()), ], - names=["a", "b", "c", "d", "e"], - ) - df_with_nulls = ctx.create_dataframe([[batch]]) - - # Test coalesce with different data types - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal("default")).alias("a_coalesced"), - f.coalesce(column("b"), literal(0)).alias("b_coalesced"), - f.coalesce(column("c"), literal("default")).alias("c_coalesced"), - f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias( - "d_coalesced" - ), - f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"), + names=["a", "b", "c", "d", "e", "g", "h", "i", "k", "l"], ) + return ctx.create_dataframe([[batch]]) - result = result_df.collect()[0] - # Verify results - assert result.column(0) == pa.array( - ["Hello", "default", "!"], type=pa.string_view() - ) - assert result.column(1) == pa.array([4, 0, 6], type=pa.int64()) - assert result.column(2) == pa.array( - ["hello ", "default", " !"], type=pa.string_view() - ) - assert result.column(3).to_pylist() == [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), - datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - assert result.column(4) == pa.array([False, False, True], type=pa.bool_()) - - # Test multiple arguments - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal(None), literal("fallback")).alias( - "multi_coalesce" - ) - ) - result = result_df.collect()[0] - assert result.column(0) == pa.array( - ["Hello", "fallback", "!"], type=pa.string_view() - ) +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + f.greatest(column("a"), column("b")), + pa.array([5, 10, 3, None], type=pa.int64()), + id="greatest_int", + ), + pytest.param( + f.greatest(column("d"), column("e")), + pa.array(["banana", "date", "cherry", None], type=pa.utf8()), + id="greatest_str", + ), + pytest.param( + f.least(column("a"), column("b")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_int", + ), + pytest.param( + f.least(column("d"), column("e")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="least_str", + ), + pytest.param( + f.coalesce(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="coalesce_int", + ), + pytest.param( + f.coalesce(column("d"), column("e"), column("g")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="coalesce_str", + ), + pytest.param( + f.nvl(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="nvl_int", + ), + pytest.param( + f.nvl(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="nvl_str", + ), + pytest.param( + f.ifnull(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="ifnull_int", + ), + pytest.param( + f.ifnull(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="ifnull_str", + ), + pytest.param( + f.nvl2(column("a"), column("b"), column("c")), + pa.array([5, 30, None, None], type=pa.int64()), + id="nvl2_int", + ), + pytest.param( + f.nvl2(column("d"), column("e"), column("g")), + pa.array(["banana", "y", None, None], type=pa.utf8()), + id="nvl2_str", + ), + pytest.param( + f.nullif(column("a"), column("b")), + pa.array([1, None, 3, None], type=pa.int64()), + id="nullif_int", + ), + pytest.param( + f.nullif(column("d"), column("e")), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + id="nullif_str", + ), + pytest.param( + f.nullif(column("a"), literal(1)), + pa.array([None, None, 3, None], type=pa.int64()), + id="nullif_equal_values", + ), + pytest.param( + f.greatest(column("a"), column("b"), column("c")), + pa.array([20, 30, 40, None], type=pa.int64()), + id="greatest_variadic", + ), + pytest.param( + f.least(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_variadic", + ), + pytest.param( + f.greatest(column("a"), literal(2)), + pa.array([2, 2, 3, 2], type=pa.int64()), + id="greatest_literal", + ), + pytest.param( + f.least(column("a"), literal(2)), + pa.array([1, 2, 2, 2], type=pa.int64()), + id="least_literal", + ), + pytest.param( + f.coalesce(column("a"), literal(0)), + pa.array([1, 0, 3, 0], type=pa.int64()), + id="coalesce_literal_int", + ), + pytest.param( + f.coalesce(column("d"), literal("default")), + pa.array(["apple", "default", "cherry", "default"], type=pa.string_view()), + id="coalesce_literal_str", + ), + pytest.param( + f.nvl(column("a"), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="nvl_literal", + ), + pytest.param( + f.ifnull(column("d"), literal("unknown")), + pa.array(["apple", "unknown", "cherry", "unknown"], type=pa.string_view()), + id="ifnull_literal", + ), + pytest.param( + f.nvl2(column("a"), literal(1), literal(0)), + pa.array([1, 0, 1, 0], type=pa.int64()), + id="nvl2_literal", + ), + pytest.param( + f.greatest(column("h"), column("i")), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="greatest_datetime", + ), + pytest.param( + f.least(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="least_datetime", + ), + pytest.param( + f.coalesce(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_datetime", + ), + pytest.param( + f.nvl(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl_bool", + ), + pytest.param( + f.coalesce(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="coalesce_bool", + ), + pytest.param( + f.nvl2(column("k"), column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl2_bool", + ), + pytest.param( + f.coalesce( + column("h"), + literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ)), + ), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_literal_datetime", + ), + pytest.param( + f.coalesce(column("k"), literal(value=False)), + pa.array([True, False, False, False], type=pa.bool_()), + id="coalesce_literal_bool", + ), + pytest.param( + f.coalesce(column("a"), literal(None), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="coalesce_skip_null_literal", + ), + ], +) +def test_conditional_functions(df_with_nulls, expr, expected): + result = df_with_nulls.select(expr.alias("result")).collect()[0] + assert result.column(0) == expected From 16feeb136737ae45fac39f7a82cca2d88fd6224b Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Fri, 3 Apr 2026 12:47:31 -0700 Subject: [PATCH 07/29] Reduce peak memory usage during release builds to fix OOM on manylinux runners (#1445) * adjust swap to 8gb * modify profile.release --- .github/workflows/build.yml | 15 ++++++++++++++- Cargo.toml | 4 ++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4b37046ee..7682d6cb0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,6 +159,19 @@ jobs: with: enable-cache: true + - name: Add extra swap for release build + if: inputs.build_mode == 'release' + run: | + set -euxo pipefail + sudo swapoff -a || true + sudo rm -f /swapfile + sudo fallocate -l 8G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=8192 + sudo chmod 600 /swapfile + sudo mkswap /swapfile + sudo swapon /swapfile + free -h + swapon --show + - name: Build (release mode) uses: PyO3/maturin-action@v1 if: inputs.build_mode == 'release' @@ -233,7 +246,7 @@ jobs: set -euxo pipefail sudo swapoff -a || true sudo rm -f /swapfile - sudo fallocate -l 16G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=16384 + sudo fallocate -l 8G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=8192 sudo chmod 600 /swapfile sudo mkswap /swapfile sudo swapon /swapfile diff --git a/Cargo.toml b/Cargo.toml index 346f6da3e..3a34e204c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,8 +64,8 @@ pyo3-build-config = "0.28" datafusion-python-util = { path = "crates/util" } [profile.release] -lto = true -codegen-units = 1 +lto = "thin" +codegen-units = 2 # We cannot publish to crates.io with any patches in the below section. Developers # must remove any entries in this section before creating a release candidate. From 8a35caea9ed01492742738f161fa5b4459d69402 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 4 Apr 2026 12:20:31 -0400 Subject: [PATCH 08/29] Add missing map functions (#1461) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add map functions (make_map, map_keys, map_values, map_extract, map_entries, element_at) Closes #1448 Co-Authored-By: Claude Opus 4.6 (1M context) * Add unit tests for map functions Co-Authored-By: Claude Opus 4.6 (1M context) * Remove redundant pyo3 element_at function element_at is already a Python-only alias for map_extract, so the Rust binding is unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) * Change make_map to accept a Python dictionary make_map now takes a dict for the common case and also supports separate keys/values lists for column expressions. Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) * Make map the primary function with make_map as alias map() now supports three calling conventions matching upstream: - map({"a": 1, "b": 2}) — from a Python dictionary - map([keys], [values]) — two lists that get zipped - map(k1, v1, k2, v2, ...) — variadic key-value pairs Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) * Improve map function docstrings - Add examples for all three map() calling conventions - Use clearer descriptions instead of jargon (no "zipped" or "variadic") - Break map_keys/map_values/map_extract/map_entries examples into two steps: create the map column first, then call the function Co-Authored-By: Claude Opus 4.6 (1M context) * Remove map() in favor of make_map(), fix docstrings, add validation - Remove map() function that shadowed Python builtin; make_map() is now the sole entry point for creating map expressions - Fix map_extract/element_at docstrings: missing keys return [None], not an empty list (matches actual upstream behavior) - Add length validation for the two-list calling convention - Update all tests and docstring examples accordingly Co-Authored-By: Claude Opus 4.6 (1M context) * Consolidate map function tests into parametrized groups Reduce boilerplate by combining make_map construction tests and map accessor function tests into two @pytest.mark.parametrize groups. Co-Authored-By: Claude Opus 4.6 (1M context) * Docstring update Co-authored-by: Nuno Faria * Docstring update Co-authored-by: Nuno Faria * Simplify test for readability Co-authored-by: Nuno Faria * Simplify test for readability Co-authored-by: Nuno Faria --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Nuno Faria --- crates/core/src/functions.rs | 20 +++++ python/datafusion/functions.py | 158 +++++++++++++++++++++++++++++++++ python/tests/test_functions.py | 100 +++++++++++++++++++++ 3 files changed, 278 insertions(+) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 3f07da95b..5e61b71be 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -93,6 +93,13 @@ fn array_cat(exprs: Vec) -> PyExpr { array_concat(exprs) } +#[pyfunction] +fn make_map(keys: Vec, values: Vec) -> PyExpr { + let keys = keys.into_iter().map(|x| x.into()).collect(); + let values = values.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::map::map(keys, values).into() +} + #[pyfunction] #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { @@ -678,6 +685,12 @@ array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); +// Map Functions +array_fn!(map_keys, map); +array_fn!(map_values, map); +array_fn!(map_extract, map key); +array_fn!(map_entries, map); + aggregate_function!(array_agg); aggregate_function!(max); aggregate_function!(min); @@ -1142,6 +1155,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(flatten))?; m.add_wrapped(wrap_pyfunction!(cardinality))?; + // Map Functions + m.add_wrapped(wrap_pyfunction!(make_map))?; + m.add_wrapped(wrap_pyfunction!(map_keys))?; + m.add_wrapped(wrap_pyfunction!(map_values))?; + m.add_wrapped(wrap_pyfunction!(map_extract))?; + m.add_wrapped(wrap_pyfunction!(map_entries))?; + // Window Functions m.add_wrapped(wrap_pyfunction!(lead))?; m.add_wrapped(wrap_pyfunction!(lag))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f1ea3d256..3febb44e3 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -140,6 +140,7 @@ "degrees", "dense_rank", "digest", + "element_at", "empty", "encode", "ends_with", @@ -206,7 +207,12 @@ "make_array", "make_date", "make_list", + "make_map", "make_time", + "map_entries", + "map_extract", + "map_keys", + "map_values", "max", "md5", "mean", @@ -3458,6 +3464,158 @@ def empty(array: Expr) -> Expr: return array_empty(array) +# map functions + + +def make_map(*args: Any) -> Expr: + """Returns a map expression. + + Supports three calling conventions: + + - ``make_map({"a": 1, "b": 2})`` — from a Python dictionary. + - ``make_map([keys], [values])`` — from a list of keys and a list of + their associated values. Both lists must be the same length. + - ``make_map(k1, v1, k2, v2, ...)`` — from alternating keys and their + associated values. + + Keys and values that are not already :py:class:`~datafusion.expr.Expr` + are automatically converted to literal expressions. + + Examples: + From a dictionary: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map({"a": 1, "b": 2}).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('a', 1), ('b', 2)] + + From two lists: + + >>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]}) + >>> df = df.select( + ... dfn.functions.make_map( + ... [dfn.col("key")], [dfn.col("val")] + ... ).alias("m")) + >>> df.collect_column("m")[0].as_py() + [('x', 10)] + + From alternating keys and values: + + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map("x", 1, "y", 2).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('x', 1), ('y', 2)] + """ + if len(args) == 1 and isinstance(args[0], dict): + key_list = list(args[0].keys()) + value_list = list(args[0].values()) + elif ( + len(args) == 2 # noqa: PLR2004 + and isinstance(args[0], list) + and isinstance(args[1], list) + ): + if len(args[0]) != len(args[1]): + msg = "make_map requires key and value lists to be the same length" + raise ValueError(msg) + key_list = args[0] + value_list = args[1] + elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004 + key_list = list(args[0::2]) + value_list = list(args[1::2]) + else: + msg = ( + "make_map expects a dict, two lists, or an even number of " + "key-value arguments" + ) + raise ValueError(msg) + + key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list] + val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list] + return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs])) + + +def map_keys(map: Expr) -> Expr: + """Returns a list of all keys in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_keys(dfn.col("m")).alias("keys")) + >>> result.collect_column("keys")[0].as_py() + ['x', 'y'] + """ + return Expr(f.map_keys(map.expr)) + + +def map_values(map: Expr) -> Expr: + """Returns a list of all values in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_values(dfn.col("m")).alias("vals")) + >>> result.collect_column("vals")[0].as_py() + [1, 2] + """ + return Expr(f.map_values(map.expr)) + + +def map_extract(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_extract( + ... dfn.col("m"), dfn.lit("x") + ... ).alias("val")) + >>> result.collect_column("val")[0].as_py() + [1] + """ + return Expr(f.map_extract(map.expr, key.expr)) + + +def map_entries(map: Expr) -> Expr: + """Returns a list of all entries (key-value struct pairs) in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_entries(dfn.col("m")).alias("entries")) + >>> result.collect_column("entries")[0].as_py() + [{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}] + """ + return Expr(f.map_entries(map.expr)) + + +def element_at(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + See Also: + This is an alias for :py:func:`map_extract`. + """ + return map_extract(map, key) + + # aggregate functions def approx_distinct( expression: Expr, diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 74fcbffb4..f25c6e78c 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -668,6 +668,106 @@ def test_array_function_obj_tests(stmt, py_expr): assert a == b +@pytest.mark.parametrize( + ("args", "expected"), + [ + pytest.param( + ({"x": 1, "y": 2},), + [("x", 1), ("y", 2)], + id="dict", + ), + pytest.param( + ({"x": literal(1), "y": literal(2)},), + [("x", 1), ("y", 2)], + id="dict_with_exprs", + ), + pytest.param( + ("x", 1, "y", 2), + [("x", 1), ("y", 2)], + id="variadic_pairs", + ), + pytest.param( + (literal("x"), literal(1), literal("y"), literal(2)), + [("x", 1), ("y", 2)], + id="variadic_with_exprs", + ), + ], +) +def test_make_map(args, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0) + assert result[0].as_py() == expected + + +def test_make_map_from_two_lists(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["k1", "k2", "k3"]), + pa.array([10, 20, 30]), + ], + names=["keys", "vals"], + ) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map([column("keys")], [column("vals")]) + result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0) + assert result.to_pylist() == [["k1"], ["k2"], ["k3"]] + + result = df.select(f.map_values(m).alias("v")).collect()[0].column(0) + assert result.to_pylist() == [[10], [20], [30]] + + +def test_make_map_odd_args_raises(): + with pytest.raises(ValueError, match="make_map expects"): + f.make_map("x", 1, "y") + + +def test_make_map_mismatched_lengths(): + with pytest.raises(ValueError, match="same length"): + f.make_map(["a", "b"], [1]) + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + pytest.param(f.map_keys, ["x", "y"], id="map_keys"), + pytest.param(f.map_values, [1, 2], id="map_values"), + pytest.param( + lambda m: f.map_extract(m, literal("x")), + [1], + id="map_extract", + ), + pytest.param( + lambda m: f.map_extract(m, literal("z")), + [None], + id="map_extract_missing_key", + ), + pytest.param( + f.map_entries, + [{"key": "x", "value": 1}, {"key": "y", "value": 2}], + id="map_entries", + ), + pytest.param( + lambda m: f.element_at(m, literal("y")), + [2], + id="element_at", + ), + ], +) +def test_map_functions(func, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map({"x": 1, "y": 2}) + result = df.select(func(m).alias("out")).collect()[0].column(0) + assert result[0].as_py() == expected + + @pytest.mark.parametrize( ("function", "expected_result"), [ From ff15648c5dca6b41d3f6146c6c36c97e605f8561 Mon Sep 17 00:00:00 2001 From: Nuno Faria Date: Sun, 5 Apr 2026 13:29:32 +0100 Subject: [PATCH 09/29] minor: Fix pytest instructions in README (#1477) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7c1c71281..7849e7a02 100644 --- a/README.md +++ b/README.md @@ -275,7 +275,7 @@ needing to activate the virtual environment: ```bash uv run --no-project maturin develop --uv -uv run --no-project pytest . +uv run --no-project pytest ``` To run the FFI tests within the examples folder, after you have built From 99bc9602dd077c924685f1fc6e54e6feb3429302 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 07:47:13 -0400 Subject: [PATCH 10/29] Add missing array functions (#1468) * Add missing array/list functions and aliases (#1452) Add new array functions from upstream DataFusion v53: array_any_value, array_distance, array_max, array_min, array_reverse, arrays_zip, string_to_array, and gen_series. Add corresponding list_* aliases and missing list_* aliases for existing functions (list_empty, list_pop_back, list_pop_front, list_has, list_has_all, list_has_any). Also add array_contains/list_contains as aliases for array_has, generate_series as alias for gen_series, and string_to_list as alias for string_to_array. Co-Authored-By: Claude Opus 4.6 (1M context) * Add unit tests for new array/list functions and aliases Tests cover all functions and aliases added in the previous commit: array_any_value, array_distance, array_max, array_min, array_reverse, arrays_zip, string_to_array, gen_series, generate_series, array_contains, list_contains, list_empty, list_pop_back, list_pop_front, list_has, list_has_all, list_has_any, and list_* aliases for the new functions. Co-Authored-By: Claude Opus 4.6 (1M context) * Improve array function APIs: optional params, better naming, restore comment - Make null_string optional in string_to_array/string_to_list - Make step optional in gen_series/generate_series - Rename second_array to element in array_contains/list_has/list_contains - Restore # Window Functions section comment in __all__ - Add tests for optional parameter variants Co-Authored-By: Claude Opus 4.6 (1M context) * Consolidate array/list function tests using pytest parametrize Reduce 26 individual tests to 14 test functions with parametrized cases, eliminating boilerplate while maintaining full coverage. Co-Authored-By: Claude Opus 4.6 (1M context) * Move list alias tests into existing test_array_functions parametrize block Merge standalone tests for list_empty, list_pop_back, list_pop_front, list_has, array_contains, list_contains, list_has_all, and list_has_any into the existing parametrized test_array_functions block alongside their array_* counterparts. Co-Authored-By: Claude Opus 4.6 (1M context) * Merge test_array_any_value into parametrized test_any_value_aliases Use the richer multi-row dataset (including all-nulls case) for both array_any_value and list_any_value via the parametrized test. Co-Authored-By: Claude Opus 4.6 (1M context) * Add arrays_overlap and list_overlap as aliases for array_has_any These aliases match the upstream DataFusion SQL-level aliases, completing the set of missing array functions from issue #1452. Co-Authored-By: Claude Opus 4.6 (1M context) * Add docstring examples for optional params in string_to_array and gen_series Co-Authored-By: Claude Opus 4.6 (1M context) * Update AGENTS file to demonstrate preferred method of documenting python functions --------- Co-authored-by: Claude Opus 4.6 (1M context) --- AGENTS.md | 17 ++ crates/core/src/functions.rs | 56 ++++++ python/datafusion/functions.py | 337 +++++++++++++++++++++++++++++++++ python/tests/test_functions.py | 137 ++++++++++++++ 4 files changed, 547 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 1853a84cd..f6fdfbd90 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,3 +25,20 @@ Skills follow the [Agent Skills](https://agentskills.io) open standard. Each ski - `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. - Additional supporting files as needed. + +## Python Function Docstrings + +Every Python function must include a docstring with usage examples. + +- **Examples are required**: Each function needs at least one doctest-style example + demonstrating basic usage. +- **Optional parameters**: If a function has optional parameters, include separate + examples that show usage both without and with the optional arguments. Pass + optional arguments using their keyword name (e.g., `step=dfn.lit(3)`) so readers + can immediately see which parameter is being demonstrated. +- **Reuse input data**: Use the same input data across examples wherever possible. + The examples should demonstrate how different optional arguments change the output + for the same input, making the effect of each option easy to understand. +- **Alias functions**: Functions that are simple aliases (e.g., `list_sort` aliasing + `array_sort`) only need a one-line description and a `See Also` reference to the + primary function. They do not need their own examples. diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 5e61b71be..8bb927718 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -93,6 +93,50 @@ fn array_cat(exprs: Vec) -> PyExpr { array_concat(exprs) } +#[pyfunction] +fn array_distance(array1: PyExpr, array2: PyExpr) -> PyExpr { + let args = vec![array1.into(), array2.into()]; + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::distance::array_distance_udf(), + args, + )) + .into() +} + +#[pyfunction] +fn arrays_zip(exprs: Vec) -> PyExpr { + let exprs = exprs.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::expr_fn::arrays_zip(exprs).into() +} + +#[pyfunction] +#[pyo3(signature = (string, delimiter, null_string=None))] +fn string_to_array(string: PyExpr, delimiter: PyExpr, null_string: Option) -> PyExpr { + let mut args = vec![string.into(), delimiter.into()]; + if let Some(null_string) = null_string { + args.push(null_string.into()); + } + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::string::string_to_array_udf(), + args, + )) + .into() +} + +#[pyfunction] +#[pyo3(signature = (start, stop, step=None))] +fn gen_series(start: PyExpr, stop: PyExpr, step: Option) -> PyExpr { + let mut args = vec![start.into(), stop.into()]; + if let Some(step) = step { + args.push(step.into()); + } + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::range::gen_series_udf(), + args, + )) + .into() +} + #[pyfunction] fn make_map(keys: Vec, values: Vec) -> PyExpr { let keys = keys.into_iter().map(|x| x.into()).collect(); @@ -681,6 +725,10 @@ array_fn!(array_intersect, first_array second_array); array_fn!(array_union, array1 array2); array_fn!(array_except, first_array second_array); array_fn!(array_resize, array size value); +array_fn!(array_any_value, array); +array_fn!(array_max, array); +array_fn!(array_min, array); +array_fn!(array_reverse, array); array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); @@ -1152,6 +1200,14 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_replace_all))?; m.add_wrapped(wrap_pyfunction!(array_sort))?; m.add_wrapped(wrap_pyfunction!(array_slice))?; + m.add_wrapped(wrap_pyfunction!(array_any_value))?; + m.add_wrapped(wrap_pyfunction!(array_distance))?; + m.add_wrapped(wrap_pyfunction!(array_max))?; + m.add_wrapped(wrap_pyfunction!(array_min))?; + m.add_wrapped(wrap_pyfunction!(array_reverse))?; + m.add_wrapped(wrap_pyfunction!(arrays_zip))?; + m.add_wrapped(wrap_pyfunction!(string_to_array))?; + m.add_wrapped(wrap_pyfunction!(gen_series))?; m.add_wrapped(wrap_pyfunction!(flatten))?; m.add_wrapped(wrap_pyfunction!(cardinality))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 3febb44e3..1b267731e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -53,10 +53,13 @@ "approx_percentile_cont_with_weight", "array", "array_agg", + "array_any_value", "array_append", "array_cat", "array_concat", + "array_contains", "array_dims", + "array_distance", "array_distinct", "array_element", "array_empty", @@ -69,6 +72,8 @@ "array_intersect", "array_join", "array_length", + "array_max", + "array_min", "array_ndims", "array_pop_back", "array_pop_front", @@ -85,10 +90,13 @@ "array_replace_all", "array_replace_n", "array_resize", + "array_reverse", "array_slice", "array_sort", "array_to_string", "array_union", + "arrays_overlap", + "arrays_zip", "arrow_cast", "arrow_typeof", "ascii", @@ -153,6 +161,8 @@ "floor", "from_unixtime", "gcd", + "gen_series", + "generate_series", "greatest", "ifnull", "in_list", @@ -167,19 +177,31 @@ "left", "length", "levenshtein", + "list_any_value", "list_append", "list_cat", "list_concat", + "list_contains", "list_dims", + "list_distance", "list_distinct", "list_element", + "list_empty", "list_except", "list_extract", + "list_has", + "list_has_all", + "list_has_any", "list_indexof", "list_intersect", "list_join", "list_length", + "list_max", + "list_min", "list_ndims", + "list_overlap", + "list_pop_back", + "list_pop_front", "list_position", "list_positions", "list_prepend", @@ -193,10 +215,12 @@ "list_replace_all", "list_replace_n", "list_resize", + "list_reverse", "list_slice", "list_sort", "list_to_string", "list_union", + "list_zip", "ln", "log", "log2", @@ -273,6 +297,8 @@ "stddev_pop", "stddev_samp", "string_agg", + "string_to_array", + "string_to_list", "strpos", "struct", "substr", @@ -2794,6 +2820,15 @@ def array_empty(array: Expr) -> Expr: return Expr(f.array_empty(array.expr)) +def list_empty(array: Expr) -> Expr: + """Returns a boolean indicating whether the array is empty. + + See Also: + This is an alias for :py:func:`array_empty`. + """ + return array_empty(array) + + def array_extract(array: Expr, n: Expr) -> Expr: """Extracts the element with the index n from the array. @@ -2891,6 +2926,69 @@ def array_has_any(first_array: Expr, second_array: Expr) -> Expr: return Expr(f.array_has_any(first_array.expr, second_array.expr)) +def array_contains(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + +def list_has(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + +def list_has_all(first_array: Expr, second_array: Expr) -> Expr: + """Determines if there is complete overlap ``second_array`` in ``first_array``. + + See Also: + This is an alias for :py:func:`array_has_all`. + """ + return array_has_all(first_array, second_array) + + +def list_has_any(first_array: Expr, second_array: Expr) -> Expr: + """Determine if there is an overlap between ``first_array`` and ``second_array``. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def arrays_overlap(first_array: Expr, second_array: Expr) -> Expr: + """Returns true if any element appears in both arrays. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def list_overlap(first_array: Expr, second_array: Expr) -> Expr: + """Returns true if any element appears in both arrays. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def list_contains(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: """Return the position of the first occurrence of ``element`` in ``array``. @@ -3058,6 +3156,24 @@ def array_pop_front(array: Expr) -> Expr: return Expr(f.array_pop_front(array.expr)) +def list_pop_back(array: Expr) -> Expr: + """Returns the array without the last element. + + See Also: + This is an alias for :py:func:`array_pop_back`. + """ + return array_pop_back(array) + + +def list_pop_front(array: Expr) -> Expr: + """Returns the array without the first element. + + See Also: + This is an alias for :py:func:`array_pop_front`. + """ + return array_pop_front(array) + + def array_remove(array: Expr, element: Expr) -> Expr: """Removes the first element from the array equal to the given value. @@ -3429,6 +3545,227 @@ def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: return array_resize(array, size, value) +def array_any_value(array: Expr) -> Expr: + """Returns the first non-null element in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[None, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_any_value(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 2 + """ + return Expr(f.array_any_value(array.expr)) + + +def list_any_value(array: Expr) -> Expr: + """Returns the first non-null element in the array. + + See Also: + This is an alias for :py:func:`array_any_value`. + """ + return array_any_value(array) + + +def array_distance(array1: Expr, array2: Expr) -> Expr: + """Returns the Euclidean distance between two numeric arrays. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1.0, 2.0]], "b": [[1.0, 4.0]]}) + >>> result = df.select( + ... dfn.functions.array_distance( + ... dfn.col("a"), dfn.col("b"), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + 2.0 + """ + return Expr(f.array_distance(array1.expr, array2.expr)) + + +def list_distance(array1: Expr, array2: Expr) -> Expr: + """Returns the Euclidean distance between two numeric arrays. + + See Also: + This is an alias for :py:func:`array_distance`. + """ + return array_distance(array1, array2) + + +def array_max(array: Expr) -> Expr: + """Returns the maximum value in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_max(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 3 + """ + return Expr(f.array_max(array.expr)) + + +def list_max(array: Expr) -> Expr: + """Returns the maximum value in the array. + + See Also: + This is an alias for :py:func:`array_max`. + """ + return array_max(array) + + +def array_min(array: Expr) -> Expr: + """Returns the minimum value in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_min(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 1 + """ + return Expr(f.array_min(array.expr)) + + +def list_min(array: Expr) -> Expr: + """Returns the minimum value in the array. + + See Also: + This is an alias for :py:func:`array_min`. + """ + return array_min(array) + + +def array_reverse(array: Expr) -> Expr: + """Reverses the order of elements in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_reverse(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + [3, 2, 1] + """ + return Expr(f.array_reverse(array.expr)) + + +def list_reverse(array: Expr) -> Expr: + """Reverses the order of elements in the array. + + See Also: + This is an alias for :py:func:`array_reverse`. + """ + return array_reverse(array) + + +def arrays_zip(*arrays: Expr) -> Expr: + """Combines multiple arrays into a single array of structs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2]], "b": [[3, 4]]}) + >>> result = df.select( + ... dfn.functions.arrays_zip(dfn.col("a"), dfn.col("b")).alias("result")) + >>> result.collect_column("result")[0].as_py() + [{'c0': 1, 'c1': 3}, {'c0': 2, 'c1': 4}] + """ + args = [a.expr for a in arrays] + return Expr(f.arrays_zip(args)) + + +def list_zip(*arrays: Expr) -> Expr: + """Combines multiple arrays into a single array of structs. + + See Also: + This is an alias for :py:func:`arrays_zip`. + """ + return arrays_zip(*arrays) + + +def string_to_array( + string: Expr, delimiter: Expr, null_string: Expr | None = None +) -> Expr: + """Splits a string based on a delimiter and returns an array of parts. + + Any parts matching the optional ``null_string`` will be replaced with ``NULL``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello,world"]}) + >>> result = df.select( + ... dfn.functions.string_to_array( + ... dfn.col("a"), dfn.lit(","), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + ['hello', 'world'] + + Replace parts matching a ``null_string`` with ``NULL``: + + >>> result = df.select( + ... dfn.functions.string_to_array( + ... dfn.col("a"), dfn.lit(","), null_string=dfn.lit("world"), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + ['hello', None] + """ + null_expr = null_string.expr if null_string is not None else None + return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) + + +def string_to_list( + string: Expr, delimiter: Expr, null_string: Expr | None = None +) -> Expr: + """Splits a string based on a delimiter and returns an array of parts. + + See Also: + This is an alias for :py:func:`string_to_array`. + """ + return string_to_array(string, delimiter, null_string) + + +def gen_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr: + """Creates a list of values in the range between start and stop. + + Unlike :py:func:`range`, this includes the upper bound. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [0]}) + >>> result = df.select( + ... dfn.functions.gen_series( + ... dfn.lit(1), dfn.lit(5), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + [1, 2, 3, 4, 5] + + Specify a custom ``step``: + + >>> result = df.select( + ... dfn.functions.gen_series( + ... dfn.lit(1), dfn.lit(10), step=dfn.lit(3), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + [1, 4, 7, 10] + """ + step_expr = step.expr if step is not None else None + return Expr(f.gen_series(start.expr, stop.expr, step_expr)) + + +def generate_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr: + """Creates a list of values in the range between start and stop. + + Unlike :py:func:`range`, this includes the upper bound. + + See Also: + This is an alias for :py:func:`gen_series`. + """ + return gen_series(start, stop, step) + + def flatten(array: Expr) -> Expr: """Flattens an array of arrays into a single array. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index f25c6e78c..2100da9ae 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -330,6 +330,10 @@ def py_flatten(arr): f.empty, lambda data: [len(r) == 0 for r in data], ), + ( + f.list_empty, + lambda data: [len(r) == 0 for r in data], + ), ( lambda col: f.array_extract(col, literal(1)), lambda data: [r[0] for r in data], @@ -354,18 +358,54 @@ def py_flatten(arr): lambda col: f.array_has(col, literal(1.0)), lambda data: [1.0 in r for r in data], ), + ( + lambda col: f.list_has(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), + ( + lambda col: f.array_contains(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), + ( + lambda col: f.list_contains(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), ( lambda col: f.array_has_all( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ), + ( + lambda col: f.list_has_all( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), ( lambda col: f.array_has_any( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ), + ( + lambda col: f.list_has_any( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), + ( + lambda col: f.arrays_overlap( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), + ( + lambda col: f.list_overlap( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), ( lambda col: f.array_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], @@ -418,10 +458,18 @@ def py_flatten(arr): f.array_pop_back, lambda data: [arr[:-1] for arr in data], ), + ( + f.list_pop_back, + lambda data: [arr[:-1] for arr in data], + ), ( f.array_pop_front, lambda data: [arr[1:] for arr in data], ), + ( + f.list_pop_front, + lambda data: [arr[1:] for arr in data], + ), ( lambda col: f.array_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], @@ -1760,3 +1808,92 @@ def df_with_nulls(): def test_conditional_functions(df_with_nulls, expr, expected): result = df_with_nulls.select(expr.alias("result")).collect()[0] assert result.column(0) == expected + + +@pytest.mark.parametrize("func", [f.array_any_value, f.list_any_value]) +def test_any_value_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[None, 2, 3], [None, None, None], [1, 2, 3]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values[0] == 2 + assert values[1] is None + assert values[2] == 1 + + +@pytest.mark.parametrize("func", [f.array_distance, f.list_distance]) +def test_array_distance_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1.0, 2.0]], "b": [[1.0, 4.0]]}) + result = df.select(func(column("a"), column("b")).alias("v")).collect() + assert result[0].column(0)[0].as_py() == pytest.approx(2.0) + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + (f.array_max, [5, 10]), + (f.list_max, [5, 10]), + (f.array_min, [1, 2]), + (f.list_min, [1, 2]), + ], +) +def test_array_min_max(func, expected): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 5, 3], [10, 2]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values == expected + + +@pytest.mark.parametrize("func", [f.array_reverse, f.list_reverse]) +def test_array_reverse_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values == [[3, 2, 1], [5, 4]] + + +@pytest.mark.parametrize("func", [f.arrays_zip, f.list_zip]) +def test_arrays_zip_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2]], "b": [[3, 4]]}) + result = df.select(func(column("a"), column("b")).alias("v")).collect() + values = result[0].column(0)[0].as_py() + assert values == [{"c0": 1, "c1": 3}, {"c0": 2, "c1": 4}] + + +@pytest.mark.parametrize("func", [f.string_to_array, f.string_to_list]) +def test_string_to_array_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,world,foo"]}) + result = df.select(func(column("a"), literal(",")).alias("v")).collect() + assert result[0].column(0)[0].as_py() == ["hello", "world", "foo"] + + +def test_string_to_array_with_null_string(): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), literal(","), literal("NA")).alias("v") + ).collect() + values = result[0].column(0)[0].as_py() + assert values == ["hello", None, "world"] + + +@pytest.mark.parametrize("func", [f.gen_series, f.generate_series]) +def test_gen_series_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [0]}) + result = df.select(func(literal(1), literal(5)).alias("v")).collect() + assert result[0].column(0)[0].as_py() == [1, 2, 3, 4, 5] + + +def test_gen_series_with_step(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [0]}) + result = df.select( + f.gen_series(literal(1), literal(10), literal(3)).alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] From d07fdb3ef7d211920f40d0106fa50161c0bf20ce Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Apr 2026 08:54:30 -0400 Subject: [PATCH 11/29] Add missing scalar functions (#1470) * Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) * Add tests for new scalar functions Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Co-Authored-By: Claude Opus 4.6 (1M context) * Accept str for field name and type parameters in scalar functions Allow arrow_cast, get_field, and union_extract to accept plain str arguments instead of requiring Expr wrappers. Also improve arrow_metadata test coverage and fix parameter shadowing. Co-Authored-By: Claude Opus 4.6 (1M context) * Accept str for key parameter in arrow_metadata for consistency Co-Authored-By: Claude Opus 4.6 (1M context) * Add doctest examples and fix docstring style for new scalar functions Replace Args/Returns sections with doctest Examples blocks for arrow_metadata, get_field, union_extract, union_tag, and version to match existing codebase conventions. Simplify row to alias-style docstring with See Also reference. Document that arrow_cast accepts both str and Expr for data_type. Co-Authored-By: Claude Opus 4.6 (1M context) * Support pyarrow DataType in arrow_cast Allow arrow_cast to accept a pyarrow DataType in addition to str and Expr. The DataType is converted to its string representation before being passed to DataFusion. Adds test coverage for the new input type. Co-Authored-By: Claude Opus 4.6 (1M context) * Document bracket syntax shorthand in get_field docstring Note that expr["field"] is a convenient alternative when the field name is a static string, and get_field is needed for dynamic expressions. Add a second doctest example showing the bracket syntax. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix arrow_cast with pyarrow DataType by delegating to Expr.cast Use the existing Rust-side PyArrowType conversion via Expr.cast() instead of str() which produces pyarrow type names that DataFusion does not recognize. Co-Authored-By: Claude Opus 4.6 (1M context) * Clarify when to use arrow_cast vs Expr.cast in docstring Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 26 +++++ python/datafusion/functions.py | 174 ++++++++++++++++++++++++++++++++- python/tests/test_functions.py | 105 ++++++++++++++++++-- 3 files changed, 296 insertions(+), 9 deletions(-) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 8bb927718..74654ce46 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -695,8 +695,29 @@ expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); expr_fn!(arrow_cast, arg_1 datatype); +expr_fn_vec!(arrow_metadata); +expr_fn!(union_tag, arg1); expr_fn!(random); +#[pyfunction] +fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { + functions::core::get_field() + .call(vec![expr.into(), name.into()]) + .into() +} + +#[pyfunction] +fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr { + functions::core::union_extract() + .call(vec![union_expr.into(), field_name.into()]) + .into() +} + +#[pyfunction] +fn version() -> PyExpr { + functions::core::version().call(vec![]).into() +} + // Array Functions array_fn!(array_append, array element); array_fn!(array_to_string, array delimiter); @@ -1014,6 +1035,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(arrow_cast))?; + m.add_wrapped(wrap_pyfunction!(arrow_metadata))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?; @@ -1142,6 +1164,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(get_field))?; + m.add_wrapped(wrap_pyfunction!(union_extract))?; + m.add_wrapped(wrap_pyfunction!(union_tag))?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 1b267731e..aa7f28746 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -98,6 +98,7 @@ "arrays_overlap", "arrays_zip", "arrow_cast", + "arrow_metadata", "arrow_typeof", "ascii", "asin", @@ -163,6 +164,7 @@ "gcd", "gen_series", "generate_series", + "get_field", "greatest", "ifnull", "in_list", @@ -280,6 +282,7 @@ "reverse", "right", "round", + "row", "row_number", "rpad", "rtrim", @@ -322,12 +325,15 @@ "translate", "trim", "trunc", + "union_extract", + "union_tag", "upper", "uuid", "var", "var_pop", "var_samp", "var_sample", + "version", "when", # Window Functions "window", @@ -2628,22 +2634,184 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) -def arrow_cast(expr: Expr, data_type: Expr) -> Expr: +def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: """Casts an expression to a specified data type. + The ``data_type`` can be a string, a ``pyarrow.DataType``, or an + ``Expr``. For simple types, :py:meth:`Expr.cast() + ` is more concise + (e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when + you want to specify the target type as a string using DataFusion's + type syntax, which can be more readable for complex types like + ``"Timestamp(Nanosecond, None)"``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) - >>> data_type = dfn.string_literal("Float64") >>> result = df.select( - ... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c") + ... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c") + ... ) + >>> result.collect_column("c")[0].as_py() + 1.0 + + >>> import pyarrow as pa + >>> result = df.select( + ... dfn.functions.arrow_cast( + ... dfn.col("a"), data_type=pa.float64() + ... ).alias("c") ... ) >>> result.collect_column("c")[0].as_py() 1.0 """ + if isinstance(data_type, pa.DataType): + return expr.cast(data_type) + if isinstance(data_type, str): + data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) +def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: + """Returns the metadata of the input expression. + + If called with one argument, returns a Map of all metadata key-value pairs. + If called with two arguments, returns the value for the specified metadata key. + + Examples: + >>> import pyarrow as pa + >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) + >>> schema = pa.schema([field]) + >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) + >>> ctx = dfn.SessionContext() + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta") + ... ) + >>> ("k", "v") in result.collect_column("meta")[0].as_py() + True + + >>> result = df.select( + ... dfn.functions.arrow_metadata( + ... dfn.col("val"), key="k" + ... ).alias("meta_val") + ... ) + >>> result.collect_column("meta_val")[0].as_py() + 'v' + """ + if key is None: + return Expr(f.arrow_metadata(expr.expr)) + if isinstance(key, str): + key = Expr.string_literal(key) + return Expr(f.arrow_metadata(expr.expr, key.expr)) + + +def get_field(expr: Expr, name: Expr | str) -> Expr: + """Extracts a field from a struct or map by name. + + When the field name is a static string, the bracket operator + ``expr["field"]`` is a convenient shorthand. Use ``get_field`` + when the field name is a dynamic expression. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1], "b": [2]}) + >>> df = df.with_column( + ... "s", + ... dfn.functions.named_struct( + ... [("x", dfn.col("a")), ("y", dfn.col("b"))] + ... ), + ... ) + >>> result = df.select( + ... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + + Equivalent using bracket syntax: + + >>> result = df.select( + ... dfn.col("s")["x"].alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + """ + if isinstance(name, str): + name = Expr.string_literal(name) + return Expr(f.get_field(expr.expr, name.expr)) + + +def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: + """Extracts a value from a union type by field name. + + Returns the value of the named field if it is the currently selected + variant, otherwise returns NULL. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_extract(dfn.col("u"), "int").alias("val") + ... ) + >>> result.collect_column("val").to_pylist() + [1, None, 2] + """ + if isinstance(field_name, str): + field_name = Expr.string_literal(field_name) + return Expr(f.union_extract(union_expr.expr, field_name.expr)) + + +def union_tag(union_expr: Expr) -> Expr: + """Returns the tag (active field name) of a union type. + + Examples: + >>> import pyarrow as pa + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_tag(dfn.col("u")).alias("tag") + ... ) + >>> result.collect_column("tag").to_pylist() + ['int', 'str', 'int'] + """ + return Expr(f.union_tag(union_expr.expr)) + + +def version() -> Expr: + """Returns the DataFusion version string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.empty_table() + >>> result = df.select(dfn.functions.version().alias("v")) + >>> "Apache DataFusion" in result.collect_column("v")[0].as_py() + True + """ + return Expr(f.version()) + + +def row(*args: Expr) -> Expr: + """Returns a struct with the given arguments. + + See Also: + This is an alias for :py:func:`struct`. + """ + return struct(*args) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 2100da9ae..4e99fa9e3 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -20,7 +20,7 @@ import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column, literal, string_literal +from datafusion import SessionContext, column, literal from datafusion import functions as f np.seterr(invalid="ignore") @@ -1291,11 +1291,8 @@ def test_make_time(df): def test_arrow_cast(df): df = df.select( - # we use `string_literal` to return utf8 instead of `literal` which returns - # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view - # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 - f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), - f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), + f.arrow_cast(column("b"), "Float64").alias("b_as_float"), + f.arrow_cast(column("b"), "Int32").alias("b_as_int"), ) result = df.collect() assert len(result) == 1 @@ -1305,6 +1302,19 @@ def test_arrow_cast(df): assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) +def test_arrow_cast_with_pyarrow_type(df): + df = df.select( + f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"), + f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"), + f.arrow_cast(column("b"), pa.string()).alias("b_as_str"), + ) + result = df.collect()[0] + + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) + assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string()) + + def test_case(df): df = df.select( f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), @@ -1810,6 +1820,89 @@ def test_conditional_functions(df_with_nulls, expr, expected): assert result.column(0) == expected +def test_get_field(df): + df = df.with_column( + "s", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), + ) + result = df.select( + f.get_field(column("s"), "x").alias("x_val"), + f.get_field(column("s"), "y").alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_arrow_metadata(): + ctx = SessionContext() + field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"}) + schema = pa.schema([field]) + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema) + df = ctx.create_dataframe([[batch]]) + + # One-argument form: returns a Map of all metadata key-value pairs + result = df.select( + f.arrow_metadata(column("val")).alias("meta"), + ).collect()[0] + assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + meta = result.column(0)[0].as_py() + assert ("key1", "value1") in meta + assert ("key2", "value2") in meta + + # Two-argument form: returns the value for a specific metadata key + result = df.select( + f.arrow_metadata(column("val"), "key1").alias("meta_val"), + ).collect()[0] + assert result.column(0)[0].as_py() == "value1" + + +def test_version(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + result = df.select(f.version().alias("v")).collect()[0] + version_str = result.column(0)[0].as_py() + assert "Apache DataFusion" in version_str + + +def test_row(df): + result = df.select( + f.row(column("a"), column("b")).alias("r"), + f.struct(column("a"), column("b")).alias("s"), + ).collect()[0] + # row is an alias for struct, so they should produce the same output + assert result.column(0) == result.column(1) + + +def test_union_tag(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0] + assert result.column(0).to_pylist() == ["int", "str", "int"] + + +def test_union_extract(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0] + assert result.column(0).to_pylist() == [1, None, 2] + + @pytest.mark.parametrize("func", [f.array_any_value, f.list_any_value]) def test_any_value_aliases(func): ctx = SessionContext() From 898d73de20346bba7241907bb18cba47da53e9a9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 7 Apr 2026 09:01:36 -0400 Subject: [PATCH 12/29] Add missing aggregate functions (#1471) * Add missing aggregate functions: grouping, percentile_cont, var_population Expose upstream DataFusion aggregate functions that were not yet available in the Python API. Closes #1454. - grouping: returns grouping set membership indicator (rewritten by the ResolveGroupingFunction analyzer rule before physical planning) - percentile_cont: computes exact percentile using continuous interpolation (unlike approx_percentile_cont which uses t-digest) - var_population: alias for var_pop Co-Authored-By: Claude Opus 4.6 (1M context) * Fix grouping() distinct parameter type for API consistency Co-Authored-By: Claude Opus 4.6 (1M context) * Improve aggregate function tests and docstrings per review feedback Add docstring example to grouping(), parametrize percentile_cont tests, and add multi-column grouping test case. Co-Authored-By: Claude Opus 4.6 (1M context) * Add GroupingSet.rollup, .cube, and .grouping_sets factory methods Expose ROLLUP, CUBE, and GROUPING SETS via the DataFrame API by adding static methods on GroupingSet that construct the corresponding Expr variants. Update grouping() docstring and tests to use the new API. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove _GroupingSetInternal alias, use expr_internal.GroupingSet directly Co-Authored-By: Claude Opus 4.6 (1M context) * Parametrize grouping set tests for rollup and cube Co-Authored-By: Claude Opus 4.6 (1M context) * Add grouping sets documentation and note grouping() alias limitation Add user documentation for GroupingSet.rollup, .cube, and .grouping_sets with Pokemon dataset examples. Document the upstream alias limitation (apache/datafusion#21411) in both the grouping() docstring and the aggregation user guide. Co-Authored-By: Claude Opus 4.6 (1M context) * Add grouping sets note to DataFrame.aggregate() docstring Co-Authored-By: Claude Opus 4.6 (1M context) * Address PR review feedback: add quantile_cont alias and simplify examples - Add quantile_cont as alias for percentile_cont (matches upstream) - Replace pa.concat_arrays batch pattern with collect_column() in docstrings - Add percentile_cont, quantile_cont, var_population to docs function list Co-Authored-By: Claude Opus 4.6 (1M context) * Accept string column names in GroupingSet factory methods GroupingSet.rollup(), .cube(), and .grouping_sets() now accept both Expr objects and string column names, consistent with DataFrame.aggregate(). Co-Authored-By: Claude Opus 4.6 (1M context) * Add agent instructions to keep aggregation/window docs in sync Co-Authored-By: Claude Opus 4.6 (1M context) * dfn is already available globally * Remove unnecessary import on doctest --------- Co-authored-by: Claude Opus 4.6 (1M context) --- AGENTS.md | 12 ++ crates/core/src/expr/grouping_set.rs | 37 +++- crates/core/src/functions.rs | 23 ++- .../common-operations/aggregations.rst | 172 +++++++++++++++++- python/datafusion/dataframe.py | 16 +- python/datafusion/expr.py | 127 ++++++++++++- python/datafusion/functions.py | 130 ++++++++++++- python/tests/test_functions.py | 109 +++++++++++ 8 files changed, 614 insertions(+), 12 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f6fdfbd90..86c2e9c3b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -42,3 +42,15 @@ Every Python function must include a docstring with usage examples. - **Alias functions**: Functions that are simple aliases (e.g., `list_sort` aliasing `array_sort`) only need a one-line description and a `See Also` reference to the primary function. They do not need their own examples. + +## Aggregate and Window Function Documentation + +When adding or updating an aggregate or window function, ensure the corresponding +site documentation is kept in sync: + +- **Aggregations**: `docs/source/user-guide/common-operations/aggregations.rst` — + add new aggregate functions to the "Aggregate Functions" list and include usage + examples if appropriate. +- **Window functions**: `docs/source/user-guide/common-operations/windows.rst` — + add new window functions to the "Available Functions" list and include usage + examples if appropriate. diff --git a/crates/core/src/expr/grouping_set.rs b/crates/core/src/expr/grouping_set.rs index 549a866ed..11d8f4fcd 100644 --- a/crates/core/src/expr/grouping_set.rs +++ b/crates/core/src/expr/grouping_set.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use datafusion::logical_expr::GroupingSet; +use datafusion::logical_expr::{Expr, GroupingSet}; use pyo3::prelude::*; +use crate::expr::PyExpr; + #[pyclass( from_py_object, frozen, @@ -30,6 +32,39 @@ pub struct PyGroupingSet { grouping_set: GroupingSet, } +#[pymethods] +impl PyGroupingSet { + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn rollup(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Rollup( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn cube(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Cube( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*expr_lists))] + fn grouping_sets(expr_lists: Vec>) -> PyExpr { + Expr::GroupingSet(GroupingSet::GroupingSets( + expr_lists + .into_iter() + .map(|list| list.into_iter().map(|e| e.expr).collect()) + .collect(), + )) + .into() + } +} + impl From for GroupingSet { fn from(grouping_set: PyGroupingSet) -> Self { grouping_set.grouping_set diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 74654ce46..f173aaa51 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -791,9 +791,10 @@ aggregate_function!(var_pop); aggregate_function!(approx_distinct); aggregate_function!(approx_median); -// Code is commented out since grouping is not yet implemented -// https://github.com/apache/datafusion-python/issues/861 -// aggregate_function!(grouping); +// The grouping function's physical plan is not implemented, but the +// ResolveGroupingFunction analyzer rule rewrites it before the physical +// planner sees it, so it works correctly at runtime. +aggregate_function!(grouping); #[pyfunction] #[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))] @@ -831,6 +832,19 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } +#[pyfunction] +#[pyo3(signature = (sort_expression, percentile, filter=None))] +pub fn percentile_cont( + sort_expression: PySortExpr, + percentile: f64, + filter: Option, +) -> PyDataFusionResult { + let agg_fn = + functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile)); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + // We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] @@ -1031,6 +1045,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(approx_median))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; + m.add_wrapped(wrap_pyfunction!(percentile_cont))?; m.add_wrapped(wrap_pyfunction!(range))?; m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; @@ -1080,7 +1095,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; m.add_wrapped(wrap_pyfunction!(greatest))?; - // m.add_wrapped(wrap_pyfunction!(grouping))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; diff --git a/docs/source/user-guide/common-operations/aggregations.rst b/docs/source/user-guide/common-operations/aggregations.rst index e458e5fcb..de24a2ba5 100644 --- a/docs/source/user-guide/common-operations/aggregations.rst +++ b/docs/source/user-guide/common-operations/aggregations.rst @@ -163,6 +163,168 @@ Suppose we want to find the speed values for only Pokemon that have low Attack v f.avg(col_speed, filter=col_attack < lit(50)).alias("Avg Speed Low Attack")]) +Grouping Sets +------------- + +The default style of aggregation produces one row per group. Sometimes you want a single query to +produce rows at multiple levels of detail — for example, totals per type *and* an overall grand +total, or subtotals for every combination of two columns plus the individual column totals. Writing +separate queries and concatenating them is tedious and runs the data multiple times. Grouping sets +solve this by letting you specify several grouping levels in one pass. + +DataFusion supports three grouping set styles through the +:py:class:`~datafusion.expr.GroupingSet` class: + +- :py:meth:`~datafusion.expr.GroupingSet.rollup` — hierarchical subtotals, like a drill-down report +- :py:meth:`~datafusion.expr.GroupingSet.cube` — every possible subtotal combination, like a pivot table +- :py:meth:`~datafusion.expr.GroupingSet.grouping_sets` — explicitly list exactly which grouping levels you want + +Because result rows come from different grouping levels, a column that is *not* part of a +particular level will be ``null`` in that row. Use :py:func:`~datafusion.functions.grouping` to +distinguish a real ``null`` in the data from one that means "this column was aggregated across." +It returns ``0`` when the column is a grouping key for that row, and ``1`` when it is not. + +Rollup +^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.rollup` creates a hierarchy. ``rollup(a, b)`` produces +grouping sets ``(a, b)``, ``(a)``, and ``()`` — like nested subtotals in a report. This is useful +when your columns have a natural hierarchy, such as region → city or type → subtype. + +Suppose we want to summarize Pokemon stats by ``Type 1`` with subtotals and a grand total. With +the default aggregation style we would need two separate queries. With ``rollup`` we get it all at +once: + +.. ipython:: python + + from datafusion.expr import GroupingSet + + df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.max(col_speed).alias("Max Speed")] + ).sort(col_type_1.sort(ascending=True, nulls_first=True)) + +The first row — where ``Type 1`` is ``null`` — is the grand total across all types. But how do you +tell a grand-total ``null`` apart from a Pokemon that genuinely has no type? The +:py:func:`~datafusion.functions.grouping` function returns ``0`` when the column is a grouping key +for that row and ``1`` when it is aggregated across. + +.. note:: + + Due to an upstream DataFusion limitation + (`apache/datafusion#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` expression — it will raise an + error at execution time. Instead, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` on the result DataFrame to + give the column a readable name. Once the upstream issue is resolved, you will be able to + use ``.alias()`` directly and the workaround below will no longer be necessary. + +The raw column name generated by ``grouping()`` contains internal identifiers, so we use +:py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` to clean it up: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + result = result.with_column_renamed(field.name, "Is Total") + result.sort(col_type_1.sort(ascending=True, nulls_first=True)) + +With two columns the hierarchy becomes more apparent. ``rollup(Type 1, Type 2)`` produces: + +- one row per ``(Type 1, Type 2)`` pair — the most detailed level +- one row per ``Type 1`` — subtotals +- one grand total row + +.. ipython:: python + + df.aggregate( + [GroupingSet.rollup(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Cube +^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.cube` produces every possible subset. ``cube(a, b)`` +produces grouping sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` — one more than ``rollup`` because +it also includes ``(b)`` alone. This is useful when neither column is "above" the other in a +hierarchy and you want all cross-tabulations. + +For our Pokemon data, ``cube(Type 1, Type 2)`` gives us stats broken down by the type pair, +by ``Type 1`` alone, by ``Type 2`` alone, and a grand total — all in one query: + +.. ipython:: python + + df.aggregate( + [GroupingSet.cube(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Compared to the ``rollup`` example above, notice the extra rows where ``Type 1`` is ``null`` but +``Type 2`` has a value — those are the per-``Type 2`` subtotals that ``rollup`` does not include. + +Explicit Grouping Sets +^^^^^^^^^^^^^^^^^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.grouping_sets` lets you list exactly which grouping levels +you need when ``rollup`` or ``cube`` would produce too many or too few. Each argument is a list of +columns forming one grouping set. + +For example, if we want only the per-``Type 1`` totals and per-``Type 2`` totals — but *not* the +full ``(Type 1, Type 2)`` detail rows or the grand total — we can ask for exactly that: + +.. ipython:: python + + df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Each row belongs to exactly one grouping level. The :py:func:`~datafusion.functions.grouping` +function tells you which level each row comes from: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1), + f.grouping(col_type_2)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + clean = field.name.split(".")[-1].rstrip(")") + result = result.with_column_renamed(field.name, f"grouping({clean})") + result.sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Where ``grouping(Type 1)`` is ``0`` the row is a per-``Type 1`` total (and ``Type 2`` is ``null``). +Where ``grouping(Type 2)`` is ``0`` the row is a per-``Type 2`` total (and ``Type 1`` is ``null``). + + Aggregate Functions ------------------- @@ -192,6 +354,7 @@ The available aggregate functions are: - :py:func:`datafusion.functions.stddev_pop` - :py:func:`datafusion.functions.var_samp` - :py:func:`datafusion.functions.var_pop` + - :py:func:`datafusion.functions.var_population` 6. Linear Regression Functions - :py:func:`datafusion.functions.regr_count` - :py:func:`datafusion.functions.regr_slope` @@ -208,9 +371,16 @@ The available aggregate functions are: - :py:func:`datafusion.functions.nth_value` 8. String Functions - :py:func:`datafusion.functions.string_agg` -9. Approximation Functions +9. Percentile Functions + - :py:func:`datafusion.functions.percentile_cont` + - :py:func:`datafusion.functions.quantile_cont` - :py:func:`datafusion.functions.approx_distinct` - :py:func:`datafusion.functions.approx_median` - :py:func:`datafusion.functions.approx_percentile_cont` - :py:func:`datafusion.functions.approx_percentile_cont_with_weight` +10. Grouping Set Functions + - :py:func:`datafusion.functions.grouping` + - :py:meth:`datafusion.expr.GroupingSet.rollup` + - :py:meth:`datafusion.expr.GroupingSet.cube` + - :py:meth:`datafusion.expr.GroupingSet.grouping_sets` diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 10e2a913f..9907eae8b 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -633,8 +633,22 @@ def aggregate( ) -> DataFrame: """Aggregates the rows of the current DataFrame. + By default each unique combination of the ``group_by`` columns + produces one row. To get multiple levels of subtotals in a + single pass, pass a + :py:class:`~datafusion.expr.GroupingSet` expression + (created via + :py:meth:`~datafusion.expr.GroupingSet.rollup`, + :py:meth:`~datafusion.expr.GroupingSet.cube`, or + :py:meth:`~datafusion.expr.GroupingSet.grouping_sets`) + as the ``group_by`` argument. See the + :ref:`aggregation` user guide for detailed examples. + Args: - group_by: Sequence of expressions or column names to group by. + group_by: Sequence of expressions or column names to group + by. A :py:class:`~datafusion.expr.GroupingSet` + expression may be included to produce multiple grouping + levels (rollup, cube, or explicit grouping sets). aggs: Sequence of expressions to aggregate. Returns: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 14753a4f5..35388468c 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -91,7 +91,6 @@ Extension = expr_internal.Extension FileType = expr_internal.FileType Filter = expr_internal.Filter -GroupingSet = expr_internal.GroupingSet Join = expr_internal.Join ILike = expr_internal.ILike InList = expr_internal.InList @@ -1430,3 +1429,129 @@ def __repr__(self) -> str: SortKey = Expr | SortExpr | str + + +class GroupingSet: + """Factory for creating grouping set expressions. + + Grouping sets control how + :py:meth:`~datafusion.dataframe.DataFrame.aggregate` groups rows. + Instead of a single ``GROUP BY``, they produce multiple grouping + levels in one pass — subtotals, cross-tabulations, or arbitrary + column subsets. + + Use :py:func:`~datafusion.functions.grouping` in the aggregate list + to tell which columns are aggregated across in each result row. + """ + + @staticmethod + def rollup(*exprs: Expr | str) -> Expr: + """Create a ``ROLLUP`` grouping set for use with ``aggregate()``. + + ``ROLLUP`` generates all prefixes of the given column list as + grouping sets. For example, ``rollup(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY ROLLUP(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the rollup. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`cube`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.rollup(*args)) + + @staticmethod + def cube(*exprs: Expr | str) -> Expr: + """Create a ``CUBE`` grouping set for use with ``aggregate()``. + + ``CUBE`` generates all possible subsets of the given column list + as grouping sets. For example, ``cube(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY CUBE(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the cube. + + Examples: + With a single column, ``cube`` behaves identically to + :py:meth:`rollup`: + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.cube(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`rollup`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.cube(*args)) + + @staticmethod + def grouping_sets(*expr_lists: list[Expr | str]) -> Expr: + """Create explicit grouping sets for use with ``aggregate()``. + + Each argument is a list of column expressions or column name + strings representing one grouping set. For example, + ``grouping_sets([a], [b])`` groups by ``a`` alone and by ``b`` + alone in a single query. + + This is equivalent to ``GROUP BY GROUPING SETS ((a), (b))`` in + SQL. + + Args: + *expr_lists: Each positional argument is a list of + expressions or column name strings forming one + grouping set. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"a": ["x", "x", "y"], "b": ["m", "n", "m"], + ... "c": [1, 2, 3]}) + >>> result = df.aggregate( + ... [GroupingSet.grouping_sets( + ... [dfn.col("a")], [dfn.col("b")])], + ... [dfn.functions.sum(dfn.col("c")).alias("s"), + ... dfn.functions.grouping(dfn.col("a")), + ... dfn.functions.grouping(dfn.col("b"))], + ... ).sort( + ... dfn.col("a").sort(nulls_first=False), + ... dfn.col("b").sort(nulls_first=False), + ... ) + >>> result.collect_column("s").to_pylist() + [3, 3, 4, 2] + + See Also: + :py:meth:`rollup`, :py:meth:`cube`, + :py:func:`~datafusion.functions.grouping` + """ + raw_lists = [[_to_raw_expr(e) for e in lst] for lst in expr_lists] + return Expr(expr_internal.GroupingSet.grouping_sets(*raw_lists)) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index aa7f28746..9dfabb62d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -166,6 +166,7 @@ "generate_series", "get_field", "greatest", + "grouping", "ifnull", "in_list", "initcap", @@ -256,9 +257,11 @@ "order_by", "overlay", "percent_rank", + "percentile_cont", "pi", "pow", "power", + "quantile_cont", "radians", "random", "range", @@ -331,6 +334,7 @@ "uuid", "var", "var_pop", + "var_population", "var_samp", "var_sample", "version", @@ -2654,7 +2658,6 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: >>> result.collect_column("c")[0].as_py() 1.0 - >>> import pyarrow as pa >>> result = df.select( ... dfn.functions.arrow_cast( ... dfn.col("a"), data_type=pa.float64() @@ -2677,7 +2680,6 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: If called with two arguments, returns the value for the specified metadata key. Examples: - >>> import pyarrow as pa >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) >>> schema = pa.schema([field]) >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) @@ -2746,7 +2748,6 @@ def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: variant, otherwise returns NULL. Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> types = pa.array([0, 1, 0], type=pa.int8()) >>> offsets = pa.array([0, 0, 1], type=pa.int32()) @@ -2771,7 +2772,6 @@ def union_tag(union_expr: Expr) -> Expr: """Returns the tag (active field name) of a union type. Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> types = pa.array([0, 1, 0], type=pa.int8()) >>> offsets = pa.array([0, 0, 1], type=pa.int32()) @@ -4306,6 +4306,60 @@ def approx_percentile_cont_with_weight( ) +def percentile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + Unlike :py:func:`approx_percentile_cont`, this function computes the exact + percentile value rather than an approximation. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + sort_expression: Values for which to find the percentile + percentile: This must be between 0.0 and 1.0, inclusive + filter: If provided, only compute against rows for which the filter is True + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5 + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.0 + + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5, + ... filter=dfn.col("a") > dfn.lit(1.0), + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.5 + """ + sort_expr_raw = sort_or_default(sort_expression) + filter_raw = filter.expr if filter is not None else None + return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw)) + + +def quantile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + See Also: + This is an alias for :py:func:`percentile_cont`. + """ + return percentile_cont(sort_expression, percentile, filter) + + def array_agg( expression: Expr, distinct: bool = False, @@ -4364,6 +4418,65 @@ def array_agg( ) +def grouping( + expression: Expr, + distinct: bool = False, + filter: Expr | None = None, +) -> Expr: + """Indicates whether a column is aggregated across in the current row. + + Returns 0 when the column is part of the grouping key for that row + (i.e., the row contains per-group results for that column). Returns 1 + when the column is *not* part of the grouping key (i.e., the row's + aggregate spans all values of that column). + + This function is meaningful with + :py:meth:`GroupingSet.rollup `, + :py:meth:`GroupingSet.cube `, or + :py:meth:`GroupingSet.grouping_sets `, + where different rows are grouped by different subsets of columns. In a + default aggregation without grouping sets every column is always part + of the key, so ``grouping()`` always returns 0. + + .. warning:: + + Due to an upstream DataFusion limitation + (`#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` + expression. Doing so will raise an error at execution time. To + rename the column, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` + on the result DataFrame instead. + + Args: + expression: The column to check grouping status for + distinct: If True, compute on distinct values only + filter: If provided, only compute against rows for which the filter is True + + Examples: + With :py:meth:`~datafusion.expr.GroupingSet.rollup`, the result + includes both per-group rows (``grouping(a) = 0``) and a + grand-total row where ``a`` is aggregated across + (``grouping(a) = 1``): + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:class:`~datafusion.expr.GroupingSet` + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw)) + + def avg( expression: Expr, filter: Expr | None = None, @@ -4835,6 +4948,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr: return Expr(f.var_pop(expression.expr, filter=filter_raw)) +def var_population(expression: Expr, filter: Expr | None = None) -> Expr: + """Computes the population variance of the argument. + + See Also: + This is an alias for :py:func:`var_pop`. + """ + return var_pop(expression, filter) + + def var_samp(expression: Expr, filter: Expr | None = None) -> Expr: """Computes the sample variance of the argument. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 4e99fa9e3..11e94af1c 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -22,6 +22,7 @@ import pytest from datafusion import SessionContext, column, literal from datafusion import functions as f +from datafusion.expr import GroupingSet np.seterr(invalid="ignore") @@ -1820,6 +1821,114 @@ def test_conditional_functions(df_with_nulls, expr, expected): assert result.column(0) == expected +@pytest.mark.parametrize( + ("func", "filter_expr", "expected"), + [ + (f.percentile_cont, None, 3.0), + (f.percentile_cont, column("a") > literal(1.0), 3.5), + (f.quantile_cont, None, 3.0), + ], + ids=["no_filter", "with_filter", "quantile_cont_alias"], +) +def test_percentile_cont(func, filter_expr, expected): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + result = df.aggregate( + [], [func(column("a"), 0.5, filter=filter_expr).alias("v")] + ).collect()[0] + assert result.column(0)[0].as_py() == expected + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_grouping", "expected_sums"), + [ + (GroupingSet.rollup(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.rollup("a"), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube("a"), [0, 0, 1], [30, 30, 60]), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_single_column( + grouping_set_expr, expected_grouping, expected_sums +): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("b")).alias("s"), f.grouping(column("a"))], + ).sort(column("a").sort(ascending=True, nulls_first=False)) + batches = result.collect() + g = pa.concat_arrays([b.column(2) for b in batches]).to_pylist() + s = pa.concat_arrays([b.column("s") for b in batches]).to_pylist() + assert g == expected_grouping + assert s == expected_sums + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_rows"), + [ + # rollup(a, b) => (a,b), (a), () => 3 + 2 + 1 = 6 + (GroupingSet.rollup(column("a"), column("b")), 6), + # cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8 + (GroupingSet.cube(column("a"), column("b")), 8), + (GroupingSet.rollup("a", "b"), 6), + (GroupingSet.cube("a", "b"), 8), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_multi_column(grouping_set_expr, expected_rows): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("c")).alias("s")], + ) + total_rows = sum(b.num_rows for b in result.collect()) + assert total_rows == expected_rows + + +@pytest.mark.parametrize( + "grouping_set_expr", + [ + GroupingSet.grouping_sets([column("a")], [column("b")]), + GroupingSet.grouping_sets(["a"], ["b"]), + ], + ids=["expr", "str"], +) +def test_grouping_sets_explicit(grouping_set_expr): + # Each row's grouping() value tells you which columns are aggregated across. + ctx = SessionContext() + df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]}) + result = df.aggregate( + [grouping_set_expr], + [ + f.sum(column("c")).alias("s"), + f.grouping(column("a")), + f.grouping(column("b")), + ], + ).sort( + column("a").sort(ascending=True, nulls_first=False), + column("b").sort(ascending=True, nulls_first=False), + ) + batches = result.collect() + ga = pa.concat_arrays([b.column(3) for b in batches]).to_pylist() + gb = pa.concat_arrays([b.column(4) for b in batches]).to_pylist() + # Rows grouped by (a): ga=0 (a is a key), gb=1 (b is aggregated across) + # Rows grouped by (b): ga=1 (a is aggregated across), gb=0 (b is a key) + assert ga == [0, 0, 1, 1] + assert gb == [1, 1, 0, 0] + + +def test_var_population(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]}) + result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0] + # var_population is an alias for var_pop + expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0] + assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10 + + def test_get_field(df): df = df.with_column( "s", From 52932128d353e417ddae2c5ff3f14135cb806f7e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 7 Apr 2026 14:58:09 -0400 Subject: [PATCH 13/29] Add missing Dataframe functions (#1472) * Add missing DataFrame methods for set operations and query Expose upstream DataFusion DataFrame methods that were not yet available in the Python API. Closes #1455. Set operations: - except_distinct: set difference with deduplication - intersect_distinct: set intersection with deduplication - union_by_name: union matching columns by name instead of position - union_by_name_distinct: union by name with deduplication Query: - distinct_on: deduplicate rows based on specific columns - sort_by: sort by expressions with ascending order and nulls last Note: show_limit is already covered by the existing show(num) method. explain_with_options and with_param_values are deferred as they require exposing additional types (ExplainOption, ParamValues). Co-Authored-By: Claude Opus 4.6 (1M context) * Add ExplainFormat enum and format option to DataFrame.explain() Extend the existing explain() method with an optional format parameter instead of adding a separate explain_with_options() method. This keeps the API simple while exposing all upstream ExplainOption functionality. Available formats: indent (default), tree, pgjson, graphviz. The ExplainFormat enum is exported from the top-level datafusion module. Co-Authored-By: Claude Opus 4.6 (1M context) * Add DataFrame.window() and unnest recursion options Expose remaining DataFrame methods from upstream DataFusion. Closes #1456. - window(*exprs): apply window function expressions and append results as new columns - unnest_column/unnest_columns: add optional recursions parameter for controlling unnest depth via (input_column, output_column, depth) tuples Note: drop_columns is already exposed as the existing drop() method. Co-Authored-By: Claude Opus 4.6 (1M context) * Update docstring Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Improve docstrings and test robustness for new DataFrame methods Clarify except_distinct/intersect_distinct docstrings, add deterministic sort to test_window, add sort_by ascending verification test, and add smoke tests for PGJSON and GRAPHVIZ explain formats. Co-Authored-By: Claude Opus 4.6 (1M context) * Consolidate new DataFrame tests into parametrized tests Combine set operation tests (except_distinct, intersect_distinct, union_by_name, union_by_name_distinct) into a single parametrized test_set_operations_distinct. Merge sort_by tests and convert explain format tests to parametrized form. Co-Authored-By: Claude Opus 4.6 (1M context) * Add doctest examples to new DataFrame method docstrings Add >>> style usage examples for window, explain, except_distinct, intersect_distinct, union_by_name, union_by_name_distinct, distinct_on, sort_by, and unnest_columns to match existing docstring conventions. Co-Authored-By: Claude Opus 4.6 (1M context) * Improve error messages, tests, and API hygiene from PR review - Provide actionable error message for invalid explain format strings - Remove recursions param from deprecated unnest_column (use unnest_columns) - Add null-handling test case for sort_by to verify nulls-last behavior - Add format-specific assertions to explain tests (TREE, PGJSON, GRAPHVIZ) - Add deep recursion test for unnest_columns with depth > 1 - Add multi-expression window test to verify variadic *exprs Co-Authored-By: Claude Opus 4.6 (1M context) * Consolidate window and unnest tests into parametrized tests Combine test_window and test_window_multiple_expressions into a single parametrized test. Merge unnest recursion tests into one parametrized test covering basic, explicit depth 1, and deep recursion cases. Co-Authored-By: Claude Opus 4.6 (1M context) * Address PR review feedback for DataFrame operations - Use upstream parse error for explain format instead of hardcoded options - Fix sort_by to use column name resolution consistent with sort() - Use ExplainFormat enum members directly in tests instead of string lookup - Merge union_by_name_distinct into union_by_name(distinct=False) for a more Pythonic API - Update check-upstream skill to note union_by_name_distinct coverage Co-Authored-By: Claude Opus 4.6 (1M context) * Add DataFrame.column(), col(), and find_qualified_columns() methods Expose upstream find_qualified_columns to resolve unqualified column names into fully qualified column expressions. This is especially useful for disambiguating columns after joins. - find_qualified_columns(*names) on Rust side calls upstream directly - DataFrame.column(name) and col(name) alias on Python side - Update join and join_on docstrings to reference DataFrame.col() - Add "Disambiguating Columns with DataFrame.col()" section to joins docs - Add tests for qualified column resolution, ambiguity, and join usage Co-Authored-By: Claude Opus 4.6 (1M context) * Merge union_by_name and union_by_name_distinct into a single method with distinct flag Co-Authored-By: Claude Opus 4.6 (1M context) * converting into a python dict loses a column when the names are identical * Consolidate except_all/except_distinct and intersect/intersect_distinct into single methods with distinct flag Follows the same pattern as union(distinct=) and union_by_name(distinct=). Also deprecates union_distinct() in favor of union(distinct=True). Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .ai/skills/check-upstream/SKILL.md | 1 + crates/core/src/dataframe.rs | 157 ++++++-- .../user-guide/common-operations/joins.rst | 33 ++ python/datafusion/__init__.py | 2 + python/datafusion/dataframe.py | 365 +++++++++++++++++- python/tests/test_dataframe.py | 261 +++++++++++++ 6 files changed, 767 insertions(+), 52 deletions(-) diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index f77210371..ac4835a4e 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -109,6 +109,7 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" **Evaluated and not requiring separate Python exposure:** - `show_limit` — already covered by `DataFrame.show()`, which provides the same functionality with a simpler API - `with_param_values` — already covered by the `param_values` argument on `SessionContext.sql()`, which accomplishes the same thing more robustly +- `union_by_name_distinct` — already covered by `DataFrame.union_by_name(distinct=True)`, which provides a more Pythonic API **How to check:** 1. Fetch the upstream DataFrame documentation page listing all methods diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 72595ba81..fff5118d5 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -582,6 +582,14 @@ impl PyDataFrame { Ok(Self::new(df)) } + /// Apply window function expressions to the DataFrame + #[pyo3(signature = (*exprs))] + fn window(&self, exprs: Vec) -> PyDataFusionResult { + let window_exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().window(window_exprs)?; + Ok(Self::new(df)) + } + fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; Ok(Self::new(df)) @@ -804,9 +812,27 @@ impl PyDataFrame { } /// Print the query plan - #[pyo3(signature = (verbose=false, analyze=false))] - fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> { - let df = self.df.as_ref().clone().explain(verbose, analyze)?; + #[pyo3(signature = (verbose=false, analyze=false, format=None))] + fn explain( + &self, + py: Python, + verbose: bool, + analyze: bool, + format: Option<&str>, + ) -> PyDataFusionResult<()> { + let explain_format = match format { + Some(f) => f + .parse::() + .map_err(|e| { + PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e)) + })?, + None => datafusion::common::format::ExplainFormat::Indent, + }; + let opts = datafusion::logical_expr::ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze) + .with_format(explain_format); + let df = self.df.as_ref().clone().explain_with_options(opts)?; print_dataframe(py, df) } @@ -864,22 +890,14 @@ impl PyDataFrame { Ok(Self::new(new_df)) } - /// Calculate the distinct union of two `DataFrame`s. The - /// two `DataFrame`s must have exactly the same schema - fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) - } - - #[pyo3(signature = (column, preserve_nulls=true))] - fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + #[pyo3(signature = (column, preserve_nulls=true, recursions=None))] + fn unnest_column( + &self, + column: &str, + preserve_nulls: bool, + recursions: Option>, + ) -> PyDataFusionResult { + let unnest_options = build_unnest_options(preserve_nulls, recursions); let df = self .df .as_ref() @@ -888,15 +906,14 @@ impl PyDataFrame { Ok(Self::new(df)) } - #[pyo3(signature = (columns, preserve_nulls=true))] + #[pyo3(signature = (columns, preserve_nulls=true, recursions=None))] fn unnest_columns( &self, columns: Vec, preserve_nulls: bool, + recursions: Option>, ) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + let unnest_options = build_unnest_options(preserve_nulls, recursions); let cols = columns.iter().map(|s| s.as_ref()).collect::>(); let df = self .df @@ -907,21 +924,79 @@ impl PyDataFrame { } /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .intersect(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn intersect(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.intersect_distinct(other)? + } else { + base.intersect(other)? + }; Ok(Self::new(new_df)) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn except_all(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.except_distinct(other)? + } else { + base.except(other)? + }; Ok(Self::new(new_df)) } + /// Union two DataFrames matching columns by name + #[pyo3(signature = (py_df, distinct=false))] + fn union_by_name(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.union_by_name_distinct(other)? + } else { + base.union_by_name(other)? + }; + Ok(Self::new(new_df)) + } + + /// Deduplicate rows based on specific columns, keeping the first row per group + fn distinct_on( + &self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> PyDataFusionResult { + let on_expr = on_expr.into_iter().map(|e| e.into()).collect(); + let select_expr = select_expr.into_iter().map(|e| e.into()).collect(); + let sort_expr = sort_expr.map(to_sort_expressions); + let df = self + .df + .as_ref() + .clone() + .distinct_on(on_expr, select_expr, sort_expr)?; + Ok(Self::new(df)) + } + + /// Sort by column expressions with ascending order and nulls last + fn sort_by(&self, exprs: Vec) -> PyDataFusionResult { + let exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().sort_by(exprs)?; + Ok(Self::new(df)) + } + + /// Return fully qualified column expressions for the given column names + fn find_qualified_columns(&self, names: Vec) -> PyDataFusionResult> { + let name_refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect(); + let qualified = self.df.find_qualified_columns(&name_refs)?; + Ok(qualified + .into_iter() + .map(|q| Expr::Column(Column::from(q)).into()) + .collect()) + } + /// Write a `DataFrame` to a CSV file. fn write_csv( &self, @@ -1295,6 +1370,26 @@ impl PyDataFrameWriteOptions { } } +fn build_unnest_options( + preserve_nulls: bool, + recursions: Option>, +) -> UnnestOptions { + let mut opts = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + if let Some(recs) = recursions { + opts.recursions = recs + .into_iter() + .map( + |(input, output, depth)| datafusion::common::RecursionUnnestOption { + input_column: datafusion::common::Column::from(input.as_str()), + output_column: datafusion::common::Column::from(output.as_str()), + depth, + }, + ) + .collect(); + } + opts +} + /// Print DataFrame fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst index 1d9d70385..a289c9377 100644 --- a/docs/source/user-guide/common-operations/joins.rst +++ b/docs/source/user-guide/common-operations/joins.rst @@ -134,3 +134,36 @@ In contrast to the above example, if we wish to get both columns: .. ipython:: python left.join(right, "id", how="inner", coalesce_duplicate_keys=False) + +Disambiguating Columns with ``DataFrame.col()`` +------------------------------------------------ + +When both DataFrames contain non-key columns with the same name, you can use +:py:meth:`~datafusion.dataframe.DataFrame.col` on each DataFrame **before** the +join to create fully qualified column references. These references can then be +used in the join predicate and when selecting from the result. + +This is especially useful with :py:meth:`~datafusion.dataframe.DataFrame.join_on`, +which accepts expression-based predicates. + +.. ipython:: python + + left = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [10, 20, 30], + } + ) + + right = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [40, 50, 60], + } + ) + + joined = left.join_on( + right, left.col("id") == right.col("id"), how="inner" + ) + + joined.select(left.col("id"), left.col("val"), right.col("val")) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2e6f81166..a736c3966 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -47,6 +47,7 @@ from .dataframe import ( DataFrame, DataFrameWriteOptions, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -82,6 +83,7 @@ "DataFrameWriteOptions", "Database", "ExecutionPlan", + "ExplainFormat", "Expr", "InsertOp", "LogicalPlan", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 9907eae8b..9dc5f0e7d 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -44,6 +44,7 @@ Expr, SortExpr, SortKey, + _to_raw_expr, ensure_expr, ensure_expr_list, expr_list_to_raw_expr_list, @@ -65,6 +66,25 @@ from enum import Enum +class ExplainFormat(Enum): + """Output format for explain plans. + + Controls how the query plan is rendered in :py:meth:`DataFrame.explain`. + """ + + INDENT = "indent" + """Default indented text format.""" + + TREE = "tree" + """Tree-style visual format with box-drawing characters.""" + + PGJSON = "pgjson" + """PostgreSQL-compatible JSON format for use with visualization tools.""" + + GRAPHVIZ = "graphviz" + """Graphviz DOT format for graph rendering.""" + + # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): @@ -395,6 +415,80 @@ def schema(self) -> pa.Schema: """ return self.df.schema() + def column(self, name: str) -> Expr: + """Return a fully qualified column expression for ``name``. + + Resolves an unqualified column name against this DataFrame's schema + and returns an :py:class:`Expr` whose underlying column reference + includes the table qualifier. This is especially useful after joins, + where the same column name may appear in multiple relations. + + Args: + name: Unqualified column name to look up. + + Returns: + A fully qualified column expression. + + Raises: + Exception: If the column is not found or is ambiguous (exists in + multiple relations). + + Examples: + Resolve a column from a simple DataFrame: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + >>> expr = df.column("a") + >>> df.select(expr).to_pydict() + {'a': [1, 2]} + + Resolve qualified columns after a join: + + >>> left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + >>> joined = left.join(right, on="id", how="inner") + >>> expr = joined.column("y") + >>> joined.select("id", expr).sort("id").to_pydict() + {'id': [1, 2], 'y': [30, 40]} + """ + return self.find_qualified_columns(name)[0] + + def col(self, name: str) -> Expr: + """Alias for :py:meth:`column`. + + See Also: + :py:meth:`column` + """ + return self.column(name) + + def find_qualified_columns(self, *names: str) -> list[Expr]: + """Return fully qualified column expressions for the given names. + + This is a batch version of :py:meth:`column` — it resolves each + unqualified name against the DataFrame's schema and returns a list + of qualified column expressions. + + Args: + names: Unqualified column names to look up. + + Returns: + List of fully qualified column expressions, one per name. + + Raises: + Exception: If any column is not found or is ambiguous. + + Examples: + Resolve multiple columns at once: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + >>> exprs = df.find_qualified_columns("a", "c") + >>> df.select(*exprs).to_pydict() + {'a': [1, 2], 'c': [5, 6]} + """ + raw_exprs = self.df.find_qualified_columns(list(names)) + return [Expr(e) for e in raw_exprs] + @deprecated( "select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" ) @@ -468,6 +562,36 @@ def drop(self, *columns: str) -> DataFrame: """ return DataFrame(self.df.drop(*columns)) + def window(self, *exprs: Expr) -> DataFrame: + """Add window function columns to the DataFrame. + + Applies the given window function expressions and appends the results + as new columns. + + Args: + exprs: Window function expressions to evaluate. + + Returns: + DataFrame with new window function columns appended. + + Examples: + Add a row number within each group: + + >>> import datafusion.functions as f + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + >>> df = df.window( + ... f.row_number( + ... partition_by=[col("b")], order_by=[col("a")] + ... ).alias("rn") + ... ) + >>> "rn" in df.schema().names + True + """ + raw = expr_list_to_raw_expr_list(exprs) + return DataFrame(self.df.window(*raw)) + def filter(self, *predicates: Expr | str) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. @@ -837,7 +961,13 @@ def join( ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. - `on` has to be provided or both `left_on` and `right_on` in conjunction. + ``on`` has to be provided or both ``left_on`` and ``right_on`` in + conjunction. + + When non-key columns share the same name in both DataFrames, use + :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references that can disambiguate them. + See :py:meth:`join_on` for an example. Args: right: Other DataFrame to join with. @@ -911,7 +1041,14 @@ def join_on( built with :func:`datafusion.col`. On expressions are used to support in-equality predicates. Equality predicates are correctly optimized. + Use :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references. These qualified references + can then be used in the join predicate and to disambiguate columns + with the same name when selecting from the result. + Examples: + Join with unique column names: + >>> ctx = dfn.SessionContext() >>> left = ctx.from_pydict({"a": [1, 2], "x": ["a", "b"]}) >>> right = ctx.from_pydict({"b": [1, 2], "y": ["c", "d"]}) @@ -920,6 +1057,18 @@ def join_on( ... ).sort(col("x")).to_pydict() {'a': [1, 2], 'x': ['a', 'b'], 'b': [1, 2], 'y': ['c', 'd']} + Use :py:meth:`col` to disambiguate shared column names: + + >>> left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + >>> joined = left.join_on( + ... right, left.col("id") == right.col("id"), how="inner" + ... ) + >>> joined.select( + ... left.col("id"), left.col("val"), right.col("val").alias("rval") + ... ).sort(left.col("id")).to_pydict() + {'id': [1, 2], 'val': [10, 20], 'rval': [30, 40]} + Args: right: Other DataFrame to join with. on_exprs: single or multiple (in)-equality predicates. @@ -932,7 +1081,12 @@ def join_on( exprs = [ensure_expr(expr) for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) - def explain(self, verbose: bool = False, analyze: bool = False) -> None: + def explain( + self, + verbose: bool = False, + analyze: bool = False, + format: ExplainFormat | None = None, + ) -> None: """Print an explanation of the DataFrame's plan so far. If ``analyze`` is specified, runs the plan and reports metrics. @@ -940,8 +1094,23 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> None: Args: verbose: If ``True``, more details will be included. analyze: If ``True``, the plan will run and metrics reported. + format: Output format for the plan. Defaults to + :py:attr:`ExplainFormat.INDENT`. + + Examples: + Show the plan in tree format: + + >>> from datafusion import ExplainFormat + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.explain(format=ExplainFormat.TREE) # doctest: +SKIP + + Show plan with runtime metrics: + + >>> df.explain(analyze=True) # doctest: +SKIP """ - self.df.explain(verbose, analyze) + fmt = format.value if format is not None else None + self.df.explain(verbose, analyze, fmt) def logical_plan(self) -> LogicalPlan: """Return the unoptimized ``LogicalPlan``. @@ -1010,45 +1179,170 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: """ return DataFrame(self.df.union(other.df, distinct)) + @deprecated( + "union_distinct() is deprecated. Use union(other, distinct=True) instead." + ) def union_distinct(self, other: DataFrame) -> DataFrame: """Calculate the distinct union of two :py:class:`DataFrame`. + See Also: + :py:meth:`union` + """ + return self.union(other, distinct=True) + + def intersect(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the intersection of two :py:class:`DataFrame`. + The two :py:class:`DataFrame` must have exactly the same schema. - Any duplicate rows are discarded. Args: - other: DataFrame to union with. + other: DataFrame to intersect with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after union. + DataFrame after intersection. + + Examples: + Find rows common to both DataFrames: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]}) + >>> df1.intersect(df2).to_pydict() + {'a': [1], 'b': [10]} + + Intersect with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 20]}) + >>> df2 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df1.intersect(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} """ - return DataFrame(self.df.union_distinct(other.df)) + return DataFrame(self.df.intersect(other.df, distinct)) - def intersect(self, other: DataFrame) -> DataFrame: - """Calculate the intersection of two :py:class:`DataFrame`. + def except_all(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the set difference of two :py:class:`DataFrame`. + + Returns rows that are in this DataFrame but not in ``other``. The two :py:class:`DataFrame` must have exactly the same schema. Args: - other: DataFrame to intersect with. + other: DataFrame to calculate exception with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after intersection. + DataFrame after set difference. + + Examples: + Remove rows present in ``df2``: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 2], "b": [10, 20]}) + >>> df1.except_all(df2).sort("a").to_pydict() + {'a': [3], 'b': [30]} + + Remove rows present in ``df2`` and deduplicate: + + >>> df1.except_all(df2, distinct=True).sort("a").to_pydict() + {'a': [3], 'b': [30]} """ - return DataFrame(self.df.intersect(other.df)) + return DataFrame(self.df.except_all(other.df, distinct)) - def except_all(self, other: DataFrame) -> DataFrame: - """Calculate the exception of two :py:class:`DataFrame`. + def union_by_name(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Union two :py:class:`DataFrame` matching columns by name. - The two :py:class:`DataFrame` must have exactly the same schema. + Unlike :py:meth:`union` which matches columns by position, this method + matches columns by their names, allowing DataFrames with different + column orders to be combined. Args: - other: DataFrame to calculate exception with. + other: DataFrame to union with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after exception. + DataFrame after union by name. + + Examples: + Combine DataFrames with different column orders: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1], "b": [10]}) + >>> df2 = ctx.from_pydict({"b": [20], "a": [2]}) + >>> df1.union_by_name(df2).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 20]} + + Union by name with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df2 = ctx.from_pydict({"b": [10], "a": [1]}) + >>> df1.union_by_name(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} """ - return DataFrame(self.df.except_all(other.df)) + return DataFrame(self.df.union_by_name(other.df, distinct)) + + def distinct_on( + self, + on_expr: list[Expr], + select_expr: list[Expr], + sort_expr: list[SortKey] | None = None, + ) -> DataFrame: + """Deduplicate rows based on specific columns. + + Returns a new DataFrame with one row per unique combination of the + ``on_expr`` columns, keeping the first row per group as determined by + ``sort_expr``. + + Args: + on_expr: Expressions that determine uniqueness. + select_expr: Expressions to include in the output. + sort_expr: Optional sort expressions to determine which row to keep. + + Returns: + DataFrame after deduplication. + + Examples: + Keep the row with the smallest ``b`` for each unique ``a``: + + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + >>> df.distinct_on( + ... [col("a")], + ... [col("a"), col("b")], + ... [col("a").sort(ascending=True), col("b").sort(ascending=True)], + ... ).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 30]} + """ + on_raw = expr_list_to_raw_expr_list(on_expr) + select_raw = expr_list_to_raw_expr_list(select_expr) + sort_raw = sort_list_to_raw_sort_list(sort_expr) if sort_expr else None + return DataFrame(self.df.distinct_on(on_raw, select_raw, sort_raw)) + + def sort_by(self, *exprs: Expr | str) -> DataFrame: + """Sort the DataFrame by column expressions in ascending order. + + This is a convenience method that sorts the DataFrame by the given + expressions in ascending order with nulls last. For more control over + sort direction and null ordering, use :py:meth:`sort` instead. + + Args: + exprs: Expressions or column names to sort by. + + Returns: + DataFrame after sorting. + + Examples: + Sort by a single column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2]}) + >>> df.sort_by("a").to_pydict() + {'a': [1, 2, 3]} + """ + raw = [_to_raw_expr(e) for e in exprs] + return DataFrame(self.df.sort_by(raw)) def write_csv( self, @@ -1310,23 +1604,52 @@ def count(self) -> int: return self.df.count() @deprecated("Use :py:func:`unnest_columns` instead.") - def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: + def unnest_column( + self, + column: str, + preserve_nulls: bool = True, + ) -> DataFrame: """See :py:func:`unnest_columns`.""" return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) - def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: + def unnest_columns( + self, + *columns: str, + preserve_nulls: bool = True, + recursions: list[tuple[str, str, int]] | None = None, + ) -> DataFrame: """Expand columns of arrays into a single row per array element. Args: columns: Column names to perform unnest operation on. preserve_nulls: If False, rows with null entries will not be returned. + recursions: Optional list of ``(input_column, output_column, depth)`` + tuples that control how deeply nested columns are unnested. Any + column not mentioned here is unnested with depth 1. Returns: A DataFrame with the columns expanded. + + Examples: + Unnest an array column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2], [3]], "b": ["x", "y"]}) + >>> df.unnest_columns("a").to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} + + With explicit recursion depth: + + >>> df.unnest_columns("a", recursions=[("a", "a", 1)]).to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} """ columns = list(columns) - return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) + return DataFrame( + self.df.unnest_columns( + columns, preserve_nulls=preserve_nulls, recursions=recursions + ) + ) def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """Export the DataFrame as an Arrow C Stream. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 759d6278c..bb8e9685c 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,6 +29,7 @@ import pytest from datafusion import ( DataFrame, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -3569,3 +3570,263 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order): pa.parquet.write_table(table, path) df = ctx.read_parquet(path, file_sort_order=file_sort_order) assert df.collect()[0].column(0).to_pylist() == [1, 2] + + +@pytest.mark.parametrize( + ("df1_data", "df2_data", "method", "kwargs", "expected_a", "expected_b"), + [ + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 2], "b": [10, 20]}, + "except_all", + {"distinct": True}, + [3], + [30], + id="except_all(distinct=True): removes matching rows and deduplicates", + ), + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 4], "b": [10, 40]}, + "intersect", + {"distinct": True}, + [1], + [10], + id="intersect(distinct=True): keeps common rows and deduplicates", + ), + pytest.param( + {"a": [1], "b": [10]}, + {"b": [20], "a": [2]}, # reversed column order tests matching by name + "union_by_name", + {}, + [1, 2], + [10, 20], + id="union_by_name: matches columns by name not position", + ), + ], +) +def test_set_operations_distinct( + df1_data, df2_data, method, kwargs, expected_a, expected_b +): + ctx = SessionContext() + df1 = ctx.from_pydict(df1_data) + df2 = ctx.from_pydict(df2_data) + result = ( + getattr(df1, method)(df2, **kwargs) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + assert result.column(0).to_pylist() == expected_a + assert result.column(1).to_pylist() == expected_b + + +def test_union_by_name_distinct(): + ctx = SessionContext() + df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + df2 = ctx.from_pydict({"b": [10], "a": [1]}) + result = df1.union_by_name(df2, distinct=True).collect()[0] + assert result.column(0).to_pylist() == [1] + assert result.column(1).to_pylist() == [10] + + +def test_column_qualified(): + """DataFrame.column() returns a qualified column expression.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + expr = df.column("a") + result = df.select(expr).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + + +def test_column_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.column("z") + + +def test_column_ambiguous(): + """After a join, duplicate column names that cannot be resolved raise an error.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + joined = left.join(right, on="id", how="inner") + with pytest.raises(Exception, match="not found"): + joined.column("val") + + +def test_column_after_join(): + """Qualified column works for non-ambiguous columns after a join.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + joined = left.join(right, on="id", how="inner") + expr = joined.column("y") + result = joined.select("id", expr).sort("id").collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [30, 40] + + +def test_col_join_disambiguate(): + """Use col() to disambiguate and select columns after a join.""" + ctx = SessionContext() + df1 = ctx.from_pydict({"foo": [1, 2, 3], "bar": [5, 6, 7]}) + df2 = ctx.from_pydict({"foo": [1, 2, 3], "baz": [8, 9, 10]}) + joined = df1.join_on(df2, df1.col("foo") == df2.col("foo"), how="inner") + result = ( + joined.select(df1.col("foo"), df1.col("bar"), df2.col("baz")) + .sort(df1.col("foo")) + .to_pydict() + ) + assert result["bar"] == [5, 6, 7] + assert result["baz"] == [8, 9, 10] + + +def test_find_qualified_columns(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + exprs = df.find_qualified_columns("a", "c") + assert len(exprs) == 2 + result = df.select(*exprs).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [5, 6] + + +def test_find_qualified_columns_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.find_qualified_columns("a", "z") + + +def test_distinct_on(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + result = ( + df.distinct_on( + [column("a")], + [column("a"), column("b")], + [column("a").sort(ascending=True), column("b").sort(ascending=True)], + ) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + # Keeps the first row per group (smallest b per a) + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [10, 30] + + +@pytest.mark.parametrize( + ("input_values", "expected"), + [ + ([3, 1, 2], [1, 2, 3]), + ([1, 2, 3], [1, 2, 3]), + ([3, None, 1, 2], [1, 2, 3, None]), + ], +) +def test_sort_by(input_values, expected): + """sort_by always sorts ascending with nulls last regardless of input order.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": input_values}) + result = df.sort_by(column("a")).collect()[0] + assert result.column(0).to_pylist() == expected + + +@pytest.mark.parametrize( + ("fmt", "verbose", "analyze", "expected_substring"), + [ + pytest.param(None, False, False, None, id="default format"), + pytest.param(ExplainFormat.TREE, False, False, "---", id="tree format"), + pytest.param( + ExplainFormat.INDENT, True, True, None, id="indent verbose+analyze" + ), + pytest.param(ExplainFormat.PGJSON, False, False, '"Plan"', id="pgjson format"), + pytest.param( + ExplainFormat.GRAPHVIZ, False, False, "digraph", id="graphviz format" + ), + ], +) +def test_explain_with_format(capsys, fmt, verbose, analyze, expected_substring): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + df.explain(verbose=verbose, analyze=analyze, format=fmt) + captured = capsys.readouterr() + assert "plan_type" in captured.out + if expected_substring is not None: + assert expected_substring in captured.out + + +@pytest.mark.parametrize( + ("window_exprs", "expected_columns"), + [ + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + ], + {"rn": [1, 2, 1]}, + id="single window expression", + ), + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + f.rank(partition_by=[column("b")], order_by=[column("a")]).alias("rnk"), + ], + {"rn": [1, 2, 1], "rnk": [1, 2, 1]}, + id="multiple window expressions", + ), + ], +) +def test_window(window_exprs, expected_columns): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + result = ( + df.window(*window_exprs()).sort(column("a").sort(ascending=True)).collect()[0] + ) + for col_name, expected_values in expected_columns.items(): + assert col_name in result.schema.names + assert ( + result.column(result.schema.get_field_index(col_name)).to_pylist() + == expected_values + ) + + +@pytest.mark.parametrize( + ("input_data", "recursions", "expected_a"), + [ + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + None, + [1, 2, 3], + id="basic unnest without recursions", + ), + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + [("a", "a", 1)], + [1, 2, 3], + id="explicit depth 1 matches basic unnest", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 1)], + [[1, 2], [3], [4]], + id="depth 1 on nested lists keeps inner lists", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 2)], + [1, 2, 3, 4], + id="depth 2 fully flattens nested lists", + ), + ], +) +def test_unnest_columns_with_recursions(input_data, recursions, expected_a): + ctx = SessionContext() + df = ctx.from_pydict(input_data) + kwargs = {} + if recursions is not None: + kwargs["recursions"] = recursions + result = df.unnest_columns("a", **kwargs).collect()[0] + assert result.column(0).to_pylist() == expected_a From 46f9ab8fcad03913234ce29e5075644c1ecdb9b7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 7 Apr 2026 15:03:38 -0400 Subject: [PATCH 14/29] Add missing deregister methods to SessionContext (#1473) * Add deregister methods to SessionContext for UDFs and object stores Expose upstream DataFusion deregister methods (deregister_udf, deregister_udaf, deregister_udwf, deregister_udtf, deregister_object_store) in both the Rust PyO3 bindings and Python wrappers, closing the gap identified in #1457. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix deregister tests to expect ValueError instead of RuntimeError DataFusion raises ValueError for planning errors when a deregistered function is used in a query. Co-Authored-By: Claude Opus 4.6 (1M context) * Replace .unwrap() with proper error propagation in object store methods Url::parse() can fail on invalid input. Use .map_err() to convert the error into a Python exception instead of panicking. Co-Authored-By: Claude Opus 4.6 (1M context) * Minor move of import statement --------- Co-authored-by: Claude Opus 4.6 (1M context) --- crates/core/src/context.rs | 32 +++++++++- python/datafusion/context.py | 41 ++++++++++++ python/tests/test_context.py | 120 +++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 1 deletion(-) diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 53994d2f5..1300a1595 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -434,11 +434,25 @@ impl PySessionContext { &upstream_host }; let url_string = format!("{scheme}{derived_host}"); - let url = Url::parse(&url_string).unwrap(); + let url = Url::parse(&url_string).map_err(|e| PyValueError::new_err(e.to_string()))?; self.ctx.runtime_env().register_object_store(&url, store); Ok(()) } + /// Deregister an object store with the given url + #[pyo3(signature = (scheme, host=None))] + pub fn deregister_object_store( + &self, + scheme: &str, + host: Option<&str>, + ) -> PyDataFusionResult<()> { + let host = host.unwrap_or(""); + let url_string = format!("{scheme}{host}"); + let url = Url::parse(&url_string).map_err(|e| PyDataFusionError::Common(e.to_string()))?; + self.ctx.runtime_env().deregister_object_store(&url)?; + Ok(()) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, table_partition_cols=vec![], file_extension=".parquet", @@ -492,6 +506,10 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } + pub fn deregister_udtf(&self, name: &str) { + self.ctx.deregister_udtf(name); + } + #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))] pub fn sql_with_options( &self, @@ -975,16 +993,28 @@ impl PySessionContext { Ok(()) } + pub fn deregister_udf(&self, name: &str) { + self.ctx.deregister_udf(name); + } + pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> { self.ctx.register_udaf(udaf.function); Ok(()) } + pub fn deregister_udaf(&self, name: &str) { + self.ctx.deregister_udaf(name); + } + pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> { self.ctx.register_udwf(udwf.function); Ok(()) } + pub fn deregister_udwf(&self, name: &str) { + self.ctx.deregister_udwf(name); + } + #[pyo3(signature = (name="datafusion"))] pub fn catalog(&self, py: Python, name: &str) -> PyResult> { let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c8edc816f..f190e3ca1 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -568,6 +568,15 @@ def register_object_store( """ self.ctx.register_object_store(schema, store, host) + def deregister_object_store(self, schema: str, host: str | None = None) -> None: + """Remove an object store from the session. + + Args: + schema: The data source schema (e.g. ``"s3://"``). + host: URL for the host (e.g. bucket name). + """ + self.ctx.deregister_object_store(schema, host) + def register_listing_table( self, name: str, @@ -894,6 +903,14 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def deregister_udtf(self, name: str) -> None: + """Remove a user-defined table function from the session. + + Args: + name: Name of the UDTF to deregister. + """ + self.ctx.deregister_udtf(name) + def register_record_batches( self, name: str, partitions: list[list[pa.RecordBatch]] ) -> None: @@ -1105,14 +1122,38 @@ def register_udf(self, udf: ScalarUDF) -> None: """Register a user-defined function (UDF) with the context.""" self.ctx.register_udf(udf._udf) + def deregister_udf(self, name: str) -> None: + """Remove a user-defined scalar function from the session. + + Args: + name: Name of the UDF to deregister. + """ + self.ctx.deregister_udf(name) + def register_udaf(self, udaf: AggregateUDF) -> None: """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf._udaf) + def deregister_udaf(self, name: str) -> None: + """Remove a user-defined aggregate function from the session. + + Args: + name: Name of the UDAF to deregister. + """ + self.ctx.deregister_udaf(name) + def register_udwf(self, udwf: WindowUDF) -> None: """Register a user-defined window function (UDWF) with the context.""" self.ctx.register_udwf(udwf._udwf) + def deregister_udwf(self, name: str) -> None: + """Remove a user-defined window function from the session. + + Args: + name: Name of the UDWF to deregister. + """ + self.ctx.deregister_udwf(name) + def catalog(self, name: str = "datafusion") -> Catalog: """Retrieve a catalog by name.""" return Catalog(self.ctx.catalog(name)) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5df6ed20f..8491cc3a5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -31,6 +31,7 @@ Table, column, literal, + udf, ) @@ -351,6 +352,125 @@ def test_deregister_table(ctx, database): assert public.names() == {"csv1", "csv2"} +def test_deregister_udf(): + ctx = SessionContext() + + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="my_is_null", + ) + ctx.register_udf(is_null) + + # Verify it works + df = ctx.from_pydict({"a": [1.0, None]}) + ctx.register_table("t", df.into_view()) + result = ctx.sql("SELECT my_is_null(a) FROM t").collect() + assert result[0].column(0) == pa.array([False, True]) + + # Deregister and verify it's gone + ctx.deregister_udf("my_is_null") + with pytest.raises(ValueError): + ctx.sql("SELECT my_is_null(a) FROM t").collect() + + +def test_deregister_udaf(): + import pyarrow.compute as pc + + ctx = SessionContext() + from datafusion import Accumulator, udaf + + class MySum(Accumulator): + def __init__(self): + self._sum = 0.0 + + def update(self, values: pa.Array) -> None: + self._sum += pc.sum(values).as_py() + + def merge(self, states: list[pa.Array]) -> None: + self._sum += pc.sum(states[0]).as_py() + + def state(self) -> list: + return [self._sum] + + def evaluate(self) -> pa.Scalar: + return self._sum + + my_sum = udaf( + MySum, + [pa.float64()], + pa.float64(), + [pa.float64()], + volatility="immutable", + name="my_sum", + ) + ctx.register_udaf(my_sum) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_sum(a) FROM t").collect() + assert result[0].column(0) == pa.array([6.0]) + + ctx.deregister_udaf("my_sum") + with pytest.raises(ValueError): + ctx.sql("SELECT my_sum(a) FROM t").collect() + + +def test_deregister_udwf(): + ctx = SessionContext() + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class MyRowNumber(WindowEvaluator): + def __init__(self): + self._row = 0 + + def evaluate_all(self, values, num_rows): + return pa.array(list(range(1, num_rows + 1)), type=pa.uint64()) + + my_row_number = udwf( + MyRowNumber, + [pa.float64()], + pa.uint64(), + volatility="immutable", + name="my_row_number", + ) + ctx.register_udwf(my_row_number) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64()) + + ctx.deregister_udwf("my_row_number") + with pytest.raises(ValueError): + ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + + +def test_deregister_udtf(): + import pyarrow.dataset as ds + + ctx = SessionContext() + from datafusion import Table, udtf + + class MyTable: + def __call__(self): + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + my_table = udtf(MyTable(), "my_table") + ctx.register_udtf(my_table) + + result = ctx.sql("SELECT * FROM my_table()").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + ctx.deregister_udtf("my_table") + with pytest.raises(ValueError): + ctx.sql("SELECT * FROM my_table()").collect() + + def test_register_table_from_dataframe(ctx): df = ctx.from_pydict({"a": [1, 2]}) ctx.register_table("df_tbl", df) From aa3b1948c3a49d14395093287a6e93354229c539 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 8 Apr 2026 09:22:28 -0400 Subject: [PATCH 15/29] Add missing registration methods (#1474) * Add missing SessionContext read/register methods for Arrow IPC and batches Add read_arrow, read_empty, register_arrow, and register_batch methods to SessionContext, exposing upstream DataFusion v53 functionality. The write_* methods and read_batch/read_batches are already covered by DataFrame.write_* and SessionContext.from_arrow respectively. Closes #1458. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove redundant read_empty Rust binding, make Python read_empty an alias for empty_table Co-Authored-By: Claude Opus 4.6 (1M context) * Add pathlib.Path and empty batch tests for Arrow IPC and register_batch Co-Authored-By: Claude Opus 4.6 (1M context) * Make test_read_empty more robust with length and num_rows checks Co-Authored-By: Claude Opus 4.6 (1M context) * Add examples to docstrings for new register/read methods Co-Authored-By: Claude Opus 4.6 (1M context) * Empty table actually returns record batch of length one but there are no columns * Add optional argument examples to register_arrow and read_arrow docstrings Demonstrate schema= and file_extension= keyword arguments in the docstring examples for register_arrow and read_arrow, following project guidelines for optional parameter documentation. Co-Authored-By: Claude Opus 4.6 (1M context) * Simplify read_empty docstring to use alias pattern Follow the same See Also alias convention used in functions.py since read_empty is a simple alias for empty_table. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove shared ctx from doctest namespace, use inline SessionContext Avoid shared SessionContext state across doctests by having each docstring example create its own ctx instance, matching the pattern used throughout the rest of the codebase. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove redundant import pyarrow as pa from docstrings The pa alias is already provided by the doctest namespace in conftest.py, so inline imports are unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- conftest.py | 2 + crates/core/src/context.rs | 58 +++++++++- python/datafusion/context.py | 181 ++++++++++++++++++++++++++++++ python/datafusion/user_defined.py | 3 - python/tests/test_context.py | 62 ++++++++++ 5 files changed, 302 insertions(+), 4 deletions(-) diff --git a/conftest.py b/conftest.py index 73e90077a..0c9410636 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ import datafusion as dfn import numpy as np +import pyarrow as pa import pytest from datafusion import col, lit from datafusion import functions as F @@ -29,6 +30,7 @@ def _doctest_namespace(doctest_namespace: dict) -> None: """Add common imports to the doctest namespace.""" doctest_namespace["dfn"] = dfn doctest_namespace["np"] = np + doctest_namespace["pa"] = pa doctest_namespace["col"] = col doctest_namespace["lit"] = lit doctest_namespace["F"] = F diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 1300a1595..ce11ef04e 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -41,7 +41,7 @@ use datafusion::execution::context::{ }; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; -use datafusion::execution::options::ReadOptions; +use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{ @@ -974,6 +974,39 @@ impl PySessionContext { Ok(()) } + #[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn register_arrow( + &self, + name: &str, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult<()> { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_arrow(name, path, options); + wait_for_future(py, result)??; + Ok(()) + } + + pub fn register_batch( + &self, + name: &str, + batch: PyArrowType, + ) -> PyDataFusionResult<()> { + self.ctx.register_batch(name, batch.0)?; + Ok(()) + } + // Registers a PyArrow.Dataset pub fn register_dataset( &self, @@ -1214,6 +1247,29 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } + #[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn read_arrow( + &self, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.read_arrow(path, options); + let df = wait_for_future(py, result)??; + Ok(PyDataFrame::new(df)) + } + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { let session = self.clone().into_bound_py_any(table.py())?; let table = PyTable::new(table, Some(session))?; diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f190e3ca1..7a306f04c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -903,6 +903,27 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def register_batch(self, name: str, batch: pa.RecordBatch) -> None: + """Register a single :py:class:`pa.RecordBatch` as a table. + + Args: + name: Name of the resultant table. + batch: Record batch to register as a table. + + Examples: + >>> ctx = dfn.SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + >>> ctx.register_batch("batch_tbl", batch) + >>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + self.ctx.register_batch(name, batch) + def deregister_udtf(self, name: str) -> None: """Remove a user-defined table function from the session. @@ -1109,6 +1130,86 @@ def register_avro( name, str(path), schema, file_extension, table_partition_cols ) + def register_arrow( + self, + name: str, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> None: + """Register an Arrow IPC file as a table. + + The registered table can be referenced from SQL statements executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"x": [10, 20, 30]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow("arrow_tbl", path) + ... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_schema", + ... path, + ... schema=pa.schema([("x", pa.int64())]), + ... ) + ... ctx.sql("SELECT * FROM arrow_schema").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_ipc", path, file_extension=".ipc" + ... ) + ... ctx.sql("SELECT * FROM arrow_ipc").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + """ + if table_partition_cols is None: + table_partition_cols = [] + table_partition_cols = _convert_table_partition_cols(table_partition_cols) + self.ctx.register_arrow( + name, str(path), schema, file_extension, table_partition_cols + ) + def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None: """Register a :py:class:`pa.dataset.Dataset` as a table. @@ -1369,6 +1470,86 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) + def read_arrow( + self, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> DataFrame: + """Create a :py:class:`DataFrame` for reading an Arrow IPC data source. + + Args: + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + file_partition_cols: Partition columns. + + Returns: + DataFrame representation of the read Arrow IPC file. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"a": [1, 2, 3]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, schema=pa.schema([("a", pa.int64())])) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, file_extension=".ipc") + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + if file_partition_cols is None: + file_partition_cols = [] + file_partition_cols = _convert_table_partition_cols(file_partition_cols) + return DataFrame( + self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols) + ) + + def read_empty(self) -> DataFrame: + """Create an empty :py:class:`DataFrame` with no columns or rows. + + See Also: + This is an alias for :meth:`empty_table`. + """ + return self.empty_table() + def read_table( self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset ) -> DataFrame: diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eaccdfa3..848ab4cee 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -213,7 +213,6 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Examples: Using ``udf`` as a function: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import ScalarUDF >>> def double_func(x): @@ -480,7 +479,6 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 instance in which this UDAF is used. Examples: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import AggregateUDF, Accumulator, udaf >>> class Summarize(Accumulator): @@ -874,7 +872,6 @@ def udwf(*args: Any, **kwargs: Any): # noqa: D417 When using ``udwf`` as a decorator, do not pass ``func`` explicitly. Examples: - >>> import pyarrow as pa >>> from datafusion.user_defined import WindowUDF, WindowEvaluator, udwf >>> class BiasedNumbers(WindowEvaluator): ... def __init__(self, start: int = 0): diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 8491cc3a5..25f66a647 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -788,6 +788,68 @@ def test_read_avro(ctx): assert avro_df is not None +def test_read_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then read it back + table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + df = ctx.read_arrow(str(arrow_path)) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array(["x", "y", "z"]) + + # Also verify pathlib.Path works + df = ctx.read_arrow(arrow_path) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_read_empty(ctx): + df = ctx.read_empty() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + df = ctx.empty_table() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + +def test_register_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then register and query it + table = pa.table({"x": [10, 20, 30]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + ctx.register_arrow("arrow_tbl", str(arrow_path)) + result = ctx.sql("SELECT * FROM arrow_tbl").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + # Also verify pathlib.Path works + ctx.register_arrow("arrow_tbl_path", arrow_path) + result = ctx.sql("SELECT * FROM arrow_tbl_path").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + +def test_register_batch(ctx): + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + ctx.register_batch("batch_tbl", batch) + result = ctx.sql("SELECT * FROM batch_tbl").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + + +def test_register_batch_empty(ctx): + batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())}) + ctx.register_batch("empty_batch_tbl", batch) + result = ctx.sql("SELECT * FROM empty_batch_tbl").collect() + assert result[0].num_rows == 0 + + def test_create_sql_options(): SQLOptions() From ecd14c10aff67169f2bfe1b7f86ff07621088dd0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 8 Apr 2026 11:11:48 -0400 Subject: [PATCH 16/29] Add missing SessionContext utility methods (#1475) * Add missing SessionContext utility methods Expose upstream DataFusion v53 utility methods: session_start_time, enable_ident_normalization, parse_sql_expr, execute_logical_plan, refresh_catalogs, remove_optimizer_rule, and table_provider. The add_optimizer_rule and add_analyzer_rule methods are omitted as the OptimizerRule and AnalyzerRule traits are not yet exposed to Python. Closes #1459. Co-Authored-By: Claude Opus 4.6 (1M context) * Raise KeyError from table_provider for consistency with table() Co-Authored-By: Claude Opus 4.6 (1M context) * Add docstring examples for new SessionContext utility methods Co-Authored-By: Claude Opus 4.6 (1M context) * update docstring * Address PR review feedback for SessionContext utility methods - Improve docstring examples to show actual output instead of asserts - Use doctest +SKIP for non-deterministic session_start_time output - Fix table_provider error mapping: outer async error is now RuntimeError - Strengthen tests: validate RFC 3339 with fromisoformat, test both optimizer rule removal paths, exact string match for parse_sql_expr, verify enable_ident_normalization with dynamic state change Co-Authored-By: Claude Opus 4.6 (1M context) * Fix test_session_start_time failure on Python 3.10 datetime.fromisoformat() only supports up to 6 fractional-second digits (microseconds) on Python 3.10. Truncate nanosecond precision before parsing. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- crates/core/src/context.rs | 49 ++++++++++++++- python/datafusion/context.py | 118 ++++++++++++++++++++++++++++++++++- python/tests/test_context.py | 55 ++++++++++++++++ 3 files changed, 219 insertions(+), 3 deletions(-) diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index ce11ef04e..b4fe524df 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -28,7 +28,7 @@ use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory}; -use datafusion::common::{ScalarValue, TableReference, exec_err}; +use datafusion::common::{DFSchema, ScalarValue, TableReference, exec_err}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -60,7 +60,7 @@ use datafusion_python_util::{ }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; -use pyo3::exceptions::{PyKeyError, PyValueError}; +use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; use url::Url; @@ -70,11 +70,13 @@ use crate::catalog::{ PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList, }; use crate::common::data_type::PyScalarValue; +use crate::common::df_schema::PyDFSchema; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{ PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err, }; +use crate::expr::PyExpr; use crate::expr::sort_expr::PySortExpr; use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; @@ -1113,6 +1115,49 @@ impl PySessionContext { self.ctx.session_id() } + pub fn session_start_time(&self) -> String { + self.ctx.session_start_time().to_rfc3339() + } + + pub fn enable_ident_normalization(&self) -> bool { + self.ctx.enable_ident_normalization() + } + + pub fn parse_sql_expr(&self, sql: &str, schema: PyDFSchema) -> PyDataFusionResult { + let df_schema: DFSchema = schema.into(); + Ok(self.ctx.parse_sql_expr(sql, &df_schema)?.into()) + } + + pub fn execute_logical_plan( + &self, + plan: PyLogicalPlan, + py: Python, + ) -> PyDataFusionResult { + let df = wait_for_future( + py, + self.ctx.execute_logical_plan(plan.plan.as_ref().clone()), + )??; + Ok(PyDataFrame::new(df)) + } + + pub fn refresh_catalogs(&self, py: Python) -> PyDataFusionResult<()> { + wait_for_future(py, self.ctx.refresh_catalogs())??; + Ok(()) + } + + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.ctx.remove_optimizer_rule(name) + } + + pub fn table_provider(&self, name: &str, py: Python) -> PyResult { + let provider = wait_for_future(py, self.ctx.table_provider(name)) + // Outer error: runtime/async failure + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + // Inner error: table not found + .map_err(|e| PyKeyError::new_err(e.to_string()))?; + Ok(PyTable { table: provider }) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] pub fn read_json( diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 7a306f04c..e3949de83 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -63,7 +63,8 @@ import polars as pl # type: ignore[import] from datafusion.catalog import CatalogProvider, Table - from datafusion.expr import SortKey + from datafusion.common import DFSchema + from datafusion.expr import Expr, SortKey from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.user_defined import ( AggregateUDF, @@ -1283,6 +1284,121 @@ def session_id(self) -> str: """Return an id that uniquely identifies this :py:class:`SessionContext`.""" return self.ctx.session_id() + def session_start_time(self) -> str: + """Return the session start time as an RFC 3339 formatted string. + + Examples: + >>> ctx = SessionContext() + >>> ctx.session_start_time() # doctest: +SKIP + '2026-01-01T12:34:56.123456789+00:00' + """ + return self.ctx.session_start_time() + + def enable_ident_normalization(self) -> bool: + """Return whether identifier normalization (lowercasing) is enabled. + + Examples: + >>> ctx = SessionContext() + >>> ctx.enable_ident_normalization() + True + """ + return self.ctx.enable_ident_normalization() + + def parse_sql_expr(self, sql: str, schema: DFSchema) -> Expr: + """Parse a SQL expression string into a logical expression. + + Args: + sql: SQL expression string. + schema: Schema to use for resolving column references. + + Returns: + Parsed expression. + + Examples: + >>> from datafusion.common import DFSchema + >>> ctx = SessionContext() + >>> schema = DFSchema.empty() + >>> ctx.parse_sql_expr("1 + 2", schema=schema) + Expr(Int64(1) + Int64(2)) + """ + from datafusion.expr import Expr # noqa: PLC0415 + + return Expr(self.ctx.parse_sql_expr(sql, schema)) + + def execute_logical_plan(self, plan: LogicalPlan) -> DataFrame: + """Execute a :py:class:`~datafusion.plan.LogicalPlan` and return a DataFrame. + + Args: + plan: Logical plan to execute. + + Returns: + DataFrame resulting from the execution. + + Examples: + >>> ctx = SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> plan = df.logical_plan() + >>> df2 = ctx.execute_logical_plan(plan) + >>> df2.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + return DataFrame(self.ctx.execute_logical_plan(plan._raw_plan)) + + def refresh_catalogs(self) -> None: + """Refresh catalog metadata. + + Examples: + >>> ctx = SessionContext() + >>> ctx.refresh_catalogs() + """ + self.ctx.refresh_catalogs() + + def remove_optimizer_rule(self, name: str) -> bool: + """Remove an optimizer rule by name. + + Args: + name: Name of the optimizer rule to remove. + + Returns: + True if a rule with the given name was found and removed. + + Examples: + >>> ctx = SessionContext() + >>> ctx.remove_optimizer_rule("nonexistent_rule") + False + """ + return self.ctx.remove_optimizer_rule(name) + + def table_provider(self, name: str) -> Table: + """Return the :py:class:`~datafusion.catalog.Table` for the given table name. + + Args: + name: Name of the table. + + Returns: + The table provider. + + Raises: + KeyError: If the table is not found. + + Examples: + >>> import pyarrow as pa + >>> ctx = SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"x": [1, 2]}) + >>> ctx.register_record_batches("my_table", [[batch]]) + >>> tbl = ctx.table_provider("my_table") + >>> tbl.schema + x: int64 + """ + from datafusion.catalog import Table # noqa: PLC0415 + + return Table(self.ctx.table_provider(name)) + def read_json( self, path: str | pathlib.Path, diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 25f66a647..13c05a9e6 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -671,6 +671,61 @@ def test_table_not_found(ctx): ctx.table(f"not-found-{uuid4()}") +def test_session_start_time(ctx): + import datetime + import re + + st = ctx.session_start_time() + assert isinstance(st, str) + # Truncate nanoseconds to microseconds for Python 3.10 compat + st = re.sub(r"(\.\d{6})\d+", r"\1", st) + dt = datetime.datetime.fromisoformat(st) + assert dt.isoformat() + + +def test_enable_ident_normalization(ctx): + assert ctx.enable_ident_normalization() is True + ctx.sql("SET datafusion.sql_parser.enable_ident_normalization = false") + assert ctx.enable_ident_normalization() is False + + +def test_parse_sql_expr(ctx): + from datafusion.common import DFSchema + + schema = DFSchema.empty() + expr = ctx.parse_sql_expr("1 + 2", schema) + assert str(expr) == "Expr(Int64(1) + Int64(2))" + + +def test_execute_logical_plan(ctx): + df = ctx.from_pydict({"a": [1, 2, 3]}) + plan = df.logical_plan() + df2 = ctx.execute_logical_plan(plan) + result = df2.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_refresh_catalogs(ctx): + ctx.refresh_catalogs() + + +def test_remove_optimizer_rule(ctx): + assert ctx.remove_optimizer_rule("push_down_filter") is True + assert ctx.remove_optimizer_rule("nonexistent_rule") is False + + +def test_table_provider(ctx): + batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_record_batches("provider_test", [[batch]]) + tbl = ctx.table_provider("provider_test") + assert tbl.schema == pa.schema([("x", pa.int64())]) + + +def test_table_provider_not_found(ctx): + with pytest.raises(KeyError): + ctx.table_provider("nonexistent_table") + + def test_read_json(ctx): path = pathlib.Path(__file__).parent.resolve() From 3585c11eed778810e3317c56c2c25a8cdc29be5b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Apr 2026 07:38:59 -0400 Subject: [PATCH 17/29] minor: remove deprecated interfaces (#1481) * udf module has been deprecated since DF47. html_formatter module has been deprecated since DF48. * database has been deprecated since DF48 * select_columns has been deprecated since DF43 * unnest_column has been deprecated since DF42 * display_name has been deprecated since DF42 * window() has been deprecated since DF50 * serde functions have been deprecated since DF42 * from_arrow_table and tables have been deprecated since DF42 * RuntimeConfig has been deprecated since DF44 * Update user documentation to remove deprecated function * update tpch examples for latest function uses * Remove unnecessary options in example * update rendering for the most recent dataframe_formatter instead of the deprecated html_formatter --- crates/core/src/context.rs | 15 -- crates/core/src/dataframe.rs | 29 +-- crates/core/src/functions.rs | 133 +---------- .../user-guide/common-operations/windows.rst | 15 +- .../source/user-guide/dataframe/rendering.rst | 225 ++++++++++-------- examples/tpch/q02_minimum_cost_supplier.py | 8 +- .../q11_important_stock_identification.py | 3 +- examples/tpch/q15_top_supplier.py | 4 +- examples/tpch/q17_small_quantity_order.py | 8 +- examples/tpch/q22_global_sales_opportunity.py | 4 +- python/datafusion/__init__.py | 3 +- python/datafusion/catalog.py | 10 - python/datafusion/context.py | 21 -- python/datafusion/dataframe.py | 20 -- python/datafusion/dataframe_formatter.py | 8 +- python/datafusion/expr.py | 15 -- python/datafusion/functions.py | 55 +---- python/datafusion/html_formatter.py | 29 --- python/datafusion/substrait.py | 25 -- python/datafusion/udf.py | 29 --- python/tests/test_expr.py | 21 -- 21 files changed, 154 insertions(+), 526 deletions(-) delete mode 100644 python/datafusion/html_formatter.py delete mode 100644 python/datafusion/udf.py diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index b4fe524df..e46d359d6 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1072,21 +1072,6 @@ impl PySessionContext { self.ctx.catalog_names().into_iter().collect() } - pub fn tables(&self) -> HashSet { - self.ctx - .catalog_names() - .into_iter() - .filter_map(|name| self.ctx.catalog(&name)) - .flat_map(move |catalog| { - catalog - .schema_names() - .into_iter() - .filter_map(move |name| catalog.schema(&name)) - }) - .flat_map(|schema| schema.table_names()) - .collect() - } - pub fn table(&self, name: &str, py: Python) -> PyResult { let res = wait_for_future(py, self.ctx.table(name)) .map_err(|e| PyKeyError::new_err(e.to_string()))?; diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index fff5118d5..c067eac30 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -468,17 +468,17 @@ impl PyDataFrame { fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult { if let Ok(key) = key.extract::() { // df[col] - self.select_columns(vec![key]) + self.select_exprs(vec![key]) } else if let Ok(tuple) = key.cast::() { // df[col1, col2, col3] let keys = tuple .iter() .map(|item| item.extract::()) .collect::>>()?; - self.select_columns(keys) + self.select_exprs(keys) } else if let Ok(keys) = key.extract::>() { // df[[col1, col2, col3]] - self.select_columns(keys) + self.select_exprs(keys) } else { let message = "DataFrame can only be indexed by string index or indices".to_string(); Err(PyDataFusionError::Common(message)) @@ -554,13 +554,6 @@ impl PyDataFrame { Ok(PyTable::from(table_provider)) } - #[pyo3(signature = (*args))] - fn select_columns(&self, args: Vec) -> PyDataFusionResult { - let args = args.iter().map(|s| s.as_ref()).collect::>(); - let df = self.df.as_ref().clone().select_columns(&args)?; - Ok(Self::new(df)) - } - #[pyo3(signature = (*args))] fn select_exprs(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); @@ -890,22 +883,6 @@ impl PyDataFrame { Ok(Self::new(new_df)) } - #[pyo3(signature = (column, preserve_nulls=true, recursions=None))] - fn unnest_column( - &self, - column: &str, - preserve_nulls: bool, - recursions: Option>, - ) -> PyDataFusionResult { - let unnest_options = build_unnest_options(preserve_nulls, recursions); - let df = self - .df - .as_ref() - .clone() - .unnest_columns_with_options(&[column], unnest_options)?; - Ok(Self::new(df)) - } - #[pyo3(signature = (columns, preserve_nulls=true, recursions=None))] fn unnest_columns( &self, diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index f173aaa51..7feb62d79 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -18,20 +18,14 @@ use std::collections::HashMap; use datafusion::common::{Column, ScalarValue, TableReference}; -use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion::functions_window::all_default_window_functions; -use datafusion::logical_expr::expr::{ - Alias, FieldMetadata, NullTreatment as DFNullTreatment, WindowFunction, WindowFunctionParams, -}; -use datafusion::logical_expr::{Expr, ExprFunctionExt, WindowFrame, WindowFunctionDefinition, lit}; +use datafusion::logical_expr::expr::{Alias, FieldMetadata, NullTreatment as DFNullTreatment}; +use datafusion::logical_expr::{Expr, ExprFunctionExt, lit}; use datafusion::{functions, functions_aggregate, functions_window}; use pyo3::prelude::*; use pyo3::wrap_pyfunction; use crate::common::data_type::{NullTreatment, PyScalarValue}; -use crate::context::PySessionContext; -use crate::errors::{PyDataFusionError, PyDataFusionResult}; +use crate::errors::PyDataFusionResult; use crate::expr::PyExpr; use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; @@ -306,126 +300,6 @@ fn when(when: PyExpr, then: PyExpr) -> PyResult { Ok(PyCaseBuilder::new(None).when(when, then)) } -/// Helper function to find the appropriate window function. -/// -/// Search procedure: -/// 1) Search built in window functions, which are being deprecated. -/// 1) If a session context is provided: -/// 1) search User Defined Aggregate Functions (UDAFs) -/// 1) search registered window functions -/// 1) search registered aggregate functions -/// 1) If no function has been found, search default aggregate functions. -/// -/// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior. -fn find_window_fn( - name: &str, - ctx: Option, -) -> PyDataFusionResult { - if let Some(ctx) = ctx { - // search UDAFs - let udaf = ctx - .ctx - .udaf(name) - .map(WindowFunctionDefinition::AggregateUDF) - .ok(); - - if let Some(udaf) = udaf { - return Ok(udaf); - } - - let session_state = ctx.ctx.state(); - - // search registered window functions - let window_fn = session_state - .window_functions() - .get(name) - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); - - if let Some(window_fn) = window_fn { - return Ok(window_fn); - } - - // search registered aggregate functions - let agg_fn = session_state - .aggregate_functions() - .get(name) - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); - - if let Some(agg_fn) = agg_fn { - return Ok(agg_fn); - } - } - - // search default aggregate functions - let agg_fn = all_default_aggregate_functions() - .iter() - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); - - if let Some(agg_fn) = agg_fn { - return Ok(agg_fn); - } - - // search default window functions - let window_fn = all_default_window_functions() - .iter() - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); - - if let Some(window_fn) = window_fn { - return Ok(window_fn); - } - - Err(PyDataFusionError::Common(format!( - "window function `{name}` not found" - ))) -} - -/// Creates a new Window function expression -#[allow(clippy::too_many_arguments)] -#[pyfunction] -#[pyo3(signature = (name, args, partition_by=None, order_by=None, window_frame=None, filter=None, distinct=false, ctx=None))] -fn window( - name: &str, - args: Vec, - partition_by: Option>, - order_by: Option>, - window_frame: Option, - filter: Option, - distinct: bool, - ctx: Option, -) -> PyResult { - let fun = find_window_fn(name, ctx)?; - - let window_frame = window_frame - .map(|w| w.into()) - .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty()))); - let filter = filter.map(|f| f.expr.into()); - - Ok(PyExpr { - expr: datafusion::logical_expr::Expr::WindowFunction(Box::new(WindowFunction { - fun, - params: WindowFunctionParams { - args: args.into_iter().map(|x| x.expr).collect::>(), - partition_by: partition_by - .unwrap_or_default() - .into_iter() - .map(|x| x.expr) - .collect::>(), - order_by: order_by - .unwrap_or_default() - .into_iter() - .map(|x| x.into()) - .collect::>(), - window_frame, - filter, - distinct, - null_treatment: None, - }, - })), - }) -} - // Generates a [pyo3] wrapper for associated aggregate functions. // All of the builder options are exposed to the python internal // function and we rely on the wrappers to only use those that @@ -1186,7 +1060,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; - m.add_wrapped(wrap_pyfunction!(window))?; m.add_wrapped(wrap_pyfunction!(regr_avgx))?; m.add_wrapped(wrap_pyfunction!(regr_avgy))?; m.add_wrapped(wrap_pyfunction!(regr_count))?; diff --git a/docs/source/user-guide/common-operations/windows.rst b/docs/source/user-guide/common-operations/windows.rst index c8fdea8f4..d77881bcf 100644 --- a/docs/source/user-guide/common-operations/windows.rst +++ b/docs/source/user-guide/common-operations/windows.rst @@ -175,10 +175,7 @@ it's ``Type 2`` column that are null. Aggregate Functions ------------------- -You can use any :ref:`Aggregation Function` as a window function. Currently -aggregate functions must use the deprecated -:py:func:`datafusion.functions.window` API but this should be resolved in -DataFusion 42.0 (`Issue Link `_). Here +You can use any :ref:`Aggregation Function` as a window function. Here is an example that shows how to compare each pokemons’s attack power with the average attack power in its ``"Type 1"`` using the :py:func:`datafusion.functions.avg` function. @@ -189,10 +186,12 @@ power in its ``"Type 1"`` using the :py:func:`datafusion.functions.avg` function col('"Name"'), col('"Attack"'), col('"Type 1"'), - f.window("avg", [col('"Attack"')]) - .partition_by(col('"Type 1"')) - .build() - .alias("Average Attack"), + f.avg(col('"Attack"')).over( + Window( + window_frame=WindowFrame("rows", None, None), + partition_by=[col('"Type 1"')], + ) + ).alias("Average Attack"), ) Available Functions diff --git a/docs/source/user-guide/dataframe/rendering.rst b/docs/source/user-guide/dataframe/rendering.rst index 9dea948bb..dc61a422f 100644 --- a/docs/source/user-guide/dataframe/rendering.rst +++ b/docs/source/user-guide/dataframe/rendering.rst @@ -15,18 +15,18 @@ .. specific language governing permissions and limitations .. under the License. -HTML Rendering in Jupyter -========================= +DataFrame Rendering +=================== -When working in Jupyter notebooks or other environments that support rich HTML display, -DataFusion DataFrames automatically render as nicely formatted HTML tables. This functionality -is provided by the ``_repr_html_`` method, which is automatically called by Jupyter to provide -a richer visualization than plain text output. +DataFusion provides configurable rendering for DataFrames in both plain text and HTML +formats. The ``datafusion.dataframe_formatter`` module controls how DataFrames are +displayed in Jupyter notebooks (via ``_repr_html_``), in the terminal (via ``__repr__``), +and anywhere else a string or HTML representation is needed. -Basic HTML Rendering --------------------- +Basic Rendering +--------------- -In a Jupyter environment, simply displaying a DataFrame object will trigger HTML rendering: +In a Jupyter environment, displaying a DataFrame triggers HTML rendering: .. code-block:: python @@ -36,74 +36,117 @@ In a Jupyter environment, simply displaying a DataFrame object will trigger HTML # Explicit display also uses HTML rendering display(df) -Customizing HTML Rendering ---------------------------- +In a terminal or when converting to string, plain text rendering is used: + +.. code-block:: python -DataFusion provides extensive customization options for HTML table rendering through the -``datafusion.html_formatter`` module. + # Plain text table output + print(df) -Configuring the HTML Formatter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Configuring the Formatter +------------------------- -You can customize how DataFrames are rendered by configuring the formatter: +You can customize how DataFrames are rendered by configuring the global formatter: .. code-block:: python - from datafusion.html_formatter import configure_formatter - - # Change the default styling + from datafusion.dataframe_formatter import configure_formatter + configure_formatter( - max_cell_length=25, # Maximum characters in a cell before truncation - max_width=1000, # Maximum width in pixels - max_height=300, # Maximum height in pixels - max_memory_bytes=2097152, # Maximum memory for rendering (2MB) - min_rows=10, # Minimum number of rows to display - max_rows=10, # Maximum rows to display in __repr__ - enable_cell_expansion=True,# Allow expanding truncated cells - custom_css=None, # Additional custom CSS + max_cell_length=25, # Maximum characters in a cell before truncation + max_width=1000, # Maximum width in pixels (HTML only) + max_height=300, # Maximum height in pixels (HTML only) + max_memory_bytes=2097152, # Maximum memory for rendering (2MB) + min_rows=10, # Minimum number of rows to display + max_rows=10, # Maximum rows to display + enable_cell_expansion=True, # Allow expanding truncated cells (HTML only) + custom_css=None, # Additional custom CSS (HTML only) show_truncation_message=True, # Show message when data is truncated - style_provider=None, # Custom styling provider - use_shared_styles=True # Share styles across tables + style_provider=None, # Custom styling provider (HTML only) + use_shared_styles=True, # Share styles across tables (HTML only) ) The formatter settings affect all DataFrames displayed after configuration. Custom Style Providers ------------------------ +---------------------- -For advanced styling needs, you can create a custom style provider: +For HTML styling, you can create a custom style provider that implements the +``StyleProvider`` protocol: .. code-block:: python - from datafusion.html_formatter import StyleProvider, configure_formatter - - class MyStyleProvider(StyleProvider): - def get_table_styles(self): - return { - "table": "border-collapse: collapse; width: 100%;", - "th": "background-color: #007bff; color: white; padding: 8px; text-align: left;", - "td": "border: 1px solid #ddd; padding: 8px;", - "tr:nth-child(even)": "background-color: #f2f2f2;", - } - - def get_value_styles(self, dtype, value): - """Return custom styles for specific values""" - if dtype == "float" and value < 0: - return "color: red;" - return None - + from datafusion.dataframe_formatter import configure_formatter + + class MyStyleProvider: + def get_cell_style(self): + """Return CSS style string for table data cells.""" + return "border: 1px solid #ddd; padding: 8px; text-align: left;" + + def get_header_style(self): + """Return CSS style string for table header cells.""" + return ( + "background-color: #007bff; color: white; " + "padding: 8px; text-align: left;" + ) + # Apply the custom style provider configure_formatter(style_provider=MyStyleProvider()) +Custom Cell Formatters +---------------------- + +You can register custom formatters for specific Python types. A cell formatter is any +callable that takes a value and returns a string: + +.. code-block:: python + + from datafusion.dataframe_formatter import get_formatter + + formatter = get_formatter() + + # Format floats to 2 decimal places + formatter.register_formatter(float, lambda v: f"{v:.2f}") + + # Format dates in a custom way + from datetime import date + formatter.register_formatter(date, lambda v: v.strftime("%B %d, %Y")) + +Custom Cell and Header Builders +------------------------------- + +For full control over the HTML of individual cells or headers, you can set custom +builder functions: + +.. code-block:: python + + from datafusion.dataframe_formatter import get_formatter + + formatter = get_formatter() + + # Custom cell builder receives (value, row, col, table_id) and returns HTML + def my_cell_builder(value, row, col, table_id): + color = "red" if isinstance(value, (int, float)) and value < 0 else "black" + return f"{value}" + + formatter.set_custom_cell_builder(my_cell_builder) + + # Custom header builder receives a schema field and returns HTML + def my_header_builder(field): + return f"{field.name}" + + formatter.set_custom_header_builder(my_header_builder) + Performance Optimization with Shared Styles -------------------------------------------- -The ``use_shared_styles`` parameter (enabled by default) optimizes performance when displaying -multiple DataFrames in notebook environments: +The ``use_shared_styles`` parameter (enabled by default) optimizes performance when +displaying multiple DataFrames in notebook environments: .. code-block:: python - from datafusion.html_formatter import StyleProvider, configure_formatter + from datafusion.dataframe_formatter import configure_formatter + # Default: Use shared styles (recommended for notebooks) configure_formatter(use_shared_styles=True) @@ -111,76 +154,48 @@ multiple DataFrames in notebook environments: configure_formatter(use_shared_styles=False) When ``use_shared_styles=True``: + - CSS styles and JavaScript are included only once per notebook session - This reduces HTML output size and prevents style duplication - Improves rendering performance with many DataFrames - Applies consistent styling across all DataFrames -Creating a Custom Formatter ----------------------------- +Working with the Formatter Directly +------------------------------------ -For complete control over rendering, you can implement a custom formatter: +You can use ``get_formatter()`` and ``set_formatter()`` for direct access to the global +formatter instance: .. code-block:: python - from datafusion.html_formatter import Formatter, get_formatter - - class MyFormatter(Formatter): - def format_html(self, batches, schema, has_more=False, table_uuid=None): - # Create your custom HTML here - html = "
" - # ... formatting logic ... - html += "
" - return html - - # Set as the global formatter - configure_formatter(formatter_class=MyFormatter) - - # Or use the formatter just for specific operations + from datafusion.dataframe_formatter import ( + DataFrameHtmlFormatter, + get_formatter, + set_formatter, + ) + + # Get and modify the current formatter formatter = get_formatter() - custom_html = formatter.format_html(batches, schema) + print(formatter.max_rows) + print(formatter.max_cell_length) -Managing Formatters -------------------- + # Create and set a fully custom formatter + custom_formatter = DataFrameHtmlFormatter( + max_cell_length=50, + max_rows=20, + enable_cell_expansion=False, + ) + set_formatter(custom_formatter) Reset to default formatting: .. code-block:: python - from datafusion.html_formatter import reset_formatter - + from datafusion.dataframe_formatter import reset_formatter + # Reset to default settings reset_formatter() -Get the current formatter settings: - -.. code-block:: python - - from datafusion.html_formatter import get_formatter - - formatter = get_formatter() - print(formatter.max_rows) - print(formatter.theme) - -Contextual Formatting ----------------------- - -You can also use a context manager to temporarily change formatting settings: - -.. code-block:: python - - from datafusion.html_formatter import formatting_context - - # Default formatting - df.show() - - # Temporarily use different formatting - with formatting_context(max_rows=100, theme="dark"): - df.show() # Will use the temporary settings - - # Back to default formatting - df.show() - Memory and Display Controls --------------------------- @@ -188,10 +203,12 @@ You can control how much data is displayed and how much memory is used for rende .. code-block:: python + from datafusion.dataframe_formatter import configure_formatter + configure_formatter( max_memory_bytes=4 * 1024 * 1024, # 4MB maximum memory for display min_rows=20, # Always show at least 20 rows - max_rows=50 # Show up to 50 rows in output + max_rows=50, # Show up to 50 rows in output ) These parameters help balance comprehensive data display against performance considerations. @@ -216,7 +233,7 @@ Additional Resources * :doc:`../io/index` - I/O Guide for reading data from various sources * :doc:`../data-sources` - Comprehensive data sources guide * :ref:`io_csv` - CSV file reading -* :ref:`io_parquet` - Parquet file reading +* :ref:`io_parquet` - Parquet file reading * :ref:`io_json` - JSON file reading * :ref:`io_avro` - Avro file reading * :ref:`io_custom_table_provider` - Custom table providers diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 7390d0892..47961d2ef 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -32,6 +32,7 @@ import datafusion from datafusion import SessionContext, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path # This is the part we're looking for. Values selected here differ from the spec in order to run @@ -106,11 +107,8 @@ window_frame = datafusion.WindowFrame("rows", None, None) df = df.with_column( "min_cost", - F.window( - "min", - [col("ps_supplycost")], - partition_by=[col("ps_partkey")], - window_frame=window_frame, + F.min(col("ps_supplycost")).over( + Window(partition_by=[col("ps_partkey")], window_frame=window_frame) ), ) diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 22829ab7c..de309fa64 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -29,6 +29,7 @@ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path NATION = "GERMANY" @@ -71,7 +72,7 @@ window_frame = WindowFrame("rows", None, None) df = df.with_column( - "total_value", F.window("sum", [col("value")], window_frame=window_frame) + "total_value", F.sum(col("value")).over(Window(window_frame=window_frame)) ) # Limit to the parts for which there is a significant value based on the fraction of the total diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index c321048f2..5128937a7 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -31,6 +31,7 @@ import pyarrow as pa from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path DATE = "1996-01-01" @@ -70,7 +71,8 @@ # Use a window function to find the maximum revenue across the entire dataframe window_frame = WindowFrame("rows", None, None) df = df.with_column( - "max_revenue", F.window("max", [col("total_revenue")], window_frame=window_frame) + "max_revenue", + F.max(col("total_revenue")).over(Window(window_frame=window_frame)), ) # Find all suppliers whose total revenue is the same as the maximum diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 6d76fe506..5ccb38422 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -30,6 +30,7 @@ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path BRAND = "Brand#23" @@ -58,11 +59,8 @@ window_frame = WindowFrame("rows", None, None) df = df.with_column( "avg_quantity", - F.window( - "avg", - [col("l_quantity")], - window_frame=window_frame, - partition_by=[col("l_partkey")], + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=window_frame) ), ) diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index c4d115b74..a2d41b215 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -28,6 +28,7 @@ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path NATION_CODES = [13, 31, 23, 29, 30, 18, 17] @@ -55,7 +56,8 @@ # current row. We want our frame to cover the entire data frame. window_frame = WindowFrame("rows", None, None) df = df.with_column( - "avg_balance", F.window("avg", [col("c_acctbal")], window_frame=window_frame) + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=window_frame)), ) df.show() diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index a736c3966..ee02c921d 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -35,7 +35,7 @@ # The following imports are okay to remain as opaque to the user. from ._internal import Config -from .catalog import Catalog, Database, Table +from .catalog import Catalog, Table from .col import col, column from .common import DFSchema from .context import ( @@ -81,7 +81,6 @@ "DFSchema", "DataFrame", "DataFrameWriteOptions", - "Database", "ExecutionPlan", "ExplainFormat", "Expr", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 03c0ddc68..20da5e671 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -129,11 +129,6 @@ def schema(self, name: str = "public") -> Schema: else schema ) - @deprecated("Use `schema` instead.") - def database(self, name: str = "public") -> Schema: - """Returns the database with the given ``name`` from this catalog.""" - return self.schema(name) - def register_schema( self, name: str, @@ -195,11 +190,6 @@ def table_exist(self, name: str) -> bool: return self._raw_schema.table_exist(name) -@deprecated("Use `Schema` instead.") -class Database(Schema): - """See `Schema`.""" - - class Table: """A DataFusion table. diff --git a/python/datafusion/context.py b/python/datafusion/context.py index e3949de83..c3f94cc16 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -426,11 +426,6 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeEnvBuilder: return self -@deprecated("Use `RuntimeEnvBuilder` instead.") -class RuntimeConfig(RuntimeEnvBuilder): - """See `RuntimeEnvBuilder`.""" - - class SQLOptions: """Options to be used when performing SQL queries.""" @@ -785,14 +780,6 @@ def from_arrow( """ return DataFrame(self.ctx.from_arrow(data, name)) - @deprecated("Use ``from_arrow`` instead.") - def from_arrow_table(self, data: pa.Table, name: str | None = None) -> DataFrame: - """Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table. - - This is an alias for :py:func:`from_arrow`. - """ - return self.from_arrow(data, name) - def from_pandas(self, data: pd.DataFrame, name: str | None = None) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from a Pandas DataFrame. @@ -1260,14 +1247,6 @@ def catalog(self, name: str = "datafusion") -> Catalog: """Retrieve a catalog by name.""" return Catalog(self.ctx.catalog(name)) - @deprecated( - "Use the catalog provider interface ``SessionContext.Catalog`` to " - "examine available catalogs, schemas and tables" - ) - def tables(self) -> set[str]: - """Deprecated.""" - return self.ctx.tables() - def table(self, name: str) -> DataFrame: """Retrieve a previously registered table by name.""" return DataFrame(self.ctx.table(name)) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 9dc5f0e7d..c00c85fdb 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -489,17 +489,6 @@ def find_qualified_columns(self, *names: str) -> list[Expr]: raw_exprs = self.df.find_qualified_columns(list(names)) return [Expr(e) for e in raw_exprs] - @deprecated( - "select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" - ) - def select_columns(self, *args: str) -> DataFrame: - """Filter the DataFrame by columns. - - Returns: - DataFrame only containing the specified columns. - """ - return self.select(*args) - def select_exprs(self, *args: str) -> DataFrame: """Project arbitrary list of expression strings into a new DataFrame. @@ -1603,15 +1592,6 @@ def count(self) -> int: """ return self.df.count() - @deprecated("Use :py:func:`unnest_columns` instead.") - def unnest_column( - self, - column: str, - preserve_nulls: bool = True, - ) -> DataFrame: - """See :py:func:`unnest_columns`.""" - return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) - def unnest_columns( self, *columns: str, diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py index b8af45a1b..fd2da99f0 100644 --- a/python/datafusion/dataframe_formatter.py +++ b/python/datafusion/dataframe_formatter.py @@ -748,7 +748,7 @@ def get_formatter() -> DataFrameHtmlFormatter: The global HTML formatter instance Example: - >>> from datafusion.html_formatter import get_formatter + >>> from datafusion.dataframe_formatter import get_formatter >>> formatter = get_formatter() >>> formatter.max_cell_length = 50 # Increase cell length """ @@ -762,7 +762,7 @@ def set_formatter(formatter: DataFrameHtmlFormatter) -> None: formatter: The formatter instance to use globally Example: - >>> from datafusion.html_formatter import get_formatter, set_formatter + >>> from datafusion.dataframe_formatter import get_formatter, set_formatter >>> custom_formatter = DataFrameHtmlFormatter(max_cell_length=100) >>> set_formatter(custom_formatter) """ @@ -783,7 +783,7 @@ def configure_formatter(**kwargs: Any) -> None: ValueError: If any invalid parameters are provided Example: - >>> from datafusion.html_formatter import configure_formatter + >>> from datafusion.dataframe_formatter import configure_formatter >>> configure_formatter( ... max_cell_length=50, ... max_height=500, @@ -827,7 +827,7 @@ def reset_formatter() -> None: and sets it as the global formatter for all DataFrames. Example: - >>> from datafusion.html_formatter import reset_formatter + >>> from datafusion.dataframe_formatter import reset_formatter >>> reset_formatter() # Reset formatter to default settings """ formatter = DataFrameHtmlFormatter() diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 35388468c..7cd74ecd5 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -27,11 +27,6 @@ from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, ClassVar -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - import pyarrow as pa from ._internal import expr as expr_internal @@ -356,16 +351,6 @@ def to_variant(self) -> Any: """Convert this expression into a python object if possible.""" return self.expr.to_variant() - @deprecated( - "display_name() is deprecated. Use :py:meth:`~Expr.schema_name` instead" - ) - def display_name(self) -> str: - """Returns the name of this expression as it should appear in a schema. - - This name will not include any CAST expressions. - """ - return self.schema_name() - def schema_name(self) -> str: """Returns the name of this expression as it should appear in a schema. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 9dfabb62d..841cd9c0b 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any import pyarrow as pa @@ -29,19 +29,11 @@ Expr, SortExpr, SortKey, - WindowFrame, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, ) -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - -if TYPE_CHECKING: - from datafusion.context import SessionContext __all__ = [ "abs", "acos", @@ -339,8 +331,6 @@ "var_sample", "version", "when", - # Window Functions - "window", ] @@ -664,49 +654,6 @@ def when(when: Expr, then: Expr) -> CaseBuilder: return CaseBuilder(f.when(when.expr, then.expr)) -@deprecated("Prefer to call Expr.over() instead") -def window( - name: str, - args: list[Expr], - partition_by: list[Expr] | Expr | None = None, - order_by: list[SortKey] | SortKey | None = None, - window_frame: WindowFrame | None = None, - filter: Expr | None = None, - distinct: bool = False, - ctx: SessionContext | None = None, -) -> Expr: - """Creates a new Window function expression. - - This interface will soon be deprecated. Instead of using this interface, - users should call the window functions directly. For example, to perform a - lag use:: - - df.select(functions.lag(col("a")).partition_by(col("b")).build()) - - The ``order_by`` parameter accepts column names or expressions, e.g.:: - - window("lag", [col("a")], order_by="ts") - """ - args = [a.expr for a in args] - partition_by_raw = expr_list_to_raw_expr_list(partition_by) - order_by_raw = sort_list_to_raw_sort_list(order_by) - window_frame = window_frame.window_frame if window_frame is not None else None - ctx = ctx.ctx if ctx is not None else None - filter_raw = filter.expr if filter is not None else None - return Expr( - f.window( - name, - args, - partition_by=partition_by_raw, - order_by=order_by_raw, - window_frame=window_frame, - ctx=ctx, - filter=filter_raw, - distinct=distinct, - ) - ) - - # scalar functions def abs(arg: Expr) -> Expr: """Return the absolute value of a given number. diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py deleted file mode 100644 index 65eb1f042..000000000 --- a/python/datafusion/html_formatter.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Deprecated module for dataframe formatting.""" - -import warnings - -from datafusion.dataframe_formatter import * # noqa: F403 - -warnings.warn( - "The module 'html_formatter' is deprecated and will be removed in the next release." - "Please use 'dataframe_formatter' instead.", - DeprecationWarning, - stacklevel=3, -) diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 3115238fa..6353ef8cc 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -25,11 +25,6 @@ from typing import TYPE_CHECKING -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - from datafusion.plan import LogicalPlan from ._internal import substrait as substrait_internal @@ -88,11 +83,6 @@ def from_json(json: str) -> Plan: return Plan(substrait_internal.Plan.from_json(json)) -@deprecated("Use `Plan` instead.") -class plan(Plan): # noqa: N801 - """See `Plan`.""" - - class Serde: """Provides the ``Substrait`` serialization and deserialization.""" @@ -158,11 +148,6 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: return Plan(substrait_internal.Serde.deserialize_bytes(proto_bytes)) -@deprecated("Use `Serde` instead.") -class serde(Serde): # noqa: N801 - """See `Serde` instead.""" - - class Producer: """Generates substrait plans from a logical plan.""" @@ -184,11 +169,6 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: ) -@deprecated("Use `Producer` instead.") -class producer(Producer): # noqa: N801 - """Use `Producer` instead.""" - - class Consumer: """Generates a logical plan from a substrait plan.""" @@ -206,8 +186,3 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: return LogicalPlan( substrait_internal.Consumer.from_substrait_plan(ctx.ctx, plan.plan_internal) ) - - -@deprecated("Use `Consumer` instead.") -class consumer(Consumer): # noqa: N801 - """Use `Consumer` instead.""" diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py deleted file mode 100644 index c7265fa09..000000000 --- a/python/datafusion/udf.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Deprecated module for user defined functions.""" - -import warnings - -from datafusion.user_defined import * # noqa: F403 - -warnings.warn( - "The module 'udf' is deprecated and will be removed in the next release. " - "Please use 'user_defined' instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 9a287c1f7..1cf824a15 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -319,27 +319,6 @@ def test_expr_getitem() -> None: assert array_values == [2, 5, None, None] -def test_display_name_deprecation(): - import warnings - - expr = col("foo") - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered - warnings.simplefilter("always") - - # should trigger warning - name = expr.display_name() - - # Verify some things - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "deprecated" in str(w[-1].message) - - # returns appropriate result - assert name == expr.schema_name() - assert name == "foo" - - @pytest.fixture def df(): ctx = SessionContext() From 1be838bb47f04bcf4d1a0f65e3e6958aa9366f3f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 12 Apr 2026 21:24:39 -0400 Subject: [PATCH 18/29] Release 53.0.0 (#1491) * Update version number and changelog * minor: set version number on dependency to publish to crates.io * taplo fmt --- Cargo.lock | 6 +-- Cargo.toml | 6 +-- dev/changelog/53.0.0.md | 107 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 dev/changelog/53.0.0.md diff --git a/Cargo.lock b/Cargo.lock index ee89c8bda..1cbb0acb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1312,7 +1312,7 @@ dependencies = [ [[package]] name = "datafusion-ffi-example" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "arrow-array", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "datafusion-python" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "arrow-select", @@ -1692,7 +1692,7 @@ dependencies = [ [[package]] name = "datafusion-python-util" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "datafusion", diff --git a/Cargo.toml b/Cargo.toml index 3a34e204c..14408d2bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [workspace.package] -version = "52.0.0" +version = "53.0.0" homepage = "https://datafusion.apache.org/python" repository = "https://github.com/apache/datafusion-python" authors = ["Apache DataFusion "] @@ -59,9 +59,9 @@ object_store = { version = "0.13.1" } url = "2" log = "0.4.29" parking_lot = "0.12" -prost-types = "0.14.3" # keep in line with `datafusion-substrait` +prost-types = "0.14.3" # keep in line with `datafusion-substrait` pyo3-build-config = "0.28" -datafusion-python-util = { path = "crates/util" } +datafusion-python-util = { path = "crates/util", version = "53.0.0" } [profile.release] lto = "thin" diff --git a/dev/changelog/53.0.0.md b/dev/changelog/53.0.0.md new file mode 100644 index 000000000..3e27a852d --- /dev/null +++ b/dev/changelog/53.0.0.md @@ -0,0 +1,107 @@ + + +# Apache DataFusion Python 53.0.0 Changelog + +This release consists of 52 commits from 9 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- minor: remove deprecated interfaces [#1481](https://github.com/apache/datafusion-python/pull/1481) (timsaucer) + +**Implemented enhancements:** + +- feat: feat: add to_time, to_local_time, to_date functions [#1387](https://github.com/apache/datafusion-python/pull/1387) (mesejo) +- feat: Add FFI_TableProviderFactory support [#1396](https://github.com/apache/datafusion-python/pull/1396) (davisp) + +**Fixed bugs:** + +- fix: satisfy rustfmt check in lib.rs re-exports [#1406](https://github.com/apache/datafusion-python/pull/1406) (kevinjqliu) + +**Documentation updates:** + +- docs: clarify DataFusion 52 FFI session-parameter requirement for provider hooks [#1439](https://github.com/apache/datafusion-python/pull/1439) (kevinjqliu) + +**Other:** + +- Merge release 52.0.0 into main [#1389](https://github.com/apache/datafusion-python/pull/1389) (timsaucer) +- Add workflow to verify release candidate on multiple systems [#1388](https://github.com/apache/datafusion-python/pull/1388) (timsaucer) +- Allow running "verify release candidate" github workflow on Windows [#1392](https://github.com/apache/datafusion-python/pull/1392) (kevinjqliu) +- ci: update pre-commit hooks, fix linting, and refresh dependencies [#1385](https://github.com/apache/datafusion-python/pull/1385) (dariocurr) +- Add CI check for crates.io patches [#1407](https://github.com/apache/datafusion-python/pull/1407) (timsaucer) +- Enable doc tests in local and CI testing [#1409](https://github.com/apache/datafusion-python/pull/1409) (ntjohnson1) +- Upgrade to DataFusion 53 [#1402](https://github.com/apache/datafusion-python/pull/1402) (nuno-faria) +- Catch warnings in FFI unit tests [#1410](https://github.com/apache/datafusion-python/pull/1410) (timsaucer) +- Add docstring examples for Scalar trigonometric functions [#1411](https://github.com/apache/datafusion-python/pull/1411) (ntjohnson1) +- Create workspace with core and util crates [#1414](https://github.com/apache/datafusion-python/pull/1414) (timsaucer) +- Add docstring examples for Scalar regex, crypto, struct and other [#1422](https://github.com/apache/datafusion-python/pull/1422) (ntjohnson1) +- Add docstring examples for Scalar math functions [#1421](https://github.com/apache/datafusion-python/pull/1421) (ntjohnson1) +- Add docstring examples for Common utility functions [#1419](https://github.com/apache/datafusion-python/pull/1419) (ntjohnson1) +- Add docstring examples for Aggregate basic and bitwise/boolean functions [#1416](https://github.com/apache/datafusion-python/pull/1416) (ntjohnson1) +- Fix CI errors on main [#1432](https://github.com/apache/datafusion-python/pull/1432) (timsaucer) +- Add docstring examples for Scalar temporal functions [#1424](https://github.com/apache/datafusion-python/pull/1424) (ntjohnson1) +- Add docstring examples for Aggregate statistical and regression functions [#1417](https://github.com/apache/datafusion-python/pull/1417) (ntjohnson1) +- Add docstring examples for Scalar array/list functions [#1420](https://github.com/apache/datafusion-python/pull/1420) (ntjohnson1) +- Add docstring examples for Scalar string functions [#1423](https://github.com/apache/datafusion-python/pull/1423) (ntjohnson1) +- Add docstring examples for Aggregate window functions [#1418](https://github.com/apache/datafusion-python/pull/1418) (ntjohnson1) +- ci: pin third-party actions to Apache-approved SHAs [#1438](https://github.com/apache/datafusion-python/pull/1438) (kevinjqliu) +- minor: bump datafusion to release version [#1441](https://github.com/apache/datafusion-python/pull/1441) (timsaucer) +- ci: add swap during build, use tpchgen-cli [#1443](https://github.com/apache/datafusion-python/pull/1443) (timsaucer) +- Update remaining existing examples to make testable/standalone executable [#1437](https://github.com/apache/datafusion-python/pull/1437) (ntjohnson1) +- Do not run validate_pycapsule if pointer_checked is used [#1426](https://github.com/apache/datafusion-python/pull/1426) (Tpt) +- Implement configuration extension support [#1391](https://github.com/apache/datafusion-python/pull/1391) (timsaucer) +- Add a working, more complete example of using a catalog (docs) [#1427](https://github.com/apache/datafusion-python/pull/1427) (toppyy) +- chore: update dependencies [#1447](https://github.com/apache/datafusion-python/pull/1447) (timsaucer) +- Complete doc string examples for functions.py [#1435](https://github.com/apache/datafusion-python/pull/1435) (ntjohnson1) +- chore: enforce uv lockfile consistency in CI and pre-commit [#1398](https://github.com/apache/datafusion-python/pull/1398) (mesejo) +- CI: Add CodeQL workflow for GitHub Actions security scanning [#1408](https://github.com/apache/datafusion-python/pull/1408) (kevinjqliu) +- ci: update codespell paths [#1469](https://github.com/apache/datafusion-python/pull/1469) (timsaucer) +- Add missing datetime functions [#1467](https://github.com/apache/datafusion-python/pull/1467) (timsaucer) +- Add AI skill to check current repository against upstream APIs [#1460](https://github.com/apache/datafusion-python/pull/1460) (timsaucer) +- Add missing string function `contains` [#1465](https://github.com/apache/datafusion-python/pull/1465) (timsaucer) +- Add missing conditional functions [#1464](https://github.com/apache/datafusion-python/pull/1464) (timsaucer) +- Reduce peak memory usage during release builds to fix OOM on manylinux runners [#1445](https://github.com/apache/datafusion-python/pull/1445) (kevinjqliu) +- Add missing map functions [#1461](https://github.com/apache/datafusion-python/pull/1461) (timsaucer) +- minor: Fix pytest instructions in the README [#1477](https://github.com/apache/datafusion-python/pull/1477) (nuno-faria) +- Add missing array functions [#1468](https://github.com/apache/datafusion-python/pull/1468) (timsaucer) +- Add missing scalar functions [#1470](https://github.com/apache/datafusion-python/pull/1470) (timsaucer) +- Add missing aggregate functions [#1471](https://github.com/apache/datafusion-python/pull/1471) (timsaucer) +- Add missing Dataframe functions [#1472](https://github.com/apache/datafusion-python/pull/1472) (timsaucer) +- Add missing deregister methods to SessionContext [#1473](https://github.com/apache/datafusion-python/pull/1473) (timsaucer) +- Add missing registration methods [#1474](https://github.com/apache/datafusion-python/pull/1474) (timsaucer) +- Add missing SessionContext utility methods [#1475](https://github.com/apache/datafusion-python/pull/1475) (timsaucer) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 25 Tim Saucer + 13 Nick + 6 Kevin Liu + 2 Daniel Mesejo + 2 Nuno Faria + 1 Paul J. Davis + 1 Thomas Tanon + 1 Topias Pyykkönen + 1 dario curreri +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. + From 00b24572c98a257f06ff026a90c07634a86204d4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 13 Apr 2026 06:33:29 -0400 Subject: [PATCH 19/29] ci: disable symbol export on Windows verification (#1486) * Set rust flags on windows release verification * Forward flag to linker * Switch to msvc rust toolchain * Revert "Switch to msvc rust toolchain" This reverts commit 9879fc7dbe066098445b9600087e665435b58f8a. --- .github/workflows/verify-release-candidate.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/verify-release-candidate.yml b/.github/workflows/verify-release-candidate.yml index a10a4faa9..7a5deff5b 100644 --- a/.github/workflows/verify-release-candidate.yml +++ b/.github/workflows/verify-release-candidate.yml @@ -73,6 +73,11 @@ jobs: version: "27.4" repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Set RUSTFLAGS for Windows GNU linker + if: matrix.os == 'windows' + shell: bash + run: echo "RUSTFLAGS=-C link-arg=-Wl,--exclude-libs=ALL" >> "$GITHUB_ENV" + - name: Run release candidate verification shell: bash run: ./dev/release/verify-release-candidate.sh "${{ inputs.version }}" "${{ inputs.rc_number }}" From 8a7efead43cff8dc7515e27e53da7545100e25a7 Mon Sep 17 00:00:00 2001 From: Shreyesh Date: Mon, 13 Apr 2026 03:34:35 -0700 Subject: [PATCH 20/29] Add Python bindings for accessing ExecutionMetrics (#1381) * feat: add Python bindings for accessing ExecutionMetrics * test: imporve tests * first round of reviews * plan caching * address some concerns * merge and address comments * fix Ci issues * attempt to fix lint * fix build * fix docstring * address some more comments --------- Co-authored-by: ShreyeshArangath --- Cargo.lock | 1 + Cargo.toml | 1 + crates/core/Cargo.toml | 1 + crates/core/src/dataframe.rs | 49 +++- crates/core/src/lib.rs | 3 + crates/core/src/metrics.rs | 169 ++++++++++++++ crates/core/src/physical_plan.rs | 5 + .../dataframe/execution-metrics.rst | 215 ++++++++++++++++++ docs/source/user-guide/dataframe/index.rst | 9 + python/datafusion/__init__.py | 4 +- python/datafusion/plan.py | 177 ++++++++++++++ python/tests/test_plans.py | 192 +++++++++++++++- 12 files changed, 817 insertions(+), 9 deletions(-) create mode 100644 crates/core/src/metrics.rs create mode 100644 docs/source/user-guide/dataframe/execution-metrics.rst diff --git a/Cargo.lock b/Cargo.lock index 1cbb0acb8..4efca3eb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1667,6 +1667,7 @@ dependencies = [ "arrow", "arrow-select", "async-trait", + "chrono", "cstr", "datafusion", "datafusion-ffi", diff --git a/Cargo.toml b/Cargo.toml index 14408d2bc..d0e87a9a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ tokio = { version = "1.50" } pyo3 = { version = "0.28" } pyo3-async-runtimes = { version = "0.28" } pyo3-log = "0.13.3" +chrono = { version = "0.4", default-features = false } arrow = { version = "58" } arrow-array = { version = "58" } arrow-schema = { version = "58" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 3e2b01c8e..d714dc978 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -47,6 +47,7 @@ pyo3 = { workspace = true, features = [ ] } pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"] } pyo3-log = { workspace = true } +chrono = { workspace = true } arrow = { workspace = true, features = ["pyarrow"] } arrow-select = { workspace = true } datafusion = { workspace = true, features = ["avro", "unicode_expressions"] } diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index c067eac30..2d815ec76 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -37,9 +37,15 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; +use datafusion::execution::context::TaskContext; use datafusion::logical_expr::SortExpr; use datafusion::logical_expr::dml::InsertOp; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; +use datafusion::physical_plan::{ + ExecutionPlan as DFExecutionPlan, collect as df_collect, + collect_partitioned as df_collect_partitioned, execute_stream as df_execute_stream, + execute_stream_partitioned as df_execute_stream_partitioned, +}; use datafusion::prelude::*; use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future}; use futures::{StreamExt, TryStreamExt}; @@ -308,6 +314,9 @@ pub struct PyDataFrame { // In IPython environment cache batches between __repr__ and _repr_html_ calls. batches: SharedCachedBatches, + + // Cache the last physical plan so that metrics are available after execution. + last_plan: Arc>>>, } impl PyDataFrame { @@ -316,6 +325,7 @@ impl PyDataFrame { Self { df: Arc::new(df), batches: Arc::new(Mutex::new(None)), + last_plan: Arc::new(Mutex::new(None)), } } @@ -387,6 +397,20 @@ impl PyDataFrame { Ok(html_str) } + /// Create the physical plan, cache it in `last_plan`, and return the plan together + /// with a task context. Centralises the repeated three-line pattern that appears in + /// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`. + fn create_and_cache_plan( + &self, + py: Python, + ) -> PyDataFusionResult<(Arc, Arc)> { + let df = self.df.as_ref().clone(); + let new_plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&new_plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + Ok((new_plan, task_ctx)) + } + async fn collect_column_inner(&self, column: &str) -> Result { let batches = self .df @@ -646,8 +670,9 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect<'py>(&self, py: Python<'py>) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())? - .map_err(PyDataFusionError::from)?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let batches = + wait_for_future(py, df_collect(plan, task_ctx))?.map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() @@ -662,7 +687,8 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult>>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())? + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))? .map_err(PyDataFusionError::from)?; batches @@ -840,7 +866,13 @@ impl PyDataFrame { } /// Get the execution plan for this `DataFrame` + /// + /// If the DataFrame has already been executed (e.g. via `collect()`), + /// returns the cached plan which includes populated metrics. fn execution_plan(&self, py: Python) -> PyDataFusionResult { + if let Some(plan) = self.last_plan.lock().as_ref() { + return Ok(PyExecutionPlan::new(Arc::clone(plan))); + } let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??; Ok(plan.into()) } @@ -1198,14 +1230,17 @@ impl PyDataFrame { } fn execute_stream(&self, py: Python) -> PyDataFusionResult { - let df = self.df.as_ref().clone(); - let stream = spawn_future(py, async move { df.execute_stream().await })?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { - let df = self.df.as_ref().clone(); - let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let streams = spawn_future( + py, + async move { df_execute_stream_partitioned(plan, task_ctx) }, + )?; Ok(streams.into_iter().map(PyRecordBatchStream::new).collect()) } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index fc2d006d3..77d69911a 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +pub mod metrics; mod options; pub mod physical_plan; mod pyarrow_filter_expression; @@ -92,6 +93,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/crates/core/src/metrics.rs b/crates/core/src/metrics.rs new file mode 100644 index 000000000..ee0937e25 --- /dev/null +++ b/crates/core/src/metrics.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{Datelike, Timelike}; +use datafusion::physical_plan::metrics::{Metric, MetricValue, MetricsSet, Timestamp}; +use pyo3::prelude::*; + +#[pyclass(from_py_object, frozen, name = "MetricsSet", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetricsSet { + metrics: MetricsSet, +} + +impl PyMetricsSet { + pub fn new(metrics: MetricsSet) -> Self { + Self { metrics } + } +} + +#[pymethods] +impl PyMetricsSet { + fn metrics(&self) -> Vec { + self.metrics + .iter() + .map(|m| PyMetric::new(Arc::clone(m))) + .collect() + } + + fn output_rows(&self) -> Option { + self.metrics.output_rows() + } + + fn elapsed_compute(&self) -> Option { + self.metrics.elapsed_compute() + } + + fn spill_count(&self) -> Option { + self.metrics.spill_count() + } + + fn spilled_bytes(&self) -> Option { + self.metrics.spilled_bytes() + } + + fn spilled_rows(&self) -> Option { + self.metrics.spilled_rows() + } + + fn sum_by_name(&self, name: &str) -> Option { + self.metrics.sum_by_name(name).map(|v| v.as_usize()) + } + + fn __repr__(&self) -> String { + format!("{}", self.metrics) + } +} + +#[pyclass(from_py_object, frozen, name = "Metric", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetric { + metric: Arc, +} + +impl PyMetric { + pub fn new(metric: Arc) -> Self { + Self { metric } + } + + fn timestamp_to_pyobject<'py>( + py: Python<'py>, + ts: &Timestamp, + ) -> PyResult>> { + match ts.value() { + Some(dt) => { + let datetime_mod = py.import("datetime")?; + let datetime_cls = datetime_mod.getattr("datetime")?; + let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?; + let result = datetime_cls.call1(( + dt.year(), + dt.month(), + dt.day(), + dt.hour(), + dt.minute(), + dt.second(), + dt.timestamp_subsec_micros(), + tz_utc, + ))?; + Ok(Some(result)) + } + None => Ok(None), + } + } +} + +#[pymethods] +impl PyMetric { + #[getter] + fn name(&self) -> String { + self.metric.value().name().to_string() + } + + #[getter] + fn value<'py>(&self, py: Python<'py>) -> PyResult>> { + match self.metric.value() { + MetricValue::OutputRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::OutputBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::ElapsedCompute(t) => Ok(Some(t.value().into_pyobject(py)?.into_any())), + MetricValue::SpillCount(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::SpilledBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::SpilledRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::CurrentMemoryUsage(g) => Ok(Some(g.value().into_pyobject(py)?.into_any())), + MetricValue::Count { count, .. } => { + Ok(Some(count.value().into_pyobject(py)?.into_any())) + } + MetricValue::Gauge { gauge, .. } => { + Ok(Some(gauge.value().into_pyobject(py)?.into_any())) + } + MetricValue::Time { time, .. } => Ok(Some(time.value().into_pyobject(py)?.into_any())), + MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => { + Self::timestamp_to_pyobject(py, ts) + } + _ => Ok(None), + } + } + + #[getter] + fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult>> { + match self.metric.value() { + MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => { + Self::timestamp_to_pyobject(py, ts) + } + _ => Ok(None), + } + } + + #[getter] + fn partition(&self) -> Option { + self.metric.partition() + } + + fn labels(&self) -> HashMap { + self.metric + .labels() + .iter() + .map(|l| (l.name().to_string(), l.value().to_string())) + .collect() + } + + fn __repr__(&self) -> String { + format!("{}", self.metric.value()) + } +} diff --git a/crates/core/src/physical_plan.rs b/crates/core/src/physical_plan.rs index 8674a8b55..fac973884 100644 --- a/crates/core/src/physical_plan.rs +++ b/crates/core/src/physical_plan.rs @@ -26,6 +26,7 @@ use pyo3::types::PyBytes; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; +use crate::metrics::PyMetricsSet; #[pyclass( from_py_object, @@ -96,6 +97,10 @@ impl PyExecutionPlan { Ok(Self::new(plan)) } + pub fn metrics(&self) -> Option { + self.plan.metrics().map(PyMetricsSet::new) + } + fn __repr__(&self) -> String { self.display_indent() } diff --git a/docs/source/user-guide/dataframe/execution-metrics.rst b/docs/source/user-guide/dataframe/execution-metrics.rst new file mode 100644 index 000000000..764fa76ef --- /dev/null +++ b/docs/source/user-guide/dataframe/execution-metrics.rst @@ -0,0 +1,215 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. _execution_metrics: + +Execution Metrics +================= + +Overview +-------- + +When DataFusion executes a query it compiles the logical plan into a tree of +*physical plan operators* (e.g. ``FilterExec``, ``ProjectionExec``, +``HashAggregateExec``). Each operator can record runtime statistics while it +runs. These statistics are called **execution metrics**. + +Typical metrics include: + +- **output_rows** – number of rows produced by the operator +- **elapsed_compute** – total CPU time (nanoseconds) spent inside the operator +- **spill_count** – number of times the operator spilled data to disk +- **spilled_bytes** – total bytes written to disk during spills +- **spilled_rows** – total rows written to disk during spills + +Metrics are collected *per-partition*: DataFusion may execute each operator +in parallel across several partitions. The convenience properties on +:py:class:`~datafusion.MetricsSet` (e.g. ``output_rows``, ``elapsed_compute``) +automatically sum the named metric across **all** partitions, giving a single +aggregate value for the operator as a whole. You can also access the raw +per-partition :py:class:`~datafusion.Metric` objects via +:py:meth:`~datafusion.MetricsSet.metrics`. + +When Are Metrics Available? +--------------------------- + +Some operators (for example ``DataSourceExec``) eagerly create a +:py:class:`~datafusion.MetricsSet` when the physical plan is built, so +:py:meth:`~datafusion.ExecutionPlan.metrics` may return a set even before any +rows have been processed. However, metric **values** such as ``output_rows`` +are only meaningful **after** the DataFrame has been executed via one of the +terminal operations: + +- :py:meth:`~datafusion.DataFrame.collect` +- :py:meth:`~datafusion.DataFrame.collect_partitioned` +- :py:meth:`~datafusion.DataFrame.execute_stream` + (metrics are available once the stream has been fully consumed) +- :py:meth:`~datafusion.DataFrame.execute_stream_partitioned` + (metrics are available once all partition streams have been fully consumed) + +Before execution, metric values will be ``0`` or ``None``. + +.. note:: + + **display() does not populate metrics.** + When a DataFrame is displayed in a notebook (e.g. via ``display(df)`` or + automatic ``repr`` output), DataFusion runs a *limited* internal execution + to fetch preview rows. This internal execution does **not** cache the + physical plan used, so :py:meth:`~datafusion.ExecutionPlan.collect_metrics` + will not reflect the display execution. To access metrics you must call + one of the terminal operations listed above. + +If you call :py:meth:`~datafusion.DataFrame.collect` (or another terminal +operation) multiple times on the same DataFrame, each call creates a fresh +physical plan. Metrics from :py:meth:`~datafusion.DataFrame.execution_plan` +always reflect the **most recent** execution. + +Reading the Physical Plan Tree +-------------------------------- + +:py:meth:`~datafusion.DataFrame.execution_plan` returns the root +:py:class:`~datafusion.ExecutionPlan` node of the physical plan tree. The tree +mirrors the operator pipeline: the root is typically a projection or +coalescing node; its children are filters, aggregates, scans, etc. + +The ``operator_name`` string returned by +:py:meth:`~datafusion.ExecutionPlan.collect_metrics` is the *display* name of +the node, for example ``"FilterExec: column1@0 > 1"``. This is the same string +you would see when calling ``plan.display()``. + +Aggregated vs Per-Partition Metrics +------------------------------------ + +DataFusion executes each operator across one or more **partitions** in +parallel. The :py:class:`~datafusion.MetricsSet` convenience properties +(``output_rows``, ``elapsed_compute``, etc.) automatically **sum** the named +metric across all partitions, giving a single aggregate value. + +To inspect individual partitions — for example to detect data skew where one +partition processes far more rows than others — iterate over the raw +:py:class:`~datafusion.Metric` objects: + +.. code-block:: python + + for metric in metrics_set.metrics(): + print(f" partition={metric.partition} {metric.name}={metric.value}") + +The ``partition`` property is a 0-based index (``0``, ``1``, …) identifying +which parallel slot processed this metric. It is ``None`` for metrics that +apply globally (not tied to a specific partition). + +Available Metrics +----------------- + +The following metrics are directly accessible as properties on +:py:class:`~datafusion.MetricsSet`: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Property + - Description + * - ``output_rows`` + - Number of rows emitted by the operator (summed across partitions). + * - ``elapsed_compute`` + - Wall-clock CPU time **in nanoseconds** spent inside the operator's + compute loop, excluding I/O wait. Useful for identifying which + operators are most expensive (summed across partitions). + * - ``spill_count`` + - Number of spill-to-disk events triggered by memory pressure. This is + a unitless count of events, not a measure of data volume (summed across + partitions). + * - ``spilled_bytes`` + - Total bytes written to disk during spill events (summed across + partitions). + * - ``spilled_rows`` + - Total rows written to disk during spill events (summed across + partitions). + +Any metric not listed above can be accessed via +:py:meth:`~datafusion.MetricsSet.sum_by_name`, or by iterating over the raw +:py:class:`~datafusion.Metric` objects returned by +:py:meth:`~datafusion.MetricsSet.metrics`. + +Labels +------ + +A :py:class:`~datafusion.Metric` may carry *labels*: key/value pairs that +provide additional context. Labels are operator-specific; most metrics have +an empty label dict. + +Some operators tag their metrics with labels to distinguish variants. For +example, a ``HashAggregateExec`` may record separate ``output_rows`` metrics +for intermediate and final output: + +.. code-block:: python + + for metric in metrics_set.metrics(): + print(metric.name, metric.labels()) + # output_rows {'output_type': 'final'} + # output_rows {'output_type': 'intermediate'} + +When summing by name (via :py:attr:`~datafusion.MetricsSet.output_rows` or +:py:meth:`~datafusion.MetricsSet.sum_by_name`), **all** metrics with that +name are summed regardless of labels. To filter by label, iterate over the +raw :py:class:`~datafusion.Metric` objects directly. + +End-to-End Example +------------------ + +.. code-block:: python + + from datafusion import SessionContext + + ctx = SessionContext() + ctx.sql("CREATE TABLE sales AS VALUES (1, 100), (2, 200), (3, 50)") + + df = ctx.sql("SELECT * FROM sales WHERE column1 > 1") + + # Execute the query — this populates the metrics + results = df.collect() + + # Retrieve the physical plan with metrics + plan = df.execution_plan() + + # Walk every operator and print its metrics + for operator_name, ms in plan.collect_metrics(): + if ms.output_rows is not None: + print(f"{operator_name}") + print(f" output_rows = {ms.output_rows}") + print(f" elapsed_compute = {ms.elapsed_compute} ns") + + # Access raw per-partition metrics + for operator_name, ms in plan.collect_metrics(): + for metric in ms.metrics(): + print( + f" partition={metric.partition} " + f"{metric.name}={metric.value} " + f"labels={metric.labels()}" + ) + +API Reference +------------- + +- :py:class:`datafusion.ExecutionPlan` — physical plan node +- :py:meth:`datafusion.ExecutionPlan.collect_metrics` — walk the tree and + return ``(operator_name, MetricsSet)`` pairs +- :py:meth:`datafusion.ExecutionPlan.metrics` — return the + :py:class:`~datafusion.MetricsSet` for a single node +- :py:class:`datafusion.MetricsSet` — aggregated metrics for one operator +- :py:class:`datafusion.Metric` — a single per-partition metric value diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 510bcbc68..8475a7bd7 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -365,7 +365,16 @@ DataFusion provides many built-in functions for data manipulation: For a complete list of available functions, see the :py:mod:`datafusion.functions` module documentation. +Execution Metrics +----------------- + +After executing a DataFrame (via ``collect()``, ``execute_stream()``, etc.), +DataFusion populates per-operator runtime statistics such as row counts and +compute time. See :doc:`execution-metrics` for a full explanation and +worked example. + .. toctree:: :maxdepth: 1 rendering + execution-metrics diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index ee02c921d..80dfa2fab 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -56,7 +56,7 @@ from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .options import CsvReadOptions -from .plan import ExecutionPlan, LogicalPlan +from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( Accumulator, @@ -86,6 +86,8 @@ "Expr", "InsertOp", "LogicalPlan", + "Metric", + "MetricsSet", "ParquetColumnOptions", "ParquetWriterOptions", "RecordBatch", diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index 9c96a18fc..c0cfd523f 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -24,11 +24,15 @@ import datafusion._internal as df_internal if TYPE_CHECKING: + import datetime + from datafusion.context import SessionContext __all__ = [ "ExecutionPlan", "LogicalPlan", + "Metric", + "MetricsSet", ] @@ -151,3 +155,176 @@ def to_proto(self) -> bytes: Tables created in memory from record batches are currently not supported. """ return self._raw_plan.to_proto() + + def metrics(self) -> MetricsSet | None: + """Return metrics for this plan node, or None if this plan has no MetricsSet. + + Some operators (e.g. DataSourceExec) eagerly initialize a MetricsSet + when the plan is created, so this may return a set even before + execution. Metric *values* (such as ``output_rows``) are only + meaningful after the DataFrame has been executed. + """ + raw = self._raw_plan.metrics() + if raw is None: + return None + return MetricsSet(raw) + + def collect_metrics(self) -> list[tuple[str, MetricsSet]]: + """Return runtime statistics for each step of the query execution. + + DataFusion executes a query as a pipeline of operators — for example a + data source scan, followed by a filter, followed by a projection. After + the DataFrame has been executed (via + :py:meth:`~datafusion.DataFrame.collect`, + :py:meth:`~datafusion.DataFrame.execute_stream`, etc.), each operator + records statistics such as how many rows it produced and how much CPU + time it consumed. + + Each entry in the returned list corresponds to one operator that + recorded metrics. The first element of the tuple is the operator's + description string — the same text shown by + :py:meth:`display_indent` — which identifies both the operator type + and its key parameters, for example ``"FilterExec: column1@0 > 1"`` + or ``"DataSourceExec: partitions=1"``. + + Returns: + A list of ``(description, MetricsSet)`` tuples ordered from the + outermost operator (top of the execution tree) down to the + data-source leaves. Only operators that recorded at least one + metric are included. Returns an empty list if called before the + DataFrame has been executed. + """ + result: list[tuple[str, MetricsSet]] = [] + + def _walk(node: ExecutionPlan) -> None: + ms = node.metrics() + if ms is not None: + result.append((node.display(), ms)) + for child in node.children(): + _walk(child) + + _walk(self) + return result + + +class MetricsSet: + """A set of metrics for a single execution plan operator. + + A physical plan operator runs independently across one or more partitions. + :py:meth:`metrics` returns the raw per-partition :py:class:`Metric` objects. + The convenience properties (:py:attr:`output_rows`, :py:attr:`elapsed_compute`, + etc.) automatically sum the named metric across *all* partitions, giving a + single aggregate value for the operator as a whole. + """ + + def __init__(self, raw: df_internal.MetricsSet) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + def metrics(self) -> list[Metric]: + """Return all individual metrics in this set.""" + return [Metric(m) for m in self._raw.metrics()] + + @property + def output_rows(self) -> int | None: + """Sum of output_rows across all partitions.""" + return self._raw.output_rows() + + @property + def elapsed_compute(self) -> int | None: + """Total CPU time (in nanoseconds) spent inside this operator's execute loop. + + Summed across all partitions. Returns ``None`` if no ``elapsed_compute`` + metric was recorded. + """ + return self._raw.elapsed_compute() + + @property + def spill_count(self) -> int | None: + """Number of times this operator spilled data to disk due to memory pressure. + + This is a count of spill events, not a byte count. Summed across all + partitions. Returns ``None`` if no ``spill_count`` metric was recorded. + """ + return self._raw.spill_count() + + @property + def spilled_bytes(self) -> int | None: + """Sum of spilled_bytes across all partitions.""" + return self._raw.spilled_bytes() + + @property + def spilled_rows(self) -> int | None: + """Sum of spilled_rows across all partitions.""" + return self._raw.spilled_rows() + + def sum_by_name(self, name: str) -> int | None: + """Sum the named metric across all partitions. + + Useful for accessing any metric not exposed as a first-class property. + Returns ``None`` if no metric with the given name was recorded. + + Args: + name: The metric name, e.g. ``"output_rows"`` or ``"elapsed_compute"``. + """ + return self._raw.sum_by_name(name) + + def __repr__(self) -> str: + """Return a string representation of the metrics set.""" + return repr(self._raw) + + +class Metric: + """A single execution metric with name, value, partition, and labels.""" + + def __init__(self, raw: df_internal.Metric) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + @property + def name(self) -> str: + """The name of this metric (e.g. ``output_rows``).""" + return self._raw.name + + @property + def value(self) -> int | datetime.datetime | None: + """The value of this metric. + + Returns an ``int`` for counters, gauges, and time-based metrics + (nanoseconds), a :py:class:`~datetime.datetime` (UTC) for + ``start_timestamp`` / ``end_timestamp`` metrics, or ``None`` + when the value has not been set or is not representable. + """ + return self._raw.value + + @property + def value_as_datetime(self) -> datetime.datetime | None: + """The value as a UTC :py:class:`~datetime.datetime` for timestamp metrics. + + Returns ``None`` for all non-timestamp metrics and for timestamp + metrics whose value has not been set (e.g. before execution). + """ + return self._raw.value_as_datetime + + @property + def partition(self) -> int | None: + """The 0-based partition index this metric applies to. + + Returns ``None`` for metrics that are not partition-specific (i.e. they + apply globally across all partitions of the operator). + """ + return self._raw.partition + + def labels(self) -> dict[str, str]: + """Return the labels associated with this metric. + + Labels provide additional context for a metric. For example:: + + metric.labels() + # {'output_type': 'final'} + """ + return self._raw.labels() + + def __repr__(self) -> str: + """Return a string representation of the metric.""" + return repr(self._raw) diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 396acbe97..3705fc7ef 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -15,8 +15,16 @@ # specific language governing permissions and limitations # under the License. +import datetime + import pytest -from datafusion import ExecutionPlan, LogicalPlan, SessionContext +from datafusion import ( + ExecutionPlan, + LogicalPlan, + Metric, + MetricsSet, + SessionContext, +) # Note: We must use CSV because memory tables are currently not supported for @@ -40,3 +48,185 @@ def test_logical_plan_to_proto(ctx, df) -> None: execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes) assert str(original_execution_plan) == str(execution_plan) + + +def test_metrics_tree_walk() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + results = plan.collect_metrics() + assert len(results) >= 1 + output_rows_by_op: dict[str, int] = {} + for name, ms in results: + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + if ms.output_rows is not None: + output_rows_by_op[name] = ms.output_rows + + # The filter passes rows where column1 > 1, so exactly + # 2 rows from (1,'a'),(2,'b'),(3,'c'). + # At least one operator must report exactly 2 output rows (the filter). + assert 2 in output_rows_by_op.values(), ( + f"Expected an operator with output_rows=2, got {output_rows_by_op}" + ) + + +def test_metric_properties() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + found_any_metric = False + for _, ms in plan.collect_metrics(): + r = repr(ms) + assert isinstance(r, str) + for metric in ms.metrics(): + found_any_metric = True + assert isinstance(metric, Metric) + assert isinstance(metric.name, str) + assert len(metric.name) > 0 + assert metric.partition is None or isinstance(metric.partition, int) + assert metric.value is None or isinstance( + metric.value, int | datetime.datetime + ) + assert isinstance(metric.labels(), dict) + mr = repr(metric) + assert isinstance(mr, str) + assert len(mr) > 0 + assert found_any_metric, "Expected at least one metric after execution" + + +def test_no_meaningful_metrics_before_execution() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + plan_before = df.execution_plan() + + # Some plan nodes (e.g. DataSourceExec) eagerly initialize a MetricsSet, + # so metrics() may return a set even before execution. However, no rows + # should have been processed yet — output_rows must be absent or zero. + for _, ms in plan_before.collect_metrics(): + rows = ms.output_rows + assert rows is None or rows == 0, ( + f"Expected 0 output_rows before execution, got {rows}" + ) + + # After execution, at least one operator must report rows processed. + df.collect() + plan_after = df.execution_plan() + output_rows_after = [ + ms.output_rows + for _, ms in plan_after.collect_metrics() + if ms.output_rows is not None and ms.output_rows > 0 + ] + assert len(output_rows_after) > 0, "Expected output_rows > 0 after execution" + + +def test_collect_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect_partitioned() + plan = df.execution_plan() + + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_execute_stream_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for _ in df.execute_stream(): + pass + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_execute_stream_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for stream in df.execute_stream_partitioned(): + for _ in stream: + pass + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_value_as_datetime() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + if metric.name in ("start_timestamp", "end_timestamp"): + dt = metric.value_as_datetime + assert dt is None or isinstance(dt, datetime.datetime) + if dt is not None: + assert dt.tzinfo is not None + else: + assert metric.value_as_datetime is None + + +def test_metric_names_and_labels() -> None: + """Verify that known metric names appear and labels are well-formed.""" + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + all_metric_names: set[str] = set() + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + all_metric_names.add(metric.name) + # Labels must be a dict of str->str + labels = metric.labels() + for k, v in labels.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + # After a filter query, we expect at minimum these standard metric names. + assert "output_rows" in all_metric_names, ( + f"Expected 'output_rows' in {all_metric_names}" + ) + assert "elapsed_compute" in all_metric_names, ( + f"Expected 'elapsed_compute' in {all_metric_names}" + ) + + +def test_collect_twice_has_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect() + df.collect() + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert len(output_rows_values) > 0 From 398980d1edbb8ad6d9744236f2dfe0c6ab4b4665 Mon Sep 17 00:00:00 2001 From: Zeel Desai <72783325+zeel2104@users.noreply.github.com> Date: Mon, 13 Apr 2026 09:24:56 -0400 Subject: [PATCH 21/29] Support None comparisons for null expressions (#1489) * Support None comparisons for null expressions * Fold None comparison coverage into relational expr test --- python/datafusion/expr.py | 4 ++++ python/tests/test_expr.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7cd74ecd5..32004656f 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -483,6 +483,8 @@ def __eq__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) @@ -492,6 +494,8 @@ def __ne__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_not_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__ne__(rhs.expr)) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 1cf824a15..d046eb48c 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -153,8 +153,8 @@ def test_relational_expr(test_ctx): batch = pa.RecordBatch.from_arrays( [ - pa.array([1, 2, 3]), - pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), + pa.array([1, 2, 3, None]), + pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()), ], names=["a", "b"], ) @@ -171,6 +171,10 @@ def test_relational_expr(test_ctx): assert df.filter(col("b") != "beta").count() == 2 assert df.filter(col("a") == "beta").count() == 0 + assert df.filter(col("a") == None).count() == 1 # noqa: E711 + assert df.filter(col("a") != None).count() == 3 # noqa: E711 + assert df.filter(col("b") == None).count() == 1 # noqa: E711 + assert df.filter(col("b") != None).count() == 3 # noqa: E711 def test_expr_to_variant(): From 2715a32e939d17222c18e8adacf85ee45da464b9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 03:27:00 -0400 Subject: [PATCH 22/29] chore: update release documentation (#1494) * Update release documentation * Minor change to workflow because release start at 1 --- .../workflows/verify-release-candidate.yml | 2 +- dev/release/README.md | 60 ++++++++++--------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/.github/workflows/verify-release-candidate.yml b/.github/workflows/verify-release-candidate.yml index 7a5deff5b..6ecb547b5 100644 --- a/.github/workflows/verify-release-candidate.yml +++ b/.github/workflows/verify-release-candidate.yml @@ -27,7 +27,7 @@ on: required: true type: string rc_number: - description: Release candidate number (e.g., 0) + description: Release candidate number (e.g., 1) required: true type: string diff --git a/dev/release/README.md b/dev/release/README.md index ed28f4aa6..4833be55a 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -26,11 +26,11 @@ required due to changes in DataFusion rather than having a large amount of work is available. When there is a new official release of DataFusion, we update the `main` branch to point to that, update the version -number, and create a new release branch, such as `branch-0.8`. Once this branch is created, we switch the `main` branch +number, and create a new release branch, such as `branch-53`. Once this branch is created, we switch the `main` branch back to using GitHub dependencies. The release activity (such as generating the changelog) can then happen on the release branch without blocking ongoing development in the `main` branch. -We can cherry-pick commits from the `main` branch into `branch-0.8` as needed and then create new patch releases +We can cherry-pick commits from the `main` branch into `branch-53` as needed and then create new patch releases from that branch. ## Detailed Guide @@ -54,7 +54,8 @@ Before creating a new release: - We need to ensure that the main branch does not have any GitHub dependencies - a PR should be created and merged to update the major version number of the project -- A new release branch should be created, such as `branch-0.8` +- A new release branch should be created, such as `branch-53` +- It is best to push this branch to the apache repository rather than a personal fork in case patch releases are required. ## Preparing a Release Candidate @@ -65,14 +66,14 @@ We maintain a `CHANGELOG.md` so our users know what has been changed between rel The changelog is generated using a Python script: ```bash -$ GITHUB_TOKEN= ./dev/release/generate-changelog.py 24.0.0 HEAD 25.0.0 > dev/changelog/25.0.0.md +$ GITHUB_TOKEN= ./dev/release/generate-changelog.py 52.0.0 HEAD 53.0.0 > dev/changelog/53.0.0.md ``` This script creates a changelog from GitHub PRs based on the labels associated with them as well as looking for titles starting with `feat:`, `fix:`, or `docs:` . The script will produce output similar to: ``` -Fetching list of commits between 24.0.0 and HEAD +Fetching list of commits between 52.0.0 and HEAD Fetching pull requests Categorizing pull requests Generating changelog content @@ -81,6 +82,7 @@ Generating changelog content ### Update the version number The only place you should need to update the version is in the root `Cargo.toml`. +You will need to update this both in the workspace section and also in the dependencies. After updating the toml file, run `cargo update` to update the cargo lock file. If you do not want to update all the dependencies, you can instead run `cargo build` which should only update the version number for `datafusion-python`. @@ -94,14 +96,14 @@ you need to push a tag to start the CI process for release candidates. The follo the upstream repository is called `apache`. ```bash -git tag 0.8.0-rc1 -git push apache 0.8.0-rc1 +git tag 53.0.0-rc1 +git push apache 53.0.0-rc1 ``` ### Create a source release ```bash -./dev/release/create-tarball.sh 0.8.0 1 +./dev/release/create-tarball.sh 53.0.0 1 ``` This will also create the email template to send to the mailing list. @@ -124,10 +126,10 @@ Click on the action and scroll down to the bottom of the page titled "Artifacts" contain files such as: ```text -datafusion-22.0.0-cp37-abi3-macosx_10_7_x86_64.whl -datafusion-22.0.0-cp37-abi3-macosx_11_0_arm64.whl -datafusion-22.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -datafusion-22.0.0-cp37-abi3-win_amd64.whl +datafusion-53.0.0-cp37-abi3-macosx_10_7_x86_64.whl +datafusion-53.0.0-cp37-abi3-macosx_11_0_arm64.whl +datafusion-53.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +datafusion-53.0.0-cp37-abi3-win_amd64.whl ``` Upload the wheels to testpypi. @@ -135,23 +137,23 @@ Upload the wheels to testpypi. ```bash unzip dist.zip python3 -m pip install --upgrade setuptools twine build -python3 -m twine upload --repository testpypi datafusion-22.0.0-cp37-abi3-*.whl +python3 -m twine upload --repository testpypi datafusion-53.0.0-cp37-abi3-*.whl ``` When prompted for username, enter `__token__`. When prompted for a password, enter a valid GitHub Personal Access Token #### Publish Python Source Distribution to testpypi -Download the source tarball created in the previous step, untar it, and run: +Download the source tarball from the Apache server created in the previous step, untar it, and run: ```bash maturin sdist ``` -This will create a file named `dist/datafusion-0.7.0.tar.gz`. Upload this to testpypi: +This will create a file named `dist/datafusion-53.0.0.tar.gz`. Upload this to testpypi: ```bash -python3 -m twine upload --repository testpypi dist/datafusion-0.7.0.tar.gz +python3 -m twine upload --repository testpypi dist/datafusion-53.0.0.tar.gz ``` ### Run Verify Release Candidate Workflow @@ -162,8 +164,8 @@ Before sending the vote email, run the manually triggered GitHub Actions workflo 1. Go to https://github.com/apache/datafusion-python/actions/workflows/verify-release-candidate.yml 2. Click "Run workflow" -3. Set `version` to the release version (for example, `52.0.0`) -4. Set `rc_number` to the RC number (for example, `0`) +3. Set `version` to the release version (for example, `53.0.0`) +4. Set `rc_number` to the RC number (for example, `1`) 5. Wait for all jobs to complete successfully Include a short note in the vote email template that this workflow was run across all OS/architecture @@ -183,7 +185,7 @@ Releases may be verified using `verify-release-candidate.sh`: ```bash git clone https://github.com/apache/datafusion-python.git -dev/release/verify-release-candidate.sh 48.0.0 1 +dev/release/verify-release-candidate.sh 53.0.0 1 ``` Alternatively, one can run unit tests against a testpypi release candidate: @@ -195,7 +197,7 @@ cd datafusion-python # checkout the release commit git fetch --tags -git checkout 40.0.0-rc1 +git checkout 53.0.0-rc1 git submodule update --init --recursive # create the env @@ -203,7 +205,7 @@ python3 -m venv .venv source .venv/bin/activate # install release candidate -pip install --extra-index-url https://test.pypi.org/simple/ datafusion==40.0.0 +pip install --extra-index-url https://test.pypi.org/simple/ datafusion==53.0.0 # install test dependencies pip install pytest numpy pytest-asyncio @@ -224,7 +226,7 @@ Once the vote passes, we can publish the release. Create the source release tarball: ```bash -./dev/release/release-tarball.sh 0.8.0 1 +./dev/release/release-tarball.sh 53.0.0 1 ``` ### Publishing Rust Crate to crates.io @@ -232,7 +234,7 @@ Create the source release tarball: Some projects depend on the Rust crate directly, so we publish this to crates.io ```shell -cargo publish +cargo publish --workspace ``` ### Publishing Python Artifacts to PyPi @@ -252,15 +254,15 @@ Pypi packages auto upload to conda-forge via [datafusion feedstock](https://gith ### Push the Release Tag ```bash -git checkout 0.8.0-rc1 -git tag 0.8.0 -git push apache 0.8.0 +git checkout 53.0.0-rc1 +git tag 53.0.0 +git push apache 53.0.0 ``` ### Add the release to Apache Reporter Add the release to https://reporter.apache.org/addrelease.html?datafusion with a version name prefixed with `DATAFUSION-PYTHON`, -for example `DATAFUSION-PYTHON-31.0.0`. +for example `DATAFUSION-PYTHON-53.0.0`. The release information is used to generate a template for a board report (see example from Apache Arrow [here](https://github.com/apache/arrow/pull/14357)). @@ -283,7 +285,7 @@ svn ls https://dist.apache.org/repos/dist/dev/datafusion | grep datafusion-pytho Delete a release candidate: ```bash -svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-python-7.1.0-rc1/ +svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-python-53.0.0-rc1/ ``` #### Deleting old releases from `release` svn @@ -299,5 +301,5 @@ svn ls https://dist.apache.org/repos/dist/release/datafusion | grep datafusion-p Delete a release: ```bash -svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-python-7.0.0 +svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-python-52.0.0 ``` From 60d8b5dbb5e409cd9ce7692972420e955b8a802e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 14 Apr 2026 03:31:01 -0400 Subject: [PATCH 23/29] Fix error on show() with an explain plan (#1492) --- crates/core/src/dataframe.rs | 12 ++++++++++-- python/tests/test_dataframe.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 2d815ec76..2e74991b8 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -38,8 +38,8 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::TaskContext; -use datafusion::logical_expr::SortExpr; use datafusion::logical_expr::dml::InsertOp; +use datafusion::logical_expr::{LogicalPlan, SortExpr}; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::physical_plan::{ ExecutionPlan as DFExecutionPlan, collect as df_collect, @@ -707,7 +707,15 @@ impl PyDataFrame { /// Print the result, 20 lines by default #[pyo3(signature = (num=20))] fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> { - let df = self.df.as_ref().clone().limit(0, Some(num))?; + let mut df = self.df.as_ref().clone(); + df = match self.df.logical_plan() { + LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) => { + // Explain and Analyzer require they are at the top + // of the plan, so do not add a limit. + df + } + _ => df.limit(0, Some(num))?, + }; print_dataframe(py, df) } diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index bb8e9685c..091fa9b56 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -412,6 +412,16 @@ def test_show_empty(df, capsys): assert "DataFrame has no rows" in captured.out +def test_show_on_explain(ctx, capsys): + ctx.sql("explain select 1").show() + captured = capsys.readouterr() + assert "1 as Int64(1)" in captured.out + + ctx.sql("explain analyze select 1").show() + captured = capsys.readouterr() + assert "1 as Int64(1)" in captured.out + + def test_sort(df): df = df.sort(column("b").sort(ascending=False)) From 40309978c920bd123a4c7b764a2ddfdb97758607 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 18:28:55 -0400 Subject: [PATCH 24/29] Add SKILL.md and enrich package docstring (#1497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add AGENTS.md and enrich __init__.py module docstring Add python/datafusion/AGENTS.md as a comprehensive DataFrame API guide for AI agents and users. It ships with pip automatically (Maturin includes everything under python-source = "python"). Covers core abstractions, import conventions, data loading, all DataFrame operations, expression building, a SQL-to-DataFrame reference table, common pitfalls, idiomatic patterns, and a categorized function index. Enrich the __init__.py module docstring from 2 lines to a full overview with core abstractions, a quick-start example, and a pointer to AGENTS.md. Closes #1394 (PR 1a) Co-Authored-By: Claude Opus 4.6 (1M context) * Clarify audience of root vs package AGENTS.md The root AGENTS.md (symlinked as CLAUDE.md) is for contributors working on the project. Add a pointer to python/datafusion/AGENTS.md which is the user-facing DataFrame API guide shipped with the package. Also add the Apache license header to the package AGENTS.md. Co-Authored-By: Claude Opus 4.6 (1M context) * Add PR template and pre-commit check guidance to AGENTS.md Document that all PRs must follow .github/pull_request_template.md and that pre-commit hooks must pass before committing. List all configured hooks (actionlint, ruff, ruff-format, cargo fmt, cargo clippy, codespell, uv-lock) and the command to run them manually. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove duplicated hook list from AGENTS.md Let the hooks be discoverable from .pre-commit-config.yaml rather than maintaining a separate list that can drift. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix AGENTS.md: Arrow C Data Interface, aggregate filter, fluent example - Clarify that DataFusion works with any Arrow C Data Interface implementation, not just PyArrow. - Show the filter keyword argument on aggregate functions (the idiomatic HAVING equivalent) instead of the post-aggregate .filter() pattern. - Update the SQL reference table to show FILTER (WHERE ...) syntax. - Remove the now-incorrect "Aggregate then filter for HAVING" pitfall. - Add .collect() to the fluent chaining example so the result is clearly materialized. Co-Authored-By: Claude Opus 4.6 (1M context) * Update agents file after working through the first tpc-h query using only the text description * Add feedback from working through each of the TPC-H queries * Address Copilot review feedback on AGENTS.md - Wrap CASE/WHEN method-chain examples in parentheses and assign to a variable so they are valid Python as shown (Copilot #1, #2). - Fix INTERSECT/EXCEPT mapping: the default distinct=False corresponds to INTERSECT ALL / EXCEPT ALL, not the distinct forms. Updated both the Set Operations section and the SQL reference table to show both the ALL and distinct variants (Copilot #4). - Change write_parquet / write_csv / write_json examples to file-style paths (output.parquet, etc.) to match the convention used in existing tests and examples. Note that a directory path is also valid for partitioned output (Copilot #5). Verified INTERSECT/EXCEPT semantics with a script: df1.intersect(df2) -> [1, 1, 2] (= INTERSECT ALL) df1.intersect(df2, distinct=True) -> [1, 2] (= INTERSECT) Co-Authored-By: Claude Opus 4.6 (1M context) * Use short-form comparisons in AGENTS.md examples Drop lit() on the RHS of comparison operators since Expr auto-wraps raw Python values, matching the style the guide recommends (Copilot #3, #6). Updates examples in the Aggregation, CASE/WHEN, SQL reference table, Common Pitfalls, Fluent Chaining, and Variables-as-CTEs sections, plus the __init__.py quick-start snippet. Prose explanations of the rule (which cite the long form as the thing to avoid) are left unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) * Move user guide from python/datafusion/AGENTS.md to SKILL.md The in-wheel AGENTS.md was not a real distribution channel -- no shipping agent walks site-packages for AGENTS.md files. Moving to SKILL.md at the repo root, with YAML frontmatter, lets the skill ecosystems (npx skills, Claude Code plugin marketplaces, community aggregators) discover it. Update the pointers in the contributor AGENTS.md and the __init__.py module docstring accordingly. The docstring now references the GitHub URL since the file no longer ships with the wheel. Co-Authored-By: Claude Opus 4.7 (1M context) * Address review feedback: doctest, streaming, date/timestamp - Convert the __init__.py quick-start block to doctest format so it is picked up by `pytest --doctest-modules` (already the project default), preventing silent rot. - Extract streaming into its own SKILL.md subsection with guidance on when to prefer execute_stream() over collect(), sync and async iteration, and execute_stream_partitioned() for per-partition streams. - Generalize the date-arithmetic rule from Date32 to both Date32 and Date64 (both reject Duration at any precision, both accept month_day_nano_interval), and note that Timestamp columns differ and do accept Duration. - Document the PyArrow-inherited type mapping returned by to_pydict()/to_pylist(), including the nanosecond fallback to pandas.Timestamp / pandas.Timedelta and the to_pandas() footgun where date columns come back as an object dtype. Co-Authored-By: Claude Opus 4.7 (1M context) * Distinguish user guide from agent reference in module docstring The docstring pointed readers at SKILL.md as a "comprehensive guide," but SKILL.md is written in a dense, skill-oriented format for agents — humans are better served by the online user guide. Put the online docs first as the primary reference and label the SKILL.md link as the agent reference. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- AGENTS.md | 34 +- SKILL.md | 733 ++++++++++++++++++++++++++++++++++ python/datafusion/__init__.py | 42 +- 3 files changed, 804 insertions(+), 5 deletions(-) create mode 100644 SKILL.md diff --git a/AGENTS.md b/AGENTS.md index 86c2e9c3b..7d3262710 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,7 +17,14 @@ under the License. --> -# Agent Instructions +# Agent Instructions for Contributors + +This file is for agents working **on** the datafusion-python project (developing, +testing, reviewing). If you need to **use** the DataFusion DataFrame API (write +queries, build expressions, understand available functions), see the user-facing +skill at [`SKILL.md`](SKILL.md). + +## Skills This project uses AI agent skills stored in `.ai/skills/`. Each skill is a directory containing a `SKILL.md` file with instructions for performing a specific task. @@ -26,6 +33,31 @@ Skills follow the [Agent Skills](https://agentskills.io) open standard. Each ski - `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. - Additional supporting files as needed. +## Pull Requests + +Every pull request must follow the template in +`.github/pull_request_template.md`. The description must include these sections: + +1. **Which issue does this PR close?** — Link the issue with `Closes #NNN`. +2. **Rationale for this change** — Why the change is needed (skip if the issue + already explains it clearly). +3. **What changes are included in this PR?** — Summarize the individual changes. +4. **Are there any user-facing changes?** — Note any changes visible to users + (new APIs, changed behavior, new files shipped in the package, etc.). If + there are breaking changes to public APIs, add the `api change` label. + +## Pre-commit Checks + +Always run pre-commit checks **before** committing. The hooks are defined in +`.pre-commit-config.yaml` and run automatically on `git commit` if pre-commit +is installed as a git hook. To run all hooks manually: + +```bash +pre-commit run --all-files +``` + +Fix any failures before committing. + ## Python Function Docstrings Every Python function must include a docstring with usage examples. diff --git a/SKILL.md b/SKILL.md new file mode 100644 index 000000000..9ba1c0cac --- /dev/null +++ b/SKILL.md @@ -0,0 +1,733 @@ + + +--- +name: datafusion-python +description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. +--- + +# DataFusion Python DataFrame API Guide + +## What Is DataFusion? + +DataFusion is an **in-process query engine** built on Apache Arrow. It is not a +database -- there is no server, no connection string, and no external +dependencies. You create a `SessionContext`, point it at data (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API described below. + +All data flows through **Apache Arrow**. The canonical Python implementation is +PyArrow (`pyarrow.RecordBatch` / `pyarrow.Table`), but any library that +conforms to the [Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +can interoperate with DataFusion. + +## Core Abstractions + +| Abstraction | Role | Key import | +|---|---|---| +| `SessionContext` | Entry point. Loads data, runs SQL, produces DataFrames. | `from datafusion import SessionContext` | +| `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | +| `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | +| `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | + +## Import Conventions + +```python +from datafusion import SessionContext, col, lit +from datafusion import functions as F +``` + +## Data Loading + +```python +ctx = SessionContext() + +# From files +df = ctx.read_parquet("path/to/data.parquet") +df = ctx.read_csv("path/to/data.csv") +df = ctx.read_json("path/to/data.json") + +# From Python objects +df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) +df = ctx.from_pylist([{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]) +df = ctx.from_pandas(pandas_df) +df = ctx.from_polars(polars_df) +df = ctx.from_arrow(arrow_table) + +# From SQL +df = ctx.sql("SELECT a, b FROM my_table WHERE a > 1") +``` + +To make a DataFrame queryable by name in SQL, register it first: + +```python +ctx.register_parquet("my_table", "path/to/data.parquet") +ctx.register_csv("my_table", "path/to/data.csv") +``` + +## DataFrame Operations Quick Reference + +Every method returns a **new** DataFrame (immutable/lazy). Chain them fluently. + +### Projection + +```python +df.select("a", "b") # preferred: plain names as strings +df.select(col("a"), (col("b") + 1).alias("b_plus_1")) # use col()/Expr only when you need an expression + +df.with_column("new_col", col("a") + lit(10)) # add one column +df.with_columns( + col("a").alias("x"), + y=col("b") + lit(1), # named keyword form +) + +df.drop("unwanted_col") +df.with_column_renamed("old_name", "new_name") +``` + +When a column is referenced by name alone, pass the name as a string rather +than wrapping it in `col()`. Reach for `col()` only when the projection needs +arithmetic, aliasing, casting, or another expression operation. + +**Case sensitivity**: both `select("Name")` and `col("Name")` lowercase the +identifier. For a column whose real name has uppercase letters, embed double +quotes inside the string: `select('"MyCol"')` or `col('"MyCol"')`. Without the +inner quotes the lookup will fail with `No field named mycol`. + +### Filtering + +```python +df.filter(col("a") > 10) +df.filter(col("a") > 10, col("b") == "x") # multiple = AND +df.filter("a > 10") # SQL expression string +``` + +Raw Python values on the right-hand side of a comparison are auto-wrapped +into literals by the `Expr` operators, so prefer `col("a") > 10` over +`col("a") > lit(10)`. See the Comparisons section and pitfall #2 for the +full rule. + +### Aggregation + +```python +# GROUP BY a, compute sum(b) and count(*) +df.aggregate(["a"], [F.sum(col("b")), F.count(col("a"))]) + +# HAVING equivalent: use the filter keyword on the aggregate function +df.aggregate( + ["region"], + [F.sum(col("sales"), filter=col("sales") > 1000).alias("large_sales")], +) +``` + +As with `select()`, group keys can be passed as plain name strings. Reach for +`col(...)` only when the grouping expression needs arithmetic, aliasing, +casting, or another expression operation. + +Most aggregate functions accept an optional `filter` keyword argument. When +provided, only rows where the filter expression is true contribute to the +aggregate. + +### Sorting + +```python +df.sort(col("a")) # ascending (default) +df.sort(col("a").sort(ascending=False)) # descending +df.sort(col("a").sort(nulls_first=False)) # override null placement +``` + +A plain expression passed to `sort()` is already treated as ascending. Only +reach for `col(...).sort(...)` when you need to override a default (descending +order or null placement). Writing `col("a").sort(ascending=True)` is redundant. + +### Joining + +```python +# Equi-join on shared column name +df1.join(df2, on="key") +df1.join(df2, on="key", how="left") + +# Different column names +df1.join(df2, left_on="id", right_on="fk_id", how="inner") + +# Expression-based join (supports inequality predicates) +df1.join_on(df2, col("a") == col("b"), how="inner") + +# Semi join: keep rows from left where a match exists in right (like EXISTS) +df1.join(df2, on="key", how="semi") + +# Anti join: keep rows from left where NO match exists in right (like NOT EXISTS) +df1.join(df2, on="key", how="anti") +``` + +Join types: `"inner"`, `"left"`, `"right"`, `"full"`, `"semi"`, `"anti"`. + +Inner is the default `how`. Prefer `df1.join(df2, on="key")` over +`df1.join(df2, on="key", how="inner")` — drop `how=` unless you need a +non-inner join type. + +When the two sides' join columns have different native names, use +`left_on=`/`right_on=` with the original names rather than aliasing one side +to match the other — see pitfall #7. + +### Window Functions + +```python +from datafusion import WindowFrame + +# Row number partitioned by group, ordered by value +df.window( + F.row_number( + partition_by=[col("group")], + order_by=[col("value")], + ).alias("rn") +) + +# Using a Window object for reuse +from datafusion.expr import Window + +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], +) +df.select( + col("group"), + col("value"), + F.sum(col("value")).over(win).alias("running_total"), +) + +# With explicit frame bounds +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], + window_frame=WindowFrame("rows", 0, None), # current row to unbounded following +) +``` + +### Set Operations + +```python +df1.union(df2) # UNION ALL (by position) +df1.union(df2, distinct=True) # UNION DISTINCT +df1.union_by_name(df2) # match columns by name, not position +df1.intersect(df2) # INTERSECT ALL +df1.intersect(df2, distinct=True) # INTERSECT (distinct) +df1.except_all(df2) # EXCEPT ALL +df1.except_all(df2, distinct=True) # EXCEPT (distinct) +``` + +### Limit and Offset + +```python +df.limit(10) # first 10 rows +df.limit(10, offset=20) # skip 20, then take 10 +``` + +### Deduplication + +```python +df.distinct() # remove duplicate rows +df.distinct_on( # keep first row per group (like DISTINCT ON in Postgres) + [col("a")], # uniqueness columns + [col("a"), col("b")], # output columns + [col("b").sort(ascending=True)], # which row to keep +) +``` + +## Executing and Collecting Results + +DataFrames are lazy until you collect. + +```python +df.show() # print formatted table to stdout +batches = df.collect() # list[pa.RecordBatch] +arr = df.collect_column("col_name") # pa.Array | pa.ChunkedArray (single column) +table = df.to_arrow_table() # pa.Table +pandas_df = df.to_pandas() # pd.DataFrame +polars_df = df.to_polars() # pl.DataFrame +py_dict = df.to_pydict() # dict[str, list] +py_list = df.to_pylist() # list[dict] +count = df.count() # int +``` + +### Date and Timestamp Type Conversion + +The Python type returned by `to_pydict()` / `to_pylist()` depends on the Arrow +column type, and the mapping is inherited from PyArrow: + +| Arrow type | Python type returned | +|---|---| +| `timestamp(s)` / `(ms)` / `(us)` | `datetime.datetime` | +| `timestamp(ns)` | `pandas.Timestamp` | +| `date32` / `date64` | `datetime.date` | +| `duration(s)` / `(ms)` / `(us)` | `datetime.timedelta` | +| `duration(ns)` | `pandas.Timedelta` | + +The nanosecond-precision fallback to pandas types is the main surprise: +pandas is not a hard dependency of `datafusion`, but PyArrow reaches for it +when `datetime.datetime` / `datetime.timedelta` would lose precision (stdlib +types only go to microseconds). If you need plain stdlib types, cast to a +coarser unit before collecting, e.g. +`df.select(col("ts").cast(pa.timestamp("us")))`. + +`df.to_pandas()` has its own footgun for dates: pandas has no pure-date dtype, +so a `date32`/`date64` column comes back as an `object` column of +`datetime.date` values rather than `datetime64[ns]`. If downstream code +expects a datetime column, cast on the DataFusion side first: +`col("ship_date").cast(pa.timestamp("ns"))`. + +### Streaming Results + +Prefer streaming over `collect()` when the result is too large to materialize +in memory, when you want to start processing before the query finishes, or +when you may break out of the loop early. `execute_stream()` pulls one +`RecordBatch` at a time from the execution plan rather than buffering the +whole result up front. + +```python +# Single-partition stream; batch is a datafusion.RecordBatch +stream = df.execute_stream() +for batch in stream: + process(batch.to_pyarrow()) # convert to pa.RecordBatch if needed + +# DataFrame is iterable directly (delegates to execute_stream) +for batch in df: + process(batch.to_pyarrow()) + +# One stream per partition, for parallel consumption +for stream in df.execute_stream_partitioned(): + for batch in stream: + process(batch.to_pyarrow()) +``` + +Async iteration is also supported via `async for batch in df: ...` (or +`df.execute_stream()`), which is useful when batches are interleaved with +other I/O. + +### Writing Results + +```python +df.write_parquet("output.parquet") +df.write_csv("output.csv") +df.write_json("output.json") +``` + +You can also pass a directory path (e.g., `"output/"`) to write a multi-file +partitioned output. + +## Expression Building + +### Column References and Literals + +```python +col("column_name") # reference a column +lit(42) # integer literal +lit("hello") # string literal +lit(3.14) # float literal +lit(pa.scalar(value)) # PyArrow scalar (preserves Arrow type) +``` + +`lit()` accepts PyArrow scalars directly -- prefer this over converting Arrow +data to Python and back when working with values extracted from query results. + +### Arithmetic + +```python +col("price") * col("quantity") # multiplication +col("a") + lit(1) # addition +col("a") - col("b") # subtraction +col("a") / lit(2) # division +col("a") % lit(3) # modulo +``` + +### Date Arithmetic + +`Date32` and `Date64` columns both require `Interval` types for arithmetic, +not `Duration`. Use PyArrow's `month_day_nano_interval` type, which takes a +`(months, days, nanos)` tuple: + +```python +import pyarrow as pa + +# Subtract 90 days from a date column +col("ship_date") - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) + +# Subtract 3 months +col("ship_date") - lit(pa.scalar((3, 0, 0), type=pa.month_day_nano_interval())) +``` + +**Important**: `lit(datetime.timedelta(days=90))` creates a `Duration(µs)` +literal, which is **not** compatible with `Date32`/`Date64` arithmetic +(`Duration(ms)` and `Duration(ns)` are rejected too). Always use +`pa.month_day_nano_interval()` for date operations. + +**Timestamps behave differently**: `Timestamp` columns *do* accept `Duration`, +so `col("ts") - lit(datetime.timedelta(days=1))` works. The interval-only +rule applies specifically to date columns. + +### Comparisons + +```python +col("a") > 10 +col("a") >= 10 +col("a") < 10 +col("a") <= 10 +col("a") == "x" +col("a") != "x" +col("a") == None # same as col("a").is_null() +col("a") != None # same as col("a").is_not_null() +``` + +Comparison operators auto-wrap the right-hand Python value into a literal, +so writing `col("a") > lit(10)` is redundant. Drop the `lit()` in +comparisons. Reach for `lit()` only when auto-wrapping does not apply — see +pitfall #2. + +### Boolean Logic + +**Important**: Python's `and`, `or`, `not` keywords do NOT work with Expr +objects. You must use the bitwise operators: + +```python +(col("a") > 1) & (col("b") < 10) # AND +(col("a") > 1) | (col("b") < 10) # OR +~(col("a") > 1) # NOT +``` + +Always wrap each comparison in parentheses when combining with `&`, `|`, `~` +because Python's operator precedence for bitwise operators is different from +logical operators. + +### Null Handling + +```python +col("a").is_null() +col("a").is_not_null() +col("a").fill_null(lit(0)) # replace NULL with a value +F.coalesce(col("a"), col("b")) # first non-null value +F.nullif(col("a"), lit(0)) # return NULL if a == 0 +``` + +### CASE / WHEN + +```python +# Simple CASE (matching on a single expression) +status_label = ( + F.case(col("status")) + .when(lit("A"), lit("Active")) + .when(lit("I"), lit("Inactive")) + .otherwise(lit("Unknown")) +) + +# Searched CASE (each branch has its own predicate) +severity = ( + F.when(col("value") > 100, lit("high")) + .when(col("value") > 50, lit("medium")) + .otherwise(lit("low")) +) +``` + +### Casting + +```python +import pyarrow as pa + +col("a").cast(pa.float64()) +col("a").cast(pa.utf8()) +col("a").cast(pa.date32()) +``` + +### Aliasing + +```python +(col("a") + col("b")).alias("total") +``` + +### BETWEEN and IN + +```python +col("a").between(lit(1), lit(10)) # 1 <= a <= 10 +F.in_list(col("a"), [lit(1), lit(2), lit(3)]) # a IN (1, 2, 3) +F.in_list(col("a"), [lit(1), lit(2)], negated=True) # a NOT IN (1, 2) +``` + +### Struct and Array Access + +```python +col("struct_col")["field_name"] # access struct field +col("array_col")[0] # access array element (0-indexed) +col("array_col")[1:3] # array slice (0-indexed) +``` + +## SQL-to-DataFrame Reference + +| SQL | DataFrame API | +|---|---| +| `SELECT a, b` | `df.select("a", "b")` | +| `SELECT a, b + 1 AS c` | `df.select(col("a"), (col("b") + lit(1)).alias("c"))` | +| `SELECT *, a + 1 AS c` | `df.with_column("c", col("a") + lit(1))` | +| `WHERE a > 10` | `df.filter(col("a") > 10)` | +| `GROUP BY a` with `SUM(b)` | `df.aggregate(["a"], [F.sum(col("b"))])` | +| `SUM(b) FILTER (WHERE b > 100)` | `F.sum(col("b"), filter=col("b") > 100)` | +| `ORDER BY a DESC` | `df.sort(col("a").sort(ascending=False))` | +| `LIMIT 10 OFFSET 5` | `df.limit(10, offset=5)` | +| `DISTINCT` | `df.distinct()` | +| `a INNER JOIN b ON a.id = b.id` | `a.join(b, on="id")` | +| `a LEFT JOIN b ON a.id = b.fk` | `a.join(b, left_on="id", right_on="fk", how="left")` | +| `WHERE EXISTS (SELECT ...)` | `a.join(b, on="key", how="semi")` | +| `WHERE NOT EXISTS (SELECT ...)` | `a.join(b, on="key", how="anti")` | +| `UNION ALL` | `df1.union(df2)` | +| `UNION` (distinct) | `df1.union(df2, distinct=True)` | +| `INTERSECT ALL` | `df1.intersect(df2)` | +| `INTERSECT` (distinct) | `df1.intersect(df2, distinct=True)` | +| `EXCEPT ALL` | `df1.except_all(df2)` | +| `EXCEPT` (distinct) | `df1.except_all(df2, distinct=True)` | +| `CASE x WHEN 1 THEN 'a' END` | `F.case(col("x")).when(lit(1), lit("a")).end()` | +| `CASE WHEN x > 1 THEN 'a' END` | `F.when(col("x") > 1, lit("a")).end()` | +| `x IN (1, 2, 3)` | `F.in_list(col("x"), [lit(1), lit(2), lit(3)])` | +| `x BETWEEN 1 AND 10` | `col("x").between(lit(1), lit(10))` | +| `CAST(x AS DOUBLE)` | `col("x").cast(pa.float64())` | +| `ROW_NUMBER() OVER (...)` | `F.row_number(partition_by=[...], order_by=[...])` | +| `SUM(x) OVER (...)` | `F.sum(col("x")).over(window)` | +| `x IS NULL` | `col("x").is_null()` | +| `COALESCE(a, b)` | `F.coalesce(col("a"), col("b"))` | + +## Common Pitfalls + +1. **Boolean operators**: Use `&`, `|`, `~` -- not Python's `and`, `or`, `not`. + Always parenthesize: `(col("a") > 1) & (col("b") < 2)`. + +2. **Wrapping scalars with `lit()`**: Prefer raw Python values on the + right-hand side of comparisons — `col("a") > 10`, `col("name") == "Alice"` + — because the Expr comparison operators auto-wrap them. Writing + `col("a") > lit(10)` is redundant. Reserve `lit()` for places where + auto-wrapping does *not* apply: + - standalone scalars passed into function calls: + `F.coalesce(col("a"), lit(0))`, not `F.coalesce(col("a"), 0)` + - arithmetic between two literals with no column involved: + `lit(1) - col("discount")` is fine, but `lit(1) - lit(2)` needs both + - values that must carry a specific Arrow type, via `lit(pa.scalar(...))` + - `.when(...)`, `.otherwise(...)`, `F.nullif(...)`, `.between(...)`, + `F.in_list(...)` and similar method/function arguments + +3. **Column name quoting**: Column names are normalized to lowercase by default + in both `select("...")` and `col("...")`. To reference a column with + uppercase letters, use double quotes inside the string: + `select('"MyColumn"')` or `col('"MyColumn"')`. + +4. **DataFrames are immutable**: Every method returns a **new** DataFrame. You + must capture the return value: + ```python + df = df.filter(col("a") > 1) # correct + df.filter(col("a") > 1) # WRONG -- result is discarded + ``` + +5. **Window frame defaults**: When using `order_by` in a window, the default + frame is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`. For a full + partition frame, set `window_frame=WindowFrame("rows", None, None)`. + +6. **Arithmetic on aggregates belongs in a later `select`, not inside + `aggregate`**: Each item in the aggregate list must be a single aggregate + call (optionally aliased). Combining aggregates with arithmetic inside + `aggregate(...)` fails with `Internal error: Invalid aggregate expression`. + Alias the aggregates, then compute the combination downstream: + ```python + # WRONG -- arithmetic wraps two aggregates + df.aggregate([], [(lit(100) * F.sum(col("a")) / F.sum(col("b"))).alias("ratio")]) + + # CORRECT -- aggregate first, then combine + (df.aggregate([], [F.sum(col("a")).alias("num"), F.sum(col("b")).alias("den")]) + .select((lit(100) * col("num") / col("den")).alias("ratio"))) + ``` + +7. **Don't alias a join column to match the other side**: When equi-joining + with `on="key"`, renaming the join column on one side via `.alias("key")` + in a fresh projection creates a schema where one side's `key` is + qualified (`?table?.key`) and the other is unqualified. The join then + fails with `Schema contains qualified field name ... and unqualified + field name ... which would be ambiguous`. Use `left_on=`/`right_on=` with + the native names, or use `join_on(...)` with an explicit equality. + ```python + # WRONG -- alias on one side produces ambiguous schema after join + failed = orders.select(col("o_orderkey").alias("l_orderkey")) + li.join(failed, on="l_orderkey") # ambiguous l_orderkey error + + # CORRECT -- keep native names, use left_on/right_on + failed = orders.select("o_orderkey") + li.join(failed, left_on="l_orderkey", right_on="o_orderkey") + + # ALSO CORRECT -- explicit predicate via join_on + # (note: join_on keeps both key columns in the output, unlike on="key") + li.join_on(failed, col("l_orderkey") == col("o_orderkey")) + ``` + +## Idiomatic Patterns + +### Fluent Chaining + +```python +result = ( + ctx.read_parquet("data.parquet") + .filter(col("year") >= 2020) + .select(col("region"), col("sales")) + .aggregate(["region"], [F.sum(col("sales")).alias("total")]) + .sort(col("total").sort(ascending=False)) + .limit(10) +) +result.show() +``` + +### Using Variables as CTEs + +Instead of SQL CTEs (`WITH ... AS`), assign intermediate DataFrames to +variables: + +```python +base = ctx.read_parquet("orders.parquet").filter(col("status") == "shipped") +by_region = base.aggregate(["region"], [F.sum(col("amount")).alias("total")]) +top_regions = by_region.filter(col("total") > 10000) +``` + +### Reusing Expressions as Variables + +Just like DataFrames, expressions (`Expr`) can be stored in variables and used +anywhere an `Expr` is expected. This is useful for building up complex +expressions or reusing a computed value across multiple operations: + +```python +# Build an expression and reuse it +disc_price = col("price") * (lit(1) - col("discount")) +df = df.select( + col("id"), + disc_price.alias("disc_price"), + (disc_price * (lit(1) + col("tax"))).alias("total"), +) + +# Use a collected scalar as an expression +max_val = result_df.collect_column("max_price")[0] # PyArrow scalar +cutoff = lit(max_val) - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) +df = df.filter(col("ship_date") <= cutoff) # cutoff is already an Expr +``` + +**Important**: Do not wrap an `Expr` in `lit()`. `lit()` is for converting +Python/PyArrow values into expressions. If a value is already an `Expr`, use it +directly. + +### Window Functions for Scalar Subqueries + +Where SQL uses a correlated scalar subquery, the idiomatic DataFrame approach +is a window function: + +```sql +-- SQL scalar subquery +SELECT *, (SELECT SUM(b) FROM t WHERE t.group = s.group) AS group_total FROM s +``` + +```python +# DataFrame: window function +win = Window(partition_by=[col("group")]) +df = df.with_column("group_total", F.sum(col("b")).over(win)) +``` + +### Semi/Anti Joins for EXISTS / NOT EXISTS + +```sql +-- SQL: WHERE EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="semi") + +-- SQL: WHERE NOT EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="anti") +``` + +### Computed Columns + +```python +# Add computed columns while keeping all originals +df = df.with_column("full_name", F.concat(col("first"), lit(" "), col("last"))) +df = df.with_column("discounted", col("price") * lit(0.9)) +``` + +## Available Functions (Categorized) + +The `functions` module (imported as `F`) provides 290+ functions. Key categories: + +**Aggregate**: `sum`, `avg`, `min`, `max`, `count`, `count_star`, `median`, +`stddev`, `stddev_pop`, `var_samp`, `var_pop`, `corr`, `covar`, `approx_distinct`, +`approx_median`, `approx_percentile_cont`, `array_agg`, `string_agg`, +`first_value`, `last_value`, `bit_and`, `bit_or`, `bit_xor`, `bool_and`, +`bool_or`, `grouping`, `regr_*` (9 regression functions) + +**Window**: `row_number`, `rank`, `dense_rank`, `percent_rank`, `cume_dist`, +`ntile`, `lag`, `lead`, `first_value`, `last_value`, `nth_value` + +**String**: `length`, `lower`, `upper`, `trim`, `ltrim`, `rtrim`, `lpad`, +`rpad`, `starts_with`, `ends_with`, `contains`, `substr`, `substring`, +`replace`, `reverse`, `repeat`, `split_part`, `concat`, `concat_ws`, +`initcap`, `ascii`, `chr`, `left`, `right`, `strpos`, `translate`, `overlay`, +`levenshtein` + +`F.substr(str, start)` takes **only two arguments** and returns the tail of +the string from `start` onward — passing a third length argument raises +`TypeError: substr() takes 2 positional arguments but 3 were given`. For the +SQL-style 3-arg form (`SUBSTRING(str FROM start FOR length)`), use +`F.substring(col("s"), lit(start), lit(length))`. For a fixed-length prefix, +`F.left(col("s"), lit(n))` is cleanest. + +```python +# WRONG — substr does not accept a length argument +F.substr(col("c_phone"), lit(1), lit(2)) +# CORRECT +F.substring(col("c_phone"), lit(1), lit(2)) # explicit length +F.left(col("c_phone"), lit(2)) # prefix shortcut +``` + +**Math**: `abs`, `ceil`, `floor`, `round`, `trunc`, `sqrt`, `cbrt`, `exp`, +`ln`, `log`, `log2`, `log10`, `pow`, `signum`, `pi`, `random`, `factorial`, +`gcd`, `lcm`, `greatest`, `least`, sin/cos/tan and inverse/hyperbolic variants + +**Date/Time**: `now`, `today`, `current_date`, `current_time`, +`current_timestamp`, `date_part`, `date_trunc`, `date_bin`, `extract`, +`to_timestamp`, `to_timestamp_millis`, `to_timestamp_micros`, +`to_timestamp_nanos`, `to_timestamp_seconds`, `to_unixtime`, `from_unixtime`, +`make_date`, `make_time`, `to_date`, `to_time`, `to_local_time`, `date_format` + +**Conditional**: `case`, `when`, `coalesce`, `nullif`, `ifnull`, `nvl`, `nvl2` + +**Array/List**: `array`, `make_array`, `array_agg`, `array_length`, +`array_element`, `array_slice`, `array_append`, `array_prepend`, +`array_concat`, `array_has`, `array_has_all`, `array_has_any`, `array_position`, +`array_remove`, `array_distinct`, `array_sort`, `array_reverse`, `flatten`, +`array_to_string`, `array_intersect`, `array_union`, `array_except`, +`generate_series` +(Most `array_*` functions also have `list_*` aliases.) + +**Struct/Map**: `struct`, `named_struct`, `get_field`, `make_map`, `map_keys`, +`map_values`, `map_entries`, `map_extract` + +**Regex**: `regexp_like`, `regexp_match`, `regexp_replace`, `regexp_count`, +`regexp_instr` + +**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest` + +**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata` + +**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, +`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 80dfa2fab..e4972411a 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,10 +15,44 @@ # specific language governing permissions and limitations # under the License. -"""DataFusion python package. - -This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. -See https://datafusion.apache.org/python for more information. +"""DataFusion: an in-process query engine built on Apache Arrow. + +DataFusion is not a database -- it has no server and no external dependencies. +You create a :py:class:`SessionContext`, point it at data sources (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API. + +Core abstractions +----------------- +- **SessionContext** -- entry point for loading data, running SQL, and creating + DataFrames. +- **DataFrame** -- lazy query builder. Every method returns a new DataFrame; + call :py:meth:`~datafusion.dataframe.DataFrame.collect` or a ``to_*`` + method to execute. +- **Expr** -- expression tree node for column references, literals, and function + calls. Build with :py:func:`col` and :py:func:`lit`. +- **functions** -- 290+ built-in scalar, aggregate, and window functions. + +Quick start +----------- + +>>> from datafusion import SessionContext, col +>>> from datafusion import functions as F +>>> ctx = SessionContext() +>>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) +>>> result = ( +... df.filter(col("a") > 1) +... .with_column("total", col("a") + col("b")) +... .aggregate([], [F.sum(col("total")).alias("grand_total")]) +... ) +>>> result.to_pydict() +{'grand_total': [16]} + +User guide and full documentation: https://datafusion.apache.org/python + +AI agent reference (SQL-to-DataFrame mappings, expression-building patterns, +common pitfalls), written in a dense, skill-oriented format: +https://github.com/apache/datafusion-python/blob/main/SKILL.md """ from __future__ import annotations From 8a5d783c7e418bfbbd95e48a2d9cacafea6162c7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 19:05:06 -0400 Subject: [PATCH 25/29] Skills require the header to be the first thing in the file which conflicts with the RAT check. Make an exception for this file. (#1501) --- SKILL.md | 19 ------------------- dev/release/rat_exclude_files.txt | 3 ++- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/SKILL.md b/SKILL.md index 9ba1c0cac..14ea5c609 100644 --- a/SKILL.md +++ b/SKILL.md @@ -1,22 +1,3 @@ - - --- name: datafusion-python description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index b2db144e8..a7a497dab 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -48,4 +48,5 @@ benchmarks/tpch/create_tables.sql .cargo/config.toml **/.cargo/config.toml uv.lock -examples/tpch/answers_sf1/*.tbl \ No newline at end of file +examples/tpch/answers_sf1/*.tbl +SKILL.md \ No newline at end of file From 8741d30cd812e4668f3f9187b56f12ce2de0d6e7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 22:01:01 -0400 Subject: [PATCH 26/29] docs: enrich module docstrings and add doctest examples (#1498) * Enrich module docstrings and add doctest examples Expands the module docstrings for `functions.py`, `dataframe.py`, `expr.py`, and `context.py` so each module opens with a concept summary, cross-references to related APIs, and a small executable example. Adds doctest examples to the high-traffic `DataFrame` methods that previously lacked them: `select`, `aggregate`, `sort`, `limit`, `join`, and `union`. Optional parameters are demonstrated with keyword syntax, and examples reuse the same input data across variants so the effect of each option is easy to see. Co-Authored-By: Claude Opus 4.7 (1M context) * Use distinct group sums in aggregate docstring example Change the score data from [1, 2, 3] to [1, 2, 5] so the grouped result produces [3, 5] instead of [3, 3], removing ambiguity about which total belongs to which team. Co-Authored-By: Claude Opus 4.7 (1M context) * Align module-docstring examples with SKILL.md idioms Drop the redundant lit() in the dataframe.py module-docstring filter example and use a plain string group key in the aggregate() doctest, so both examples model the style SKILL.md recommends. Also document the sort("a") string form and sort_by() shortcut in SKILL.md's sorting section. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- SKILL.md | 16 +++- python/datafusion/context.py | 27 ++++++- python/datafusion/dataframe.py | 135 ++++++++++++++++++++++++++++++--- python/datafusion/expr.py | 28 ++++++- python/datafusion/functions.py | 22 +++++- 5 files changed, 209 insertions(+), 19 deletions(-) diff --git a/SKILL.md b/SKILL.md index 14ea5c609..7b07b430f 100644 --- a/SKILL.md +++ b/SKILL.md @@ -128,14 +128,22 @@ aggregate. ### Sorting ```python -df.sort(col("a")) # ascending (default) +df.sort("a") # ascending (plain name, preferred) +df.sort(col("a")) # ascending via col() df.sort(col("a").sort(ascending=False)) # descending df.sort(col("a").sort(nulls_first=False)) # override null placement + +df.sort_by("a", "b") # ascending-only shortcut ``` -A plain expression passed to `sort()` is already treated as ascending. Only -reach for `col(...).sort(...)` when you need to override a default (descending -order or null placement). Writing `col("a").sort(ascending=True)` is redundant. +As with `select()` and `aggregate()`, bare column references can be passed as +plain name strings. A plain expression passed to `sort()` is already treated +as ascending, so reach for `col(...).sort(...)` only when you need to override +a default (descending order or null placement). Writing +`col("a").sort(ascending=True)` is redundant. + +For ascending-only sorts with no null-placement override, `df.sort_by(...)` is +a shorter alias for `df.sort(...)`. ### Joining diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c3f94cc16..dd6790402 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -15,7 +15,32 @@ # specific language governing permissions and limitations # under the License. -"""Session Context and it's associated configuration.""" +""":py:class:`SessionContext` — entry point for running DataFusion queries. + +A :py:class:`SessionContext` holds registered tables, catalogs, and +configuration for the current session. It is the first object most programs +create: from it you register data, run SQL strings +(:py:meth:`SessionContext.sql`), read files +(:py:meth:`SessionContext.read_csv`, +:py:meth:`SessionContext.read_parquet`, ...), and construct +:py:class:`~datafusion.dataframe.DataFrame` objects in memory +(:py:meth:`SessionContext.from_pydict`, +:py:meth:`SessionContext.from_arrow`). + +Session behavior (memory limits, batch size, configured optimizer passes, +...) is controlled by :py:class:`SessionConfig` and +:py:class:`RuntimeEnvBuilder`; SQL dialect limits are controlled by +:py:class:`SQLOptions`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> ctx.sql("SELECT 1 AS n").to_pydict() + {'n': [1]} + +See :ref:`user_guide_concepts` in the online documentation for the broader +execution model. +""" from __future__ import annotations diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c00c85fdb..2b07861da 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -14,9 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""":py:class:`DataFrame` is one of the core concepts in DataFusion. - -See :ref:`user_guide_concepts` in the online documentation for more information. +""":py:class:`DataFrame` — lazy, chainable query representation. + +A :py:class:`DataFrame` is a logical plan over one or more data sources. +Methods that reshape the plan (:py:meth:`DataFrame.select`, +:py:meth:`DataFrame.filter`, :py:meth:`DataFrame.aggregate`, +:py:meth:`DataFrame.sort`, :py:meth:`DataFrame.join`, +:py:meth:`DataFrame.limit`, the set-operation methods, ...) return a new +:py:class:`DataFrame` and do no work until a terminal method such as +:py:meth:`DataFrame.collect`, :py:meth:`DataFrame.to_pydict`, +:py:meth:`DataFrame.show`, or one of the ``write_*`` methods is called. + +DataFrames are produced from a +:py:class:`~datafusion.context.SessionContext`, typically via +:py:meth:`~datafusion.context.SessionContext.sql`, +:py:meth:`~datafusion.context.SessionContext.read_csv`, +:py:meth:`~datafusion.context.SessionContext.read_parquet`, or +:py:meth:`~datafusion.context.SessionContext.from_pydict`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.filter(col("a") > 1).select("b").to_pydict() + {'b': [20, 30]} + +See :ref:`user_guide_concepts` in the online documentation for a high-level +overview of the execution model. """ from __future__ import annotations @@ -503,21 +526,29 @@ def select_exprs(self, *args: str) -> DataFrame: def select(self, *exprs: Expr | str) -> DataFrame: """Project arbitrary expressions into a new :py:class:`DataFrame`. + String arguments are treated as column names; :py:class:`~datafusion.expr.Expr` + arguments can reshape, rename, or compute new columns. + Args: exprs: Either column names or :py:class:`~datafusion.expr.Expr` to select. Returns: DataFrame after projection. It has one column for each expression. - Example usage: + Examples: + Select columns by name: - The following example will return 3 columns from the original dataframe. - The first two columns will be the original column ``a`` and ``b`` since the - string "a" is assumed to refer to column selection. Also a duplicate of - column ``a`` will be returned with the column name ``alternate_a``:: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.select("a").to_pydict() + {'a': [1, 2, 3]} - df = df.select("a", col("b"), col("a").alias("alternate_a")) + Mix column names, expressions, and aliases. The string ``"a"`` selects + column ``a`` directly; ``col("a").alias("alternate_a")`` returns a + duplicate under a new name: + >>> df.select("a", col("b"), col("a").alias("alternate_a")).to_pydict() + {'a': [1, 2, 3], 'b': [10, 20, 30], 'alternate_a': [1, 2, 3]} """ exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) @@ -766,6 +797,24 @@ def aggregate( Returns: DataFrame after aggregation. + + Examples: + Aggregate without grouping — an empty ``group_by`` produces a + single row: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"team": ["x", "x", "y"], "score": [1, 2, 5]} + ... ) + >>> df.aggregate([], [F.sum(col("score")).alias("total")]).to_pydict() + {'total': [8]} + + Group by a column and produce one row per group: + + >>> df.aggregate( + ... ["team"], [F.sum(col("score")).alias("total")] + ... ).sort("team").to_pydict() + {'team': ['x', 'y'], 'total': [3, 5]} """ group_by_list = ( list(group_by) @@ -786,13 +835,27 @@ def sort(self, *exprs: SortKey) -> DataFrame: """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its ``sort`` method. + calling its ``sort`` method. For ascending-only sorts, the shorter + :py:meth:`sort_by` is usually more convenient. Args: exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. + + Examples: + Sort ascending by a column name: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2], "b": [10, 20, 30]}) + >>> df.sort("a").to_pydict() + {'a': [1, 2, 3], 'b': [20, 30, 10]} + + Sort descending using :py:meth:`Expr.sort`: + + >>> df.sort(col("a").sort(ascending=False)).to_pydict() + {'a': [3, 2, 1], 'b': [10, 30, 20]} """ exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) @@ -812,12 +875,28 @@ def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. + Results are returned in unspecified order unless the DataFrame is + explicitly sorted first via :py:meth:`sort` or :py:meth:`sort_by`. + Args: count: Number of rows to limit the DataFrame to. offset: Number of rows to skip. Returns: DataFrame after limiting. + + Examples: + Take the first two rows: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}).sort("a") + >>> df.limit(2).to_pydict() + {'a': [1, 2]} + + Skip the first row then take two (paging): + + >>> df.limit(2, offset=1).to_pydict() + {'a': [2, 3]} """ return DataFrame(self.df.limit(count, offset)) @@ -972,6 +1051,28 @@ def join( Returns: DataFrame after join. + + Examples: + Inner-join two DataFrames on a shared column: + + >>> ctx = dfn.SessionContext() + >>> left = ctx.from_pydict({"id": [1, 2, 3], "val": [10, 20, 30]}) + >>> right = ctx.from_pydict({"id": [2, 3, 4], "label": ["b", "c", "d"]}) + >>> left.join(right, on="id").sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'label': ['b', 'c']} + + Left join to keep all rows from the left side: + + >>> left.join(right, on="id", how="left").sort("id").to_pydict() + {'id': [1, 2, 3], 'val': [10, 20, 30], 'label': [None, 'b', 'c']} + + Use ``left_on`` / ``right_on`` when the key columns differ in name: + + >>> right2 = ctx.from_pydict({"rid": [2, 3], "label": ["b", "c"]}) + >>> left.join( + ... right2, left_on="id", right_on="rid" + ... ).sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'rid': [2, 3], 'label': ['b', 'c']} """ if join_keys is not None: warnings.warn( @@ -1165,6 +1266,20 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: Returns: DataFrame after union. + + Examples: + Stack rows from both DataFrames, preserving duplicates: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2]}) + >>> df2 = ctx.from_pydict({"a": [2, 3]}) + >>> df1.union(df2).sort("a").to_pydict() + {'a': [1, 2, 2, 3]} + + Deduplicate the combined result with ``distinct=True``: + + >>> df1.union(df2, distinct=True).sort("a").to_pydict() + {'a': [1, 2, 3]} """ return DataFrame(self.df.union(other.df, distinct)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 32004656f..1ff6976f7 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,31 @@ # specific language governing permissions and limitations # under the License. -"""This module supports expressions, one of the core concepts in DataFusion. - -See :ref:`Expressions` in the online documentation for more details. +""":py:class:`Expr` — the logical expression type used to build DataFusion queries. + +An :py:class:`Expr` represents a computation over columns or literals: a +column reference (``col("a")``), a literal (``lit(5)``), an operator +combination (``col("a") + lit(1)``), or the output of a function from +:py:mod:`datafusion.functions`. Expressions are passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.sort`. + +Convenience constructors are re-exported at the package level: +:py:func:`datafusion.col` / :py:func:`datafusion.column` for column references +and :py:func:`datafusion.lit` / :py:func:`datafusion.literal` for scalar +literals. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.select((col("a") * lit(10)).alias("ten_a")).to_pydict() + {'ten_a': [10, 20, 30]} + +See :ref:`expressions` in the online documentation for details on available +operators and helpers. """ # ruff: noqa: PLC0415 diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 841cd9c0b..280a6d3ac 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,7 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User functions for operating on :py:class:`~datafusion.expr.Expr`.""" +"""Scalar, aggregate, and window functions for :py:class:`~datafusion.expr.Expr`. + +Each function returns an :py:class:`~datafusion.expr.Expr` that can be combined +with other expressions and passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.window`. The module is conventionally +imported as ``F`` so calls read like ``F.sum(col("price"))``. + +Examples: + >>> from datafusion import functions as F + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}) + >>> df.aggregate([], [F.sum(col("a")).alias("total")]).to_pydict() + {'total': [10]} + +See :ref:`aggregation` and :ref:`window_functions` in the online documentation +for categorized catalogs of aggregate and window functions. +""" from __future__ import annotations From c8bb9f7d3876de97141d204740a6b99d5facd10f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 07:57:11 -0400 Subject: [PATCH 27/29] docs: add README section for AI coding assistants (#1503) Points users to the repo-root SKILL.md via the npx skills registry or a manual AGENTS.md / CLAUDE.md pointer. Implements PR 1c of the plan in #1394. Co-authored-by: Claude Opus 4.7 (1M context) --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index 7849e7a02..4baed7d1d 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,22 @@ You can verify the installation by running: '0.6.0' ``` +## Using DataFusion with AI coding assistants + +This project ships a [`SKILL.md`](SKILL.md) at the repo root that teaches AI +coding assistants how to write idiomatic DataFusion Python. It follows the +[Agent Skills](https://agentskills.io) open standard. + +**Preferred:** `npx skills add apache/datafusion-python` — installs the skill in +Claude Code, Cursor, Windsurf, Cline, Codex, Copilot, Gemini CLI, and other +supported agents. + +**Manual:** paste this line into your project's `AGENTS.md` / `CLAUDE.md`: + +``` +For DataFusion Python code, see https://github.com/apache/datafusion-python/blob/main/SKILL.md +``` + ## How to develop This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). The Maturin tools used in this workflow can be installed either via `uv` or `pip`. Both approaches should offer the same experience. It is recommended to use `uv` since it has significant performance improvements From 03577163a057f791b19f30ce5130464a4a1c78a4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 11:47:06 -0400 Subject: [PATCH 28/29] tpch examples: rewrite queries idiomatically and embed reference SQL (#1504) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tpch examples: add reference SQL to each query, fix Q20 - Append the canonical TPC-H reference SQL (from benchmarks/tpch/queries/) to each q01..q22 module docstring so readers can compare the DataFrame translation against the SQL at a glance. - Fix Q20: `df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))` was missing the assignment so the filter was dropped from the pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: rewrite non-idiomatic queries in idiomatic DataFrame form Rewrite the seven TPC-H example queries that did not demonstrate the idiomatic DataFrame pattern. The remaining queries (Q02/Q11/Q15/Q17/Q22, which use window functions in place of correlated subqueries) already are idiomatic and are left unchanged. - Q04: replace `.aggregate([col("l_orderkey")], [])` with `.select("l_orderkey").distinct()`, which is the natural way to express "reduce to one row per order" on a DataFrame. - Q07: remove the CASE-as-filter on `n_name` and use `F.in_list(col("n_name"), [nation_1, nation_2])` instead. Drops a comment block that admitted the filter form was simpler. - Q08: rewrite the switched CASE `F.case(...).when(lit(False), ...)` as a searched `F.when(col(...).is_not_null(), ...).otherwise(...)`. That mirrors the reference SQL's `case when ... then ... else 0 end` shape. - Q12: replace `array_position(make_array(...), col)` with `F.in_list(col("l_shipmode"), [...])`. Same semantics, without routing through array construction / array search. - Q19: remove the pyarrow UDF that re-implemented a disjunctive predicate in Python. Build the same predicate in DataFusion by OR-combining one `in_list` + range-filter expression per brand. Keeps the per-brand constants in the existing `items_of_interest` dict. - Q20: use `F.starts_with` instead of an explicit substring slice. Replace the inner-join + `select(...).distinct()` tail with a semi join against a precomputed set of excess-quantity suppliers so the supplier columns are preserved without deduplication after the fact. - Q21: replace the `array_agg` / `array_length` / `array_element` pipeline with two semi joins. One semi join keeps orders with more than one distinct supplier (stand-in for the reference SQL's `exists` subquery), the other keeps orders with exactly one late supplier (stand-in for the `not exists` subquery). All 22 answer-file comparisons and 22 plan-comparison diagnostics still pass (`pytest examples/tpch/_tests.py`: 44 passed). Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: align reference SQL constants with DataFrame queries The reference SQL embedded in each q01..q22 module docstring was carried over verbatim from ``benchmarks/tpch/queries/`` and uses a different set of TPC-H substitution parameters than the DataFrame examples (answer-file-validated at scale factor 1). Update each reference SQL to use the substitution parameters the DataFrame uses, so both expressions describe the same query and would produce the same results against the same data. Constants aligned: - Q01: ``90 days`` cutoff (DataFrame ``DAYS_BEFORE_FINAL = 90``). - Q02: ``p_size = 15``, ``p_type like '%BRASS'``, ``r_name = 'EUROPE'``. - Q04: base date ``1993-07-01`` (``3 month`` interval preserved per the "quarter of a year" wording). - Q05: ``r_name = 'ASIA'``. - Q06: ``l_discount between 0.06 - 0.01 and 0.06 + 0.01``. - Q07: nations ``'FRANCE'`` / ``'GERMANY'``. - Q08: ``r_name = 'AMERICA'``, ``p_type = 'ECONOMY ANODIZED STEEL'``, inner-case ``nation = 'BRAZIL'``. - Q09: ``p_name like '%green%'``. - Q10: base date ``1993-10-01`` (``3 month`` interval preserved). - Q11: ``n_name = 'GERMANY'``. - Q12: ship modes ``('MAIL', 'SHIP')``, base date ``1994-01-01``. - Q13: ``o_comment not like '%special%requests%'``. - Q14: base date ``1995-09-01``. - Q15: base date ``1996-01-01``. - Q16: ``p_brand <> 'Brand#45'``, ``p_type not like 'MEDIUM POLISHED%'``, sizes ``(49, 14, 23, 45, 19, 3, 36, 9)``. - Q17: ``p_brand = 'Brand#23'``, ``p_container = 'MED BOX'``. - Q18: ``sum(l_quantity) > 300``. - Q19: brands ``Brand#12`` / ``Brand#23`` / ``Brand#34`` with the matching minimum quantities (1, 10, 20). - Q20: ``p_name like 'forest%'``, base date ``1994-01-01``, ``n_name = 'CANADA'``. - Q21: ``n_name = 'SAUDI ARABIA'``. - Q22: country codes ``('13', '31', '23', '29', '30', '18', '17')``. Interval units (month / year) are preserved where the problem-statement text reads "given quarter", "given year", "given month". Q01 keeps the literal "days" unit because the TPC-H problem statement itself describes the cutoff in days. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: apply SKILL.md idioms across all 22 queries Sweep every q01..q22 example for idiomatic DataFrame style as described in the repo-root SKILL.md: - ``col("x") == "s"`` in place of ``col("x") == lit("s")`` on comparison right-hand sides (auto-wrap applies). - Plain-name strings in ``select``/``aggregate``/``sort`` group/sort key lists when the key is a bare column. - Drop redundant ``how="inner"`` and single-element ``left_on``/``right_on`` list wrapping on equi-joins. - Collapse chained ``.filter(a).filter(b)`` runs into ``.filter(a, b)`` and chained ``.with_column`` runs into ``.with_columns(a=..., b=...)``. - ``df.sort_by(...)`` or plain-name ``df.sort(...)`` when no null-placement override is needed. - ``F.count_star()`` in place of ``F.count(col("x"))`` whenever the SQL reads ``count(*)``. - ``F.starts_with(col, lit(prefix))`` and ``~F.starts_with(...)`` in place of substring-prefix equality/inequality tricks. - ``F.in_list(col, [lit(...)])`` in place of ``~F.array_position(...). is_null()`` and in place of disjunctions of equality comparisons. - Searched ``F.when(cond, x).otherwise(y)`` in place of switched ``F.case(bool_expr).when(lit(True/False), x).end()`` forms. - Semi-joins as the DataFrame form of ``EXISTS`` (Q04); anti-joins as ``NOT EXISTS`` (Q22 was already using this idiom). - Whole-frame window aggregates as the DataFrame stand-in for a SQL scalar subquery (Q11/Q15/Q17/Q22). Individual query fixes of note: - Q16 — add the secondary sort keys (``p_brand``, ``p_type``, ``p_size``) that the TPC-H spec requires but the original DataFrame omitted. - Q22 — drop a stray ``df.show()`` mid-pipeline; replace the 0-based substring slice with ``F.left(col("c_phone"), lit(2))``. - Q14 — rewrite the promo/non-promo factor split as a searched CASE inside ``F.sum(...)`` so the DataFrame expression matches the reference SQL shape exactly. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: more idiomatic aggregate FILTER, string funcs, date handling Additional sweep of the TPC-H DataFrame examples informed by comparing against a fresh set of SKILL.md-only generations under ``examples/tpch/agentic_queries/``: - Q02: ``F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST))`` in place of ``F.strpos(col, lit) > 0``. The reference SQL is ``p_type like '%BRASS'``, which is an ends_with check, not contains. ``F.strpos > 0`` returned the correct rows on TPC-H data by coincidence but is semantically wrong. - Q09: ``F.contains(col("p_name"), lit(part_color))`` in place of ``F.strpos(col, lit) > 0``. The SQL is ``p_name like '%green%'``. - Q08, Q12, Q14: use the ``filter`` keyword on ``F.sum`` / ``F.count`` — the DataFrame form of SQL ``sum(...) FILTER (WHERE ...)`` — instead of wrapping the aggregate input in ``F.when(cond, x).otherwise(0)``. Q08 also reorganises to inner-join the supplier's nation onto the regional sales, which removes the previous left-join + ``F.when(is_not_null, ...)`` dance. - Q15: compute the grand maximum revenue as a separate scalar aggregate and ``join_on(...)`` on equality, instead of the whole-frame window ``F.max`` + filter shape. Simpler plan, same result. - Q16: ``F.regexp_like(col, pattern)`` in place of ``F.regexp_match(col, pattern).is_not_null()``. - Q04, Q05, Q06, Q07, Q08, Q10, Q12, Q14, Q15, Q20: store both the start and the end of the date window as plain ``datetime.date`` objects and compare with ``lit(end_date)``, instead of carrying the start date + ``pa.month_day_nano_interval`` and adding them at query-build time. Drops unused ``pyarrow`` imports from the files that no longer need Arrow scalars. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- examples/tpch/q01_pricing_summary_report.py | 44 ++-- examples/tpch/q02_minimum_cost_supplier.py | 87 ++++++-- examples/tpch/q03_shipping_priority.py | 51 +++-- examples/tpch/q04_order_priority_checking.py | 67 +++--- examples/tpch/q05_local_supplier_volume.py | 66 +++--- .../tpch/q06_forecasting_revenue_change.py | 35 +-- examples/tpch/q07_volume_shipping.py | 103 +++++---- examples/tpch/q08_market_share.py | 205 +++++++++--------- .../tpch/q09_product_type_profit_measure.py | 77 +++++-- examples/tpch/q10_returned_item_reporting.py | 102 +++++---- .../q11_important_stock_identification.py | 83 ++++--- examples/tpch/q12_ship_mode_order_priority.py | 108 +++++---- examples/tpch/q13_customer_distribution.py | 47 ++-- examples/tpch/q14_promotion_effect.py | 81 +++---- examples/tpch/q15_top_supplier.py | 94 ++++---- .../tpch/q16_part_supplier_relationship.py | 81 ++++--- examples/tpch/q17_small_quantity_order.py | 58 +++-- examples/tpch/q18_large_volume_customer.py | 71 ++++-- examples/tpch/q19_discounted_revenue.py | 134 ++++++------ examples/tpch/q20_potential_part_promotion.py | 120 ++++++---- .../tpch/q21_suppliers_kept_orders_waiting.py | 134 +++++++----- examples/tpch/q22_global_sales_opportunity.py | 104 ++++++--- 22 files changed, 1196 insertions(+), 756 deletions(-) diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f97f00dc..105f1632d 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -27,6 +27,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + lineitem + where + l_shipdate <= date '1998-12-01' - interval '90 days' + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus; """ import pyarrow as pa @@ -58,31 +82,25 @@ # Aggregate the results +disc_price = col("l_extendedprice") * (lit(1) - col("l_discount")) + df = df.aggregate( - [col("l_returnflag"), col("l_linestatus")], + ["l_returnflag", "l_linestatus"], [ F.sum(col("l_quantity")).alias("sum_qty"), F.sum(col("l_extendedprice")).alias("sum_base_price"), - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "sum_disc_price" - ), - F.sum( - col("l_extendedprice") - * (lit(1) - col("l_discount")) - * (lit(1) + col("l_tax")) - ).alias("sum_charge"), + F.sum(disc_price).alias("sum_disc_price"), + F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"), F.avg(col("l_quantity")).alias("avg_qty"), F.avg(col("l_extendedprice")).alias("avg_price"), F.avg(col("l_discount")).alias("avg_disc"), - F.count(col("l_returnflag")).alias( - "count_order" - ), # Counting any column should return same result + F.count_star().alias("count_order"), ], ) # Sort per the expected result -df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort()) +df = df.sort_by("l_returnflag", "l_linestatus") # Note: There appears to be a discrepancy between what is returned here and what is in the generated # answers file for the case of return flag N and line status O, but I did not investigate further. diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 47961d2ef..c5c6b9c0b 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -27,6 +27,52 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment + from + part, + supplier, + partsupp, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) + order by + s_acctbal desc, + n_name, + s_name, + p_partkey limit 100; """ import datafusion @@ -67,35 +113,30 @@ "r_regionkey", "r_name" ) -# Filter down parts. Part names contain the type of interest, so we can use strpos to find where -# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position -# in the string where it is located. +# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which +# is an ``ends_with`` check; use the dedicated string function rather than +# a manual substring match. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0) -).filter(col("p_size") == lit(SIZE_OF_INTEREST)) + F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)), + col("p_size") == SIZE_OF_INTEREST, +) # Filter regions down to the one of interest -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join( - df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" -) -df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -) +df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey") +df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join( - df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" -) +df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey") # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -112,9 +153,9 @@ ), ) -df = df.filter(col("min_cost") == col("ps_supplycost")) - -df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") +df = df.filter(col("min_cost") == col("ps_supplycost")).join( + df_part, left_on="ps_partkey", right_on="p_partkey" +) # From the problem statement, these are the values we wish to output @@ -132,12 +173,10 @@ # Sort and display 100 entries df = df.sort( col("s_acctbal").sort(ascending=False), - col("n_name").sort(), - col("s_name").sort(), - col("p_partkey").sort(), -) - -df = df.limit(100) + "n_name", + "s_name", + "p_partkey", +).limit(100) # Show results diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index fc1231e0a..880c7435f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -25,6 +25,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority + from + customer, + orders, + lineitem + where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' + group by + l_orderkey, + o_orderdate, + o_shippriority + order by + revenue desc, + o_orderdate limit 10; """ from datafusion import SessionContext, col, lit @@ -50,20 +75,20 @@ # Limit dataframes to the rows of interest -df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST)) +df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST) df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST)) df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST)) # Join all 3 dataframes -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") +df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join( + df_lineitem, left_on="o_orderkey", right_on="l_orderkey" +) # Compute the revenue df = df.aggregate( - [col("l_orderkey")], + ["l_orderkey"], [ F.first_value(col("o_orderdate")).alias("o_orderdate"), F.first_value(col("o_shippriority")).alias("o_shippriority"), @@ -71,17 +96,13 @@ ], ) -# Sort by priority - -df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort()) - -# Only return 10 results +# Sort by priority, take 10, and project in the order expected by the spec. -df = df.limit(10) - -# Change the order that the columns are reported in just to match the spec - -df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +df = ( + df.sort(col("revenue").sort(ascending=False), "o_orderdate") + .limit(10) + .select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +) # Show result diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 426338aea..6f11c1383 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -24,18 +24,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_orderpriority, + count(*) as order_count + from + orders + where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) + group by + o_orderpriority + order by + o_orderpriority; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -# Ideally we could put 3 months into the interval. See note below. -INTERVAL_DAYS = 92 -DATE_OF_INTEREST = "1993-07-01" +QUARTER_START = date(1993, 7, 1) +QUARTER_END = date(1993, 10, 1) # Load the dataframes we need @@ -48,36 +70,23 @@ "l_orderkey", "l_commitdate", "l_receiptdate" ) -# Create a date object from the string -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - -# Limit results to cases where commitment date before receipt date -# Aggregate the results so we only get one row to join with the order table. -# Alternately, and likely more idiomatic is instead of `.aggregate` you could -# do `.select("l_orderkey").distinct()`. The goal here is to show -# multiple examples of how to use Data Fusion. -df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate( - [col("l_orderkey")], [] +# Keep only orders in the quarter of interest, then restrict to those that +# have at least one late lineitem via a semi join (the DataFrame form of +# ``EXISTS`` from the reference SQL). +df_orders = df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) -# Limit orders to date range of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) -) +late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) -# Perform the join to find only orders for which there are lineitems outside of expected range df = df_orders.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" + late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi" ) -# Based on priority, find the number of entries -df = df.aggregate( - [col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")] +# Count the number of orders in each priority group and sort. +df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by( + "o_orderpriority" ) -# Sort the results -df = df.sort(col("o_orderpriority").sort()) - df.show() diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index fa2b01dea..bfdba5d4c 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -27,23 +27,45 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue + from + customer, + orders, + lineitem, + supplier, + nation, + region + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year + group by + n_name + order by + revenue desc; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_OF_INTEREST = "1994-01-01" -INTERVAL_DAYS = 365 +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) REGION_OF_INTEREST = "ASIA" -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -68,38 +90,32 @@ ) # Restrict dataframes to cases of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) +df_orders = df_orders.filter( + col("o_orderdate") >= lit(YEAR_START), + col("o_orderdate") < lit(YEAR_END), ) -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Join all the dataframes df = ( - df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" - ) - .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join( df_supplier, left_on=["l_suppkey", "c_nationkey"], right_on=["s_suppkey", "s_nationkey"], - how="inner", ) - .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(df_region, left_on="n_regionkey", right_on="r_regionkey") ) -# Compute the final result +# Compute the final result, then sort in descending order. df = df.aggregate( - [col("n_name")], + ["n_name"], [F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")], -) - -# Sort in descending order - -df = df.sort(col("revenue").sort(ascending=False)) +).sort(col("revenue").sort(ascending=False)) df.show() diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 1de5848b1..ed54d22a4 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -27,28 +27,34 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice * l_discount) as revenue + from + lineitem + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path # Variables from the example query -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) DISCOUT = 0.06 DELTA = 0.01 QUANTITY = 24 -INTERVAL_DAYS = 365 - -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -59,12 +65,11 @@ # Filter down to lineitems of interest -df = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA)) - .filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA)) - .filter(col("l_quantity") < lit(QUANTITY)) +df = df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), + col("l_quantity") < QUANTITY, ) # Add up all the "lost" revenue diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index ff2f891f1..df1c2ae0d 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -26,9 +26,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue + from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit @@ -40,11 +82,8 @@ nation_1 = lit("FRANCE") nation_2 = lit("GERMANY") -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -69,60 +108,44 @@ # Filter to time of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( - col("l_shipdate") <= end_date +df_lineitem = df_lineitem.filter( + col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE) ) -# A simpler way to do the following operation is to use a filter, but we also want to demonstrate -# how to use case statements. Here we are assigning `n_name` to be itself when it is either of -# the two nations of interest. Since there is no `otherwise()` statement, any values that do -# not match these will result in a null value and then get filtered out. -# -# To do the same using a simple filter would be: -# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001 -df_nation = df_nation.with_column( - "n_name", - F.case(col("n_name")) - .when(nation_1, col("n_name")) - .when(nation_2, col("n_name")) - .end(), -).filter(~col("n_name").is_null()) +# Limit the nation table to the two nations of interest. +df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2])) # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("c_custkey"), col("n_name").alias("cust_nation")) + df_nation, left_on="c_nationkey", right_on="n_nationkey" +).select("c_custkey", col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join( - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" - ) - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .filter(col("cust_nation") != col("supp_nation")) ) # Extract out two values for every line item -df = df.with_column( - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) +df = df.with_columns( + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), +) -# Aggregate the results +# Aggregate and sort per the spec. df = df.aggregate( - [col("supp_nation"), col("cust_nation"), col("l_year")], + ["supp_nation", "cust_nation", "l_year"], [F.sum(col("volume")).alias("revenue")], -) - -# Sort based on problem statement requirements -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) +).sort_by("supp_nation", "cust_nation", "l_year") df.show() diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 4bf50efba..dd7bacedb 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -25,24 +25,61 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share + from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations + group by + o_year + order by + o_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -supplier_nation = lit("BRAZIL") -customer_region = lit("AMERICA") -part_of_interest = lit("ECONOMY ANODIZED STEEL") - -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" +supplier_nation = "BRAZIL" +customer_region = "AMERICA" +part_of_interest = "ECONOMY ANODIZED STEEL" -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -74,105 +111,57 @@ # Limit orders to those in the specified range -df_orders = df_orders.filter(col("o_orderdate") >= start_date).filter( - col("o_orderdate") <= end_date -) - -# Part 1: Find customers in the region - -# We want customers in region specified by region_of_interest. This will be used to compute -# the total sales of the part of interest. We want to know of those sales what fraction -# was supplied by the nation of interest. There is no guarantee that the nation of -# interest is within the region of interest. - -# First we find all the sales that make up the basis. - -df_regional_customers = df_region.filter(col("r_name") == customer_region) - -# After this join we have all of the possible sales nations -df_regional_customers = df_regional_customers.join( - df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" -) - -# Now find the possible customers -df_regional_customers = df_regional_customers.join( - df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" -) - -# Next find orders for these customers -df_regional_customers = df_regional_customers.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -) - -# Find all line items from these orders -df_regional_customers = df_regional_customers.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" -) - -# Limit to the part of interest -df_regional_customers = df_regional_customers.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" -) - -# Compute the volume for each line item -df_regional_customers = df_regional_customers.with_column( - "volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")) -) - -# Part 2: Find suppliers from the nation - -# Now that we have all of the sales of that part in the specified region, we need -# to determine which of those came from suppliers in the nation we are interested in. - -df_national_suppliers = df_nation.filter(col("n_name") == supplier_nation) - -# Determine the suppliers by the limited nation key we have in our single row df above -df_national_suppliers = df_national_suppliers.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" -) - -# When we join to the customer dataframe, we don't want to confuse other columns, so only -# select the supplier key that we need -df_national_suppliers = df_national_suppliers.select("s_suppkey") - - -# Part 3: Combine suppliers and customers and compute the market share - -# Now we can do a left outer join on the suppkey. Those line items from other suppliers -# will get a null value. We can check for the existence of this null to compute a volume -# column only from suppliers in the nation we are evaluating. - -df = df_regional_customers.join( - df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" -) - -# Use a case statement to compute the volume sold by suppliers in the nation of interest -df = df.with_column( - "national_volume", - F.case(col("s_suppkey").is_null()) - .when(lit(value=False), col("volume")) - .otherwise(lit(0.0)), -) - -df = df.with_column( - "o_year", F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()) -) - - -# Lastly, sum up the results - -df = df.aggregate( - [col("o_year")], - [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), - ], +df_orders = df_orders.filter( + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) +) + +# Pair each supplier with its nation name so every regional-customer row +# below carries the supplier's nation and can be filtered inside the +# aggregate with ``F.sum(..., filter=...)``. + +df_supplier_with_nation = df_supplier.join( + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) + +# Build every (part, lineitem, order, customer) row for customers in the +# target region ordering the target part. Each row carries the supplier's +# nation so we can aggregate on it below. + +df = ( + df_region.filter(col("r_name") == customer_region) + .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") + .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") + .join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") + .with_columns( + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + ) +) + +# Aggregate the total and national volumes per year via the ``filter`` +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). +# ``coalesce`` handles the case where no sale came from the target nation +# for a given year. +df = ( + df.aggregate( + ["o_year"], + [ + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( + "national_volume" + ), + F.sum(col("volume")).alias("total_volume"), + ], + ) + .select( + "o_year", + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( + "mkt_share" + ), + ) + .sort_by("o_year") ) -df = df.select( - col("o_year"), (F.col("national_volume") / F.col("volume")).alias("mkt_share") -) - -df = df.sort(col("o_year").sort()) - df.show() diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index e2abbd095..ec68a2ab7 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -27,6 +27,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + nation, + o_year, + sum(amount) as sum_profit + from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit + group by + nation, + o_year + order by + nation, + o_year desc; """ import pyarrow as pa @@ -34,7 +69,7 @@ from datafusion import functions as F from util import get_data_path -part_color = lit("green") +part_color = "green" # Load the dataframes we need @@ -62,37 +97,35 @@ "n_nationkey", "n_name", "n_regionkey" ) -# Limit possible parts to the color specified -df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) - -# We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join( - df_partsupp, - left_on=["l_suppkey", "l_partkey"], - right_on=["ps_suppkey", "ps_partkey"], - how="inner", +# Limit possible parts to the color specified, then walk the joins down to the +# line-item rows we need and attach the supplier's nation. ``F.contains`` +# maps directly to the reference SQL's ``p_name like '%green%'``. +df = ( + df_part.filter(F.contains(col("p_name"), lit(part_color))) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join( + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") ) -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( col("n_name").alias("nation"), F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()).alias("o_year"), ( - (col("l_extendedprice") * (lit(1) - col("l_discount"))) - - (col("ps_supplycost") * col("l_quantity")) + col("l_extendedprice") * (lit(1) - col("l_discount")) + - col("ps_supplycost") * col("l_quantity") ).alias("amount"), ) -# Sum up the values by nation and year -df = df.aggregate( - [col("nation"), col("o_year")], [F.sum(col("amount")).alias("profit")] +# Sum up the values by nation and year, then sort per the spec. +df = df.aggregate(["nation", "o_year"], [F.sum(col("amount")).alias("profit")]).sort( + "nation", col("o_year").sort(ascending=False) ) -# Sort according to the problem specification -df = df.sort(col("nation").sort(), col("o_year").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ed822e264..e6532517e 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -27,20 +27,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment + from + customer, + orders, + lineitem, + nation + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey + group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment + order by + revenue desc limit 20; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_START_OF_QUARTER = "1993-10-01" - -date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date()) - -interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1993, 10, 1) +QUARTER_END = date(1994, 1, 1) # Load the dataframes we need @@ -66,44 +96,40 @@ ) # limit to returns -df_lineitem = df_lineitem.filter(col("l_returnflag") == lit("R")) +df_lineitem = df_lineitem.filter(col("l_returnflag") == "R") # Rather than aggregate by all of the customer fields as you might do looking at the specification, # we can aggregate by o_custkey and then join in the customer data at the end. -df = df_orders.filter(col("o_orderdate") >= date_start_of_quarter).filter( - col("o_orderdate") < date_start_of_quarter + interval_one_quarter +df = ( + df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), + ) + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .aggregate( + ["o_custkey"], + [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], + ) ) -df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") - -# Compute the revenue -df = df.aggregate( - [col("o_custkey")], - [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], +# Now join in the customer data, project the spec's output columns, and take the top 20. +df = ( + df.join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_nation, left_on="c_nationkey", right_on="n_nationkey") + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(col("revenue").sort(ascending=False)) + .limit(20) ) -# Now join in the customer data -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") -df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") - -# These are the columns the problem statement requires -df = df.select( - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", -) - -# Sort the results in descending order -df = df.sort(col("revenue").sort(ascending=False)) - -# Only return the top 20 results -df = df.limit(20) - df.show() diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index de309fa64..1f40bbdad 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -25,6 +25,36 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + value desc; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -49,39 +79,30 @@ "n_nationkey", "n_name" ) -# limit to returns -df_nation = df_nation.filter(col("n_name") == lit(NATION)) - -# Find part supplies of within this target nation - -df = df_nation.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +# Restrict to the target nation, then walk to partsupp rows via the supplier +# join. Aggregate the per-part inventory value. +df = ( + df_nation.filter(col("n_name") == NATION) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .join(df_partsupp, left_on="s_suppkey", right_on="ps_suppkey") + .with_column("value", col("ps_supplycost") * col("ps_availqty")) + .aggregate(["ps_partkey"], [F.sum(col("value")).alias("value")]) ) -df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") - - -# Compute the value of individual parts -df = df.with_column("value", col("ps_supplycost") * col("ps_availqty")) - -# Compute total value of specific parts -df = df.aggregate([col("ps_partkey")], [F.sum(col("value")).alias("value")]) - -# By default window functions go from unbounded preceding to current row, but we want -# to compute this sum across all rows -window_frame = WindowFrame("rows", None, None) - -df = df.with_column( - "total_value", F.sum(col("value")).over(Window(window_frame=window_frame)) +# A window function evaluated over the entire output produces a scalar grand +# total that can be referenced row-by-row in the filter — a DataFrame-native +# stand-in for the SQL HAVING ... > (SELECT SUM(...) * FRACTION ...) pattern. +# The default frame is "UNBOUNDED PRECEDING to CURRENT ROW"; override to the +# full partition for the grand total. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df.with_column( + "total_value", F.sum(col("value")).over(Window(window_frame=whole_frame)) + ) + .filter(col("value") / col("total_value") >= lit(FRACTION)) + .select("ps_partkey", "value") + .sort(col("value").sort(ascending=False)) ) -# Limit to the parts for which there is a significant value based on the fraction of the total -df = df.filter(col("value") / col("total_value") >= lit(FRACTION)) - -# We only need to report on these two columns -df = df.select("ps_partkey", "value") - -# Sort in descending order of value -df = df.sort(col("value").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 9071597f0..fb78fe3c2 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -27,18 +27,49 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + orders, + lineitem + where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year + group by + l_shipmode + order by + l_shipmode; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path SHIP_MODE_1 = "MAIL" SHIP_MODE_2 = "SHIP" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) # Load the dataframes we need @@ -51,63 +82,30 @@ "l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - - -df = df_lineitem.filter(col("l_receiptdate") >= lit(date)).filter( - col("l_receiptdate") < lit(date) + lit(interval) -) - -# Note: It is not recommended to use array_has because it treats the second argument as an argument -# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow. -# Instead check the position of the entry is not null. -df = df.filter( - ~F.array_position( - F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode") - ).is_null() -) - -# Since we have only two values, it's much easier to do this as a filter where the l_shipmode -# matches either of the two values, but we want to show doing some array operations in this -# example. If you want to see this done with filters, comment out the above line and uncomment -# this one. -# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001 +df = df_lineitem.filter( + col("l_receiptdate") >= lit(YEAR_START), + col("l_receiptdate") < lit(YEAR_END), + # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. + F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), + col("l_shipdate") < col("l_commitdate"), + col("l_commitdate") < col("l_receiptdate"), +).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") -# We need order priority, so join order df to line item -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +# Flag each line item as belonging to a high-priority order or not. +high_priorities = [lit("1-URGENT"), lit("2-HIGH")] +is_high = F.in_list(col("o_orderpriority"), high_priorities) +is_low = F.in_list(col("o_orderpriority"), high_priorities, negated=True) -# Restrict to line items we care about based on the problem statement. -df = df.filter(col("l_commitdate") < col("l_receiptdate")) - -df = df.filter(col("l_shipdate") < col("l_commitdate")) - -df = df.with_column( - "high_line_value", - F.case(col("o_orderpriority")) - .when(lit("1-URGENT"), lit(1)) - .when(lit("2-HIGH"), lit(1)) - .otherwise(lit(0)), -) - -# Aggregate the results +# Count the high-priority and low-priority lineitems per ship mode via the +# ``filter`` kwarg on ``F.count`` (DataFrame form of SQL's ``count(*) +# FILTER (WHERE ...)``). df = df.aggregate( - [col("l_shipmode")], + ["l_shipmode"], [ - F.sum(col("high_line_value")).alias("high_line_count"), - F.count(col("high_line_value")).alias("all_lines_count"), + F.count(col("o_orderkey"), filter=is_high).alias("high_line_count"), + F.count(col("o_orderkey"), filter=is_low).alias("low_line_count"), ], -) - -# Compute the final output -df = df.select( - col("l_shipmode"), - col("high_line_count"), - (col("all_lines_count") - col("high_line_count")).alias("low_line_count"), -) - -df = df.sort(col("l_shipmode").sort()) +).sort_by("l_shipmode") df.show() diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 93f082ea3..37c0b93f6 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -26,6 +26,29 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_count, + count(*) as custdist + from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc; """ from datafusion import SessionContext, col, lit @@ -49,20 +72,16 @@ F.regexp_match(col("o_comment"), lit(f"{WORD_1}.?*{WORD_2}")).is_null() ) -# Since we may have customers with no orders we must do a left join -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" -) - -# Find the number of orders for each customer -df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) - -# Ultimately we want to know the number of customers that have that customer count -df = df.aggregate([col("c_count")], [F.count(col("c_count")).alias("custdist")]) - -# We want to order the results by the highest number of customers per count -df = df.sort( - col("custdist").sort(ascending=False), col("c_count").sort(ascending=False) +# Customers with no orders still participate, so this is a left join. Count the +# orders per customer, then count customers per order-count value. +df = ( + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="left") + .aggregate(["c_custkey"], [F.count(col("o_custkey")).alias("c_count")]) + .aggregate(["c_count"], [F.count_star().alias("custdist")]) + .sort( + col("custdist").sort(ascending=False), + col("c_count").sort(ascending=False), + ) ) df.show() diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index d62f76e3c..08f4f054d 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -24,20 +24,32 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + from + lineitem, + part + where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1995-09-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_one_month = lit(pa.scalar((0, 30, 0), type=pa.month_day_nano_interval())) +MONTH_START = date(1995, 9, 1) +MONTH_END = date(1995, 10, 1) # Load the dataframes we need @@ -49,37 +61,30 @@ df_part = ctx.read_parquet(get_data_path("part.parquet")).select("p_partkey", "p_type") -# Check part type begins with PROMO -df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO") -).with_column("promo_factor", lit(1.0)) - -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_one_month -) - -# Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" -) - -# Make a factor of 1.0 if it is a promotion, 0.0 otherwise -df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) -df = df.with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) - - -# Sum up the promo and total revenue -df = df.aggregate( - [], - [ - F.sum(col("promo_factor") * col("revenue")).alias("promo_revenue"), - F.sum(col("revenue")).alias("total_revenue"), - ], -) - -# Return the percentage of revenue from promotions -df = df.select( - (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") +# Restrict the line items to the month of interest, join the matching part +# rows, and aggregate revenue totals with a ``filter`` clause on the promo +# sum — the DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``. +revenue = col("l_extendedprice") * (lit(1.0) - col("l_discount")) +is_promo = F.starts_with(col("p_type"), lit("PROMO")) + +df = ( + df_lineitem.filter( + col("l_shipdate") >= lit(MONTH_START), + col("l_shipdate") < lit(MONTH_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + [], + [ + F.sum(revenue, filter=is_promo).alias("promo_revenue"), + F.sum(revenue).alias("total_revenue"), + ], + ) + .select( + (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias( + "promo_revenue" + ) + ) ) df.show() diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 5128937a7..01c38b9f8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -24,21 +24,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + supplier, + revenue0 + where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) + order by + s_suppkey; + drop view revenue0; """ -from datetime import datetime +from datetime import date -import pyarrow as pa -from datafusion import SessionContext, WindowFrame, col, lit +from datafusion import SessionContext, col, lit from datafusion import functions as F -from datafusion.expr import Window from util import get_data_path -DATE = "1996-01-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_3_months = lit(pa.scalar((0, 91, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1996, 1, 1) +QUARTER_END = date(1996, 4, 1) # Load the dataframes we need @@ -54,38 +83,29 @@ "s_phone", ) -# Limit line items to the quarter of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_3_months -) +# Per-supplier revenue over the quarter of interest. +revenue = col("l_extendedprice") * (lit(1) - col("l_discount")) -df = df_lineitem.aggregate( - [col("l_suppkey")], - [ - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "total_revenue" - ) - ], -) +per_supplier_revenue = df_lineitem.filter( + col("l_shipdate") >= lit(QUARTER_START), + col("l_shipdate") < lit(QUARTER_END), +).aggregate(["l_suppkey"], [F.sum(revenue).alias("total_revenue")]) -# Use a window function to find the maximum revenue across the entire dataframe -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "max_revenue", - F.max(col("total_revenue")).over(Window(window_frame=window_frame)), +# Compute the grand maximum revenue separately and join on equality — the +# DataFrame stand-in for the reference SQL's +# ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. +max_revenue = per_supplier_revenue.aggregate( + [], [F.max(col("total_revenue")).alias("max_rev")] ) -# Find all suppliers whose total revenue is the same as the maximum -df = df.filter(col("total_revenue") == col("max_revenue")) - -# Now that we know the supplier(s) with maximum revenue, get the rest of their information -# from the supplier table -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +top_suppliers = per_supplier_revenue.join_on( + max_revenue, col("total_revenue") == col("max_rev") +).select("l_suppkey", "total_revenue") -# Return only the columns requested -df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") - -# If we have more than one, sort by supplier number (suppkey) -df = df.sort(col("s_suppkey").sort()) +df = ( + df_supplier.join(top_suppliers, left_on="s_suppkey", right_on="l_suppkey") + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort_by("s_suppkey") +) df.show() diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 65043ffda..ddeadff5f 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -26,9 +26,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt + from + partsupp, + part + where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) + group by + p_brand, + p_type, + p_size + order by + supplier_cnt desc, + p_brand, + p_type, + p_size; """ -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -52,39 +84,36 @@ ) df_unwanted_suppliers = df_supplier.filter( - ~F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_null() + F.regexp_like(col("s_comment"), lit("Customer.*Complaints")) ) -# Remove unwanted suppliers +# Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). df_partsupp = df_partsupp.join( - df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" + df_unwanted_suppliers, left_on="ps_suppkey", right_on="s_suppkey", how="anti" ) -# Select the parts we are interested in -df_part = df_part.filter(col("p_brand") != lit(BRAND)) +# Select the parts we are interested in. df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) - != lit(TYPE_TO_IGNORE) -) - -# Python conversion of integer to literal casts it to int64 but the data for -# part size is stored as an int32, so perform a cast. Then check to find if the part -# size is within the array of possible sizes by checking the position of it is not -# null. -p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) -df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) - -df = df_part.join( - df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" + col("p_brand") != BRAND, + ~F.starts_with(col("p_type"), lit(TYPE_TO_IGNORE)), + F.in_list(col("p_size"), [lit(s) for s in SIZES_OF_INTEREST]), ) -df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() - -df = df.aggregate( - [col("p_brand"), col("p_type"), col("p_size")], - [F.count(col("ps_suppkey")).alias("supplier_cnt")], +# For each (brand, type, size), count the distinct suppliers remaining. +df = ( + df_part.join(df_partsupp, left_on="p_partkey", right_on="ps_partkey") + .select("p_brand", "p_type", "p_size", "ps_suppkey") + .distinct() + .aggregate( + ["p_brand", "p_type", "p_size"], + [F.count(col("ps_suppkey")).alias("supplier_cnt")], + ) + .sort( + col("supplier_cnt").sort(ascending=False), + "p_brand", + "p_type", + "p_size", + ) ) -df = df.sort(col("supplier_cnt").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5ccb38422..f2229171f 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -26,6 +26,26 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); """ from datafusion import SessionContext, WindowFrame, col, lit @@ -47,29 +67,23 @@ "l_partkey", "l_quantity", "l_extendedprice" ) -# Limit to the problem statement's brand and container types -df = df_part.filter(col("p_brand") == lit(BRAND)).filter( - col("p_container") == lit(CONTAINER) -) - -# Combine data -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +# Limit to parts of the target brand/container, join their line items, and +# attach the per-part average quantity via a partitioned window function — +# the DataFrame form of the SQL's correlated ``avg(l_quantity)`` subquery. +whole_frame = WindowFrame("rows", None, None) -# Find the average quantity -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_quantity", - F.avg(col("l_quantity")).over( - Window(partition_by=[col("l_partkey")], window_frame=window_frame) - ), +df = ( + df_part.filter(col("p_brand") == BRAND, col("p_container") == CONTAINER) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .with_column( + "avg_quantity", + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=whole_frame) + ), + ) + .filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) + .aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) + .select((col("total") / lit(7.0)).alias("avg_yearly")) ) -df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) - -# Compute the total -df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) - -# Divide by number of years in the problem statement to get average -df = df.select((col("total") / lit(7)).alias("avg_yearly")) - df.show() diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 834d181c9..23132d60d 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -24,9 +24,44 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) + from + customer, + orders, + lineitem + where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey + group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice + order by + o_totalprice desc, + o_orderdate limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -46,22 +81,24 @@ "l_orderkey", "l_quantity", "l_extendedprice" ) -df = df_lineitem.aggregate( - [col("l_orderkey")], [F.sum(col("l_quantity")).alias("total_quantity")] +# Find orders whose total quantity exceeds the threshold, then join in the +# order + customer details the problem statement requires and sort. +df = ( + df_lineitem.aggregate( + ["l_orderkey"], [F.sum(col("l_quantity")).alias("total_quantity")] + ) + .filter(col("total_quantity") > QUANTITY) + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .select( + "c_name", + "c_custkey", + "o_orderkey", + "o_orderdate", + "o_totalprice", + "total_quantity", + ) + .sort(col("o_totalprice").sort(ascending=False), "o_orderdate") ) -# Limit to orders in which the total quantity is above a threshold -df = df.filter(col("total_quantity") > lit(QUANTITY)) - -# We've identified the orders of interest, now join the additional data -# we are required to report on -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - -df = df.select( - "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" -) - -df = df.sort(col("o_totalprice").sort(ascending=False), col("o_orderdate").sort()) - df.show() diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index bd492aac0..a2be1c1b7 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -24,10 +24,47 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice* (1 - l_discount)) as revenue + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); """ -import pyarrow as pa -from datafusion import SessionContext, col, lit, udf +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -65,72 +102,41 @@ "l_discount", ) -# These limitations apply to all line items, so go ahead and do them first - -df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON")) - -df = df.filter( - (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) -) +# Filter conditions that apply to every disjunct of the reference SQL's WHERE +# clause — pull them out up front so the per-brand predicate stays focused on +# the brand-specific parts. +df = df_lineitem.filter( + col("l_shipinstruct") == "DELIVER IN PERSON", + F.in_list(col("l_shipmode"), [lit("AIR"), lit("AIR REG")]), +).join(df_part, left_on="l_partkey", right_on="p_partkey") + + +# Build one OR-combined predicate per brand. Each disjunct encodes the +# brand-specific container list, quantity window, and size range from the +# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand +# B ...) or (... brand C ...)`` form directly, without a UDF. +def _brand_predicate( + brand: str, min_quantity: int, containers: list[str], max_size: int +): + return ( + (col("p_brand") == brand) + & F.in_list(col("p_container"), [lit(c) for c in containers]) + & col("l_quantity").between(lit(min_quantity), lit(min_quantity + 10)) + & col("p_size").between(lit(1), lit(max_size)) + ) -df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") - - -# Create the user defined function (UDF) definition that does the work -def is_of_interest( - brand_arr: pa.Array, - container_arr: pa.Array, - quantity_arr: pa.Array, - size_arr: pa.Array, -) -> pa.Array: - """ - The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array - and generating a resultant Array. The length of the inputs should match and there should be the - same number of rows in the output. - """ - result = [] - for idx, brand_val in enumerate(brand_arr): - brand = brand_val.as_py() - if brand in items_of_interest: - values_of_interest = items_of_interest[brand] - - container_matches = ( - container_arr[idx].as_py() in values_of_interest["containers"] - ) - - quantity = quantity_arr[idx].as_py() - quantity_matches = ( - values_of_interest["min_quantity"] - <= quantity - <= values_of_interest["min_quantity"] + 10 - ) - - size = size_arr[idx].as_py() - size_matches = 1 <= size <= values_of_interest["max_size"] - - result.append(container_matches and quantity_matches and size_matches) - else: - result.append(False) - - return pa.array(result) - - -# Turn the above function into a UDF that DataFusion can understand -is_of_interest_udf = udf( - is_of_interest, - [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()], - pa.bool_(), - "stable", -) -# Filter results using the above UDF -df = df.filter( - is_of_interest_udf( - col("p_brand"), col("p_container"), col("l_quantity"), col("p_size") +predicate = None +for brand, params in items_of_interest.items(): + part_predicate = _brand_predicate( + brand, + params["min_quantity"], + params["containers"], + params["max_size"], ) -) + predicate = part_predicate if predicate is None else predicate | part_predicate -df = df.aggregate( +df = df.filter(predicate).aggregate( [], [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], ) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index a25188d31..18f96da97 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -25,17 +25,57 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + s_address + from + supplier, + nation + where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' + order by + s_name; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path COLOR_OF_INTEREST = "forest" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) NATION_OF_INTEREST = "CANADA" # Load the dataframes we need @@ -56,46 +96,48 @@ "n_nationkey", "n_name" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - -# Filter down dataframes -df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) -df_part = df_part.filter( - F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1)) - == lit(COLOR_OF_INTEREST) -) - -df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter( - col("l_shipdate") < lit(date) + lit(interval) +# Filter down dataframes. ``starts_with`` reads more naturally than an +# explicit substring slice and maps directly to the reference SQL's +# ``p_name like 'forest%'`` clause. +df_nation = df_nation.filter(col("n_name") == NATION_OF_INTEREST) +df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) + +# Compute the total quantity of interesting parts shipped by each (part, +# supplier) pair within the year of interest. +totals = ( + df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + ["l_partkey", "l_suppkey"], + [F.sum(col("l_quantity")).alias("total_sold")], + ) ) -# This will filter down the line items to the parts of interest -df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") - -# Compute the total sold and limit ourselves to individual supplier/part combinations -df = df.aggregate( - [col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")] +# Keep only (part, supplier) pairs whose available quantity exceeds 50% of +# the total shipped. The result already contains one row per supplier of +# interest, so we can semi-join the supplier table rather than inner-join +# and deduplicate afterwards. +excess_suppliers = ( + df_partsupp.join( + totals, + left_on=["ps_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + ) + .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) + .select(col("ps_suppkey").alias("suppkey")) + .distinct() ) -df = df.join( - df_partsupp, - left_on=["l_partkey", "l_suppkey"], - right_on=["ps_partkey", "ps_suppkey"], - how="inner", +# Limit to suppliers in the nation of interest and pick out the two +# requested columns. +df = ( + df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") + .select("s_name", "s_address") + .sort_by("s_name") ) -# Find cases of excess quantity -df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) - -# We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - -# Restrict to the requested data per the problem statement -df = df.select("s_name", "s_address").distinct() - -df = df.sort(col("s_name").sort()) - df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 4ee9d3733..d98f76ce7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -24,9 +24,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + count(*) as numwait + from + supplier, + lineitem l1, + orders, + nation + where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' + group by + s_name + order by + numwait desc, + s_name limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -50,65 +92,57 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) - -df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" +df_suppliers_of_interest = df_nation.filter(col("n_name") == NATION_OF_INTEREST).join( + df_supplier, left_on="n_nationkey", right_on="s_nationkey" ) -# Find the failed orders and all their line items -df = df_orders.filter(col("o_orderstatus") == lit("F")) - -df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") - -# Identify the line items for which the order is failed due to. -df = df.with_column( - "failed_supp", - F.case(col("l_receiptdate") > col("l_commitdate")) - .when(lit(value=True), col("l_suppkey")) - .end(), +# Line items for orders that have status 'F'. This is the candidate set of +# (order, supplier) pairs we reason about below. +failed_order_lineitems = df_lineitem.join( + df_orders.filter(col("o_orderstatus") == "F"), + left_on="l_orderkey", + right_on="o_orderkey", ) -# There are other ways we could do this but the purpose of this example is to work with rows where -# an element is an array of values. In this case, we will create two columns of arrays. One will be -# an array of all of the suppliers who made up this order. That way we can filter the dataframe for -# only orders where this array is larger than one for multiple supplier orders. The second column -# is all of the suppliers who failed to make their commitment. We can filter the second column for -# arrays with size one. That combination will give us orders that had multiple suppliers where only -# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the -# same supplier reported in either array. -df = df.aggregate( - [col("o_orderkey")], - [ - F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"), - F.array_agg( - col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True - ).alias("failed_suppliers"), - ], +# Line items whose receipt was late. This corresponds to ``l1`` in the +# reference SQL. +late_lineitems = failed_order_lineitems.filter( + col("l_receiptdate") > col("l_commitdate") ) -# This is the check described above which will identify single failed supplier in a multiple -# supplier order. -df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter( - F.array_length(col("all_suppliers")) > lit(1) +# Orders that had more than one distinct supplier. Expressed as +# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's +# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery. +multi_supplier_orders = ( + failed_order_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_suppliers")]) + .filter(col("n_suppliers") > 1) + .select("l_orderkey") ) -# Since we have an array we know is exactly one element long, we can extract that single value. -df = df.select( - col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey") +# Orders where exactly one distinct supplier was late. Stands in for the +# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is +# also late ...)`` subquery: if only one supplier on the order was late, +# nobody else on the same order was late. +single_late_supplier_orders = ( + late_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_late_suppliers")]) + .filter(col("n_late_suppliers") == 1) + .select("l_orderkey") ) -# Join to the supplier of interest list for the nation of interest -df = df.join( - df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +# Keep late line items whose order qualifies on both counts, attach the +# supplier name for suppliers in the nation of interest, count one row per +# qualifying order, and return the top 100. +df = ( + late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi") + .join(single_late_supplier_orders, on="l_orderkey", how="semi") + .join(df_suppliers_of_interest, left_on="l_suppkey", right_on="s_suppkey") + .aggregate(["s_name"], [F.count_star().alias("numwait")]) + .sort(col("numwait").sort(ascending=False), "s_name") + .limit(100) ) -# Count how many orders that supplier is the only failed supplier for -df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) - -# Return in descending order -df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort()) - -df = df.limit(100) - df.show() diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index a2d41b215..5043eeb51 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -24,6 +24,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale + group by + cntrycode + order by + cntrycode; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -42,40 +82,36 @@ ) df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select("o_custkey") -# The nation code is a two digit number, but we need to convert it to a string literal -nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) - -# Use the substring operation to extract the first two characters of the phone number -df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0), lit(3))) - -# Limit our search to customers with some balance and in the country code above -df = df.filter(col("c_acctbal") > lit(0.0)) -df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null()) - -# Compute the average balance. By default, the window frame is from unbounded preceding to the -# current row. We want our frame to cover the entire data frame. -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_balance", - F.avg(col("c_acctbal")).over(Window(window_frame=window_frame)), -) - -df.show() -# Limit results to customers with above average balance -df = df.filter(col("c_acctbal") > col("avg_balance")) - -# Limit results to customers with no orders -df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") - -# Count up the customers and the balances -df = df.aggregate( - [col("cntrycode")], - [ - F.count(col("c_custkey")).alias("numcust"), - F.sum(col("c_acctbal")).alias("totacctbal"), - ], +# Country code is the two-digit prefix of the phone number. +nation_codes = [lit(str(n)) for n in NATION_CODES] + +# Start from customers with a positive balance in one of the target country +# codes, then attach the grand-mean balance via a whole-frame window so we +# can filter per row — DataFrame stand-in for the SQL's scalar ``(select +# avg(c_acctbal) ... )`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_customer.with_column("cntrycode", F.left(col("c_phone"), lit(2))) + .filter( + col("c_acctbal") > 0.0, + F.in_list(col("cntrycode"), nation_codes), + ) + .with_column( + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=whole_frame)), + ) + .filter(col("c_acctbal") > col("avg_balance")) + # Keep only customers with no orders (anti join = NOT EXISTS). + .join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") + .aggregate( + ["cntrycode"], + [ + F.count_star().alias("numcust"), + F.sum(col("c_acctbal")).alias("totacctbal"), + ], + ) + .sort_by("cntrycode") ) -df = df.sort(col("cntrycode").sort()) - df.show() From e0284c6e788b6fc893495ed929b9badef1cf925c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 13:09:24 -0400 Subject: [PATCH 29/29] feat: add AI skill to find and improve the Pythonic interface to functions (#1484) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: accept native Python types in function arguments instead of requiring lit() Update 47 functions in functions.py to accept native Python types (int, float, str) for arguments that are contextually literals, eliminating verbose lit() wrapping. For example, users can now write split_part(col("a"), ",", 2) instead of split_part(col("a"), lit(","), lit(2)). All changes are backward compatible. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: update alias function signatures to match pythonic primary functions Update instr and position (aliases of strpos) to accept Expr | str for the substring parameter, matching the updated primary function signature. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: update make-pythonic skill to require alias type hint updates Alias functions that delegate to a primary function must have their type hints updated to match, even though coercion logic is only added to the primary. Added a new Step 3 to the implementation workflow for this. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: address review feedback on pythonic skill and function signatures Update SKILL.md to prevent three classes of issues: clarify that float already accepts int per PEP 484 (avoiding redundant int | float that fails ruff PYI041), add backward-compat rule for Category B so existing Expr params aren't removed, and add guidance for inline coercion with many optional nullable params instead of local helpers. Replace regexp_instr's _to_raw() helper with inline coercion matching the pattern used throughout the rest of the file. Co-Authored-By: Claude Opus 4.6 (1M context) * refactor: add coerce_to_expr helpers and replace inline coercion patterns Introduce coerce_to_expr() and coerce_to_expr_or_none() in expr.py as the complement to ensure_expr() — where ensure_expr rejects non-Expr values, these helpers wrap them via Expr.literal(). Replaces ~60 inline isinstance checks in functions.py with single-line helper calls, and updates the make-pythonic skill to document the new pattern. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: add aggregate function literal detection to make-pythonic skill Add Technique 1a to detect literal-only arguments in aggregate functions. Unlike scalar UDFs which enforce literals in invoke_with_args(), aggregate functions enforce them in accumulator() via get_scalar_value(), validate_percentile_expr(), or downcast_ref::(). Without this technique, the skill would incorrectly classify arguments like approx_percentile_cont's percentile as Category A (Expr | float) when they should be Category B (float only). Updates the decision flow to branch on scalar vs aggregate before checking for literal enforcement. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: add window function literal detection to make-pythonic skill Add Technique 1b to detect literal-only arguments in window functions. Window functions enforce literals in partition_evaluator() via get_scalar_value_from_args() / downcast_ref::(), not in invoke_with_args() (scalar) or accumulator() (aggregate). Updates the decision flow to branch on scalar vs aggregate vs window. Known window functions with literal-only arguments: ntile (n), lead/lag (offset, default_value), nth_value (n). Co-Authored-By: Claude Opus 4.6 (1M context) * fix: use explicit None checks, widen numeric type hints, and add tests Replace 7 fragile truthiness checks (x.expr if x else None) with explicit is not None checks to prevent silent None when zero-valued literals are passed. Widen log/power/pow type hints to Expr | int | float with noqa: PYI041 for clarity. Add unit tests for coerce_to_expr helpers and integration tests for pythonic calling conventions. Co-Authored-By: Claude Opus 4.6 (1M context) * chore: suppress FBT003 in tests and remove redundant noqa comments Add FBT003 (boolean positional value) to the per-file-ignores for python/tests/* in pyproject.toml, and remove the 6 now-redundant inline noqa: FBT003 comments across test_expr.py and test_context.py. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: replace static function lists with discovery instructions in skill Replace hardcoded "Known aggregate/window functions with literal-only arguments" lists with instructions to discover them dynamically by searching the upstream crate source. Keeps a few examples as validation anchors so the agent knows its search is working correctly. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: make interrupt test reliable on Python 3.11 PyThreadState_SetAsyncExc only delivers exceptions when the thread is executing Python bytecode, not while in native (Rust/C) code. The previous test had two issues causing flakiness on Python 3.11: 1. The interrupt fired before df.collect() entered the UDF, while the thread was still in native code where async exceptions are ignored. 2. time.sleep(2.0) is a single C call where async exceptions are not checked — they're only checked between bytecode instructions. Fix by adding a threading.Event so the interrupt waits until the UDF is actually executing Python code, and by sleeping in small increments so the eval loop has opportunities to check for pending exceptions. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 430 +++++++++++++++++++++++++++++ pyproject.toml | 1 + python/datafusion/expr.py | 41 +++ python/datafusion/functions.py | 445 +++++++++++++++++------------- python/tests/test_context.py | 6 +- python/tests/test_dataframe.py | 21 +- python/tests/test_expr.py | 49 +++- python/tests/test_functions.py | 93 +++++++ 8 files changed, 893 insertions(+), 193 deletions(-) create mode 100644 .ai/skills/make-pythonic/SKILL.md diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md new file mode 100644 index 000000000..57145ac6c --- /dev/null +++ b/.ai/skills/make-pythonic/SKILL.md @@ -0,0 +1,430 @@ + + +--- +name: make-pythonic +description: Audit and improve datafusion-python functions to accept native Python types (int, float, str, bool) instead of requiring explicit lit() or col() wrapping. Analyzes function signatures, checks upstream Rust implementations for type constraints, and applies the appropriate coercion pattern. +argument-hint: [scope] (e.g., "string functions", "datetime functions", "array functions", "math functions", "all", or a specific function name like "split_part") +--- + +# Make Python API Functions More Pythonic + +You are improving the datafusion-python API to feel more natural to Python users. The goal is to allow functions to accept native Python types (int, float, str, bool, etc.) for arguments that are contextually always or typically literal values, instead of requiring users to manually wrap them in `lit()`. + +**Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. + +## How to Identify Candidates + +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. + +For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: + +### Signal 1: Contextual Understanding + +Some arguments are contextually always or almost always literal values based on what the function does: + +| Context | Typical Arguments | Examples | +|---------|------------------|----------| +| **String position/count** | Character counts, indices, repetition counts | `left(str, n)`, `right(str, n)`, `repeat(str, n)`, `lpad(str, count, ...)` | +| **Delimiters/separators** | Fixed separator characters | `split_part(str, delim, idx)`, `concat_ws(sep, ...)` | +| **Search/replace patterns** | Literal search strings, replacements | `replace(str, from, to)`, `regexp_replace(str, pattern, replacement, flags)` | +| **Date/time parts** | Part names from a fixed set | `date_part(part, date)`, `date_trunc(part, date)` | +| **Rounding precision** | Decimal place counts | `round(val, places)`, `trunc(val, places)` | +| **Fill characters** | Padding characters | `lpad(str, count, fill)`, `rpad(str, count, fill)` | + +### Signal 2: Upstream Rust Implementation + +Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFusion function implementation to determine type constraints. The upstream source is cached locally at: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-/src/ +``` + +Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. + +For **aggregate functions**, the upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ +``` + +There are five concrete techniques to check, in order of signal strength: + +#### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) + +Some functions pattern-match on `ColumnarValue::Scalar` in their `invoke_with_args()` method and **return an error** if the argument is a column/array. This means the argument **must** be a literal — passing a column expression will fail at runtime. + +Example from `date_trunc.rs`: +```rust +let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v.to_lowercase() +} else { + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); +}; +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. + +#### Technique 1a: Check `accumulator()` for literal-only enforcement (aggregate functions) + +Technique 1 applies to scalar UDFs. Aggregate functions do not have `invoke_with_args()` — instead, they enforce literal-only arguments in their `accumulator()` (or `create_accumulator()`) method, which runs at planning time before any data is processed. + +Look for these patterns inside `accumulator()`: + +- `get_scalar_value(expr)` — evaluates the expression against an empty batch and errors if it's not a scalar +- `validate_percentile_expr(expr)` — specific helper used by percentile functions +- `downcast_ref::()` — checks that the physical expression is a literal constant + +Example from `approx_percentile_cont.rs`: +```rust +fn accumulator(&self, args: AccumulatorArgs) -> Result { + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; + // ... +} +``` + +Where `validate_percentile_expr` calls `get_scalar_value` and errors with `"must be a literal"`. + +Example from `string_agg.rs`: +```rust +fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which aggregate functions have literal-only arguments, search the upstream aggregate crate for `get_scalar_value`, `validate_percentile_expr`, and `downcast_ref::()` inside `accumulator()` methods. For example, you should expect to find `approx_percentile_cont` (percentile) and `string_agg` (delimiter) among the results. + +#### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) + +Window functions do not have `invoke_with_args()` or `accumulator()`. Instead, they enforce literal-only arguments in their `partition_evaluator()` method, which constructs the evaluator that processes each partition. + +The upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-window-/src/ +``` + +Look for `get_scalar_value_from_args()` calls inside `partition_evaluator()`. This helper (defined in the window crate's `utils.rs`) calls `downcast_ref::()` and errors with `"There is only support Literal types for field at idx: {index} in Window Function"`. + +Example from `ntile.rs`: +```rust +fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, +) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + exec_datafusion_err!("NTILE requires a positive integer") + })?; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which window functions have literal-only arguments, search the upstream window crate for `get_scalar_value_from_args` inside `partition_evaluator()` methods. For example, you should expect to find `ntile` (n) and `lead`/`lag` (offset, default_value) among the results. + +#### Technique 2: Check the `Signature` for data type constraints + +Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. + +Example from `repeat.rs`: +```rust +signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, +), +``` + +This tells you arg 2 (`n`) must be an integer type coerced to Int64. Use this to choose the correct Python type (e.g., `int` not `str` or `float`). + +Common mappings: +| Rust Type Constraint | Python Type | +|---------------------|-------------| +| `logical_int64()` / `TypeSignatureClass::Integer` | `int` | +| `logical_float64()` / `TypeSignatureClass::Numeric` | `int \| float` | +| `logical_string()` / `TypeSignatureClass::String` | `str` | +| `LogicalType::Boolean` | `bool` | + +**Important:** In Python's type system (PEP 484), `float` already accepts `int` values, so `int | float` is redundant and will fail the `ruff` linter (rule PYI041). Use `float` alone when the Rust side accepts a float/numeric type — Python users can still pass integer literals like `log(10, col("a"))` or `power(col("a"), 3)` without issue. Only use `int` when the Rust side strictly requires an integer (e.g., `logical_int64()`). + +#### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage + +Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. + +Example from `round.rs`: +```rust +let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), + Some(None) => None, // argument is not a literal (column) + Some(Some(scalar)) if scalar.is_null() => Some(0), + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), +}; +``` + +**If you find this pattern:** The argument is **Category A** — accept native types AND `Expr`. It works as a column but is primarily used as a literal. + +#### Decision flow + +``` +What kind of function is this? + Scalar UDF: + Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Aggregate: + Is argument rejected at planning time if not a literal? + (check accumulator() for get_scalar_value / validate_percentile_expr / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Window: + Is argument rejected at planning time if not a literal? + (check partition_evaluator() for get_scalar_value_from_args / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + +Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only +``` + +## Coercion Categories + +When making a function more pythonic, apply the correct coercion pattern based on **what the argument represents**: + +### Category A: Arguments That Should Accept Native Types AND Expr + +These are arguments that are *typically* literals but *could* be column references in advanced use cases. For these, accept a union type and coerce native types to `Expr.literal()`. + +**Type hint pattern:** `Expr | int`, `Expr | str`, `Expr | int | str`, etc. + +**When to use:** When the argument could plausibly come from a column in some use case (e.g., the repeat count might come from a column in a data-driven scenario). + +```python +def repeat(string: Expr, n: Expr | int) -> Expr: + """Repeats the ``string`` to ``n`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["ha"]}) + >>> result = df.select( + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) + >>> result.collect_column("r")[0].as_py() + 'hahaha' + """ + if not isinstance(n, Expr): + n = Expr.literal(n) + return Expr(f.repeat(string.expr, n.expr)) +``` + +### Category B: Arguments That Should ONLY Accept Specific Native Types + +These are arguments where an `Expr` never makes sense because the value must be a fixed literal known at query-planning time (not a per-row value). For these, accept only the native type(s) and wrap internally. + +**Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) + +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant, **AND** the parameter was not previously typed as `Expr`: +- Separator in `concat_ws` (already typed as `str` in the Rust binding) +- Index in `array_position` (already typed as `int` in the Rust binding) +- Values that the Rust implementation already accepts as native types + +**Backward compatibility rule:** If a parameter was previously typed as `Expr`, you **must** keep `Expr` in the union even if the Rust side requires a literal. Removing `Expr` would break existing user code like `date_part(lit("year"), col("a"))`. Use **Category A** instead — accept `Expr | str` — and let users who pass column expressions discover the runtime error from the Rust side. Never silently break backward compatibility. + +```python +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list ``args`` with the separator. + + ``separator`` is already typed as ``str`` in the Rust binding, so + there is no backward-compatibility concern. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello"], "b": ["world"]}) + >>> result = df.select( + ... dfn.functions.concat_ws("-", dfn.col("a"), dfn.col("b")).alias("c")) + >>> result.collect_column("c")[0].as_py() + 'hello-world' + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, args)) +``` + +### Category C: Arguments That Should Accept str as Column Name + +In some contexts a string argument naturally refers to a column name rather than a literal. This is the pattern used by DataFrame methods. + +**Type hint pattern:** `Expr | str` + +**When to use:** Only when the string contextually means a column name (rare in `functions.py`, more common in DataFrame methods). + +```python +# Use _to_raw_expr() from expr.py for this pattern +from datafusion.expr import _to_raw_expr + +def some_function(column: Expr | str) -> Expr: + raw = _to_raw_expr(column) # str -> col(str) + return Expr(f.some_function(raw)) +``` + +**IMPORTANT:** In `functions.py`, string arguments almost never mean column names. Functions operate on expressions, and column references should use `col()`. Category C applies mainly to DataFrame methods and context APIs, not to scalar/aggregate/window functions. Do NOT convert string arguments to column expressions in `functions.py` unless there is a very clear reason to do so. + +## Implementation Steps + +For each function being updated: + +### Step 1: Analyze the Function + +1. Read the current Python function signature in `python/datafusion/functions.py` +2. Read the Rust binding in `crates/core/src/functions.rs` +3. Optionally check the upstream DataFusion docs for the function +4. Determine which category (A, B, or C) applies to each parameter + +### Step 2: Update the Python Function + +1. **Change the type hints** to accept native types (e.g., `Expr` -> `Expr | int`) +2. **Add coercion logic** at the top of the function body +3. **Update the docstring** examples to use the simpler calling convention +4. **Preserve backward compatibility** — existing code using `Expr` must still work + +### Step 3: Update Alias Type Hints + +After updating a primary function, find all alias functions that delegate to it (e.g., `instr` and `position` delegate to `strpos`). Update each alias's **parameter type hints** to match the primary function's new signature. Do not add coercion logic to aliases — the primary function handles that. + +### Step 4: Update Docstring Examples (primary functions only) + +Per the project's CLAUDE.md rules: +- Every function must have doctest-style examples +- Optional parameters need examples both without and with the optional args, using keyword argument syntax +- Reuse the same input data across examples where possible + +**Update examples to demonstrate the pythonic calling convention:** + +```python +# BEFORE (old style - still works but verbose) +dfn.functions.left(dfn.col("a"), dfn.lit(3)) + +# AFTER (new style - shown in examples) +dfn.functions.left(dfn.col("a"), 3) +``` + +### Step 5: Run Tests + +After making changes, run the doctests to verify: +```bash +python -m pytest --doctest-modules python/datafusion/functions.py -v +``` + +## Coercion Helper Pattern + +Use the coercion helpers from `datafusion.expr` to convert native Python values to `Expr`. These are the complement of `ensure_expr()` — where `ensure_expr` *rejects* non-`Expr` values, the coercion helpers *wrap* them via `Expr.literal()`. + +**For required parameters** use `coerce_to_expr`: + +```python +from datafusion.expr import coerce_to_expr + +def left(string: Expr, n: Expr | int) -> Expr: + n = coerce_to_expr(n) + return Expr(f.left(string.expr, n.expr)) +``` + +**For optional nullable parameters** use `coerce_to_expr_or_none`: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none + +def regexp_count( + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, +) -> Expr: + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) +``` + +Both helpers are defined in `python/datafusion/expr.py` alongside `ensure_expr`. Import them in `functions.py` via: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none +``` + +## What NOT to Change + +- **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. +- **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. +- **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. +- **Do not add coercion logic to simple aliases.** If a function is just `return other_function(...)`, the primary function handles coercion. However, you **must update the alias's type hints** to match the primary function's signature so that type checkers and documentation accurately reflect what the alias accepts. +- **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. + +## Priority Order + +When auditing functions, process them in this order: + +1. **Date/time functions** — `date_part`, `date_trunc`, `date_bin` — these have the clearest literal arguments +2. **String functions** — `left`, `right`, `repeat`, `lpad`, `rpad`, `split_part`, `substring`, `replace`, `regexp_replace`, `regexp_match`, `regexp_count` — common and verbose without coercion +3. **Math functions** — `round`, `trunc`, `power` — numeric literal arguments +4. **Array functions** — `array_slice`, `array_position`, `array_remove_n`, `array_replace_n`, `array_resize`, `array_element` — index and count arguments +5. **Other functions** — any remaining functions with literal arguments + +## Output Format + +For each function analyzed, report: + +``` +## [Function Name] + +**Current signature:** `function(arg1: Expr, arg2: Expr) -> Expr` +**Proposed signature:** `function(arg1: Expr, arg2: Expr | int) -> Expr` +**Category:** A (accepts native + Expr) +**Arguments changed:** +- `arg2`: Expr -> Expr | int (always a literal count) +**Rust binding:** Takes PyExpr, wraps to literal internally +**Status:** [Changed / Skipped / Needs Discussion] +``` + +If asked to implement (not just audit), make the changes directly and show a summary of what was updated. diff --git a/pyproject.toml b/pyproject.toml index 327199d1a..951f7adc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "ARG", "BLE001", "D", + "FBT003", "PD", "PLC0415", "PLR0913", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 1ff6976f7..0f7f3ab5a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -243,6 +243,8 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "coerce_to_expr", + "coerce_to_expr_or_none", "ensure_expr", "ensure_expr_list", ] @@ -255,6 +257,10 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: higher level APIs consistently require explicit :func:`~datafusion.col` or :func:`~datafusion.lit` expressions. + See Also: + :func:`coerce_to_expr` — the opposite behavior: *wraps* non-``Expr`` + values as literals instead of rejecting them. + Args: value: Candidate expression or other object. @@ -299,6 +305,41 @@ def _iter( return list(_iter(exprs)) +def coerce_to_expr(value: Any) -> Expr: + """Coerce a native Python value to an ``Expr`` literal, passing ``Expr`` through. + + This is the complement of :func:`ensure_expr`: where ``ensure_expr`` + *rejects* non-``Expr`` values, ``coerce_to_expr`` *wraps* them via + :meth:`Expr.literal` so that functions can accept native Python types + (``int``, ``float``, ``str``, ``bool``, etc.) alongside ``Expr``. + + Args: + value: An ``Expr`` instance (returned as-is) or a Python literal to wrap. + + Returns: + An ``Expr`` representing the value. + """ + if isinstance(value, Expr): + return value + return Expr.literal(value) + + +def coerce_to_expr_or_none(value: Any | None) -> Expr | None: + """Coerce a value to ``Expr`` or pass ``None`` through unchanged. + + Same as :func:`coerce_to_expr` but accepts ``None`` for optional parameters. + + Args: + value: An ``Expr`` instance, a Python literal to wrap, or ``None``. + + Returns: + An ``Expr`` representing the value, or ``None``. + """ + if value is None: + return None + return coerce_to_expr(value) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 280a6d3ac..08062851a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -49,6 +49,8 @@ Expr, SortExpr, SortKey, + coerce_to_expr, + coerce_to_expr_or_none, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, @@ -383,49 +385,52 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(expr: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr | str) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.encode(dfn.col("a"), dfn.lit("base64")).alias("enc")) + ... dfn.functions.encode(dfn.col("a"), "base64").alias("enc")) >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) -def decode(expr: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr | str) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["aGVsbG8="]}) >>> result = df.select( - ... dfn.functions.decode(dfn.col("a"), dfn.lit("base64")).alias("dec")) + ... dfn.functions.decode(dfn.col("a"), "base64").alias("dec")) >>> result.collect_column("dec")[0].as_py() b'hello' """ + encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) -def array_to_string(expr: Expr, delimiter: Expr) -> Expr: +def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) >>> result = df.select( - ... dfn.functions.array_to_string(dfn.col("a"), dfn.lit(",")).alias("s")) + ... dfn.functions.array_to_string(dfn.col("a"), ",").alias("s")) >>> result.collect_column("s")[0].as_py() '1,2,3' """ + delimiter = coerce_to_expr(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) -def array_join(expr: Expr, delimiter: Expr) -> Expr: +def array_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -434,7 +439,7 @@ def array_join(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_to_string(expr: Expr, delimiter: Expr) -> Expr: +def list_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -443,7 +448,7 @@ def list_to_string(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_join(expr: Expr, delimiter: Expr) -> Expr: +def list_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -479,7 +484,7 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: return Expr(f.in_list(arg.expr, values, negated)) -def digest(value: Expr, method: Expr) -> Expr: +def digest(value: Expr, method: Expr | str) -> Expr: """Computes the binary hash of an expression using the specified algorithm. Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, @@ -489,24 +494,26 @@ def digest(value: Expr, method: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.digest(dfn.col("a"), dfn.lit("md5")).alias("d")) + ... dfn.functions.digest(dfn.col("a"), "md5").alias("d")) >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) -def contains(string: Expr, search_str: Expr) -> Expr: +def contains(string: Expr, search_str: Expr | str) -> Expr: """Returns true if ``search_str`` is found within ``string`` (case-sensitive). Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) >>> result = df.select( - ... dfn.functions.contains(dfn.col("a"), dfn.lit("brown")).alias("c")) + ... dfn.functions.contains(dfn.col("a"), "brown").alias("c")) >>> result.collect_column("c")[0].as_py() True """ + search_str = coerce_to_expr(search_str) return Expr(f.contains(string.expr, search_str.expr)) @@ -969,17 +976,18 @@ def degrees(arg: Expr) -> Expr: return Expr(f.degrees(arg.expr)) -def ends_with(arg: Expr, suffix: Expr) -> Expr: +def ends_with(arg: Expr, suffix: Expr | str) -> Expr: """Returns true if the ``string`` ends with the ``suffix``, false otherwise. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abc","b","c"]}) >>> ends_with_df = df.select( - ... dfn.functions.ends_with(dfn.col("a"), dfn.lit("c")).alias("ends_with")) + ... dfn.functions.ends_with(dfn.col("a"), "c").alias("ends_with")) >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ + suffix = coerce_to_expr(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -1011,7 +1019,7 @@ def factorial(arg: Expr) -> Expr: return Expr(f.factorial(arg.expr)) -def find_in_set(string: Expr, string_list: Expr) -> Expr: +def find_in_set(string: Expr, string_list: Expr | str) -> Expr: """Find a string in a list of strings. Returns a value in the range of 1 to N if the string is in the string list @@ -1023,10 +1031,11 @@ def find_in_set(string: Expr, string_list: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["b"]}) >>> result = df.select( - ... dfn.functions.find_in_set(dfn.col("a"), dfn.lit("a,b,c")).alias("pos")) + ... dfn.functions.find_in_set(dfn.col("a"), "a,b,c").alias("pos")) >>> result.collect_column("pos")[0].as_py() 2 """ + string_list = coerce_to_expr(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1102,7 +1111,7 @@ def initcap(string: Expr) -> Expr: return Expr(f.initcap(string.expr)) -def instr(string: Expr, substring: Expr) -> Expr: +def instr(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1158,31 +1167,33 @@ def least(*args: Expr) -> Expr: return Expr(f.least(*exprs)) -def left(string: Expr, n: Expr) -> Expr: +def left(string: Expr, n: Expr | int) -> Expr: """Returns the first ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat"]}) >>> left_df = df.select( - ... dfn.functions.left(dfn.col("a"), dfn.lit(3)).alias("left")) + ... dfn.functions.left(dfn.col("a"), 3).alias("left")) >>> left_df.collect_column("left")[0].as_py() 'the' """ + n = coerce_to_expr(n) return Expr(f.left(string.expr, n.expr)) -def levenshtein(string1: Expr, string2: Expr) -> Expr: +def levenshtein(string1: Expr, string2: Expr | str) -> Expr: """Returns the Levenshtein distance between the two given strings. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["kitten"]}) >>> result = df.select( - ... dfn.functions.levenshtein(dfn.col("a"), dfn.lit("sitting")).alias("d")) + ... dfn.functions.levenshtein(dfn.col("a"), "sitting").alias("d")) >>> result.collect_column("d")[0].as_py() 3 """ + string2 = coerce_to_expr(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1199,18 +1210,19 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr, num: Expr) -> Expr: +def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041 """Returns the logarithm of a number for a particular ``base``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [100.0]}) >>> result = df.select( - ... dfn.functions.log(dfn.lit(10.0), dfn.col("a")).alias("log") + ... dfn.functions.log(10.0, dfn.col("a")).alias("log") ... ) >>> result.collect_column("log")[0].as_py() 2.0 """ + base = coerce_to_expr(base) return Expr(f.log(base.expr, num.expr)) @@ -1253,7 +1265,7 @@ def lower(arg: Expr) -> Expr: return Expr(f.lower(arg.expr)) -def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add left padding to a string. Extends the string to length length by prepending the characters fill (a @@ -1264,9 +1276,7 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat", "a hat"]}) >>> lpad_df = df.select( - ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(6) - ... ).alias("lpad")) + ... dfn.functions.lpad(dfn.col("a"), 6).alias("lpad")) >>> lpad_df.collect_column("lpad")[0].as_py() 'the ca' >>> lpad_df.collect_column("lpad")[1].as_py() @@ -1274,12 +1284,13 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(10), characters=dfn.lit(".") + ... dfn.col("a"), 10, characters="." ... ).alias("lpad")) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1374,7 +1385,10 @@ def octet_length(arg: Expr) -> Expr: def overlay( - string: Expr, substring: Expr, start: Expr, length: Expr | None = None + string: Expr, + substring: Expr | str, + start: Expr | int, + length: Expr | int | None = None, ) -> Expr: """Replace a substring with a new substring. @@ -1385,13 +1399,15 @@ def overlay( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcdef"]}) >>> result = df.select( - ... dfn.functions.overlay(dfn.col("a"), dfn.lit("XY"), dfn.lit(3), - ... dfn.lit(2)).alias("o")) + ... dfn.functions.overlay(dfn.col("a"), "XY", 3, 2).alias("o")) >>> result.collect_column("o")[0].as_py() 'abXYef' """ + substring = coerce_to_expr(substring) + start = coerce_to_expr(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) + length = coerce_to_expr(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1411,7 +1427,7 @@ def pi() -> Expr: return Expr(f.pi()) -def position(string: Expr, substring: Expr) -> Expr: +def position(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1420,22 +1436,23 @@ def position(string: Expr, substring: Expr) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr) -> Expr: +def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [2.0]}) >>> result = df.select( - ... dfn.functions.power(dfn.col("a"), dfn.lit(3.0)).alias("pow") + ... dfn.functions.power(dfn.col("a"), 3.0).alias("pow") ... ) >>> result.collect_column("pow")[0].as_py() 8.0 """ + exponent = coerce_to_expr(exponent) return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr) -> Expr: +def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1460,7 +1477,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1470,9 +1489,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( - ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.functions.regexp_like(dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1481,19 +1498,24 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_like( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1503,9 +1525,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.functions.regexp_match(dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1514,20 +1534,26 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_match( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1542,8 +1568,7 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1554,20 +1579,30 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + pattern = coerce_to_expr(pattern) + replacement = coerce_to_expr(replacement) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_replace( + string.expr, + pattern.expr, + replacement.expr, + flags.expr if flags is not None else None, + ) + ) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1578,9 +1613,7 @@ def regexp_count( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( - ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.functions.regexp_count(dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1589,25 +1622,31 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ - if flags is not None: - flags = flags.expr - start = start.expr if start is not None else start - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1623,9 +1662,7 @@ def regexp_instr( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.functions.regexp_instr(dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1636,9 +1673,8 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1648,56 +1684,58 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + regex = coerce_to_expr(regex) + start = coerce_to_expr_or_none(start) + n = coerce_to_expr_or_none(n) + flags = coerce_to_expr_or_none(flags) + sub_expr = coerce_to_expr_or_none(sub_expr) return Expr( f.regexp_instr( values.expr, regex.expr, - start, - n, - flags, - sub_expr, + start.expr if start is not None else None, + n.expr if n is not None else None, + flags.expr if flags is not None else None, + sub_expr.expr if sub_expr is not None else None, ) ) -def repeat(string: Expr, n: Expr) -> Expr: +def repeat(string: Expr, n: Expr | int) -> Expr: """Repeats the ``string`` to ``n`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["ha"]}) >>> result = df.select( - ... dfn.functions.repeat(dfn.col("a"), dfn.lit(3)).alias("r")) + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'hahaha' """ + n = coerce_to_expr(n) return Expr(f.repeat(string.expr, n.expr)) -def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces all occurrences of ``from_val`` with ``to_val`` in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.replace(dfn.col("a"), dfn.lit("world"), - ... dfn.lit("there")).alias("r")) + ... dfn.functions.replace(dfn.col("a"), "world", "there").alias("r")) >>> result.collect_column("r")[0].as_py() 'hello there' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1714,39 +1752,39 @@ def reverse(arg: Expr) -> Expr: return Expr(f.reverse(arg.expr)) -def right(string: Expr, n: Expr) -> Expr: +def right(string: Expr, n: Expr | int) -> Expr: """Returns the last ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) - >>> result = df.select(dfn.functions.right(dfn.col("a"), dfn.lit(3)).alias("r")) + >>> result = df.select(dfn.functions.right(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'llo' """ + n = coerce_to_expr(n) return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: +def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: """Round the argument to the nearest integer. If the optional ``decimal_places`` is specified, round to the nearest number of decimal places. You can specify a negative number of decimal places. For example - ``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``. + ``round(lit(125.2345), -2)`` would yield a value of ``100.0``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) - >>> result = df.select(dfn.functions.round(dfn.col("a"), dfn.lit(2)).alias("r")) + >>> result = df.select(dfn.functions.round(dfn.col("a"), 2).alias("r")) >>> result.collect_column("r")[0].as_py() 1.57 """ - if decimal_places is None: - decimal_places = Expr.literal(0) + decimal_places = coerce_to_expr(decimal_places if decimal_places is not None else 0) return Expr(f.round(value.expr, decimal_places.expr)) -def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add right padding to a string. Extends the string to length length by appending the characters fill (a space @@ -1756,11 +1794,12 @@ def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hi"]}) >>> result = df.select( - ... dfn.functions.rpad(dfn.col("a"), dfn.lit(5), dfn.lit("!")).alias("r")) + ... dfn.functions.rpad(dfn.col("a"), 5, "!").alias("r")) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1876,7 +1915,7 @@ def sinh(arg: Expr) -> Expr: return Expr(f.sinh(arg.expr)) -def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: +def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: """Split a string and return one part. Splits a string based on a delimiter and picks out the desired field based @@ -1886,12 +1925,12 @@ def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a,b,c"]}) >>> result = df.select( - ... dfn.functions.split_part( - ... dfn.col("a"), dfn.lit(","), dfn.lit(2) - ... ).alias("s")) + ... dfn.functions.split_part(dfn.col("a"), ",", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'b' """ + delimiter = coerce_to_expr(delimiter) + index = coerce_to_expr(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1908,49 +1947,52 @@ def sqrt(arg: Expr) -> Expr: return Expr(f.sqrt(arg.expr)) -def starts_with(string: Expr, prefix: Expr) -> Expr: +def starts_with(string: Expr, prefix: Expr | str) -> Expr: """Returns true if string starts with prefix. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello_from_datafusion"]}) >>> result = df.select( - ... dfn.functions.starts_with(dfn.col("a"), dfn.lit("hello")).alias("sw")) + ... dfn.functions.starts_with(dfn.col("a"), "hello").alias("sw")) >>> result.collect_column("sw")[0].as_py() True """ + prefix = coerce_to_expr(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) -def strpos(string: Expr, substring: Expr) -> Expr: +def strpos(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.strpos(dfn.col("a"), dfn.lit("llo")).alias("pos")) + ... dfn.functions.strpos(dfn.col("a"), "llo").alias("pos")) >>> result.collect_column("pos")[0].as_py() 3 """ + substring = coerce_to_expr(substring) return Expr(f.strpos(string.expr, substring.expr)) -def substr(string: Expr, position: Expr) -> Expr: +def substr(string: Expr, position: Expr | int) -> Expr: """Substring from the ``position`` to the end. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.substr(dfn.col("a"), dfn.lit(3)).alias("s")) + ... dfn.functions.substr(dfn.col("a"), 3).alias("s")) >>> result.collect_column("s")[0].as_py() 'llo' """ + position = coerce_to_expr(position) return Expr(f.substr(string.expr, position.expr)) -def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: +def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr: """Returns an indexed substring. The return will be the ``string`` from before ``count`` occurrences of @@ -1960,27 +2002,28 @@ def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a.b.c"]}) >>> result = df.select( - ... dfn.functions.substr_index(dfn.col("a"), dfn.lit("."), - ... dfn.lit(2)).alias("s")) + ... dfn.functions.substr_index(dfn.col("a"), ".", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'a.b' """ + delimiter = coerce_to_expr(delimiter) + count = coerce_to_expr(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) -def substring(string: Expr, position: Expr, length: Expr) -> Expr: +def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: """Substring from the ``position`` with ``length`` characters. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.substring( - ... dfn.col("a"), dfn.lit(1), dfn.lit(5) - ... ).alias("s")) + ... dfn.functions.substring(dfn.col("a"), 1, 5).alias("s")) >>> result.collect_column("s")[0].as_py() 'hello' """ + position = coerce_to_expr(position) + length = coerce_to_expr(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -2053,7 +2096,7 @@ def current_timestamp() -> Expr: return now() -def to_char(arg: Expr, formatter: Expr) -> Expr: +def to_char(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. For usage of ``formatter`` see the rust chrono package ``strftime`` package. @@ -2066,16 +2109,17 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: >>> result = df.select( ... dfn.functions.to_char( ... dfn.functions.to_timestamp(dfn.col("a")), - ... dfn.lit("%Y/%m/%d"), + ... "%Y/%m/%d", ... ).alias("formatted") ... ) >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ + formatter = coerce_to_expr(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) -def date_format(arg: Expr, formatter: Expr) -> Expr: +def date_format(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. See Also: @@ -2287,7 +2331,7 @@ def current_time() -> Expr: return Expr(f.current_time()) -def datepart(part: Expr, date: Expr) -> Expr: +def datepart(part: Expr | str, date: Expr) -> Expr: """Return a specified part of a date. See Also: @@ -2296,22 +2340,28 @@ def datepart(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_part(part: Expr, date: Expr) -> Expr: +def date_part(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. + Args: + part: The part of the date to extract. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to extract from. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_part(dfn.lit("year"), dfn.col("a")).alias("y")) + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) >>> result.collect_column("y")[0].as_py() 2021 """ + part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) -def extract(part: Expr, date: Expr) -> Expr: +def extract(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. See Also: @@ -2320,25 +2370,29 @@ def extract(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_trunc(part: Expr, date: Expr) -> Expr: +def date_trunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. + Args: + part: The precision to truncate to. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to truncate. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_trunc( - ... dfn.lit("month"), dfn.col("a") - ... ).alias("t") + ... dfn.functions.date_trunc("month", dfn.col("a")).alias("t") ... ) >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ + part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) -def datetrunc(part: Expr, date: Expr) -> Expr: +def datetrunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. See Also: @@ -2399,18 +2453,19 @@ def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: return Expr(f.make_time(hour.expr, minute.expr, second.expr)) -def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.translate(dfn.col("a"), dfn.lit("helo"), - ... dfn.lit("HELO")).alias("t")) + ... dfn.functions.translate(dfn.col("a"), "helo", "HELO").alias("t")) >>> result.collect_column("t")[0].as_py() 'HELLO' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2427,27 +2482,24 @@ def trim(arg: Expr) -> Expr: return Expr(f.trim(arg.expr)) -def trunc(num: Expr, precision: Expr | None = None) -> Expr: +def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: """Truncate the number toward zero with optional precision. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a") - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a")).alias("t")) >>> result.collect_column("t")[0].as_py() 1.0 >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a"), precision=dfn.lit(2) - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a"), precision=2).alias("t")) >>> result.collect_column("t")[0].as_py() 1.56 """ if precision is not None: + precision = coerce_to_expr(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -2928,17 +2980,18 @@ def list_dims(array: Expr) -> Expr: return array_dims(array) -def array_element(array: Expr, n: Expr) -> Expr: +def array_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[10, 20, 30]]}) >>> result = df.select( - ... dfn.functions.array_element(dfn.col("a"), dfn.lit(2)).alias("result")) + ... dfn.functions.array_element(dfn.col("a"), 2).alias("result")) >>> result.collect_column("result")[0].as_py() 20 """ + n = coerce_to_expr(n) return Expr(f.array_element(array.expr, n.expr)) @@ -2964,7 +3017,7 @@ def list_empty(array: Expr) -> Expr: return array_empty(array) -def array_extract(array: Expr, n: Expr) -> Expr: +def array_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2973,7 +3026,7 @@ def array_extract(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_element(array: Expr, n: Expr) -> Expr: +def list_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2982,7 +3035,7 @@ def list_element(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_extract(array: Expr, n: Expr) -> Expr: +def list_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -3332,22 +3385,24 @@ def list_remove(array: Expr, element: Expr) -> Expr: return array_remove(array, element) -def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_remove_n(dfn.col("a"), dfn.lit(1), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_remove_n( + ... dfn.col("a"), dfn.lit(1), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 1] """ + max = coerce_to_expr(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) -def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def list_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. See Also: @@ -3381,21 +3436,22 @@ def list_remove_all(array: Expr, element: Expr) -> Expr: return array_remove_all(array, element) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) >>> result = df.select( - ... dfn.functions.array_repeat(dfn.lit(3), dfn.lit(3)).alias("result")) + ... dfn.functions.array_repeat(dfn.lit(3), 3).alias("result")) >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ + count = coerce_to_expr(count) return Expr(f.array_repeat(element.expr, count.expr)) -def list_repeat(element: Expr, count: Expr) -> Expr: +def list_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. See Also: @@ -3428,7 +3484,7 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace(array, from_val, to_val) -def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3438,15 +3494,17 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Exp >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_replace_n(dfn.col("a"), dfn.lit(1), dfn.lit(9), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_replace_n( + ... dfn.col("a"), dfn.lit(1), dfn.lit(9), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ + max = coerce_to_expr(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) -def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3529,7 +3587,10 @@ def list_sort(array: Expr, descending: bool = False, null_first: bool = False) - def array_slice( - array: Expr, begin: Expr, end: Expr, stride: Expr | None = None + array: Expr, + begin: Expr | int, + end: Expr | int, + stride: Expr | int | None = None, ) -> Expr: """Returns a slice of the array. @@ -3537,9 +3598,7 @@ def array_slice( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) >>> result = df.select( - ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(2), dfn.lit(3) - ... ).alias("result")) + ... dfn.functions.array_slice(dfn.col("a"), 2, 3).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 3] @@ -3547,18 +3606,27 @@ def array_slice( >>> result = df.select( ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(1), dfn.lit(4), - ... stride=dfn.lit(2), + ... dfn.col("a"), 1, 4, stride=2, ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 3] """ - if stride is not None: - stride = stride.expr - return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + begin = coerce_to_expr(begin) + end = coerce_to_expr(end) + stride = coerce_to_expr_or_none(stride) + return Expr( + f.array_slice( + array.expr, + begin.expr, + end.expr, + stride.expr if stride is not None else None, + ) + ) -def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: +def list_slice( + array: Expr, begin: Expr | int, end: Expr | int, stride: Expr | int | None = None +) -> Expr: """Returns a slice of the array. See Also: @@ -3650,7 +3718,7 @@ def list_except(array1: Expr, array2: Expr) -> Expr: return array_except(array1, array2) -def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will @@ -3660,15 +3728,15 @@ def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2]]}) >>> result = df.select( - ... dfn.functions.array_resize(dfn.col("a"), dfn.lit(4), - ... dfn.lit(0)).alias("result")) + ... dfn.functions.array_resize(dfn.col("a"), 4, dfn.lit(0)).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ + size = coerce_to_expr(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) -def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def list_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will be @@ -3822,7 +3890,7 @@ def list_zip(*arrays: Expr) -> Expr: def string_to_array( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. @@ -3832,9 +3900,7 @@ def string_to_array( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello,world"]}) >>> result = df.select( - ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), - ... ).alias("result")) + ... dfn.functions.string_to_array(dfn.col("a"), ",").alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', 'world'] @@ -3842,17 +3908,24 @@ def string_to_array( >>> result = df.select( ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), null_string=dfn.lit("world"), + ... dfn.col("a"), ",", null_string="world", ... ).alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', None] """ - null_expr = null_string.expr if null_string is not None else None - return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) + delimiter = coerce_to_expr(delimiter) + null_string = coerce_to_expr_or_none(null_string) + return Expr( + f.string_to_array( + string.expr, + delimiter.expr, + null_string.expr if null_string is not None else None, + ) + ) def string_to_list( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 13c05a9e6..e0ebdbae5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -964,12 +964,12 @@ def test_csv_read_options_builder_pattern(): options = ( CsvReadOptions() - .with_has_header(False) # noqa: FBT003 + .with_has_header(False) .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) - .with_truncated_rows(True) # noqa: FBT003 - .with_newlines_in_values(True) # noqa: FBT003 + .with_truncated_rows(True) + .with_newlines_in_values(True) .with_file_extension(".tsv") ) assert options.has_header is False diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 091fa9b56..9e2f791ea 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -3426,10 +3426,18 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +_slow_udf_started = threading.Event() + + @udf([pa.int64()], pa.int64(), "immutable") def slow_udf(x: pa.Array) -> pa.Array: - # This must be longer than the check interval in wait_for_future - time.sleep(2.0) + _slow_udf_started.set() + # Sleep in small increments so Python's eval loop checks for pending + # async exceptions (like KeyboardInterrupt via PyThreadState_SetAsyncExc) + # between iterations. A single long time.sleep() is a C call where async + # exceptions are not checked on all Python versions (notably 3.11). + for _ in range(200): + time.sleep(0.01) return x @@ -3463,6 +3471,7 @@ def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 P if as_c_stream: reader = pa.RecordBatchReader.from_stream(df) + _slow_udf_started.clear() read_started = threading.Event() read_exception = [] read_thread_id = None @@ -3474,6 +3483,14 @@ def trigger_interrupt(): msg = f"Read operation did not start within {max_wait_time} seconds" raise RuntimeError(msg) + # For slow_query tests, wait until the UDF is actually executing Python + # bytecode before sending the interrupt. PyThreadState_SetAsyncExc only + # delivers exceptions when the thread is in the Python eval loop, not + # while in native (Rust/C) code. + if slow_query and not _slow_udf_started.wait(timeout=max_wait_time): + msg = f"UDF did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + if read_thread_id is None: msg = "Cannot get read thread ID" raise RuntimeError(msg) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index d046eb48c..8aa791ae1 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -53,6 +53,8 @@ TransactionEnd, TransactionStart, Values, + coerce_to_expr, + coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @@ -1030,12 +1032,55 @@ def test_ensure_expr_list_bytearray(): ensure_expr_list(bytearray(b"a")) +def test_coerce_to_expr_passes_expr_through(): + e = col("a") + result = coerce_to_expr(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + +def test_coerce_to_expr_wraps_int(): + result = coerce_to_expr(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_wraps_str(): + result = coerce_to_expr("hello") + assert isinstance(result, type(lit("hello"))) + + +def test_coerce_to_expr_wraps_float(): + result = coerce_to_expr(3.14) + assert isinstance(result, type(lit(3.14))) + + +def test_coerce_to_expr_wraps_bool(): + result = coerce_to_expr(True) + assert isinstance(result, type(lit(True))) + + +def test_coerce_to_expr_or_none_returns_none(): + assert coerce_to_expr_or_none(None) is None + + +def test_coerce_to_expr_or_none_wraps_value(): + result = coerce_to_expr_or_none(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_or_none_passes_expr_through(): + e = col("a") + result = coerce_to_expr_or_none(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + @pytest.mark.parametrize( "value", [ # Boolean - pa.scalar(True, type=pa.bool_()), # noqa: FBT003 - pa.scalar(False, type=pa.bool_()), # noqa: FBT003 + pa.scalar(True, type=pa.bool_()), + pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 11e94af1c..d9781b1fb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -2099,3 +2099,96 @@ def test_gen_series_with_step(): f.gen_series(literal(1), literal(10), literal(3)).alias("v") ).collect() assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] + + +class TestPythonicNativeTypes: + """Tests for accepting native Python types instead of requiring lit().""" + + def test_split_part_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect() + assert result[0].column(0)[0].as_py() == "b" + + def test_encode_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello"]}) + result = df.select(f.encode(column("a"), "base64").alias("e")).collect() + assert result[0].column(0)[0].as_py() == "aGVsbG8" + + def test_date_part_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_part("year", column("a")).alias("y")).collect() + assert result[0].column(0)[0].as_py() == 2021 + + def test_date_trunc_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_trunc("month", column("a")).alias("t")).collect() + assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00" + + def test_left_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["the cat"]}) + result = df.select(f.left(column("a"), 3).alias("l")).collect() + assert result[0].column(0)[0].as_py() == "the" + + def test_round_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.567]}) + result = df.select(f.round(column("a"), 2).alias("r")).collect() + assert result[0].column(0)[0].as_py() == 1.57 + + def test_regexp_count_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["abcabc"]}) + result = df.select( + f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c") + ).collect() + assert result[0].column(0)[0].as_py() == 1 + + def test_log_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [100.0]}) + result = df.select(f.log(10, column("a")).alias("l")).collect() + assert result[0].column(0)[0].as_py() == 2.0 + + def test_power_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [2.0]}) + result = df.select(f.power(column("a"), 3).alias("p")).collect() + assert result[0].column(0)[0].as_py() == 8.0 + + def test_array_slice_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) + result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect() + assert result[0].column(0)[0].as_py() == [2, 3] + + def test_string_to_array_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), ",", null_string="NA").alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == ["hello", None, "world"] + + def test_regexp_replace_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a1 b2 c3"]}) + result = df.select( + f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r") + ).collect() + assert result[0].column(0)[0].as_py() == "aX bX cX" + + def test_backward_compat_with_lit(self): + """Verify that existing code using lit() still works.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select( + f.split_part(column("a"), literal(","), literal(2)).alias("s") + ).collect() + assert result[0].column(0)[0].as_py() == "b"