diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md new file mode 100644 index 000000000..57145ac6c --- /dev/null +++ b/.ai/skills/make-pythonic/SKILL.md @@ -0,0 +1,430 @@ + + +--- +name: make-pythonic +description: Audit and improve datafusion-python functions to accept native Python types (int, float, str, bool) instead of requiring explicit lit() or col() wrapping. Analyzes function signatures, checks upstream Rust implementations for type constraints, and applies the appropriate coercion pattern. +argument-hint: [scope] (e.g., "string functions", "datetime functions", "array functions", "math functions", "all", or a specific function name like "split_part") +--- + +# Make Python API Functions More Pythonic + +You are improving the datafusion-python API to feel more natural to Python users. The goal is to allow functions to accept native Python types (int, float, str, bool, etc.) for arguments that are contextually always or typically literal values, instead of requiring users to manually wrap them in `lit()`. + +**Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. + +## How to Identify Candidates + +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. + +For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: + +### Signal 1: Contextual Understanding + +Some arguments are contextually always or almost always literal values based on what the function does: + +| Context | Typical Arguments | Examples | +|---------|------------------|----------| +| **String position/count** | Character counts, indices, repetition counts | `left(str, n)`, `right(str, n)`, `repeat(str, n)`, `lpad(str, count, ...)` | +| **Delimiters/separators** | Fixed separator characters | `split_part(str, delim, idx)`, `concat_ws(sep, ...)` | +| **Search/replace patterns** | Literal search strings, replacements | `replace(str, from, to)`, `regexp_replace(str, pattern, replacement, flags)` | +| **Date/time parts** | Part names from a fixed set | `date_part(part, date)`, `date_trunc(part, date)` | +| **Rounding precision** | Decimal place counts | `round(val, places)`, `trunc(val, places)` | +| **Fill characters** | Padding characters | `lpad(str, count, fill)`, `rpad(str, count, fill)` | + +### Signal 2: Upstream Rust Implementation + +Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFusion function implementation to determine type constraints. The upstream source is cached locally at: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-/src/ +``` + +Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. + +For **aggregate functions**, the upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ +``` + +There are five concrete techniques to check, in order of signal strength: + +#### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) + +Some functions pattern-match on `ColumnarValue::Scalar` in their `invoke_with_args()` method and **return an error** if the argument is a column/array. This means the argument **must** be a literal — passing a column expression will fail at runtime. + +Example from `date_trunc.rs`: +```rust +let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v.to_lowercase() +} else { + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); +}; +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. + +#### Technique 1a: Check `accumulator()` for literal-only enforcement (aggregate functions) + +Technique 1 applies to scalar UDFs. Aggregate functions do not have `invoke_with_args()` — instead, they enforce literal-only arguments in their `accumulator()` (or `create_accumulator()`) method, which runs at planning time before any data is processed. + +Look for these patterns inside `accumulator()`: + +- `get_scalar_value(expr)` — evaluates the expression against an empty batch and errors if it's not a scalar +- `validate_percentile_expr(expr)` — specific helper used by percentile functions +- `downcast_ref::()` — checks that the physical expression is a literal constant + +Example from `approx_percentile_cont.rs`: +```rust +fn accumulator(&self, args: AccumulatorArgs) -> Result { + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; + // ... +} +``` + +Where `validate_percentile_expr` calls `get_scalar_value` and errors with `"must be a literal"`. + +Example from `string_agg.rs`: +```rust +fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which aggregate functions have literal-only arguments, search the upstream aggregate crate for `get_scalar_value`, `validate_percentile_expr`, and `downcast_ref::()` inside `accumulator()` methods. For example, you should expect to find `approx_percentile_cont` (percentile) and `string_agg` (delimiter) among the results. + +#### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) + +Window functions do not have `invoke_with_args()` or `accumulator()`. Instead, they enforce literal-only arguments in their `partition_evaluator()` method, which constructs the evaluator that processes each partition. + +The upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-window-/src/ +``` + +Look for `get_scalar_value_from_args()` calls inside `partition_evaluator()`. This helper (defined in the window crate's `utils.rs`) calls `downcast_ref::()` and errors with `"There is only support Literal types for field at idx: {index} in Window Function"`. + +Example from `ntile.rs`: +```rust +fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, +) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + exec_datafusion_err!("NTILE requires a positive integer") + })?; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which window functions have literal-only arguments, search the upstream window crate for `get_scalar_value_from_args` inside `partition_evaluator()` methods. For example, you should expect to find `ntile` (n) and `lead`/`lag` (offset, default_value) among the results. + +#### Technique 2: Check the `Signature` for data type constraints + +Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. + +Example from `repeat.rs`: +```rust +signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, +), +``` + +This tells you arg 2 (`n`) must be an integer type coerced to Int64. Use this to choose the correct Python type (e.g., `int` not `str` or `float`). + +Common mappings: +| Rust Type Constraint | Python Type | +|---------------------|-------------| +| `logical_int64()` / `TypeSignatureClass::Integer` | `int` | +| `logical_float64()` / `TypeSignatureClass::Numeric` | `int \| float` | +| `logical_string()` / `TypeSignatureClass::String` | `str` | +| `LogicalType::Boolean` | `bool` | + +**Important:** In Python's type system (PEP 484), `float` already accepts `int` values, so `int | float` is redundant and will fail the `ruff` linter (rule PYI041). Use `float` alone when the Rust side accepts a float/numeric type — Python users can still pass integer literals like `log(10, col("a"))` or `power(col("a"), 3)` without issue. Only use `int` when the Rust side strictly requires an integer (e.g., `logical_int64()`). + +#### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage + +Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. + +Example from `round.rs`: +```rust +let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), + Some(None) => None, // argument is not a literal (column) + Some(Some(scalar)) if scalar.is_null() => Some(0), + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), +}; +``` + +**If you find this pattern:** The argument is **Category A** — accept native types AND `Expr`. It works as a column but is primarily used as a literal. + +#### Decision flow + +``` +What kind of function is this? + Scalar UDF: + Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Aggregate: + Is argument rejected at planning time if not a literal? + (check accumulator() for get_scalar_value / validate_percentile_expr / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Window: + Is argument rejected at planning time if not a literal? + (check partition_evaluator() for get_scalar_value_from_args / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + +Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only +``` + +## Coercion Categories + +When making a function more pythonic, apply the correct coercion pattern based on **what the argument represents**: + +### Category A: Arguments That Should Accept Native Types AND Expr + +These are arguments that are *typically* literals but *could* be column references in advanced use cases. For these, accept a union type and coerce native types to `Expr.literal()`. + +**Type hint pattern:** `Expr | int`, `Expr | str`, `Expr | int | str`, etc. + +**When to use:** When the argument could plausibly come from a column in some use case (e.g., the repeat count might come from a column in a data-driven scenario). + +```python +def repeat(string: Expr, n: Expr | int) -> Expr: + """Repeats the ``string`` to ``n`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["ha"]}) + >>> result = df.select( + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) + >>> result.collect_column("r")[0].as_py() + 'hahaha' + """ + if not isinstance(n, Expr): + n = Expr.literal(n) + return Expr(f.repeat(string.expr, n.expr)) +``` + +### Category B: Arguments That Should ONLY Accept Specific Native Types + +These are arguments where an `Expr` never makes sense because the value must be a fixed literal known at query-planning time (not a per-row value). For these, accept only the native type(s) and wrap internally. + +**Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) + +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant, **AND** the parameter was not previously typed as `Expr`: +- Separator in `concat_ws` (already typed as `str` in the Rust binding) +- Index in `array_position` (already typed as `int` in the Rust binding) +- Values that the Rust implementation already accepts as native types + +**Backward compatibility rule:** If a parameter was previously typed as `Expr`, you **must** keep `Expr` in the union even if the Rust side requires a literal. Removing `Expr` would break existing user code like `date_part(lit("year"), col("a"))`. Use **Category A** instead — accept `Expr | str` — and let users who pass column expressions discover the runtime error from the Rust side. Never silently break backward compatibility. + +```python +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list ``args`` with the separator. + + ``separator`` is already typed as ``str`` in the Rust binding, so + there is no backward-compatibility concern. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello"], "b": ["world"]}) + >>> result = df.select( + ... dfn.functions.concat_ws("-", dfn.col("a"), dfn.col("b")).alias("c")) + >>> result.collect_column("c")[0].as_py() + 'hello-world' + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, args)) +``` + +### Category C: Arguments That Should Accept str as Column Name + +In some contexts a string argument naturally refers to a column name rather than a literal. This is the pattern used by DataFrame methods. + +**Type hint pattern:** `Expr | str` + +**When to use:** Only when the string contextually means a column name (rare in `functions.py`, more common in DataFrame methods). + +```python +# Use _to_raw_expr() from expr.py for this pattern +from datafusion.expr import _to_raw_expr + +def some_function(column: Expr | str) -> Expr: + raw = _to_raw_expr(column) # str -> col(str) + return Expr(f.some_function(raw)) +``` + +**IMPORTANT:** In `functions.py`, string arguments almost never mean column names. Functions operate on expressions, and column references should use `col()`. Category C applies mainly to DataFrame methods and context APIs, not to scalar/aggregate/window functions. Do NOT convert string arguments to column expressions in `functions.py` unless there is a very clear reason to do so. + +## Implementation Steps + +For each function being updated: + +### Step 1: Analyze the Function + +1. Read the current Python function signature in `python/datafusion/functions.py` +2. Read the Rust binding in `crates/core/src/functions.rs` +3. Optionally check the upstream DataFusion docs for the function +4. Determine which category (A, B, or C) applies to each parameter + +### Step 2: Update the Python Function + +1. **Change the type hints** to accept native types (e.g., `Expr` -> `Expr | int`) +2. **Add coercion logic** at the top of the function body +3. **Update the docstring** examples to use the simpler calling convention +4. **Preserve backward compatibility** — existing code using `Expr` must still work + +### Step 3: Update Alias Type Hints + +After updating a primary function, find all alias functions that delegate to it (e.g., `instr` and `position` delegate to `strpos`). Update each alias's **parameter type hints** to match the primary function's new signature. Do not add coercion logic to aliases — the primary function handles that. + +### Step 4: Update Docstring Examples (primary functions only) + +Per the project's CLAUDE.md rules: +- Every function must have doctest-style examples +- Optional parameters need examples both without and with the optional args, using keyword argument syntax +- Reuse the same input data across examples where possible + +**Update examples to demonstrate the pythonic calling convention:** + +```python +# BEFORE (old style - still works but verbose) +dfn.functions.left(dfn.col("a"), dfn.lit(3)) + +# AFTER (new style - shown in examples) +dfn.functions.left(dfn.col("a"), 3) +``` + +### Step 5: Run Tests + +After making changes, run the doctests to verify: +```bash +python -m pytest --doctest-modules python/datafusion/functions.py -v +``` + +## Coercion Helper Pattern + +Use the coercion helpers from `datafusion.expr` to convert native Python values to `Expr`. These are the complement of `ensure_expr()` — where `ensure_expr` *rejects* non-`Expr` values, the coercion helpers *wrap* them via `Expr.literal()`. + +**For required parameters** use `coerce_to_expr`: + +```python +from datafusion.expr import coerce_to_expr + +def left(string: Expr, n: Expr | int) -> Expr: + n = coerce_to_expr(n) + return Expr(f.left(string.expr, n.expr)) +``` + +**For optional nullable parameters** use `coerce_to_expr_or_none`: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none + +def regexp_count( + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, +) -> Expr: + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) +``` + +Both helpers are defined in `python/datafusion/expr.py` alongside `ensure_expr`. Import them in `functions.py` via: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none +``` + +## What NOT to Change + +- **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. +- **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. +- **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. +- **Do not add coercion logic to simple aliases.** If a function is just `return other_function(...)`, the primary function handles coercion. However, you **must update the alias's type hints** to match the primary function's signature so that type checkers and documentation accurately reflect what the alias accepts. +- **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. + +## Priority Order + +When auditing functions, process them in this order: + +1. **Date/time functions** — `date_part`, `date_trunc`, `date_bin` — these have the clearest literal arguments +2. **String functions** — `left`, `right`, `repeat`, `lpad`, `rpad`, `split_part`, `substring`, `replace`, `regexp_replace`, `regexp_match`, `regexp_count` — common and verbose without coercion +3. **Math functions** — `round`, `trunc`, `power` — numeric literal arguments +4. **Array functions** — `array_slice`, `array_position`, `array_remove_n`, `array_replace_n`, `array_resize`, `array_element` — index and count arguments +5. **Other functions** — any remaining functions with literal arguments + +## Output Format + +For each function analyzed, report: + +``` +## [Function Name] + +**Current signature:** `function(arg1: Expr, arg2: Expr) -> Expr` +**Proposed signature:** `function(arg1: Expr, arg2: Expr | int) -> Expr` +**Category:** A (accepts native + Expr) +**Arguments changed:** +- `arg2`: Expr -> Expr | int (always a literal count) +**Rust binding:** Takes PyExpr, wraps to literal internally +**Status:** [Changed / Skipped / Needs Discussion] +``` + +If asked to implement (not just audit), make the changes directly and show a summary of what was updated. diff --git a/AGENTS.md b/AGENTS.md index 86c2e9c3b..7d3262710 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,7 +17,14 @@ under the License. --> -# Agent Instructions +# Agent Instructions for Contributors + +This file is for agents working **on** the datafusion-python project (developing, +testing, reviewing). If you need to **use** the DataFusion DataFrame API (write +queries, build expressions, understand available functions), see the user-facing +skill at [`SKILL.md`](SKILL.md). + +## Skills This project uses AI agent skills stored in `.ai/skills/`. Each skill is a directory containing a `SKILL.md` file with instructions for performing a specific task. @@ -26,6 +33,31 @@ Skills follow the [Agent Skills](https://agentskills.io) open standard. Each ski - `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. - Additional supporting files as needed. +## Pull Requests + +Every pull request must follow the template in +`.github/pull_request_template.md`. The description must include these sections: + +1. **Which issue does this PR close?** — Link the issue with `Closes #NNN`. +2. **Rationale for this change** — Why the change is needed (skip if the issue + already explains it clearly). +3. **What changes are included in this PR?** — Summarize the individual changes. +4. **Are there any user-facing changes?** — Note any changes visible to users + (new APIs, changed behavior, new files shipped in the package, etc.). If + there are breaking changes to public APIs, add the `api change` label. + +## Pre-commit Checks + +Always run pre-commit checks **before** committing. The hooks are defined in +`.pre-commit-config.yaml` and run automatically on `git commit` if pre-commit +is installed as a git hook. To run all hooks manually: + +```bash +pre-commit run --all-files +``` + +Fix any failures before committing. + ## Python Function Docstrings Every Python function must include a docstring with usage examples. diff --git a/README.md b/README.md index 7849e7a02..4baed7d1d 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,22 @@ You can verify the installation by running: '0.6.0' ``` +## Using DataFusion with AI coding assistants + +This project ships a [`SKILL.md`](SKILL.md) at the repo root that teaches AI +coding assistants how to write idiomatic DataFusion Python. It follows the +[Agent Skills](https://agentskills.io) open standard. + +**Preferred:** `npx skills add apache/datafusion-python` — installs the skill in +Claude Code, Cursor, Windsurf, Cline, Codex, Copilot, Gemini CLI, and other +supported agents. + +**Manual:** paste this line into your project's `AGENTS.md` / `CLAUDE.md`: + +``` +For DataFusion Python code, see https://github.com/apache/datafusion-python/blob/main/SKILL.md +``` + ## How to develop This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). The Maturin tools used in this workflow can be installed either via `uv` or `pip`. Both approaches should offer the same experience. It is recommended to use `uv` since it has significant performance improvements diff --git a/SKILL.md b/SKILL.md new file mode 100644 index 000000000..7b07b430f --- /dev/null +++ b/SKILL.md @@ -0,0 +1,722 @@ +--- +name: datafusion-python +description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. +--- + +# DataFusion Python DataFrame API Guide + +## What Is DataFusion? + +DataFusion is an **in-process query engine** built on Apache Arrow. It is not a +database -- there is no server, no connection string, and no external +dependencies. You create a `SessionContext`, point it at data (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API described below. + +All data flows through **Apache Arrow**. The canonical Python implementation is +PyArrow (`pyarrow.RecordBatch` / `pyarrow.Table`), but any library that +conforms to the [Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +can interoperate with DataFusion. + +## Core Abstractions + +| Abstraction | Role | Key import | +|---|---|---| +| `SessionContext` | Entry point. Loads data, runs SQL, produces DataFrames. | `from datafusion import SessionContext` | +| `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | +| `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | +| `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | + +## Import Conventions + +```python +from datafusion import SessionContext, col, lit +from datafusion import functions as F +``` + +## Data Loading + +```python +ctx = SessionContext() + +# From files +df = ctx.read_parquet("path/to/data.parquet") +df = ctx.read_csv("path/to/data.csv") +df = ctx.read_json("path/to/data.json") + +# From Python objects +df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) +df = ctx.from_pylist([{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]) +df = ctx.from_pandas(pandas_df) +df = ctx.from_polars(polars_df) +df = ctx.from_arrow(arrow_table) + +# From SQL +df = ctx.sql("SELECT a, b FROM my_table WHERE a > 1") +``` + +To make a DataFrame queryable by name in SQL, register it first: + +```python +ctx.register_parquet("my_table", "path/to/data.parquet") +ctx.register_csv("my_table", "path/to/data.csv") +``` + +## DataFrame Operations Quick Reference + +Every method returns a **new** DataFrame (immutable/lazy). Chain them fluently. + +### Projection + +```python +df.select("a", "b") # preferred: plain names as strings +df.select(col("a"), (col("b") + 1).alias("b_plus_1")) # use col()/Expr only when you need an expression + +df.with_column("new_col", col("a") + lit(10)) # add one column +df.with_columns( + col("a").alias("x"), + y=col("b") + lit(1), # named keyword form +) + +df.drop("unwanted_col") +df.with_column_renamed("old_name", "new_name") +``` + +When a column is referenced by name alone, pass the name as a string rather +than wrapping it in `col()`. Reach for `col()` only when the projection needs +arithmetic, aliasing, casting, or another expression operation. + +**Case sensitivity**: both `select("Name")` and `col("Name")` lowercase the +identifier. For a column whose real name has uppercase letters, embed double +quotes inside the string: `select('"MyCol"')` or `col('"MyCol"')`. Without the +inner quotes the lookup will fail with `No field named mycol`. + +### Filtering + +```python +df.filter(col("a") > 10) +df.filter(col("a") > 10, col("b") == "x") # multiple = AND +df.filter("a > 10") # SQL expression string +``` + +Raw Python values on the right-hand side of a comparison are auto-wrapped +into literals by the `Expr` operators, so prefer `col("a") > 10` over +`col("a") > lit(10)`. See the Comparisons section and pitfall #2 for the +full rule. + +### Aggregation + +```python +# GROUP BY a, compute sum(b) and count(*) +df.aggregate(["a"], [F.sum(col("b")), F.count(col("a"))]) + +# HAVING equivalent: use the filter keyword on the aggregate function +df.aggregate( + ["region"], + [F.sum(col("sales"), filter=col("sales") > 1000).alias("large_sales")], +) +``` + +As with `select()`, group keys can be passed as plain name strings. Reach for +`col(...)` only when the grouping expression needs arithmetic, aliasing, +casting, or another expression operation. + +Most aggregate functions accept an optional `filter` keyword argument. When +provided, only rows where the filter expression is true contribute to the +aggregate. + +### Sorting + +```python +df.sort("a") # ascending (plain name, preferred) +df.sort(col("a")) # ascending via col() +df.sort(col("a").sort(ascending=False)) # descending +df.sort(col("a").sort(nulls_first=False)) # override null placement + +df.sort_by("a", "b") # ascending-only shortcut +``` + +As with `select()` and `aggregate()`, bare column references can be passed as +plain name strings. A plain expression passed to `sort()` is already treated +as ascending, so reach for `col(...).sort(...)` only when you need to override +a default (descending order or null placement). Writing +`col("a").sort(ascending=True)` is redundant. + +For ascending-only sorts with no null-placement override, `df.sort_by(...)` is +a shorter alias for `df.sort(...)`. + +### Joining + +```python +# Equi-join on shared column name +df1.join(df2, on="key") +df1.join(df2, on="key", how="left") + +# Different column names +df1.join(df2, left_on="id", right_on="fk_id", how="inner") + +# Expression-based join (supports inequality predicates) +df1.join_on(df2, col("a") == col("b"), how="inner") + +# Semi join: keep rows from left where a match exists in right (like EXISTS) +df1.join(df2, on="key", how="semi") + +# Anti join: keep rows from left where NO match exists in right (like NOT EXISTS) +df1.join(df2, on="key", how="anti") +``` + +Join types: `"inner"`, `"left"`, `"right"`, `"full"`, `"semi"`, `"anti"`. + +Inner is the default `how`. Prefer `df1.join(df2, on="key")` over +`df1.join(df2, on="key", how="inner")` — drop `how=` unless you need a +non-inner join type. + +When the two sides' join columns have different native names, use +`left_on=`/`right_on=` with the original names rather than aliasing one side +to match the other — see pitfall #7. + +### Window Functions + +```python +from datafusion import WindowFrame + +# Row number partitioned by group, ordered by value +df.window( + F.row_number( + partition_by=[col("group")], + order_by=[col("value")], + ).alias("rn") +) + +# Using a Window object for reuse +from datafusion.expr import Window + +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], +) +df.select( + col("group"), + col("value"), + F.sum(col("value")).over(win).alias("running_total"), +) + +# With explicit frame bounds +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], + window_frame=WindowFrame("rows", 0, None), # current row to unbounded following +) +``` + +### Set Operations + +```python +df1.union(df2) # UNION ALL (by position) +df1.union(df2, distinct=True) # UNION DISTINCT +df1.union_by_name(df2) # match columns by name, not position +df1.intersect(df2) # INTERSECT ALL +df1.intersect(df2, distinct=True) # INTERSECT (distinct) +df1.except_all(df2) # EXCEPT ALL +df1.except_all(df2, distinct=True) # EXCEPT (distinct) +``` + +### Limit and Offset + +```python +df.limit(10) # first 10 rows +df.limit(10, offset=20) # skip 20, then take 10 +``` + +### Deduplication + +```python +df.distinct() # remove duplicate rows +df.distinct_on( # keep first row per group (like DISTINCT ON in Postgres) + [col("a")], # uniqueness columns + [col("a"), col("b")], # output columns + [col("b").sort(ascending=True)], # which row to keep +) +``` + +## Executing and Collecting Results + +DataFrames are lazy until you collect. + +```python +df.show() # print formatted table to stdout +batches = df.collect() # list[pa.RecordBatch] +arr = df.collect_column("col_name") # pa.Array | pa.ChunkedArray (single column) +table = df.to_arrow_table() # pa.Table +pandas_df = df.to_pandas() # pd.DataFrame +polars_df = df.to_polars() # pl.DataFrame +py_dict = df.to_pydict() # dict[str, list] +py_list = df.to_pylist() # list[dict] +count = df.count() # int +``` + +### Date and Timestamp Type Conversion + +The Python type returned by `to_pydict()` / `to_pylist()` depends on the Arrow +column type, and the mapping is inherited from PyArrow: + +| Arrow type | Python type returned | +|---|---| +| `timestamp(s)` / `(ms)` / `(us)` | `datetime.datetime` | +| `timestamp(ns)` | `pandas.Timestamp` | +| `date32` / `date64` | `datetime.date` | +| `duration(s)` / `(ms)` / `(us)` | `datetime.timedelta` | +| `duration(ns)` | `pandas.Timedelta` | + +The nanosecond-precision fallback to pandas types is the main surprise: +pandas is not a hard dependency of `datafusion`, but PyArrow reaches for it +when `datetime.datetime` / `datetime.timedelta` would lose precision (stdlib +types only go to microseconds). If you need plain stdlib types, cast to a +coarser unit before collecting, e.g. +`df.select(col("ts").cast(pa.timestamp("us")))`. + +`df.to_pandas()` has its own footgun for dates: pandas has no pure-date dtype, +so a `date32`/`date64` column comes back as an `object` column of +`datetime.date` values rather than `datetime64[ns]`. If downstream code +expects a datetime column, cast on the DataFusion side first: +`col("ship_date").cast(pa.timestamp("ns"))`. + +### Streaming Results + +Prefer streaming over `collect()` when the result is too large to materialize +in memory, when you want to start processing before the query finishes, or +when you may break out of the loop early. `execute_stream()` pulls one +`RecordBatch` at a time from the execution plan rather than buffering the +whole result up front. + +```python +# Single-partition stream; batch is a datafusion.RecordBatch +stream = df.execute_stream() +for batch in stream: + process(batch.to_pyarrow()) # convert to pa.RecordBatch if needed + +# DataFrame is iterable directly (delegates to execute_stream) +for batch in df: + process(batch.to_pyarrow()) + +# One stream per partition, for parallel consumption +for stream in df.execute_stream_partitioned(): + for batch in stream: + process(batch.to_pyarrow()) +``` + +Async iteration is also supported via `async for batch in df: ...` (or +`df.execute_stream()`), which is useful when batches are interleaved with +other I/O. + +### Writing Results + +```python +df.write_parquet("output.parquet") +df.write_csv("output.csv") +df.write_json("output.json") +``` + +You can also pass a directory path (e.g., `"output/"`) to write a multi-file +partitioned output. + +## Expression Building + +### Column References and Literals + +```python +col("column_name") # reference a column +lit(42) # integer literal +lit("hello") # string literal +lit(3.14) # float literal +lit(pa.scalar(value)) # PyArrow scalar (preserves Arrow type) +``` + +`lit()` accepts PyArrow scalars directly -- prefer this over converting Arrow +data to Python and back when working with values extracted from query results. + +### Arithmetic + +```python +col("price") * col("quantity") # multiplication +col("a") + lit(1) # addition +col("a") - col("b") # subtraction +col("a") / lit(2) # division +col("a") % lit(3) # modulo +``` + +### Date Arithmetic + +`Date32` and `Date64` columns both require `Interval` types for arithmetic, +not `Duration`. Use PyArrow's `month_day_nano_interval` type, which takes a +`(months, days, nanos)` tuple: + +```python +import pyarrow as pa + +# Subtract 90 days from a date column +col("ship_date") - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) + +# Subtract 3 months +col("ship_date") - lit(pa.scalar((3, 0, 0), type=pa.month_day_nano_interval())) +``` + +**Important**: `lit(datetime.timedelta(days=90))` creates a `Duration(µs)` +literal, which is **not** compatible with `Date32`/`Date64` arithmetic +(`Duration(ms)` and `Duration(ns)` are rejected too). Always use +`pa.month_day_nano_interval()` for date operations. + +**Timestamps behave differently**: `Timestamp` columns *do* accept `Duration`, +so `col("ts") - lit(datetime.timedelta(days=1))` works. The interval-only +rule applies specifically to date columns. + +### Comparisons + +```python +col("a") > 10 +col("a") >= 10 +col("a") < 10 +col("a") <= 10 +col("a") == "x" +col("a") != "x" +col("a") == None # same as col("a").is_null() +col("a") != None # same as col("a").is_not_null() +``` + +Comparison operators auto-wrap the right-hand Python value into a literal, +so writing `col("a") > lit(10)` is redundant. Drop the `lit()` in +comparisons. Reach for `lit()` only when auto-wrapping does not apply — see +pitfall #2. + +### Boolean Logic + +**Important**: Python's `and`, `or`, `not` keywords do NOT work with Expr +objects. You must use the bitwise operators: + +```python +(col("a") > 1) & (col("b") < 10) # AND +(col("a") > 1) | (col("b") < 10) # OR +~(col("a") > 1) # NOT +``` + +Always wrap each comparison in parentheses when combining with `&`, `|`, `~` +because Python's operator precedence for bitwise operators is different from +logical operators. + +### Null Handling + +```python +col("a").is_null() +col("a").is_not_null() +col("a").fill_null(lit(0)) # replace NULL with a value +F.coalesce(col("a"), col("b")) # first non-null value +F.nullif(col("a"), lit(0)) # return NULL if a == 0 +``` + +### CASE / WHEN + +```python +# Simple CASE (matching on a single expression) +status_label = ( + F.case(col("status")) + .when(lit("A"), lit("Active")) + .when(lit("I"), lit("Inactive")) + .otherwise(lit("Unknown")) +) + +# Searched CASE (each branch has its own predicate) +severity = ( + F.when(col("value") > 100, lit("high")) + .when(col("value") > 50, lit("medium")) + .otherwise(lit("low")) +) +``` + +### Casting + +```python +import pyarrow as pa + +col("a").cast(pa.float64()) +col("a").cast(pa.utf8()) +col("a").cast(pa.date32()) +``` + +### Aliasing + +```python +(col("a") + col("b")).alias("total") +``` + +### BETWEEN and IN + +```python +col("a").between(lit(1), lit(10)) # 1 <= a <= 10 +F.in_list(col("a"), [lit(1), lit(2), lit(3)]) # a IN (1, 2, 3) +F.in_list(col("a"), [lit(1), lit(2)], negated=True) # a NOT IN (1, 2) +``` + +### Struct and Array Access + +```python +col("struct_col")["field_name"] # access struct field +col("array_col")[0] # access array element (0-indexed) +col("array_col")[1:3] # array slice (0-indexed) +``` + +## SQL-to-DataFrame Reference + +| SQL | DataFrame API | +|---|---| +| `SELECT a, b` | `df.select("a", "b")` | +| `SELECT a, b + 1 AS c` | `df.select(col("a"), (col("b") + lit(1)).alias("c"))` | +| `SELECT *, a + 1 AS c` | `df.with_column("c", col("a") + lit(1))` | +| `WHERE a > 10` | `df.filter(col("a") > 10)` | +| `GROUP BY a` with `SUM(b)` | `df.aggregate(["a"], [F.sum(col("b"))])` | +| `SUM(b) FILTER (WHERE b > 100)` | `F.sum(col("b"), filter=col("b") > 100)` | +| `ORDER BY a DESC` | `df.sort(col("a").sort(ascending=False))` | +| `LIMIT 10 OFFSET 5` | `df.limit(10, offset=5)` | +| `DISTINCT` | `df.distinct()` | +| `a INNER JOIN b ON a.id = b.id` | `a.join(b, on="id")` | +| `a LEFT JOIN b ON a.id = b.fk` | `a.join(b, left_on="id", right_on="fk", how="left")` | +| `WHERE EXISTS (SELECT ...)` | `a.join(b, on="key", how="semi")` | +| `WHERE NOT EXISTS (SELECT ...)` | `a.join(b, on="key", how="anti")` | +| `UNION ALL` | `df1.union(df2)` | +| `UNION` (distinct) | `df1.union(df2, distinct=True)` | +| `INTERSECT ALL` | `df1.intersect(df2)` | +| `INTERSECT` (distinct) | `df1.intersect(df2, distinct=True)` | +| `EXCEPT ALL` | `df1.except_all(df2)` | +| `EXCEPT` (distinct) | `df1.except_all(df2, distinct=True)` | +| `CASE x WHEN 1 THEN 'a' END` | `F.case(col("x")).when(lit(1), lit("a")).end()` | +| `CASE WHEN x > 1 THEN 'a' END` | `F.when(col("x") > 1, lit("a")).end()` | +| `x IN (1, 2, 3)` | `F.in_list(col("x"), [lit(1), lit(2), lit(3)])` | +| `x BETWEEN 1 AND 10` | `col("x").between(lit(1), lit(10))` | +| `CAST(x AS DOUBLE)` | `col("x").cast(pa.float64())` | +| `ROW_NUMBER() OVER (...)` | `F.row_number(partition_by=[...], order_by=[...])` | +| `SUM(x) OVER (...)` | `F.sum(col("x")).over(window)` | +| `x IS NULL` | `col("x").is_null()` | +| `COALESCE(a, b)` | `F.coalesce(col("a"), col("b"))` | + +## Common Pitfalls + +1. **Boolean operators**: Use `&`, `|`, `~` -- not Python's `and`, `or`, `not`. + Always parenthesize: `(col("a") > 1) & (col("b") < 2)`. + +2. **Wrapping scalars with `lit()`**: Prefer raw Python values on the + right-hand side of comparisons — `col("a") > 10`, `col("name") == "Alice"` + — because the Expr comparison operators auto-wrap them. Writing + `col("a") > lit(10)` is redundant. Reserve `lit()` for places where + auto-wrapping does *not* apply: + - standalone scalars passed into function calls: + `F.coalesce(col("a"), lit(0))`, not `F.coalesce(col("a"), 0)` + - arithmetic between two literals with no column involved: + `lit(1) - col("discount")` is fine, but `lit(1) - lit(2)` needs both + - values that must carry a specific Arrow type, via `lit(pa.scalar(...))` + - `.when(...)`, `.otherwise(...)`, `F.nullif(...)`, `.between(...)`, + `F.in_list(...)` and similar method/function arguments + +3. **Column name quoting**: Column names are normalized to lowercase by default + in both `select("...")` and `col("...")`. To reference a column with + uppercase letters, use double quotes inside the string: + `select('"MyColumn"')` or `col('"MyColumn"')`. + +4. **DataFrames are immutable**: Every method returns a **new** DataFrame. You + must capture the return value: + ```python + df = df.filter(col("a") > 1) # correct + df.filter(col("a") > 1) # WRONG -- result is discarded + ``` + +5. **Window frame defaults**: When using `order_by` in a window, the default + frame is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`. For a full + partition frame, set `window_frame=WindowFrame("rows", None, None)`. + +6. **Arithmetic on aggregates belongs in a later `select`, not inside + `aggregate`**: Each item in the aggregate list must be a single aggregate + call (optionally aliased). Combining aggregates with arithmetic inside + `aggregate(...)` fails with `Internal error: Invalid aggregate expression`. + Alias the aggregates, then compute the combination downstream: + ```python + # WRONG -- arithmetic wraps two aggregates + df.aggregate([], [(lit(100) * F.sum(col("a")) / F.sum(col("b"))).alias("ratio")]) + + # CORRECT -- aggregate first, then combine + (df.aggregate([], [F.sum(col("a")).alias("num"), F.sum(col("b")).alias("den")]) + .select((lit(100) * col("num") / col("den")).alias("ratio"))) + ``` + +7. **Don't alias a join column to match the other side**: When equi-joining + with `on="key"`, renaming the join column on one side via `.alias("key")` + in a fresh projection creates a schema where one side's `key` is + qualified (`?table?.key`) and the other is unqualified. The join then + fails with `Schema contains qualified field name ... and unqualified + field name ... which would be ambiguous`. Use `left_on=`/`right_on=` with + the native names, or use `join_on(...)` with an explicit equality. + ```python + # WRONG -- alias on one side produces ambiguous schema after join + failed = orders.select(col("o_orderkey").alias("l_orderkey")) + li.join(failed, on="l_orderkey") # ambiguous l_orderkey error + + # CORRECT -- keep native names, use left_on/right_on + failed = orders.select("o_orderkey") + li.join(failed, left_on="l_orderkey", right_on="o_orderkey") + + # ALSO CORRECT -- explicit predicate via join_on + # (note: join_on keeps both key columns in the output, unlike on="key") + li.join_on(failed, col("l_orderkey") == col("o_orderkey")) + ``` + +## Idiomatic Patterns + +### Fluent Chaining + +```python +result = ( + ctx.read_parquet("data.parquet") + .filter(col("year") >= 2020) + .select(col("region"), col("sales")) + .aggregate(["region"], [F.sum(col("sales")).alias("total")]) + .sort(col("total").sort(ascending=False)) + .limit(10) +) +result.show() +``` + +### Using Variables as CTEs + +Instead of SQL CTEs (`WITH ... AS`), assign intermediate DataFrames to +variables: + +```python +base = ctx.read_parquet("orders.parquet").filter(col("status") == "shipped") +by_region = base.aggregate(["region"], [F.sum(col("amount")).alias("total")]) +top_regions = by_region.filter(col("total") > 10000) +``` + +### Reusing Expressions as Variables + +Just like DataFrames, expressions (`Expr`) can be stored in variables and used +anywhere an `Expr` is expected. This is useful for building up complex +expressions or reusing a computed value across multiple operations: + +```python +# Build an expression and reuse it +disc_price = col("price") * (lit(1) - col("discount")) +df = df.select( + col("id"), + disc_price.alias("disc_price"), + (disc_price * (lit(1) + col("tax"))).alias("total"), +) + +# Use a collected scalar as an expression +max_val = result_df.collect_column("max_price")[0] # PyArrow scalar +cutoff = lit(max_val) - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) +df = df.filter(col("ship_date") <= cutoff) # cutoff is already an Expr +``` + +**Important**: Do not wrap an `Expr` in `lit()`. `lit()` is for converting +Python/PyArrow values into expressions. If a value is already an `Expr`, use it +directly. + +### Window Functions for Scalar Subqueries + +Where SQL uses a correlated scalar subquery, the idiomatic DataFrame approach +is a window function: + +```sql +-- SQL scalar subquery +SELECT *, (SELECT SUM(b) FROM t WHERE t.group = s.group) AS group_total FROM s +``` + +```python +# DataFrame: window function +win = Window(partition_by=[col("group")]) +df = df.with_column("group_total", F.sum(col("b")).over(win)) +``` + +### Semi/Anti Joins for EXISTS / NOT EXISTS + +```sql +-- SQL: WHERE EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="semi") + +-- SQL: WHERE NOT EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="anti") +``` + +### Computed Columns + +```python +# Add computed columns while keeping all originals +df = df.with_column("full_name", F.concat(col("first"), lit(" "), col("last"))) +df = df.with_column("discounted", col("price") * lit(0.9)) +``` + +## Available Functions (Categorized) + +The `functions` module (imported as `F`) provides 290+ functions. Key categories: + +**Aggregate**: `sum`, `avg`, `min`, `max`, `count`, `count_star`, `median`, +`stddev`, `stddev_pop`, `var_samp`, `var_pop`, `corr`, `covar`, `approx_distinct`, +`approx_median`, `approx_percentile_cont`, `array_agg`, `string_agg`, +`first_value`, `last_value`, `bit_and`, `bit_or`, `bit_xor`, `bool_and`, +`bool_or`, `grouping`, `regr_*` (9 regression functions) + +**Window**: `row_number`, `rank`, `dense_rank`, `percent_rank`, `cume_dist`, +`ntile`, `lag`, `lead`, `first_value`, `last_value`, `nth_value` + +**String**: `length`, `lower`, `upper`, `trim`, `ltrim`, `rtrim`, `lpad`, +`rpad`, `starts_with`, `ends_with`, `contains`, `substr`, `substring`, +`replace`, `reverse`, `repeat`, `split_part`, `concat`, `concat_ws`, +`initcap`, `ascii`, `chr`, `left`, `right`, `strpos`, `translate`, `overlay`, +`levenshtein` + +`F.substr(str, start)` takes **only two arguments** and returns the tail of +the string from `start` onward — passing a third length argument raises +`TypeError: substr() takes 2 positional arguments but 3 were given`. For the +SQL-style 3-arg form (`SUBSTRING(str FROM start FOR length)`), use +`F.substring(col("s"), lit(start), lit(length))`. For a fixed-length prefix, +`F.left(col("s"), lit(n))` is cleanest. + +```python +# WRONG — substr does not accept a length argument +F.substr(col("c_phone"), lit(1), lit(2)) +# CORRECT +F.substring(col("c_phone"), lit(1), lit(2)) # explicit length +F.left(col("c_phone"), lit(2)) # prefix shortcut +``` + +**Math**: `abs`, `ceil`, `floor`, `round`, `trunc`, `sqrt`, `cbrt`, `exp`, +`ln`, `log`, `log2`, `log10`, `pow`, `signum`, `pi`, `random`, `factorial`, +`gcd`, `lcm`, `greatest`, `least`, sin/cos/tan and inverse/hyperbolic variants + +**Date/Time**: `now`, `today`, `current_date`, `current_time`, +`current_timestamp`, `date_part`, `date_trunc`, `date_bin`, `extract`, +`to_timestamp`, `to_timestamp_millis`, `to_timestamp_micros`, +`to_timestamp_nanos`, `to_timestamp_seconds`, `to_unixtime`, `from_unixtime`, +`make_date`, `make_time`, `to_date`, `to_time`, `to_local_time`, `date_format` + +**Conditional**: `case`, `when`, `coalesce`, `nullif`, `ifnull`, `nvl`, `nvl2` + +**Array/List**: `array`, `make_array`, `array_agg`, `array_length`, +`array_element`, `array_slice`, `array_append`, `array_prepend`, +`array_concat`, `array_has`, `array_has_all`, `array_has_any`, `array_position`, +`array_remove`, `array_distinct`, `array_sort`, `array_reverse`, `flatten`, +`array_to_string`, `array_intersect`, `array_union`, `array_except`, +`generate_series` +(Most `array_*` functions also have `list_*` aliases.) + +**Struct/Map**: `struct`, `named_struct`, `get_field`, `make_map`, `map_keys`, +`map_values`, `map_entries`, `map_extract` + +**Regex**: `regexp_like`, `regexp_match`, `regexp_replace`, `regexp_count`, +`regexp_instr` + +**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest` + +**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata` + +**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, +`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index b2db144e8..a7a497dab 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -48,4 +48,5 @@ benchmarks/tpch/create_tables.sql .cargo/config.toml **/.cargo/config.toml uv.lock -examples/tpch/answers_sf1/*.tbl \ No newline at end of file +examples/tpch/answers_sf1/*.tbl +SKILL.md \ No newline at end of file diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f97f00dc..105f1632d 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -27,6 +27,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + lineitem + where + l_shipdate <= date '1998-12-01' - interval '90 days' + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus; """ import pyarrow as pa @@ -58,31 +82,25 @@ # Aggregate the results +disc_price = col("l_extendedprice") * (lit(1) - col("l_discount")) + df = df.aggregate( - [col("l_returnflag"), col("l_linestatus")], + ["l_returnflag", "l_linestatus"], [ F.sum(col("l_quantity")).alias("sum_qty"), F.sum(col("l_extendedprice")).alias("sum_base_price"), - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "sum_disc_price" - ), - F.sum( - col("l_extendedprice") - * (lit(1) - col("l_discount")) - * (lit(1) + col("l_tax")) - ).alias("sum_charge"), + F.sum(disc_price).alias("sum_disc_price"), + F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"), F.avg(col("l_quantity")).alias("avg_qty"), F.avg(col("l_extendedprice")).alias("avg_price"), F.avg(col("l_discount")).alias("avg_disc"), - F.count(col("l_returnflag")).alias( - "count_order" - ), # Counting any column should return same result + F.count_star().alias("count_order"), ], ) # Sort per the expected result -df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort()) +df = df.sort_by("l_returnflag", "l_linestatus") # Note: There appears to be a discrepancy between what is returned here and what is in the generated # answers file for the case of return flag N and line status O, but I did not investigate further. diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 47961d2ef..c5c6b9c0b 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -27,6 +27,52 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment + from + part, + supplier, + partsupp, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) + order by + s_acctbal desc, + n_name, + s_name, + p_partkey limit 100; """ import datafusion @@ -67,35 +113,30 @@ "r_regionkey", "r_name" ) -# Filter down parts. Part names contain the type of interest, so we can use strpos to find where -# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position -# in the string where it is located. +# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which +# is an ``ends_with`` check; use the dedicated string function rather than +# a manual substring match. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0) -).filter(col("p_size") == lit(SIZE_OF_INTEREST)) + F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)), + col("p_size") == SIZE_OF_INTEREST, +) # Filter regions down to the one of interest -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join( - df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" -) -df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -) +df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey") +df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join( - df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" -) +df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey") # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -112,9 +153,9 @@ ), ) -df = df.filter(col("min_cost") == col("ps_supplycost")) - -df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") +df = df.filter(col("min_cost") == col("ps_supplycost")).join( + df_part, left_on="ps_partkey", right_on="p_partkey" +) # From the problem statement, these are the values we wish to output @@ -132,12 +173,10 @@ # Sort and display 100 entries df = df.sort( col("s_acctbal").sort(ascending=False), - col("n_name").sort(), - col("s_name").sort(), - col("p_partkey").sort(), -) - -df = df.limit(100) + "n_name", + "s_name", + "p_partkey", +).limit(100) # Show results diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index fc1231e0a..880c7435f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -25,6 +25,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority + from + customer, + orders, + lineitem + where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' + group by + l_orderkey, + o_orderdate, + o_shippriority + order by + revenue desc, + o_orderdate limit 10; """ from datafusion import SessionContext, col, lit @@ -50,20 +75,20 @@ # Limit dataframes to the rows of interest -df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST)) +df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST) df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST)) df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST)) # Join all 3 dataframes -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") +df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join( + df_lineitem, left_on="o_orderkey", right_on="l_orderkey" +) # Compute the revenue df = df.aggregate( - [col("l_orderkey")], + ["l_orderkey"], [ F.first_value(col("o_orderdate")).alias("o_orderdate"), F.first_value(col("o_shippriority")).alias("o_shippriority"), @@ -71,17 +96,13 @@ ], ) -# Sort by priority - -df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort()) - -# Only return 10 results +# Sort by priority, take 10, and project in the order expected by the spec. -df = df.limit(10) - -# Change the order that the columns are reported in just to match the spec - -df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +df = ( + df.sort(col("revenue").sort(ascending=False), "o_orderdate") + .limit(10) + .select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +) # Show result diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 426338aea..6f11c1383 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -24,18 +24,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_orderpriority, + count(*) as order_count + from + orders + where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) + group by + o_orderpriority + order by + o_orderpriority; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -# Ideally we could put 3 months into the interval. See note below. -INTERVAL_DAYS = 92 -DATE_OF_INTEREST = "1993-07-01" +QUARTER_START = date(1993, 7, 1) +QUARTER_END = date(1993, 10, 1) # Load the dataframes we need @@ -48,36 +70,23 @@ "l_orderkey", "l_commitdate", "l_receiptdate" ) -# Create a date object from the string -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - -# Limit results to cases where commitment date before receipt date -# Aggregate the results so we only get one row to join with the order table. -# Alternately, and likely more idiomatic is instead of `.aggregate` you could -# do `.select("l_orderkey").distinct()`. The goal here is to show -# multiple examples of how to use Data Fusion. -df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate( - [col("l_orderkey")], [] +# Keep only orders in the quarter of interest, then restrict to those that +# have at least one late lineitem via a semi join (the DataFrame form of +# ``EXISTS`` from the reference SQL). +df_orders = df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) -# Limit orders to date range of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) -) +late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) -# Perform the join to find only orders for which there are lineitems outside of expected range df = df_orders.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" + late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi" ) -# Based on priority, find the number of entries -df = df.aggregate( - [col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")] +# Count the number of orders in each priority group and sort. +df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by( + "o_orderpriority" ) -# Sort the results -df = df.sort(col("o_orderpriority").sort()) - df.show() diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index fa2b01dea..bfdba5d4c 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -27,23 +27,45 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue + from + customer, + orders, + lineitem, + supplier, + nation, + region + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year + group by + n_name + order by + revenue desc; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_OF_INTEREST = "1994-01-01" -INTERVAL_DAYS = 365 +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) REGION_OF_INTEREST = "ASIA" -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -68,38 +90,32 @@ ) # Restrict dataframes to cases of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) +df_orders = df_orders.filter( + col("o_orderdate") >= lit(YEAR_START), + col("o_orderdate") < lit(YEAR_END), ) -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Join all the dataframes df = ( - df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" - ) - .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join( df_supplier, left_on=["l_suppkey", "c_nationkey"], right_on=["s_suppkey", "s_nationkey"], - how="inner", ) - .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(df_region, left_on="n_regionkey", right_on="r_regionkey") ) -# Compute the final result +# Compute the final result, then sort in descending order. df = df.aggregate( - [col("n_name")], + ["n_name"], [F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")], -) - -# Sort in descending order - -df = df.sort(col("revenue").sort(ascending=False)) +).sort(col("revenue").sort(ascending=False)) df.show() diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 1de5848b1..ed54d22a4 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -27,28 +27,34 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice * l_discount) as revenue + from + lineitem + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path # Variables from the example query -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) DISCOUT = 0.06 DELTA = 0.01 QUANTITY = 24 -INTERVAL_DAYS = 365 - -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -59,12 +65,11 @@ # Filter down to lineitems of interest -df = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA)) - .filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA)) - .filter(col("l_quantity") < lit(QUANTITY)) +df = df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), + col("l_quantity") < QUANTITY, ) # Add up all the "lost" revenue diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index ff2f891f1..df1c2ae0d 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -26,9 +26,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue + from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit @@ -40,11 +82,8 @@ nation_1 = lit("FRANCE") nation_2 = lit("GERMANY") -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -69,60 +108,44 @@ # Filter to time of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( - col("l_shipdate") <= end_date +df_lineitem = df_lineitem.filter( + col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE) ) -# A simpler way to do the following operation is to use a filter, but we also want to demonstrate -# how to use case statements. Here we are assigning `n_name` to be itself when it is either of -# the two nations of interest. Since there is no `otherwise()` statement, any values that do -# not match these will result in a null value and then get filtered out. -# -# To do the same using a simple filter would be: -# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001 -df_nation = df_nation.with_column( - "n_name", - F.case(col("n_name")) - .when(nation_1, col("n_name")) - .when(nation_2, col("n_name")) - .end(), -).filter(~col("n_name").is_null()) +# Limit the nation table to the two nations of interest. +df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2])) # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("c_custkey"), col("n_name").alias("cust_nation")) + df_nation, left_on="c_nationkey", right_on="n_nationkey" +).select("c_custkey", col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join( - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" - ) - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .filter(col("cust_nation") != col("supp_nation")) ) # Extract out two values for every line item -df = df.with_column( - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) +df = df.with_columns( + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), +) -# Aggregate the results +# Aggregate and sort per the spec. df = df.aggregate( - [col("supp_nation"), col("cust_nation"), col("l_year")], + ["supp_nation", "cust_nation", "l_year"], [F.sum(col("volume")).alias("revenue")], -) - -# Sort based on problem statement requirements -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) +).sort_by("supp_nation", "cust_nation", "l_year") df.show() diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 4bf50efba..dd7bacedb 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -25,24 +25,61 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share + from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations + group by + o_year + order by + o_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -supplier_nation = lit("BRAZIL") -customer_region = lit("AMERICA") -part_of_interest = lit("ECONOMY ANODIZED STEEL") - -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" +supplier_nation = "BRAZIL" +customer_region = "AMERICA" +part_of_interest = "ECONOMY ANODIZED STEEL" -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -74,105 +111,57 @@ # Limit orders to those in the specified range -df_orders = df_orders.filter(col("o_orderdate") >= start_date).filter( - col("o_orderdate") <= end_date -) - -# Part 1: Find customers in the region - -# We want customers in region specified by region_of_interest. This will be used to compute -# the total sales of the part of interest. We want to know of those sales what fraction -# was supplied by the nation of interest. There is no guarantee that the nation of -# interest is within the region of interest. - -# First we find all the sales that make up the basis. - -df_regional_customers = df_region.filter(col("r_name") == customer_region) - -# After this join we have all of the possible sales nations -df_regional_customers = df_regional_customers.join( - df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" -) - -# Now find the possible customers -df_regional_customers = df_regional_customers.join( - df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" -) - -# Next find orders for these customers -df_regional_customers = df_regional_customers.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -) - -# Find all line items from these orders -df_regional_customers = df_regional_customers.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" -) - -# Limit to the part of interest -df_regional_customers = df_regional_customers.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" -) - -# Compute the volume for each line item -df_regional_customers = df_regional_customers.with_column( - "volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")) -) - -# Part 2: Find suppliers from the nation - -# Now that we have all of the sales of that part in the specified region, we need -# to determine which of those came from suppliers in the nation we are interested in. - -df_national_suppliers = df_nation.filter(col("n_name") == supplier_nation) - -# Determine the suppliers by the limited nation key we have in our single row df above -df_national_suppliers = df_national_suppliers.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" -) - -# When we join to the customer dataframe, we don't want to confuse other columns, so only -# select the supplier key that we need -df_national_suppliers = df_national_suppliers.select("s_suppkey") - - -# Part 3: Combine suppliers and customers and compute the market share - -# Now we can do a left outer join on the suppkey. Those line items from other suppliers -# will get a null value. We can check for the existence of this null to compute a volume -# column only from suppliers in the nation we are evaluating. - -df = df_regional_customers.join( - df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" -) - -# Use a case statement to compute the volume sold by suppliers in the nation of interest -df = df.with_column( - "national_volume", - F.case(col("s_suppkey").is_null()) - .when(lit(value=False), col("volume")) - .otherwise(lit(0.0)), -) - -df = df.with_column( - "o_year", F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()) -) - - -# Lastly, sum up the results - -df = df.aggregate( - [col("o_year")], - [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), - ], +df_orders = df_orders.filter( + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) +) + +# Pair each supplier with its nation name so every regional-customer row +# below carries the supplier's nation and can be filtered inside the +# aggregate with ``F.sum(..., filter=...)``. + +df_supplier_with_nation = df_supplier.join( + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) + +# Build every (part, lineitem, order, customer) row for customers in the +# target region ordering the target part. Each row carries the supplier's +# nation so we can aggregate on it below. + +df = ( + df_region.filter(col("r_name") == customer_region) + .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") + .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") + .join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") + .with_columns( + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + ) +) + +# Aggregate the total and national volumes per year via the ``filter`` +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). +# ``coalesce`` handles the case where no sale came from the target nation +# for a given year. +df = ( + df.aggregate( + ["o_year"], + [ + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( + "national_volume" + ), + F.sum(col("volume")).alias("total_volume"), + ], + ) + .select( + "o_year", + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( + "mkt_share" + ), + ) + .sort_by("o_year") ) -df = df.select( - col("o_year"), (F.col("national_volume") / F.col("volume")).alias("mkt_share") -) - -df = df.sort(col("o_year").sort()) - df.show() diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index e2abbd095..ec68a2ab7 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -27,6 +27,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + nation, + o_year, + sum(amount) as sum_profit + from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit + group by + nation, + o_year + order by + nation, + o_year desc; """ import pyarrow as pa @@ -34,7 +69,7 @@ from datafusion import functions as F from util import get_data_path -part_color = lit("green") +part_color = "green" # Load the dataframes we need @@ -62,37 +97,35 @@ "n_nationkey", "n_name", "n_regionkey" ) -# Limit possible parts to the color specified -df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) - -# We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join( - df_partsupp, - left_on=["l_suppkey", "l_partkey"], - right_on=["ps_suppkey", "ps_partkey"], - how="inner", +# Limit possible parts to the color specified, then walk the joins down to the +# line-item rows we need and attach the supplier's nation. ``F.contains`` +# maps directly to the reference SQL's ``p_name like '%green%'``. +df = ( + df_part.filter(F.contains(col("p_name"), lit(part_color))) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join( + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") ) -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( col("n_name").alias("nation"), F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()).alias("o_year"), ( - (col("l_extendedprice") * (lit(1) - col("l_discount"))) - - (col("ps_supplycost") * col("l_quantity")) + col("l_extendedprice") * (lit(1) - col("l_discount")) + - col("ps_supplycost") * col("l_quantity") ).alias("amount"), ) -# Sum up the values by nation and year -df = df.aggregate( - [col("nation"), col("o_year")], [F.sum(col("amount")).alias("profit")] +# Sum up the values by nation and year, then sort per the spec. +df = df.aggregate(["nation", "o_year"], [F.sum(col("amount")).alias("profit")]).sort( + "nation", col("o_year").sort(ascending=False) ) -# Sort according to the problem specification -df = df.sort(col("nation").sort(), col("o_year").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ed822e264..e6532517e 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -27,20 +27,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment + from + customer, + orders, + lineitem, + nation + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey + group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment + order by + revenue desc limit 20; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_START_OF_QUARTER = "1993-10-01" - -date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date()) - -interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1993, 10, 1) +QUARTER_END = date(1994, 1, 1) # Load the dataframes we need @@ -66,44 +96,40 @@ ) # limit to returns -df_lineitem = df_lineitem.filter(col("l_returnflag") == lit("R")) +df_lineitem = df_lineitem.filter(col("l_returnflag") == "R") # Rather than aggregate by all of the customer fields as you might do looking at the specification, # we can aggregate by o_custkey and then join in the customer data at the end. -df = df_orders.filter(col("o_orderdate") >= date_start_of_quarter).filter( - col("o_orderdate") < date_start_of_quarter + interval_one_quarter +df = ( + df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), + ) + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .aggregate( + ["o_custkey"], + [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], + ) ) -df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") - -# Compute the revenue -df = df.aggregate( - [col("o_custkey")], - [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], +# Now join in the customer data, project the spec's output columns, and take the top 20. +df = ( + df.join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_nation, left_on="c_nationkey", right_on="n_nationkey") + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(col("revenue").sort(ascending=False)) + .limit(20) ) -# Now join in the customer data -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") -df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") - -# These are the columns the problem statement requires -df = df.select( - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", -) - -# Sort the results in descending order -df = df.sort(col("revenue").sort(ascending=False)) - -# Only return the top 20 results -df = df.limit(20) - df.show() diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index de309fa64..1f40bbdad 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -25,6 +25,36 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + value desc; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -49,39 +79,30 @@ "n_nationkey", "n_name" ) -# limit to returns -df_nation = df_nation.filter(col("n_name") == lit(NATION)) - -# Find part supplies of within this target nation - -df = df_nation.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +# Restrict to the target nation, then walk to partsupp rows via the supplier +# join. Aggregate the per-part inventory value. +df = ( + df_nation.filter(col("n_name") == NATION) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .join(df_partsupp, left_on="s_suppkey", right_on="ps_suppkey") + .with_column("value", col("ps_supplycost") * col("ps_availqty")) + .aggregate(["ps_partkey"], [F.sum(col("value")).alias("value")]) ) -df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") - - -# Compute the value of individual parts -df = df.with_column("value", col("ps_supplycost") * col("ps_availqty")) - -# Compute total value of specific parts -df = df.aggregate([col("ps_partkey")], [F.sum(col("value")).alias("value")]) - -# By default window functions go from unbounded preceding to current row, but we want -# to compute this sum across all rows -window_frame = WindowFrame("rows", None, None) - -df = df.with_column( - "total_value", F.sum(col("value")).over(Window(window_frame=window_frame)) +# A window function evaluated over the entire output produces a scalar grand +# total that can be referenced row-by-row in the filter — a DataFrame-native +# stand-in for the SQL HAVING ... > (SELECT SUM(...) * FRACTION ...) pattern. +# The default frame is "UNBOUNDED PRECEDING to CURRENT ROW"; override to the +# full partition for the grand total. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df.with_column( + "total_value", F.sum(col("value")).over(Window(window_frame=whole_frame)) + ) + .filter(col("value") / col("total_value") >= lit(FRACTION)) + .select("ps_partkey", "value") + .sort(col("value").sort(ascending=False)) ) -# Limit to the parts for which there is a significant value based on the fraction of the total -df = df.filter(col("value") / col("total_value") >= lit(FRACTION)) - -# We only need to report on these two columns -df = df.select("ps_partkey", "value") - -# Sort in descending order of value -df = df.sort(col("value").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 9071597f0..fb78fe3c2 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -27,18 +27,49 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + orders, + lineitem + where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year + group by + l_shipmode + order by + l_shipmode; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path SHIP_MODE_1 = "MAIL" SHIP_MODE_2 = "SHIP" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) # Load the dataframes we need @@ -51,63 +82,30 @@ "l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - - -df = df_lineitem.filter(col("l_receiptdate") >= lit(date)).filter( - col("l_receiptdate") < lit(date) + lit(interval) -) - -# Note: It is not recommended to use array_has because it treats the second argument as an argument -# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow. -# Instead check the position of the entry is not null. -df = df.filter( - ~F.array_position( - F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode") - ).is_null() -) - -# Since we have only two values, it's much easier to do this as a filter where the l_shipmode -# matches either of the two values, but we want to show doing some array operations in this -# example. If you want to see this done with filters, comment out the above line and uncomment -# this one. -# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001 +df = df_lineitem.filter( + col("l_receiptdate") >= lit(YEAR_START), + col("l_receiptdate") < lit(YEAR_END), + # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. + F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), + col("l_shipdate") < col("l_commitdate"), + col("l_commitdate") < col("l_receiptdate"), +).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") -# We need order priority, so join order df to line item -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +# Flag each line item as belonging to a high-priority order or not. +high_priorities = [lit("1-URGENT"), lit("2-HIGH")] +is_high = F.in_list(col("o_orderpriority"), high_priorities) +is_low = F.in_list(col("o_orderpriority"), high_priorities, negated=True) -# Restrict to line items we care about based on the problem statement. -df = df.filter(col("l_commitdate") < col("l_receiptdate")) - -df = df.filter(col("l_shipdate") < col("l_commitdate")) - -df = df.with_column( - "high_line_value", - F.case(col("o_orderpriority")) - .when(lit("1-URGENT"), lit(1)) - .when(lit("2-HIGH"), lit(1)) - .otherwise(lit(0)), -) - -# Aggregate the results +# Count the high-priority and low-priority lineitems per ship mode via the +# ``filter`` kwarg on ``F.count`` (DataFrame form of SQL's ``count(*) +# FILTER (WHERE ...)``). df = df.aggregate( - [col("l_shipmode")], + ["l_shipmode"], [ - F.sum(col("high_line_value")).alias("high_line_count"), - F.count(col("high_line_value")).alias("all_lines_count"), + F.count(col("o_orderkey"), filter=is_high).alias("high_line_count"), + F.count(col("o_orderkey"), filter=is_low).alias("low_line_count"), ], -) - -# Compute the final output -df = df.select( - col("l_shipmode"), - col("high_line_count"), - (col("all_lines_count") - col("high_line_count")).alias("low_line_count"), -) - -df = df.sort(col("l_shipmode").sort()) +).sort_by("l_shipmode") df.show() diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 93f082ea3..37c0b93f6 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -26,6 +26,29 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_count, + count(*) as custdist + from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc; """ from datafusion import SessionContext, col, lit @@ -49,20 +72,16 @@ F.regexp_match(col("o_comment"), lit(f"{WORD_1}.?*{WORD_2}")).is_null() ) -# Since we may have customers with no orders we must do a left join -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" -) - -# Find the number of orders for each customer -df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) - -# Ultimately we want to know the number of customers that have that customer count -df = df.aggregate([col("c_count")], [F.count(col("c_count")).alias("custdist")]) - -# We want to order the results by the highest number of customers per count -df = df.sort( - col("custdist").sort(ascending=False), col("c_count").sort(ascending=False) +# Customers with no orders still participate, so this is a left join. Count the +# orders per customer, then count customers per order-count value. +df = ( + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="left") + .aggregate(["c_custkey"], [F.count(col("o_custkey")).alias("c_count")]) + .aggregate(["c_count"], [F.count_star().alias("custdist")]) + .sort( + col("custdist").sort(ascending=False), + col("c_count").sort(ascending=False), + ) ) df.show() diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index d62f76e3c..08f4f054d 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -24,20 +24,32 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + from + lineitem, + part + where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1995-09-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_one_month = lit(pa.scalar((0, 30, 0), type=pa.month_day_nano_interval())) +MONTH_START = date(1995, 9, 1) +MONTH_END = date(1995, 10, 1) # Load the dataframes we need @@ -49,37 +61,30 @@ df_part = ctx.read_parquet(get_data_path("part.parquet")).select("p_partkey", "p_type") -# Check part type begins with PROMO -df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO") -).with_column("promo_factor", lit(1.0)) - -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_one_month -) - -# Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" -) - -# Make a factor of 1.0 if it is a promotion, 0.0 otherwise -df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) -df = df.with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) - - -# Sum up the promo and total revenue -df = df.aggregate( - [], - [ - F.sum(col("promo_factor") * col("revenue")).alias("promo_revenue"), - F.sum(col("revenue")).alias("total_revenue"), - ], -) - -# Return the percentage of revenue from promotions -df = df.select( - (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") +# Restrict the line items to the month of interest, join the matching part +# rows, and aggregate revenue totals with a ``filter`` clause on the promo +# sum — the DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``. +revenue = col("l_extendedprice") * (lit(1.0) - col("l_discount")) +is_promo = F.starts_with(col("p_type"), lit("PROMO")) + +df = ( + df_lineitem.filter( + col("l_shipdate") >= lit(MONTH_START), + col("l_shipdate") < lit(MONTH_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + [], + [ + F.sum(revenue, filter=is_promo).alias("promo_revenue"), + F.sum(revenue).alias("total_revenue"), + ], + ) + .select( + (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias( + "promo_revenue" + ) + ) ) df.show() diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 5128937a7..01c38b9f8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -24,21 +24,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + supplier, + revenue0 + where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) + order by + s_suppkey; + drop view revenue0; """ -from datetime import datetime +from datetime import date -import pyarrow as pa -from datafusion import SessionContext, WindowFrame, col, lit +from datafusion import SessionContext, col, lit from datafusion import functions as F -from datafusion.expr import Window from util import get_data_path -DATE = "1996-01-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_3_months = lit(pa.scalar((0, 91, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1996, 1, 1) +QUARTER_END = date(1996, 4, 1) # Load the dataframes we need @@ -54,38 +83,29 @@ "s_phone", ) -# Limit line items to the quarter of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_3_months -) +# Per-supplier revenue over the quarter of interest. +revenue = col("l_extendedprice") * (lit(1) - col("l_discount")) -df = df_lineitem.aggregate( - [col("l_suppkey")], - [ - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "total_revenue" - ) - ], -) +per_supplier_revenue = df_lineitem.filter( + col("l_shipdate") >= lit(QUARTER_START), + col("l_shipdate") < lit(QUARTER_END), +).aggregate(["l_suppkey"], [F.sum(revenue).alias("total_revenue")]) -# Use a window function to find the maximum revenue across the entire dataframe -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "max_revenue", - F.max(col("total_revenue")).over(Window(window_frame=window_frame)), +# Compute the grand maximum revenue separately and join on equality — the +# DataFrame stand-in for the reference SQL's +# ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. +max_revenue = per_supplier_revenue.aggregate( + [], [F.max(col("total_revenue")).alias("max_rev")] ) -# Find all suppliers whose total revenue is the same as the maximum -df = df.filter(col("total_revenue") == col("max_revenue")) - -# Now that we know the supplier(s) with maximum revenue, get the rest of their information -# from the supplier table -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +top_suppliers = per_supplier_revenue.join_on( + max_revenue, col("total_revenue") == col("max_rev") +).select("l_suppkey", "total_revenue") -# Return only the columns requested -df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") - -# If we have more than one, sort by supplier number (suppkey) -df = df.sort(col("s_suppkey").sort()) +df = ( + df_supplier.join(top_suppliers, left_on="s_suppkey", right_on="l_suppkey") + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort_by("s_suppkey") +) df.show() diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 65043ffda..ddeadff5f 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -26,9 +26,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt + from + partsupp, + part + where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) + group by + p_brand, + p_type, + p_size + order by + supplier_cnt desc, + p_brand, + p_type, + p_size; """ -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -52,39 +84,36 @@ ) df_unwanted_suppliers = df_supplier.filter( - ~F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_null() + F.regexp_like(col("s_comment"), lit("Customer.*Complaints")) ) -# Remove unwanted suppliers +# Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). df_partsupp = df_partsupp.join( - df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" + df_unwanted_suppliers, left_on="ps_suppkey", right_on="s_suppkey", how="anti" ) -# Select the parts we are interested in -df_part = df_part.filter(col("p_brand") != lit(BRAND)) +# Select the parts we are interested in. df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) - != lit(TYPE_TO_IGNORE) -) - -# Python conversion of integer to literal casts it to int64 but the data for -# part size is stored as an int32, so perform a cast. Then check to find if the part -# size is within the array of possible sizes by checking the position of it is not -# null. -p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) -df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) - -df = df_part.join( - df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" + col("p_brand") != BRAND, + ~F.starts_with(col("p_type"), lit(TYPE_TO_IGNORE)), + F.in_list(col("p_size"), [lit(s) for s in SIZES_OF_INTEREST]), ) -df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() - -df = df.aggregate( - [col("p_brand"), col("p_type"), col("p_size")], - [F.count(col("ps_suppkey")).alias("supplier_cnt")], +# For each (brand, type, size), count the distinct suppliers remaining. +df = ( + df_part.join(df_partsupp, left_on="p_partkey", right_on="ps_partkey") + .select("p_brand", "p_type", "p_size", "ps_suppkey") + .distinct() + .aggregate( + ["p_brand", "p_type", "p_size"], + [F.count(col("ps_suppkey")).alias("supplier_cnt")], + ) + .sort( + col("supplier_cnt").sort(ascending=False), + "p_brand", + "p_type", + "p_size", + ) ) -df = df.sort(col("supplier_cnt").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5ccb38422..f2229171f 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -26,6 +26,26 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); """ from datafusion import SessionContext, WindowFrame, col, lit @@ -47,29 +67,23 @@ "l_partkey", "l_quantity", "l_extendedprice" ) -# Limit to the problem statement's brand and container types -df = df_part.filter(col("p_brand") == lit(BRAND)).filter( - col("p_container") == lit(CONTAINER) -) - -# Combine data -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +# Limit to parts of the target brand/container, join their line items, and +# attach the per-part average quantity via a partitioned window function — +# the DataFrame form of the SQL's correlated ``avg(l_quantity)`` subquery. +whole_frame = WindowFrame("rows", None, None) -# Find the average quantity -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_quantity", - F.avg(col("l_quantity")).over( - Window(partition_by=[col("l_partkey")], window_frame=window_frame) - ), +df = ( + df_part.filter(col("p_brand") == BRAND, col("p_container") == CONTAINER) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .with_column( + "avg_quantity", + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=whole_frame) + ), + ) + .filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) + .aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) + .select((col("total") / lit(7.0)).alias("avg_yearly")) ) -df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) - -# Compute the total -df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) - -# Divide by number of years in the problem statement to get average -df = df.select((col("total") / lit(7)).alias("avg_yearly")) - df.show() diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 834d181c9..23132d60d 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -24,9 +24,44 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) + from + customer, + orders, + lineitem + where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey + group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice + order by + o_totalprice desc, + o_orderdate limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -46,22 +81,24 @@ "l_orderkey", "l_quantity", "l_extendedprice" ) -df = df_lineitem.aggregate( - [col("l_orderkey")], [F.sum(col("l_quantity")).alias("total_quantity")] +# Find orders whose total quantity exceeds the threshold, then join in the +# order + customer details the problem statement requires and sort. +df = ( + df_lineitem.aggregate( + ["l_orderkey"], [F.sum(col("l_quantity")).alias("total_quantity")] + ) + .filter(col("total_quantity") > QUANTITY) + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .select( + "c_name", + "c_custkey", + "o_orderkey", + "o_orderdate", + "o_totalprice", + "total_quantity", + ) + .sort(col("o_totalprice").sort(ascending=False), "o_orderdate") ) -# Limit to orders in which the total quantity is above a threshold -df = df.filter(col("total_quantity") > lit(QUANTITY)) - -# We've identified the orders of interest, now join the additional data -# we are required to report on -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - -df = df.select( - "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" -) - -df = df.sort(col("o_totalprice").sort(ascending=False), col("o_orderdate").sort()) - df.show() diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index bd492aac0..a2be1c1b7 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -24,10 +24,47 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice* (1 - l_discount)) as revenue + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); """ -import pyarrow as pa -from datafusion import SessionContext, col, lit, udf +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -65,72 +102,41 @@ "l_discount", ) -# These limitations apply to all line items, so go ahead and do them first - -df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON")) - -df = df.filter( - (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) -) +# Filter conditions that apply to every disjunct of the reference SQL's WHERE +# clause — pull them out up front so the per-brand predicate stays focused on +# the brand-specific parts. +df = df_lineitem.filter( + col("l_shipinstruct") == "DELIVER IN PERSON", + F.in_list(col("l_shipmode"), [lit("AIR"), lit("AIR REG")]), +).join(df_part, left_on="l_partkey", right_on="p_partkey") + + +# Build one OR-combined predicate per brand. Each disjunct encodes the +# brand-specific container list, quantity window, and size range from the +# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand +# B ...) or (... brand C ...)`` form directly, without a UDF. +def _brand_predicate( + brand: str, min_quantity: int, containers: list[str], max_size: int +): + return ( + (col("p_brand") == brand) + & F.in_list(col("p_container"), [lit(c) for c in containers]) + & col("l_quantity").between(lit(min_quantity), lit(min_quantity + 10)) + & col("p_size").between(lit(1), lit(max_size)) + ) -df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") - - -# Create the user defined function (UDF) definition that does the work -def is_of_interest( - brand_arr: pa.Array, - container_arr: pa.Array, - quantity_arr: pa.Array, - size_arr: pa.Array, -) -> pa.Array: - """ - The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array - and generating a resultant Array. The length of the inputs should match and there should be the - same number of rows in the output. - """ - result = [] - for idx, brand_val in enumerate(brand_arr): - brand = brand_val.as_py() - if brand in items_of_interest: - values_of_interest = items_of_interest[brand] - - container_matches = ( - container_arr[idx].as_py() in values_of_interest["containers"] - ) - - quantity = quantity_arr[idx].as_py() - quantity_matches = ( - values_of_interest["min_quantity"] - <= quantity - <= values_of_interest["min_quantity"] + 10 - ) - - size = size_arr[idx].as_py() - size_matches = 1 <= size <= values_of_interest["max_size"] - - result.append(container_matches and quantity_matches and size_matches) - else: - result.append(False) - - return pa.array(result) - - -# Turn the above function into a UDF that DataFusion can understand -is_of_interest_udf = udf( - is_of_interest, - [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()], - pa.bool_(), - "stable", -) -# Filter results using the above UDF -df = df.filter( - is_of_interest_udf( - col("p_brand"), col("p_container"), col("l_quantity"), col("p_size") +predicate = None +for brand, params in items_of_interest.items(): + part_predicate = _brand_predicate( + brand, + params["min_quantity"], + params["containers"], + params["max_size"], ) -) + predicate = part_predicate if predicate is None else predicate | part_predicate -df = df.aggregate( +df = df.filter(predicate).aggregate( [], [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], ) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index a25188d31..18f96da97 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -25,17 +25,57 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + s_address + from + supplier, + nation + where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' + order by + s_name; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path COLOR_OF_INTEREST = "forest" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) NATION_OF_INTEREST = "CANADA" # Load the dataframes we need @@ -56,46 +96,48 @@ "n_nationkey", "n_name" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - -# Filter down dataframes -df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) -df_part = df_part.filter( - F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1)) - == lit(COLOR_OF_INTEREST) -) - -df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter( - col("l_shipdate") < lit(date) + lit(interval) +# Filter down dataframes. ``starts_with`` reads more naturally than an +# explicit substring slice and maps directly to the reference SQL's +# ``p_name like 'forest%'`` clause. +df_nation = df_nation.filter(col("n_name") == NATION_OF_INTEREST) +df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) + +# Compute the total quantity of interesting parts shipped by each (part, +# supplier) pair within the year of interest. +totals = ( + df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + ["l_partkey", "l_suppkey"], + [F.sum(col("l_quantity")).alias("total_sold")], + ) ) -# This will filter down the line items to the parts of interest -df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") - -# Compute the total sold and limit ourselves to individual supplier/part combinations -df = df.aggregate( - [col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")] +# Keep only (part, supplier) pairs whose available quantity exceeds 50% of +# the total shipped. The result already contains one row per supplier of +# interest, so we can semi-join the supplier table rather than inner-join +# and deduplicate afterwards. +excess_suppliers = ( + df_partsupp.join( + totals, + left_on=["ps_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + ) + .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) + .select(col("ps_suppkey").alias("suppkey")) + .distinct() ) -df = df.join( - df_partsupp, - left_on=["l_partkey", "l_suppkey"], - right_on=["ps_partkey", "ps_suppkey"], - how="inner", +# Limit to suppliers in the nation of interest and pick out the two +# requested columns. +df = ( + df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") + .select("s_name", "s_address") + .sort_by("s_name") ) -# Find cases of excess quantity -df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) - -# We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - -# Restrict to the requested data per the problem statement -df = df.select("s_name", "s_address").distinct() - -df = df.sort(col("s_name").sort()) - df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 4ee9d3733..d98f76ce7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -24,9 +24,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + count(*) as numwait + from + supplier, + lineitem l1, + orders, + nation + where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' + group by + s_name + order by + numwait desc, + s_name limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -50,65 +92,57 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) - -df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" +df_suppliers_of_interest = df_nation.filter(col("n_name") == NATION_OF_INTEREST).join( + df_supplier, left_on="n_nationkey", right_on="s_nationkey" ) -# Find the failed orders and all their line items -df = df_orders.filter(col("o_orderstatus") == lit("F")) - -df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") - -# Identify the line items for which the order is failed due to. -df = df.with_column( - "failed_supp", - F.case(col("l_receiptdate") > col("l_commitdate")) - .when(lit(value=True), col("l_suppkey")) - .end(), +# Line items for orders that have status 'F'. This is the candidate set of +# (order, supplier) pairs we reason about below. +failed_order_lineitems = df_lineitem.join( + df_orders.filter(col("o_orderstatus") == "F"), + left_on="l_orderkey", + right_on="o_orderkey", ) -# There are other ways we could do this but the purpose of this example is to work with rows where -# an element is an array of values. In this case, we will create two columns of arrays. One will be -# an array of all of the suppliers who made up this order. That way we can filter the dataframe for -# only orders where this array is larger than one for multiple supplier orders. The second column -# is all of the suppliers who failed to make their commitment. We can filter the second column for -# arrays with size one. That combination will give us orders that had multiple suppliers where only -# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the -# same supplier reported in either array. -df = df.aggregate( - [col("o_orderkey")], - [ - F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"), - F.array_agg( - col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True - ).alias("failed_suppliers"), - ], +# Line items whose receipt was late. This corresponds to ``l1`` in the +# reference SQL. +late_lineitems = failed_order_lineitems.filter( + col("l_receiptdate") > col("l_commitdate") ) -# This is the check described above which will identify single failed supplier in a multiple -# supplier order. -df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter( - F.array_length(col("all_suppliers")) > lit(1) +# Orders that had more than one distinct supplier. Expressed as +# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's +# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery. +multi_supplier_orders = ( + failed_order_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_suppliers")]) + .filter(col("n_suppliers") > 1) + .select("l_orderkey") ) -# Since we have an array we know is exactly one element long, we can extract that single value. -df = df.select( - col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey") +# Orders where exactly one distinct supplier was late. Stands in for the +# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is +# also late ...)`` subquery: if only one supplier on the order was late, +# nobody else on the same order was late. +single_late_supplier_orders = ( + late_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_late_suppliers")]) + .filter(col("n_late_suppliers") == 1) + .select("l_orderkey") ) -# Join to the supplier of interest list for the nation of interest -df = df.join( - df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +# Keep late line items whose order qualifies on both counts, attach the +# supplier name for suppliers in the nation of interest, count one row per +# qualifying order, and return the top 100. +df = ( + late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi") + .join(single_late_supplier_orders, on="l_orderkey", how="semi") + .join(df_suppliers_of_interest, left_on="l_suppkey", right_on="s_suppkey") + .aggregate(["s_name"], [F.count_star().alias("numwait")]) + .sort(col("numwait").sort(ascending=False), "s_name") + .limit(100) ) -# Count how many orders that supplier is the only failed supplier for -df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) - -# Return in descending order -df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort()) - -df = df.limit(100) - df.show() diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index a2d41b215..5043eeb51 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -24,6 +24,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale + group by + cntrycode + order by + cntrycode; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -42,40 +82,36 @@ ) df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select("o_custkey") -# The nation code is a two digit number, but we need to convert it to a string literal -nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) - -# Use the substring operation to extract the first two characters of the phone number -df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0), lit(3))) - -# Limit our search to customers with some balance and in the country code above -df = df.filter(col("c_acctbal") > lit(0.0)) -df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null()) - -# Compute the average balance. By default, the window frame is from unbounded preceding to the -# current row. We want our frame to cover the entire data frame. -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_balance", - F.avg(col("c_acctbal")).over(Window(window_frame=window_frame)), -) - -df.show() -# Limit results to customers with above average balance -df = df.filter(col("c_acctbal") > col("avg_balance")) - -# Limit results to customers with no orders -df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") - -# Count up the customers and the balances -df = df.aggregate( - [col("cntrycode")], - [ - F.count(col("c_custkey")).alias("numcust"), - F.sum(col("c_acctbal")).alias("totacctbal"), - ], +# Country code is the two-digit prefix of the phone number. +nation_codes = [lit(str(n)) for n in NATION_CODES] + +# Start from customers with a positive balance in one of the target country +# codes, then attach the grand-mean balance via a whole-frame window so we +# can filter per row — DataFrame stand-in for the SQL's scalar ``(select +# avg(c_acctbal) ... )`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_customer.with_column("cntrycode", F.left(col("c_phone"), lit(2))) + .filter( + col("c_acctbal") > 0.0, + F.in_list(col("cntrycode"), nation_codes), + ) + .with_column( + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=whole_frame)), + ) + .filter(col("c_acctbal") > col("avg_balance")) + # Keep only customers with no orders (anti join = NOT EXISTS). + .join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") + .aggregate( + ["cntrycode"], + [ + F.count_star().alias("numcust"), + F.sum(col("c_acctbal")).alias("totacctbal"), + ], + ) + .sort_by("cntrycode") ) -df = df.sort(col("cntrycode").sort()) - df.show() diff --git a/pyproject.toml b/pyproject.toml index 327199d1a..951f7adc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "ARG", "BLE001", "D", + "FBT003", "PD", "PLC0415", "PLR0913", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 80dfa2fab..e4972411a 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,10 +15,44 @@ # specific language governing permissions and limitations # under the License. -"""DataFusion python package. - -This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. -See https://datafusion.apache.org/python for more information. +"""DataFusion: an in-process query engine built on Apache Arrow. + +DataFusion is not a database -- it has no server and no external dependencies. +You create a :py:class:`SessionContext`, point it at data sources (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API. + +Core abstractions +----------------- +- **SessionContext** -- entry point for loading data, running SQL, and creating + DataFrames. +- **DataFrame** -- lazy query builder. Every method returns a new DataFrame; + call :py:meth:`~datafusion.dataframe.DataFrame.collect` or a ``to_*`` + method to execute. +- **Expr** -- expression tree node for column references, literals, and function + calls. Build with :py:func:`col` and :py:func:`lit`. +- **functions** -- 290+ built-in scalar, aggregate, and window functions. + +Quick start +----------- + +>>> from datafusion import SessionContext, col +>>> from datafusion import functions as F +>>> ctx = SessionContext() +>>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) +>>> result = ( +... df.filter(col("a") > 1) +... .with_column("total", col("a") + col("b")) +... .aggregate([], [F.sum(col("total")).alias("grand_total")]) +... ) +>>> result.to_pydict() +{'grand_total': [16]} + +User guide and full documentation: https://datafusion.apache.org/python + +AI agent reference (SQL-to-DataFrame mappings, expression-building patterns, +common pitfalls), written in a dense, skill-oriented format: +https://github.com/apache/datafusion-python/blob/main/SKILL.md """ from __future__ import annotations diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c3f94cc16..dd6790402 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -15,7 +15,32 @@ # specific language governing permissions and limitations # under the License. -"""Session Context and it's associated configuration.""" +""":py:class:`SessionContext` — entry point for running DataFusion queries. + +A :py:class:`SessionContext` holds registered tables, catalogs, and +configuration for the current session. It is the first object most programs +create: from it you register data, run SQL strings +(:py:meth:`SessionContext.sql`), read files +(:py:meth:`SessionContext.read_csv`, +:py:meth:`SessionContext.read_parquet`, ...), and construct +:py:class:`~datafusion.dataframe.DataFrame` objects in memory +(:py:meth:`SessionContext.from_pydict`, +:py:meth:`SessionContext.from_arrow`). + +Session behavior (memory limits, batch size, configured optimizer passes, +...) is controlled by :py:class:`SessionConfig` and +:py:class:`RuntimeEnvBuilder`; SQL dialect limits are controlled by +:py:class:`SQLOptions`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> ctx.sql("SELECT 1 AS n").to_pydict() + {'n': [1]} + +See :ref:`user_guide_concepts` in the online documentation for the broader +execution model. +""" from __future__ import annotations diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c00c85fdb..2b07861da 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -14,9 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""":py:class:`DataFrame` is one of the core concepts in DataFusion. - -See :ref:`user_guide_concepts` in the online documentation for more information. +""":py:class:`DataFrame` — lazy, chainable query representation. + +A :py:class:`DataFrame` is a logical plan over one or more data sources. +Methods that reshape the plan (:py:meth:`DataFrame.select`, +:py:meth:`DataFrame.filter`, :py:meth:`DataFrame.aggregate`, +:py:meth:`DataFrame.sort`, :py:meth:`DataFrame.join`, +:py:meth:`DataFrame.limit`, the set-operation methods, ...) return a new +:py:class:`DataFrame` and do no work until a terminal method such as +:py:meth:`DataFrame.collect`, :py:meth:`DataFrame.to_pydict`, +:py:meth:`DataFrame.show`, or one of the ``write_*`` methods is called. + +DataFrames are produced from a +:py:class:`~datafusion.context.SessionContext`, typically via +:py:meth:`~datafusion.context.SessionContext.sql`, +:py:meth:`~datafusion.context.SessionContext.read_csv`, +:py:meth:`~datafusion.context.SessionContext.read_parquet`, or +:py:meth:`~datafusion.context.SessionContext.from_pydict`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.filter(col("a") > 1).select("b").to_pydict() + {'b': [20, 30]} + +See :ref:`user_guide_concepts` in the online documentation for a high-level +overview of the execution model. """ from __future__ import annotations @@ -503,21 +526,29 @@ def select_exprs(self, *args: str) -> DataFrame: def select(self, *exprs: Expr | str) -> DataFrame: """Project arbitrary expressions into a new :py:class:`DataFrame`. + String arguments are treated as column names; :py:class:`~datafusion.expr.Expr` + arguments can reshape, rename, or compute new columns. + Args: exprs: Either column names or :py:class:`~datafusion.expr.Expr` to select. Returns: DataFrame after projection. It has one column for each expression. - Example usage: + Examples: + Select columns by name: - The following example will return 3 columns from the original dataframe. - The first two columns will be the original column ``a`` and ``b`` since the - string "a" is assumed to refer to column selection. Also a duplicate of - column ``a`` will be returned with the column name ``alternate_a``:: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.select("a").to_pydict() + {'a': [1, 2, 3]} - df = df.select("a", col("b"), col("a").alias("alternate_a")) + Mix column names, expressions, and aliases. The string ``"a"`` selects + column ``a`` directly; ``col("a").alias("alternate_a")`` returns a + duplicate under a new name: + >>> df.select("a", col("b"), col("a").alias("alternate_a")).to_pydict() + {'a': [1, 2, 3], 'b': [10, 20, 30], 'alternate_a': [1, 2, 3]} """ exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) @@ -766,6 +797,24 @@ def aggregate( Returns: DataFrame after aggregation. + + Examples: + Aggregate without grouping — an empty ``group_by`` produces a + single row: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"team": ["x", "x", "y"], "score": [1, 2, 5]} + ... ) + >>> df.aggregate([], [F.sum(col("score")).alias("total")]).to_pydict() + {'total': [8]} + + Group by a column and produce one row per group: + + >>> df.aggregate( + ... ["team"], [F.sum(col("score")).alias("total")] + ... ).sort("team").to_pydict() + {'team': ['x', 'y'], 'total': [3, 5]} """ group_by_list = ( list(group_by) @@ -786,13 +835,27 @@ def sort(self, *exprs: SortKey) -> DataFrame: """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its ``sort`` method. + calling its ``sort`` method. For ascending-only sorts, the shorter + :py:meth:`sort_by` is usually more convenient. Args: exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. + + Examples: + Sort ascending by a column name: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2], "b": [10, 20, 30]}) + >>> df.sort("a").to_pydict() + {'a': [1, 2, 3], 'b': [20, 30, 10]} + + Sort descending using :py:meth:`Expr.sort`: + + >>> df.sort(col("a").sort(ascending=False)).to_pydict() + {'a': [3, 2, 1], 'b': [10, 30, 20]} """ exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) @@ -812,12 +875,28 @@ def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. + Results are returned in unspecified order unless the DataFrame is + explicitly sorted first via :py:meth:`sort` or :py:meth:`sort_by`. + Args: count: Number of rows to limit the DataFrame to. offset: Number of rows to skip. Returns: DataFrame after limiting. + + Examples: + Take the first two rows: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}).sort("a") + >>> df.limit(2).to_pydict() + {'a': [1, 2]} + + Skip the first row then take two (paging): + + >>> df.limit(2, offset=1).to_pydict() + {'a': [2, 3]} """ return DataFrame(self.df.limit(count, offset)) @@ -972,6 +1051,28 @@ def join( Returns: DataFrame after join. + + Examples: + Inner-join two DataFrames on a shared column: + + >>> ctx = dfn.SessionContext() + >>> left = ctx.from_pydict({"id": [1, 2, 3], "val": [10, 20, 30]}) + >>> right = ctx.from_pydict({"id": [2, 3, 4], "label": ["b", "c", "d"]}) + >>> left.join(right, on="id").sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'label': ['b', 'c']} + + Left join to keep all rows from the left side: + + >>> left.join(right, on="id", how="left").sort("id").to_pydict() + {'id': [1, 2, 3], 'val': [10, 20, 30], 'label': [None, 'b', 'c']} + + Use ``left_on`` / ``right_on`` when the key columns differ in name: + + >>> right2 = ctx.from_pydict({"rid": [2, 3], "label": ["b", "c"]}) + >>> left.join( + ... right2, left_on="id", right_on="rid" + ... ).sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'rid': [2, 3], 'label': ['b', 'c']} """ if join_keys is not None: warnings.warn( @@ -1165,6 +1266,20 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: Returns: DataFrame after union. + + Examples: + Stack rows from both DataFrames, preserving duplicates: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2]}) + >>> df2 = ctx.from_pydict({"a": [2, 3]}) + >>> df1.union(df2).sort("a").to_pydict() + {'a': [1, 2, 2, 3]} + + Deduplicate the combined result with ``distinct=True``: + + >>> df1.union(df2, distinct=True).sort("a").to_pydict() + {'a': [1, 2, 3]} """ return DataFrame(self.df.union(other.df, distinct)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 32004656f..0f7f3ab5a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,31 @@ # specific language governing permissions and limitations # under the License. -"""This module supports expressions, one of the core concepts in DataFusion. - -See :ref:`Expressions` in the online documentation for more details. +""":py:class:`Expr` — the logical expression type used to build DataFusion queries. + +An :py:class:`Expr` represents a computation over columns or literals: a +column reference (``col("a")``), a literal (``lit(5)``), an operator +combination (``col("a") + lit(1)``), or the output of a function from +:py:mod:`datafusion.functions`. Expressions are passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.sort`. + +Convenience constructors are re-exported at the package level: +:py:func:`datafusion.col` / :py:func:`datafusion.column` for column references +and :py:func:`datafusion.lit` / :py:func:`datafusion.literal` for scalar +literals. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.select((col("a") * lit(10)).alias("ten_a")).to_pydict() + {'ten_a': [10, 20, 30]} + +See :ref:`expressions` in the online documentation for details on available +operators and helpers. """ # ruff: noqa: PLC0415 @@ -221,6 +243,8 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "coerce_to_expr", + "coerce_to_expr_or_none", "ensure_expr", "ensure_expr_list", ] @@ -233,6 +257,10 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: higher level APIs consistently require explicit :func:`~datafusion.col` or :func:`~datafusion.lit` expressions. + See Also: + :func:`coerce_to_expr` — the opposite behavior: *wraps* non-``Expr`` + values as literals instead of rejecting them. + Args: value: Candidate expression or other object. @@ -277,6 +305,41 @@ def _iter( return list(_iter(exprs)) +def coerce_to_expr(value: Any) -> Expr: + """Coerce a native Python value to an ``Expr`` literal, passing ``Expr`` through. + + This is the complement of :func:`ensure_expr`: where ``ensure_expr`` + *rejects* non-``Expr`` values, ``coerce_to_expr`` *wraps* them via + :meth:`Expr.literal` so that functions can accept native Python types + (``int``, ``float``, ``str``, ``bool``, etc.) alongside ``Expr``. + + Args: + value: An ``Expr`` instance (returned as-is) or a Python literal to wrap. + + Returns: + An ``Expr`` representing the value. + """ + if isinstance(value, Expr): + return value + return Expr.literal(value) + + +def coerce_to_expr_or_none(value: Any | None) -> Expr | None: + """Coerce a value to ``Expr`` or pass ``None`` through unchanged. + + Same as :func:`coerce_to_expr` but accepts ``None`` for optional parameters. + + Args: + value: An ``Expr`` instance, a Python literal to wrap, or ``None``. + + Returns: + An ``Expr`` representing the value, or ``None``. + """ + if value is None: + return None + return coerce_to_expr(value) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 841cd9c0b..08062851a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,7 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User functions for operating on :py:class:`~datafusion.expr.Expr`.""" +"""Scalar, aggregate, and window functions for :py:class:`~datafusion.expr.Expr`. + +Each function returns an :py:class:`~datafusion.expr.Expr` that can be combined +with other expressions and passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.window`. The module is conventionally +imported as ``F`` so calls read like ``F.sum(col("price"))``. + +Examples: + >>> from datafusion import functions as F + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}) + >>> df.aggregate([], [F.sum(col("a")).alias("total")]).to_pydict() + {'total': [10]} + +See :ref:`aggregation` and :ref:`window_functions` in the online documentation +for categorized catalogs of aggregate and window functions. +""" from __future__ import annotations @@ -29,6 +49,8 @@ Expr, SortExpr, SortKey, + coerce_to_expr, + coerce_to_expr_or_none, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, @@ -363,49 +385,52 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(expr: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr | str) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.encode(dfn.col("a"), dfn.lit("base64")).alias("enc")) + ... dfn.functions.encode(dfn.col("a"), "base64").alias("enc")) >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) -def decode(expr: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr | str) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["aGVsbG8="]}) >>> result = df.select( - ... dfn.functions.decode(dfn.col("a"), dfn.lit("base64")).alias("dec")) + ... dfn.functions.decode(dfn.col("a"), "base64").alias("dec")) >>> result.collect_column("dec")[0].as_py() b'hello' """ + encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) -def array_to_string(expr: Expr, delimiter: Expr) -> Expr: +def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) >>> result = df.select( - ... dfn.functions.array_to_string(dfn.col("a"), dfn.lit(",")).alias("s")) + ... dfn.functions.array_to_string(dfn.col("a"), ",").alias("s")) >>> result.collect_column("s")[0].as_py() '1,2,3' """ + delimiter = coerce_to_expr(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) -def array_join(expr: Expr, delimiter: Expr) -> Expr: +def array_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -414,7 +439,7 @@ def array_join(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_to_string(expr: Expr, delimiter: Expr) -> Expr: +def list_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -423,7 +448,7 @@ def list_to_string(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_join(expr: Expr, delimiter: Expr) -> Expr: +def list_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -459,7 +484,7 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: return Expr(f.in_list(arg.expr, values, negated)) -def digest(value: Expr, method: Expr) -> Expr: +def digest(value: Expr, method: Expr | str) -> Expr: """Computes the binary hash of an expression using the specified algorithm. Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, @@ -469,24 +494,26 @@ def digest(value: Expr, method: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.digest(dfn.col("a"), dfn.lit("md5")).alias("d")) + ... dfn.functions.digest(dfn.col("a"), "md5").alias("d")) >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) -def contains(string: Expr, search_str: Expr) -> Expr: +def contains(string: Expr, search_str: Expr | str) -> Expr: """Returns true if ``search_str`` is found within ``string`` (case-sensitive). Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) >>> result = df.select( - ... dfn.functions.contains(dfn.col("a"), dfn.lit("brown")).alias("c")) + ... dfn.functions.contains(dfn.col("a"), "brown").alias("c")) >>> result.collect_column("c")[0].as_py() True """ + search_str = coerce_to_expr(search_str) return Expr(f.contains(string.expr, search_str.expr)) @@ -949,17 +976,18 @@ def degrees(arg: Expr) -> Expr: return Expr(f.degrees(arg.expr)) -def ends_with(arg: Expr, suffix: Expr) -> Expr: +def ends_with(arg: Expr, suffix: Expr | str) -> Expr: """Returns true if the ``string`` ends with the ``suffix``, false otherwise. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abc","b","c"]}) >>> ends_with_df = df.select( - ... dfn.functions.ends_with(dfn.col("a"), dfn.lit("c")).alias("ends_with")) + ... dfn.functions.ends_with(dfn.col("a"), "c").alias("ends_with")) >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ + suffix = coerce_to_expr(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -991,7 +1019,7 @@ def factorial(arg: Expr) -> Expr: return Expr(f.factorial(arg.expr)) -def find_in_set(string: Expr, string_list: Expr) -> Expr: +def find_in_set(string: Expr, string_list: Expr | str) -> Expr: """Find a string in a list of strings. Returns a value in the range of 1 to N if the string is in the string list @@ -1003,10 +1031,11 @@ def find_in_set(string: Expr, string_list: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["b"]}) >>> result = df.select( - ... dfn.functions.find_in_set(dfn.col("a"), dfn.lit("a,b,c")).alias("pos")) + ... dfn.functions.find_in_set(dfn.col("a"), "a,b,c").alias("pos")) >>> result.collect_column("pos")[0].as_py() 2 """ + string_list = coerce_to_expr(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1082,7 +1111,7 @@ def initcap(string: Expr) -> Expr: return Expr(f.initcap(string.expr)) -def instr(string: Expr, substring: Expr) -> Expr: +def instr(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1138,31 +1167,33 @@ def least(*args: Expr) -> Expr: return Expr(f.least(*exprs)) -def left(string: Expr, n: Expr) -> Expr: +def left(string: Expr, n: Expr | int) -> Expr: """Returns the first ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat"]}) >>> left_df = df.select( - ... dfn.functions.left(dfn.col("a"), dfn.lit(3)).alias("left")) + ... dfn.functions.left(dfn.col("a"), 3).alias("left")) >>> left_df.collect_column("left")[0].as_py() 'the' """ + n = coerce_to_expr(n) return Expr(f.left(string.expr, n.expr)) -def levenshtein(string1: Expr, string2: Expr) -> Expr: +def levenshtein(string1: Expr, string2: Expr | str) -> Expr: """Returns the Levenshtein distance between the two given strings. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["kitten"]}) >>> result = df.select( - ... dfn.functions.levenshtein(dfn.col("a"), dfn.lit("sitting")).alias("d")) + ... dfn.functions.levenshtein(dfn.col("a"), "sitting").alias("d")) >>> result.collect_column("d")[0].as_py() 3 """ + string2 = coerce_to_expr(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1179,18 +1210,19 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr, num: Expr) -> Expr: +def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041 """Returns the logarithm of a number for a particular ``base``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [100.0]}) >>> result = df.select( - ... dfn.functions.log(dfn.lit(10.0), dfn.col("a")).alias("log") + ... dfn.functions.log(10.0, dfn.col("a")).alias("log") ... ) >>> result.collect_column("log")[0].as_py() 2.0 """ + base = coerce_to_expr(base) return Expr(f.log(base.expr, num.expr)) @@ -1233,7 +1265,7 @@ def lower(arg: Expr) -> Expr: return Expr(f.lower(arg.expr)) -def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add left padding to a string. Extends the string to length length by prepending the characters fill (a @@ -1244,9 +1276,7 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat", "a hat"]}) >>> lpad_df = df.select( - ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(6) - ... ).alias("lpad")) + ... dfn.functions.lpad(dfn.col("a"), 6).alias("lpad")) >>> lpad_df.collect_column("lpad")[0].as_py() 'the ca' >>> lpad_df.collect_column("lpad")[1].as_py() @@ -1254,12 +1284,13 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(10), characters=dfn.lit(".") + ... dfn.col("a"), 10, characters="." ... ).alias("lpad")) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1354,7 +1385,10 @@ def octet_length(arg: Expr) -> Expr: def overlay( - string: Expr, substring: Expr, start: Expr, length: Expr | None = None + string: Expr, + substring: Expr | str, + start: Expr | int, + length: Expr | int | None = None, ) -> Expr: """Replace a substring with a new substring. @@ -1365,13 +1399,15 @@ def overlay( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcdef"]}) >>> result = df.select( - ... dfn.functions.overlay(dfn.col("a"), dfn.lit("XY"), dfn.lit(3), - ... dfn.lit(2)).alias("o")) + ... dfn.functions.overlay(dfn.col("a"), "XY", 3, 2).alias("o")) >>> result.collect_column("o")[0].as_py() 'abXYef' """ + substring = coerce_to_expr(substring) + start = coerce_to_expr(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) + length = coerce_to_expr(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1391,7 +1427,7 @@ def pi() -> Expr: return Expr(f.pi()) -def position(string: Expr, substring: Expr) -> Expr: +def position(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1400,22 +1436,23 @@ def position(string: Expr, substring: Expr) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr) -> Expr: +def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [2.0]}) >>> result = df.select( - ... dfn.functions.power(dfn.col("a"), dfn.lit(3.0)).alias("pow") + ... dfn.functions.power(dfn.col("a"), 3.0).alias("pow") ... ) >>> result.collect_column("pow")[0].as_py() 8.0 """ + exponent = coerce_to_expr(exponent) return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr) -> Expr: +def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1440,7 +1477,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1450,9 +1489,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( - ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.functions.regexp_like(dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1461,19 +1498,24 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_like( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1483,9 +1525,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.functions.regexp_match(dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1494,20 +1534,26 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_match( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1522,8 +1568,7 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1534,20 +1579,30 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + pattern = coerce_to_expr(pattern) + replacement = coerce_to_expr(replacement) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_replace( + string.expr, + pattern.expr, + replacement.expr, + flags.expr if flags is not None else None, + ) + ) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1558,9 +1613,7 @@ def regexp_count( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( - ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.functions.regexp_count(dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1569,25 +1622,31 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ - if flags is not None: - flags = flags.expr - start = start.expr if start is not None else start - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1603,9 +1662,7 @@ def regexp_instr( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.functions.regexp_instr(dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1616,9 +1673,8 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1628,56 +1684,58 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + regex = coerce_to_expr(regex) + start = coerce_to_expr_or_none(start) + n = coerce_to_expr_or_none(n) + flags = coerce_to_expr_or_none(flags) + sub_expr = coerce_to_expr_or_none(sub_expr) return Expr( f.regexp_instr( values.expr, regex.expr, - start, - n, - flags, - sub_expr, + start.expr if start is not None else None, + n.expr if n is not None else None, + flags.expr if flags is not None else None, + sub_expr.expr if sub_expr is not None else None, ) ) -def repeat(string: Expr, n: Expr) -> Expr: +def repeat(string: Expr, n: Expr | int) -> Expr: """Repeats the ``string`` to ``n`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["ha"]}) >>> result = df.select( - ... dfn.functions.repeat(dfn.col("a"), dfn.lit(3)).alias("r")) + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'hahaha' """ + n = coerce_to_expr(n) return Expr(f.repeat(string.expr, n.expr)) -def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces all occurrences of ``from_val`` with ``to_val`` in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.replace(dfn.col("a"), dfn.lit("world"), - ... dfn.lit("there")).alias("r")) + ... dfn.functions.replace(dfn.col("a"), "world", "there").alias("r")) >>> result.collect_column("r")[0].as_py() 'hello there' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1694,39 +1752,39 @@ def reverse(arg: Expr) -> Expr: return Expr(f.reverse(arg.expr)) -def right(string: Expr, n: Expr) -> Expr: +def right(string: Expr, n: Expr | int) -> Expr: """Returns the last ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) - >>> result = df.select(dfn.functions.right(dfn.col("a"), dfn.lit(3)).alias("r")) + >>> result = df.select(dfn.functions.right(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'llo' """ + n = coerce_to_expr(n) return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: +def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: """Round the argument to the nearest integer. If the optional ``decimal_places`` is specified, round to the nearest number of decimal places. You can specify a negative number of decimal places. For example - ``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``. + ``round(lit(125.2345), -2)`` would yield a value of ``100.0``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) - >>> result = df.select(dfn.functions.round(dfn.col("a"), dfn.lit(2)).alias("r")) + >>> result = df.select(dfn.functions.round(dfn.col("a"), 2).alias("r")) >>> result.collect_column("r")[0].as_py() 1.57 """ - if decimal_places is None: - decimal_places = Expr.literal(0) + decimal_places = coerce_to_expr(decimal_places if decimal_places is not None else 0) return Expr(f.round(value.expr, decimal_places.expr)) -def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add right padding to a string. Extends the string to length length by appending the characters fill (a space @@ -1736,11 +1794,12 @@ def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hi"]}) >>> result = df.select( - ... dfn.functions.rpad(dfn.col("a"), dfn.lit(5), dfn.lit("!")).alias("r")) + ... dfn.functions.rpad(dfn.col("a"), 5, "!").alias("r")) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1856,7 +1915,7 @@ def sinh(arg: Expr) -> Expr: return Expr(f.sinh(arg.expr)) -def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: +def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: """Split a string and return one part. Splits a string based on a delimiter and picks out the desired field based @@ -1866,12 +1925,12 @@ def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a,b,c"]}) >>> result = df.select( - ... dfn.functions.split_part( - ... dfn.col("a"), dfn.lit(","), dfn.lit(2) - ... ).alias("s")) + ... dfn.functions.split_part(dfn.col("a"), ",", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'b' """ + delimiter = coerce_to_expr(delimiter) + index = coerce_to_expr(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1888,49 +1947,52 @@ def sqrt(arg: Expr) -> Expr: return Expr(f.sqrt(arg.expr)) -def starts_with(string: Expr, prefix: Expr) -> Expr: +def starts_with(string: Expr, prefix: Expr | str) -> Expr: """Returns true if string starts with prefix. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello_from_datafusion"]}) >>> result = df.select( - ... dfn.functions.starts_with(dfn.col("a"), dfn.lit("hello")).alias("sw")) + ... dfn.functions.starts_with(dfn.col("a"), "hello").alias("sw")) >>> result.collect_column("sw")[0].as_py() True """ + prefix = coerce_to_expr(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) -def strpos(string: Expr, substring: Expr) -> Expr: +def strpos(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.strpos(dfn.col("a"), dfn.lit("llo")).alias("pos")) + ... dfn.functions.strpos(dfn.col("a"), "llo").alias("pos")) >>> result.collect_column("pos")[0].as_py() 3 """ + substring = coerce_to_expr(substring) return Expr(f.strpos(string.expr, substring.expr)) -def substr(string: Expr, position: Expr) -> Expr: +def substr(string: Expr, position: Expr | int) -> Expr: """Substring from the ``position`` to the end. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.substr(dfn.col("a"), dfn.lit(3)).alias("s")) + ... dfn.functions.substr(dfn.col("a"), 3).alias("s")) >>> result.collect_column("s")[0].as_py() 'llo' """ + position = coerce_to_expr(position) return Expr(f.substr(string.expr, position.expr)) -def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: +def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr: """Returns an indexed substring. The return will be the ``string`` from before ``count`` occurrences of @@ -1940,27 +2002,28 @@ def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a.b.c"]}) >>> result = df.select( - ... dfn.functions.substr_index(dfn.col("a"), dfn.lit("."), - ... dfn.lit(2)).alias("s")) + ... dfn.functions.substr_index(dfn.col("a"), ".", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'a.b' """ + delimiter = coerce_to_expr(delimiter) + count = coerce_to_expr(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) -def substring(string: Expr, position: Expr, length: Expr) -> Expr: +def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: """Substring from the ``position`` with ``length`` characters. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.substring( - ... dfn.col("a"), dfn.lit(1), dfn.lit(5) - ... ).alias("s")) + ... dfn.functions.substring(dfn.col("a"), 1, 5).alias("s")) >>> result.collect_column("s")[0].as_py() 'hello' """ + position = coerce_to_expr(position) + length = coerce_to_expr(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -2033,7 +2096,7 @@ def current_timestamp() -> Expr: return now() -def to_char(arg: Expr, formatter: Expr) -> Expr: +def to_char(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. For usage of ``formatter`` see the rust chrono package ``strftime`` package. @@ -2046,16 +2109,17 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: >>> result = df.select( ... dfn.functions.to_char( ... dfn.functions.to_timestamp(dfn.col("a")), - ... dfn.lit("%Y/%m/%d"), + ... "%Y/%m/%d", ... ).alias("formatted") ... ) >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ + formatter = coerce_to_expr(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) -def date_format(arg: Expr, formatter: Expr) -> Expr: +def date_format(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. See Also: @@ -2267,7 +2331,7 @@ def current_time() -> Expr: return Expr(f.current_time()) -def datepart(part: Expr, date: Expr) -> Expr: +def datepart(part: Expr | str, date: Expr) -> Expr: """Return a specified part of a date. See Also: @@ -2276,22 +2340,28 @@ def datepart(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_part(part: Expr, date: Expr) -> Expr: +def date_part(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. + Args: + part: The part of the date to extract. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to extract from. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_part(dfn.lit("year"), dfn.col("a")).alias("y")) + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) >>> result.collect_column("y")[0].as_py() 2021 """ + part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) -def extract(part: Expr, date: Expr) -> Expr: +def extract(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. See Also: @@ -2300,25 +2370,29 @@ def extract(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_trunc(part: Expr, date: Expr) -> Expr: +def date_trunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. + Args: + part: The precision to truncate to. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to truncate. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_trunc( - ... dfn.lit("month"), dfn.col("a") - ... ).alias("t") + ... dfn.functions.date_trunc("month", dfn.col("a")).alias("t") ... ) >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ + part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) -def datetrunc(part: Expr, date: Expr) -> Expr: +def datetrunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. See Also: @@ -2379,18 +2453,19 @@ def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: return Expr(f.make_time(hour.expr, minute.expr, second.expr)) -def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.translate(dfn.col("a"), dfn.lit("helo"), - ... dfn.lit("HELO")).alias("t")) + ... dfn.functions.translate(dfn.col("a"), "helo", "HELO").alias("t")) >>> result.collect_column("t")[0].as_py() 'HELLO' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2407,27 +2482,24 @@ def trim(arg: Expr) -> Expr: return Expr(f.trim(arg.expr)) -def trunc(num: Expr, precision: Expr | None = None) -> Expr: +def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: """Truncate the number toward zero with optional precision. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a") - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a")).alias("t")) >>> result.collect_column("t")[0].as_py() 1.0 >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a"), precision=dfn.lit(2) - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a"), precision=2).alias("t")) >>> result.collect_column("t")[0].as_py() 1.56 """ if precision is not None: + precision = coerce_to_expr(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -2908,17 +2980,18 @@ def list_dims(array: Expr) -> Expr: return array_dims(array) -def array_element(array: Expr, n: Expr) -> Expr: +def array_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[10, 20, 30]]}) >>> result = df.select( - ... dfn.functions.array_element(dfn.col("a"), dfn.lit(2)).alias("result")) + ... dfn.functions.array_element(dfn.col("a"), 2).alias("result")) >>> result.collect_column("result")[0].as_py() 20 """ + n = coerce_to_expr(n) return Expr(f.array_element(array.expr, n.expr)) @@ -2944,7 +3017,7 @@ def list_empty(array: Expr) -> Expr: return array_empty(array) -def array_extract(array: Expr, n: Expr) -> Expr: +def array_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2953,7 +3026,7 @@ def array_extract(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_element(array: Expr, n: Expr) -> Expr: +def list_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2962,7 +3035,7 @@ def list_element(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_extract(array: Expr, n: Expr) -> Expr: +def list_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -3312,22 +3385,24 @@ def list_remove(array: Expr, element: Expr) -> Expr: return array_remove(array, element) -def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_remove_n(dfn.col("a"), dfn.lit(1), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_remove_n( + ... dfn.col("a"), dfn.lit(1), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 1] """ + max = coerce_to_expr(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) -def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def list_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. See Also: @@ -3361,21 +3436,22 @@ def list_remove_all(array: Expr, element: Expr) -> Expr: return array_remove_all(array, element) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) >>> result = df.select( - ... dfn.functions.array_repeat(dfn.lit(3), dfn.lit(3)).alias("result")) + ... dfn.functions.array_repeat(dfn.lit(3), 3).alias("result")) >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ + count = coerce_to_expr(count) return Expr(f.array_repeat(element.expr, count.expr)) -def list_repeat(element: Expr, count: Expr) -> Expr: +def list_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. See Also: @@ -3408,7 +3484,7 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace(array, from_val, to_val) -def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3418,15 +3494,17 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Exp >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_replace_n(dfn.col("a"), dfn.lit(1), dfn.lit(9), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_replace_n( + ... dfn.col("a"), dfn.lit(1), dfn.lit(9), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ + max = coerce_to_expr(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) -def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3509,7 +3587,10 @@ def list_sort(array: Expr, descending: bool = False, null_first: bool = False) - def array_slice( - array: Expr, begin: Expr, end: Expr, stride: Expr | None = None + array: Expr, + begin: Expr | int, + end: Expr | int, + stride: Expr | int | None = None, ) -> Expr: """Returns a slice of the array. @@ -3517,9 +3598,7 @@ def array_slice( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) >>> result = df.select( - ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(2), dfn.lit(3) - ... ).alias("result")) + ... dfn.functions.array_slice(dfn.col("a"), 2, 3).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 3] @@ -3527,18 +3606,27 @@ def array_slice( >>> result = df.select( ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(1), dfn.lit(4), - ... stride=dfn.lit(2), + ... dfn.col("a"), 1, 4, stride=2, ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 3] """ - if stride is not None: - stride = stride.expr - return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + begin = coerce_to_expr(begin) + end = coerce_to_expr(end) + stride = coerce_to_expr_or_none(stride) + return Expr( + f.array_slice( + array.expr, + begin.expr, + end.expr, + stride.expr if stride is not None else None, + ) + ) -def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: +def list_slice( + array: Expr, begin: Expr | int, end: Expr | int, stride: Expr | int | None = None +) -> Expr: """Returns a slice of the array. See Also: @@ -3630,7 +3718,7 @@ def list_except(array1: Expr, array2: Expr) -> Expr: return array_except(array1, array2) -def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will @@ -3640,15 +3728,15 @@ def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2]]}) >>> result = df.select( - ... dfn.functions.array_resize(dfn.col("a"), dfn.lit(4), - ... dfn.lit(0)).alias("result")) + ... dfn.functions.array_resize(dfn.col("a"), 4, dfn.lit(0)).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ + size = coerce_to_expr(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) -def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def list_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will be @@ -3802,7 +3890,7 @@ def list_zip(*arrays: Expr) -> Expr: def string_to_array( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. @@ -3812,9 +3900,7 @@ def string_to_array( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello,world"]}) >>> result = df.select( - ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), - ... ).alias("result")) + ... dfn.functions.string_to_array(dfn.col("a"), ",").alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', 'world'] @@ -3822,17 +3908,24 @@ def string_to_array( >>> result = df.select( ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), null_string=dfn.lit("world"), + ... dfn.col("a"), ",", null_string="world", ... ).alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', None] """ - null_expr = null_string.expr if null_string is not None else None - return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) + delimiter = coerce_to_expr(delimiter) + null_string = coerce_to_expr_or_none(null_string) + return Expr( + f.string_to_array( + string.expr, + delimiter.expr, + null_string.expr if null_string is not None else None, + ) + ) def string_to_list( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 13c05a9e6..e0ebdbae5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -964,12 +964,12 @@ def test_csv_read_options_builder_pattern(): options = ( CsvReadOptions() - .with_has_header(False) # noqa: FBT003 + .with_has_header(False) .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) - .with_truncated_rows(True) # noqa: FBT003 - .with_newlines_in_values(True) # noqa: FBT003 + .with_truncated_rows(True) + .with_newlines_in_values(True) .with_file_extension(".tsv") ) assert options.has_header is False diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 091fa9b56..9e2f791ea 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -3426,10 +3426,18 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +_slow_udf_started = threading.Event() + + @udf([pa.int64()], pa.int64(), "immutable") def slow_udf(x: pa.Array) -> pa.Array: - # This must be longer than the check interval in wait_for_future - time.sleep(2.0) + _slow_udf_started.set() + # Sleep in small increments so Python's eval loop checks for pending + # async exceptions (like KeyboardInterrupt via PyThreadState_SetAsyncExc) + # between iterations. A single long time.sleep() is a C call where async + # exceptions are not checked on all Python versions (notably 3.11). + for _ in range(200): + time.sleep(0.01) return x @@ -3463,6 +3471,7 @@ def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 P if as_c_stream: reader = pa.RecordBatchReader.from_stream(df) + _slow_udf_started.clear() read_started = threading.Event() read_exception = [] read_thread_id = None @@ -3474,6 +3483,14 @@ def trigger_interrupt(): msg = f"Read operation did not start within {max_wait_time} seconds" raise RuntimeError(msg) + # For slow_query tests, wait until the UDF is actually executing Python + # bytecode before sending the interrupt. PyThreadState_SetAsyncExc only + # delivers exceptions when the thread is in the Python eval loop, not + # while in native (Rust/C) code. + if slow_query and not _slow_udf_started.wait(timeout=max_wait_time): + msg = f"UDF did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + if read_thread_id is None: msg = "Cannot get read thread ID" raise RuntimeError(msg) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index d046eb48c..8aa791ae1 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -53,6 +53,8 @@ TransactionEnd, TransactionStart, Values, + coerce_to_expr, + coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @@ -1030,12 +1032,55 @@ def test_ensure_expr_list_bytearray(): ensure_expr_list(bytearray(b"a")) +def test_coerce_to_expr_passes_expr_through(): + e = col("a") + result = coerce_to_expr(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + +def test_coerce_to_expr_wraps_int(): + result = coerce_to_expr(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_wraps_str(): + result = coerce_to_expr("hello") + assert isinstance(result, type(lit("hello"))) + + +def test_coerce_to_expr_wraps_float(): + result = coerce_to_expr(3.14) + assert isinstance(result, type(lit(3.14))) + + +def test_coerce_to_expr_wraps_bool(): + result = coerce_to_expr(True) + assert isinstance(result, type(lit(True))) + + +def test_coerce_to_expr_or_none_returns_none(): + assert coerce_to_expr_or_none(None) is None + + +def test_coerce_to_expr_or_none_wraps_value(): + result = coerce_to_expr_or_none(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_or_none_passes_expr_through(): + e = col("a") + result = coerce_to_expr_or_none(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + @pytest.mark.parametrize( "value", [ # Boolean - pa.scalar(True, type=pa.bool_()), # noqa: FBT003 - pa.scalar(False, type=pa.bool_()), # noqa: FBT003 + pa.scalar(True, type=pa.bool_()), + pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 11e94af1c..d9781b1fb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -2099,3 +2099,96 @@ def test_gen_series_with_step(): f.gen_series(literal(1), literal(10), literal(3)).alias("v") ).collect() assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] + + +class TestPythonicNativeTypes: + """Tests for accepting native Python types instead of requiring lit().""" + + def test_split_part_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect() + assert result[0].column(0)[0].as_py() == "b" + + def test_encode_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello"]}) + result = df.select(f.encode(column("a"), "base64").alias("e")).collect() + assert result[0].column(0)[0].as_py() == "aGVsbG8" + + def test_date_part_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_part("year", column("a")).alias("y")).collect() + assert result[0].column(0)[0].as_py() == 2021 + + def test_date_trunc_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_trunc("month", column("a")).alias("t")).collect() + assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00" + + def test_left_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["the cat"]}) + result = df.select(f.left(column("a"), 3).alias("l")).collect() + assert result[0].column(0)[0].as_py() == "the" + + def test_round_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.567]}) + result = df.select(f.round(column("a"), 2).alias("r")).collect() + assert result[0].column(0)[0].as_py() == 1.57 + + def test_regexp_count_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["abcabc"]}) + result = df.select( + f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c") + ).collect() + assert result[0].column(0)[0].as_py() == 1 + + def test_log_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [100.0]}) + result = df.select(f.log(10, column("a")).alias("l")).collect() + assert result[0].column(0)[0].as_py() == 2.0 + + def test_power_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [2.0]}) + result = df.select(f.power(column("a"), 3).alias("p")).collect() + assert result[0].column(0)[0].as_py() == 8.0 + + def test_array_slice_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) + result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect() + assert result[0].column(0)[0].as_py() == [2, 3] + + def test_string_to_array_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), ",", null_string="NA").alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == ["hello", None, "world"] + + def test_regexp_replace_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a1 b2 c3"]}) + result = df.select( + f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r") + ).collect() + assert result[0].column(0)[0].as_py() == "aX bX cX" + + def test_backward_compat_with_lit(self): + """Verify that existing code using lit() still works.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select( + f.split_part(column("a"), literal(","), literal(2)).alias("s") + ).collect() + assert result[0].column(0)[0].as_py() == "b"