diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md new file mode 100644 index 000000000..ac4835a4e --- /dev/null +++ b/.ai/skills/check-upstream/SKILL.md @@ -0,0 +1,383 @@ + + +--- +name: check-upstream +description: Check if upstream Apache DataFusion features (functions, DataFrame ops, SessionContext methods, FFI types) are exposed in this Python project. Use when adding missing functions, auditing API coverage, or ensuring parity with upstream. +argument-hint: [area] (e.g., "scalar functions", "aggregate functions", "window functions", "dataframe", "session context", "ffi types", "all") +--- + +# Check Upstream DataFusion Feature Coverage + +You are auditing the datafusion-python project to find features from the upstream Apache DataFusion Rust library that are **not yet exposed** in this Python binding project. Your goal is to identify gaps and, if asked, implement the missing bindings. + +**IMPORTANT: The Python API is the source of truth for coverage.** A function or method is considered "exposed" if it exists in the Python API (e.g., `python/datafusion/functions.py`), even if there is no corresponding entry in the Rust bindings. Many upstream functions are aliases of other functions — the Python layer can expose these aliases by calling a different underlying Rust binding. Do NOT report a function as missing if it appears in the Python `__all__` list and has a working implementation, regardless of whether a matching `#[pyfunction]` exists in Rust. + +## Areas to Check + +The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" is given, check all areas. + +### 1. Scalar Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/scalar_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions` +- Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()` + +**How to check:** +1. Fetch the upstream scalar function documentation page +2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions) +3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding. +4. Only report functions that are missing from the Python `__all__` list / function definitions + +### 2. Aggregate Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_aggregate/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/aggregate_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions) +- Rust bindings: `crates/core/src/functions.rs` + +**How to check:** +1. Fetch the upstream aggregate function documentation page +2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions) +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 3. Window Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_window/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/window_functions.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` (window functions like `rank`, `dense_rank`, `lag`, `lead`, etc.) +- Rust bindings: `crates/core/src/functions.rs` + +**How to check:** +1. Fetch the upstream window function documentation page +2. Compare against window functions in `python/datafusion/functions.py` (check `__all__` list and function definitions) +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 4. Table Functions + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/functions_table/index.html +- User docs: https://datafusion.apache.org/user-guide/sql/table_functions.html (if available) + +**Where they are exposed in this project:** +- Python API: `python/datafusion/functions.py` and `python/datafusion/user_defined.py` (TableFunction/udtf) +- Rust bindings: `crates/core/src/functions.rs` and `crates/core/src/udtf.rs` + +**How to check:** +1. Fetch the upstream table function documentation +2. Compare against what's available in the Python API +3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding +4. Report only functions missing from the Python API + +### 5. DataFrame Operations + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/dataframe.py` — the `DataFrame` class +- Rust bindings: `crates/core/src/dataframe.rs` — `PyDataFrame` with `#[pymethods]` + +**Evaluated and not requiring separate Python exposure:** +- `show_limit` — already covered by `DataFrame.show()`, which provides the same functionality with a simpler API +- `with_param_values` — already covered by the `param_values` argument on `SessionContext.sql()`, which accomplishes the same thing more robustly +- `union_by_name_distinct` — already covered by `DataFrame.union_by_name(distinct=True)`, which provides a more Pythonic API + +**How to check:** +1. Fetch the upstream DataFrame documentation page listing all methods +2. Compare against methods in `python/datafusion/dataframe.py` — this is the source of truth for coverage +3. The Rust bindings (`crates/core/src/dataframe.rs`) may be consulted for context, but a method is covered if it exists in the Python API +4. Check against the "evaluated and not requiring exposure" list before flagging as a gap +5. Report only methods missing from the Python API + +### 6. SessionContext Methods + +**Upstream source of truth:** +- Rust docs: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html + +**Where they are exposed in this project:** +- Python API: `python/datafusion/context.py` — the `SessionContext` class +- Rust bindings: `crates/core/src/context.rs` — `PySessionContext` with `#[pymethods]` + +**How to check:** +1. Fetch the upstream SessionContext documentation page listing all methods +2. Compare against methods in `python/datafusion/context.py` — this is the source of truth for coverage +3. The Rust bindings (`crates/core/src/context.rs`) may be consulted for context, but a method is covered if it exists in the Python API +4. Report only methods missing from the Python API + +### 7. FFI Types (datafusion-ffi) + +**Upstream source of truth:** +- Crate source: https://github.com/apache/datafusion/tree/main/datafusion/ffi/src +- Rust docs: https://docs.rs/datafusion-ffi/latest/datafusion_ffi/ + +**Where they are exposed in this project:** +- Rust bindings: various files under `crates/core/src/` and `crates/util/src/` +- FFI example: `examples/datafusion-ffi-example/src/` +- Dependency declared in root `Cargo.toml` and `crates/core/Cargo.toml` + +**Discovering currently supported FFI types:** +Grep for `use datafusion_ffi::` in `crates/core/src/` and `crates/util/src/` to find all FFI types currently imported and used. + +**Evaluated and not requiring direct Python exposure:** +These upstream FFI types have been reviewed and do not need to be independently exposed to end users: +- `FFI_ExecutionPlan` — already used indirectly through table providers; no need for direct exposure +- `FFI_PhysicalExpr` / `FFI_PhysicalSortExpr` — internal physical planning types not expected to be needed by end users +- `FFI_RecordBatchStream` — one level deeper than FFI_ExecutionPlan, used internally when execution plans stream results +- `FFI_SessionRef` / `ForeignSession` — session sharing across FFI; Python manages sessions natively via SessionContext +- `FFI_SessionConfig` — Python can configure sessions natively without FFI +- `FFI_ConfigOptions` / `FFI_TableOptions` — internal configuration plumbing +- `FFI_PlanProperties` / `FFI_Boundedness` / `FFI_EmissionType` — read from existing plans, not user-facing +- `FFI_Partitioning` — supporting type for physical planning +- Supporting/utility types (`FFI_Option`, `FFI_Result`, `WrappedSchema`, `WrappedArray`, `FFI_ColumnarValue`, `FFI_Volatility`, `FFI_InsertOp`, `FFI_AccumulatorArgs`, `FFI_Accumulator`, `FFI_GroupsAccumulator`, `FFI_EmitTo`, `FFI_AggregateOrderSensitivity`, `FFI_PartitionEvaluator`, `FFI_PartitionEvaluatorArgs`, `FFI_Range`, `FFI_SortOptions`, `FFI_Distribution`, `FFI_ExprProperties`, `FFI_SortProperties`, `FFI_Interval`, `FFI_TableProviderFilterPushDown`, `FFI_TableType`) — used as building blocks within the types above, not independently exposed + +**How to check:** +1. Discover currently supported types by grepping for `use datafusion_ffi::` in `crates/core/src/` and `crates/util/src/`, then compare against the upstream `datafusion-ffi` crate's `lib.rs` exports +2. If new FFI types appear upstream, evaluate whether they represent a user-facing capability +3. Check against the "evaluated and not requiring exposure" list before flagging as a gap +4. Report any genuinely new types that enable user-facing functionality +5. For each currently supported FFI type, verify the full pipeline is present using the checklist from "Adding a New FFI Type": + - Rust PyO3 wrapper with `from_pycapsule()` method + - Python Protocol type (e.g., `ScalarUDFExportable`) for FFI objects + - Python wrapper class with full type hints on all public methods + - ABC base class (if the type can be user-implemented) + - Registered in Rust `init_module()` and Python `__init__.py` + - FFI example in `examples/datafusion-ffi-example/` + - Type appears in union type hints where accepted + +## Checking for Existing GitHub Issues + +After identifying missing APIs, search the open issues at https://github.com/apache/datafusion-python/issues for each gap to see if an issue already exists requesting that API be exposed. Search using the function or method name as the query. + +- If an existing issue is found, include a link to it in the report. Do NOT create a new issue. +- If no existing issue is found, note that no issue exists yet. If the user asks to create issues for missing APIs, each issue should specify that Python test coverage is required as part of the implementation. + +## Output Format + +For each area checked, produce a report like: + +``` +## [Area Name] Coverage Report + +### Currently Exposed (X functions/methods) +- list of what's already available + +### Missing from Upstream (Y functions/methods) +- function_name — brief description of what it does (existing issue: #123) +- function_name — brief description of what it does (no existing issue) + +### Notes +- Any relevant observations about partial implementations, naming differences, etc. +``` + +## Implementation Pattern + +If the user asks you to implement missing features, follow these patterns: + +### Adding a New Function (Scalar/Aggregate/Window) + +**Step 1: Rust binding** in `crates/core/src/functions.rs`: +```rust +#[pyfunction] +#[pyo3(signature = (arg1, arg2))] +fn new_function_name(arg1: PyExpr, arg2: PyExpr) -> PyResult { + Ok(datafusion::functions::module::expr_fn::new_function_name(arg1.expr, arg2.expr).into()) +} +``` +Then register in `init_module()`: +```rust +m.add_wrapped(wrap_pyfunction!(new_function_name))?; +``` + +**Step 2: Python wrapper** in `python/datafusion/functions.py`: +```python +def new_function_name(arg1: Expr, arg2: Expr) -> Expr: + """Description of what the function does. + + Args: + arg1: Description of first argument. + arg2: Description of second argument. + + Returns: + Description of return value. + """ + return Expr(f.new_function_name(arg1.expr, arg2.expr)) +``` +Add to `__all__` list. + +### Adding a New DataFrame Method + +**Step 1: Rust binding** in `crates/core/src/dataframe.rs`: +```rust +#[pymethods] +impl PyDataFrame { + fn new_method(&self, py: Python, param: PyExpr) -> PyDataFusionResult { + let df = self.df.as_ref().clone().new_method(param.into())?; + Ok(Self::new(df)) + } +} +``` + +**Step 2: Python wrapper** in `python/datafusion/dataframe.py`: +```python +def new_method(self, param: Expr) -> DataFrame: + """Description of the method.""" + return DataFrame(self.df.new_method(param.expr)) +``` + +### Adding a New SessionContext Method + +**Step 1: Rust binding** in `crates/core/src/context.rs`: +```rust +#[pymethods] +impl PySessionContext { + pub fn new_method(&self, py: Python, param: String) -> PyDataFusionResult { + let df = wait_for_future(py, self.ctx.new_method(¶m))?; + Ok(PyDataFrame::new(df)) + } +} +``` + +**Step 2: Python wrapper** in `python/datafusion/context.py`: +```python +def new_method(self, param: str) -> DataFrame: + """Description of the method.""" + return DataFrame(self.ctx.new_method(param)) +``` + +### Adding a New FFI Type + +FFI types require a full pipeline from C struct through to a typed Python wrapper. Each layer must be present. + +**Step 1: Rust PyO3 wrapper class** in a new or existing file under `crates/core/src/`: +```rust +use datafusion_ffi::new_type::FFI_NewType; + +#[pyclass(from_py_object, frozen, name = "RawNewType", module = "datafusion.module_name", subclass)] +pub struct PyNewType { + pub inner: Arc, +} + +#[pymethods] +impl PyNewType { + #[staticmethod] + fn from_pycapsule(obj: &Bound<'_, PyAny>) -> PyDataFusionResult { + let capsule = obj + .getattr("__datafusion_new_type__")? + .call0()? + .downcast::()?; + let ffi_ptr = unsafe { capsule.reference::() }; + let provider: Arc = ffi_ptr.into(); + Ok(Self { inner: provider }) + } + + fn some_method(&self) -> PyResult<...> { + // wrap inner trait method + } +} +``` +Register in the appropriate `init_module()`: +```rust +m.add_class::()?; +``` + +**Step 2: Python Protocol type** in the appropriate Python module (e.g., `python/datafusion/catalog.py`): +```python +class NewTypeExportable(Protocol): + """Type hint for objects providing a __datafusion_new_type__ PyCapsule.""" + + def __datafusion_new_type__(self) -> object: ... +``` + +**Step 3: Python wrapper class** in the same module: +```python +class NewType: + """Description of the type. + + This class wraps a DataFusion NewType, which can be created from a native + Python implementation or imported from an FFI-compatible library. + """ + + def __init__( + self, + new_type: df_internal.module_name.RawNewType | NewTypeExportable, + ) -> None: + if isinstance(new_type, df_internal.module_name.RawNewType): + self._raw = new_type + else: + self._raw = df_internal.module_name.RawNewType.from_pycapsule(new_type) + + def some_method(self) -> ReturnType: + """Description of the method.""" + return self._raw.some_method() +``` + +**Step 4: ABC base class** (if users should be able to subclass and provide custom implementations in Python): +```python +from abc import ABC, abstractmethod + +class NewTypeProvider(ABC): + """Abstract base class for implementing a custom NewType in Python.""" + + @abstractmethod + def some_method(self) -> ReturnType: + """Description of the method.""" + ... +``` + +**Step 5: Module exports** — add to the appropriate `__init__.py`: +- Add the wrapper class (`NewType`) to `python/datafusion/__init__.py` +- Add the ABC (`NewTypeProvider`) if applicable +- Add the Protocol type (`NewTypeExportable`) if it should be public + +**Step 6: FFI example** — add an example implementation under `examples/datafusion-ffi-example/src/`: +```rust +// examples/datafusion-ffi-example/src/new_type.rs +use datafusion_ffi::new_type::FFI_NewType; +// ... example showing how an external Rust library exposes this type via PyCapsule +``` + +**Checklist for each FFI type:** +- [ ] Rust PyO3 wrapper with `from_pycapsule()` method +- [ ] Python Protocol type (e.g., `NewTypeExportable`) for FFI objects +- [ ] Python wrapper class with full type hints on all public methods +- [ ] ABC base class (if the type can be user-implemented) +- [ ] Registered in Rust `init_module()` and Python `__init__.py` +- [ ] FFI example in `examples/datafusion-ffi-example/` +- [ ] Type appears in union type hints where accepted (e.g., `Table | TableProviderExportable`) + +## Important Notes + +- The upstream DataFusion version used by this project is specified in `crates/core/Cargo.toml` — check the `datafusion` dependency version to ensure you're comparing against the right upstream version. +- Some upstream features may intentionally not be exposed (e.g., internal-only APIs). Use judgment about what's user-facing. +- When fetching upstream docs, prefer the published docs.rs documentation as it matches the crate version. +- Function aliases (e.g., `array_append` / `list_append`) should both be exposed if upstream supports them. +- Check the `__all__` list in `functions.py` to see what's publicly exported vs just defined. diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md new file mode 100644 index 000000000..57145ac6c --- /dev/null +++ b/.ai/skills/make-pythonic/SKILL.md @@ -0,0 +1,430 @@ + + +--- +name: make-pythonic +description: Audit and improve datafusion-python functions to accept native Python types (int, float, str, bool) instead of requiring explicit lit() or col() wrapping. Analyzes function signatures, checks upstream Rust implementations for type constraints, and applies the appropriate coercion pattern. +argument-hint: [scope] (e.g., "string functions", "datetime functions", "array functions", "math functions", "all", or a specific function name like "split_part") +--- + +# Make Python API Functions More Pythonic + +You are improving the datafusion-python API to feel more natural to Python users. The goal is to allow functions to accept native Python types (int, float, str, bool, etc.) for arguments that are contextually always or typically literal values, instead of requiring users to manually wrap them in `lit()`. + +**Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. + +## How to Identify Candidates + +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. + +For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: + +### Signal 1: Contextual Understanding + +Some arguments are contextually always or almost always literal values based on what the function does: + +| Context | Typical Arguments | Examples | +|---------|------------------|----------| +| **String position/count** | Character counts, indices, repetition counts | `left(str, n)`, `right(str, n)`, `repeat(str, n)`, `lpad(str, count, ...)` | +| **Delimiters/separators** | Fixed separator characters | `split_part(str, delim, idx)`, `concat_ws(sep, ...)` | +| **Search/replace patterns** | Literal search strings, replacements | `replace(str, from, to)`, `regexp_replace(str, pattern, replacement, flags)` | +| **Date/time parts** | Part names from a fixed set | `date_part(part, date)`, `date_trunc(part, date)` | +| **Rounding precision** | Decimal place counts | `round(val, places)`, `trunc(val, places)` | +| **Fill characters** | Padding characters | `lpad(str, count, fill)`, `rpad(str, count, fill)` | + +### Signal 2: Upstream Rust Implementation + +Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFusion function implementation to determine type constraints. The upstream source is cached locally at: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-/src/ +``` + +Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. + +For **aggregate functions**, the upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ +``` + +There are five concrete techniques to check, in order of signal strength: + +#### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) + +Some functions pattern-match on `ColumnarValue::Scalar` in their `invoke_with_args()` method and **return an error** if the argument is a column/array. This means the argument **must** be a literal — passing a column expression will fail at runtime. + +Example from `date_trunc.rs`: +```rust +let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v.to_lowercase() +} else { + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); +}; +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. + +#### Technique 1a: Check `accumulator()` for literal-only enforcement (aggregate functions) + +Technique 1 applies to scalar UDFs. Aggregate functions do not have `invoke_with_args()` — instead, they enforce literal-only arguments in their `accumulator()` (or `create_accumulator()`) method, which runs at planning time before any data is processed. + +Look for these patterns inside `accumulator()`: + +- `get_scalar_value(expr)` — evaluates the expression against an empty batch and errors if it's not a scalar +- `validate_percentile_expr(expr)` — specific helper used by percentile functions +- `downcast_ref::()` — checks that the physical expression is a literal constant + +Example from `approx_percentile_cont.rs`: +```rust +fn accumulator(&self, args: AccumulatorArgs) -> Result { + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; + // ... +} +``` + +Where `validate_percentile_expr` calls `get_scalar_value` and errors with `"must be a literal"`. + +Example from `string_agg.rs`: +```rust +fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which aggregate functions have literal-only arguments, search the upstream aggregate crate for `get_scalar_value`, `validate_percentile_expr`, and `downcast_ref::()` inside `accumulator()` methods. For example, you should expect to find `approx_percentile_cont` (percentile) and `string_agg` (delimiter) among the results. + +#### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) + +Window functions do not have `invoke_with_args()` or `accumulator()`. Instead, they enforce literal-only arguments in their `partition_evaluator()` method, which constructs the evaluator that processes each partition. + +The upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-window-/src/ +``` + +Look for `get_scalar_value_from_args()` calls inside `partition_evaluator()`. This helper (defined in the window crate's `utils.rs`) calls `downcast_ref::()` and errors with `"There is only support Literal types for field at idx: {index} in Window Function"`. + +Example from `ntile.rs`: +```rust +fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, +) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + exec_datafusion_err!("NTILE requires a positive integer") + })?; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which window functions have literal-only arguments, search the upstream window crate for `get_scalar_value_from_args` inside `partition_evaluator()` methods. For example, you should expect to find `ntile` (n) and `lead`/`lag` (offset, default_value) among the results. + +#### Technique 2: Check the `Signature` for data type constraints + +Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. + +Example from `repeat.rs`: +```rust +signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, +), +``` + +This tells you arg 2 (`n`) must be an integer type coerced to Int64. Use this to choose the correct Python type (e.g., `int` not `str` or `float`). + +Common mappings: +| Rust Type Constraint | Python Type | +|---------------------|-------------| +| `logical_int64()` / `TypeSignatureClass::Integer` | `int` | +| `logical_float64()` / `TypeSignatureClass::Numeric` | `int \| float` | +| `logical_string()` / `TypeSignatureClass::String` | `str` | +| `LogicalType::Boolean` | `bool` | + +**Important:** In Python's type system (PEP 484), `float` already accepts `int` values, so `int | float` is redundant and will fail the `ruff` linter (rule PYI041). Use `float` alone when the Rust side accepts a float/numeric type — Python users can still pass integer literals like `log(10, col("a"))` or `power(col("a"), 3)` without issue. Only use `int` when the Rust side strictly requires an integer (e.g., `logical_int64()`). + +#### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage + +Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. + +Example from `round.rs`: +```rust +let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), + Some(None) => None, // argument is not a literal (column) + Some(Some(scalar)) if scalar.is_null() => Some(0), + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), +}; +``` + +**If you find this pattern:** The argument is **Category A** — accept native types AND `Expr`. It works as a column but is primarily used as a literal. + +#### Decision flow + +``` +What kind of function is this? + Scalar UDF: + Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Aggregate: + Is argument rejected at planning time if not a literal? + (check accumulator() for get_scalar_value / validate_percentile_expr / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Window: + Is argument rejected at planning time if not a literal? + (check partition_evaluator() for get_scalar_value_from_args / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + +Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only +``` + +## Coercion Categories + +When making a function more pythonic, apply the correct coercion pattern based on **what the argument represents**: + +### Category A: Arguments That Should Accept Native Types AND Expr + +These are arguments that are *typically* literals but *could* be column references in advanced use cases. For these, accept a union type and coerce native types to `Expr.literal()`. + +**Type hint pattern:** `Expr | int`, `Expr | str`, `Expr | int | str`, etc. + +**When to use:** When the argument could plausibly come from a column in some use case (e.g., the repeat count might come from a column in a data-driven scenario). + +```python +def repeat(string: Expr, n: Expr | int) -> Expr: + """Repeats the ``string`` to ``n`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["ha"]}) + >>> result = df.select( + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) + >>> result.collect_column("r")[0].as_py() + 'hahaha' + """ + if not isinstance(n, Expr): + n = Expr.literal(n) + return Expr(f.repeat(string.expr, n.expr)) +``` + +### Category B: Arguments That Should ONLY Accept Specific Native Types + +These are arguments where an `Expr` never makes sense because the value must be a fixed literal known at query-planning time (not a per-row value). For these, accept only the native type(s) and wrap internally. + +**Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) + +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant, **AND** the parameter was not previously typed as `Expr`: +- Separator in `concat_ws` (already typed as `str` in the Rust binding) +- Index in `array_position` (already typed as `int` in the Rust binding) +- Values that the Rust implementation already accepts as native types + +**Backward compatibility rule:** If a parameter was previously typed as `Expr`, you **must** keep `Expr` in the union even if the Rust side requires a literal. Removing `Expr` would break existing user code like `date_part(lit("year"), col("a"))`. Use **Category A** instead — accept `Expr | str` — and let users who pass column expressions discover the runtime error from the Rust side. Never silently break backward compatibility. + +```python +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list ``args`` with the separator. + + ``separator`` is already typed as ``str`` in the Rust binding, so + there is no backward-compatibility concern. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello"], "b": ["world"]}) + >>> result = df.select( + ... dfn.functions.concat_ws("-", dfn.col("a"), dfn.col("b")).alias("c")) + >>> result.collect_column("c")[0].as_py() + 'hello-world' + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, args)) +``` + +### Category C: Arguments That Should Accept str as Column Name + +In some contexts a string argument naturally refers to a column name rather than a literal. This is the pattern used by DataFrame methods. + +**Type hint pattern:** `Expr | str` + +**When to use:** Only when the string contextually means a column name (rare in `functions.py`, more common in DataFrame methods). + +```python +# Use _to_raw_expr() from expr.py for this pattern +from datafusion.expr import _to_raw_expr + +def some_function(column: Expr | str) -> Expr: + raw = _to_raw_expr(column) # str -> col(str) + return Expr(f.some_function(raw)) +``` + +**IMPORTANT:** In `functions.py`, string arguments almost never mean column names. Functions operate on expressions, and column references should use `col()`. Category C applies mainly to DataFrame methods and context APIs, not to scalar/aggregate/window functions. Do NOT convert string arguments to column expressions in `functions.py` unless there is a very clear reason to do so. + +## Implementation Steps + +For each function being updated: + +### Step 1: Analyze the Function + +1. Read the current Python function signature in `python/datafusion/functions.py` +2. Read the Rust binding in `crates/core/src/functions.rs` +3. Optionally check the upstream DataFusion docs for the function +4. Determine which category (A, B, or C) applies to each parameter + +### Step 2: Update the Python Function + +1. **Change the type hints** to accept native types (e.g., `Expr` -> `Expr | int`) +2. **Add coercion logic** at the top of the function body +3. **Update the docstring** examples to use the simpler calling convention +4. **Preserve backward compatibility** — existing code using `Expr` must still work + +### Step 3: Update Alias Type Hints + +After updating a primary function, find all alias functions that delegate to it (e.g., `instr` and `position` delegate to `strpos`). Update each alias's **parameter type hints** to match the primary function's new signature. Do not add coercion logic to aliases — the primary function handles that. + +### Step 4: Update Docstring Examples (primary functions only) + +Per the project's CLAUDE.md rules: +- Every function must have doctest-style examples +- Optional parameters need examples both without and with the optional args, using keyword argument syntax +- Reuse the same input data across examples where possible + +**Update examples to demonstrate the pythonic calling convention:** + +```python +# BEFORE (old style - still works but verbose) +dfn.functions.left(dfn.col("a"), dfn.lit(3)) + +# AFTER (new style - shown in examples) +dfn.functions.left(dfn.col("a"), 3) +``` + +### Step 5: Run Tests + +After making changes, run the doctests to verify: +```bash +python -m pytest --doctest-modules python/datafusion/functions.py -v +``` + +## Coercion Helper Pattern + +Use the coercion helpers from `datafusion.expr` to convert native Python values to `Expr`. These are the complement of `ensure_expr()` — where `ensure_expr` *rejects* non-`Expr` values, the coercion helpers *wrap* them via `Expr.literal()`. + +**For required parameters** use `coerce_to_expr`: + +```python +from datafusion.expr import coerce_to_expr + +def left(string: Expr, n: Expr | int) -> Expr: + n = coerce_to_expr(n) + return Expr(f.left(string.expr, n.expr)) +``` + +**For optional nullable parameters** use `coerce_to_expr_or_none`: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none + +def regexp_count( + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, +) -> Expr: + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) +``` + +Both helpers are defined in `python/datafusion/expr.py` alongside `ensure_expr`. Import them in `functions.py` via: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none +``` + +## What NOT to Change + +- **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. +- **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. +- **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. +- **Do not add coercion logic to simple aliases.** If a function is just `return other_function(...)`, the primary function handles coercion. However, you **must update the alias's type hints** to match the primary function's signature so that type checkers and documentation accurately reflect what the alias accepts. +- **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. + +## Priority Order + +When auditing functions, process them in this order: + +1. **Date/time functions** — `date_part`, `date_trunc`, `date_bin` — these have the clearest literal arguments +2. **String functions** — `left`, `right`, `repeat`, `lpad`, `rpad`, `split_part`, `substring`, `replace`, `regexp_replace`, `regexp_match`, `regexp_count` — common and verbose without coercion +3. **Math functions** — `round`, `trunc`, `power` — numeric literal arguments +4. **Array functions** — `array_slice`, `array_position`, `array_remove_n`, `array_replace_n`, `array_resize`, `array_element` — index and count arguments +5. **Other functions** — any remaining functions with literal arguments + +## Output Format + +For each function analyzed, report: + +``` +## [Function Name] + +**Current signature:** `function(arg1: Expr, arg2: Expr) -> Expr` +**Proposed signature:** `function(arg1: Expr, arg2: Expr | int) -> Expr` +**Category:** A (accepts native + Expr) +**Arguments changed:** +- `arg2`: Expr -> Expr | int (always a literal count) +**Rust binding:** Takes PyExpr, wraps to literal internally +**Status:** [Changed / Skipped / Needs Discussion] +``` + +If asked to implement (not just audit), make the changes directly and show a summary of what was updated. diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 000000000..6838a1160 --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +../.ai/skills \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4b37046ee..7682d6cb0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,6 +159,19 @@ jobs: with: enable-cache: true + - name: Add extra swap for release build + if: inputs.build_mode == 'release' + run: | + set -euxo pipefail + sudo swapoff -a || true + sudo rm -f /swapfile + sudo fallocate -l 8G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=8192 + sudo chmod 600 /swapfile + sudo mkswap /swapfile + sudo swapon /swapfile + free -h + swapon --show + - name: Build (release mode) uses: PyO3/maturin-action@v1 if: inputs.build_mode == 'release' @@ -233,7 +246,7 @@ jobs: set -euxo pipefail sudo swapoff -a || true sudo rm -f /swapfile - sudo fallocate -l 16G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=16384 + sudo fallocate -l 8G /swapfile || sudo dd if=/dev/zero of=/swapfile bs=1M count=8192 sudo chmod 600 /swapfile sudo mkswap /swapfile sudo swapon /swapfile diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..a9855cf48 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: '16 4 * * 1' + +permissions: + contents: read + +jobs: + analyze: + name: Analyze Actions + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + persist-credentials: false + + - name: Initialize CodeQL + uses: github/codeql-action/init@c793b717bc78562f491db7b0e93a3a178b099162 # v4 + with: + languages: actions + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@c793b717bc78562f491db7b0e93a3a178b099162 # v4 + with: + category: "/language:actions" diff --git a/.github/workflows/verify-release-candidate.yml b/.github/workflows/verify-release-candidate.yml index a10a4faa9..6ecb547b5 100644 --- a/.github/workflows/verify-release-candidate.yml +++ b/.github/workflows/verify-release-candidate.yml @@ -27,7 +27,7 @@ on: required: true type: string rc_number: - description: Release candidate number (e.g., 0) + description: Release candidate number (e.g., 1) required: true type: string @@ -73,6 +73,11 @@ jobs: version: "27.4" repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Set RUSTFLAGS for Windows GNU linker + if: matrix.os == 'windows' + shell: bash + run: echo "RUSTFLAGS=-C link-arg=-Wl,--exclude-libs=ALL" >> "$GITHUB_ENV" + - name: Run release candidate verification shell: bash run: ./dev/release/verify-release-candidate.sh "${{ inputs.version }}" "${{ inputs.rc_number }}" diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..7d3262710 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,88 @@ + + +# Agent Instructions for Contributors + +This file is for agents working **on** the datafusion-python project (developing, +testing, reviewing). If you need to **use** the DataFusion DataFrame API (write +queries, build expressions, understand available functions), see the user-facing +skill at [`SKILL.md`](SKILL.md). + +## Skills + +This project uses AI agent skills stored in `.ai/skills/`. Each skill is a directory containing a `SKILL.md` file with instructions for performing a specific task. + +Skills follow the [Agent Skills](https://agentskills.io) open standard. Each skill directory contains: + +- `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. +- Additional supporting files as needed. + +## Pull Requests + +Every pull request must follow the template in +`.github/pull_request_template.md`. The description must include these sections: + +1. **Which issue does this PR close?** — Link the issue with `Closes #NNN`. +2. **Rationale for this change** — Why the change is needed (skip if the issue + already explains it clearly). +3. **What changes are included in this PR?** — Summarize the individual changes. +4. **Are there any user-facing changes?** — Note any changes visible to users + (new APIs, changed behavior, new files shipped in the package, etc.). If + there are breaking changes to public APIs, add the `api change` label. + +## Pre-commit Checks + +Always run pre-commit checks **before** committing. The hooks are defined in +`.pre-commit-config.yaml` and run automatically on `git commit` if pre-commit +is installed as a git hook. To run all hooks manually: + +```bash +pre-commit run --all-files +``` + +Fix any failures before committing. + +## Python Function Docstrings + +Every Python function must include a docstring with usage examples. + +- **Examples are required**: Each function needs at least one doctest-style example + demonstrating basic usage. +- **Optional parameters**: If a function has optional parameters, include separate + examples that show usage both without and with the optional arguments. Pass + optional arguments using their keyword name (e.g., `step=dfn.lit(3)`) so readers + can immediately see which parameter is being demonstrated. +- **Reuse input data**: Use the same input data across examples wherever possible. + The examples should demonstrate how different optional arguments change the output + for the same input, making the effect of each option easy to understand. +- **Alias functions**: Functions that are simple aliases (e.g., `list_sort` aliasing + `array_sort`) only need a one-line description and a `See Also` reference to the + primary function. They do not need their own examples. + +## Aggregate and Window Function Documentation + +When adding or updating an aggregate or window function, ensure the corresponding +site documentation is kept in sync: + +- **Aggregations**: `docs/source/user-guide/common-operations/aggregations.rst` — + add new aggregate functions to the "Aggregate Functions" list and include usage + examples if appropriate. +- **Window functions**: `docs/source/user-guide/common-operations/windows.rst` — + add new window functions to the "Available Functions" list and include usage + examples if appropriate. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 000000000..47dc3e3d8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index ee89c8bda..4efca3eb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1312,7 +1312,7 @@ dependencies = [ [[package]] name = "datafusion-ffi-example" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "arrow-array", @@ -1662,11 +1662,12 @@ dependencies = [ [[package]] name = "datafusion-python" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "arrow-select", "async-trait", + "chrono", "cstr", "datafusion", "datafusion-ffi", @@ -1692,7 +1693,7 @@ dependencies = [ [[package]] name = "datafusion-python-util" -version = "52.0.0" +version = "53.0.0" dependencies = [ "arrow", "datafusion", diff --git a/Cargo.toml b/Cargo.toml index 346f6da3e..d0e87a9a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [workspace.package] -version = "52.0.0" +version = "53.0.0" homepage = "https://datafusion.apache.org/python" repository = "https://github.com/apache/datafusion-python" authors = ["Apache DataFusion "] @@ -35,6 +35,7 @@ tokio = { version = "1.50" } pyo3 = { version = "0.28" } pyo3-async-runtimes = { version = "0.28" } pyo3-log = "0.13.3" +chrono = { version = "0.4", default-features = false } arrow = { version = "58" } arrow-array = { version = "58" } arrow-schema = { version = "58" } @@ -59,13 +60,13 @@ object_store = { version = "0.13.1" } url = "2" log = "0.4.29" parking_lot = "0.12" -prost-types = "0.14.3" # keep in line with `datafusion-substrait` +prost-types = "0.14.3" # keep in line with `datafusion-substrait` pyo3-build-config = "0.28" -datafusion-python-util = { path = "crates/util" } +datafusion-python-util = { path = "crates/util", version = "53.0.0" } [profile.release] -lto = true -codegen-units = 1 +lto = "thin" +codegen-units = 2 # We cannot publish to crates.io with any patches in the below section. Developers # must remove any entries in this section before creating a release candidate. diff --git a/README.md b/README.md index c24257876..4baed7d1d 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,22 @@ You can verify the installation by running: '0.6.0' ``` +## Using DataFusion with AI coding assistants + +This project ships a [`SKILL.md`](SKILL.md) at the repo root that teaches AI +coding assistants how to write idiomatic DataFusion Python. It follows the +[Agent Skills](https://agentskills.io) open standard. + +**Preferred:** `npx skills add apache/datafusion-python` — installs the skill in +Claude Code, Cursor, Windsurf, Cline, Codex, Copilot, Gemini CLI, and other +supported agents. + +**Manual:** paste this line into your project's `AGENTS.md` / `CLAUDE.md`: + +``` +For DataFusion Python code, see https://github.com/apache/datafusion-python/blob/main/SKILL.md +``` + ## How to develop This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). The Maturin tools used in this workflow can be installed either via `uv` or `pip`. Both approaches should offer the same experience. It is recommended to use `uv` since it has significant performance improvements @@ -275,7 +291,7 @@ needing to activate the virtual environment: ```bash uv run --no-project maturin develop --uv -uv run --no-project pytest . +uv run --no-project pytest ``` To run the FFI tests within the examples folder, after you have built @@ -312,6 +328,33 @@ There are scripts in `ci/scripts` for running Rust and Python linters. ./ci/scripts/rust_toml_fmt.sh ``` +## Checking Upstream DataFusion Coverage + +This project includes an [AI agent skill](.ai/skills/check-upstream/SKILL.md) for auditing which +features from the upstream Apache DataFusion Rust library are not yet exposed in these Python +bindings. This is useful when adding missing functions, auditing API coverage, or ensuring parity +with upstream. + +The skill accepts an optional area argument: + +``` +scalar functions +aggregate functions +window functions +dataframe +session context +ffi types +all +``` + +If no argument is provided, it defaults to checking all areas. The skill will fetch the upstream +DataFusion documentation, compare it against the functions and methods exposed in this project, and +produce a coverage report listing what is currently exposed and what is missing. + +The skill definition lives in `.ai/skills/check-upstream/SKILL.md` and follows the +[Agent Skills](https://agentskills.io) open standard. It can be used by any AI coding agent that +supports skill discovery, or followed manually. + ## How to update dependencies To change test dependencies, change the `pyproject.toml` and run diff --git a/SKILL.md b/SKILL.md new file mode 100644 index 000000000..7b07b430f --- /dev/null +++ b/SKILL.md @@ -0,0 +1,722 @@ +--- +name: datafusion-python +description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. +--- + +# DataFusion Python DataFrame API Guide + +## What Is DataFusion? + +DataFusion is an **in-process query engine** built on Apache Arrow. It is not a +database -- there is no server, no connection string, and no external +dependencies. You create a `SessionContext`, point it at data (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API described below. + +All data flows through **Apache Arrow**. The canonical Python implementation is +PyArrow (`pyarrow.RecordBatch` / `pyarrow.Table`), but any library that +conforms to the [Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +can interoperate with DataFusion. + +## Core Abstractions + +| Abstraction | Role | Key import | +|---|---|---| +| `SessionContext` | Entry point. Loads data, runs SQL, produces DataFrames. | `from datafusion import SessionContext` | +| `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | +| `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | +| `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | + +## Import Conventions + +```python +from datafusion import SessionContext, col, lit +from datafusion import functions as F +``` + +## Data Loading + +```python +ctx = SessionContext() + +# From files +df = ctx.read_parquet("path/to/data.parquet") +df = ctx.read_csv("path/to/data.csv") +df = ctx.read_json("path/to/data.json") + +# From Python objects +df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) +df = ctx.from_pylist([{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]) +df = ctx.from_pandas(pandas_df) +df = ctx.from_polars(polars_df) +df = ctx.from_arrow(arrow_table) + +# From SQL +df = ctx.sql("SELECT a, b FROM my_table WHERE a > 1") +``` + +To make a DataFrame queryable by name in SQL, register it first: + +```python +ctx.register_parquet("my_table", "path/to/data.parquet") +ctx.register_csv("my_table", "path/to/data.csv") +``` + +## DataFrame Operations Quick Reference + +Every method returns a **new** DataFrame (immutable/lazy). Chain them fluently. + +### Projection + +```python +df.select("a", "b") # preferred: plain names as strings +df.select(col("a"), (col("b") + 1).alias("b_plus_1")) # use col()/Expr only when you need an expression + +df.with_column("new_col", col("a") + lit(10)) # add one column +df.with_columns( + col("a").alias("x"), + y=col("b") + lit(1), # named keyword form +) + +df.drop("unwanted_col") +df.with_column_renamed("old_name", "new_name") +``` + +When a column is referenced by name alone, pass the name as a string rather +than wrapping it in `col()`. Reach for `col()` only when the projection needs +arithmetic, aliasing, casting, or another expression operation. + +**Case sensitivity**: both `select("Name")` and `col("Name")` lowercase the +identifier. For a column whose real name has uppercase letters, embed double +quotes inside the string: `select('"MyCol"')` or `col('"MyCol"')`. Without the +inner quotes the lookup will fail with `No field named mycol`. + +### Filtering + +```python +df.filter(col("a") > 10) +df.filter(col("a") > 10, col("b") == "x") # multiple = AND +df.filter("a > 10") # SQL expression string +``` + +Raw Python values on the right-hand side of a comparison are auto-wrapped +into literals by the `Expr` operators, so prefer `col("a") > 10` over +`col("a") > lit(10)`. See the Comparisons section and pitfall #2 for the +full rule. + +### Aggregation + +```python +# GROUP BY a, compute sum(b) and count(*) +df.aggregate(["a"], [F.sum(col("b")), F.count(col("a"))]) + +# HAVING equivalent: use the filter keyword on the aggregate function +df.aggregate( + ["region"], + [F.sum(col("sales"), filter=col("sales") > 1000).alias("large_sales")], +) +``` + +As with `select()`, group keys can be passed as plain name strings. Reach for +`col(...)` only when the grouping expression needs arithmetic, aliasing, +casting, or another expression operation. + +Most aggregate functions accept an optional `filter` keyword argument. When +provided, only rows where the filter expression is true contribute to the +aggregate. + +### Sorting + +```python +df.sort("a") # ascending (plain name, preferred) +df.sort(col("a")) # ascending via col() +df.sort(col("a").sort(ascending=False)) # descending +df.sort(col("a").sort(nulls_first=False)) # override null placement + +df.sort_by("a", "b") # ascending-only shortcut +``` + +As with `select()` and `aggregate()`, bare column references can be passed as +plain name strings. A plain expression passed to `sort()` is already treated +as ascending, so reach for `col(...).sort(...)` only when you need to override +a default (descending order or null placement). Writing +`col("a").sort(ascending=True)` is redundant. + +For ascending-only sorts with no null-placement override, `df.sort_by(...)` is +a shorter alias for `df.sort(...)`. + +### Joining + +```python +# Equi-join on shared column name +df1.join(df2, on="key") +df1.join(df2, on="key", how="left") + +# Different column names +df1.join(df2, left_on="id", right_on="fk_id", how="inner") + +# Expression-based join (supports inequality predicates) +df1.join_on(df2, col("a") == col("b"), how="inner") + +# Semi join: keep rows from left where a match exists in right (like EXISTS) +df1.join(df2, on="key", how="semi") + +# Anti join: keep rows from left where NO match exists in right (like NOT EXISTS) +df1.join(df2, on="key", how="anti") +``` + +Join types: `"inner"`, `"left"`, `"right"`, `"full"`, `"semi"`, `"anti"`. + +Inner is the default `how`. Prefer `df1.join(df2, on="key")` over +`df1.join(df2, on="key", how="inner")` — drop `how=` unless you need a +non-inner join type. + +When the two sides' join columns have different native names, use +`left_on=`/`right_on=` with the original names rather than aliasing one side +to match the other — see pitfall #7. + +### Window Functions + +```python +from datafusion import WindowFrame + +# Row number partitioned by group, ordered by value +df.window( + F.row_number( + partition_by=[col("group")], + order_by=[col("value")], + ).alias("rn") +) + +# Using a Window object for reuse +from datafusion.expr import Window + +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], +) +df.select( + col("group"), + col("value"), + F.sum(col("value")).over(win).alias("running_total"), +) + +# With explicit frame bounds +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], + window_frame=WindowFrame("rows", 0, None), # current row to unbounded following +) +``` + +### Set Operations + +```python +df1.union(df2) # UNION ALL (by position) +df1.union(df2, distinct=True) # UNION DISTINCT +df1.union_by_name(df2) # match columns by name, not position +df1.intersect(df2) # INTERSECT ALL +df1.intersect(df2, distinct=True) # INTERSECT (distinct) +df1.except_all(df2) # EXCEPT ALL +df1.except_all(df2, distinct=True) # EXCEPT (distinct) +``` + +### Limit and Offset + +```python +df.limit(10) # first 10 rows +df.limit(10, offset=20) # skip 20, then take 10 +``` + +### Deduplication + +```python +df.distinct() # remove duplicate rows +df.distinct_on( # keep first row per group (like DISTINCT ON in Postgres) + [col("a")], # uniqueness columns + [col("a"), col("b")], # output columns + [col("b").sort(ascending=True)], # which row to keep +) +``` + +## Executing and Collecting Results + +DataFrames are lazy until you collect. + +```python +df.show() # print formatted table to stdout +batches = df.collect() # list[pa.RecordBatch] +arr = df.collect_column("col_name") # pa.Array | pa.ChunkedArray (single column) +table = df.to_arrow_table() # pa.Table +pandas_df = df.to_pandas() # pd.DataFrame +polars_df = df.to_polars() # pl.DataFrame +py_dict = df.to_pydict() # dict[str, list] +py_list = df.to_pylist() # list[dict] +count = df.count() # int +``` + +### Date and Timestamp Type Conversion + +The Python type returned by `to_pydict()` / `to_pylist()` depends on the Arrow +column type, and the mapping is inherited from PyArrow: + +| Arrow type | Python type returned | +|---|---| +| `timestamp(s)` / `(ms)` / `(us)` | `datetime.datetime` | +| `timestamp(ns)` | `pandas.Timestamp` | +| `date32` / `date64` | `datetime.date` | +| `duration(s)` / `(ms)` / `(us)` | `datetime.timedelta` | +| `duration(ns)` | `pandas.Timedelta` | + +The nanosecond-precision fallback to pandas types is the main surprise: +pandas is not a hard dependency of `datafusion`, but PyArrow reaches for it +when `datetime.datetime` / `datetime.timedelta` would lose precision (stdlib +types only go to microseconds). If you need plain stdlib types, cast to a +coarser unit before collecting, e.g. +`df.select(col("ts").cast(pa.timestamp("us")))`. + +`df.to_pandas()` has its own footgun for dates: pandas has no pure-date dtype, +so a `date32`/`date64` column comes back as an `object` column of +`datetime.date` values rather than `datetime64[ns]`. If downstream code +expects a datetime column, cast on the DataFusion side first: +`col("ship_date").cast(pa.timestamp("ns"))`. + +### Streaming Results + +Prefer streaming over `collect()` when the result is too large to materialize +in memory, when you want to start processing before the query finishes, or +when you may break out of the loop early. `execute_stream()` pulls one +`RecordBatch` at a time from the execution plan rather than buffering the +whole result up front. + +```python +# Single-partition stream; batch is a datafusion.RecordBatch +stream = df.execute_stream() +for batch in stream: + process(batch.to_pyarrow()) # convert to pa.RecordBatch if needed + +# DataFrame is iterable directly (delegates to execute_stream) +for batch in df: + process(batch.to_pyarrow()) + +# One stream per partition, for parallel consumption +for stream in df.execute_stream_partitioned(): + for batch in stream: + process(batch.to_pyarrow()) +``` + +Async iteration is also supported via `async for batch in df: ...` (or +`df.execute_stream()`), which is useful when batches are interleaved with +other I/O. + +### Writing Results + +```python +df.write_parquet("output.parquet") +df.write_csv("output.csv") +df.write_json("output.json") +``` + +You can also pass a directory path (e.g., `"output/"`) to write a multi-file +partitioned output. + +## Expression Building + +### Column References and Literals + +```python +col("column_name") # reference a column +lit(42) # integer literal +lit("hello") # string literal +lit(3.14) # float literal +lit(pa.scalar(value)) # PyArrow scalar (preserves Arrow type) +``` + +`lit()` accepts PyArrow scalars directly -- prefer this over converting Arrow +data to Python and back when working with values extracted from query results. + +### Arithmetic + +```python +col("price") * col("quantity") # multiplication +col("a") + lit(1) # addition +col("a") - col("b") # subtraction +col("a") / lit(2) # division +col("a") % lit(3) # modulo +``` + +### Date Arithmetic + +`Date32` and `Date64` columns both require `Interval` types for arithmetic, +not `Duration`. Use PyArrow's `month_day_nano_interval` type, which takes a +`(months, days, nanos)` tuple: + +```python +import pyarrow as pa + +# Subtract 90 days from a date column +col("ship_date") - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) + +# Subtract 3 months +col("ship_date") - lit(pa.scalar((3, 0, 0), type=pa.month_day_nano_interval())) +``` + +**Important**: `lit(datetime.timedelta(days=90))` creates a `Duration(µs)` +literal, which is **not** compatible with `Date32`/`Date64` arithmetic +(`Duration(ms)` and `Duration(ns)` are rejected too). Always use +`pa.month_day_nano_interval()` for date operations. + +**Timestamps behave differently**: `Timestamp` columns *do* accept `Duration`, +so `col("ts") - lit(datetime.timedelta(days=1))` works. The interval-only +rule applies specifically to date columns. + +### Comparisons + +```python +col("a") > 10 +col("a") >= 10 +col("a") < 10 +col("a") <= 10 +col("a") == "x" +col("a") != "x" +col("a") == None # same as col("a").is_null() +col("a") != None # same as col("a").is_not_null() +``` + +Comparison operators auto-wrap the right-hand Python value into a literal, +so writing `col("a") > lit(10)` is redundant. Drop the `lit()` in +comparisons. Reach for `lit()` only when auto-wrapping does not apply — see +pitfall #2. + +### Boolean Logic + +**Important**: Python's `and`, `or`, `not` keywords do NOT work with Expr +objects. You must use the bitwise operators: + +```python +(col("a") > 1) & (col("b") < 10) # AND +(col("a") > 1) | (col("b") < 10) # OR +~(col("a") > 1) # NOT +``` + +Always wrap each comparison in parentheses when combining with `&`, `|`, `~` +because Python's operator precedence for bitwise operators is different from +logical operators. + +### Null Handling + +```python +col("a").is_null() +col("a").is_not_null() +col("a").fill_null(lit(0)) # replace NULL with a value +F.coalesce(col("a"), col("b")) # first non-null value +F.nullif(col("a"), lit(0)) # return NULL if a == 0 +``` + +### CASE / WHEN + +```python +# Simple CASE (matching on a single expression) +status_label = ( + F.case(col("status")) + .when(lit("A"), lit("Active")) + .when(lit("I"), lit("Inactive")) + .otherwise(lit("Unknown")) +) + +# Searched CASE (each branch has its own predicate) +severity = ( + F.when(col("value") > 100, lit("high")) + .when(col("value") > 50, lit("medium")) + .otherwise(lit("low")) +) +``` + +### Casting + +```python +import pyarrow as pa + +col("a").cast(pa.float64()) +col("a").cast(pa.utf8()) +col("a").cast(pa.date32()) +``` + +### Aliasing + +```python +(col("a") + col("b")).alias("total") +``` + +### BETWEEN and IN + +```python +col("a").between(lit(1), lit(10)) # 1 <= a <= 10 +F.in_list(col("a"), [lit(1), lit(2), lit(3)]) # a IN (1, 2, 3) +F.in_list(col("a"), [lit(1), lit(2)], negated=True) # a NOT IN (1, 2) +``` + +### Struct and Array Access + +```python +col("struct_col")["field_name"] # access struct field +col("array_col")[0] # access array element (0-indexed) +col("array_col")[1:3] # array slice (0-indexed) +``` + +## SQL-to-DataFrame Reference + +| SQL | DataFrame API | +|---|---| +| `SELECT a, b` | `df.select("a", "b")` | +| `SELECT a, b + 1 AS c` | `df.select(col("a"), (col("b") + lit(1)).alias("c"))` | +| `SELECT *, a + 1 AS c` | `df.with_column("c", col("a") + lit(1))` | +| `WHERE a > 10` | `df.filter(col("a") > 10)` | +| `GROUP BY a` with `SUM(b)` | `df.aggregate(["a"], [F.sum(col("b"))])` | +| `SUM(b) FILTER (WHERE b > 100)` | `F.sum(col("b"), filter=col("b") > 100)` | +| `ORDER BY a DESC` | `df.sort(col("a").sort(ascending=False))` | +| `LIMIT 10 OFFSET 5` | `df.limit(10, offset=5)` | +| `DISTINCT` | `df.distinct()` | +| `a INNER JOIN b ON a.id = b.id` | `a.join(b, on="id")` | +| `a LEFT JOIN b ON a.id = b.fk` | `a.join(b, left_on="id", right_on="fk", how="left")` | +| `WHERE EXISTS (SELECT ...)` | `a.join(b, on="key", how="semi")` | +| `WHERE NOT EXISTS (SELECT ...)` | `a.join(b, on="key", how="anti")` | +| `UNION ALL` | `df1.union(df2)` | +| `UNION` (distinct) | `df1.union(df2, distinct=True)` | +| `INTERSECT ALL` | `df1.intersect(df2)` | +| `INTERSECT` (distinct) | `df1.intersect(df2, distinct=True)` | +| `EXCEPT ALL` | `df1.except_all(df2)` | +| `EXCEPT` (distinct) | `df1.except_all(df2, distinct=True)` | +| `CASE x WHEN 1 THEN 'a' END` | `F.case(col("x")).when(lit(1), lit("a")).end()` | +| `CASE WHEN x > 1 THEN 'a' END` | `F.when(col("x") > 1, lit("a")).end()` | +| `x IN (1, 2, 3)` | `F.in_list(col("x"), [lit(1), lit(2), lit(3)])` | +| `x BETWEEN 1 AND 10` | `col("x").between(lit(1), lit(10))` | +| `CAST(x AS DOUBLE)` | `col("x").cast(pa.float64())` | +| `ROW_NUMBER() OVER (...)` | `F.row_number(partition_by=[...], order_by=[...])` | +| `SUM(x) OVER (...)` | `F.sum(col("x")).over(window)` | +| `x IS NULL` | `col("x").is_null()` | +| `COALESCE(a, b)` | `F.coalesce(col("a"), col("b"))` | + +## Common Pitfalls + +1. **Boolean operators**: Use `&`, `|`, `~` -- not Python's `and`, `or`, `not`. + Always parenthesize: `(col("a") > 1) & (col("b") < 2)`. + +2. **Wrapping scalars with `lit()`**: Prefer raw Python values on the + right-hand side of comparisons — `col("a") > 10`, `col("name") == "Alice"` + — because the Expr comparison operators auto-wrap them. Writing + `col("a") > lit(10)` is redundant. Reserve `lit()` for places where + auto-wrapping does *not* apply: + - standalone scalars passed into function calls: + `F.coalesce(col("a"), lit(0))`, not `F.coalesce(col("a"), 0)` + - arithmetic between two literals with no column involved: + `lit(1) - col("discount")` is fine, but `lit(1) - lit(2)` needs both + - values that must carry a specific Arrow type, via `lit(pa.scalar(...))` + - `.when(...)`, `.otherwise(...)`, `F.nullif(...)`, `.between(...)`, + `F.in_list(...)` and similar method/function arguments + +3. **Column name quoting**: Column names are normalized to lowercase by default + in both `select("...")` and `col("...")`. To reference a column with + uppercase letters, use double quotes inside the string: + `select('"MyColumn"')` or `col('"MyColumn"')`. + +4. **DataFrames are immutable**: Every method returns a **new** DataFrame. You + must capture the return value: + ```python + df = df.filter(col("a") > 1) # correct + df.filter(col("a") > 1) # WRONG -- result is discarded + ``` + +5. **Window frame defaults**: When using `order_by` in a window, the default + frame is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`. For a full + partition frame, set `window_frame=WindowFrame("rows", None, None)`. + +6. **Arithmetic on aggregates belongs in a later `select`, not inside + `aggregate`**: Each item in the aggregate list must be a single aggregate + call (optionally aliased). Combining aggregates with arithmetic inside + `aggregate(...)` fails with `Internal error: Invalid aggregate expression`. + Alias the aggregates, then compute the combination downstream: + ```python + # WRONG -- arithmetic wraps two aggregates + df.aggregate([], [(lit(100) * F.sum(col("a")) / F.sum(col("b"))).alias("ratio")]) + + # CORRECT -- aggregate first, then combine + (df.aggregate([], [F.sum(col("a")).alias("num"), F.sum(col("b")).alias("den")]) + .select((lit(100) * col("num") / col("den")).alias("ratio"))) + ``` + +7. **Don't alias a join column to match the other side**: When equi-joining + with `on="key"`, renaming the join column on one side via `.alias("key")` + in a fresh projection creates a schema where one side's `key` is + qualified (`?table?.key`) and the other is unqualified. The join then + fails with `Schema contains qualified field name ... and unqualified + field name ... which would be ambiguous`. Use `left_on=`/`right_on=` with + the native names, or use `join_on(...)` with an explicit equality. + ```python + # WRONG -- alias on one side produces ambiguous schema after join + failed = orders.select(col("o_orderkey").alias("l_orderkey")) + li.join(failed, on="l_orderkey") # ambiguous l_orderkey error + + # CORRECT -- keep native names, use left_on/right_on + failed = orders.select("o_orderkey") + li.join(failed, left_on="l_orderkey", right_on="o_orderkey") + + # ALSO CORRECT -- explicit predicate via join_on + # (note: join_on keeps both key columns in the output, unlike on="key") + li.join_on(failed, col("l_orderkey") == col("o_orderkey")) + ``` + +## Idiomatic Patterns + +### Fluent Chaining + +```python +result = ( + ctx.read_parquet("data.parquet") + .filter(col("year") >= 2020) + .select(col("region"), col("sales")) + .aggregate(["region"], [F.sum(col("sales")).alias("total")]) + .sort(col("total").sort(ascending=False)) + .limit(10) +) +result.show() +``` + +### Using Variables as CTEs + +Instead of SQL CTEs (`WITH ... AS`), assign intermediate DataFrames to +variables: + +```python +base = ctx.read_parquet("orders.parquet").filter(col("status") == "shipped") +by_region = base.aggregate(["region"], [F.sum(col("amount")).alias("total")]) +top_regions = by_region.filter(col("total") > 10000) +``` + +### Reusing Expressions as Variables + +Just like DataFrames, expressions (`Expr`) can be stored in variables and used +anywhere an `Expr` is expected. This is useful for building up complex +expressions or reusing a computed value across multiple operations: + +```python +# Build an expression and reuse it +disc_price = col("price") * (lit(1) - col("discount")) +df = df.select( + col("id"), + disc_price.alias("disc_price"), + (disc_price * (lit(1) + col("tax"))).alias("total"), +) + +# Use a collected scalar as an expression +max_val = result_df.collect_column("max_price")[0] # PyArrow scalar +cutoff = lit(max_val) - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) +df = df.filter(col("ship_date") <= cutoff) # cutoff is already an Expr +``` + +**Important**: Do not wrap an `Expr` in `lit()`. `lit()` is for converting +Python/PyArrow values into expressions. If a value is already an `Expr`, use it +directly. + +### Window Functions for Scalar Subqueries + +Where SQL uses a correlated scalar subquery, the idiomatic DataFrame approach +is a window function: + +```sql +-- SQL scalar subquery +SELECT *, (SELECT SUM(b) FROM t WHERE t.group = s.group) AS group_total FROM s +``` + +```python +# DataFrame: window function +win = Window(partition_by=[col("group")]) +df = df.with_column("group_total", F.sum(col("b")).over(win)) +``` + +### Semi/Anti Joins for EXISTS / NOT EXISTS + +```sql +-- SQL: WHERE EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="semi") + +-- SQL: WHERE NOT EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="anti") +``` + +### Computed Columns + +```python +# Add computed columns while keeping all originals +df = df.with_column("full_name", F.concat(col("first"), lit(" "), col("last"))) +df = df.with_column("discounted", col("price") * lit(0.9)) +``` + +## Available Functions (Categorized) + +The `functions` module (imported as `F`) provides 290+ functions. Key categories: + +**Aggregate**: `sum`, `avg`, `min`, `max`, `count`, `count_star`, `median`, +`stddev`, `stddev_pop`, `var_samp`, `var_pop`, `corr`, `covar`, `approx_distinct`, +`approx_median`, `approx_percentile_cont`, `array_agg`, `string_agg`, +`first_value`, `last_value`, `bit_and`, `bit_or`, `bit_xor`, `bool_and`, +`bool_or`, `grouping`, `regr_*` (9 regression functions) + +**Window**: `row_number`, `rank`, `dense_rank`, `percent_rank`, `cume_dist`, +`ntile`, `lag`, `lead`, `first_value`, `last_value`, `nth_value` + +**String**: `length`, `lower`, `upper`, `trim`, `ltrim`, `rtrim`, `lpad`, +`rpad`, `starts_with`, `ends_with`, `contains`, `substr`, `substring`, +`replace`, `reverse`, `repeat`, `split_part`, `concat`, `concat_ws`, +`initcap`, `ascii`, `chr`, `left`, `right`, `strpos`, `translate`, `overlay`, +`levenshtein` + +`F.substr(str, start)` takes **only two arguments** and returns the tail of +the string from `start` onward — passing a third length argument raises +`TypeError: substr() takes 2 positional arguments but 3 were given`. For the +SQL-style 3-arg form (`SUBSTRING(str FROM start FOR length)`), use +`F.substring(col("s"), lit(start), lit(length))`. For a fixed-length prefix, +`F.left(col("s"), lit(n))` is cleanest. + +```python +# WRONG — substr does not accept a length argument +F.substr(col("c_phone"), lit(1), lit(2)) +# CORRECT +F.substring(col("c_phone"), lit(1), lit(2)) # explicit length +F.left(col("c_phone"), lit(2)) # prefix shortcut +``` + +**Math**: `abs`, `ceil`, `floor`, `round`, `trunc`, `sqrt`, `cbrt`, `exp`, +`ln`, `log`, `log2`, `log10`, `pow`, `signum`, `pi`, `random`, `factorial`, +`gcd`, `lcm`, `greatest`, `least`, sin/cos/tan and inverse/hyperbolic variants + +**Date/Time**: `now`, `today`, `current_date`, `current_time`, +`current_timestamp`, `date_part`, `date_trunc`, `date_bin`, `extract`, +`to_timestamp`, `to_timestamp_millis`, `to_timestamp_micros`, +`to_timestamp_nanos`, `to_timestamp_seconds`, `to_unixtime`, `from_unixtime`, +`make_date`, `make_time`, `to_date`, `to_time`, `to_local_time`, `date_format` + +**Conditional**: `case`, `when`, `coalesce`, `nullif`, `ifnull`, `nvl`, `nvl2` + +**Array/List**: `array`, `make_array`, `array_agg`, `array_length`, +`array_element`, `array_slice`, `array_append`, `array_prepend`, +`array_concat`, `array_has`, `array_has_all`, `array_has_any`, `array_position`, +`array_remove`, `array_distinct`, `array_sort`, `array_reverse`, `flatten`, +`array_to_string`, `array_intersect`, `array_union`, `array_except`, +`generate_series` +(Most `array_*` functions also have `list_*` aliases.) + +**Struct/Map**: `struct`, `named_struct`, `get_field`, `make_map`, `map_keys`, +`map_values`, `map_entries`, `map_extract` + +**Regex**: `regexp_like`, `regexp_match`, `regexp_replace`, `regexp_count`, +`regexp_instr` + +**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest` + +**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata` + +**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, +`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` diff --git a/conftest.py b/conftest.py index 73e90077a..0c9410636 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ import datafusion as dfn import numpy as np +import pyarrow as pa import pytest from datafusion import col, lit from datafusion import functions as F @@ -29,6 +30,7 @@ def _doctest_namespace(doctest_namespace: dict) -> None: """Add common imports to the doctest namespace.""" doctest_namespace["dfn"] = dfn doctest_namespace["np"] = np + doctest_namespace["pa"] = pa doctest_namespace["col"] = col doctest_namespace["lit"] = lit doctest_namespace["F"] = F diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 3e2b01c8e..d714dc978 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -47,6 +47,7 @@ pyo3 = { workspace = true, features = [ ] } pyo3-async-runtimes = { workspace = true, features = ["tokio-runtime"] } pyo3-log = { workspace = true } +chrono = { workspace = true } arrow = { workspace = true, features = ["pyarrow"] } arrow-select = { workspace = true } datafusion = { workspace = true, features = ["avro", "unicode_expressions"] } diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 53994d2f5..e46d359d6 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -28,7 +28,7 @@ use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory}; -use datafusion::common::{ScalarValue, TableReference, exec_err}; +use datafusion::common::{DFSchema, ScalarValue, TableReference, exec_err}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -41,7 +41,7 @@ use datafusion::execution::context::{ }; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; -use datafusion::execution::options::ReadOptions; +use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{ @@ -60,7 +60,7 @@ use datafusion_python_util::{ }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; -use pyo3::exceptions::{PyKeyError, PyValueError}; +use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; use url::Url; @@ -70,11 +70,13 @@ use crate::catalog::{ PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList, }; use crate::common::data_type::PyScalarValue; +use crate::common::df_schema::PyDFSchema; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{ PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err, }; +use crate::expr::PyExpr; use crate::expr::sort_expr::PySortExpr; use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; @@ -434,11 +436,25 @@ impl PySessionContext { &upstream_host }; let url_string = format!("{scheme}{derived_host}"); - let url = Url::parse(&url_string).unwrap(); + let url = Url::parse(&url_string).map_err(|e| PyValueError::new_err(e.to_string()))?; self.ctx.runtime_env().register_object_store(&url, store); Ok(()) } + /// Deregister an object store with the given url + #[pyo3(signature = (scheme, host=None))] + pub fn deregister_object_store( + &self, + scheme: &str, + host: Option<&str>, + ) -> PyDataFusionResult<()> { + let host = host.unwrap_or(""); + let url_string = format!("{scheme}{host}"); + let url = Url::parse(&url_string).map_err(|e| PyDataFusionError::Common(e.to_string()))?; + self.ctx.runtime_env().deregister_object_store(&url)?; + Ok(()) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, table_partition_cols=vec![], file_extension=".parquet", @@ -492,6 +508,10 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } + pub fn deregister_udtf(&self, name: &str) { + self.ctx.deregister_udtf(name); + } + #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))] pub fn sql_with_options( &self, @@ -956,6 +976,39 @@ impl PySessionContext { Ok(()) } + #[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn register_arrow( + &self, + name: &str, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult<()> { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_arrow(name, path, options); + wait_for_future(py, result)??; + Ok(()) + } + + pub fn register_batch( + &self, + name: &str, + batch: PyArrowType, + ) -> PyDataFusionResult<()> { + self.ctx.register_batch(name, batch.0)?; + Ok(()) + } + // Registers a PyArrow.Dataset pub fn register_dataset( &self, @@ -975,16 +1028,28 @@ impl PySessionContext { Ok(()) } + pub fn deregister_udf(&self, name: &str) { + self.ctx.deregister_udf(name); + } + pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> { self.ctx.register_udaf(udaf.function); Ok(()) } + pub fn deregister_udaf(&self, name: &str) { + self.ctx.deregister_udaf(name); + } + pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> { self.ctx.register_udwf(udwf.function); Ok(()) } + pub fn deregister_udwf(&self, name: &str) { + self.ctx.deregister_udwf(name); + } + #[pyo3(signature = (name="datafusion"))] pub fn catalog(&self, py: Python, name: &str) -> PyResult> { let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( @@ -1007,21 +1072,6 @@ impl PySessionContext { self.ctx.catalog_names().into_iter().collect() } - pub fn tables(&self) -> HashSet { - self.ctx - .catalog_names() - .into_iter() - .filter_map(|name| self.ctx.catalog(&name)) - .flat_map(move |catalog| { - catalog - .schema_names() - .into_iter() - .filter_map(move |name| catalog.schema(&name)) - }) - .flat_map(|schema| schema.table_names()) - .collect() - } - pub fn table(&self, name: &str, py: Python) -> PyResult { let res = wait_for_future(py, self.ctx.table(name)) .map_err(|e| PyKeyError::new_err(e.to_string()))?; @@ -1050,6 +1100,49 @@ impl PySessionContext { self.ctx.session_id() } + pub fn session_start_time(&self) -> String { + self.ctx.session_start_time().to_rfc3339() + } + + pub fn enable_ident_normalization(&self) -> bool { + self.ctx.enable_ident_normalization() + } + + pub fn parse_sql_expr(&self, sql: &str, schema: PyDFSchema) -> PyDataFusionResult { + let df_schema: DFSchema = schema.into(); + Ok(self.ctx.parse_sql_expr(sql, &df_schema)?.into()) + } + + pub fn execute_logical_plan( + &self, + plan: PyLogicalPlan, + py: Python, + ) -> PyDataFusionResult { + let df = wait_for_future( + py, + self.ctx.execute_logical_plan(plan.plan.as_ref().clone()), + )??; + Ok(PyDataFrame::new(df)) + } + + pub fn refresh_catalogs(&self, py: Python) -> PyDataFusionResult<()> { + wait_for_future(py, self.ctx.refresh_catalogs())??; + Ok(()) + } + + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.ctx.remove_optimizer_rule(name) + } + + pub fn table_provider(&self, name: &str, py: Python) -> PyResult { + let provider = wait_for_future(py, self.ctx.table_provider(name)) + // Outer error: runtime/async failure + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + // Inner error: table not found + .map_err(|e| PyKeyError::new_err(e.to_string()))?; + Ok(PyTable { table: provider }) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] pub fn read_json( @@ -1184,6 +1277,29 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } + #[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn read_arrow( + &self, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.read_arrow(path, options); + let df = wait_for_future(py, result)??; + Ok(PyDataFrame::new(df)) + } + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { let session = self.clone().into_bound_py_any(table.py())?; let table = PyTable::new(table, Some(session))?; diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 72595ba81..2e74991b8 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -37,9 +37,15 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; -use datafusion::logical_expr::SortExpr; +use datafusion::execution::context::TaskContext; use datafusion::logical_expr::dml::InsertOp; +use datafusion::logical_expr::{LogicalPlan, SortExpr}; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; +use datafusion::physical_plan::{ + ExecutionPlan as DFExecutionPlan, collect as df_collect, + collect_partitioned as df_collect_partitioned, execute_stream as df_execute_stream, + execute_stream_partitioned as df_execute_stream_partitioned, +}; use datafusion::prelude::*; use datafusion_python_util::{is_ipython_env, spawn_future, wait_for_future}; use futures::{StreamExt, TryStreamExt}; @@ -308,6 +314,9 @@ pub struct PyDataFrame { // In IPython environment cache batches between __repr__ and _repr_html_ calls. batches: SharedCachedBatches, + + // Cache the last physical plan so that metrics are available after execution. + last_plan: Arc>>>, } impl PyDataFrame { @@ -316,6 +325,7 @@ impl PyDataFrame { Self { df: Arc::new(df), batches: Arc::new(Mutex::new(None)), + last_plan: Arc::new(Mutex::new(None)), } } @@ -387,6 +397,20 @@ impl PyDataFrame { Ok(html_str) } + /// Create the physical plan, cache it in `last_plan`, and return the plan together + /// with a task context. Centralises the repeated three-line pattern that appears in + /// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`. + fn create_and_cache_plan( + &self, + py: Python, + ) -> PyDataFusionResult<(Arc, Arc)> { + let df = self.df.as_ref().clone(); + let new_plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&new_plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + Ok((new_plan, task_ctx)) + } + async fn collect_column_inner(&self, column: &str) -> Result { let batches = self .df @@ -468,17 +492,17 @@ impl PyDataFrame { fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult { if let Ok(key) = key.extract::() { // df[col] - self.select_columns(vec![key]) + self.select_exprs(vec![key]) } else if let Ok(tuple) = key.cast::() { // df[col1, col2, col3] let keys = tuple .iter() .map(|item| item.extract::()) .collect::>>()?; - self.select_columns(keys) + self.select_exprs(keys) } else if let Ok(keys) = key.extract::>() { // df[[col1, col2, col3]] - self.select_columns(keys) + self.select_exprs(keys) } else { let message = "DataFrame can only be indexed by string index or indices".to_string(); Err(PyDataFusionError::Common(message)) @@ -554,13 +578,6 @@ impl PyDataFrame { Ok(PyTable::from(table_provider)) } - #[pyo3(signature = (*args))] - fn select_columns(&self, args: Vec) -> PyDataFusionResult { - let args = args.iter().map(|s| s.as_ref()).collect::>(); - let df = self.df.as_ref().clone().select_columns(&args)?; - Ok(Self::new(df)) - } - #[pyo3(signature = (*args))] fn select_exprs(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); @@ -582,6 +599,14 @@ impl PyDataFrame { Ok(Self::new(df)) } + /// Apply window function expressions to the DataFrame + #[pyo3(signature = (*exprs))] + fn window(&self, exprs: Vec) -> PyDataFusionResult { + let window_exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().window(window_exprs)?; + Ok(Self::new(df)) + } + fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; Ok(Self::new(df)) @@ -645,8 +670,9 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect<'py>(&self, py: Python<'py>) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())? - .map_err(PyDataFusionError::from)?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let batches = + wait_for_future(py, df_collect(plan, task_ctx))?.map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() @@ -661,7 +687,8 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult>>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())? + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))? .map_err(PyDataFusionError::from)?; batches @@ -680,7 +707,15 @@ impl PyDataFrame { /// Print the result, 20 lines by default #[pyo3(signature = (num=20))] fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> { - let df = self.df.as_ref().clone().limit(0, Some(num))?; + let mut df = self.df.as_ref().clone(); + df = match self.df.logical_plan() { + LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) => { + // Explain and Analyzer require they are at the top + // of the plan, so do not add a limit. + df + } + _ => df.limit(0, Some(num))?, + }; print_dataframe(py, df) } @@ -804,9 +839,27 @@ impl PyDataFrame { } /// Print the query plan - #[pyo3(signature = (verbose=false, analyze=false))] - fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> { - let df = self.df.as_ref().clone().explain(verbose, analyze)?; + #[pyo3(signature = (verbose=false, analyze=false, format=None))] + fn explain( + &self, + py: Python, + verbose: bool, + analyze: bool, + format: Option<&str>, + ) -> PyDataFusionResult<()> { + let explain_format = match format { + Some(f) => f + .parse::() + .map_err(|e| { + PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e)) + })?, + None => datafusion::common::format::ExplainFormat::Indent, + }; + let opts = datafusion::logical_expr::ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze) + .with_format(explain_format); + let df = self.df.as_ref().clone().explain_with_options(opts)?; print_dataframe(py, df) } @@ -821,7 +874,13 @@ impl PyDataFrame { } /// Get the execution plan for this `DataFrame` + /// + /// If the DataFrame has already been executed (e.g. via `collect()`), + /// returns the cached plan which includes populated metrics. fn execution_plan(&self, py: Python) -> PyDataFusionResult { + if let Some(plan) = self.last_plan.lock().as_ref() { + return Ok(PyExecutionPlan::new(Arc::clone(plan))); + } let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??; Ok(plan.into()) } @@ -864,39 +923,14 @@ impl PyDataFrame { Ok(Self::new(new_df)) } - /// Calculate the distinct union of two `DataFrame`s. The - /// two `DataFrame`s must have exactly the same schema - fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) - } - - #[pyo3(signature = (column, preserve_nulls=true))] - fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); - let df = self - .df - .as_ref() - .clone() - .unnest_columns_with_options(&[column], unnest_options)?; - Ok(Self::new(df)) - } - - #[pyo3(signature = (columns, preserve_nulls=true))] + #[pyo3(signature = (columns, preserve_nulls=true, recursions=None))] fn unnest_columns( &self, columns: Vec, preserve_nulls: bool, + recursions: Option>, ) -> PyDataFusionResult { - // TODO: expose RecursionUnnestOptions - // REF: https://github.com/apache/datafusion/pull/11577 - let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + let unnest_options = build_unnest_options(preserve_nulls, recursions); let cols = columns.iter().map(|s| s.as_ref()).collect::>(); let df = self .df @@ -907,21 +941,79 @@ impl PyDataFrame { } /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self - .df - .as_ref() - .clone() - .intersect(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn intersect(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.intersect_distinct(other)? + } else { + base.intersect(other)? + }; Ok(Self::new(new_df)) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { - let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; + #[pyo3(signature = (py_df, distinct=false))] + fn except_all(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.except_distinct(other)? + } else { + base.except(other)? + }; Ok(Self::new(new_df)) } + /// Union two DataFrames matching columns by name + #[pyo3(signature = (py_df, distinct=false))] + fn union_by_name(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { + let base = self.df.as_ref().clone(); + let other = py_df.df.as_ref().clone(); + let new_df = if distinct { + base.union_by_name_distinct(other)? + } else { + base.union_by_name(other)? + }; + Ok(Self::new(new_df)) + } + + /// Deduplicate rows based on specific columns, keeping the first row per group + fn distinct_on( + &self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> PyDataFusionResult { + let on_expr = on_expr.into_iter().map(|e| e.into()).collect(); + let select_expr = select_expr.into_iter().map(|e| e.into()).collect(); + let sort_expr = sort_expr.map(to_sort_expressions); + let df = self + .df + .as_ref() + .clone() + .distinct_on(on_expr, select_expr, sort_expr)?; + Ok(Self::new(df)) + } + + /// Sort by column expressions with ascending order and nulls last + fn sort_by(&self, exprs: Vec) -> PyDataFusionResult { + let exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.as_ref().clone().sort_by(exprs)?; + Ok(Self::new(df)) + } + + /// Return fully qualified column expressions for the given column names + fn find_qualified_columns(&self, names: Vec) -> PyDataFusionResult> { + let name_refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect(); + let qualified = self.df.find_qualified_columns(&name_refs)?; + Ok(qualified + .into_iter() + .map(|q| Expr::Column(Column::from(q)).into()) + .collect()) + } + /// Write a `DataFrame` to a CSV file. fn write_csv( &self, @@ -1146,14 +1238,17 @@ impl PyDataFrame { } fn execute_stream(&self, py: Python) -> PyDataFusionResult { - let df = self.df.as_ref().clone(); - let stream = spawn_future(py, async move { df.execute_stream().await })?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { - let df = self.df.as_ref().clone(); - let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + let (plan, task_ctx) = self.create_and_cache_plan(py)?; + let streams = spawn_future( + py, + async move { df_execute_stream_partitioned(plan, task_ctx) }, + )?; Ok(streams.into_iter().map(PyRecordBatchStream::new).collect()) } @@ -1295,6 +1390,26 @@ impl PyDataFrameWriteOptions { } } +fn build_unnest_options( + preserve_nulls: bool, + recursions: Option>, +) -> UnnestOptions { + let mut opts = UnnestOptions::default().with_preserve_nulls(preserve_nulls); + if let Some(recs) = recursions { + opts.recursions = recs + .into_iter() + .map( + |(input, output, depth)| datafusion::common::RecursionUnnestOption { + input_column: datafusion::common::Column::from(input.as_str()), + output_column: datafusion::common::Column::from(output.as_str()), + depth, + }, + ) + .collect(); + } + opts +} + /// Print DataFrame fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches diff --git a/crates/core/src/expr/grouping_set.rs b/crates/core/src/expr/grouping_set.rs index 549a866ed..11d8f4fcd 100644 --- a/crates/core/src/expr/grouping_set.rs +++ b/crates/core/src/expr/grouping_set.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use datafusion::logical_expr::GroupingSet; +use datafusion::logical_expr::{Expr, GroupingSet}; use pyo3::prelude::*; +use crate::expr::PyExpr; + #[pyclass( from_py_object, frozen, @@ -30,6 +32,39 @@ pub struct PyGroupingSet { grouping_set: GroupingSet, } +#[pymethods] +impl PyGroupingSet { + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn rollup(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Rollup( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn cube(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Cube( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*expr_lists))] + fn grouping_sets(expr_lists: Vec>) -> PyExpr { + Expr::GroupingSet(GroupingSet::GroupingSets( + expr_lists + .into_iter() + .map(|list| list.into_iter().map(|e| e.expr).collect()) + .collect(), + )) + .into() + } +} + impl From for GroupingSet { fn from(grouping_set: PyGroupingSet) -> Self { grouping_set.grouping_set diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index c32134054..7feb62d79 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -18,20 +18,14 @@ use std::collections::HashMap; use datafusion::common::{Column, ScalarValue, TableReference}; -use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion::functions_window::all_default_window_functions; -use datafusion::logical_expr::expr::{ - Alias, FieldMetadata, NullTreatment as DFNullTreatment, WindowFunction, WindowFunctionParams, -}; -use datafusion::logical_expr::{Expr, ExprFunctionExt, WindowFrame, WindowFunctionDefinition, lit}; +use datafusion::logical_expr::expr::{Alias, FieldMetadata, NullTreatment as DFNullTreatment}; +use datafusion::logical_expr::{Expr, ExprFunctionExt, lit}; use datafusion::{functions, functions_aggregate, functions_window}; use pyo3::prelude::*; use pyo3::wrap_pyfunction; use crate::common::data_type::{NullTreatment, PyScalarValue}; -use crate::context::PySessionContext; -use crate::errors::{PyDataFusionError, PyDataFusionResult}; +use crate::errors::PyDataFusionResult; use crate::expr::PyExpr; use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; @@ -93,6 +87,57 @@ fn array_cat(exprs: Vec) -> PyExpr { array_concat(exprs) } +#[pyfunction] +fn array_distance(array1: PyExpr, array2: PyExpr) -> PyExpr { + let args = vec![array1.into(), array2.into()]; + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::distance::array_distance_udf(), + args, + )) + .into() +} + +#[pyfunction] +fn arrays_zip(exprs: Vec) -> PyExpr { + let exprs = exprs.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::expr_fn::arrays_zip(exprs).into() +} + +#[pyfunction] +#[pyo3(signature = (string, delimiter, null_string=None))] +fn string_to_array(string: PyExpr, delimiter: PyExpr, null_string: Option) -> PyExpr { + let mut args = vec![string.into(), delimiter.into()]; + if let Some(null_string) = null_string { + args.push(null_string.into()); + } + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::string::string_to_array_udf(), + args, + )) + .into() +} + +#[pyfunction] +#[pyo3(signature = (start, stop, step=None))] +fn gen_series(start: PyExpr, stop: PyExpr, step: Option) -> PyExpr { + let mut args = vec![start.into(), stop.into()]; + if let Some(step) = step { + args.push(step.into()); + } + Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( + datafusion::functions_nested::range::gen_series_udf(), + args, + )) + .into() +} + +#[pyfunction] +fn make_map(keys: Vec, values: Vec) -> PyExpr { + let keys = keys.into_iter().map(|x| x.into()).collect(); + let values = values.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::map::map(keys, values).into() +} + #[pyfunction] #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { @@ -255,126 +300,6 @@ fn when(when: PyExpr, then: PyExpr) -> PyResult { Ok(PyCaseBuilder::new(None).when(when, then)) } -/// Helper function to find the appropriate window function. -/// -/// Search procedure: -/// 1) Search built in window functions, which are being deprecated. -/// 1) If a session context is provided: -/// 1) search User Defined Aggregate Functions (UDAFs) -/// 1) search registered window functions -/// 1) search registered aggregate functions -/// 1) If no function has been found, search default aggregate functions. -/// -/// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior. -fn find_window_fn( - name: &str, - ctx: Option, -) -> PyDataFusionResult { - if let Some(ctx) = ctx { - // search UDAFs - let udaf = ctx - .ctx - .udaf(name) - .map(WindowFunctionDefinition::AggregateUDF) - .ok(); - - if let Some(udaf) = udaf { - return Ok(udaf); - } - - let session_state = ctx.ctx.state(); - - // search registered window functions - let window_fn = session_state - .window_functions() - .get(name) - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); - - if let Some(window_fn) = window_fn { - return Ok(window_fn); - } - - // search registered aggregate functions - let agg_fn = session_state - .aggregate_functions() - .get(name) - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); - - if let Some(agg_fn) = agg_fn { - return Ok(agg_fn); - } - } - - // search default aggregate functions - let agg_fn = all_default_aggregate_functions() - .iter() - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); - - if let Some(agg_fn) = agg_fn { - return Ok(agg_fn); - } - - // search default window functions - let window_fn = all_default_window_functions() - .iter() - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); - - if let Some(window_fn) = window_fn { - return Ok(window_fn); - } - - Err(PyDataFusionError::Common(format!( - "window function `{name}` not found" - ))) -} - -/// Creates a new Window function expression -#[allow(clippy::too_many_arguments)] -#[pyfunction] -#[pyo3(signature = (name, args, partition_by=None, order_by=None, window_frame=None, filter=None, distinct=false, ctx=None))] -fn window( - name: &str, - args: Vec, - partition_by: Option>, - order_by: Option>, - window_frame: Option, - filter: Option, - distinct: bool, - ctx: Option, -) -> PyResult { - let fun = find_window_fn(name, ctx)?; - - let window_frame = window_frame - .map(|w| w.into()) - .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty()))); - let filter = filter.map(|f| f.expr.into()); - - Ok(PyExpr { - expr: datafusion::logical_expr::Expr::WindowFunction(Box::new(WindowFunction { - fun, - params: WindowFunctionParams { - args: args.into_iter().map(|x| x.expr).collect::>(), - partition_by: partition_by - .unwrap_or_default() - .into_iter() - .map(|x| x.expr) - .collect::>(), - order_by: order_by - .unwrap_or_default() - .into_iter() - .map(|x| x.into()) - .collect::>(), - window_frame, - filter, - distinct, - null_treatment: None, - }, - })), - }) -} - // Generates a [pyo3] wrapper for associated aggregate functions. // All of the builder options are exposed to the python internal // function and we rely on the wrappers to only use those that @@ -494,6 +419,13 @@ expr_fn!(length, string); expr_fn!(char_length, string); expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); +expr_fn_vec!(greatest); +expr_fn_vec!(least); +expr_fn!( + contains, + string search_str, + "Return true if search_str is found within string (case-sensitive)." +); expr_fn!(cos, num); expr_fn!(cosh, num); expr_fn!(cot, num); @@ -543,6 +475,11 @@ expr_fn!( x y, "Returns x if x is not NULL otherwise returns y." ); +expr_fn!( + nvl2, + x y z, + "Returns y if x is not NULL; otherwise returns z." +); expr_fn!(nullif, arg_1 arg_2); expr_fn!( octet_length, @@ -616,6 +553,7 @@ expr_fn!(date_part, part date); expr_fn!(date_trunc, part date); expr_fn!(date_bin, stride source origin); expr_fn!(make_date, year month day); +expr_fn!(make_time, hour minute second); expr_fn!(to_char, datetime format); expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); @@ -631,8 +569,29 @@ expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); expr_fn!(arrow_cast, arg_1 datatype); +expr_fn_vec!(arrow_metadata); +expr_fn!(union_tag, arg1); expr_fn!(random); +#[pyfunction] +fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { + functions::core::get_field() + .call(vec![expr.into(), name.into()]) + .into() +} + +#[pyfunction] +fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr { + functions::core::union_extract() + .call(vec![union_expr.into(), field_name.into()]) + .into() +} + +#[pyfunction] +fn version() -> PyExpr { + functions::core::version().call(vec![]).into() +} + // Array Functions array_fn!(array_append, array element); array_fn!(array_to_string, array delimiter); @@ -661,10 +620,20 @@ array_fn!(array_intersect, first_array second_array); array_fn!(array_union, array1 array2); array_fn!(array_except, first_array second_array); array_fn!(array_resize, array size value); +array_fn!(array_any_value, array); +array_fn!(array_max, array); +array_fn!(array_min, array); +array_fn!(array_reverse, array); array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); +// Map Functions +array_fn!(map_keys, map); +array_fn!(map_values, map); +array_fn!(map_extract, map key); +array_fn!(map_entries, map); + aggregate_function!(array_agg); aggregate_function!(max); aggregate_function!(min); @@ -696,9 +665,10 @@ aggregate_function!(var_pop); aggregate_function!(approx_distinct); aggregate_function!(approx_median); -// Code is commented out since grouping is not yet implemented -// https://github.com/apache/datafusion-python/issues/861 -// aggregate_function!(grouping); +// The grouping function's physical plan is not implemented, but the +// ResolveGroupingFunction analyzer rule rewrites it before the physical +// planner sees it, so it works correctly at runtime. +aggregate_function!(grouping); #[pyfunction] #[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))] @@ -736,6 +706,19 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } +#[pyfunction] +#[pyo3(signature = (sort_expression, percentile, filter=None))] +pub fn percentile_cont( + sort_expression: PySortExpr, + percentile: f64, + filter: Option, +) -> PyDataFusionResult { + let agg_fn = + functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile)); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + // We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] @@ -936,10 +919,12 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(approx_median))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; + m.add_wrapped(wrap_pyfunction!(percentile_cont))?; m.add_wrapped(wrap_pyfunction!(range))?; m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(arrow_cast))?; + m.add_wrapped(wrap_pyfunction!(arrow_metadata))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?; @@ -960,6 +945,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(contains))?; m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(cosh))?; @@ -974,6 +960,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(date_part))?; m.add_wrapped(wrap_pyfunction!(date_trunc))?; m.add_wrapped(wrap_pyfunction!(make_date))?; + m.add_wrapped(wrap_pyfunction!(make_time))?; m.add_wrapped(wrap_pyfunction!(digest))?; m.add_wrapped(wrap_pyfunction!(ends_with))?; m.add_wrapped(wrap_pyfunction!(exp))?; @@ -981,13 +968,15 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; - // m.add_wrapped(wrap_pyfunction!(grouping))?; + m.add_wrapped(wrap_pyfunction!(greatest))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; m.add_wrapped(wrap_pyfunction!(iszero))?; m.add_wrapped(wrap_pyfunction!(levenshtein))?; m.add_wrapped(wrap_pyfunction!(lcm))?; + m.add_wrapped(wrap_pyfunction!(least))?; m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; @@ -1005,6 +994,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(named_struct))?; m.add_wrapped(wrap_pyfunction!(nanvl))?; m.add_wrapped(wrap_pyfunction!(nvl))?; + m.add_wrapped(wrap_pyfunction!(nvl2))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(nullif))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; @@ -1063,10 +1053,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(get_field))?; + m.add_wrapped(wrap_pyfunction!(union_extract))?; + m.add_wrapped(wrap_pyfunction!(union_tag))?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; - m.add_wrapped(wrap_pyfunction!(window))?; m.add_wrapped(wrap_pyfunction!(regr_avgx))?; m.add_wrapped(wrap_pyfunction!(regr_avgy))?; m.add_wrapped(wrap_pyfunction!(regr_count))?; @@ -1121,9 +1114,24 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_replace_all))?; m.add_wrapped(wrap_pyfunction!(array_sort))?; m.add_wrapped(wrap_pyfunction!(array_slice))?; + m.add_wrapped(wrap_pyfunction!(array_any_value))?; + m.add_wrapped(wrap_pyfunction!(array_distance))?; + m.add_wrapped(wrap_pyfunction!(array_max))?; + m.add_wrapped(wrap_pyfunction!(array_min))?; + m.add_wrapped(wrap_pyfunction!(array_reverse))?; + m.add_wrapped(wrap_pyfunction!(arrays_zip))?; + m.add_wrapped(wrap_pyfunction!(string_to_array))?; + m.add_wrapped(wrap_pyfunction!(gen_series))?; m.add_wrapped(wrap_pyfunction!(flatten))?; m.add_wrapped(wrap_pyfunction!(cardinality))?; + // Map Functions + m.add_wrapped(wrap_pyfunction!(make_map))?; + m.add_wrapped(wrap_pyfunction!(map_keys))?; + m.add_wrapped(wrap_pyfunction!(map_values))?; + m.add_wrapped(wrap_pyfunction!(map_extract))?; + m.add_wrapped(wrap_pyfunction!(map_entries))?; + // Window Functions m.add_wrapped(wrap_pyfunction!(lead))?; m.add_wrapped(wrap_pyfunction!(lag))?; diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index fc2d006d3..77d69911a 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +pub mod metrics; mod options; pub mod physical_plan; mod pyarrow_filter_expression; @@ -92,6 +93,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/crates/core/src/metrics.rs b/crates/core/src/metrics.rs new file mode 100644 index 000000000..ee0937e25 --- /dev/null +++ b/crates/core/src/metrics.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{Datelike, Timelike}; +use datafusion::physical_plan::metrics::{Metric, MetricValue, MetricsSet, Timestamp}; +use pyo3::prelude::*; + +#[pyclass(from_py_object, frozen, name = "MetricsSet", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetricsSet { + metrics: MetricsSet, +} + +impl PyMetricsSet { + pub fn new(metrics: MetricsSet) -> Self { + Self { metrics } + } +} + +#[pymethods] +impl PyMetricsSet { + fn metrics(&self) -> Vec { + self.metrics + .iter() + .map(|m| PyMetric::new(Arc::clone(m))) + .collect() + } + + fn output_rows(&self) -> Option { + self.metrics.output_rows() + } + + fn elapsed_compute(&self) -> Option { + self.metrics.elapsed_compute() + } + + fn spill_count(&self) -> Option { + self.metrics.spill_count() + } + + fn spilled_bytes(&self) -> Option { + self.metrics.spilled_bytes() + } + + fn spilled_rows(&self) -> Option { + self.metrics.spilled_rows() + } + + fn sum_by_name(&self, name: &str) -> Option { + self.metrics.sum_by_name(name).map(|v| v.as_usize()) + } + + fn __repr__(&self) -> String { + format!("{}", self.metrics) + } +} + +#[pyclass(from_py_object, frozen, name = "Metric", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetric { + metric: Arc, +} + +impl PyMetric { + pub fn new(metric: Arc) -> Self { + Self { metric } + } + + fn timestamp_to_pyobject<'py>( + py: Python<'py>, + ts: &Timestamp, + ) -> PyResult>> { + match ts.value() { + Some(dt) => { + let datetime_mod = py.import("datetime")?; + let datetime_cls = datetime_mod.getattr("datetime")?; + let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?; + let result = datetime_cls.call1(( + dt.year(), + dt.month(), + dt.day(), + dt.hour(), + dt.minute(), + dt.second(), + dt.timestamp_subsec_micros(), + tz_utc, + ))?; + Ok(Some(result)) + } + None => Ok(None), + } + } +} + +#[pymethods] +impl PyMetric { + #[getter] + fn name(&self) -> String { + self.metric.value().name().to_string() + } + + #[getter] + fn value<'py>(&self, py: Python<'py>) -> PyResult>> { + match self.metric.value() { + MetricValue::OutputRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::OutputBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::ElapsedCompute(t) => Ok(Some(t.value().into_pyobject(py)?.into_any())), + MetricValue::SpillCount(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::SpilledBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::SpilledRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())), + MetricValue::CurrentMemoryUsage(g) => Ok(Some(g.value().into_pyobject(py)?.into_any())), + MetricValue::Count { count, .. } => { + Ok(Some(count.value().into_pyobject(py)?.into_any())) + } + MetricValue::Gauge { gauge, .. } => { + Ok(Some(gauge.value().into_pyobject(py)?.into_any())) + } + MetricValue::Time { time, .. } => Ok(Some(time.value().into_pyobject(py)?.into_any())), + MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => { + Self::timestamp_to_pyobject(py, ts) + } + _ => Ok(None), + } + } + + #[getter] + fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult>> { + match self.metric.value() { + MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => { + Self::timestamp_to_pyobject(py, ts) + } + _ => Ok(None), + } + } + + #[getter] + fn partition(&self) -> Option { + self.metric.partition() + } + + fn labels(&self) -> HashMap { + self.metric + .labels() + .iter() + .map(|l| (l.name().to_string(), l.value().to_string())) + .collect() + } + + fn __repr__(&self) -> String { + format!("{}", self.metric.value()) + } +} diff --git a/crates/core/src/physical_plan.rs b/crates/core/src/physical_plan.rs index 8674a8b55..fac973884 100644 --- a/crates/core/src/physical_plan.rs +++ b/crates/core/src/physical_plan.rs @@ -26,6 +26,7 @@ use pyo3::types::PyBytes; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; +use crate::metrics::PyMetricsSet; #[pyclass( from_py_object, @@ -96,6 +97,10 @@ impl PyExecutionPlan { Ok(Self::new(plan)) } + pub fn metrics(&self) -> Option { + self.plan.metrics().map(PyMetricsSet::new) + } + fn __repr__(&self) -> String { self.display_indent() } diff --git a/dev/changelog/53.0.0.md b/dev/changelog/53.0.0.md new file mode 100644 index 000000000..3e27a852d --- /dev/null +++ b/dev/changelog/53.0.0.md @@ -0,0 +1,107 @@ + + +# Apache DataFusion Python 53.0.0 Changelog + +This release consists of 52 commits from 9 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- minor: remove deprecated interfaces [#1481](https://github.com/apache/datafusion-python/pull/1481) (timsaucer) + +**Implemented enhancements:** + +- feat: feat: add to_time, to_local_time, to_date functions [#1387](https://github.com/apache/datafusion-python/pull/1387) (mesejo) +- feat: Add FFI_TableProviderFactory support [#1396](https://github.com/apache/datafusion-python/pull/1396) (davisp) + +**Fixed bugs:** + +- fix: satisfy rustfmt check in lib.rs re-exports [#1406](https://github.com/apache/datafusion-python/pull/1406) (kevinjqliu) + +**Documentation updates:** + +- docs: clarify DataFusion 52 FFI session-parameter requirement for provider hooks [#1439](https://github.com/apache/datafusion-python/pull/1439) (kevinjqliu) + +**Other:** + +- Merge release 52.0.0 into main [#1389](https://github.com/apache/datafusion-python/pull/1389) (timsaucer) +- Add workflow to verify release candidate on multiple systems [#1388](https://github.com/apache/datafusion-python/pull/1388) (timsaucer) +- Allow running "verify release candidate" github workflow on Windows [#1392](https://github.com/apache/datafusion-python/pull/1392) (kevinjqliu) +- ci: update pre-commit hooks, fix linting, and refresh dependencies [#1385](https://github.com/apache/datafusion-python/pull/1385) (dariocurr) +- Add CI check for crates.io patches [#1407](https://github.com/apache/datafusion-python/pull/1407) (timsaucer) +- Enable doc tests in local and CI testing [#1409](https://github.com/apache/datafusion-python/pull/1409) (ntjohnson1) +- Upgrade to DataFusion 53 [#1402](https://github.com/apache/datafusion-python/pull/1402) (nuno-faria) +- Catch warnings in FFI unit tests [#1410](https://github.com/apache/datafusion-python/pull/1410) (timsaucer) +- Add docstring examples for Scalar trigonometric functions [#1411](https://github.com/apache/datafusion-python/pull/1411) (ntjohnson1) +- Create workspace with core and util crates [#1414](https://github.com/apache/datafusion-python/pull/1414) (timsaucer) +- Add docstring examples for Scalar regex, crypto, struct and other [#1422](https://github.com/apache/datafusion-python/pull/1422) (ntjohnson1) +- Add docstring examples for Scalar math functions [#1421](https://github.com/apache/datafusion-python/pull/1421) (ntjohnson1) +- Add docstring examples for Common utility functions [#1419](https://github.com/apache/datafusion-python/pull/1419) (ntjohnson1) +- Add docstring examples for Aggregate basic and bitwise/boolean functions [#1416](https://github.com/apache/datafusion-python/pull/1416) (ntjohnson1) +- Fix CI errors on main [#1432](https://github.com/apache/datafusion-python/pull/1432) (timsaucer) +- Add docstring examples for Scalar temporal functions [#1424](https://github.com/apache/datafusion-python/pull/1424) (ntjohnson1) +- Add docstring examples for Aggregate statistical and regression functions [#1417](https://github.com/apache/datafusion-python/pull/1417) (ntjohnson1) +- Add docstring examples for Scalar array/list functions [#1420](https://github.com/apache/datafusion-python/pull/1420) (ntjohnson1) +- Add docstring examples for Scalar string functions [#1423](https://github.com/apache/datafusion-python/pull/1423) (ntjohnson1) +- Add docstring examples for Aggregate window functions [#1418](https://github.com/apache/datafusion-python/pull/1418) (ntjohnson1) +- ci: pin third-party actions to Apache-approved SHAs [#1438](https://github.com/apache/datafusion-python/pull/1438) (kevinjqliu) +- minor: bump datafusion to release version [#1441](https://github.com/apache/datafusion-python/pull/1441) (timsaucer) +- ci: add swap during build, use tpchgen-cli [#1443](https://github.com/apache/datafusion-python/pull/1443) (timsaucer) +- Update remaining existing examples to make testable/standalone executable [#1437](https://github.com/apache/datafusion-python/pull/1437) (ntjohnson1) +- Do not run validate_pycapsule if pointer_checked is used [#1426](https://github.com/apache/datafusion-python/pull/1426) (Tpt) +- Implement configuration extension support [#1391](https://github.com/apache/datafusion-python/pull/1391) (timsaucer) +- Add a working, more complete example of using a catalog (docs) [#1427](https://github.com/apache/datafusion-python/pull/1427) (toppyy) +- chore: update dependencies [#1447](https://github.com/apache/datafusion-python/pull/1447) (timsaucer) +- Complete doc string examples for functions.py [#1435](https://github.com/apache/datafusion-python/pull/1435) (ntjohnson1) +- chore: enforce uv lockfile consistency in CI and pre-commit [#1398](https://github.com/apache/datafusion-python/pull/1398) (mesejo) +- CI: Add CodeQL workflow for GitHub Actions security scanning [#1408](https://github.com/apache/datafusion-python/pull/1408) (kevinjqliu) +- ci: update codespell paths [#1469](https://github.com/apache/datafusion-python/pull/1469) (timsaucer) +- Add missing datetime functions [#1467](https://github.com/apache/datafusion-python/pull/1467) (timsaucer) +- Add AI skill to check current repository against upstream APIs [#1460](https://github.com/apache/datafusion-python/pull/1460) (timsaucer) +- Add missing string function `contains` [#1465](https://github.com/apache/datafusion-python/pull/1465) (timsaucer) +- Add missing conditional functions [#1464](https://github.com/apache/datafusion-python/pull/1464) (timsaucer) +- Reduce peak memory usage during release builds to fix OOM on manylinux runners [#1445](https://github.com/apache/datafusion-python/pull/1445) (kevinjqliu) +- Add missing map functions [#1461](https://github.com/apache/datafusion-python/pull/1461) (timsaucer) +- minor: Fix pytest instructions in the README [#1477](https://github.com/apache/datafusion-python/pull/1477) (nuno-faria) +- Add missing array functions [#1468](https://github.com/apache/datafusion-python/pull/1468) (timsaucer) +- Add missing scalar functions [#1470](https://github.com/apache/datafusion-python/pull/1470) (timsaucer) +- Add missing aggregate functions [#1471](https://github.com/apache/datafusion-python/pull/1471) (timsaucer) +- Add missing Dataframe functions [#1472](https://github.com/apache/datafusion-python/pull/1472) (timsaucer) +- Add missing deregister methods to SessionContext [#1473](https://github.com/apache/datafusion-python/pull/1473) (timsaucer) +- Add missing registration methods [#1474](https://github.com/apache/datafusion-python/pull/1474) (timsaucer) +- Add missing SessionContext utility methods [#1475](https://github.com/apache/datafusion-python/pull/1475) (timsaucer) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 25 Tim Saucer + 13 Nick + 6 Kevin Liu + 2 Daniel Mesejo + 2 Nuno Faria + 1 Paul J. Davis + 1 Thomas Tanon + 1 Topias Pyykkönen + 1 dario curreri +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. + diff --git a/dev/release/README.md b/dev/release/README.md index ed28f4aa6..4833be55a 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -26,11 +26,11 @@ required due to changes in DataFusion rather than having a large amount of work is available. When there is a new official release of DataFusion, we update the `main` branch to point to that, update the version -number, and create a new release branch, such as `branch-0.8`. Once this branch is created, we switch the `main` branch +number, and create a new release branch, such as `branch-53`. Once this branch is created, we switch the `main` branch back to using GitHub dependencies. The release activity (such as generating the changelog) can then happen on the release branch without blocking ongoing development in the `main` branch. -We can cherry-pick commits from the `main` branch into `branch-0.8` as needed and then create new patch releases +We can cherry-pick commits from the `main` branch into `branch-53` as needed and then create new patch releases from that branch. ## Detailed Guide @@ -54,7 +54,8 @@ Before creating a new release: - We need to ensure that the main branch does not have any GitHub dependencies - a PR should be created and merged to update the major version number of the project -- A new release branch should be created, such as `branch-0.8` +- A new release branch should be created, such as `branch-53` +- It is best to push this branch to the apache repository rather than a personal fork in case patch releases are required. ## Preparing a Release Candidate @@ -65,14 +66,14 @@ We maintain a `CHANGELOG.md` so our users know what has been changed between rel The changelog is generated using a Python script: ```bash -$ GITHUB_TOKEN= ./dev/release/generate-changelog.py 24.0.0 HEAD 25.0.0 > dev/changelog/25.0.0.md +$ GITHUB_TOKEN= ./dev/release/generate-changelog.py 52.0.0 HEAD 53.0.0 > dev/changelog/53.0.0.md ``` This script creates a changelog from GitHub PRs based on the labels associated with them as well as looking for titles starting with `feat:`, `fix:`, or `docs:` . The script will produce output similar to: ``` -Fetching list of commits between 24.0.0 and HEAD +Fetching list of commits between 52.0.0 and HEAD Fetching pull requests Categorizing pull requests Generating changelog content @@ -81,6 +82,7 @@ Generating changelog content ### Update the version number The only place you should need to update the version is in the root `Cargo.toml`. +You will need to update this both in the workspace section and also in the dependencies. After updating the toml file, run `cargo update` to update the cargo lock file. If you do not want to update all the dependencies, you can instead run `cargo build` which should only update the version number for `datafusion-python`. @@ -94,14 +96,14 @@ you need to push a tag to start the CI process for release candidates. The follo the upstream repository is called `apache`. ```bash -git tag 0.8.0-rc1 -git push apache 0.8.0-rc1 +git tag 53.0.0-rc1 +git push apache 53.0.0-rc1 ``` ### Create a source release ```bash -./dev/release/create-tarball.sh 0.8.0 1 +./dev/release/create-tarball.sh 53.0.0 1 ``` This will also create the email template to send to the mailing list. @@ -124,10 +126,10 @@ Click on the action and scroll down to the bottom of the page titled "Artifacts" contain files such as: ```text -datafusion-22.0.0-cp37-abi3-macosx_10_7_x86_64.whl -datafusion-22.0.0-cp37-abi3-macosx_11_0_arm64.whl -datafusion-22.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -datafusion-22.0.0-cp37-abi3-win_amd64.whl +datafusion-53.0.0-cp37-abi3-macosx_10_7_x86_64.whl +datafusion-53.0.0-cp37-abi3-macosx_11_0_arm64.whl +datafusion-53.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +datafusion-53.0.0-cp37-abi3-win_amd64.whl ``` Upload the wheels to testpypi. @@ -135,23 +137,23 @@ Upload the wheels to testpypi. ```bash unzip dist.zip python3 -m pip install --upgrade setuptools twine build -python3 -m twine upload --repository testpypi datafusion-22.0.0-cp37-abi3-*.whl +python3 -m twine upload --repository testpypi datafusion-53.0.0-cp37-abi3-*.whl ``` When prompted for username, enter `__token__`. When prompted for a password, enter a valid GitHub Personal Access Token #### Publish Python Source Distribution to testpypi -Download the source tarball created in the previous step, untar it, and run: +Download the source tarball from the Apache server created in the previous step, untar it, and run: ```bash maturin sdist ``` -This will create a file named `dist/datafusion-0.7.0.tar.gz`. Upload this to testpypi: +This will create a file named `dist/datafusion-53.0.0.tar.gz`. Upload this to testpypi: ```bash -python3 -m twine upload --repository testpypi dist/datafusion-0.7.0.tar.gz +python3 -m twine upload --repository testpypi dist/datafusion-53.0.0.tar.gz ``` ### Run Verify Release Candidate Workflow @@ -162,8 +164,8 @@ Before sending the vote email, run the manually triggered GitHub Actions workflo 1. Go to https://github.com/apache/datafusion-python/actions/workflows/verify-release-candidate.yml 2. Click "Run workflow" -3. Set `version` to the release version (for example, `52.0.0`) -4. Set `rc_number` to the RC number (for example, `0`) +3. Set `version` to the release version (for example, `53.0.0`) +4. Set `rc_number` to the RC number (for example, `1`) 5. Wait for all jobs to complete successfully Include a short note in the vote email template that this workflow was run across all OS/architecture @@ -183,7 +185,7 @@ Releases may be verified using `verify-release-candidate.sh`: ```bash git clone https://github.com/apache/datafusion-python.git -dev/release/verify-release-candidate.sh 48.0.0 1 +dev/release/verify-release-candidate.sh 53.0.0 1 ``` Alternatively, one can run unit tests against a testpypi release candidate: @@ -195,7 +197,7 @@ cd datafusion-python # checkout the release commit git fetch --tags -git checkout 40.0.0-rc1 +git checkout 53.0.0-rc1 git submodule update --init --recursive # create the env @@ -203,7 +205,7 @@ python3 -m venv .venv source .venv/bin/activate # install release candidate -pip install --extra-index-url https://test.pypi.org/simple/ datafusion==40.0.0 +pip install --extra-index-url https://test.pypi.org/simple/ datafusion==53.0.0 # install test dependencies pip install pytest numpy pytest-asyncio @@ -224,7 +226,7 @@ Once the vote passes, we can publish the release. Create the source release tarball: ```bash -./dev/release/release-tarball.sh 0.8.0 1 +./dev/release/release-tarball.sh 53.0.0 1 ``` ### Publishing Rust Crate to crates.io @@ -232,7 +234,7 @@ Create the source release tarball: Some projects depend on the Rust crate directly, so we publish this to crates.io ```shell -cargo publish +cargo publish --workspace ``` ### Publishing Python Artifacts to PyPi @@ -252,15 +254,15 @@ Pypi packages auto upload to conda-forge via [datafusion feedstock](https://gith ### Push the Release Tag ```bash -git checkout 0.8.0-rc1 -git tag 0.8.0 -git push apache 0.8.0 +git checkout 53.0.0-rc1 +git tag 53.0.0 +git push apache 53.0.0 ``` ### Add the release to Apache Reporter Add the release to https://reporter.apache.org/addrelease.html?datafusion with a version name prefixed with `DATAFUSION-PYTHON`, -for example `DATAFUSION-PYTHON-31.0.0`. +for example `DATAFUSION-PYTHON-53.0.0`. The release information is used to generate a template for a board report (see example from Apache Arrow [here](https://github.com/apache/arrow/pull/14357)). @@ -283,7 +285,7 @@ svn ls https://dist.apache.org/repos/dist/dev/datafusion | grep datafusion-pytho Delete a release candidate: ```bash -svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-python-7.1.0-rc1/ +svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-python-53.0.0-rc1/ ``` #### Deleting old releases from `release` svn @@ -299,5 +301,5 @@ svn ls https://dist.apache.org/repos/dist/release/datafusion | grep datafusion-p Delete a release: ```bash -svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-python-7.0.0 +svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-python-52.0.0 ``` diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index b2db144e8..a7a497dab 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -48,4 +48,5 @@ benchmarks/tpch/create_tables.sql .cargo/config.toml **/.cargo/config.toml uv.lock -examples/tpch/answers_sf1/*.tbl \ No newline at end of file +examples/tpch/answers_sf1/*.tbl +SKILL.md \ No newline at end of file diff --git a/docs/source/user-guide/common-operations/aggregations.rst b/docs/source/user-guide/common-operations/aggregations.rst index e458e5fcb..de24a2ba5 100644 --- a/docs/source/user-guide/common-operations/aggregations.rst +++ b/docs/source/user-guide/common-operations/aggregations.rst @@ -163,6 +163,168 @@ Suppose we want to find the speed values for only Pokemon that have low Attack v f.avg(col_speed, filter=col_attack < lit(50)).alias("Avg Speed Low Attack")]) +Grouping Sets +------------- + +The default style of aggregation produces one row per group. Sometimes you want a single query to +produce rows at multiple levels of detail — for example, totals per type *and* an overall grand +total, or subtotals for every combination of two columns plus the individual column totals. Writing +separate queries and concatenating them is tedious and runs the data multiple times. Grouping sets +solve this by letting you specify several grouping levels in one pass. + +DataFusion supports three grouping set styles through the +:py:class:`~datafusion.expr.GroupingSet` class: + +- :py:meth:`~datafusion.expr.GroupingSet.rollup` — hierarchical subtotals, like a drill-down report +- :py:meth:`~datafusion.expr.GroupingSet.cube` — every possible subtotal combination, like a pivot table +- :py:meth:`~datafusion.expr.GroupingSet.grouping_sets` — explicitly list exactly which grouping levels you want + +Because result rows come from different grouping levels, a column that is *not* part of a +particular level will be ``null`` in that row. Use :py:func:`~datafusion.functions.grouping` to +distinguish a real ``null`` in the data from one that means "this column was aggregated across." +It returns ``0`` when the column is a grouping key for that row, and ``1`` when it is not. + +Rollup +^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.rollup` creates a hierarchy. ``rollup(a, b)`` produces +grouping sets ``(a, b)``, ``(a)``, and ``()`` — like nested subtotals in a report. This is useful +when your columns have a natural hierarchy, such as region → city or type → subtype. + +Suppose we want to summarize Pokemon stats by ``Type 1`` with subtotals and a grand total. With +the default aggregation style we would need two separate queries. With ``rollup`` we get it all at +once: + +.. ipython:: python + + from datafusion.expr import GroupingSet + + df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.max(col_speed).alias("Max Speed")] + ).sort(col_type_1.sort(ascending=True, nulls_first=True)) + +The first row — where ``Type 1`` is ``null`` — is the grand total across all types. But how do you +tell a grand-total ``null`` apart from a Pokemon that genuinely has no type? The +:py:func:`~datafusion.functions.grouping` function returns ``0`` when the column is a grouping key +for that row and ``1`` when it is aggregated across. + +.. note:: + + Due to an upstream DataFusion limitation + (`apache/datafusion#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` expression — it will raise an + error at execution time. Instead, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` on the result DataFrame to + give the column a readable name. Once the upstream issue is resolved, you will be able to + use ``.alias()`` directly and the workaround below will no longer be necessary. + +The raw column name generated by ``grouping()`` contains internal identifiers, so we use +:py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` to clean it up: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + result = result.with_column_renamed(field.name, "Is Total") + result.sort(col_type_1.sort(ascending=True, nulls_first=True)) + +With two columns the hierarchy becomes more apparent. ``rollup(Type 1, Type 2)`` produces: + +- one row per ``(Type 1, Type 2)`` pair — the most detailed level +- one row per ``Type 1`` — subtotals +- one grand total row + +.. ipython:: python + + df.aggregate( + [GroupingSet.rollup(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Cube +^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.cube` produces every possible subset. ``cube(a, b)`` +produces grouping sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` — one more than ``rollup`` because +it also includes ``(b)`` alone. This is useful when neither column is "above" the other in a +hierarchy and you want all cross-tabulations. + +For our Pokemon data, ``cube(Type 1, Type 2)`` gives us stats broken down by the type pair, +by ``Type 1`` alone, by ``Type 2`` alone, and a grand total — all in one query: + +.. ipython:: python + + df.aggregate( + [GroupingSet.cube(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Compared to the ``rollup`` example above, notice the extra rows where ``Type 1`` is ``null`` but +``Type 2`` has a value — those are the per-``Type 2`` subtotals that ``rollup`` does not include. + +Explicit Grouping Sets +^^^^^^^^^^^^^^^^^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.grouping_sets` lets you list exactly which grouping levels +you need when ``rollup`` or ``cube`` would produce too many or too few. Each argument is a list of +columns forming one grouping set. + +For example, if we want only the per-``Type 1`` totals and per-``Type 2`` totals — but *not* the +full ``(Type 1, Type 2)`` detail rows or the grand total — we can ask for exactly that: + +.. ipython:: python + + df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Each row belongs to exactly one grouping level. The :py:func:`~datafusion.functions.grouping` +function tells you which level each row comes from: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1), + f.grouping(col_type_2)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + clean = field.name.split(".")[-1].rstrip(")") + result = result.with_column_renamed(field.name, f"grouping({clean})") + result.sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Where ``grouping(Type 1)`` is ``0`` the row is a per-``Type 1`` total (and ``Type 2`` is ``null``). +Where ``grouping(Type 2)`` is ``0`` the row is a per-``Type 2`` total (and ``Type 1`` is ``null``). + + Aggregate Functions ------------------- @@ -192,6 +354,7 @@ The available aggregate functions are: - :py:func:`datafusion.functions.stddev_pop` - :py:func:`datafusion.functions.var_samp` - :py:func:`datafusion.functions.var_pop` + - :py:func:`datafusion.functions.var_population` 6. Linear Regression Functions - :py:func:`datafusion.functions.regr_count` - :py:func:`datafusion.functions.regr_slope` @@ -208,9 +371,16 @@ The available aggregate functions are: - :py:func:`datafusion.functions.nth_value` 8. String Functions - :py:func:`datafusion.functions.string_agg` -9. Approximation Functions +9. Percentile Functions + - :py:func:`datafusion.functions.percentile_cont` + - :py:func:`datafusion.functions.quantile_cont` - :py:func:`datafusion.functions.approx_distinct` - :py:func:`datafusion.functions.approx_median` - :py:func:`datafusion.functions.approx_percentile_cont` - :py:func:`datafusion.functions.approx_percentile_cont_with_weight` +10. Grouping Set Functions + - :py:func:`datafusion.functions.grouping` + - :py:meth:`datafusion.expr.GroupingSet.rollup` + - :py:meth:`datafusion.expr.GroupingSet.cube` + - :py:meth:`datafusion.expr.GroupingSet.grouping_sets` diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst index 1d9d70385..a289c9377 100644 --- a/docs/source/user-guide/common-operations/joins.rst +++ b/docs/source/user-guide/common-operations/joins.rst @@ -134,3 +134,36 @@ In contrast to the above example, if we wish to get both columns: .. ipython:: python left.join(right, "id", how="inner", coalesce_duplicate_keys=False) + +Disambiguating Columns with ``DataFrame.col()`` +------------------------------------------------ + +When both DataFrames contain non-key columns with the same name, you can use +:py:meth:`~datafusion.dataframe.DataFrame.col` on each DataFrame **before** the +join to create fully qualified column references. These references can then be +used in the join predicate and when selecting from the result. + +This is especially useful with :py:meth:`~datafusion.dataframe.DataFrame.join_on`, +which accepts expression-based predicates. + +.. ipython:: python + + left = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [10, 20, 30], + } + ) + + right = ctx.from_pydict( + { + "id": [1, 2, 3], + "val": [40, 50, 60], + } + ) + + joined = left.join_on( + right, left.col("id") == right.col("id"), how="inner" + ) + + joined.select(left.col("id"), left.col("val"), right.col("val")) diff --git a/docs/source/user-guide/common-operations/windows.rst b/docs/source/user-guide/common-operations/windows.rst index c8fdea8f4..d77881bcf 100644 --- a/docs/source/user-guide/common-operations/windows.rst +++ b/docs/source/user-guide/common-operations/windows.rst @@ -175,10 +175,7 @@ it's ``Type 2`` column that are null. Aggregate Functions ------------------- -You can use any :ref:`Aggregation Function` as a window function. Currently -aggregate functions must use the deprecated -:py:func:`datafusion.functions.window` API but this should be resolved in -DataFusion 42.0 (`Issue Link `_). Here +You can use any :ref:`Aggregation Function` as a window function. Here is an example that shows how to compare each pokemons’s attack power with the average attack power in its ``"Type 1"`` using the :py:func:`datafusion.functions.avg` function. @@ -189,10 +186,12 @@ power in its ``"Type 1"`` using the :py:func:`datafusion.functions.avg` function col('"Name"'), col('"Attack"'), col('"Type 1"'), - f.window("avg", [col('"Attack"')]) - .partition_by(col('"Type 1"')) - .build() - .alias("Average Attack"), + f.avg(col('"Attack"')).over( + Window( + window_frame=WindowFrame("rows", None, None), + partition_by=[col('"Type 1"')], + ) + ).alias("Average Attack"), ) Available Functions diff --git a/docs/source/user-guide/dataframe/execution-metrics.rst b/docs/source/user-guide/dataframe/execution-metrics.rst new file mode 100644 index 000000000..764fa76ef --- /dev/null +++ b/docs/source/user-guide/dataframe/execution-metrics.rst @@ -0,0 +1,215 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. _execution_metrics: + +Execution Metrics +================= + +Overview +-------- + +When DataFusion executes a query it compiles the logical plan into a tree of +*physical plan operators* (e.g. ``FilterExec``, ``ProjectionExec``, +``HashAggregateExec``). Each operator can record runtime statistics while it +runs. These statistics are called **execution metrics**. + +Typical metrics include: + +- **output_rows** – number of rows produced by the operator +- **elapsed_compute** – total CPU time (nanoseconds) spent inside the operator +- **spill_count** – number of times the operator spilled data to disk +- **spilled_bytes** – total bytes written to disk during spills +- **spilled_rows** – total rows written to disk during spills + +Metrics are collected *per-partition*: DataFusion may execute each operator +in parallel across several partitions. The convenience properties on +:py:class:`~datafusion.MetricsSet` (e.g. ``output_rows``, ``elapsed_compute``) +automatically sum the named metric across **all** partitions, giving a single +aggregate value for the operator as a whole. You can also access the raw +per-partition :py:class:`~datafusion.Metric` objects via +:py:meth:`~datafusion.MetricsSet.metrics`. + +When Are Metrics Available? +--------------------------- + +Some operators (for example ``DataSourceExec``) eagerly create a +:py:class:`~datafusion.MetricsSet` when the physical plan is built, so +:py:meth:`~datafusion.ExecutionPlan.metrics` may return a set even before any +rows have been processed. However, metric **values** such as ``output_rows`` +are only meaningful **after** the DataFrame has been executed via one of the +terminal operations: + +- :py:meth:`~datafusion.DataFrame.collect` +- :py:meth:`~datafusion.DataFrame.collect_partitioned` +- :py:meth:`~datafusion.DataFrame.execute_stream` + (metrics are available once the stream has been fully consumed) +- :py:meth:`~datafusion.DataFrame.execute_stream_partitioned` + (metrics are available once all partition streams have been fully consumed) + +Before execution, metric values will be ``0`` or ``None``. + +.. note:: + + **display() does not populate metrics.** + When a DataFrame is displayed in a notebook (e.g. via ``display(df)`` or + automatic ``repr`` output), DataFusion runs a *limited* internal execution + to fetch preview rows. This internal execution does **not** cache the + physical plan used, so :py:meth:`~datafusion.ExecutionPlan.collect_metrics` + will not reflect the display execution. To access metrics you must call + one of the terminal operations listed above. + +If you call :py:meth:`~datafusion.DataFrame.collect` (or another terminal +operation) multiple times on the same DataFrame, each call creates a fresh +physical plan. Metrics from :py:meth:`~datafusion.DataFrame.execution_plan` +always reflect the **most recent** execution. + +Reading the Physical Plan Tree +-------------------------------- + +:py:meth:`~datafusion.DataFrame.execution_plan` returns the root +:py:class:`~datafusion.ExecutionPlan` node of the physical plan tree. The tree +mirrors the operator pipeline: the root is typically a projection or +coalescing node; its children are filters, aggregates, scans, etc. + +The ``operator_name`` string returned by +:py:meth:`~datafusion.ExecutionPlan.collect_metrics` is the *display* name of +the node, for example ``"FilterExec: column1@0 > 1"``. This is the same string +you would see when calling ``plan.display()``. + +Aggregated vs Per-Partition Metrics +------------------------------------ + +DataFusion executes each operator across one or more **partitions** in +parallel. The :py:class:`~datafusion.MetricsSet` convenience properties +(``output_rows``, ``elapsed_compute``, etc.) automatically **sum** the named +metric across all partitions, giving a single aggregate value. + +To inspect individual partitions — for example to detect data skew where one +partition processes far more rows than others — iterate over the raw +:py:class:`~datafusion.Metric` objects: + +.. code-block:: python + + for metric in metrics_set.metrics(): + print(f" partition={metric.partition} {metric.name}={metric.value}") + +The ``partition`` property is a 0-based index (``0``, ``1``, …) identifying +which parallel slot processed this metric. It is ``None`` for metrics that +apply globally (not tied to a specific partition). + +Available Metrics +----------------- + +The following metrics are directly accessible as properties on +:py:class:`~datafusion.MetricsSet`: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Property + - Description + * - ``output_rows`` + - Number of rows emitted by the operator (summed across partitions). + * - ``elapsed_compute`` + - Wall-clock CPU time **in nanoseconds** spent inside the operator's + compute loop, excluding I/O wait. Useful for identifying which + operators are most expensive (summed across partitions). + * - ``spill_count`` + - Number of spill-to-disk events triggered by memory pressure. This is + a unitless count of events, not a measure of data volume (summed across + partitions). + * - ``spilled_bytes`` + - Total bytes written to disk during spill events (summed across + partitions). + * - ``spilled_rows`` + - Total rows written to disk during spill events (summed across + partitions). + +Any metric not listed above can be accessed via +:py:meth:`~datafusion.MetricsSet.sum_by_name`, or by iterating over the raw +:py:class:`~datafusion.Metric` objects returned by +:py:meth:`~datafusion.MetricsSet.metrics`. + +Labels +------ + +A :py:class:`~datafusion.Metric` may carry *labels*: key/value pairs that +provide additional context. Labels are operator-specific; most metrics have +an empty label dict. + +Some operators tag their metrics with labels to distinguish variants. For +example, a ``HashAggregateExec`` may record separate ``output_rows`` metrics +for intermediate and final output: + +.. code-block:: python + + for metric in metrics_set.metrics(): + print(metric.name, metric.labels()) + # output_rows {'output_type': 'final'} + # output_rows {'output_type': 'intermediate'} + +When summing by name (via :py:attr:`~datafusion.MetricsSet.output_rows` or +:py:meth:`~datafusion.MetricsSet.sum_by_name`), **all** metrics with that +name are summed regardless of labels. To filter by label, iterate over the +raw :py:class:`~datafusion.Metric` objects directly. + +End-to-End Example +------------------ + +.. code-block:: python + + from datafusion import SessionContext + + ctx = SessionContext() + ctx.sql("CREATE TABLE sales AS VALUES (1, 100), (2, 200), (3, 50)") + + df = ctx.sql("SELECT * FROM sales WHERE column1 > 1") + + # Execute the query — this populates the metrics + results = df.collect() + + # Retrieve the physical plan with metrics + plan = df.execution_plan() + + # Walk every operator and print its metrics + for operator_name, ms in plan.collect_metrics(): + if ms.output_rows is not None: + print(f"{operator_name}") + print(f" output_rows = {ms.output_rows}") + print(f" elapsed_compute = {ms.elapsed_compute} ns") + + # Access raw per-partition metrics + for operator_name, ms in plan.collect_metrics(): + for metric in ms.metrics(): + print( + f" partition={metric.partition} " + f"{metric.name}={metric.value} " + f"labels={metric.labels()}" + ) + +API Reference +------------- + +- :py:class:`datafusion.ExecutionPlan` — physical plan node +- :py:meth:`datafusion.ExecutionPlan.collect_metrics` — walk the tree and + return ``(operator_name, MetricsSet)`` pairs +- :py:meth:`datafusion.ExecutionPlan.metrics` — return the + :py:class:`~datafusion.MetricsSet` for a single node +- :py:class:`datafusion.MetricsSet` — aggregated metrics for one operator +- :py:class:`datafusion.Metric` — a single per-partition metric value diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 510bcbc68..8475a7bd7 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -365,7 +365,16 @@ DataFusion provides many built-in functions for data manipulation: For a complete list of available functions, see the :py:mod:`datafusion.functions` module documentation. +Execution Metrics +----------------- + +After executing a DataFrame (via ``collect()``, ``execute_stream()``, etc.), +DataFusion populates per-operator runtime statistics such as row counts and +compute time. See :doc:`execution-metrics` for a full explanation and +worked example. + .. toctree:: :maxdepth: 1 rendering + execution-metrics diff --git a/docs/source/user-guide/dataframe/rendering.rst b/docs/source/user-guide/dataframe/rendering.rst index 9dea948bb..dc61a422f 100644 --- a/docs/source/user-guide/dataframe/rendering.rst +++ b/docs/source/user-guide/dataframe/rendering.rst @@ -15,18 +15,18 @@ .. specific language governing permissions and limitations .. under the License. -HTML Rendering in Jupyter -========================= +DataFrame Rendering +=================== -When working in Jupyter notebooks or other environments that support rich HTML display, -DataFusion DataFrames automatically render as nicely formatted HTML tables. This functionality -is provided by the ``_repr_html_`` method, which is automatically called by Jupyter to provide -a richer visualization than plain text output. +DataFusion provides configurable rendering for DataFrames in both plain text and HTML +formats. The ``datafusion.dataframe_formatter`` module controls how DataFrames are +displayed in Jupyter notebooks (via ``_repr_html_``), in the terminal (via ``__repr__``), +and anywhere else a string or HTML representation is needed. -Basic HTML Rendering --------------------- +Basic Rendering +--------------- -In a Jupyter environment, simply displaying a DataFrame object will trigger HTML rendering: +In a Jupyter environment, displaying a DataFrame triggers HTML rendering: .. code-block:: python @@ -36,74 +36,117 @@ In a Jupyter environment, simply displaying a DataFrame object will trigger HTML # Explicit display also uses HTML rendering display(df) -Customizing HTML Rendering ---------------------------- +In a terminal or when converting to string, plain text rendering is used: + +.. code-block:: python -DataFusion provides extensive customization options for HTML table rendering through the -``datafusion.html_formatter`` module. + # Plain text table output + print(df) -Configuring the HTML Formatter -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Configuring the Formatter +------------------------- -You can customize how DataFrames are rendered by configuring the formatter: +You can customize how DataFrames are rendered by configuring the global formatter: .. code-block:: python - from datafusion.html_formatter import configure_formatter - - # Change the default styling + from datafusion.dataframe_formatter import configure_formatter + configure_formatter( - max_cell_length=25, # Maximum characters in a cell before truncation - max_width=1000, # Maximum width in pixels - max_height=300, # Maximum height in pixels - max_memory_bytes=2097152, # Maximum memory for rendering (2MB) - min_rows=10, # Minimum number of rows to display - max_rows=10, # Maximum rows to display in __repr__ - enable_cell_expansion=True,# Allow expanding truncated cells - custom_css=None, # Additional custom CSS + max_cell_length=25, # Maximum characters in a cell before truncation + max_width=1000, # Maximum width in pixels (HTML only) + max_height=300, # Maximum height in pixels (HTML only) + max_memory_bytes=2097152, # Maximum memory for rendering (2MB) + min_rows=10, # Minimum number of rows to display + max_rows=10, # Maximum rows to display + enable_cell_expansion=True, # Allow expanding truncated cells (HTML only) + custom_css=None, # Additional custom CSS (HTML only) show_truncation_message=True, # Show message when data is truncated - style_provider=None, # Custom styling provider - use_shared_styles=True # Share styles across tables + style_provider=None, # Custom styling provider (HTML only) + use_shared_styles=True, # Share styles across tables (HTML only) ) The formatter settings affect all DataFrames displayed after configuration. Custom Style Providers ------------------------ +---------------------- -For advanced styling needs, you can create a custom style provider: +For HTML styling, you can create a custom style provider that implements the +``StyleProvider`` protocol: .. code-block:: python - from datafusion.html_formatter import StyleProvider, configure_formatter - - class MyStyleProvider(StyleProvider): - def get_table_styles(self): - return { - "table": "border-collapse: collapse; width: 100%;", - "th": "background-color: #007bff; color: white; padding: 8px; text-align: left;", - "td": "border: 1px solid #ddd; padding: 8px;", - "tr:nth-child(even)": "background-color: #f2f2f2;", - } - - def get_value_styles(self, dtype, value): - """Return custom styles for specific values""" - if dtype == "float" and value < 0: - return "color: red;" - return None - + from datafusion.dataframe_formatter import configure_formatter + + class MyStyleProvider: + def get_cell_style(self): + """Return CSS style string for table data cells.""" + return "border: 1px solid #ddd; padding: 8px; text-align: left;" + + def get_header_style(self): + """Return CSS style string for table header cells.""" + return ( + "background-color: #007bff; color: white; " + "padding: 8px; text-align: left;" + ) + # Apply the custom style provider configure_formatter(style_provider=MyStyleProvider()) +Custom Cell Formatters +---------------------- + +You can register custom formatters for specific Python types. A cell formatter is any +callable that takes a value and returns a string: + +.. code-block:: python + + from datafusion.dataframe_formatter import get_formatter + + formatter = get_formatter() + + # Format floats to 2 decimal places + formatter.register_formatter(float, lambda v: f"{v:.2f}") + + # Format dates in a custom way + from datetime import date + formatter.register_formatter(date, lambda v: v.strftime("%B %d, %Y")) + +Custom Cell and Header Builders +------------------------------- + +For full control over the HTML of individual cells or headers, you can set custom +builder functions: + +.. code-block:: python + + from datafusion.dataframe_formatter import get_formatter + + formatter = get_formatter() + + # Custom cell builder receives (value, row, col, table_id) and returns HTML + def my_cell_builder(value, row, col, table_id): + color = "red" if isinstance(value, (int, float)) and value < 0 else "black" + return f"{value}" + + formatter.set_custom_cell_builder(my_cell_builder) + + # Custom header builder receives a schema field and returns HTML + def my_header_builder(field): + return f"{field.name}" + + formatter.set_custom_header_builder(my_header_builder) + Performance Optimization with Shared Styles -------------------------------------------- -The ``use_shared_styles`` parameter (enabled by default) optimizes performance when displaying -multiple DataFrames in notebook environments: +The ``use_shared_styles`` parameter (enabled by default) optimizes performance when +displaying multiple DataFrames in notebook environments: .. code-block:: python - from datafusion.html_formatter import StyleProvider, configure_formatter + from datafusion.dataframe_formatter import configure_formatter + # Default: Use shared styles (recommended for notebooks) configure_formatter(use_shared_styles=True) @@ -111,76 +154,48 @@ multiple DataFrames in notebook environments: configure_formatter(use_shared_styles=False) When ``use_shared_styles=True``: + - CSS styles and JavaScript are included only once per notebook session - This reduces HTML output size and prevents style duplication - Improves rendering performance with many DataFrames - Applies consistent styling across all DataFrames -Creating a Custom Formatter ----------------------------- +Working with the Formatter Directly +------------------------------------ -For complete control over rendering, you can implement a custom formatter: +You can use ``get_formatter()`` and ``set_formatter()`` for direct access to the global +formatter instance: .. code-block:: python - from datafusion.html_formatter import Formatter, get_formatter - - class MyFormatter(Formatter): - def format_html(self, batches, schema, has_more=False, table_uuid=None): - # Create your custom HTML here - html = "
" - # ... formatting logic ... - html += "
" - return html - - # Set as the global formatter - configure_formatter(formatter_class=MyFormatter) - - # Or use the formatter just for specific operations + from datafusion.dataframe_formatter import ( + DataFrameHtmlFormatter, + get_formatter, + set_formatter, + ) + + # Get and modify the current formatter formatter = get_formatter() - custom_html = formatter.format_html(batches, schema) + print(formatter.max_rows) + print(formatter.max_cell_length) -Managing Formatters -------------------- + # Create and set a fully custom formatter + custom_formatter = DataFrameHtmlFormatter( + max_cell_length=50, + max_rows=20, + enable_cell_expansion=False, + ) + set_formatter(custom_formatter) Reset to default formatting: .. code-block:: python - from datafusion.html_formatter import reset_formatter - + from datafusion.dataframe_formatter import reset_formatter + # Reset to default settings reset_formatter() -Get the current formatter settings: - -.. code-block:: python - - from datafusion.html_formatter import get_formatter - - formatter = get_formatter() - print(formatter.max_rows) - print(formatter.theme) - -Contextual Formatting ----------------------- - -You can also use a context manager to temporarily change formatting settings: - -.. code-block:: python - - from datafusion.html_formatter import formatting_context - - # Default formatting - df.show() - - # Temporarily use different formatting - with formatting_context(max_rows=100, theme="dark"): - df.show() # Will use the temporary settings - - # Back to default formatting - df.show() - Memory and Display Controls --------------------------- @@ -188,10 +203,12 @@ You can control how much data is displayed and how much memory is used for rende .. code-block:: python + from datafusion.dataframe_formatter import configure_formatter + configure_formatter( max_memory_bytes=4 * 1024 * 1024, # 4MB maximum memory for display min_rows=20, # Always show at least 20 rows - max_rows=50 # Show up to 50 rows in output + max_rows=50, # Show up to 50 rows in output ) These parameters help balance comprehensive data display against performance considerations. @@ -216,7 +233,7 @@ Additional Resources * :doc:`../io/index` - I/O Guide for reading data from various sources * :doc:`../data-sources` - Comprehensive data sources guide * :ref:`io_csv` - CSV file reading -* :ref:`io_parquet` - Parquet file reading +* :ref:`io_parquet` - Parquet file reading * :ref:`io_json` - JSON file reading * :ref:`io_avro` - Avro file reading * :ref:`io_custom_table_provider` - Custom table providers diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f97f00dc..105f1632d 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -27,6 +27,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + lineitem + where + l_shipdate <= date '1998-12-01' - interval '90 days' + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus; """ import pyarrow as pa @@ -58,31 +82,25 @@ # Aggregate the results +disc_price = col("l_extendedprice") * (lit(1) - col("l_discount")) + df = df.aggregate( - [col("l_returnflag"), col("l_linestatus")], + ["l_returnflag", "l_linestatus"], [ F.sum(col("l_quantity")).alias("sum_qty"), F.sum(col("l_extendedprice")).alias("sum_base_price"), - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "sum_disc_price" - ), - F.sum( - col("l_extendedprice") - * (lit(1) - col("l_discount")) - * (lit(1) + col("l_tax")) - ).alias("sum_charge"), + F.sum(disc_price).alias("sum_disc_price"), + F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"), F.avg(col("l_quantity")).alias("avg_qty"), F.avg(col("l_extendedprice")).alias("avg_price"), F.avg(col("l_discount")).alias("avg_disc"), - F.count(col("l_returnflag")).alias( - "count_order" - ), # Counting any column should return same result + F.count_star().alias("count_order"), ], ) # Sort per the expected result -df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort()) +df = df.sort_by("l_returnflag", "l_linestatus") # Note: There appears to be a discrepancy between what is returned here and what is in the generated # answers file for the case of return flag N and line status O, but I did not investigate further. diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 7390d0892..c5c6b9c0b 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -27,11 +27,58 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment + from + part, + supplier, + partsupp, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) + order by + s_acctbal desc, + n_name, + s_name, + p_partkey limit 100; """ import datafusion from datafusion import SessionContext, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path # This is the part we're looking for. Values selected here differ from the spec in order to run @@ -66,35 +113,30 @@ "r_regionkey", "r_name" ) -# Filter down parts. Part names contain the type of interest, so we can use strpos to find where -# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position -# in the string where it is located. +# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which +# is an ``ends_with`` check; use the dedicated string function rather than +# a manual substring match. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0) -).filter(col("p_size") == lit(SIZE_OF_INTEREST)) + F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)), + col("p_size") == SIZE_OF_INTEREST, +) # Filter regions down to the one of interest -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join( - df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" -) -df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -) +df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey") +df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join( - df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" -) +df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey") # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -106,17 +148,14 @@ window_frame = datafusion.WindowFrame("rows", None, None) df = df.with_column( "min_cost", - F.window( - "min", - [col("ps_supplycost")], - partition_by=[col("ps_partkey")], - window_frame=window_frame, + F.min(col("ps_supplycost")).over( + Window(partition_by=[col("ps_partkey")], window_frame=window_frame) ), ) -df = df.filter(col("min_cost") == col("ps_supplycost")) - -df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") +df = df.filter(col("min_cost") == col("ps_supplycost")).join( + df_part, left_on="ps_partkey", right_on="p_partkey" +) # From the problem statement, these are the values we wish to output @@ -134,12 +173,10 @@ # Sort and display 100 entries df = df.sort( col("s_acctbal").sort(ascending=False), - col("n_name").sort(), - col("s_name").sort(), - col("p_partkey").sort(), -) - -df = df.limit(100) + "n_name", + "s_name", + "p_partkey", +).limit(100) # Show results diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index fc1231e0a..880c7435f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -25,6 +25,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority + from + customer, + orders, + lineitem + where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' + group by + l_orderkey, + o_orderdate, + o_shippriority + order by + revenue desc, + o_orderdate limit 10; """ from datafusion import SessionContext, col, lit @@ -50,20 +75,20 @@ # Limit dataframes to the rows of interest -df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST)) +df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST) df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST)) df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST)) # Join all 3 dataframes -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") +df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join( + df_lineitem, left_on="o_orderkey", right_on="l_orderkey" +) # Compute the revenue df = df.aggregate( - [col("l_orderkey")], + ["l_orderkey"], [ F.first_value(col("o_orderdate")).alias("o_orderdate"), F.first_value(col("o_shippriority")).alias("o_shippriority"), @@ -71,17 +96,13 @@ ], ) -# Sort by priority - -df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort()) - -# Only return 10 results +# Sort by priority, take 10, and project in the order expected by the spec. -df = df.limit(10) - -# Change the order that the columns are reported in just to match the spec - -df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +df = ( + df.sort(col("revenue").sort(ascending=False), "o_orderdate") + .limit(10) + .select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +) # Show result diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 426338aea..6f11c1383 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -24,18 +24,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_orderpriority, + count(*) as order_count + from + orders + where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) + group by + o_orderpriority + order by + o_orderpriority; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -# Ideally we could put 3 months into the interval. See note below. -INTERVAL_DAYS = 92 -DATE_OF_INTEREST = "1993-07-01" +QUARTER_START = date(1993, 7, 1) +QUARTER_END = date(1993, 10, 1) # Load the dataframes we need @@ -48,36 +70,23 @@ "l_orderkey", "l_commitdate", "l_receiptdate" ) -# Create a date object from the string -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - -# Limit results to cases where commitment date before receipt date -# Aggregate the results so we only get one row to join with the order table. -# Alternately, and likely more idiomatic is instead of `.aggregate` you could -# do `.select("l_orderkey").distinct()`. The goal here is to show -# multiple examples of how to use Data Fusion. -df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate( - [col("l_orderkey")], [] +# Keep only orders in the quarter of interest, then restrict to those that +# have at least one late lineitem via a semi join (the DataFrame form of +# ``EXISTS`` from the reference SQL). +df_orders = df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) -# Limit orders to date range of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) -) +late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) -# Perform the join to find only orders for which there are lineitems outside of expected range df = df_orders.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" + late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi" ) -# Based on priority, find the number of entries -df = df.aggregate( - [col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")] +# Count the number of orders in each priority group and sort. +df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by( + "o_orderpriority" ) -# Sort the results -df = df.sort(col("o_orderpriority").sort()) - df.show() diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index fa2b01dea..bfdba5d4c 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -27,23 +27,45 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue + from + customer, + orders, + lineitem, + supplier, + nation, + region + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year + group by + n_name + order by + revenue desc; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_OF_INTEREST = "1994-01-01" -INTERVAL_DAYS = 365 +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) REGION_OF_INTEREST = "ASIA" -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -68,38 +90,32 @@ ) # Restrict dataframes to cases of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) +df_orders = df_orders.filter( + col("o_orderdate") >= lit(YEAR_START), + col("o_orderdate") < lit(YEAR_END), ) -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Join all the dataframes df = ( - df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" - ) - .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join( df_supplier, left_on=["l_suppkey", "c_nationkey"], right_on=["s_suppkey", "s_nationkey"], - how="inner", ) - .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(df_region, left_on="n_regionkey", right_on="r_regionkey") ) -# Compute the final result +# Compute the final result, then sort in descending order. df = df.aggregate( - [col("n_name")], + ["n_name"], [F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")], -) - -# Sort in descending order - -df = df.sort(col("revenue").sort(ascending=False)) +).sort(col("revenue").sort(ascending=False)) df.show() diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 1de5848b1..ed54d22a4 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -27,28 +27,34 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice * l_discount) as revenue + from + lineitem + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path # Variables from the example query -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) DISCOUT = 0.06 DELTA = 0.01 QUANTITY = 24 -INTERVAL_DAYS = 365 - -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -59,12 +65,11 @@ # Filter down to lineitems of interest -df = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA)) - .filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA)) - .filter(col("l_quantity") < lit(QUANTITY)) +df = df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), + col("l_quantity") < QUANTITY, ) # Add up all the "lost" revenue diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index ff2f891f1..df1c2ae0d 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -26,9 +26,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue + from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit @@ -40,11 +82,8 @@ nation_1 = lit("FRANCE") nation_2 = lit("GERMANY") -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -69,60 +108,44 @@ # Filter to time of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( - col("l_shipdate") <= end_date +df_lineitem = df_lineitem.filter( + col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE) ) -# A simpler way to do the following operation is to use a filter, but we also want to demonstrate -# how to use case statements. Here we are assigning `n_name` to be itself when it is either of -# the two nations of interest. Since there is no `otherwise()` statement, any values that do -# not match these will result in a null value and then get filtered out. -# -# To do the same using a simple filter would be: -# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001 -df_nation = df_nation.with_column( - "n_name", - F.case(col("n_name")) - .when(nation_1, col("n_name")) - .when(nation_2, col("n_name")) - .end(), -).filter(~col("n_name").is_null()) +# Limit the nation table to the two nations of interest. +df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2])) # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("c_custkey"), col("n_name").alias("cust_nation")) + df_nation, left_on="c_nationkey", right_on="n_nationkey" +).select("c_custkey", col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join( - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" - ) - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .filter(col("cust_nation") != col("supp_nation")) ) # Extract out two values for every line item -df = df.with_column( - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) +df = df.with_columns( + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), +) -# Aggregate the results +# Aggregate and sort per the spec. df = df.aggregate( - [col("supp_nation"), col("cust_nation"), col("l_year")], + ["supp_nation", "cust_nation", "l_year"], [F.sum(col("volume")).alias("revenue")], -) - -# Sort based on problem statement requirements -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) +).sort_by("supp_nation", "cust_nation", "l_year") df.show() diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 4bf50efba..dd7bacedb 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -25,24 +25,61 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share + from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations + group by + o_year + order by + o_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -supplier_nation = lit("BRAZIL") -customer_region = lit("AMERICA") -part_of_interest = lit("ECONOMY ANODIZED STEEL") - -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" +supplier_nation = "BRAZIL" +customer_region = "AMERICA" +part_of_interest = "ECONOMY ANODIZED STEEL" -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -74,105 +111,57 @@ # Limit orders to those in the specified range -df_orders = df_orders.filter(col("o_orderdate") >= start_date).filter( - col("o_orderdate") <= end_date -) - -# Part 1: Find customers in the region - -# We want customers in region specified by region_of_interest. This will be used to compute -# the total sales of the part of interest. We want to know of those sales what fraction -# was supplied by the nation of interest. There is no guarantee that the nation of -# interest is within the region of interest. - -# First we find all the sales that make up the basis. - -df_regional_customers = df_region.filter(col("r_name") == customer_region) - -# After this join we have all of the possible sales nations -df_regional_customers = df_regional_customers.join( - df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" -) - -# Now find the possible customers -df_regional_customers = df_regional_customers.join( - df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" -) - -# Next find orders for these customers -df_regional_customers = df_regional_customers.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -) - -# Find all line items from these orders -df_regional_customers = df_regional_customers.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" -) - -# Limit to the part of interest -df_regional_customers = df_regional_customers.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" -) - -# Compute the volume for each line item -df_regional_customers = df_regional_customers.with_column( - "volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")) -) - -# Part 2: Find suppliers from the nation - -# Now that we have all of the sales of that part in the specified region, we need -# to determine which of those came from suppliers in the nation we are interested in. - -df_national_suppliers = df_nation.filter(col("n_name") == supplier_nation) - -# Determine the suppliers by the limited nation key we have in our single row df above -df_national_suppliers = df_national_suppliers.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" -) - -# When we join to the customer dataframe, we don't want to confuse other columns, so only -# select the supplier key that we need -df_national_suppliers = df_national_suppliers.select("s_suppkey") - - -# Part 3: Combine suppliers and customers and compute the market share - -# Now we can do a left outer join on the suppkey. Those line items from other suppliers -# will get a null value. We can check for the existence of this null to compute a volume -# column only from suppliers in the nation we are evaluating. - -df = df_regional_customers.join( - df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" -) - -# Use a case statement to compute the volume sold by suppliers in the nation of interest -df = df.with_column( - "national_volume", - F.case(col("s_suppkey").is_null()) - .when(lit(value=False), col("volume")) - .otherwise(lit(0.0)), -) - -df = df.with_column( - "o_year", F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()) -) - - -# Lastly, sum up the results - -df = df.aggregate( - [col("o_year")], - [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), - ], +df_orders = df_orders.filter( + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) +) + +# Pair each supplier with its nation name so every regional-customer row +# below carries the supplier's nation and can be filtered inside the +# aggregate with ``F.sum(..., filter=...)``. + +df_supplier_with_nation = df_supplier.join( + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) + +# Build every (part, lineitem, order, customer) row for customers in the +# target region ordering the target part. Each row carries the supplier's +# nation so we can aggregate on it below. + +df = ( + df_region.filter(col("r_name") == customer_region) + .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") + .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") + .join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") + .with_columns( + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + ) +) + +# Aggregate the total and national volumes per year via the ``filter`` +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). +# ``coalesce`` handles the case where no sale came from the target nation +# for a given year. +df = ( + df.aggregate( + ["o_year"], + [ + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( + "national_volume" + ), + F.sum(col("volume")).alias("total_volume"), + ], + ) + .select( + "o_year", + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( + "mkt_share" + ), + ) + .sort_by("o_year") ) -df = df.select( - col("o_year"), (F.col("national_volume") / F.col("volume")).alias("mkt_share") -) - -df = df.sort(col("o_year").sort()) - df.show() diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index e2abbd095..ec68a2ab7 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -27,6 +27,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + nation, + o_year, + sum(amount) as sum_profit + from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit + group by + nation, + o_year + order by + nation, + o_year desc; """ import pyarrow as pa @@ -34,7 +69,7 @@ from datafusion import functions as F from util import get_data_path -part_color = lit("green") +part_color = "green" # Load the dataframes we need @@ -62,37 +97,35 @@ "n_nationkey", "n_name", "n_regionkey" ) -# Limit possible parts to the color specified -df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) - -# We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join( - df_partsupp, - left_on=["l_suppkey", "l_partkey"], - right_on=["ps_suppkey", "ps_partkey"], - how="inner", +# Limit possible parts to the color specified, then walk the joins down to the +# line-item rows we need and attach the supplier's nation. ``F.contains`` +# maps directly to the reference SQL's ``p_name like '%green%'``. +df = ( + df_part.filter(F.contains(col("p_name"), lit(part_color))) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join( + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") ) -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( col("n_name").alias("nation"), F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()).alias("o_year"), ( - (col("l_extendedprice") * (lit(1) - col("l_discount"))) - - (col("ps_supplycost") * col("l_quantity")) + col("l_extendedprice") * (lit(1) - col("l_discount")) + - col("ps_supplycost") * col("l_quantity") ).alias("amount"), ) -# Sum up the values by nation and year -df = df.aggregate( - [col("nation"), col("o_year")], [F.sum(col("amount")).alias("profit")] +# Sum up the values by nation and year, then sort per the spec. +df = df.aggregate(["nation", "o_year"], [F.sum(col("amount")).alias("profit")]).sort( + "nation", col("o_year").sort(ascending=False) ) -# Sort according to the problem specification -df = df.sort(col("nation").sort(), col("o_year").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ed822e264..e6532517e 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -27,20 +27,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment + from + customer, + orders, + lineitem, + nation + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey + group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment + order by + revenue desc limit 20; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_START_OF_QUARTER = "1993-10-01" - -date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date()) - -interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1993, 10, 1) +QUARTER_END = date(1994, 1, 1) # Load the dataframes we need @@ -66,44 +96,40 @@ ) # limit to returns -df_lineitem = df_lineitem.filter(col("l_returnflag") == lit("R")) +df_lineitem = df_lineitem.filter(col("l_returnflag") == "R") # Rather than aggregate by all of the customer fields as you might do looking at the specification, # we can aggregate by o_custkey and then join in the customer data at the end. -df = df_orders.filter(col("o_orderdate") >= date_start_of_quarter).filter( - col("o_orderdate") < date_start_of_quarter + interval_one_quarter +df = ( + df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), + ) + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .aggregate( + ["o_custkey"], + [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], + ) ) -df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") - -# Compute the revenue -df = df.aggregate( - [col("o_custkey")], - [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], +# Now join in the customer data, project the spec's output columns, and take the top 20. +df = ( + df.join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_nation, left_on="c_nationkey", right_on="n_nationkey") + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(col("revenue").sort(ascending=False)) + .limit(20) ) -# Now join in the customer data -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") -df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") - -# These are the columns the problem statement requires -df = df.select( - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", -) - -# Sort the results in descending order -df = df.sort(col("revenue").sort(ascending=False)) - -# Only return the top 20 results -df = df.limit(20) - df.show() diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 22829ab7c..1f40bbdad 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -25,10 +25,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + value desc; """ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path NATION = "GERMANY" @@ -48,39 +79,30 @@ "n_nationkey", "n_name" ) -# limit to returns -df_nation = df_nation.filter(col("n_name") == lit(NATION)) - -# Find part supplies of within this target nation - -df = df_nation.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +# Restrict to the target nation, then walk to partsupp rows via the supplier +# join. Aggregate the per-part inventory value. +df = ( + df_nation.filter(col("n_name") == NATION) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .join(df_partsupp, left_on="s_suppkey", right_on="ps_suppkey") + .with_column("value", col("ps_supplycost") * col("ps_availqty")) + .aggregate(["ps_partkey"], [F.sum(col("value")).alias("value")]) ) -df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") - - -# Compute the value of individual parts -df = df.with_column("value", col("ps_supplycost") * col("ps_availqty")) - -# Compute total value of specific parts -df = df.aggregate([col("ps_partkey")], [F.sum(col("value")).alias("value")]) - -# By default window functions go from unbounded preceding to current row, but we want -# to compute this sum across all rows -window_frame = WindowFrame("rows", None, None) - -df = df.with_column( - "total_value", F.window("sum", [col("value")], window_frame=window_frame) +# A window function evaluated over the entire output produces a scalar grand +# total that can be referenced row-by-row in the filter — a DataFrame-native +# stand-in for the SQL HAVING ... > (SELECT SUM(...) * FRACTION ...) pattern. +# The default frame is "UNBOUNDED PRECEDING to CURRENT ROW"; override to the +# full partition for the grand total. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df.with_column( + "total_value", F.sum(col("value")).over(Window(window_frame=whole_frame)) + ) + .filter(col("value") / col("total_value") >= lit(FRACTION)) + .select("ps_partkey", "value") + .sort(col("value").sort(ascending=False)) ) -# Limit to the parts for which there is a significant value based on the fraction of the total -df = df.filter(col("value") / col("total_value") >= lit(FRACTION)) - -# We only need to report on these two columns -df = df.select("ps_partkey", "value") - -# Sort in descending order of value -df = df.sort(col("value").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 9071597f0..fb78fe3c2 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -27,18 +27,49 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + orders, + lineitem + where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year + group by + l_shipmode + order by + l_shipmode; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path SHIP_MODE_1 = "MAIL" SHIP_MODE_2 = "SHIP" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) # Load the dataframes we need @@ -51,63 +82,30 @@ "l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - - -df = df_lineitem.filter(col("l_receiptdate") >= lit(date)).filter( - col("l_receiptdate") < lit(date) + lit(interval) -) - -# Note: It is not recommended to use array_has because it treats the second argument as an argument -# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow. -# Instead check the position of the entry is not null. -df = df.filter( - ~F.array_position( - F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode") - ).is_null() -) - -# Since we have only two values, it's much easier to do this as a filter where the l_shipmode -# matches either of the two values, but we want to show doing some array operations in this -# example. If you want to see this done with filters, comment out the above line and uncomment -# this one. -# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001 +df = df_lineitem.filter( + col("l_receiptdate") >= lit(YEAR_START), + col("l_receiptdate") < lit(YEAR_END), + # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. + F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), + col("l_shipdate") < col("l_commitdate"), + col("l_commitdate") < col("l_receiptdate"), +).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") -# We need order priority, so join order df to line item -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +# Flag each line item as belonging to a high-priority order or not. +high_priorities = [lit("1-URGENT"), lit("2-HIGH")] +is_high = F.in_list(col("o_orderpriority"), high_priorities) +is_low = F.in_list(col("o_orderpriority"), high_priorities, negated=True) -# Restrict to line items we care about based on the problem statement. -df = df.filter(col("l_commitdate") < col("l_receiptdate")) - -df = df.filter(col("l_shipdate") < col("l_commitdate")) - -df = df.with_column( - "high_line_value", - F.case(col("o_orderpriority")) - .when(lit("1-URGENT"), lit(1)) - .when(lit("2-HIGH"), lit(1)) - .otherwise(lit(0)), -) - -# Aggregate the results +# Count the high-priority and low-priority lineitems per ship mode via the +# ``filter`` kwarg on ``F.count`` (DataFrame form of SQL's ``count(*) +# FILTER (WHERE ...)``). df = df.aggregate( - [col("l_shipmode")], + ["l_shipmode"], [ - F.sum(col("high_line_value")).alias("high_line_count"), - F.count(col("high_line_value")).alias("all_lines_count"), + F.count(col("o_orderkey"), filter=is_high).alias("high_line_count"), + F.count(col("o_orderkey"), filter=is_low).alias("low_line_count"), ], -) - -# Compute the final output -df = df.select( - col("l_shipmode"), - col("high_line_count"), - (col("all_lines_count") - col("high_line_count")).alias("low_line_count"), -) - -df = df.sort(col("l_shipmode").sort()) +).sort_by("l_shipmode") df.show() diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 93f082ea3..37c0b93f6 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -26,6 +26,29 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_count, + count(*) as custdist + from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc; """ from datafusion import SessionContext, col, lit @@ -49,20 +72,16 @@ F.regexp_match(col("o_comment"), lit(f"{WORD_1}.?*{WORD_2}")).is_null() ) -# Since we may have customers with no orders we must do a left join -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" -) - -# Find the number of orders for each customer -df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) - -# Ultimately we want to know the number of customers that have that customer count -df = df.aggregate([col("c_count")], [F.count(col("c_count")).alias("custdist")]) - -# We want to order the results by the highest number of customers per count -df = df.sort( - col("custdist").sort(ascending=False), col("c_count").sort(ascending=False) +# Customers with no orders still participate, so this is a left join. Count the +# orders per customer, then count customers per order-count value. +df = ( + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="left") + .aggregate(["c_custkey"], [F.count(col("o_custkey")).alias("c_count")]) + .aggregate(["c_count"], [F.count_star().alias("custdist")]) + .sort( + col("custdist").sort(ascending=False), + col("c_count").sort(ascending=False), + ) ) df.show() diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index d62f76e3c..08f4f054d 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -24,20 +24,32 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + from + lineitem, + part + where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1995-09-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_one_month = lit(pa.scalar((0, 30, 0), type=pa.month_day_nano_interval())) +MONTH_START = date(1995, 9, 1) +MONTH_END = date(1995, 10, 1) # Load the dataframes we need @@ -49,37 +61,30 @@ df_part = ctx.read_parquet(get_data_path("part.parquet")).select("p_partkey", "p_type") -# Check part type begins with PROMO -df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO") -).with_column("promo_factor", lit(1.0)) - -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_one_month -) - -# Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" -) - -# Make a factor of 1.0 if it is a promotion, 0.0 otherwise -df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) -df = df.with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) - - -# Sum up the promo and total revenue -df = df.aggregate( - [], - [ - F.sum(col("promo_factor") * col("revenue")).alias("promo_revenue"), - F.sum(col("revenue")).alias("total_revenue"), - ], -) - -# Return the percentage of revenue from promotions -df = df.select( - (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") +# Restrict the line items to the month of interest, join the matching part +# rows, and aggregate revenue totals with a ``filter`` clause on the promo +# sum — the DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``. +revenue = col("l_extendedprice") * (lit(1.0) - col("l_discount")) +is_promo = F.starts_with(col("p_type"), lit("PROMO")) + +df = ( + df_lineitem.filter( + col("l_shipdate") >= lit(MONTH_START), + col("l_shipdate") < lit(MONTH_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + [], + [ + F.sum(revenue, filter=is_promo).alias("promo_revenue"), + F.sum(revenue).alias("total_revenue"), + ], + ) + .select( + (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias( + "promo_revenue" + ) + ) ) df.show() diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index c321048f2..01c38b9f8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -24,20 +24,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + supplier, + revenue0 + where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) + order by + s_suppkey; + drop view revenue0; """ -from datetime import datetime +from datetime import date -import pyarrow as pa -from datafusion import SessionContext, WindowFrame, col, lit +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1996-01-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_3_months = lit(pa.scalar((0, 91, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1996, 1, 1) +QUARTER_END = date(1996, 4, 1) # Load the dataframes we need @@ -53,37 +83,29 @@ "s_phone", ) -# Limit line items to the quarter of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_3_months -) +# Per-supplier revenue over the quarter of interest. +revenue = col("l_extendedprice") * (lit(1) - col("l_discount")) -df = df_lineitem.aggregate( - [col("l_suppkey")], - [ - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "total_revenue" - ) - ], -) +per_supplier_revenue = df_lineitem.filter( + col("l_shipdate") >= lit(QUARTER_START), + col("l_shipdate") < lit(QUARTER_END), +).aggregate(["l_suppkey"], [F.sum(revenue).alias("total_revenue")]) -# Use a window function to find the maximum revenue across the entire dataframe -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "max_revenue", F.window("max", [col("total_revenue")], window_frame=window_frame) +# Compute the grand maximum revenue separately and join on equality — the +# DataFrame stand-in for the reference SQL's +# ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. +max_revenue = per_supplier_revenue.aggregate( + [], [F.max(col("total_revenue")).alias("max_rev")] ) -# Find all suppliers whose total revenue is the same as the maximum -df = df.filter(col("total_revenue") == col("max_revenue")) - -# Now that we know the supplier(s) with maximum revenue, get the rest of their information -# from the supplier table -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +top_suppliers = per_supplier_revenue.join_on( + max_revenue, col("total_revenue") == col("max_rev") +).select("l_suppkey", "total_revenue") -# Return only the columns requested -df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") - -# If we have more than one, sort by supplier number (suppkey) -df = df.sort(col("s_suppkey").sort()) +df = ( + df_supplier.join(top_suppliers, left_on="s_suppkey", right_on="l_suppkey") + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort_by("s_suppkey") +) df.show() diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 65043ffda..ddeadff5f 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -26,9 +26,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt + from + partsupp, + part + where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) + group by + p_brand, + p_type, + p_size + order by + supplier_cnt desc, + p_brand, + p_type, + p_size; """ -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -52,39 +84,36 @@ ) df_unwanted_suppliers = df_supplier.filter( - ~F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_null() + F.regexp_like(col("s_comment"), lit("Customer.*Complaints")) ) -# Remove unwanted suppliers +# Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). df_partsupp = df_partsupp.join( - df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" + df_unwanted_suppliers, left_on="ps_suppkey", right_on="s_suppkey", how="anti" ) -# Select the parts we are interested in -df_part = df_part.filter(col("p_brand") != lit(BRAND)) +# Select the parts we are interested in. df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) - != lit(TYPE_TO_IGNORE) -) - -# Python conversion of integer to literal casts it to int64 but the data for -# part size is stored as an int32, so perform a cast. Then check to find if the part -# size is within the array of possible sizes by checking the position of it is not -# null. -p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) -df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) - -df = df_part.join( - df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" + col("p_brand") != BRAND, + ~F.starts_with(col("p_type"), lit(TYPE_TO_IGNORE)), + F.in_list(col("p_size"), [lit(s) for s in SIZES_OF_INTEREST]), ) -df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() - -df = df.aggregate( - [col("p_brand"), col("p_type"), col("p_size")], - [F.count(col("ps_suppkey")).alias("supplier_cnt")], +# For each (brand, type, size), count the distinct suppliers remaining. +df = ( + df_part.join(df_partsupp, left_on="p_partkey", right_on="ps_partkey") + .select("p_brand", "p_type", "p_size", "ps_suppkey") + .distinct() + .aggregate( + ["p_brand", "p_type", "p_size"], + [F.count(col("ps_suppkey")).alias("supplier_cnt")], + ) + .sort( + col("supplier_cnt").sort(ascending=False), + "p_brand", + "p_type", + "p_size", + ) ) -df = df.sort(col("supplier_cnt").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 6d76fe506..f2229171f 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -26,10 +26,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); """ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path BRAND = "Brand#23" @@ -46,32 +67,23 @@ "l_partkey", "l_quantity", "l_extendedprice" ) -# Limit to the problem statement's brand and container types -df = df_part.filter(col("p_brand") == lit(BRAND)).filter( - col("p_container") == lit(CONTAINER) -) - -# Combine data -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +# Limit to parts of the target brand/container, join their line items, and +# attach the per-part average quantity via a partitioned window function — +# the DataFrame form of the SQL's correlated ``avg(l_quantity)`` subquery. +whole_frame = WindowFrame("rows", None, None) -# Find the average quantity -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_quantity", - F.window( - "avg", - [col("l_quantity")], - window_frame=window_frame, - partition_by=[col("l_partkey")], - ), +df = ( + df_part.filter(col("p_brand") == BRAND, col("p_container") == CONTAINER) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .with_column( + "avg_quantity", + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=whole_frame) + ), + ) + .filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) + .aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) + .select((col("total") / lit(7.0)).alias("avg_yearly")) ) -df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) - -# Compute the total -df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) - -# Divide by number of years in the problem statement to get average -df = df.select((col("total") / lit(7)).alias("avg_yearly")) - df.show() diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 834d181c9..23132d60d 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -24,9 +24,44 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) + from + customer, + orders, + lineitem + where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey + group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice + order by + o_totalprice desc, + o_orderdate limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -46,22 +81,24 @@ "l_orderkey", "l_quantity", "l_extendedprice" ) -df = df_lineitem.aggregate( - [col("l_orderkey")], [F.sum(col("l_quantity")).alias("total_quantity")] +# Find orders whose total quantity exceeds the threshold, then join in the +# order + customer details the problem statement requires and sort. +df = ( + df_lineitem.aggregate( + ["l_orderkey"], [F.sum(col("l_quantity")).alias("total_quantity")] + ) + .filter(col("total_quantity") > QUANTITY) + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .select( + "c_name", + "c_custkey", + "o_orderkey", + "o_orderdate", + "o_totalprice", + "total_quantity", + ) + .sort(col("o_totalprice").sort(ascending=False), "o_orderdate") ) -# Limit to orders in which the total quantity is above a threshold -df = df.filter(col("total_quantity") > lit(QUANTITY)) - -# We've identified the orders of interest, now join the additional data -# we are required to report on -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - -df = df.select( - "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" -) - -df = df.sort(col("o_totalprice").sort(ascending=False), col("o_orderdate").sort()) - df.show() diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index bd492aac0..a2be1c1b7 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -24,10 +24,47 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice* (1 - l_discount)) as revenue + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); """ -import pyarrow as pa -from datafusion import SessionContext, col, lit, udf +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -65,72 +102,41 @@ "l_discount", ) -# These limitations apply to all line items, so go ahead and do them first - -df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON")) - -df = df.filter( - (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) -) +# Filter conditions that apply to every disjunct of the reference SQL's WHERE +# clause — pull them out up front so the per-brand predicate stays focused on +# the brand-specific parts. +df = df_lineitem.filter( + col("l_shipinstruct") == "DELIVER IN PERSON", + F.in_list(col("l_shipmode"), [lit("AIR"), lit("AIR REG")]), +).join(df_part, left_on="l_partkey", right_on="p_partkey") + + +# Build one OR-combined predicate per brand. Each disjunct encodes the +# brand-specific container list, quantity window, and size range from the +# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand +# B ...) or (... brand C ...)`` form directly, without a UDF. +def _brand_predicate( + brand: str, min_quantity: int, containers: list[str], max_size: int +): + return ( + (col("p_brand") == brand) + & F.in_list(col("p_container"), [lit(c) for c in containers]) + & col("l_quantity").between(lit(min_quantity), lit(min_quantity + 10)) + & col("p_size").between(lit(1), lit(max_size)) + ) -df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") - - -# Create the user defined function (UDF) definition that does the work -def is_of_interest( - brand_arr: pa.Array, - container_arr: pa.Array, - quantity_arr: pa.Array, - size_arr: pa.Array, -) -> pa.Array: - """ - The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array - and generating a resultant Array. The length of the inputs should match and there should be the - same number of rows in the output. - """ - result = [] - for idx, brand_val in enumerate(brand_arr): - brand = brand_val.as_py() - if brand in items_of_interest: - values_of_interest = items_of_interest[brand] - - container_matches = ( - container_arr[idx].as_py() in values_of_interest["containers"] - ) - - quantity = quantity_arr[idx].as_py() - quantity_matches = ( - values_of_interest["min_quantity"] - <= quantity - <= values_of_interest["min_quantity"] + 10 - ) - - size = size_arr[idx].as_py() - size_matches = 1 <= size <= values_of_interest["max_size"] - - result.append(container_matches and quantity_matches and size_matches) - else: - result.append(False) - - return pa.array(result) - - -# Turn the above function into a UDF that DataFusion can understand -is_of_interest_udf = udf( - is_of_interest, - [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()], - pa.bool_(), - "stable", -) -# Filter results using the above UDF -df = df.filter( - is_of_interest_udf( - col("p_brand"), col("p_container"), col("l_quantity"), col("p_size") +predicate = None +for brand, params in items_of_interest.items(): + part_predicate = _brand_predicate( + brand, + params["min_quantity"], + params["containers"], + params["max_size"], ) -) + predicate = part_predicate if predicate is None else predicate | part_predicate -df = df.aggregate( +df = df.filter(predicate).aggregate( [], [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], ) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index a25188d31..18f96da97 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -25,17 +25,57 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + s_address + from + supplier, + nation + where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' + order by + s_name; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path COLOR_OF_INTEREST = "forest" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) NATION_OF_INTEREST = "CANADA" # Load the dataframes we need @@ -56,46 +96,48 @@ "n_nationkey", "n_name" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - -# Filter down dataframes -df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) -df_part = df_part.filter( - F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1)) - == lit(COLOR_OF_INTEREST) -) - -df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter( - col("l_shipdate") < lit(date) + lit(interval) +# Filter down dataframes. ``starts_with`` reads more naturally than an +# explicit substring slice and maps directly to the reference SQL's +# ``p_name like 'forest%'`` clause. +df_nation = df_nation.filter(col("n_name") == NATION_OF_INTEREST) +df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) + +# Compute the total quantity of interesting parts shipped by each (part, +# supplier) pair within the year of interest. +totals = ( + df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + ["l_partkey", "l_suppkey"], + [F.sum(col("l_quantity")).alias("total_sold")], + ) ) -# This will filter down the line items to the parts of interest -df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") - -# Compute the total sold and limit ourselves to individual supplier/part combinations -df = df.aggregate( - [col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")] +# Keep only (part, supplier) pairs whose available quantity exceeds 50% of +# the total shipped. The result already contains one row per supplier of +# interest, so we can semi-join the supplier table rather than inner-join +# and deduplicate afterwards. +excess_suppliers = ( + df_partsupp.join( + totals, + left_on=["ps_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + ) + .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) + .select(col("ps_suppkey").alias("suppkey")) + .distinct() ) -df = df.join( - df_partsupp, - left_on=["l_partkey", "l_suppkey"], - right_on=["ps_partkey", "ps_suppkey"], - how="inner", +# Limit to suppliers in the nation of interest and pick out the two +# requested columns. +df = ( + df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") + .select("s_name", "s_address") + .sort_by("s_name") ) -# Find cases of excess quantity -df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) - -# We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - -# Restrict to the requested data per the problem statement -df = df.select("s_name", "s_address").distinct() - -df = df.sort(col("s_name").sort()) - df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 4ee9d3733..d98f76ce7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -24,9 +24,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + count(*) as numwait + from + supplier, + lineitem l1, + orders, + nation + where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' + group by + s_name + order by + numwait desc, + s_name limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -50,65 +92,57 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) - -df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" +df_suppliers_of_interest = df_nation.filter(col("n_name") == NATION_OF_INTEREST).join( + df_supplier, left_on="n_nationkey", right_on="s_nationkey" ) -# Find the failed orders and all their line items -df = df_orders.filter(col("o_orderstatus") == lit("F")) - -df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") - -# Identify the line items for which the order is failed due to. -df = df.with_column( - "failed_supp", - F.case(col("l_receiptdate") > col("l_commitdate")) - .when(lit(value=True), col("l_suppkey")) - .end(), +# Line items for orders that have status 'F'. This is the candidate set of +# (order, supplier) pairs we reason about below. +failed_order_lineitems = df_lineitem.join( + df_orders.filter(col("o_orderstatus") == "F"), + left_on="l_orderkey", + right_on="o_orderkey", ) -# There are other ways we could do this but the purpose of this example is to work with rows where -# an element is an array of values. In this case, we will create two columns of arrays. One will be -# an array of all of the suppliers who made up this order. That way we can filter the dataframe for -# only orders where this array is larger than one for multiple supplier orders. The second column -# is all of the suppliers who failed to make their commitment. We can filter the second column for -# arrays with size one. That combination will give us orders that had multiple suppliers where only -# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the -# same supplier reported in either array. -df = df.aggregate( - [col("o_orderkey")], - [ - F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"), - F.array_agg( - col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True - ).alias("failed_suppliers"), - ], +# Line items whose receipt was late. This corresponds to ``l1`` in the +# reference SQL. +late_lineitems = failed_order_lineitems.filter( + col("l_receiptdate") > col("l_commitdate") ) -# This is the check described above which will identify single failed supplier in a multiple -# supplier order. -df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter( - F.array_length(col("all_suppliers")) > lit(1) +# Orders that had more than one distinct supplier. Expressed as +# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's +# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery. +multi_supplier_orders = ( + failed_order_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_suppliers")]) + .filter(col("n_suppliers") > 1) + .select("l_orderkey") ) -# Since we have an array we know is exactly one element long, we can extract that single value. -df = df.select( - col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey") +# Orders where exactly one distinct supplier was late. Stands in for the +# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is +# also late ...)`` subquery: if only one supplier on the order was late, +# nobody else on the same order was late. +single_late_supplier_orders = ( + late_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_late_suppliers")]) + .filter(col("n_late_suppliers") == 1) + .select("l_orderkey") ) -# Join to the supplier of interest list for the nation of interest -df = df.join( - df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +# Keep late line items whose order qualifies on both counts, attach the +# supplier name for suppliers in the nation of interest, count one row per +# qualifying order, and return the top 100. +df = ( + late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi") + .join(single_late_supplier_orders, on="l_orderkey", how="semi") + .join(df_suppliers_of_interest, left_on="l_suppkey", right_on="s_suppkey") + .aggregate(["s_name"], [F.count_star().alias("numwait")]) + .sort(col("numwait").sort(ascending=False), "s_name") + .limit(100) ) -# Count how many orders that supplier is the only failed supplier for -df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) - -# Return in descending order -df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort()) - -df = df.limit(100) - df.show() diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index c4d115b74..5043eeb51 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -24,10 +24,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale + group by + cntrycode + order by + cntrycode; """ from datafusion import SessionContext, WindowFrame, col, lit from datafusion import functions as F +from datafusion.expr import Window from util import get_data_path NATION_CODES = [13, 31, 23, 29, 30, 18, 17] @@ -41,39 +82,36 @@ ) df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select("o_custkey") -# The nation code is a two digit number, but we need to convert it to a string literal -nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) - -# Use the substring operation to extract the first two characters of the phone number -df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0), lit(3))) - -# Limit our search to customers with some balance and in the country code above -df = df.filter(col("c_acctbal") > lit(0.0)) -df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null()) - -# Compute the average balance. By default, the window frame is from unbounded preceding to the -# current row. We want our frame to cover the entire data frame. -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_balance", F.window("avg", [col("c_acctbal")], window_frame=window_frame) -) - -df.show() -# Limit results to customers with above average balance -df = df.filter(col("c_acctbal") > col("avg_balance")) - -# Limit results to customers with no orders -df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") - -# Count up the customers and the balances -df = df.aggregate( - [col("cntrycode")], - [ - F.count(col("c_custkey")).alias("numcust"), - F.sum(col("c_acctbal")).alias("totacctbal"), - ], +# Country code is the two-digit prefix of the phone number. +nation_codes = [lit(str(n)) for n in NATION_CODES] + +# Start from customers with a positive balance in one of the target country +# codes, then attach the grand-mean balance via a whole-frame window so we +# can filter per row — DataFrame stand-in for the SQL's scalar ``(select +# avg(c_acctbal) ... )`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_customer.with_column("cntrycode", F.left(col("c_phone"), lit(2))) + .filter( + col("c_acctbal") > 0.0, + F.in_list(col("cntrycode"), nation_codes), + ) + .with_column( + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=whole_frame)), + ) + .filter(col("c_acctbal") > col("avg_balance")) + # Keep only customers with no orders (anti join = NOT EXISTS). + .join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") + .aggregate( + ["cntrycode"], + [ + F.count_star().alias("numcust"), + F.sum(col("c_acctbal")).alias("totacctbal"), + ], + ) + .sort_by("cntrycode") ) -df = df.sort(col("cntrycode").sort()) - df.show() diff --git a/pyproject.toml b/pyproject.toml index d05a64083..951f7adc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "ARG", "BLE001", "D", + "FBT003", "PD", "PLC0415", "PLR0913", @@ -170,12 +171,15 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "docs/*" = ["D"] "docs/source/conf.py" = ["ANN001", "ERA001", "INP001"] +# CI and pre-commit invoke codespell with different paths, so we have a little +# redundancy here, and we intentionally drop python in the path. [tool.codespell] skip = [ - "./python/tests/test_functions.py", - "./target", + "*/tests/test_functions.py", + "*/target", + "./uv.lock", "uv.lock", - "./examples/tpch/answers_sf1/*", + "*/tpch/answers_sf1/*", ] count = true ignore-words-list = ["IST", "ans"] diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2e6f81166..e4972411a 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,10 +15,44 @@ # specific language governing permissions and limitations # under the License. -"""DataFusion python package. - -This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. -See https://datafusion.apache.org/python for more information. +"""DataFusion: an in-process query engine built on Apache Arrow. + +DataFusion is not a database -- it has no server and no external dependencies. +You create a :py:class:`SessionContext`, point it at data sources (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API. + +Core abstractions +----------------- +- **SessionContext** -- entry point for loading data, running SQL, and creating + DataFrames. +- **DataFrame** -- lazy query builder. Every method returns a new DataFrame; + call :py:meth:`~datafusion.dataframe.DataFrame.collect` or a ``to_*`` + method to execute. +- **Expr** -- expression tree node for column references, literals, and function + calls. Build with :py:func:`col` and :py:func:`lit`. +- **functions** -- 290+ built-in scalar, aggregate, and window functions. + +Quick start +----------- + +>>> from datafusion import SessionContext, col +>>> from datafusion import functions as F +>>> ctx = SessionContext() +>>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) +>>> result = ( +... df.filter(col("a") > 1) +... .with_column("total", col("a") + col("b")) +... .aggregate([], [F.sum(col("total")).alias("grand_total")]) +... ) +>>> result.to_pydict() +{'grand_total': [16]} + +User guide and full documentation: https://datafusion.apache.org/python + +AI agent reference (SQL-to-DataFrame mappings, expression-building patterns, +common pitfalls), written in a dense, skill-oriented format: +https://github.com/apache/datafusion-python/blob/main/SKILL.md """ from __future__ import annotations @@ -35,7 +69,7 @@ # The following imports are okay to remain as opaque to the user. from ._internal import Config -from .catalog import Catalog, Database, Table +from .catalog import Catalog, Table from .col import col, column from .common import DFSchema from .context import ( @@ -47,6 +81,7 @@ from .dataframe import ( DataFrame, DataFrameWriteOptions, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -55,7 +90,7 @@ from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .options import CsvReadOptions -from .plan import ExecutionPlan, LogicalPlan +from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( Accumulator, @@ -80,11 +115,13 @@ "DFSchema", "DataFrame", "DataFrameWriteOptions", - "Database", "ExecutionPlan", + "ExplainFormat", "Expr", "InsertOp", "LogicalPlan", + "Metric", + "MetricsSet", "ParquetColumnOptions", "ParquetWriterOptions", "RecordBatch", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 03c0ddc68..20da5e671 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -129,11 +129,6 @@ def schema(self, name: str = "public") -> Schema: else schema ) - @deprecated("Use `schema` instead.") - def database(self, name: str = "public") -> Schema: - """Returns the database with the given ``name`` from this catalog.""" - return self.schema(name) - def register_schema( self, name: str, @@ -195,11 +190,6 @@ def table_exist(self, name: str) -> bool: return self._raw_schema.table_exist(name) -@deprecated("Use `Schema` instead.") -class Database(Schema): - """See `Schema`.""" - - class Table: """A DataFusion table. diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c8edc816f..dd6790402 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -15,7 +15,32 @@ # specific language governing permissions and limitations # under the License. -"""Session Context and it's associated configuration.""" +""":py:class:`SessionContext` — entry point for running DataFusion queries. + +A :py:class:`SessionContext` holds registered tables, catalogs, and +configuration for the current session. It is the first object most programs +create: from it you register data, run SQL strings +(:py:meth:`SessionContext.sql`), read files +(:py:meth:`SessionContext.read_csv`, +:py:meth:`SessionContext.read_parquet`, ...), and construct +:py:class:`~datafusion.dataframe.DataFrame` objects in memory +(:py:meth:`SessionContext.from_pydict`, +:py:meth:`SessionContext.from_arrow`). + +Session behavior (memory limits, batch size, configured optimizer passes, +...) is controlled by :py:class:`SessionConfig` and +:py:class:`RuntimeEnvBuilder`; SQL dialect limits are controlled by +:py:class:`SQLOptions`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> ctx.sql("SELECT 1 AS n").to_pydict() + {'n': [1]} + +See :ref:`user_guide_concepts` in the online documentation for the broader +execution model. +""" from __future__ import annotations @@ -63,7 +88,8 @@ import polars as pl # type: ignore[import] from datafusion.catalog import CatalogProvider, Table - from datafusion.expr import SortKey + from datafusion.common import DFSchema + from datafusion.expr import Expr, SortKey from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.user_defined import ( AggregateUDF, @@ -425,11 +451,6 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeEnvBuilder: return self -@deprecated("Use `RuntimeEnvBuilder` instead.") -class RuntimeConfig(RuntimeEnvBuilder): - """See `RuntimeEnvBuilder`.""" - - class SQLOptions: """Options to be used when performing SQL queries.""" @@ -568,6 +589,15 @@ def register_object_store( """ self.ctx.register_object_store(schema, store, host) + def deregister_object_store(self, schema: str, host: str | None = None) -> None: + """Remove an object store from the session. + + Args: + schema: The data source schema (e.g. ``"s3://"``). + host: URL for the host (e.g. bucket name). + """ + self.ctx.deregister_object_store(schema, host) + def register_listing_table( self, name: str, @@ -775,14 +805,6 @@ def from_arrow( """ return DataFrame(self.ctx.from_arrow(data, name)) - @deprecated("Use ``from_arrow`` instead.") - def from_arrow_table(self, data: pa.Table, name: str | None = None) -> DataFrame: - """Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table. - - This is an alias for :py:func:`from_arrow`. - """ - return self.from_arrow(data, name) - def from_pandas(self, data: pd.DataFrame, name: str | None = None) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from a Pandas DataFrame. @@ -894,6 +916,35 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def register_batch(self, name: str, batch: pa.RecordBatch) -> None: + """Register a single :py:class:`pa.RecordBatch` as a table. + + Args: + name: Name of the resultant table. + batch: Record batch to register as a table. + + Examples: + >>> ctx = dfn.SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + >>> ctx.register_batch("batch_tbl", batch) + >>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + self.ctx.register_batch(name, batch) + + def deregister_udtf(self, name: str) -> None: + """Remove a user-defined table function from the session. + + Args: + name: Name of the UDTF to deregister. + """ + self.ctx.deregister_udtf(name) + def register_record_batches( self, name: str, partitions: list[list[pa.RecordBatch]] ) -> None: @@ -1092,6 +1143,86 @@ def register_avro( name, str(path), schema, file_extension, table_partition_cols ) + def register_arrow( + self, + name: str, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> None: + """Register an Arrow IPC file as a table. + + The registered table can be referenced from SQL statements executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"x": [10, 20, 30]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow("arrow_tbl", path) + ... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_schema", + ... path, + ... schema=pa.schema([("x", pa.int64())]), + ... ) + ... ctx.sql("SELECT * FROM arrow_schema").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... ctx.register_arrow( + ... "arrow_ipc", path, file_extension=".ipc" + ... ) + ... ctx.sql("SELECT * FROM arrow_ipc").collect()[0].column(0) + + [ + 10, + 20, + 30 + ] + """ + if table_partition_cols is None: + table_partition_cols = [] + table_partition_cols = _convert_table_partition_cols(table_partition_cols) + self.ctx.register_arrow( + name, str(path), schema, file_extension, table_partition_cols + ) + def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None: """Register a :py:class:`pa.dataset.Dataset` as a table. @@ -1105,26 +1236,42 @@ def register_udf(self, udf: ScalarUDF) -> None: """Register a user-defined function (UDF) with the context.""" self.ctx.register_udf(udf._udf) + def deregister_udf(self, name: str) -> None: + """Remove a user-defined scalar function from the session. + + Args: + name: Name of the UDF to deregister. + """ + self.ctx.deregister_udf(name) + def register_udaf(self, udaf: AggregateUDF) -> None: """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf._udaf) + def deregister_udaf(self, name: str) -> None: + """Remove a user-defined aggregate function from the session. + + Args: + name: Name of the UDAF to deregister. + """ + self.ctx.deregister_udaf(name) + def register_udwf(self, udwf: WindowUDF) -> None: """Register a user-defined window function (UDWF) with the context.""" self.ctx.register_udwf(udwf._udwf) + def deregister_udwf(self, name: str) -> None: + """Remove a user-defined window function from the session. + + Args: + name: Name of the UDWF to deregister. + """ + self.ctx.deregister_udwf(name) + def catalog(self, name: str = "datafusion") -> Catalog: """Retrieve a catalog by name.""" return Catalog(self.ctx.catalog(name)) - @deprecated( - "Use the catalog provider interface ``SessionContext.Catalog`` to " - "examine available catalogs, schemas and tables" - ) - def tables(self) -> set[str]: - """Deprecated.""" - return self.ctx.tables() - def table(self, name: str) -> DataFrame: """Retrieve a previously registered table by name.""" return DataFrame(self.ctx.table(name)) @@ -1141,6 +1288,121 @@ def session_id(self) -> str: """Return an id that uniquely identifies this :py:class:`SessionContext`.""" return self.ctx.session_id() + def session_start_time(self) -> str: + """Return the session start time as an RFC 3339 formatted string. + + Examples: + >>> ctx = SessionContext() + >>> ctx.session_start_time() # doctest: +SKIP + '2026-01-01T12:34:56.123456789+00:00' + """ + return self.ctx.session_start_time() + + def enable_ident_normalization(self) -> bool: + """Return whether identifier normalization (lowercasing) is enabled. + + Examples: + >>> ctx = SessionContext() + >>> ctx.enable_ident_normalization() + True + """ + return self.ctx.enable_ident_normalization() + + def parse_sql_expr(self, sql: str, schema: DFSchema) -> Expr: + """Parse a SQL expression string into a logical expression. + + Args: + sql: SQL expression string. + schema: Schema to use for resolving column references. + + Returns: + Parsed expression. + + Examples: + >>> from datafusion.common import DFSchema + >>> ctx = SessionContext() + >>> schema = DFSchema.empty() + >>> ctx.parse_sql_expr("1 + 2", schema=schema) + Expr(Int64(1) + Int64(2)) + """ + from datafusion.expr import Expr # noqa: PLC0415 + + return Expr(self.ctx.parse_sql_expr(sql, schema)) + + def execute_logical_plan(self, plan: LogicalPlan) -> DataFrame: + """Execute a :py:class:`~datafusion.plan.LogicalPlan` and return a DataFrame. + + Args: + plan: Logical plan to execute. + + Returns: + DataFrame resulting from the execution. + + Examples: + >>> ctx = SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> plan = df.logical_plan() + >>> df2 = ctx.execute_logical_plan(plan) + >>> df2.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + return DataFrame(self.ctx.execute_logical_plan(plan._raw_plan)) + + def refresh_catalogs(self) -> None: + """Refresh catalog metadata. + + Examples: + >>> ctx = SessionContext() + >>> ctx.refresh_catalogs() + """ + self.ctx.refresh_catalogs() + + def remove_optimizer_rule(self, name: str) -> bool: + """Remove an optimizer rule by name. + + Args: + name: Name of the optimizer rule to remove. + + Returns: + True if a rule with the given name was found and removed. + + Examples: + >>> ctx = SessionContext() + >>> ctx.remove_optimizer_rule("nonexistent_rule") + False + """ + return self.ctx.remove_optimizer_rule(name) + + def table_provider(self, name: str) -> Table: + """Return the :py:class:`~datafusion.catalog.Table` for the given table name. + + Args: + name: Name of the table. + + Returns: + The table provider. + + Raises: + KeyError: If the table is not found. + + Examples: + >>> import pyarrow as pa + >>> ctx = SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"x": [1, 2]}) + >>> ctx.register_record_batches("my_table", [[batch]]) + >>> tbl = ctx.table_provider("my_table") + >>> tbl.schema + x: int64 + """ + from datafusion.catalog import Table # noqa: PLC0415 + + return Table(self.ctx.table_provider(name)) + def read_json( self, path: str | pathlib.Path, @@ -1328,6 +1590,86 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) + def read_arrow( + self, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> DataFrame: + """Create a :py:class:`DataFrame` for reading an Arrow IPC data source. + + Args: + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + file_partition_cols: Partition columns. + + Returns: + DataFrame representation of the read Arrow IPC file. + + Examples: + >>> import tempfile, os + >>> ctx = dfn.SessionContext() + >>> table = pa.table({"a": [1, 2, 3]}) + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Provide an explicit ``schema`` to override schema inference: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.arrow") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, schema=pa.schema([("a", pa.int64())])) + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + + Use ``file_extension`` to read files with a non-default extension: + + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... path = os.path.join(tmpdir, "data.ipc") + ... with pa.ipc.new_file(path, table.schema) as writer: + ... writer.write_table(table) + ... df = ctx.read_arrow(path, file_extension=".ipc") + ... df.collect()[0].column(0) + + [ + 1, + 2, + 3 + ] + """ + if file_partition_cols is None: + file_partition_cols = [] + file_partition_cols = _convert_table_partition_cols(file_partition_cols) + return DataFrame( + self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols) + ) + + def read_empty(self) -> DataFrame: + """Create an empty :py:class:`DataFrame` with no columns or rows. + + See Also: + This is an alias for :meth:`empty_table`. + """ + return self.empty_table() + def read_table( self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset ) -> DataFrame: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 10e2a913f..2b07861da 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -14,9 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""":py:class:`DataFrame` is one of the core concepts in DataFusion. - -See :ref:`user_guide_concepts` in the online documentation for more information. +""":py:class:`DataFrame` — lazy, chainable query representation. + +A :py:class:`DataFrame` is a logical plan over one or more data sources. +Methods that reshape the plan (:py:meth:`DataFrame.select`, +:py:meth:`DataFrame.filter`, :py:meth:`DataFrame.aggregate`, +:py:meth:`DataFrame.sort`, :py:meth:`DataFrame.join`, +:py:meth:`DataFrame.limit`, the set-operation methods, ...) return a new +:py:class:`DataFrame` and do no work until a terminal method such as +:py:meth:`DataFrame.collect`, :py:meth:`DataFrame.to_pydict`, +:py:meth:`DataFrame.show`, or one of the ``write_*`` methods is called. + +DataFrames are produced from a +:py:class:`~datafusion.context.SessionContext`, typically via +:py:meth:`~datafusion.context.SessionContext.sql`, +:py:meth:`~datafusion.context.SessionContext.read_csv`, +:py:meth:`~datafusion.context.SessionContext.read_parquet`, or +:py:meth:`~datafusion.context.SessionContext.from_pydict`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.filter(col("a") > 1).select("b").to_pydict() + {'b': [20, 30]} + +See :ref:`user_guide_concepts` in the online documentation for a high-level +overview of the execution model. """ from __future__ import annotations @@ -44,6 +67,7 @@ Expr, SortExpr, SortKey, + _to_raw_expr, ensure_expr, ensure_expr_list, expr_list_to_raw_expr_list, @@ -65,6 +89,25 @@ from enum import Enum +class ExplainFormat(Enum): + """Output format for explain plans. + + Controls how the query plan is rendered in :py:meth:`DataFrame.explain`. + """ + + INDENT = "indent" + """Default indented text format.""" + + TREE = "tree" + """Tree-style visual format with box-drawing characters.""" + + PGJSON = "pgjson" + """PostgreSQL-compatible JSON format for use with visualization tools.""" + + GRAPHVIZ = "graphviz" + """Graphviz DOT format for graph rendering.""" + + # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): @@ -395,16 +438,79 @@ def schema(self) -> pa.Schema: """ return self.df.schema() - @deprecated( - "select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead" - ) - def select_columns(self, *args: str) -> DataFrame: - """Filter the DataFrame by columns. + def column(self, name: str) -> Expr: + """Return a fully qualified column expression for ``name``. + + Resolves an unqualified column name against this DataFrame's schema + and returns an :py:class:`Expr` whose underlying column reference + includes the table qualifier. This is especially useful after joins, + where the same column name may appear in multiple relations. + + Args: + name: Unqualified column name to look up. Returns: - DataFrame only containing the specified columns. + A fully qualified column expression. + + Raises: + Exception: If the column is not found or is ambiguous (exists in + multiple relations). + + Examples: + Resolve a column from a simple DataFrame: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + >>> expr = df.column("a") + >>> df.select(expr).to_pydict() + {'a': [1, 2]} + + Resolve qualified columns after a join: + + >>> left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + >>> joined = left.join(right, on="id", how="inner") + >>> expr = joined.column("y") + >>> joined.select("id", expr).sort("id").to_pydict() + {'id': [1, 2], 'y': [30, 40]} """ - return self.select(*args) + return self.find_qualified_columns(name)[0] + + def col(self, name: str) -> Expr: + """Alias for :py:meth:`column`. + + See Also: + :py:meth:`column` + """ + return self.column(name) + + def find_qualified_columns(self, *names: str) -> list[Expr]: + """Return fully qualified column expressions for the given names. + + This is a batch version of :py:meth:`column` — it resolves each + unqualified name against the DataFrame's schema and returns a list + of qualified column expressions. + + Args: + names: Unqualified column names to look up. + + Returns: + List of fully qualified column expressions, one per name. + + Raises: + Exception: If any column is not found or is ambiguous. + + Examples: + Resolve multiple columns at once: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + >>> exprs = df.find_qualified_columns("a", "c") + >>> df.select(*exprs).to_pydict() + {'a': [1, 2], 'c': [5, 6]} + """ + raw_exprs = self.df.find_qualified_columns(list(names)) + return [Expr(e) for e in raw_exprs] def select_exprs(self, *args: str) -> DataFrame: """Project arbitrary list of expression strings into a new DataFrame. @@ -420,21 +526,29 @@ def select_exprs(self, *args: str) -> DataFrame: def select(self, *exprs: Expr | str) -> DataFrame: """Project arbitrary expressions into a new :py:class:`DataFrame`. + String arguments are treated as column names; :py:class:`~datafusion.expr.Expr` + arguments can reshape, rename, or compute new columns. + Args: exprs: Either column names or :py:class:`~datafusion.expr.Expr` to select. Returns: DataFrame after projection. It has one column for each expression. - Example usage: + Examples: + Select columns by name: - The following example will return 3 columns from the original dataframe. - The first two columns will be the original column ``a`` and ``b`` since the - string "a" is assumed to refer to column selection. Also a duplicate of - column ``a`` will be returned with the column name ``alternate_a``:: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.select("a").to_pydict() + {'a': [1, 2, 3]} - df = df.select("a", col("b"), col("a").alias("alternate_a")) + Mix column names, expressions, and aliases. The string ``"a"`` selects + column ``a`` directly; ``col("a").alias("alternate_a")`` returns a + duplicate under a new name: + >>> df.select("a", col("b"), col("a").alias("alternate_a")).to_pydict() + {'a': [1, 2, 3], 'b': [10, 20, 30], 'alternate_a': [1, 2, 3]} """ exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) @@ -468,6 +582,36 @@ def drop(self, *columns: str) -> DataFrame: """ return DataFrame(self.df.drop(*columns)) + def window(self, *exprs: Expr) -> DataFrame: + """Add window function columns to the DataFrame. + + Applies the given window function expressions and appends the results + as new columns. + + Args: + exprs: Window function expressions to evaluate. + + Returns: + DataFrame with new window function columns appended. + + Examples: + Add a row number within each group: + + >>> import datafusion.functions as f + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + >>> df = df.window( + ... f.row_number( + ... partition_by=[col("b")], order_by=[col("a")] + ... ).alias("rn") + ... ) + >>> "rn" in df.schema().names + True + """ + raw = expr_list_to_raw_expr_list(exprs) + return DataFrame(self.df.window(*raw)) + def filter(self, *predicates: Expr | str) -> DataFrame: """Return a DataFrame for which ``predicate`` evaluates to ``True``. @@ -633,12 +777,44 @@ def aggregate( ) -> DataFrame: """Aggregates the rows of the current DataFrame. + By default each unique combination of the ``group_by`` columns + produces one row. To get multiple levels of subtotals in a + single pass, pass a + :py:class:`~datafusion.expr.GroupingSet` expression + (created via + :py:meth:`~datafusion.expr.GroupingSet.rollup`, + :py:meth:`~datafusion.expr.GroupingSet.cube`, or + :py:meth:`~datafusion.expr.GroupingSet.grouping_sets`) + as the ``group_by`` argument. See the + :ref:`aggregation` user guide for detailed examples. + Args: - group_by: Sequence of expressions or column names to group by. + group_by: Sequence of expressions or column names to group + by. A :py:class:`~datafusion.expr.GroupingSet` + expression may be included to produce multiple grouping + levels (rollup, cube, or explicit grouping sets). aggs: Sequence of expressions to aggregate. Returns: DataFrame after aggregation. + + Examples: + Aggregate without grouping — an empty ``group_by`` produces a + single row: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"team": ["x", "x", "y"], "score": [1, 2, 5]} + ... ) + >>> df.aggregate([], [F.sum(col("score")).alias("total")]).to_pydict() + {'total': [8]} + + Group by a column and produce one row per group: + + >>> df.aggregate( + ... ["team"], [F.sum(col("score")).alias("total")] + ... ).sort("team").to_pydict() + {'team': ['x', 'y'], 'total': [3, 5]} """ group_by_list = ( list(group_by) @@ -659,13 +835,27 @@ def sort(self, *exprs: SortKey) -> DataFrame: """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its ``sort`` method. + calling its ``sort`` method. For ascending-only sorts, the shorter + :py:meth:`sort_by` is usually more convenient. Args: exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. + + Examples: + Sort ascending by a column name: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2], "b": [10, 20, 30]}) + >>> df.sort("a").to_pydict() + {'a': [1, 2, 3], 'b': [20, 30, 10]} + + Sort descending using :py:meth:`Expr.sort`: + + >>> df.sort(col("a").sort(ascending=False)).to_pydict() + {'a': [3, 2, 1], 'b': [10, 30, 20]} """ exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) @@ -685,12 +875,28 @@ def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. + Results are returned in unspecified order unless the DataFrame is + explicitly sorted first via :py:meth:`sort` or :py:meth:`sort_by`. + Args: count: Number of rows to limit the DataFrame to. offset: Number of rows to skip. Returns: DataFrame after limiting. + + Examples: + Take the first two rows: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}).sort("a") + >>> df.limit(2).to_pydict() + {'a': [1, 2]} + + Skip the first row then take two (paging): + + >>> df.limit(2, offset=1).to_pydict() + {'a': [2, 3]} """ return DataFrame(self.df.limit(count, offset)) @@ -823,7 +1029,13 @@ def join( ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. - `on` has to be provided or both `left_on` and `right_on` in conjunction. + ``on`` has to be provided or both ``left_on`` and ``right_on`` in + conjunction. + + When non-key columns share the same name in both DataFrames, use + :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references that can disambiguate them. + See :py:meth:`join_on` for an example. Args: right: Other DataFrame to join with. @@ -839,6 +1051,28 @@ def join( Returns: DataFrame after join. + + Examples: + Inner-join two DataFrames on a shared column: + + >>> ctx = dfn.SessionContext() + >>> left = ctx.from_pydict({"id": [1, 2, 3], "val": [10, 20, 30]}) + >>> right = ctx.from_pydict({"id": [2, 3, 4], "label": ["b", "c", "d"]}) + >>> left.join(right, on="id").sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'label': ['b', 'c']} + + Left join to keep all rows from the left side: + + >>> left.join(right, on="id", how="left").sort("id").to_pydict() + {'id': [1, 2, 3], 'val': [10, 20, 30], 'label': [None, 'b', 'c']} + + Use ``left_on`` / ``right_on`` when the key columns differ in name: + + >>> right2 = ctx.from_pydict({"rid": [2, 3], "label": ["b", "c"]}) + >>> left.join( + ... right2, left_on="id", right_on="rid" + ... ).sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'rid': [2, 3], 'label': ['b', 'c']} """ if join_keys is not None: warnings.warn( @@ -897,7 +1131,14 @@ def join_on( built with :func:`datafusion.col`. On expressions are used to support in-equality predicates. Equality predicates are correctly optimized. + Use :py:meth:`DataFrame.col` on each DataFrame **before** the join to + obtain fully qualified column references. These qualified references + can then be used in the join predicate and to disambiguate columns + with the same name when selecting from the result. + Examples: + Join with unique column names: + >>> ctx = dfn.SessionContext() >>> left = ctx.from_pydict({"a": [1, 2], "x": ["a", "b"]}) >>> right = ctx.from_pydict({"b": [1, 2], "y": ["c", "d"]}) @@ -906,6 +1147,18 @@ def join_on( ... ).sort(col("x")).to_pydict() {'a': [1, 2], 'x': ['a', 'b'], 'b': [1, 2], 'y': ['c', 'd']} + Use :py:meth:`col` to disambiguate shared column names: + + >>> left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + >>> right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + >>> joined = left.join_on( + ... right, left.col("id") == right.col("id"), how="inner" + ... ) + >>> joined.select( + ... left.col("id"), left.col("val"), right.col("val").alias("rval") + ... ).sort(left.col("id")).to_pydict() + {'id': [1, 2], 'val': [10, 20], 'rval': [30, 40]} + Args: right: Other DataFrame to join with. on_exprs: single or multiple (in)-equality predicates. @@ -918,7 +1171,12 @@ def join_on( exprs = [ensure_expr(expr) for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) - def explain(self, verbose: bool = False, analyze: bool = False) -> None: + def explain( + self, + verbose: bool = False, + analyze: bool = False, + format: ExplainFormat | None = None, + ) -> None: """Print an explanation of the DataFrame's plan so far. If ``analyze`` is specified, runs the plan and reports metrics. @@ -926,8 +1184,23 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> None: Args: verbose: If ``True``, more details will be included. analyze: If ``True``, the plan will run and metrics reported. + format: Output format for the plan. Defaults to + :py:attr:`ExplainFormat.INDENT`. + + Examples: + Show the plan in tree format: + + >>> from datafusion import ExplainFormat + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.explain(format=ExplainFormat.TREE) # doctest: +SKIP + + Show plan with runtime metrics: + + >>> df.explain(analyze=True) # doctest: +SKIP """ - self.df.explain(verbose, analyze) + fmt = format.value if format is not None else None + self.df.explain(verbose, analyze, fmt) def logical_plan(self) -> LogicalPlan: """Return the unoptimized ``LogicalPlan``. @@ -993,48 +1266,187 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: Returns: DataFrame after union. + + Examples: + Stack rows from both DataFrames, preserving duplicates: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2]}) + >>> df2 = ctx.from_pydict({"a": [2, 3]}) + >>> df1.union(df2).sort("a").to_pydict() + {'a': [1, 2, 2, 3]} + + Deduplicate the combined result with ``distinct=True``: + + >>> df1.union(df2, distinct=True).sort("a").to_pydict() + {'a': [1, 2, 3]} """ return DataFrame(self.df.union(other.df, distinct)) + @deprecated( + "union_distinct() is deprecated. Use union(other, distinct=True) instead." + ) def union_distinct(self, other: DataFrame) -> DataFrame: """Calculate the distinct union of two :py:class:`DataFrame`. + See Also: + :py:meth:`union` + """ + return self.union(other, distinct=True) + + def intersect(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the intersection of two :py:class:`DataFrame`. + The two :py:class:`DataFrame` must have exactly the same schema. - Any duplicate rows are discarded. Args: - other: DataFrame to union with. + other: DataFrame to intersect with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after union. + DataFrame after intersection. + + Examples: + Find rows common to both DataFrames: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]}) + >>> df1.intersect(df2).to_pydict() + {'a': [1], 'b': [10]} + + Intersect with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 20]}) + >>> df2 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df1.intersect(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} """ - return DataFrame(self.df.union_distinct(other.df)) + return DataFrame(self.df.intersect(other.df, distinct)) - def intersect(self, other: DataFrame) -> DataFrame: - """Calculate the intersection of two :py:class:`DataFrame`. + def except_all(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the set difference of two :py:class:`DataFrame`. + + Returns rows that are in this DataFrame but not in ``other``. The two :py:class:`DataFrame` must have exactly the same schema. Args: - other: DataFrame to intersect with. + other: DataFrame to calculate exception with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after intersection. + DataFrame after set difference. + + Examples: + Remove rows present in ``df2``: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df2 = ctx.from_pydict({"a": [1, 2], "b": [10, 20]}) + >>> df1.except_all(df2).sort("a").to_pydict() + {'a': [3], 'b': [30]} + + Remove rows present in ``df2`` and deduplicate: + + >>> df1.except_all(df2, distinct=True).sort("a").to_pydict() + {'a': [3], 'b': [30]} """ - return DataFrame(self.df.intersect(other.df)) + return DataFrame(self.df.except_all(other.df, distinct)) - def except_all(self, other: DataFrame) -> DataFrame: - """Calculate the exception of two :py:class:`DataFrame`. + def union_by_name(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Union two :py:class:`DataFrame` matching columns by name. - The two :py:class:`DataFrame` must have exactly the same schema. + Unlike :py:meth:`union` which matches columns by position, this method + matches columns by their names, allowing DataFrames with different + column orders to be combined. Args: - other: DataFrame to calculate exception with. + other: DataFrame to union with. + distinct: If ``True``, duplicate rows are removed from the result. Returns: - DataFrame after exception. + DataFrame after union by name. + + Examples: + Combine DataFrames with different column orders: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1], "b": [10]}) + >>> df2 = ctx.from_pydict({"b": [20], "a": [2]}) + >>> df1.union_by_name(df2).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 20]} + + Union by name with deduplication: + + >>> df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + >>> df2 = ctx.from_pydict({"b": [10], "a": [1]}) + >>> df1.union_by_name(df2, distinct=True).to_pydict() + {'a': [1], 'b': [10]} + """ + return DataFrame(self.df.union_by_name(other.df, distinct)) + + def distinct_on( + self, + on_expr: list[Expr], + select_expr: list[Expr], + sort_expr: list[SortKey] | None = None, + ) -> DataFrame: + """Deduplicate rows based on specific columns. + + Returns a new DataFrame with one row per unique combination of the + ``on_expr`` columns, keeping the first row per group as determined by + ``sort_expr``. + + Args: + on_expr: Expressions that determine uniqueness. + select_expr: Expressions to include in the output. + sort_expr: Optional sort expressions to determine which row to keep. + + Returns: + DataFrame after deduplication. + + Examples: + Keep the row with the smallest ``b`` for each unique ``a``: + + >>> from datafusion import col + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + >>> df.distinct_on( + ... [col("a")], + ... [col("a"), col("b")], + ... [col("a").sort(ascending=True), col("b").sort(ascending=True)], + ... ).sort("a").to_pydict() + {'a': [1, 2], 'b': [10, 30]} + """ + on_raw = expr_list_to_raw_expr_list(on_expr) + select_raw = expr_list_to_raw_expr_list(select_expr) + sort_raw = sort_list_to_raw_sort_list(sort_expr) if sort_expr else None + return DataFrame(self.df.distinct_on(on_raw, select_raw, sort_raw)) + + def sort_by(self, *exprs: Expr | str) -> DataFrame: + """Sort the DataFrame by column expressions in ascending order. + + This is a convenience method that sorts the DataFrame by the given + expressions in ascending order with nulls last. For more control over + sort direction and null ordering, use :py:meth:`sort` instead. + + Args: + exprs: Expressions or column names to sort by. + + Returns: + DataFrame after sorting. + + Examples: + Sort by a single column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2]}) + >>> df.sort_by("a").to_pydict() + {'a': [1, 2, 3]} """ - return DataFrame(self.df.except_all(other.df)) + raw = [_to_raw_expr(e) for e in exprs] + return DataFrame(self.df.sort_by(raw)) def write_csv( self, @@ -1295,24 +1707,44 @@ def count(self) -> int: """ return self.df.count() - @deprecated("Use :py:func:`unnest_columns` instead.") - def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: - """See :py:func:`unnest_columns`.""" - return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) - - def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: + def unnest_columns( + self, + *columns: str, + preserve_nulls: bool = True, + recursions: list[tuple[str, str, int]] | None = None, + ) -> DataFrame: """Expand columns of arrays into a single row per array element. Args: columns: Column names to perform unnest operation on. preserve_nulls: If False, rows with null entries will not be returned. + recursions: Optional list of ``(input_column, output_column, depth)`` + tuples that control how deeply nested columns are unnested. Any + column not mentioned here is unnested with depth 1. Returns: A DataFrame with the columns expanded. + + Examples: + Unnest an array column: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2], [3]], "b": ["x", "y"]}) + >>> df.unnest_columns("a").to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} + + With explicit recursion depth: + + >>> df.unnest_columns("a", recursions=[("a", "a", 1)]).to_pydict() + {'a': [1, 2, 3], 'b': ['x', 'x', 'y']} """ columns = list(columns) - return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) + return DataFrame( + self.df.unnest_columns( + columns, preserve_nulls=preserve_nulls, recursions=recursions + ) + ) def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """Export the DataFrame as an Arrow C Stream. diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py index b8af45a1b..fd2da99f0 100644 --- a/python/datafusion/dataframe_formatter.py +++ b/python/datafusion/dataframe_formatter.py @@ -748,7 +748,7 @@ def get_formatter() -> DataFrameHtmlFormatter: The global HTML formatter instance Example: - >>> from datafusion.html_formatter import get_formatter + >>> from datafusion.dataframe_formatter import get_formatter >>> formatter = get_formatter() >>> formatter.max_cell_length = 50 # Increase cell length """ @@ -762,7 +762,7 @@ def set_formatter(formatter: DataFrameHtmlFormatter) -> None: formatter: The formatter instance to use globally Example: - >>> from datafusion.html_formatter import get_formatter, set_formatter + >>> from datafusion.dataframe_formatter import get_formatter, set_formatter >>> custom_formatter = DataFrameHtmlFormatter(max_cell_length=100) >>> set_formatter(custom_formatter) """ @@ -783,7 +783,7 @@ def configure_formatter(**kwargs: Any) -> None: ValueError: If any invalid parameters are provided Example: - >>> from datafusion.html_formatter import configure_formatter + >>> from datafusion.dataframe_formatter import configure_formatter >>> configure_formatter( ... max_cell_length=50, ... max_height=500, @@ -827,7 +827,7 @@ def reset_formatter() -> None: and sets it as the global formatter for all DataFrames. Example: - >>> from datafusion.html_formatter import reset_formatter + >>> from datafusion.dataframe_formatter import reset_formatter >>> reset_formatter() # Reset formatter to default settings """ formatter = DataFrameHtmlFormatter() diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 14753a4f5..0f7f3ab5a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,31 @@ # specific language governing permissions and limitations # under the License. -"""This module supports expressions, one of the core concepts in DataFusion. - -See :ref:`Expressions` in the online documentation for more details. +""":py:class:`Expr` — the logical expression type used to build DataFusion queries. + +An :py:class:`Expr` represents a computation over columns or literals: a +column reference (``col("a")``), a literal (``lit(5)``), an operator +combination (``col("a") + lit(1)``), or the output of a function from +:py:mod:`datafusion.functions`. Expressions are passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.sort`. + +Convenience constructors are re-exported at the package level: +:py:func:`datafusion.col` / :py:func:`datafusion.column` for column references +and :py:func:`datafusion.lit` / :py:func:`datafusion.literal` for scalar +literals. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.select((col("a") * lit(10)).alias("ten_a")).to_pydict() + {'ten_a': [10, 20, 30]} + +See :ref:`expressions` in the online documentation for details on available +operators and helpers. """ # ruff: noqa: PLC0415 @@ -27,11 +49,6 @@ from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, ClassVar -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - import pyarrow as pa from ._internal import expr as expr_internal @@ -91,7 +108,6 @@ Extension = expr_internal.Extension FileType = expr_internal.FileType Filter = expr_internal.Filter -GroupingSet = expr_internal.GroupingSet Join = expr_internal.Join ILike = expr_internal.ILike InList = expr_internal.InList @@ -227,6 +243,8 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "coerce_to_expr", + "coerce_to_expr_or_none", "ensure_expr", "ensure_expr_list", ] @@ -239,6 +257,10 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: higher level APIs consistently require explicit :func:`~datafusion.col` or :func:`~datafusion.lit` expressions. + See Also: + :func:`coerce_to_expr` — the opposite behavior: *wraps* non-``Expr`` + values as literals instead of rejecting them. + Args: value: Candidate expression or other object. @@ -283,6 +305,41 @@ def _iter( return list(_iter(exprs)) +def coerce_to_expr(value: Any) -> Expr: + """Coerce a native Python value to an ``Expr`` literal, passing ``Expr`` through. + + This is the complement of :func:`ensure_expr`: where ``ensure_expr`` + *rejects* non-``Expr`` values, ``coerce_to_expr`` *wraps* them via + :meth:`Expr.literal` so that functions can accept native Python types + (``int``, ``float``, ``str``, ``bool``, etc.) alongside ``Expr``. + + Args: + value: An ``Expr`` instance (returned as-is) or a Python literal to wrap. + + Returns: + An ``Expr`` representing the value. + """ + if isinstance(value, Expr): + return value + return Expr.literal(value) + + +def coerce_to_expr_or_none(value: Any | None) -> Expr | None: + """Coerce a value to ``Expr`` or pass ``None`` through unchanged. + + Same as :func:`coerce_to_expr` but accepts ``None`` for optional parameters. + + Args: + value: An ``Expr`` instance, a Python literal to wrap, or ``None``. + + Returns: + An ``Expr`` representing the value, or ``None``. + """ + if value is None: + return None + return coerce_to_expr(value) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. @@ -357,16 +414,6 @@ def to_variant(self) -> Any: """Convert this expression into a python object if possible.""" return self.expr.to_variant() - @deprecated( - "display_name() is deprecated. Use :py:meth:`~Expr.schema_name` instead" - ) - def display_name(self) -> str: - """Returns the name of this expression as it should appear in a schema. - - This name will not include any CAST expressions. - """ - return self.schema_name() - def schema_name(self) -> str: """Returns the name of this expression as it should appear in a schema. @@ -499,6 +546,8 @@ def __eq__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) @@ -508,6 +557,8 @@ def __ne__(self, rhs: object) -> Expr: Accepts either an expression or any valid PyArrow scalar literal value. """ + if rhs is None: + return self.is_not_null() if not isinstance(rhs, Expr): rhs = Expr.literal(rhs) return Expr(self.expr.__ne__(rhs.expr)) @@ -1430,3 +1481,129 @@ def __repr__(self) -> str: SortKey = Expr | SortExpr | str + + +class GroupingSet: + """Factory for creating grouping set expressions. + + Grouping sets control how + :py:meth:`~datafusion.dataframe.DataFrame.aggregate` groups rows. + Instead of a single ``GROUP BY``, they produce multiple grouping + levels in one pass — subtotals, cross-tabulations, or arbitrary + column subsets. + + Use :py:func:`~datafusion.functions.grouping` in the aggregate list + to tell which columns are aggregated across in each result row. + """ + + @staticmethod + def rollup(*exprs: Expr | str) -> Expr: + """Create a ``ROLLUP`` grouping set for use with ``aggregate()``. + + ``ROLLUP`` generates all prefixes of the given column list as + grouping sets. For example, ``rollup(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY ROLLUP(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the rollup. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`cube`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.rollup(*args)) + + @staticmethod + def cube(*exprs: Expr | str) -> Expr: + """Create a ``CUBE`` grouping set for use with ``aggregate()``. + + ``CUBE`` generates all possible subsets of the given column list + as grouping sets. For example, ``cube(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY CUBE(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the cube. + + Examples: + With a single column, ``cube`` behaves identically to + :py:meth:`rollup`: + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.cube(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`rollup`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.cube(*args)) + + @staticmethod + def grouping_sets(*expr_lists: list[Expr | str]) -> Expr: + """Create explicit grouping sets for use with ``aggregate()``. + + Each argument is a list of column expressions or column name + strings representing one grouping set. For example, + ``grouping_sets([a], [b])`` groups by ``a`` alone and by ``b`` + alone in a single query. + + This is equivalent to ``GROUP BY GROUPING SETS ((a), (b))`` in + SQL. + + Args: + *expr_lists: Each positional argument is a list of + expressions or column name strings forming one + grouping set. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"a": ["x", "x", "y"], "b": ["m", "n", "m"], + ... "c": [1, 2, 3]}) + >>> result = df.aggregate( + ... [GroupingSet.grouping_sets( + ... [dfn.col("a")], [dfn.col("b")])], + ... [dfn.functions.sum(dfn.col("c")).alias("s"), + ... dfn.functions.grouping(dfn.col("a")), + ... dfn.functions.grouping(dfn.col("b"))], + ... ).sort( + ... dfn.col("a").sort(nulls_first=False), + ... dfn.col("b").sort(nulls_first=False), + ... ) + >>> result.collect_column("s").to_pylist() + [3, 3, 4, 2] + + See Also: + :py:meth:`rollup`, :py:meth:`cube`, + :py:func:`~datafusion.functions.grouping` + """ + raw_lists = [[_to_raw_expr(e) for e in lst] for lst in expr_lists] + return Expr(expr_internal.GroupingSet.grouping_sets(*raw_lists)) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f062cbfce..08062851a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,11 +14,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User functions for operating on :py:class:`~datafusion.expr.Expr`.""" +"""Scalar, aggregate, and window functions for :py:class:`~datafusion.expr.Expr`. + +Each function returns an :py:class:`~datafusion.expr.Expr` that can be combined +with other expressions and passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.window`. The module is conventionally +imported as ``F`` so calls read like ``F.sum(col("price"))``. + +Examples: + >>> from datafusion import functions as F + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}) + >>> df.aggregate([], [F.sum(col("a")).alias("total")]).to_pydict() + {'total': [10]} + +See :ref:`aggregation` and :ref:`window_functions` in the online documentation +for categorized catalogs of aggregate and window functions. +""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any import pyarrow as pa @@ -29,19 +49,13 @@ Expr, SortExpr, SortKey, - WindowFrame, + coerce_to_expr, + coerce_to_expr_or_none, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, ) -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - -if TYPE_CHECKING: - from datafusion.context import SessionContext __all__ = [ "abs", "acos", @@ -53,10 +67,13 @@ "approx_percentile_cont_with_weight", "array", "array_agg", + "array_any_value", "array_append", "array_cat", "array_concat", + "array_contains", "array_dims", + "array_distance", "array_distinct", "array_element", "array_empty", @@ -69,6 +86,8 @@ "array_intersect", "array_join", "array_length", + "array_max", + "array_min", "array_ndims", "array_pop_back", "array_pop_front", @@ -85,11 +104,15 @@ "array_replace_all", "array_replace_n", "array_resize", + "array_reverse", "array_slice", "array_sort", "array_to_string", "array_union", + "arrays_overlap", + "arrays_zip", "arrow_cast", + "arrow_metadata", "arrow_typeof", "ascii", "asin", @@ -116,6 +139,7 @@ "col", "concat", "concat_ws", + "contains", "corr", "cos", "cosh", @@ -128,7 +152,9 @@ "cume_dist", "current_date", "current_time", + "current_timestamp", "date_bin", + "date_format", "date_part", "date_trunc", "datepart", @@ -137,6 +163,7 @@ "degrees", "dense_rank", "digest", + "element_at", "empty", "encode", "ends_with", @@ -149,6 +176,12 @@ "floor", "from_unixtime", "gcd", + "gen_series", + "generate_series", + "get_field", + "greatest", + "grouping", + "ifnull", "in_list", "initcap", "isnan", @@ -157,22 +190,35 @@ "last_value", "lcm", "lead", + "least", "left", "length", "levenshtein", + "list_any_value", "list_append", "list_cat", "list_concat", + "list_contains", "list_dims", + "list_distance", "list_distinct", "list_element", + "list_empty", "list_except", "list_extract", + "list_has", + "list_has_all", + "list_has_any", "list_indexof", "list_intersect", "list_join", "list_length", + "list_max", + "list_min", "list_ndims", + "list_overlap", + "list_pop_back", + "list_pop_front", "list_position", "list_positions", "list_prepend", @@ -186,10 +232,12 @@ "list_replace_all", "list_replace_n", "list_resize", + "list_reverse", "list_slice", "list_sort", "list_to_string", "list_union", + "list_zip", "ln", "log", "log2", @@ -200,6 +248,12 @@ "make_array", "make_date", "make_list", + "make_map", + "make_time", + "map_entries", + "map_extract", + "map_keys", + "map_values", "max", "md5", "mean", @@ -212,13 +266,16 @@ "ntile", "nullif", "nvl", + "nvl2", "octet_length", "order_by", "overlay", "percent_rank", + "percentile_cont", "pi", "pow", "power", + "quantile_cont", "radians", "random", "range", @@ -242,6 +299,7 @@ "reverse", "right", "round", + "row", "row_number", "rpad", "rtrim", @@ -259,6 +317,8 @@ "stddev_pop", "stddev_samp", "string_agg", + "string_to_array", + "string_to_list", "strpos", "struct", "substr", @@ -282,15 +342,17 @@ "translate", "trim", "trunc", + "union_extract", + "union_tag", "upper", "uuid", "var", "var_pop", + "var_population", "var_samp", "var_sample", + "version", "when", - # Window Functions - "window", ] @@ -323,49 +385,52 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(expr: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr | str) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.encode(dfn.col("a"), dfn.lit("base64")).alias("enc")) + ... dfn.functions.encode(dfn.col("a"), "base64").alias("enc")) >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) -def decode(expr: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr | str) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["aGVsbG8="]}) >>> result = df.select( - ... dfn.functions.decode(dfn.col("a"), dfn.lit("base64")).alias("dec")) + ... dfn.functions.decode(dfn.col("a"), "base64").alias("dec")) >>> result.collect_column("dec")[0].as_py() b'hello' """ + encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) -def array_to_string(expr: Expr, delimiter: Expr) -> Expr: +def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) >>> result = df.select( - ... dfn.functions.array_to_string(dfn.col("a"), dfn.lit(",")).alias("s")) + ... dfn.functions.array_to_string(dfn.col("a"), ",").alias("s")) >>> result.collect_column("s")[0].as_py() '1,2,3' """ + delimiter = coerce_to_expr(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) -def array_join(expr: Expr, delimiter: Expr) -> Expr: +def array_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -374,7 +439,7 @@ def array_join(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_to_string(expr: Expr, delimiter: Expr) -> Expr: +def list_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -383,7 +448,7 @@ def list_to_string(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_join(expr: Expr, delimiter: Expr) -> Expr: +def list_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -419,7 +484,7 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: return Expr(f.in_list(arg.expr, values, negated)) -def digest(value: Expr, method: Expr) -> Expr: +def digest(value: Expr, method: Expr | str) -> Expr: """Computes the binary hash of an expression using the specified algorithm. Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, @@ -429,13 +494,29 @@ def digest(value: Expr, method: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.digest(dfn.col("a"), dfn.lit("md5")).alias("d")) + ... dfn.functions.digest(dfn.col("a"), "md5").alias("d")) >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) +def contains(string: Expr, search_str: Expr | str) -> Expr: + """Returns true if ``search_str`` is found within ``string`` (case-sensitive). + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) + >>> result = df.select( + ... dfn.functions.contains(dfn.col("a"), "brown").alias("c")) + >>> result.collect_column("c")[0].as_py() + True + """ + search_str = coerce_to_expr(search_str) + return Expr(f.contains(string.expr, search_str.expr)) + + def concat(*args: Expr) -> Expr: """Concatenates the text representations of all the arguments. @@ -600,49 +681,6 @@ def when(when: Expr, then: Expr) -> CaseBuilder: return CaseBuilder(f.when(when.expr, then.expr)) -@deprecated("Prefer to call Expr.over() instead") -def window( - name: str, - args: list[Expr], - partition_by: list[Expr] | Expr | None = None, - order_by: list[SortKey] | SortKey | None = None, - window_frame: WindowFrame | None = None, - filter: Expr | None = None, - distinct: bool = False, - ctx: SessionContext | None = None, -) -> Expr: - """Creates a new Window function expression. - - This interface will soon be deprecated. Instead of using this interface, - users should call the window functions directly. For example, to perform a - lag use:: - - df.select(functions.lag(col("a")).partition_by(col("b")).build()) - - The ``order_by`` parameter accepts column names or expressions, e.g.:: - - window("lag", [col("a")], order_by="ts") - """ - args = [a.expr for a in args] - partition_by_raw = expr_list_to_raw_expr_list(partition_by) - order_by_raw = sort_list_to_raw_sort_list(order_by) - window_frame = window_frame.window_frame if window_frame is not None else None - ctx = ctx.ctx if ctx is not None else None - filter_raw = filter.expr if filter is not None else None - return Expr( - f.window( - name, - args, - partition_by=partition_by_raw, - order_by=order_by_raw, - window_frame=window_frame, - ctx=ctx, - filter=filter_raw, - distinct=distinct, - ) - ) - - # scalar functions def abs(arg: Expr) -> Expr: """Return the absolute value of a given number. @@ -938,17 +976,18 @@ def degrees(arg: Expr) -> Expr: return Expr(f.degrees(arg.expr)) -def ends_with(arg: Expr, suffix: Expr) -> Expr: +def ends_with(arg: Expr, suffix: Expr | str) -> Expr: """Returns true if the ``string`` ends with the ``suffix``, false otherwise. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abc","b","c"]}) >>> ends_with_df = df.select( - ... dfn.functions.ends_with(dfn.col("a"), dfn.lit("c")).alias("ends_with")) + ... dfn.functions.ends_with(dfn.col("a"), "c").alias("ends_with")) >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ + suffix = coerce_to_expr(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -980,7 +1019,7 @@ def factorial(arg: Expr) -> Expr: return Expr(f.factorial(arg.expr)) -def find_in_set(string: Expr, string_list: Expr) -> Expr: +def find_in_set(string: Expr, string_list: Expr | str) -> Expr: """Find a string in a list of strings. Returns a value in the range of 1 to N if the string is in the string list @@ -992,10 +1031,11 @@ def find_in_set(string: Expr, string_list: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["b"]}) >>> result = df.select( - ... dfn.functions.find_in_set(dfn.col("a"), dfn.lit("a,b,c")).alias("pos")) + ... dfn.functions.find_in_set(dfn.col("a"), "a,b,c").alias("pos")) >>> result.collect_column("pos")[0].as_py() 2 """ + string_list = coerce_to_expr(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1027,6 +1067,34 @@ def gcd(x: Expr, y: Expr) -> Expr: return Expr(f.gcd(x.expr, y.expr)) +def greatest(*args: Expr) -> Expr: + """Returns the greatest value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.greatest(dfn.col("a"), dfn.col("b")).alias("greatest")) + >>> result.collect_column("greatest")[0].as_py() + 2 + >>> result.collect_column("greatest")[1].as_py() + 3 + """ + exprs = [arg.expr for arg in args] + return Expr(f.greatest(*exprs)) + + +def ifnull(x: Expr, y: Expr) -> Expr: + """Returns ``x`` if ``x`` is not NULL. Otherwise returns ``y``. + + See Also: + This is an alias for :py:func:`nvl`. + """ + return nvl(x, y) + + def initcap(string: Expr) -> Expr: """Set the initial letter of each word to capital. @@ -1043,7 +1111,7 @@ def initcap(string: Expr) -> Expr: return Expr(f.initcap(string.expr)) -def instr(string: Expr, substring: Expr) -> Expr: +def instr(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1080,31 +1148,52 @@ def lcm(x: Expr, y: Expr) -> Expr: return Expr(f.lcm(x.expr, y.expr)) -def left(string: Expr, n: Expr) -> Expr: +def least(*args: Expr) -> Expr: + """Returns the least value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.least(dfn.col("a"), dfn.col("b")).alias("least")) + >>> result.collect_column("least")[0].as_py() + 1 + >>> result.collect_column("least")[1].as_py() + 1 + """ + exprs = [arg.expr for arg in args] + return Expr(f.least(*exprs)) + + +def left(string: Expr, n: Expr | int) -> Expr: """Returns the first ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat"]}) >>> left_df = df.select( - ... dfn.functions.left(dfn.col("a"), dfn.lit(3)).alias("left")) + ... dfn.functions.left(dfn.col("a"), 3).alias("left")) >>> left_df.collect_column("left")[0].as_py() 'the' """ + n = coerce_to_expr(n) return Expr(f.left(string.expr, n.expr)) -def levenshtein(string1: Expr, string2: Expr) -> Expr: +def levenshtein(string1: Expr, string2: Expr | str) -> Expr: """Returns the Levenshtein distance between the two given strings. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["kitten"]}) >>> result = df.select( - ... dfn.functions.levenshtein(dfn.col("a"), dfn.lit("sitting")).alias("d")) + ... dfn.functions.levenshtein(dfn.col("a"), "sitting").alias("d")) >>> result.collect_column("d")[0].as_py() 3 """ + string2 = coerce_to_expr(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1121,18 +1210,19 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr, num: Expr) -> Expr: +def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041 """Returns the logarithm of a number for a particular ``base``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [100.0]}) >>> result = df.select( - ... dfn.functions.log(dfn.lit(10.0), dfn.col("a")).alias("log") + ... dfn.functions.log(10.0, dfn.col("a")).alias("log") ... ) >>> result.collect_column("log")[0].as_py() 2.0 """ + base = coerce_to_expr(base) return Expr(f.log(base.expr, num.expr)) @@ -1175,7 +1265,7 @@ def lower(arg: Expr) -> Expr: return Expr(f.lower(arg.expr)) -def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add left padding to a string. Extends the string to length length by prepending the characters fill (a @@ -1186,9 +1276,7 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat", "a hat"]}) >>> lpad_df = df.select( - ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(6) - ... ).alias("lpad")) + ... dfn.functions.lpad(dfn.col("a"), 6).alias("lpad")) >>> lpad_df.collect_column("lpad")[0].as_py() 'the ca' >>> lpad_df.collect_column("lpad")[1].as_py() @@ -1196,12 +1284,13 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(10), characters=dfn.lit(".") + ... dfn.col("a"), 10, characters="." ... ).alias("lpad")) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1264,6 +1353,24 @@ def nvl(x: Expr, y: Expr) -> Expr: return Expr(f.nvl(x.expr, y.expr)) +def nvl2(x: Expr, y: Expr, z: Expr) -> Expr: + """Returns ``y`` if ``x`` is not NULL. Otherwise returns ``z``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [None, 1], "b": [10, 20], "c": [30, 40]}) + >>> result = df.select( + ... dfn.functions.nvl2( + ... dfn.col("a"), dfn.col("b"), dfn.col("c")).alias("nvl2") + ... ) + >>> result.collect_column("nvl2")[0].as_py() + 30 + >>> result.collect_column("nvl2")[1].as_py() + 20 + """ + return Expr(f.nvl2(x.expr, y.expr, z.expr)) + + def octet_length(arg: Expr) -> Expr: """Returns the number of bytes of a string. @@ -1278,7 +1385,10 @@ def octet_length(arg: Expr) -> Expr: def overlay( - string: Expr, substring: Expr, start: Expr, length: Expr | None = None + string: Expr, + substring: Expr | str, + start: Expr | int, + length: Expr | int | None = None, ) -> Expr: """Replace a substring with a new substring. @@ -1289,13 +1399,15 @@ def overlay( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcdef"]}) >>> result = df.select( - ... dfn.functions.overlay(dfn.col("a"), dfn.lit("XY"), dfn.lit(3), - ... dfn.lit(2)).alias("o")) + ... dfn.functions.overlay(dfn.col("a"), "XY", 3, 2).alias("o")) >>> result.collect_column("o")[0].as_py() 'abXYef' """ + substring = coerce_to_expr(substring) + start = coerce_to_expr(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) + length = coerce_to_expr(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1315,7 +1427,7 @@ def pi() -> Expr: return Expr(f.pi()) -def position(string: Expr, substring: Expr) -> Expr: +def position(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1324,22 +1436,23 @@ def position(string: Expr, substring: Expr) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr) -> Expr: +def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [2.0]}) >>> result = df.select( - ... dfn.functions.power(dfn.col("a"), dfn.lit(3.0)).alias("pow") + ... dfn.functions.power(dfn.col("a"), 3.0).alias("pow") ... ) >>> result.collect_column("pow")[0].as_py() 8.0 """ + exponent = coerce_to_expr(exponent) return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr) -> Expr: +def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1364,7 +1477,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1374,9 +1489,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( - ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.functions.regexp_like(dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1385,19 +1498,24 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_like( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1407,9 +1525,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.functions.regexp_match(dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1418,20 +1534,26 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_match( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1446,8 +1568,7 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1458,20 +1579,30 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + pattern = coerce_to_expr(pattern) + replacement = coerce_to_expr(replacement) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_replace( + string.expr, + pattern.expr, + replacement.expr, + flags.expr if flags is not None else None, + ) + ) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1482,9 +1613,7 @@ def regexp_count( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( - ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.functions.regexp_count(dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1493,25 +1622,31 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ - if flags is not None: - flags = flags.expr - start = start.expr if start is not None else start - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1527,9 +1662,7 @@ def regexp_instr( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.functions.regexp_instr(dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1540,9 +1673,8 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1552,56 +1684,58 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + regex = coerce_to_expr(regex) + start = coerce_to_expr_or_none(start) + n = coerce_to_expr_or_none(n) + flags = coerce_to_expr_or_none(flags) + sub_expr = coerce_to_expr_or_none(sub_expr) return Expr( f.regexp_instr( values.expr, regex.expr, - start, - n, - flags, - sub_expr, + start.expr if start is not None else None, + n.expr if n is not None else None, + flags.expr if flags is not None else None, + sub_expr.expr if sub_expr is not None else None, ) ) -def repeat(string: Expr, n: Expr) -> Expr: +def repeat(string: Expr, n: Expr | int) -> Expr: """Repeats the ``string`` to ``n`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["ha"]}) >>> result = df.select( - ... dfn.functions.repeat(dfn.col("a"), dfn.lit(3)).alias("r")) + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'hahaha' """ + n = coerce_to_expr(n) return Expr(f.repeat(string.expr, n.expr)) -def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces all occurrences of ``from_val`` with ``to_val`` in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.replace(dfn.col("a"), dfn.lit("world"), - ... dfn.lit("there")).alias("r")) + ... dfn.functions.replace(dfn.col("a"), "world", "there").alias("r")) >>> result.collect_column("r")[0].as_py() 'hello there' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1618,39 +1752,39 @@ def reverse(arg: Expr) -> Expr: return Expr(f.reverse(arg.expr)) -def right(string: Expr, n: Expr) -> Expr: +def right(string: Expr, n: Expr | int) -> Expr: """Returns the last ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) - >>> result = df.select(dfn.functions.right(dfn.col("a"), dfn.lit(3)).alias("r")) + >>> result = df.select(dfn.functions.right(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'llo' """ + n = coerce_to_expr(n) return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: +def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: """Round the argument to the nearest integer. If the optional ``decimal_places`` is specified, round to the nearest number of decimal places. You can specify a negative number of decimal places. For example - ``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``. + ``round(lit(125.2345), -2)`` would yield a value of ``100.0``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) - >>> result = df.select(dfn.functions.round(dfn.col("a"), dfn.lit(2)).alias("r")) + >>> result = df.select(dfn.functions.round(dfn.col("a"), 2).alias("r")) >>> result.collect_column("r")[0].as_py() 1.57 """ - if decimal_places is None: - decimal_places = Expr.literal(0) + decimal_places = coerce_to_expr(decimal_places if decimal_places is not None else 0) return Expr(f.round(value.expr, decimal_places.expr)) -def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add right padding to a string. Extends the string to length length by appending the characters fill (a space @@ -1660,11 +1794,12 @@ def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hi"]}) >>> result = df.select( - ... dfn.functions.rpad(dfn.col("a"), dfn.lit(5), dfn.lit("!")).alias("r")) + ... dfn.functions.rpad(dfn.col("a"), 5, "!").alias("r")) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1780,7 +1915,7 @@ def sinh(arg: Expr) -> Expr: return Expr(f.sinh(arg.expr)) -def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: +def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: """Split a string and return one part. Splits a string based on a delimiter and picks out the desired field based @@ -1790,12 +1925,12 @@ def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a,b,c"]}) >>> result = df.select( - ... dfn.functions.split_part( - ... dfn.col("a"), dfn.lit(","), dfn.lit(2) - ... ).alias("s")) + ... dfn.functions.split_part(dfn.col("a"), ",", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'b' """ + delimiter = coerce_to_expr(delimiter) + index = coerce_to_expr(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1812,49 +1947,52 @@ def sqrt(arg: Expr) -> Expr: return Expr(f.sqrt(arg.expr)) -def starts_with(string: Expr, prefix: Expr) -> Expr: +def starts_with(string: Expr, prefix: Expr | str) -> Expr: """Returns true if string starts with prefix. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello_from_datafusion"]}) >>> result = df.select( - ... dfn.functions.starts_with(dfn.col("a"), dfn.lit("hello")).alias("sw")) + ... dfn.functions.starts_with(dfn.col("a"), "hello").alias("sw")) >>> result.collect_column("sw")[0].as_py() True """ + prefix = coerce_to_expr(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) -def strpos(string: Expr, substring: Expr) -> Expr: +def strpos(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.strpos(dfn.col("a"), dfn.lit("llo")).alias("pos")) + ... dfn.functions.strpos(dfn.col("a"), "llo").alias("pos")) >>> result.collect_column("pos")[0].as_py() 3 """ + substring = coerce_to_expr(substring) return Expr(f.strpos(string.expr, substring.expr)) -def substr(string: Expr, position: Expr) -> Expr: +def substr(string: Expr, position: Expr | int) -> Expr: """Substring from the ``position`` to the end. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.substr(dfn.col("a"), dfn.lit(3)).alias("s")) + ... dfn.functions.substr(dfn.col("a"), 3).alias("s")) >>> result.collect_column("s")[0].as_py() 'llo' """ + position = coerce_to_expr(position) return Expr(f.substr(string.expr, position.expr)) -def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: +def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr: """Returns an indexed substring. The return will be the ``string`` from before ``count`` occurrences of @@ -1864,27 +2002,28 @@ def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a.b.c"]}) >>> result = df.select( - ... dfn.functions.substr_index(dfn.col("a"), dfn.lit("."), - ... dfn.lit(2)).alias("s")) + ... dfn.functions.substr_index(dfn.col("a"), ".", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'a.b' """ + delimiter = coerce_to_expr(delimiter) + count = coerce_to_expr(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) -def substring(string: Expr, position: Expr, length: Expr) -> Expr: +def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: """Substring from the ``position`` with ``length`` characters. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.substring( - ... dfn.col("a"), dfn.lit(1), dfn.lit(5) - ... ).alias("s")) + ... dfn.functions.substring(dfn.col("a"), 1, 5).alias("s")) >>> result.collect_column("s")[0].as_py() 'hello' """ + position = coerce_to_expr(position) + length = coerce_to_expr(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -1948,7 +2087,16 @@ def now() -> Expr: return Expr(f.now()) -def to_char(arg: Expr, formatter: Expr) -> Expr: +def current_timestamp() -> Expr: + """Returns the current timestamp in nanoseconds. + + See Also: + This is an alias for :py:func:`now`. + """ + return now() + + +def to_char(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. For usage of ``formatter`` see the rust chrono package ``strftime`` package. @@ -1961,15 +2109,25 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: >>> result = df.select( ... dfn.functions.to_char( ... dfn.functions.to_timestamp(dfn.col("a")), - ... dfn.lit("%Y/%m/%d"), + ... "%Y/%m/%d", ... ).alias("formatted") ... ) >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ + formatter = coerce_to_expr(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) +def date_format(arg: Expr, formatter: Expr | str) -> Expr: + """Returns a string representation of a date, time, timestamp or duration. + + See Also: + This is an alias for :py:func:`to_char`. + """ + return to_char(arg, formatter) + + def _unwrap_exprs(args: tuple[Expr, ...]) -> list: return [arg.expr for arg in args] @@ -2173,7 +2331,7 @@ def current_time() -> Expr: return Expr(f.current_time()) -def datepart(part: Expr, date: Expr) -> Expr: +def datepart(part: Expr | str, date: Expr) -> Expr: """Return a specified part of a date. See Also: @@ -2182,22 +2340,28 @@ def datepart(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_part(part: Expr, date: Expr) -> Expr: +def date_part(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. + Args: + part: The part of the date to extract. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to extract from. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_part(dfn.lit("year"), dfn.col("a")).alias("y")) + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) >>> result.collect_column("y")[0].as_py() 2021 """ + part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) -def extract(part: Expr, date: Expr) -> Expr: +def extract(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. See Also: @@ -2206,25 +2370,29 @@ def extract(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_trunc(part: Expr, date: Expr) -> Expr: +def date_trunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. + Args: + part: The precision to truncate to. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to truncate. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_trunc( - ... dfn.lit("month"), dfn.col("a") - ... ).alias("t") + ... dfn.functions.date_trunc("month", dfn.col("a")).alias("t") ... ) >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ + part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) -def datetrunc(part: Expr, date: Expr) -> Expr: +def datetrunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. See Also: @@ -2270,18 +2438,34 @@ def make_date(year: Expr, month: Expr, day: Expr) -> Expr: return Expr(f.make_date(year.expr, month.expr, day.expr)) -def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: + """Make a time from hour, minute and second component parts. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"h": [12], "m": [30], "s": [0]}) + >>> result = df.select( + ... dfn.functions.make_time(dfn.col("h"), dfn.col("m"), + ... dfn.col("s")).alias("t")) + >>> result.collect_column("t")[0].as_py() + datetime.time(12, 30) + """ + return Expr(f.make_time(hour.expr, minute.expr, second.expr)) + + +def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.translate(dfn.col("a"), dfn.lit("helo"), - ... dfn.lit("HELO")).alias("t")) + ... dfn.functions.translate(dfn.col("a"), "helo", "HELO").alias("t")) >>> result.collect_column("t")[0].as_py() 'HELLO' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2298,27 +2482,24 @@ def trim(arg: Expr) -> Expr: return Expr(f.trim(arg.expr)) -def trunc(num: Expr, precision: Expr | None = None) -> Expr: +def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: """Truncate the number toward zero with optional precision. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a") - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a")).alias("t")) >>> result.collect_column("t")[0].as_py() 1.0 >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a"), precision=dfn.lit(2) - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a"), precision=2).alias("t")) >>> result.collect_column("t")[0].as_py() 1.56 """ if precision is not None: + precision = coerce_to_expr(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -2476,22 +2657,180 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) -def arrow_cast(expr: Expr, data_type: Expr) -> Expr: +def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: """Casts an expression to a specified data type. + The ``data_type`` can be a string, a ``pyarrow.DataType``, or an + ``Expr``. For simple types, :py:meth:`Expr.cast() + ` is more concise + (e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when + you want to specify the target type as a string using DataFusion's + type syntax, which can be more readable for complex types like + ``"Timestamp(Nanosecond, None)"``. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) - >>> data_type = dfn.string_literal("Float64") >>> result = df.select( - ... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c") + ... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c") + ... ) + >>> result.collect_column("c")[0].as_py() + 1.0 + + >>> result = df.select( + ... dfn.functions.arrow_cast( + ... dfn.col("a"), data_type=pa.float64() + ... ).alias("c") ... ) >>> result.collect_column("c")[0].as_py() 1.0 """ + if isinstance(data_type, pa.DataType): + return expr.cast(data_type) + if isinstance(data_type, str): + data_type = Expr.string_literal(data_type) return Expr(f.arrow_cast(expr.expr, data_type.expr)) +def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: + """Returns the metadata of the input expression. + + If called with one argument, returns a Map of all metadata key-value pairs. + If called with two arguments, returns the value for the specified metadata key. + + Examples: + >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) + >>> schema = pa.schema([field]) + >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) + >>> ctx = dfn.SessionContext() + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta") + ... ) + >>> ("k", "v") in result.collect_column("meta")[0].as_py() + True + + >>> result = df.select( + ... dfn.functions.arrow_metadata( + ... dfn.col("val"), key="k" + ... ).alias("meta_val") + ... ) + >>> result.collect_column("meta_val")[0].as_py() + 'v' + """ + if key is None: + return Expr(f.arrow_metadata(expr.expr)) + if isinstance(key, str): + key = Expr.string_literal(key) + return Expr(f.arrow_metadata(expr.expr, key.expr)) + + +def get_field(expr: Expr, name: Expr | str) -> Expr: + """Extracts a field from a struct or map by name. + + When the field name is a static string, the bracket operator + ``expr["field"]`` is a convenient shorthand. Use ``get_field`` + when the field name is a dynamic expression. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1], "b": [2]}) + >>> df = df.with_column( + ... "s", + ... dfn.functions.named_struct( + ... [("x", dfn.col("a")), ("y", dfn.col("b"))] + ... ), + ... ) + >>> result = df.select( + ... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + + Equivalent using bracket syntax: + + >>> result = df.select( + ... dfn.col("s")["x"].alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 + """ + if isinstance(name, str): + name = Expr.string_literal(name) + return Expr(f.get_field(expr.expr, name.expr)) + + +def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: + """Extracts a value from a union type by field name. + + Returns the value of the named field if it is the currently selected + variant, otherwise returns NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_extract(dfn.col("u"), "int").alias("val") + ... ) + >>> result.collect_column("val").to_pylist() + [1, None, 2] + """ + if isinstance(field_name, str): + field_name = Expr.string_literal(field_name) + return Expr(f.union_extract(union_expr.expr, field_name.expr)) + + +def union_tag(union_expr: Expr) -> Expr: + """Returns the tag (active field name) of a union type. + + Examples: + >>> ctx = dfn.SessionContext() + >>> types = pa.array([0, 1, 0], type=pa.int8()) + >>> offsets = pa.array([0, 0, 1], type=pa.int32()) + >>> arr = pa.UnionArray.from_dense( + ... types, offsets, [pa.array([1, 2]), pa.array(["hi"])], + ... ["int", "str"], [0, 1], + ... ) + >>> batch = pa.RecordBatch.from_arrays([arr], names=["u"]) + >>> df = ctx.create_dataframe([[batch]]) + >>> result = df.select( + ... dfn.functions.union_tag(dfn.col("u")).alias("tag") + ... ) + >>> result.collect_column("tag").to_pylist() + ['int', 'str', 'int'] + """ + return Expr(f.union_tag(union_expr.expr)) + + +def version() -> Expr: + """Returns the DataFusion version string. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.empty_table() + >>> result = df.select(dfn.functions.version().alias("v")) + >>> "Apache DataFusion" in result.collect_column("v")[0].as_py() + True + """ + return Expr(f.version()) + + +def row(*args: Expr) -> Expr: + """Returns a struct with the given arguments. + + See Also: + This is an alias for :py:func:`struct`. + """ + return struct(*args) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``. @@ -2641,17 +2980,18 @@ def list_dims(array: Expr) -> Expr: return array_dims(array) -def array_element(array: Expr, n: Expr) -> Expr: +def array_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[10, 20, 30]]}) >>> result = df.select( - ... dfn.functions.array_element(dfn.col("a"), dfn.lit(2)).alias("result")) + ... dfn.functions.array_element(dfn.col("a"), 2).alias("result")) >>> result.collect_column("result")[0].as_py() 20 """ + n = coerce_to_expr(n) return Expr(f.array_element(array.expr, n.expr)) @@ -2668,7 +3008,16 @@ def array_empty(array: Expr) -> Expr: return Expr(f.array_empty(array.expr)) -def array_extract(array: Expr, n: Expr) -> Expr: +def list_empty(array: Expr) -> Expr: + """Returns a boolean indicating whether the array is empty. + + See Also: + This is an alias for :py:func:`array_empty`. + """ + return array_empty(array) + + +def array_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2677,7 +3026,7 @@ def array_extract(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_element(array: Expr, n: Expr) -> Expr: +def list_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2686,7 +3035,7 @@ def list_element(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_extract(array: Expr, n: Expr) -> Expr: +def list_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2765,6 +3114,69 @@ def array_has_any(first_array: Expr, second_array: Expr) -> Expr: return Expr(f.array_has_any(first_array.expr, second_array.expr)) +def array_contains(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + +def list_has(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + +def list_has_all(first_array: Expr, second_array: Expr) -> Expr: + """Determines if there is complete overlap ``second_array`` in ``first_array``. + + See Also: + This is an alias for :py:func:`array_has_all`. + """ + return array_has_all(first_array, second_array) + + +def list_has_any(first_array: Expr, second_array: Expr) -> Expr: + """Determine if there is an overlap between ``first_array`` and ``second_array``. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def arrays_overlap(first_array: Expr, second_array: Expr) -> Expr: + """Returns true if any element appears in both arrays. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def list_overlap(first_array: Expr, second_array: Expr) -> Expr: + """Returns true if any element appears in both arrays. + + See Also: + This is an alias for :py:func:`array_has_any`. + """ + return array_has_any(first_array, second_array) + + +def list_contains(array: Expr, element: Expr) -> Expr: + """Returns true if the element appears in the array, otherwise false. + + See Also: + This is an alias for :py:func:`array_has`. + """ + return array_has(array, element) + + def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: """Return the position of the first occurrence of ``element`` in ``array``. @@ -2932,6 +3344,24 @@ def array_pop_front(array: Expr) -> Expr: return Expr(f.array_pop_front(array.expr)) +def list_pop_back(array: Expr) -> Expr: + """Returns the array without the last element. + + See Also: + This is an alias for :py:func:`array_pop_back`. + """ + return array_pop_back(array) + + +def list_pop_front(array: Expr) -> Expr: + """Returns the array without the first element. + + See Also: + This is an alias for :py:func:`array_pop_front`. + """ + return array_pop_front(array) + + def array_remove(array: Expr, element: Expr) -> Expr: """Removes the first element from the array equal to the given value. @@ -2955,22 +3385,24 @@ def list_remove(array: Expr, element: Expr) -> Expr: return array_remove(array, element) -def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_remove_n(dfn.col("a"), dfn.lit(1), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_remove_n( + ... dfn.col("a"), dfn.lit(1), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 1] """ + max = coerce_to_expr(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) -def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def list_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. See Also: @@ -3004,21 +3436,22 @@ def list_remove_all(array: Expr, element: Expr) -> Expr: return array_remove_all(array, element) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) >>> result = df.select( - ... dfn.functions.array_repeat(dfn.lit(3), dfn.lit(3)).alias("result")) + ... dfn.functions.array_repeat(dfn.lit(3), 3).alias("result")) >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ + count = coerce_to_expr(count) return Expr(f.array_repeat(element.expr, count.expr)) -def list_repeat(element: Expr, count: Expr) -> Expr: +def list_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. See Also: @@ -3051,7 +3484,7 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace(array, from_val, to_val) -def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3061,15 +3494,17 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Exp >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_replace_n(dfn.col("a"), dfn.lit(1), dfn.lit(9), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_replace_n( + ... dfn.col("a"), dfn.lit(1), dfn.lit(9), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ + max = coerce_to_expr(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) -def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3152,7 +3587,10 @@ def list_sort(array: Expr, descending: bool = False, null_first: bool = False) - def array_slice( - array: Expr, begin: Expr, end: Expr, stride: Expr | None = None + array: Expr, + begin: Expr | int, + end: Expr | int, + stride: Expr | int | None = None, ) -> Expr: """Returns a slice of the array. @@ -3160,9 +3598,7 @@ def array_slice( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) >>> result = df.select( - ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(2), dfn.lit(3) - ... ).alias("result")) + ... dfn.functions.array_slice(dfn.col("a"), 2, 3).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 3] @@ -3170,18 +3606,27 @@ def array_slice( >>> result = df.select( ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(1), dfn.lit(4), - ... stride=dfn.lit(2), + ... dfn.col("a"), 1, 4, stride=2, ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 3] """ - if stride is not None: - stride = stride.expr - return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + begin = coerce_to_expr(begin) + end = coerce_to_expr(end) + stride = coerce_to_expr_or_none(stride) + return Expr( + f.array_slice( + array.expr, + begin.expr, + end.expr, + stride.expr if stride is not None else None, + ) + ) -def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: +def list_slice( + array: Expr, begin: Expr | int, end: Expr | int, stride: Expr | int | None = None +) -> Expr: """Returns a slice of the array. See Also: @@ -3273,7 +3718,7 @@ def list_except(array1: Expr, array2: Expr) -> Expr: return array_except(array1, array2) -def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will @@ -3283,15 +3728,15 @@ def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2]]}) >>> result = df.select( - ... dfn.functions.array_resize(dfn.col("a"), dfn.lit(4), - ... dfn.lit(0)).alias("result")) + ... dfn.functions.array_resize(dfn.col("a"), 4, dfn.lit(0)).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ + size = coerce_to_expr(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) -def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def list_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will be @@ -3303,13 +3748,239 @@ def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: return array_resize(array, size, value) -def flatten(array: Expr) -> Expr: - """Flattens an array of arrays into a single array. +def array_any_value(array: Expr) -> Expr: + """Returns the first non-null element in the array. Examples: >>> ctx = dfn.SessionContext() - >>> df = ctx.from_pydict({"a": [[[1, 2], [3, 4]]]}) - >>> result = df.select(dfn.functions.flatten(dfn.col("a")).alias("result")) + >>> df = ctx.from_pydict({"a": [[None, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_any_value(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 2 + """ + return Expr(f.array_any_value(array.expr)) + + +def list_any_value(array: Expr) -> Expr: + """Returns the first non-null element in the array. + + See Also: + This is an alias for :py:func:`array_any_value`. + """ + return array_any_value(array) + + +def array_distance(array1: Expr, array2: Expr) -> Expr: + """Returns the Euclidean distance between two numeric arrays. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1.0, 2.0]], "b": [[1.0, 4.0]]}) + >>> result = df.select( + ... dfn.functions.array_distance( + ... dfn.col("a"), dfn.col("b"), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + 2.0 + """ + return Expr(f.array_distance(array1.expr, array2.expr)) + + +def list_distance(array1: Expr, array2: Expr) -> Expr: + """Returns the Euclidean distance between two numeric arrays. + + See Also: + This is an alias for :py:func:`array_distance`. + """ + return array_distance(array1, array2) + + +def array_max(array: Expr) -> Expr: + """Returns the maximum value in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_max(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 3 + """ + return Expr(f.array_max(array.expr)) + + +def list_max(array: Expr) -> Expr: + """Returns the maximum value in the array. + + See Also: + This is an alias for :py:func:`array_max`. + """ + return array_max(array) + + +def array_min(array: Expr) -> Expr: + """Returns the minimum value in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_min(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + 1 + """ + return Expr(f.array_min(array.expr)) + + +def list_min(array: Expr) -> Expr: + """Returns the minimum value in the array. + + See Also: + This is an alias for :py:func:`array_min`. + """ + return array_min(array) + + +def array_reverse(array: Expr) -> Expr: + """Reverses the order of elements in the array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) + >>> result = df.select( + ... dfn.functions.array_reverse(dfn.col("a")).alias("result")) + >>> result.collect_column("result")[0].as_py() + [3, 2, 1] + """ + return Expr(f.array_reverse(array.expr)) + + +def list_reverse(array: Expr) -> Expr: + """Reverses the order of elements in the array. + + See Also: + This is an alias for :py:func:`array_reverse`. + """ + return array_reverse(array) + + +def arrays_zip(*arrays: Expr) -> Expr: + """Combines multiple arrays into a single array of structs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[1, 2]], "b": [[3, 4]]}) + >>> result = df.select( + ... dfn.functions.arrays_zip(dfn.col("a"), dfn.col("b")).alias("result")) + >>> result.collect_column("result")[0].as_py() + [{'c0': 1, 'c1': 3}, {'c0': 2, 'c1': 4}] + """ + args = [a.expr for a in arrays] + return Expr(f.arrays_zip(args)) + + +def list_zip(*arrays: Expr) -> Expr: + """Combines multiple arrays into a single array of structs. + + See Also: + This is an alias for :py:func:`arrays_zip`. + """ + return arrays_zip(*arrays) + + +def string_to_array( + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None +) -> Expr: + """Splits a string based on a delimiter and returns an array of parts. + + Any parts matching the optional ``null_string`` will be replaced with ``NULL``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello,world"]}) + >>> result = df.select( + ... dfn.functions.string_to_array(dfn.col("a"), ",").alias("result")) + >>> result.collect_column("result")[0].as_py() + ['hello', 'world'] + + Replace parts matching a ``null_string`` with ``NULL``: + + >>> result = df.select( + ... dfn.functions.string_to_array( + ... dfn.col("a"), ",", null_string="world", + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + ['hello', None] + """ + delimiter = coerce_to_expr(delimiter) + null_string = coerce_to_expr_or_none(null_string) + return Expr( + f.string_to_array( + string.expr, + delimiter.expr, + null_string.expr if null_string is not None else None, + ) + ) + + +def string_to_list( + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None +) -> Expr: + """Splits a string based on a delimiter and returns an array of parts. + + See Also: + This is an alias for :py:func:`string_to_array`. + """ + return string_to_array(string, delimiter, null_string) + + +def gen_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr: + """Creates a list of values in the range between start and stop. + + Unlike :py:func:`range`, this includes the upper bound. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [0]}) + >>> result = df.select( + ... dfn.functions.gen_series( + ... dfn.lit(1), dfn.lit(5), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + [1, 2, 3, 4, 5] + + Specify a custom ``step``: + + >>> result = df.select( + ... dfn.functions.gen_series( + ... dfn.lit(1), dfn.lit(10), step=dfn.lit(3), + ... ).alias("result")) + >>> result.collect_column("result")[0].as_py() + [1, 4, 7, 10] + """ + step_expr = step.expr if step is not None else None + return Expr(f.gen_series(start.expr, stop.expr, step_expr)) + + +def generate_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr: + """Creates a list of values in the range between start and stop. + + Unlike :py:func:`range`, this includes the upper bound. + + See Also: + This is an alias for :py:func:`gen_series`. + """ + return gen_series(start, stop, step) + + +def flatten(array: Expr) -> Expr: + """Flattens an array of arrays into a single array. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [[[1, 2], [3, 4]]]}) + >>> result = df.select(dfn.functions.flatten(dfn.col("a")).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 3, 4] """ @@ -3338,6 +4009,158 @@ def empty(array: Expr) -> Expr: return array_empty(array) +# map functions + + +def make_map(*args: Any) -> Expr: + """Returns a map expression. + + Supports three calling conventions: + + - ``make_map({"a": 1, "b": 2})`` — from a Python dictionary. + - ``make_map([keys], [values])`` — from a list of keys and a list of + their associated values. Both lists must be the same length. + - ``make_map(k1, v1, k2, v2, ...)`` — from alternating keys and their + associated values. + + Keys and values that are not already :py:class:`~datafusion.expr.Expr` + are automatically converted to literal expressions. + + Examples: + From a dictionary: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map({"a": 1, "b": 2}).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('a', 1), ('b', 2)] + + From two lists: + + >>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]}) + >>> df = df.select( + ... dfn.functions.make_map( + ... [dfn.col("key")], [dfn.col("val")] + ... ).alias("m")) + >>> df.collect_column("m")[0].as_py() + [('x', 10)] + + From alternating keys and values: + + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map("x", 1, "y", 2).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('x', 1), ('y', 2)] + """ + if len(args) == 1 and isinstance(args[0], dict): + key_list = list(args[0].keys()) + value_list = list(args[0].values()) + elif ( + len(args) == 2 # noqa: PLR2004 + and isinstance(args[0], list) + and isinstance(args[1], list) + ): + if len(args[0]) != len(args[1]): + msg = "make_map requires key and value lists to be the same length" + raise ValueError(msg) + key_list = args[0] + value_list = args[1] + elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004 + key_list = list(args[0::2]) + value_list = list(args[1::2]) + else: + msg = ( + "make_map expects a dict, two lists, or an even number of " + "key-value arguments" + ) + raise ValueError(msg) + + key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list] + val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list] + return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs])) + + +def map_keys(map: Expr) -> Expr: + """Returns a list of all keys in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_keys(dfn.col("m")).alias("keys")) + >>> result.collect_column("keys")[0].as_py() + ['x', 'y'] + """ + return Expr(f.map_keys(map.expr)) + + +def map_values(map: Expr) -> Expr: + """Returns a list of all values in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_values(dfn.col("m")).alias("vals")) + >>> result.collect_column("vals")[0].as_py() + [1, 2] + """ + return Expr(f.map_values(map.expr)) + + +def map_extract(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_extract( + ... dfn.col("m"), dfn.lit("x") + ... ).alias("val")) + >>> result.collect_column("val")[0].as_py() + [1] + """ + return Expr(f.map_extract(map.expr, key.expr)) + + +def map_entries(map: Expr) -> Expr: + """Returns a list of all entries (key-value struct pairs) in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_entries(dfn.col("m")).alias("entries")) + >>> result.collect_column("entries")[0].as_py() + [{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}] + """ + return Expr(f.map_entries(map.expr)) + + +def element_at(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + See Also: + This is an alias for :py:func:`map_extract`. + """ + return map_extract(map, key) + + # aggregate functions def approx_distinct( expression: Expr, @@ -3523,6 +4346,60 @@ def approx_percentile_cont_with_weight( ) +def percentile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + Unlike :py:func:`approx_percentile_cont`, this function computes the exact + percentile value rather than an approximation. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + sort_expression: Values for which to find the percentile + percentile: This must be between 0.0 and 1.0, inclusive + filter: If provided, only compute against rows for which the filter is True + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5 + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.0 + + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5, + ... filter=dfn.col("a") > dfn.lit(1.0), + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.5 + """ + sort_expr_raw = sort_or_default(sort_expression) + filter_raw = filter.expr if filter is not None else None + return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw)) + + +def quantile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + See Also: + This is an alias for :py:func:`percentile_cont`. + """ + return percentile_cont(sort_expression, percentile, filter) + + def array_agg( expression: Expr, distinct: bool = False, @@ -3581,6 +4458,65 @@ def array_agg( ) +def grouping( + expression: Expr, + distinct: bool = False, + filter: Expr | None = None, +) -> Expr: + """Indicates whether a column is aggregated across in the current row. + + Returns 0 when the column is part of the grouping key for that row + (i.e., the row contains per-group results for that column). Returns 1 + when the column is *not* part of the grouping key (i.e., the row's + aggregate spans all values of that column). + + This function is meaningful with + :py:meth:`GroupingSet.rollup `, + :py:meth:`GroupingSet.cube `, or + :py:meth:`GroupingSet.grouping_sets `, + where different rows are grouped by different subsets of columns. In a + default aggregation without grouping sets every column is always part + of the key, so ``grouping()`` always returns 0. + + .. warning:: + + Due to an upstream DataFusion limitation + (`#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` + expression. Doing so will raise an error at execution time. To + rename the column, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` + on the result DataFrame instead. + + Args: + expression: The column to check grouping status for + distinct: If True, compute on distinct values only + filter: If provided, only compute against rows for which the filter is True + + Examples: + With :py:meth:`~datafusion.expr.GroupingSet.rollup`, the result + includes both per-group rows (``grouping(a) = 0``) and a + grand-total row where ``a`` is aggregated across + (``grouping(a) = 1``): + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:class:`~datafusion.expr.GroupingSet` + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw)) + + def avg( expression: Expr, filter: Expr | None = None, @@ -4052,6 +4988,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr: return Expr(f.var_pop(expression.expr, filter=filter_raw)) +def var_population(expression: Expr, filter: Expr | None = None) -> Expr: + """Computes the population variance of the argument. + + See Also: + This is an alias for :py:func:`var_pop`. + """ + return var_pop(expression, filter) + + def var_samp(expression: Expr, filter: Expr | None = None) -> Expr: """Computes the sample variance of the argument. diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py deleted file mode 100644 index 65eb1f042..000000000 --- a/python/datafusion/html_formatter.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Deprecated module for dataframe formatting.""" - -import warnings - -from datafusion.dataframe_formatter import * # noqa: F403 - -warnings.warn( - "The module 'html_formatter' is deprecated and will be removed in the next release." - "Please use 'dataframe_formatter' instead.", - DeprecationWarning, - stacklevel=3, -) diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index 9c96a18fc..c0cfd523f 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -24,11 +24,15 @@ import datafusion._internal as df_internal if TYPE_CHECKING: + import datetime + from datafusion.context import SessionContext __all__ = [ "ExecutionPlan", "LogicalPlan", + "Metric", + "MetricsSet", ] @@ -151,3 +155,176 @@ def to_proto(self) -> bytes: Tables created in memory from record batches are currently not supported. """ return self._raw_plan.to_proto() + + def metrics(self) -> MetricsSet | None: + """Return metrics for this plan node, or None if this plan has no MetricsSet. + + Some operators (e.g. DataSourceExec) eagerly initialize a MetricsSet + when the plan is created, so this may return a set even before + execution. Metric *values* (such as ``output_rows``) are only + meaningful after the DataFrame has been executed. + """ + raw = self._raw_plan.metrics() + if raw is None: + return None + return MetricsSet(raw) + + def collect_metrics(self) -> list[tuple[str, MetricsSet]]: + """Return runtime statistics for each step of the query execution. + + DataFusion executes a query as a pipeline of operators — for example a + data source scan, followed by a filter, followed by a projection. After + the DataFrame has been executed (via + :py:meth:`~datafusion.DataFrame.collect`, + :py:meth:`~datafusion.DataFrame.execute_stream`, etc.), each operator + records statistics such as how many rows it produced and how much CPU + time it consumed. + + Each entry in the returned list corresponds to one operator that + recorded metrics. The first element of the tuple is the operator's + description string — the same text shown by + :py:meth:`display_indent` — which identifies both the operator type + and its key parameters, for example ``"FilterExec: column1@0 > 1"`` + or ``"DataSourceExec: partitions=1"``. + + Returns: + A list of ``(description, MetricsSet)`` tuples ordered from the + outermost operator (top of the execution tree) down to the + data-source leaves. Only operators that recorded at least one + metric are included. Returns an empty list if called before the + DataFrame has been executed. + """ + result: list[tuple[str, MetricsSet]] = [] + + def _walk(node: ExecutionPlan) -> None: + ms = node.metrics() + if ms is not None: + result.append((node.display(), ms)) + for child in node.children(): + _walk(child) + + _walk(self) + return result + + +class MetricsSet: + """A set of metrics for a single execution plan operator. + + A physical plan operator runs independently across one or more partitions. + :py:meth:`metrics` returns the raw per-partition :py:class:`Metric` objects. + The convenience properties (:py:attr:`output_rows`, :py:attr:`elapsed_compute`, + etc.) automatically sum the named metric across *all* partitions, giving a + single aggregate value for the operator as a whole. + """ + + def __init__(self, raw: df_internal.MetricsSet) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + def metrics(self) -> list[Metric]: + """Return all individual metrics in this set.""" + return [Metric(m) for m in self._raw.metrics()] + + @property + def output_rows(self) -> int | None: + """Sum of output_rows across all partitions.""" + return self._raw.output_rows() + + @property + def elapsed_compute(self) -> int | None: + """Total CPU time (in nanoseconds) spent inside this operator's execute loop. + + Summed across all partitions. Returns ``None`` if no ``elapsed_compute`` + metric was recorded. + """ + return self._raw.elapsed_compute() + + @property + def spill_count(self) -> int | None: + """Number of times this operator spilled data to disk due to memory pressure. + + This is a count of spill events, not a byte count. Summed across all + partitions. Returns ``None`` if no ``spill_count`` metric was recorded. + """ + return self._raw.spill_count() + + @property + def spilled_bytes(self) -> int | None: + """Sum of spilled_bytes across all partitions.""" + return self._raw.spilled_bytes() + + @property + def spilled_rows(self) -> int | None: + """Sum of spilled_rows across all partitions.""" + return self._raw.spilled_rows() + + def sum_by_name(self, name: str) -> int | None: + """Sum the named metric across all partitions. + + Useful for accessing any metric not exposed as a first-class property. + Returns ``None`` if no metric with the given name was recorded. + + Args: + name: The metric name, e.g. ``"output_rows"`` or ``"elapsed_compute"``. + """ + return self._raw.sum_by_name(name) + + def __repr__(self) -> str: + """Return a string representation of the metrics set.""" + return repr(self._raw) + + +class Metric: + """A single execution metric with name, value, partition, and labels.""" + + def __init__(self, raw: df_internal.Metric) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + @property + def name(self) -> str: + """The name of this metric (e.g. ``output_rows``).""" + return self._raw.name + + @property + def value(self) -> int | datetime.datetime | None: + """The value of this metric. + + Returns an ``int`` for counters, gauges, and time-based metrics + (nanoseconds), a :py:class:`~datetime.datetime` (UTC) for + ``start_timestamp`` / ``end_timestamp`` metrics, or ``None`` + when the value has not been set or is not representable. + """ + return self._raw.value + + @property + def value_as_datetime(self) -> datetime.datetime | None: + """The value as a UTC :py:class:`~datetime.datetime` for timestamp metrics. + + Returns ``None`` for all non-timestamp metrics and for timestamp + metrics whose value has not been set (e.g. before execution). + """ + return self._raw.value_as_datetime + + @property + def partition(self) -> int | None: + """The 0-based partition index this metric applies to. + + Returns ``None`` for metrics that are not partition-specific (i.e. they + apply globally across all partitions of the operator). + """ + return self._raw.partition + + def labels(self) -> dict[str, str]: + """Return the labels associated with this metric. + + Labels provide additional context for a metric. For example:: + + metric.labels() + # {'output_type': 'final'} + """ + return self._raw.labels() + + def __repr__(self) -> str: + """Return a string representation of the metric.""" + return repr(self._raw) diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 3115238fa..6353ef8cc 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -25,11 +25,6 @@ from typing import TYPE_CHECKING -try: - from warnings import deprecated # Python 3.13+ -except ImportError: - from typing_extensions import deprecated # Python 3.12 - from datafusion.plan import LogicalPlan from ._internal import substrait as substrait_internal @@ -88,11 +83,6 @@ def from_json(json: str) -> Plan: return Plan(substrait_internal.Plan.from_json(json)) -@deprecated("Use `Plan` instead.") -class plan(Plan): # noqa: N801 - """See `Plan`.""" - - class Serde: """Provides the ``Substrait`` serialization and deserialization.""" @@ -158,11 +148,6 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: return Plan(substrait_internal.Serde.deserialize_bytes(proto_bytes)) -@deprecated("Use `Serde` instead.") -class serde(Serde): # noqa: N801 - """See `Serde` instead.""" - - class Producer: """Generates substrait plans from a logical plan.""" @@ -184,11 +169,6 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: ) -@deprecated("Use `Producer` instead.") -class producer(Producer): # noqa: N801 - """Use `Producer` instead.""" - - class Consumer: """Generates a logical plan from a substrait plan.""" @@ -206,8 +186,3 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: return LogicalPlan( substrait_internal.Consumer.from_substrait_plan(ctx.ctx, plan.plan_internal) ) - - -@deprecated("Use `Consumer` instead.") -class consumer(Consumer): # noqa: N801 - """Use `Consumer` instead.""" diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py deleted file mode 100644 index c7265fa09..000000000 --- a/python/datafusion/udf.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Deprecated module for user defined functions.""" - -import warnings - -from datafusion.user_defined import * # noqa: F403 - -warnings.warn( - "The module 'udf' is deprecated and will be removed in the next release. " - "Please use 'user_defined' instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eaccdfa3..848ab4cee 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -213,7 +213,6 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Examples: Using ``udf`` as a function: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import ScalarUDF >>> def double_func(x): @@ -480,7 +479,6 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 instance in which this UDAF is used. Examples: - >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> from datafusion.user_defined import AggregateUDF, Accumulator, udaf >>> class Summarize(Accumulator): @@ -874,7 +872,6 @@ def udwf(*args: Any, **kwargs: Any): # noqa: D417 When using ``udwf`` as a decorator, do not pass ``func`` explicitly. Examples: - >>> import pyarrow as pa >>> from datafusion.user_defined import WindowUDF, WindowEvaluator, udwf >>> class BiasedNumbers(WindowEvaluator): ... def __init__(self, start: int = 0): diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5df6ed20f..e0ebdbae5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -31,6 +31,7 @@ Table, column, literal, + udf, ) @@ -351,6 +352,125 @@ def test_deregister_table(ctx, database): assert public.names() == {"csv1", "csv2"} +def test_deregister_udf(): + ctx = SessionContext() + + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="my_is_null", + ) + ctx.register_udf(is_null) + + # Verify it works + df = ctx.from_pydict({"a": [1.0, None]}) + ctx.register_table("t", df.into_view()) + result = ctx.sql("SELECT my_is_null(a) FROM t").collect() + assert result[0].column(0) == pa.array([False, True]) + + # Deregister and verify it's gone + ctx.deregister_udf("my_is_null") + with pytest.raises(ValueError): + ctx.sql("SELECT my_is_null(a) FROM t").collect() + + +def test_deregister_udaf(): + import pyarrow.compute as pc + + ctx = SessionContext() + from datafusion import Accumulator, udaf + + class MySum(Accumulator): + def __init__(self): + self._sum = 0.0 + + def update(self, values: pa.Array) -> None: + self._sum += pc.sum(values).as_py() + + def merge(self, states: list[pa.Array]) -> None: + self._sum += pc.sum(states[0]).as_py() + + def state(self) -> list: + return [self._sum] + + def evaluate(self) -> pa.Scalar: + return self._sum + + my_sum = udaf( + MySum, + [pa.float64()], + pa.float64(), + [pa.float64()], + volatility="immutable", + name="my_sum", + ) + ctx.register_udaf(my_sum) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_sum(a) FROM t").collect() + assert result[0].column(0) == pa.array([6.0]) + + ctx.deregister_udaf("my_sum") + with pytest.raises(ValueError): + ctx.sql("SELECT my_sum(a) FROM t").collect() + + +def test_deregister_udwf(): + ctx = SessionContext() + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class MyRowNumber(WindowEvaluator): + def __init__(self): + self._row = 0 + + def evaluate_all(self, values, num_rows): + return pa.array(list(range(1, num_rows + 1)), type=pa.uint64()) + + my_row_number = udwf( + MyRowNumber, + [pa.float64()], + pa.uint64(), + volatility="immutable", + name="my_row_number", + ) + ctx.register_udwf(my_row_number) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64()) + + ctx.deregister_udwf("my_row_number") + with pytest.raises(ValueError): + ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + + +def test_deregister_udtf(): + import pyarrow.dataset as ds + + ctx = SessionContext() + from datafusion import Table, udtf + + class MyTable: + def __call__(self): + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + my_table = udtf(MyTable(), "my_table") + ctx.register_udtf(my_table) + + result = ctx.sql("SELECT * FROM my_table()").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + ctx.deregister_udtf("my_table") + with pytest.raises(ValueError): + ctx.sql("SELECT * FROM my_table()").collect() + + def test_register_table_from_dataframe(ctx): df = ctx.from_pydict({"a": [1, 2]}) ctx.register_table("df_tbl", df) @@ -551,6 +671,61 @@ def test_table_not_found(ctx): ctx.table(f"not-found-{uuid4()}") +def test_session_start_time(ctx): + import datetime + import re + + st = ctx.session_start_time() + assert isinstance(st, str) + # Truncate nanoseconds to microseconds for Python 3.10 compat + st = re.sub(r"(\.\d{6})\d+", r"\1", st) + dt = datetime.datetime.fromisoformat(st) + assert dt.isoformat() + + +def test_enable_ident_normalization(ctx): + assert ctx.enable_ident_normalization() is True + ctx.sql("SET datafusion.sql_parser.enable_ident_normalization = false") + assert ctx.enable_ident_normalization() is False + + +def test_parse_sql_expr(ctx): + from datafusion.common import DFSchema + + schema = DFSchema.empty() + expr = ctx.parse_sql_expr("1 + 2", schema) + assert str(expr) == "Expr(Int64(1) + Int64(2))" + + +def test_execute_logical_plan(ctx): + df = ctx.from_pydict({"a": [1, 2, 3]}) + plan = df.logical_plan() + df2 = ctx.execute_logical_plan(plan) + result = df2.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_refresh_catalogs(ctx): + ctx.refresh_catalogs() + + +def test_remove_optimizer_rule(ctx): + assert ctx.remove_optimizer_rule("push_down_filter") is True + assert ctx.remove_optimizer_rule("nonexistent_rule") is False + + +def test_table_provider(ctx): + batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_record_batches("provider_test", [[batch]]) + tbl = ctx.table_provider("provider_test") + assert tbl.schema == pa.schema([("x", pa.int64())]) + + +def test_table_provider_not_found(ctx): + with pytest.raises(KeyError): + ctx.table_provider("nonexistent_table") + + def test_read_json(ctx): path = pathlib.Path(__file__).parent.resolve() @@ -668,6 +843,68 @@ def test_read_avro(ctx): assert avro_df is not None +def test_read_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then read it back + table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + df = ctx.read_arrow(str(arrow_path)) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array(["x", "y", "z"]) + + # Also verify pathlib.Path works + df = ctx.read_arrow(arrow_path) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_read_empty(ctx): + df = ctx.read_empty() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + df = ctx.empty_table() + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 0 + + +def test_register_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then register and query it + table = pa.table({"x": [10, 20, 30]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + ctx.register_arrow("arrow_tbl", str(arrow_path)) + result = ctx.sql("SELECT * FROM arrow_tbl").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + # Also verify pathlib.Path works + ctx.register_arrow("arrow_tbl_path", arrow_path) + result = ctx.sql("SELECT * FROM arrow_tbl_path").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + +def test_register_batch(ctx): + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + ctx.register_batch("batch_tbl", batch) + result = ctx.sql("SELECT * FROM batch_tbl").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + + +def test_register_batch_empty(ctx): + batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())}) + ctx.register_batch("empty_batch_tbl", batch) + result = ctx.sql("SELECT * FROM empty_batch_tbl").collect() + assert result[0].num_rows == 0 + + def test_create_sql_options(): SQLOptions() @@ -727,12 +964,12 @@ def test_csv_read_options_builder_pattern(): options = ( CsvReadOptions() - .with_has_header(False) # noqa: FBT003 + .with_has_header(False) .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) - .with_truncated_rows(True) # noqa: FBT003 - .with_newlines_in_values(True) # noqa: FBT003 + .with_truncated_rows(True) + .with_newlines_in_values(True) .with_file_extension(".tsv") ) assert options.has_header is False diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 759d6278c..9e2f791ea 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,6 +29,7 @@ import pytest from datafusion import ( DataFrame, + ExplainFormat, InsertOp, ParquetColumnOptions, ParquetWriterOptions, @@ -411,6 +412,16 @@ def test_show_empty(df, capsys): assert "DataFrame has no rows" in captured.out +def test_show_on_explain(ctx, capsys): + ctx.sql("explain select 1").show() + captured = capsys.readouterr() + assert "1 as Int64(1)" in captured.out + + ctx.sql("explain analyze select 1").show() + captured = capsys.readouterr() + assert "1 as Int64(1)" in captured.out + + def test_sort(df): df = df.sort(column("b").sort(ascending=False)) @@ -3415,10 +3426,18 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +_slow_udf_started = threading.Event() + + @udf([pa.int64()], pa.int64(), "immutable") def slow_udf(x: pa.Array) -> pa.Array: - # This must be longer than the check interval in wait_for_future - time.sleep(2.0) + _slow_udf_started.set() + # Sleep in small increments so Python's eval loop checks for pending + # async exceptions (like KeyboardInterrupt via PyThreadState_SetAsyncExc) + # between iterations. A single long time.sleep() is a C call where async + # exceptions are not checked on all Python versions (notably 3.11). + for _ in range(200): + time.sleep(0.01) return x @@ -3452,6 +3471,7 @@ def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 P if as_c_stream: reader = pa.RecordBatchReader.from_stream(df) + _slow_udf_started.clear() read_started = threading.Event() read_exception = [] read_thread_id = None @@ -3463,6 +3483,14 @@ def trigger_interrupt(): msg = f"Read operation did not start within {max_wait_time} seconds" raise RuntimeError(msg) + # For slow_query tests, wait until the UDF is actually executing Python + # bytecode before sending the interrupt. PyThreadState_SetAsyncExc only + # delivers exceptions when the thread is in the Python eval loop, not + # while in native (Rust/C) code. + if slow_query and not _slow_udf_started.wait(timeout=max_wait_time): + msg = f"UDF did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + if read_thread_id is None: msg = "Cannot get read thread ID" raise RuntimeError(msg) @@ -3569,3 +3597,263 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order): pa.parquet.write_table(table, path) df = ctx.read_parquet(path, file_sort_order=file_sort_order) assert df.collect()[0].column(0).to_pylist() == [1, 2] + + +@pytest.mark.parametrize( + ("df1_data", "df2_data", "method", "kwargs", "expected_a", "expected_b"), + [ + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 2], "b": [10, 20]}, + "except_all", + {"distinct": True}, + [3], + [30], + id="except_all(distinct=True): removes matching rows and deduplicates", + ), + pytest.param( + {"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]}, + {"a": [1, 4], "b": [10, 40]}, + "intersect", + {"distinct": True}, + [1], + [10], + id="intersect(distinct=True): keeps common rows and deduplicates", + ), + pytest.param( + {"a": [1], "b": [10]}, + {"b": [20], "a": [2]}, # reversed column order tests matching by name + "union_by_name", + {}, + [1, 2], + [10, 20], + id="union_by_name: matches columns by name not position", + ), + ], +) +def test_set_operations_distinct( + df1_data, df2_data, method, kwargs, expected_a, expected_b +): + ctx = SessionContext() + df1 = ctx.from_pydict(df1_data) + df2 = ctx.from_pydict(df2_data) + result = ( + getattr(df1, method)(df2, **kwargs) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + assert result.column(0).to_pylist() == expected_a + assert result.column(1).to_pylist() == expected_b + + +def test_union_by_name_distinct(): + ctx = SessionContext() + df1 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]}) + df2 = ctx.from_pydict({"b": [10], "a": [1]}) + result = df1.union_by_name(df2, distinct=True).collect()[0] + assert result.column(0).to_pylist() == [1] + assert result.column(1).to_pylist() == [10] + + +def test_column_qualified(): + """DataFrame.column() returns a qualified column expression.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4]}) + expr = df.column("a") + result = df.select(expr).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + + +def test_column_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.column("z") + + +def test_column_ambiguous(): + """After a join, duplicate column names that cannot be resolved raise an error.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "val": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "val": [30, 40]}) + joined = left.join(right, on="id", how="inner") + with pytest.raises(Exception, match="not found"): + joined.column("val") + + +def test_column_after_join(): + """Qualified column works for non-ambiguous columns after a join.""" + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2], "x": [10, 20]}) + right = ctx.from_pydict({"id": [1, 2], "y": [30, 40]}) + joined = left.join(right, on="id", how="inner") + expr = joined.column("y") + result = joined.select("id", expr).sort("id").collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [30, 40] + + +def test_col_join_disambiguate(): + """Use col() to disambiguate and select columns after a join.""" + ctx = SessionContext() + df1 = ctx.from_pydict({"foo": [1, 2, 3], "bar": [5, 6, 7]}) + df2 = ctx.from_pydict({"foo": [1, 2, 3], "baz": [8, 9, 10]}) + joined = df1.join_on(df2, df1.col("foo") == df2.col("foo"), how="inner") + result = ( + joined.select(df1.col("foo"), df1.col("bar"), df2.col("baz")) + .sort(df1.col("foo")) + .to_pydict() + ) + assert result["bar"] == [5, 6, 7] + assert result["baz"] == [8, 9, 10] + + +def test_find_qualified_columns(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + exprs = df.find_qualified_columns("a", "c") + assert len(exprs) == 2 + result = df.select(*exprs).collect()[0] + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [5, 6] + + +def test_find_qualified_columns_not_found(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="not found"): + df.find_qualified_columns("a", "z") + + +def test_distinct_on(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2, 2], "b": [10, 20, 30, 40]}) + result = ( + df.distinct_on( + [column("a")], + [column("a"), column("b")], + [column("a").sort(ascending=True), column("b").sort(ascending=True)], + ) + .sort(column("a").sort(ascending=True)) + .collect()[0] + ) + # Keeps the first row per group (smallest b per a) + assert result.column(0).to_pylist() == [1, 2] + assert result.column(1).to_pylist() == [10, 30] + + +@pytest.mark.parametrize( + ("input_values", "expected"), + [ + ([3, 1, 2], [1, 2, 3]), + ([1, 2, 3], [1, 2, 3]), + ([3, None, 1, 2], [1, 2, 3, None]), + ], +) +def test_sort_by(input_values, expected): + """sort_by always sorts ascending with nulls last regardless of input order.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": input_values}) + result = df.sort_by(column("a")).collect()[0] + assert result.column(0).to_pylist() == expected + + +@pytest.mark.parametrize( + ("fmt", "verbose", "analyze", "expected_substring"), + [ + pytest.param(None, False, False, None, id="default format"), + pytest.param(ExplainFormat.TREE, False, False, "---", id="tree format"), + pytest.param( + ExplainFormat.INDENT, True, True, None, id="indent verbose+analyze" + ), + pytest.param(ExplainFormat.PGJSON, False, False, '"Plan"', id="pgjson format"), + pytest.param( + ExplainFormat.GRAPHVIZ, False, False, "digraph", id="graphviz format" + ), + ], +) +def test_explain_with_format(capsys, fmt, verbose, analyze, expected_substring): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + df.explain(verbose=verbose, analyze=analyze, format=fmt) + captured = capsys.readouterr() + assert "plan_type" in captured.out + if expected_substring is not None: + assert expected_substring in captured.out + + +@pytest.mark.parametrize( + ("window_exprs", "expected_columns"), + [ + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + ], + {"rn": [1, 2, 1]}, + id="single window expression", + ), + pytest.param( + lambda: [ + f.row_number(partition_by=[column("b")], order_by=[column("a")]).alias( + "rn" + ), + f.rank(partition_by=[column("b")], order_by=[column("a")]).alias("rnk"), + ], + {"rn": [1, 2, 1], "rnk": [1, 2, 1]}, + id="multiple window expressions", + ), + ], +) +def test_window(window_exprs, expected_columns): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "x", "y"]}) + result = ( + df.window(*window_exprs()).sort(column("a").sort(ascending=True)).collect()[0] + ) + for col_name, expected_values in expected_columns.items(): + assert col_name in result.schema.names + assert ( + result.column(result.schema.get_field_index(col_name)).to_pylist() + == expected_values + ) + + +@pytest.mark.parametrize( + ("input_data", "recursions", "expected_a"), + [ + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + None, + [1, 2, 3], + id="basic unnest without recursions", + ), + pytest.param( + {"a": [[1, 2], [3]], "b": ["x", "y"]}, + [("a", "a", 1)], + [1, 2, 3], + id="explicit depth 1 matches basic unnest", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 1)], + [[1, 2], [3], [4]], + id="depth 1 on nested lists keeps inner lists", + ), + pytest.param( + {"a": [[[1, 2], [3]], [[4]]], "b": ["x", "y"]}, + [("a", "a", 2)], + [1, 2, 3, 4], + id="depth 2 fully flattens nested lists", + ), + ], +) +def test_unnest_columns_with_recursions(input_data, recursions, expected_a): + ctx = SessionContext() + df = ctx.from_pydict(input_data) + kwargs = {} + if recursions is not None: + kwargs["recursions"] = recursions + result = df.unnest_columns("a", **kwargs).collect()[0] + assert result.column(0).to_pylist() == expected_a diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 9a287c1f7..8aa791ae1 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -53,6 +53,8 @@ TransactionEnd, TransactionStart, Values, + coerce_to_expr, + coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @@ -153,8 +155,8 @@ def test_relational_expr(test_ctx): batch = pa.RecordBatch.from_arrays( [ - pa.array([1, 2, 3]), - pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), + pa.array([1, 2, 3, None]), + pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()), ], names=["a", "b"], ) @@ -171,6 +173,10 @@ def test_relational_expr(test_ctx): assert df.filter(col("b") != "beta").count() == 2 assert df.filter(col("a") == "beta").count() == 0 + assert df.filter(col("a") == None).count() == 1 # noqa: E711 + assert df.filter(col("a") != None).count() == 3 # noqa: E711 + assert df.filter(col("b") == None).count() == 1 # noqa: E711 + assert df.filter(col("b") != None).count() == 3 # noqa: E711 def test_expr_to_variant(): @@ -319,27 +325,6 @@ def test_expr_getitem() -> None: assert array_values == [2, 5, None, None] -def test_display_name_deprecation(): - import warnings - - expr = col("foo") - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered - warnings.simplefilter("always") - - # should trigger warning - name = expr.display_name() - - # Verify some things - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "deprecated" in str(w[-1].message) - - # returns appropriate result - assert name == expr.schema_name() - assert name == "foo" - - @pytest.fixture def df(): ctx = SessionContext() @@ -1047,12 +1032,55 @@ def test_ensure_expr_list_bytearray(): ensure_expr_list(bytearray(b"a")) +def test_coerce_to_expr_passes_expr_through(): + e = col("a") + result = coerce_to_expr(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + +def test_coerce_to_expr_wraps_int(): + result = coerce_to_expr(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_wraps_str(): + result = coerce_to_expr("hello") + assert isinstance(result, type(lit("hello"))) + + +def test_coerce_to_expr_wraps_float(): + result = coerce_to_expr(3.14) + assert isinstance(result, type(lit(3.14))) + + +def test_coerce_to_expr_wraps_bool(): + result = coerce_to_expr(True) + assert isinstance(result, type(lit(True))) + + +def test_coerce_to_expr_or_none_returns_none(): + assert coerce_to_expr_or_none(None) is None + + +def test_coerce_to_expr_or_none_wraps_value(): + result = coerce_to_expr_or_none(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_or_none_passes_expr_through(): + e = col("a") + result = coerce_to_expr_or_none(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + @pytest.mark.parametrize( "value", [ # Boolean - pa.scalar(True, type=pa.bool_()), # noqa: FBT003 - pa.scalar(False, type=pa.bool_()), # noqa: FBT003 + pa.scalar(True, type=pa.bool_()), + pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 37d349c58..d9781b1fb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -20,8 +20,9 @@ import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column, literal, string_literal +from datafusion import SessionContext, column, literal from datafusion import functions as f +from datafusion.expr import GroupingSet np.seterr(invalid="ignore") @@ -330,6 +331,10 @@ def py_flatten(arr): f.empty, lambda data: [len(r) == 0 for r in data], ), + ( + f.list_empty, + lambda data: [len(r) == 0 for r in data], + ), ( lambda col: f.array_extract(col, literal(1)), lambda data: [r[0] for r in data], @@ -354,18 +359,54 @@ def py_flatten(arr): lambda col: f.array_has(col, literal(1.0)), lambda data: [1.0 in r for r in data], ), + ( + lambda col: f.list_has(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), + ( + lambda col: f.array_contains(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), + ( + lambda col: f.list_contains(col, literal(1.0)), + lambda data: [1.0 in r for r in data], + ), ( lambda col: f.array_has_all( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ), + ( + lambda col: f.list_has_all( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), ( lambda col: f.array_has_any( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ), + ( + lambda col: f.list_has_any( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), + ( + lambda col: f.arrays_overlap( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), + ( + lambda col: f.list_overlap( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + ), ( lambda col: f.array_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], @@ -418,10 +459,18 @@ def py_flatten(arr): f.array_pop_back, lambda data: [arr[:-1] for arr in data], ), + ( + f.list_pop_back, + lambda data: [arr[:-1] for arr in data], + ), ( f.array_pop_front, lambda data: [arr[1:] for arr in data], ), + ( + f.list_pop_front, + lambda data: [arr[1:] for arr in data], + ), ( lambda col: f.array_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], @@ -668,6 +717,106 @@ def test_array_function_obj_tests(stmt, py_expr): assert a == b +@pytest.mark.parametrize( + ("args", "expected"), + [ + pytest.param( + ({"x": 1, "y": 2},), + [("x", 1), ("y", 2)], + id="dict", + ), + pytest.param( + ({"x": literal(1), "y": literal(2)},), + [("x", 1), ("y", 2)], + id="dict_with_exprs", + ), + pytest.param( + ("x", 1, "y", 2), + [("x", 1), ("y", 2)], + id="variadic_pairs", + ), + pytest.param( + (literal("x"), literal(1), literal("y"), literal(2)), + [("x", 1), ("y", 2)], + id="variadic_with_exprs", + ), + ], +) +def test_make_map(args, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0) + assert result[0].as_py() == expected + + +def test_make_map_from_two_lists(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["k1", "k2", "k3"]), + pa.array([10, 20, 30]), + ], + names=["keys", "vals"], + ) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map([column("keys")], [column("vals")]) + result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0) + assert result.to_pylist() == [["k1"], ["k2"], ["k3"]] + + result = df.select(f.map_values(m).alias("v")).collect()[0].column(0) + assert result.to_pylist() == [[10], [20], [30]] + + +def test_make_map_odd_args_raises(): + with pytest.raises(ValueError, match="make_map expects"): + f.make_map("x", 1, "y") + + +def test_make_map_mismatched_lengths(): + with pytest.raises(ValueError, match="same length"): + f.make_map(["a", "b"], [1]) + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + pytest.param(f.map_keys, ["x", "y"], id="map_keys"), + pytest.param(f.map_values, [1, 2], id="map_values"), + pytest.param( + lambda m: f.map_extract(m, literal("x")), + [1], + id="map_extract", + ), + pytest.param( + lambda m: f.map_extract(m, literal("z")), + [None], + id="map_extract_missing_key", + ), + pytest.param( + f.map_entries, + [{"key": "x", "value": 1}, {"key": "y", "value": 2}], + id="map_entries", + ), + pytest.param( + lambda m: f.element_at(m, literal("y")), + [2], + id="element_at", + ), + ], +) +def test_map_functions(func, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map({"x": 1, "y": 2}) + result = df.select(func(m).alias("out")).collect()[0].column(0) + assert result[0].as_py() == expected + + @pytest.mark.parametrize( ("function", "expected_result"), [ @@ -745,6 +894,7 @@ def test_array_function_obj_tests(stmt, py_expr): f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"]), ), + (f.contains(column("a"), literal("ell")), pa.array([True, False, False])), (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), ( @@ -1107,13 +1257,43 @@ def test_today_alias_matches_current_date(df): assert result.column(0) == result.column(1) +def test_current_timestamp_alias_matches_now(df): + result = df.select( + f.now().alias("now"), + f.current_timestamp().alias("current_timestamp"), + ).collect()[0] + + assert result.column(0) == result.column(1) + + +def test_date_format_alias_matches_to_char(df): + result = df.select( + f.to_char( + f.to_timestamp(literal("2021-01-01T00:00:00")), literal("%Y/%m/%d") + ).alias("to_char"), + f.date_format( + f.to_timestamp(literal("2021-01-01T00:00:00")), literal("%Y/%m/%d") + ).alias("date_format"), + ).collect()[0] + + assert result.column(0) == result.column(1) + assert result.column(0)[0].as_py() == "2021/01/01" + + +def test_make_time(df): + ctx = SessionContext() + df_time = ctx.from_pydict({"h": [12], "m": [30], "s": [0]}) + result = df_time.select( + f.make_time(column("h"), column("m"), column("s")).alias("t") + ).collect()[0] + + assert result.column(0)[0].as_py() == time(12, 30) + + def test_arrow_cast(df): df = df.select( - # we use `string_literal` to return utf8 instead of `literal` which returns - # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view - # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 - f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), - f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), + f.arrow_cast(column("b"), "Float64").alias("b_as_float"), + f.arrow_cast(column("b"), "Int32").alias("b_as_int"), ) result = df.collect() assert len(result) == 1 @@ -1123,6 +1303,19 @@ def test_arrow_cast(df): assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) +def test_arrow_cast_with_pyarrow_type(df): + df = df.select( + f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"), + f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"), + f.arrow_cast(column("b"), pa.string()).alias("b_as_str"), + ) + result = df.collect()[0] + + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) + assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string()) + + def test_case(df): df = df.select( f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), @@ -1376,62 +1569,626 @@ def test_alias_with_metadata(df): assert df.schema().field("b").metadata == {b"key": b"value"} -def test_coalesce(df): - # Create a DataFrame with null values +@pytest.fixture +def df_with_nulls(): ctx = SessionContext() + # Rows: + # 0: both values present + # 1: a/d/h/k null, b/e/i/l present + # 2: a/d/h/k present, b/e/i/l null + # 3: all null batch = pa.RecordBatch.from_arrays( [ - pa.array(["Hello", None, "!"]), # string column with null - pa.array([4, None, 6]), # integer column with null - pa.array(["hello ", None, " !"]), # string column with null + pa.array([1, None, 3, None], type=pa.int64()), + pa.array([5, 10, None, None], type=pa.int64()), + pa.array([20, 30, 40, None], type=pa.int64()), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + pa.array(["banana", "date", None, None], type=pa.utf8()), + pa.array(["x", "y", "z", None], type=pa.utf8()), pa.array( [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), None, - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - ), # datetime with null - pa.array([False, None, True]), # boolean column with null + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + None, + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array([True, None, False, None], type=pa.bool_()), + pa.array([False, True, None, None], type=pa.bool_()), ], - names=["a", "b", "c", "d", "e"], + names=["a", "b", "c", "d", "e", "g", "h", "i", "k", "l"], ) - df_with_nulls = ctx.create_dataframe([[batch]]) + return ctx.create_dataframe([[batch]]) + - # Test coalesce with different data types - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal("default")).alias("a_coalesced"), - f.coalesce(column("b"), literal(0)).alias("b_coalesced"), - f.coalesce(column("c"), literal("default")).alias("c_coalesced"), - f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias( - "d_coalesced" +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + f.greatest(column("a"), column("b")), + pa.array([5, 10, 3, None], type=pa.int64()), + id="greatest_int", ), - f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"), - ) + pytest.param( + f.greatest(column("d"), column("e")), + pa.array(["banana", "date", "cherry", None], type=pa.utf8()), + id="greatest_str", + ), + pytest.param( + f.least(column("a"), column("b")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_int", + ), + pytest.param( + f.least(column("d"), column("e")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="least_str", + ), + pytest.param( + f.coalesce(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="coalesce_int", + ), + pytest.param( + f.coalesce(column("d"), column("e"), column("g")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="coalesce_str", + ), + pytest.param( + f.nvl(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="nvl_int", + ), + pytest.param( + f.nvl(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="nvl_str", + ), + pytest.param( + f.ifnull(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="ifnull_int", + ), + pytest.param( + f.ifnull(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="ifnull_str", + ), + pytest.param( + f.nvl2(column("a"), column("b"), column("c")), + pa.array([5, 30, None, None], type=pa.int64()), + id="nvl2_int", + ), + pytest.param( + f.nvl2(column("d"), column("e"), column("g")), + pa.array(["banana", "y", None, None], type=pa.utf8()), + id="nvl2_str", + ), + pytest.param( + f.nullif(column("a"), column("b")), + pa.array([1, None, 3, None], type=pa.int64()), + id="nullif_int", + ), + pytest.param( + f.nullif(column("d"), column("e")), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + id="nullif_str", + ), + pytest.param( + f.nullif(column("a"), literal(1)), + pa.array([None, None, 3, None], type=pa.int64()), + id="nullif_equal_values", + ), + pytest.param( + f.greatest(column("a"), column("b"), column("c")), + pa.array([20, 30, 40, None], type=pa.int64()), + id="greatest_variadic", + ), + pytest.param( + f.least(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_variadic", + ), + pytest.param( + f.greatest(column("a"), literal(2)), + pa.array([2, 2, 3, 2], type=pa.int64()), + id="greatest_literal", + ), + pytest.param( + f.least(column("a"), literal(2)), + pa.array([1, 2, 2, 2], type=pa.int64()), + id="least_literal", + ), + pytest.param( + f.coalesce(column("a"), literal(0)), + pa.array([1, 0, 3, 0], type=pa.int64()), + id="coalesce_literal_int", + ), + pytest.param( + f.coalesce(column("d"), literal("default")), + pa.array(["apple", "default", "cherry", "default"], type=pa.string_view()), + id="coalesce_literal_str", + ), + pytest.param( + f.nvl(column("a"), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="nvl_literal", + ), + pytest.param( + f.ifnull(column("d"), literal("unknown")), + pa.array(["apple", "unknown", "cherry", "unknown"], type=pa.string_view()), + id="ifnull_literal", + ), + pytest.param( + f.nvl2(column("a"), literal(1), literal(0)), + pa.array([1, 0, 1, 0], type=pa.int64()), + id="nvl2_literal", + ), + pytest.param( + f.greatest(column("h"), column("i")), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="greatest_datetime", + ), + pytest.param( + f.least(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="least_datetime", + ), + pytest.param( + f.coalesce(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_datetime", + ), + pytest.param( + f.nvl(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl_bool", + ), + pytest.param( + f.coalesce(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="coalesce_bool", + ), + pytest.param( + f.nvl2(column("k"), column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl2_bool", + ), + pytest.param( + f.coalesce( + column("h"), + literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ)), + ), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_literal_datetime", + ), + pytest.param( + f.coalesce(column("k"), literal(value=False)), + pa.array([True, False, False, False], type=pa.bool_()), + id="coalesce_literal_bool", + ), + pytest.param( + f.coalesce(column("a"), literal(None), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="coalesce_skip_null_literal", + ), + ], +) +def test_conditional_functions(df_with_nulls, expr, expected): + result = df_with_nulls.select(expr.alias("result")).collect()[0] + assert result.column(0) == expected - result = result_df.collect()[0] - # Verify results - assert result.column(0) == pa.array( - ["Hello", "default", "!"], type=pa.string_view() - ) - assert result.column(1) == pa.array([4, 0, 6], type=pa.int64()) - assert result.column(2) == pa.array( - ["hello ", "default", " !"], type=pa.string_view() +@pytest.mark.parametrize( + ("func", "filter_expr", "expected"), + [ + (f.percentile_cont, None, 3.0), + (f.percentile_cont, column("a") > literal(1.0), 3.5), + (f.quantile_cont, None, 3.0), + ], + ids=["no_filter", "with_filter", "quantile_cont_alias"], +) +def test_percentile_cont(func, filter_expr, expected): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + result = df.aggregate( + [], [func(column("a"), 0.5, filter=filter_expr).alias("v")] + ).collect()[0] + assert result.column(0)[0].as_py() == expected + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_grouping", "expected_sums"), + [ + (GroupingSet.rollup(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.rollup("a"), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube("a"), [0, 0, 1], [30, 30, 60]), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_single_column( + grouping_set_expr, expected_grouping, expected_sums +): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("b")).alias("s"), f.grouping(column("a"))], + ).sort(column("a").sort(ascending=True, nulls_first=False)) + batches = result.collect() + g = pa.concat_arrays([b.column(2) for b in batches]).to_pylist() + s = pa.concat_arrays([b.column("s") for b in batches]).to_pylist() + assert g == expected_grouping + assert s == expected_sums + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_rows"), + [ + # rollup(a, b) => (a,b), (a), () => 3 + 2 + 1 = 6 + (GroupingSet.rollup(column("a"), column("b")), 6), + # cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8 + (GroupingSet.cube(column("a"), column("b")), 8), + (GroupingSet.rollup("a", "b"), 6), + (GroupingSet.cube("a", "b"), 8), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_multi_column(grouping_set_expr, expected_rows): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("c")).alias("s")], ) - assert result.column(3).to_pylist() == [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), - datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - assert result.column(4) == pa.array([False, False, True], type=pa.bool_()) + total_rows = sum(b.num_rows for b in result.collect()) + assert total_rows == expected_rows - # Test multiple arguments - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal(None), literal("fallback")).alias( - "multi_coalesce" - ) + +@pytest.mark.parametrize( + "grouping_set_expr", + [ + GroupingSet.grouping_sets([column("a")], [column("b")]), + GroupingSet.grouping_sets(["a"], ["b"]), + ], + ids=["expr", "str"], +) +def test_grouping_sets_explicit(grouping_set_expr): + # Each row's grouping() value tells you which columns are aggregated across. + ctx = SessionContext() + df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]}) + result = df.aggregate( + [grouping_set_expr], + [ + f.sum(column("c")).alias("s"), + f.grouping(column("a")), + f.grouping(column("b")), + ], + ).sort( + column("a").sort(ascending=True, nulls_first=False), + column("b").sort(ascending=True, nulls_first=False), ) - result = result_df.collect()[0] - assert result.column(0) == pa.array( - ["Hello", "fallback", "!"], type=pa.string_view() + batches = result.collect() + ga = pa.concat_arrays([b.column(3) for b in batches]).to_pylist() + gb = pa.concat_arrays([b.column(4) for b in batches]).to_pylist() + # Rows grouped by (a): ga=0 (a is a key), gb=1 (b is aggregated across) + # Rows grouped by (b): ga=1 (a is aggregated across), gb=0 (b is a key) + assert ga == [0, 0, 1, 1] + assert gb == [1, 1, 0, 0] + + +def test_var_population(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]}) + result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0] + # var_population is an alias for var_pop + expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0] + assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10 + + +def test_get_field(df): + df = df.with_column( + "s", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), ) + result = df.select( + f.get_field(column("s"), "x").alias("x_val"), + f.get_field(column("s"), "y").alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_arrow_metadata(): + ctx = SessionContext() + field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"}) + schema = pa.schema([field]) + batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema) + df = ctx.create_dataframe([[batch]]) + + # One-argument form: returns a Map of all metadata key-value pairs + result = df.select( + f.arrow_metadata(column("val")).alias("meta"), + ).collect()[0] + assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + meta = result.column(0)[0].as_py() + assert ("key1", "value1") in meta + assert ("key2", "value2") in meta + + # Two-argument form: returns the value for a specific metadata key + result = df.select( + f.arrow_metadata(column("val"), "key1").alias("meta_val"), + ).collect()[0] + assert result.column(0)[0].as_py() == "value1" + + +def test_version(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + result = df.select(f.version().alias("v")).collect()[0] + version_str = result.column(0)[0].as_py() + assert "Apache DataFusion" in version_str + + +def test_row(df): + result = df.select( + f.row(column("a"), column("b")).alias("r"), + f.struct(column("a"), column("b")).alias("s"), + ).collect()[0] + # row is an alias for struct, so they should produce the same output + assert result.column(0) == result.column(1) + + +def test_union_tag(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0] + assert result.column(0).to_pylist() == ["int", "str", "int"] + + +def test_union_extract(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0] + assert result.column(0).to_pylist() == [1, None, 2] + + +@pytest.mark.parametrize("func", [f.array_any_value, f.list_any_value]) +def test_any_value_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[None, 2, 3], [None, None, None], [1, 2, 3]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values[0] == 2 + assert values[1] is None + assert values[2] == 1 + + +@pytest.mark.parametrize("func", [f.array_distance, f.list_distance]) +def test_array_distance_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1.0, 2.0]], "b": [[1.0, 4.0]]}) + result = df.select(func(column("a"), column("b")).alias("v")).collect() + assert result[0].column(0)[0].as_py() == pytest.approx(2.0) + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + (f.array_max, [5, 10]), + (f.list_max, [5, 10]), + (f.array_min, [1, 2]), + (f.list_min, [1, 2]), + ], +) +def test_array_min_max(func, expected): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 5, 3], [10, 2]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values == expected + + +@pytest.mark.parametrize("func", [f.array_reverse, f.list_reverse]) +def test_array_reverse_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5]]}) + result = df.select(func(column("a")).alias("v")).collect() + values = [row.as_py() for row in result[0].column(0)] + assert values == [[3, 2, 1], [5, 4]] + + +@pytest.mark.parametrize("func", [f.arrays_zip, f.list_zip]) +def test_arrays_zip_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2]], "b": [[3, 4]]}) + result = df.select(func(column("a"), column("b")).alias("v")).collect() + values = result[0].column(0)[0].as_py() + assert values == [{"c0": 1, "c1": 3}, {"c0": 2, "c1": 4}] + + +@pytest.mark.parametrize("func", [f.string_to_array, f.string_to_list]) +def test_string_to_array_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,world,foo"]}) + result = df.select(func(column("a"), literal(",")).alias("v")).collect() + assert result[0].column(0)[0].as_py() == ["hello", "world", "foo"] + + +def test_string_to_array_with_null_string(): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), literal(","), literal("NA")).alias("v") + ).collect() + values = result[0].column(0)[0].as_py() + assert values == ["hello", None, "world"] + + +@pytest.mark.parametrize("func", [f.gen_series, f.generate_series]) +def test_gen_series_aliases(func): + ctx = SessionContext() + df = ctx.from_pydict({"a": [0]}) + result = df.select(func(literal(1), literal(5)).alias("v")).collect() + assert result[0].column(0)[0].as_py() == [1, 2, 3, 4, 5] + + +def test_gen_series_with_step(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [0]}) + result = df.select( + f.gen_series(literal(1), literal(10), literal(3)).alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] + + +class TestPythonicNativeTypes: + """Tests for accepting native Python types instead of requiring lit().""" + + def test_split_part_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect() + assert result[0].column(0)[0].as_py() == "b" + + def test_encode_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello"]}) + result = df.select(f.encode(column("a"), "base64").alias("e")).collect() + assert result[0].column(0)[0].as_py() == "aGVsbG8" + + def test_date_part_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_part("year", column("a")).alias("y")).collect() + assert result[0].column(0)[0].as_py() == 2021 + + def test_date_trunc_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_trunc("month", column("a")).alias("t")).collect() + assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00" + + def test_left_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["the cat"]}) + result = df.select(f.left(column("a"), 3).alias("l")).collect() + assert result[0].column(0)[0].as_py() == "the" + + def test_round_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.567]}) + result = df.select(f.round(column("a"), 2).alias("r")).collect() + assert result[0].column(0)[0].as_py() == 1.57 + + def test_regexp_count_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["abcabc"]}) + result = df.select( + f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c") + ).collect() + assert result[0].column(0)[0].as_py() == 1 + + def test_log_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [100.0]}) + result = df.select(f.log(10, column("a")).alias("l")).collect() + assert result[0].column(0)[0].as_py() == 2.0 + + def test_power_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [2.0]}) + result = df.select(f.power(column("a"), 3).alias("p")).collect() + assert result[0].column(0)[0].as_py() == 8.0 + + def test_array_slice_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) + result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect() + assert result[0].column(0)[0].as_py() == [2, 3] + + def test_string_to_array_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), ",", null_string="NA").alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == ["hello", None, "world"] + + def test_regexp_replace_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a1 b2 c3"]}) + result = df.select( + f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r") + ).collect() + assert result[0].column(0)[0].as_py() == "aX bX cX" + + def test_backward_compat_with_lit(self): + """Verify that existing code using lit() still works.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select( + f.split_part(column("a"), literal(","), literal(2)).alias("s") + ).collect() + assert result[0].column(0)[0].as_py() == "b" diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 396acbe97..3705fc7ef 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -15,8 +15,16 @@ # specific language governing permissions and limitations # under the License. +import datetime + import pytest -from datafusion import ExecutionPlan, LogicalPlan, SessionContext +from datafusion import ( + ExecutionPlan, + LogicalPlan, + Metric, + MetricsSet, + SessionContext, +) # Note: We must use CSV because memory tables are currently not supported for @@ -40,3 +48,185 @@ def test_logical_plan_to_proto(ctx, df) -> None: execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes) assert str(original_execution_plan) == str(execution_plan) + + +def test_metrics_tree_walk() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + results = plan.collect_metrics() + assert len(results) >= 1 + output_rows_by_op: dict[str, int] = {} + for name, ms in results: + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + if ms.output_rows is not None: + output_rows_by_op[name] = ms.output_rows + + # The filter passes rows where column1 > 1, so exactly + # 2 rows from (1,'a'),(2,'b'),(3,'c'). + # At least one operator must report exactly 2 output rows (the filter). + assert 2 in output_rows_by_op.values(), ( + f"Expected an operator with output_rows=2, got {output_rows_by_op}" + ) + + +def test_metric_properties() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + found_any_metric = False + for _, ms in plan.collect_metrics(): + r = repr(ms) + assert isinstance(r, str) + for metric in ms.metrics(): + found_any_metric = True + assert isinstance(metric, Metric) + assert isinstance(metric.name, str) + assert len(metric.name) > 0 + assert metric.partition is None or isinstance(metric.partition, int) + assert metric.value is None or isinstance( + metric.value, int | datetime.datetime + ) + assert isinstance(metric.labels(), dict) + mr = repr(metric) + assert isinstance(mr, str) + assert len(mr) > 0 + assert found_any_metric, "Expected at least one metric after execution" + + +def test_no_meaningful_metrics_before_execution() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + plan_before = df.execution_plan() + + # Some plan nodes (e.g. DataSourceExec) eagerly initialize a MetricsSet, + # so metrics() may return a set even before execution. However, no rows + # should have been processed yet — output_rows must be absent or zero. + for _, ms in plan_before.collect_metrics(): + rows = ms.output_rows + assert rows is None or rows == 0, ( + f"Expected 0 output_rows before execution, got {rows}" + ) + + # After execution, at least one operator must report rows processed. + df.collect() + plan_after = df.execution_plan() + output_rows_after = [ + ms.output_rows + for _, ms in plan_after.collect_metrics() + if ms.output_rows is not None and ms.output_rows > 0 + ] + assert len(output_rows_after) > 0, "Expected output_rows > 0 after execution" + + +def test_collect_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect_partitioned() + plan = df.execution_plan() + + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_execute_stream_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for _ in df.execute_stream(): + pass + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_execute_stream_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + for stream in df.execute_stream_partitioned(): + for _ in stream: + pass + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}" + + +def test_value_as_datetime() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + if metric.name in ("start_timestamp", "end_timestamp"): + dt = metric.value_as_datetime + assert dt is None or isinstance(dt, datetime.datetime) + if dt is not None: + assert dt.tzinfo is not None + else: + assert metric.value_as_datetime is None + + +def test_metric_names_and_labels() -> None: + """Verify that known metric names appear and labels are well-formed.""" + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + all_metric_names: set[str] = set() + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + all_metric_names.add(metric.name) + # Labels must be a dict of str->str + labels = metric.labels() + for k, v in labels.items(): + assert isinstance(k, str) + assert isinstance(v, str) + + # After a filter query, we expect at minimum these standard metric names. + assert "output_rows" in all_metric_names, ( + f"Expected 'output_rows' in {all_metric_names}" + ) + assert "elapsed_compute" in all_metric_names, ( + f"Expected 'elapsed_compute' in {all_metric_names}" + ) + + +def test_collect_twice_has_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect() + df.collect() + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None + ] + assert len(output_rows_values) > 0