From 40309978c920bd123a4c7b764a2ddfdb97758607 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 18:28:55 -0400 Subject: [PATCH 1/6] Add SKILL.md and enrich package docstring (#1497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add AGENTS.md and enrich __init__.py module docstring Add python/datafusion/AGENTS.md as a comprehensive DataFrame API guide for AI agents and users. It ships with pip automatically (Maturin includes everything under python-source = "python"). Covers core abstractions, import conventions, data loading, all DataFrame operations, expression building, a SQL-to-DataFrame reference table, common pitfalls, idiomatic patterns, and a categorized function index. Enrich the __init__.py module docstring from 2 lines to a full overview with core abstractions, a quick-start example, and a pointer to AGENTS.md. Closes #1394 (PR 1a) Co-Authored-By: Claude Opus 4.6 (1M context) * Clarify audience of root vs package AGENTS.md The root AGENTS.md (symlinked as CLAUDE.md) is for contributors working on the project. Add a pointer to python/datafusion/AGENTS.md which is the user-facing DataFrame API guide shipped with the package. Also add the Apache license header to the package AGENTS.md. Co-Authored-By: Claude Opus 4.6 (1M context) * Add PR template and pre-commit check guidance to AGENTS.md Document that all PRs must follow .github/pull_request_template.md and that pre-commit hooks must pass before committing. List all configured hooks (actionlint, ruff, ruff-format, cargo fmt, cargo clippy, codespell, uv-lock) and the command to run them manually. Co-Authored-By: Claude Opus 4.6 (1M context) * Remove duplicated hook list from AGENTS.md Let the hooks be discoverable from .pre-commit-config.yaml rather than maintaining a separate list that can drift. Co-Authored-By: Claude Opus 4.6 (1M context) * Fix AGENTS.md: Arrow C Data Interface, aggregate filter, fluent example - Clarify that DataFusion works with any Arrow C Data Interface implementation, not just PyArrow. - Show the filter keyword argument on aggregate functions (the idiomatic HAVING equivalent) instead of the post-aggregate .filter() pattern. - Update the SQL reference table to show FILTER (WHERE ...) syntax. - Remove the now-incorrect "Aggregate then filter for HAVING" pitfall. - Add .collect() to the fluent chaining example so the result is clearly materialized. Co-Authored-By: Claude Opus 4.6 (1M context) * Update agents file after working through the first tpc-h query using only the text description * Add feedback from working through each of the TPC-H queries * Address Copilot review feedback on AGENTS.md - Wrap CASE/WHEN method-chain examples in parentheses and assign to a variable so they are valid Python as shown (Copilot #1, #2). - Fix INTERSECT/EXCEPT mapping: the default distinct=False corresponds to INTERSECT ALL / EXCEPT ALL, not the distinct forms. Updated both the Set Operations section and the SQL reference table to show both the ALL and distinct variants (Copilot #4). - Change write_parquet / write_csv / write_json examples to file-style paths (output.parquet, etc.) to match the convention used in existing tests and examples. Note that a directory path is also valid for partitioned output (Copilot #5). Verified INTERSECT/EXCEPT semantics with a script: df1.intersect(df2) -> [1, 1, 2] (= INTERSECT ALL) df1.intersect(df2, distinct=True) -> [1, 2] (= INTERSECT) Co-Authored-By: Claude Opus 4.6 (1M context) * Use short-form comparisons in AGENTS.md examples Drop lit() on the RHS of comparison operators since Expr auto-wraps raw Python values, matching the style the guide recommends (Copilot #3, #6). Updates examples in the Aggregation, CASE/WHEN, SQL reference table, Common Pitfalls, Fluent Chaining, and Variables-as-CTEs sections, plus the __init__.py quick-start snippet. Prose explanations of the rule (which cite the long form as the thing to avoid) are left unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) * Move user guide from python/datafusion/AGENTS.md to SKILL.md The in-wheel AGENTS.md was not a real distribution channel -- no shipping agent walks site-packages for AGENTS.md files. Moving to SKILL.md at the repo root, with YAML frontmatter, lets the skill ecosystems (npx skills, Claude Code plugin marketplaces, community aggregators) discover it. Update the pointers in the contributor AGENTS.md and the __init__.py module docstring accordingly. The docstring now references the GitHub URL since the file no longer ships with the wheel. Co-Authored-By: Claude Opus 4.7 (1M context) * Address review feedback: doctest, streaming, date/timestamp - Convert the __init__.py quick-start block to doctest format so it is picked up by `pytest --doctest-modules` (already the project default), preventing silent rot. - Extract streaming into its own SKILL.md subsection with guidance on when to prefer execute_stream() over collect(), sync and async iteration, and execute_stream_partitioned() for per-partition streams. - Generalize the date-arithmetic rule from Date32 to both Date32 and Date64 (both reject Duration at any precision, both accept month_day_nano_interval), and note that Timestamp columns differ and do accept Duration. - Document the PyArrow-inherited type mapping returned by to_pydict()/to_pylist(), including the nanosecond fallback to pandas.Timestamp / pandas.Timedelta and the to_pandas() footgun where date columns come back as an object dtype. Co-Authored-By: Claude Opus 4.7 (1M context) * Distinguish user guide from agent reference in module docstring The docstring pointed readers at SKILL.md as a "comprehensive guide," but SKILL.md is written in a dense, skill-oriented format for agents — humans are better served by the online user guide. Put the online docs first as the primary reference and label the SKILL.md link as the agent reference. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- AGENTS.md | 34 +- SKILL.md | 733 ++++++++++++++++++++++++++++++++++ python/datafusion/__init__.py | 42 +- 3 files changed, 804 insertions(+), 5 deletions(-) create mode 100644 SKILL.md diff --git a/AGENTS.md b/AGENTS.md index 86c2e9c3b..7d3262710 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,7 +17,14 @@ under the License. --> -# Agent Instructions +# Agent Instructions for Contributors + +This file is for agents working **on** the datafusion-python project (developing, +testing, reviewing). If you need to **use** the DataFusion DataFrame API (write +queries, build expressions, understand available functions), see the user-facing +skill at [`SKILL.md`](SKILL.md). + +## Skills This project uses AI agent skills stored in `.ai/skills/`. Each skill is a directory containing a `SKILL.md` file with instructions for performing a specific task. @@ -26,6 +33,31 @@ Skills follow the [Agent Skills](https://agentskills.io) open standard. Each ski - `SKILL.md` — The skill definition with YAML frontmatter (name, description, argument-hint) and detailed instructions. - Additional supporting files as needed. +## Pull Requests + +Every pull request must follow the template in +`.github/pull_request_template.md`. The description must include these sections: + +1. **Which issue does this PR close?** — Link the issue with `Closes #NNN`. +2. **Rationale for this change** — Why the change is needed (skip if the issue + already explains it clearly). +3. **What changes are included in this PR?** — Summarize the individual changes. +4. **Are there any user-facing changes?** — Note any changes visible to users + (new APIs, changed behavior, new files shipped in the package, etc.). If + there are breaking changes to public APIs, add the `api change` label. + +## Pre-commit Checks + +Always run pre-commit checks **before** committing. The hooks are defined in +`.pre-commit-config.yaml` and run automatically on `git commit` if pre-commit +is installed as a git hook. To run all hooks manually: + +```bash +pre-commit run --all-files +``` + +Fix any failures before committing. + ## Python Function Docstrings Every Python function must include a docstring with usage examples. diff --git a/SKILL.md b/SKILL.md new file mode 100644 index 000000000..9ba1c0cac --- /dev/null +++ b/SKILL.md @@ -0,0 +1,733 @@ + + +--- +name: datafusion-python +description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. +--- + +# DataFusion Python DataFrame API Guide + +## What Is DataFusion? + +DataFusion is an **in-process query engine** built on Apache Arrow. It is not a +database -- there is no server, no connection string, and no external +dependencies. You create a `SessionContext`, point it at data (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API described below. + +All data flows through **Apache Arrow**. The canonical Python implementation is +PyArrow (`pyarrow.RecordBatch` / `pyarrow.Table`), but any library that +conforms to the [Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +can interoperate with DataFusion. + +## Core Abstractions + +| Abstraction | Role | Key import | +|---|---|---| +| `SessionContext` | Entry point. Loads data, runs SQL, produces DataFrames. | `from datafusion import SessionContext` | +| `DataFrame` | Lazy query builder. Each method returns a new DataFrame. | Returned by context methods | +| `Expr` | Expression tree node (column ref, literal, function call, ...). | `from datafusion import col, lit` | +| `functions` | 290+ built-in scalar, aggregate, and window functions. | `from datafusion import functions as F` | + +## Import Conventions + +```python +from datafusion import SessionContext, col, lit +from datafusion import functions as F +``` + +## Data Loading + +```python +ctx = SessionContext() + +# From files +df = ctx.read_parquet("path/to/data.parquet") +df = ctx.read_csv("path/to/data.csv") +df = ctx.read_json("path/to/data.json") + +# From Python objects +df = ctx.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) +df = ctx.from_pylist([{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]) +df = ctx.from_pandas(pandas_df) +df = ctx.from_polars(polars_df) +df = ctx.from_arrow(arrow_table) + +# From SQL +df = ctx.sql("SELECT a, b FROM my_table WHERE a > 1") +``` + +To make a DataFrame queryable by name in SQL, register it first: + +```python +ctx.register_parquet("my_table", "path/to/data.parquet") +ctx.register_csv("my_table", "path/to/data.csv") +``` + +## DataFrame Operations Quick Reference + +Every method returns a **new** DataFrame (immutable/lazy). Chain them fluently. + +### Projection + +```python +df.select("a", "b") # preferred: plain names as strings +df.select(col("a"), (col("b") + 1).alias("b_plus_1")) # use col()/Expr only when you need an expression + +df.with_column("new_col", col("a") + lit(10)) # add one column +df.with_columns( + col("a").alias("x"), + y=col("b") + lit(1), # named keyword form +) + +df.drop("unwanted_col") +df.with_column_renamed("old_name", "new_name") +``` + +When a column is referenced by name alone, pass the name as a string rather +than wrapping it in `col()`. Reach for `col()` only when the projection needs +arithmetic, aliasing, casting, or another expression operation. + +**Case sensitivity**: both `select("Name")` and `col("Name")` lowercase the +identifier. For a column whose real name has uppercase letters, embed double +quotes inside the string: `select('"MyCol"')` or `col('"MyCol"')`. Without the +inner quotes the lookup will fail with `No field named mycol`. + +### Filtering + +```python +df.filter(col("a") > 10) +df.filter(col("a") > 10, col("b") == "x") # multiple = AND +df.filter("a > 10") # SQL expression string +``` + +Raw Python values on the right-hand side of a comparison are auto-wrapped +into literals by the `Expr` operators, so prefer `col("a") > 10` over +`col("a") > lit(10)`. See the Comparisons section and pitfall #2 for the +full rule. + +### Aggregation + +```python +# GROUP BY a, compute sum(b) and count(*) +df.aggregate(["a"], [F.sum(col("b")), F.count(col("a"))]) + +# HAVING equivalent: use the filter keyword on the aggregate function +df.aggregate( + ["region"], + [F.sum(col("sales"), filter=col("sales") > 1000).alias("large_sales")], +) +``` + +As with `select()`, group keys can be passed as plain name strings. Reach for +`col(...)` only when the grouping expression needs arithmetic, aliasing, +casting, or another expression operation. + +Most aggregate functions accept an optional `filter` keyword argument. When +provided, only rows where the filter expression is true contribute to the +aggregate. + +### Sorting + +```python +df.sort(col("a")) # ascending (default) +df.sort(col("a").sort(ascending=False)) # descending +df.sort(col("a").sort(nulls_first=False)) # override null placement +``` + +A plain expression passed to `sort()` is already treated as ascending. Only +reach for `col(...).sort(...)` when you need to override a default (descending +order or null placement). Writing `col("a").sort(ascending=True)` is redundant. + +### Joining + +```python +# Equi-join on shared column name +df1.join(df2, on="key") +df1.join(df2, on="key", how="left") + +# Different column names +df1.join(df2, left_on="id", right_on="fk_id", how="inner") + +# Expression-based join (supports inequality predicates) +df1.join_on(df2, col("a") == col("b"), how="inner") + +# Semi join: keep rows from left where a match exists in right (like EXISTS) +df1.join(df2, on="key", how="semi") + +# Anti join: keep rows from left where NO match exists in right (like NOT EXISTS) +df1.join(df2, on="key", how="anti") +``` + +Join types: `"inner"`, `"left"`, `"right"`, `"full"`, `"semi"`, `"anti"`. + +Inner is the default `how`. Prefer `df1.join(df2, on="key")` over +`df1.join(df2, on="key", how="inner")` — drop `how=` unless you need a +non-inner join type. + +When the two sides' join columns have different native names, use +`left_on=`/`right_on=` with the original names rather than aliasing one side +to match the other — see pitfall #7. + +### Window Functions + +```python +from datafusion import WindowFrame + +# Row number partitioned by group, ordered by value +df.window( + F.row_number( + partition_by=[col("group")], + order_by=[col("value")], + ).alias("rn") +) + +# Using a Window object for reuse +from datafusion.expr import Window + +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], +) +df.select( + col("group"), + col("value"), + F.sum(col("value")).over(win).alias("running_total"), +) + +# With explicit frame bounds +win = Window( + partition_by=[col("group")], + order_by=[col("value").sort(ascending=True)], + window_frame=WindowFrame("rows", 0, None), # current row to unbounded following +) +``` + +### Set Operations + +```python +df1.union(df2) # UNION ALL (by position) +df1.union(df2, distinct=True) # UNION DISTINCT +df1.union_by_name(df2) # match columns by name, not position +df1.intersect(df2) # INTERSECT ALL +df1.intersect(df2, distinct=True) # INTERSECT (distinct) +df1.except_all(df2) # EXCEPT ALL +df1.except_all(df2, distinct=True) # EXCEPT (distinct) +``` + +### Limit and Offset + +```python +df.limit(10) # first 10 rows +df.limit(10, offset=20) # skip 20, then take 10 +``` + +### Deduplication + +```python +df.distinct() # remove duplicate rows +df.distinct_on( # keep first row per group (like DISTINCT ON in Postgres) + [col("a")], # uniqueness columns + [col("a"), col("b")], # output columns + [col("b").sort(ascending=True)], # which row to keep +) +``` + +## Executing and Collecting Results + +DataFrames are lazy until you collect. + +```python +df.show() # print formatted table to stdout +batches = df.collect() # list[pa.RecordBatch] +arr = df.collect_column("col_name") # pa.Array | pa.ChunkedArray (single column) +table = df.to_arrow_table() # pa.Table +pandas_df = df.to_pandas() # pd.DataFrame +polars_df = df.to_polars() # pl.DataFrame +py_dict = df.to_pydict() # dict[str, list] +py_list = df.to_pylist() # list[dict] +count = df.count() # int +``` + +### Date and Timestamp Type Conversion + +The Python type returned by `to_pydict()` / `to_pylist()` depends on the Arrow +column type, and the mapping is inherited from PyArrow: + +| Arrow type | Python type returned | +|---|---| +| `timestamp(s)` / `(ms)` / `(us)` | `datetime.datetime` | +| `timestamp(ns)` | `pandas.Timestamp` | +| `date32` / `date64` | `datetime.date` | +| `duration(s)` / `(ms)` / `(us)` | `datetime.timedelta` | +| `duration(ns)` | `pandas.Timedelta` | + +The nanosecond-precision fallback to pandas types is the main surprise: +pandas is not a hard dependency of `datafusion`, but PyArrow reaches for it +when `datetime.datetime` / `datetime.timedelta` would lose precision (stdlib +types only go to microseconds). If you need plain stdlib types, cast to a +coarser unit before collecting, e.g. +`df.select(col("ts").cast(pa.timestamp("us")))`. + +`df.to_pandas()` has its own footgun for dates: pandas has no pure-date dtype, +so a `date32`/`date64` column comes back as an `object` column of +`datetime.date` values rather than `datetime64[ns]`. If downstream code +expects a datetime column, cast on the DataFusion side first: +`col("ship_date").cast(pa.timestamp("ns"))`. + +### Streaming Results + +Prefer streaming over `collect()` when the result is too large to materialize +in memory, when you want to start processing before the query finishes, or +when you may break out of the loop early. `execute_stream()` pulls one +`RecordBatch` at a time from the execution plan rather than buffering the +whole result up front. + +```python +# Single-partition stream; batch is a datafusion.RecordBatch +stream = df.execute_stream() +for batch in stream: + process(batch.to_pyarrow()) # convert to pa.RecordBatch if needed + +# DataFrame is iterable directly (delegates to execute_stream) +for batch in df: + process(batch.to_pyarrow()) + +# One stream per partition, for parallel consumption +for stream in df.execute_stream_partitioned(): + for batch in stream: + process(batch.to_pyarrow()) +``` + +Async iteration is also supported via `async for batch in df: ...` (or +`df.execute_stream()`), which is useful when batches are interleaved with +other I/O. + +### Writing Results + +```python +df.write_parquet("output.parquet") +df.write_csv("output.csv") +df.write_json("output.json") +``` + +You can also pass a directory path (e.g., `"output/"`) to write a multi-file +partitioned output. + +## Expression Building + +### Column References and Literals + +```python +col("column_name") # reference a column +lit(42) # integer literal +lit("hello") # string literal +lit(3.14) # float literal +lit(pa.scalar(value)) # PyArrow scalar (preserves Arrow type) +``` + +`lit()` accepts PyArrow scalars directly -- prefer this over converting Arrow +data to Python and back when working with values extracted from query results. + +### Arithmetic + +```python +col("price") * col("quantity") # multiplication +col("a") + lit(1) # addition +col("a") - col("b") # subtraction +col("a") / lit(2) # division +col("a") % lit(3) # modulo +``` + +### Date Arithmetic + +`Date32` and `Date64` columns both require `Interval` types for arithmetic, +not `Duration`. Use PyArrow's `month_day_nano_interval` type, which takes a +`(months, days, nanos)` tuple: + +```python +import pyarrow as pa + +# Subtract 90 days from a date column +col("ship_date") - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) + +# Subtract 3 months +col("ship_date") - lit(pa.scalar((3, 0, 0), type=pa.month_day_nano_interval())) +``` + +**Important**: `lit(datetime.timedelta(days=90))` creates a `Duration(µs)` +literal, which is **not** compatible with `Date32`/`Date64` arithmetic +(`Duration(ms)` and `Duration(ns)` are rejected too). Always use +`pa.month_day_nano_interval()` for date operations. + +**Timestamps behave differently**: `Timestamp` columns *do* accept `Duration`, +so `col("ts") - lit(datetime.timedelta(days=1))` works. The interval-only +rule applies specifically to date columns. + +### Comparisons + +```python +col("a") > 10 +col("a") >= 10 +col("a") < 10 +col("a") <= 10 +col("a") == "x" +col("a") != "x" +col("a") == None # same as col("a").is_null() +col("a") != None # same as col("a").is_not_null() +``` + +Comparison operators auto-wrap the right-hand Python value into a literal, +so writing `col("a") > lit(10)` is redundant. Drop the `lit()` in +comparisons. Reach for `lit()` only when auto-wrapping does not apply — see +pitfall #2. + +### Boolean Logic + +**Important**: Python's `and`, `or`, `not` keywords do NOT work with Expr +objects. You must use the bitwise operators: + +```python +(col("a") > 1) & (col("b") < 10) # AND +(col("a") > 1) | (col("b") < 10) # OR +~(col("a") > 1) # NOT +``` + +Always wrap each comparison in parentheses when combining with `&`, `|`, `~` +because Python's operator precedence for bitwise operators is different from +logical operators. + +### Null Handling + +```python +col("a").is_null() +col("a").is_not_null() +col("a").fill_null(lit(0)) # replace NULL with a value +F.coalesce(col("a"), col("b")) # first non-null value +F.nullif(col("a"), lit(0)) # return NULL if a == 0 +``` + +### CASE / WHEN + +```python +# Simple CASE (matching on a single expression) +status_label = ( + F.case(col("status")) + .when(lit("A"), lit("Active")) + .when(lit("I"), lit("Inactive")) + .otherwise(lit("Unknown")) +) + +# Searched CASE (each branch has its own predicate) +severity = ( + F.when(col("value") > 100, lit("high")) + .when(col("value") > 50, lit("medium")) + .otherwise(lit("low")) +) +``` + +### Casting + +```python +import pyarrow as pa + +col("a").cast(pa.float64()) +col("a").cast(pa.utf8()) +col("a").cast(pa.date32()) +``` + +### Aliasing + +```python +(col("a") + col("b")).alias("total") +``` + +### BETWEEN and IN + +```python +col("a").between(lit(1), lit(10)) # 1 <= a <= 10 +F.in_list(col("a"), [lit(1), lit(2), lit(3)]) # a IN (1, 2, 3) +F.in_list(col("a"), [lit(1), lit(2)], negated=True) # a NOT IN (1, 2) +``` + +### Struct and Array Access + +```python +col("struct_col")["field_name"] # access struct field +col("array_col")[0] # access array element (0-indexed) +col("array_col")[1:3] # array slice (0-indexed) +``` + +## SQL-to-DataFrame Reference + +| SQL | DataFrame API | +|---|---| +| `SELECT a, b` | `df.select("a", "b")` | +| `SELECT a, b + 1 AS c` | `df.select(col("a"), (col("b") + lit(1)).alias("c"))` | +| `SELECT *, a + 1 AS c` | `df.with_column("c", col("a") + lit(1))` | +| `WHERE a > 10` | `df.filter(col("a") > 10)` | +| `GROUP BY a` with `SUM(b)` | `df.aggregate(["a"], [F.sum(col("b"))])` | +| `SUM(b) FILTER (WHERE b > 100)` | `F.sum(col("b"), filter=col("b") > 100)` | +| `ORDER BY a DESC` | `df.sort(col("a").sort(ascending=False))` | +| `LIMIT 10 OFFSET 5` | `df.limit(10, offset=5)` | +| `DISTINCT` | `df.distinct()` | +| `a INNER JOIN b ON a.id = b.id` | `a.join(b, on="id")` | +| `a LEFT JOIN b ON a.id = b.fk` | `a.join(b, left_on="id", right_on="fk", how="left")` | +| `WHERE EXISTS (SELECT ...)` | `a.join(b, on="key", how="semi")` | +| `WHERE NOT EXISTS (SELECT ...)` | `a.join(b, on="key", how="anti")` | +| `UNION ALL` | `df1.union(df2)` | +| `UNION` (distinct) | `df1.union(df2, distinct=True)` | +| `INTERSECT ALL` | `df1.intersect(df2)` | +| `INTERSECT` (distinct) | `df1.intersect(df2, distinct=True)` | +| `EXCEPT ALL` | `df1.except_all(df2)` | +| `EXCEPT` (distinct) | `df1.except_all(df2, distinct=True)` | +| `CASE x WHEN 1 THEN 'a' END` | `F.case(col("x")).when(lit(1), lit("a")).end()` | +| `CASE WHEN x > 1 THEN 'a' END` | `F.when(col("x") > 1, lit("a")).end()` | +| `x IN (1, 2, 3)` | `F.in_list(col("x"), [lit(1), lit(2), lit(3)])` | +| `x BETWEEN 1 AND 10` | `col("x").between(lit(1), lit(10))` | +| `CAST(x AS DOUBLE)` | `col("x").cast(pa.float64())` | +| `ROW_NUMBER() OVER (...)` | `F.row_number(partition_by=[...], order_by=[...])` | +| `SUM(x) OVER (...)` | `F.sum(col("x")).over(window)` | +| `x IS NULL` | `col("x").is_null()` | +| `COALESCE(a, b)` | `F.coalesce(col("a"), col("b"))` | + +## Common Pitfalls + +1. **Boolean operators**: Use `&`, `|`, `~` -- not Python's `and`, `or`, `not`. + Always parenthesize: `(col("a") > 1) & (col("b") < 2)`. + +2. **Wrapping scalars with `lit()`**: Prefer raw Python values on the + right-hand side of comparisons — `col("a") > 10`, `col("name") == "Alice"` + — because the Expr comparison operators auto-wrap them. Writing + `col("a") > lit(10)` is redundant. Reserve `lit()` for places where + auto-wrapping does *not* apply: + - standalone scalars passed into function calls: + `F.coalesce(col("a"), lit(0))`, not `F.coalesce(col("a"), 0)` + - arithmetic between two literals with no column involved: + `lit(1) - col("discount")` is fine, but `lit(1) - lit(2)` needs both + - values that must carry a specific Arrow type, via `lit(pa.scalar(...))` + - `.when(...)`, `.otherwise(...)`, `F.nullif(...)`, `.between(...)`, + `F.in_list(...)` and similar method/function arguments + +3. **Column name quoting**: Column names are normalized to lowercase by default + in both `select("...")` and `col("...")`. To reference a column with + uppercase letters, use double quotes inside the string: + `select('"MyColumn"')` or `col('"MyColumn"')`. + +4. **DataFrames are immutable**: Every method returns a **new** DataFrame. You + must capture the return value: + ```python + df = df.filter(col("a") > 1) # correct + df.filter(col("a") > 1) # WRONG -- result is discarded + ``` + +5. **Window frame defaults**: When using `order_by` in a window, the default + frame is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`. For a full + partition frame, set `window_frame=WindowFrame("rows", None, None)`. + +6. **Arithmetic on aggregates belongs in a later `select`, not inside + `aggregate`**: Each item in the aggregate list must be a single aggregate + call (optionally aliased). Combining aggregates with arithmetic inside + `aggregate(...)` fails with `Internal error: Invalid aggregate expression`. + Alias the aggregates, then compute the combination downstream: + ```python + # WRONG -- arithmetic wraps two aggregates + df.aggregate([], [(lit(100) * F.sum(col("a")) / F.sum(col("b"))).alias("ratio")]) + + # CORRECT -- aggregate first, then combine + (df.aggregate([], [F.sum(col("a")).alias("num"), F.sum(col("b")).alias("den")]) + .select((lit(100) * col("num") / col("den")).alias("ratio"))) + ``` + +7. **Don't alias a join column to match the other side**: When equi-joining + with `on="key"`, renaming the join column on one side via `.alias("key")` + in a fresh projection creates a schema where one side's `key` is + qualified (`?table?.key`) and the other is unqualified. The join then + fails with `Schema contains qualified field name ... and unqualified + field name ... which would be ambiguous`. Use `left_on=`/`right_on=` with + the native names, or use `join_on(...)` with an explicit equality. + ```python + # WRONG -- alias on one side produces ambiguous schema after join + failed = orders.select(col("o_orderkey").alias("l_orderkey")) + li.join(failed, on="l_orderkey") # ambiguous l_orderkey error + + # CORRECT -- keep native names, use left_on/right_on + failed = orders.select("o_orderkey") + li.join(failed, left_on="l_orderkey", right_on="o_orderkey") + + # ALSO CORRECT -- explicit predicate via join_on + # (note: join_on keeps both key columns in the output, unlike on="key") + li.join_on(failed, col("l_orderkey") == col("o_orderkey")) + ``` + +## Idiomatic Patterns + +### Fluent Chaining + +```python +result = ( + ctx.read_parquet("data.parquet") + .filter(col("year") >= 2020) + .select(col("region"), col("sales")) + .aggregate(["region"], [F.sum(col("sales")).alias("total")]) + .sort(col("total").sort(ascending=False)) + .limit(10) +) +result.show() +``` + +### Using Variables as CTEs + +Instead of SQL CTEs (`WITH ... AS`), assign intermediate DataFrames to +variables: + +```python +base = ctx.read_parquet("orders.parquet").filter(col("status") == "shipped") +by_region = base.aggregate(["region"], [F.sum(col("amount")).alias("total")]) +top_regions = by_region.filter(col("total") > 10000) +``` + +### Reusing Expressions as Variables + +Just like DataFrames, expressions (`Expr`) can be stored in variables and used +anywhere an `Expr` is expected. This is useful for building up complex +expressions or reusing a computed value across multiple operations: + +```python +# Build an expression and reuse it +disc_price = col("price") * (lit(1) - col("discount")) +df = df.select( + col("id"), + disc_price.alias("disc_price"), + (disc_price * (lit(1) + col("tax"))).alias("total"), +) + +# Use a collected scalar as an expression +max_val = result_df.collect_column("max_price")[0] # PyArrow scalar +cutoff = lit(max_val) - lit(pa.scalar((0, 90, 0), type=pa.month_day_nano_interval())) +df = df.filter(col("ship_date") <= cutoff) # cutoff is already an Expr +``` + +**Important**: Do not wrap an `Expr` in `lit()`. `lit()` is for converting +Python/PyArrow values into expressions. If a value is already an `Expr`, use it +directly. + +### Window Functions for Scalar Subqueries + +Where SQL uses a correlated scalar subquery, the idiomatic DataFrame approach +is a window function: + +```sql +-- SQL scalar subquery +SELECT *, (SELECT SUM(b) FROM t WHERE t.group = s.group) AS group_total FROM s +``` + +```python +# DataFrame: window function +win = Window(partition_by=[col("group")]) +df = df.with_column("group_total", F.sum(col("b")).over(win)) +``` + +### Semi/Anti Joins for EXISTS / NOT EXISTS + +```sql +-- SQL: WHERE EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="semi") + +-- SQL: WHERE NOT EXISTS (SELECT 1 FROM other WHERE other.key = main.key) +-- DataFrame: +result = main.join(other, on="key", how="anti") +``` + +### Computed Columns + +```python +# Add computed columns while keeping all originals +df = df.with_column("full_name", F.concat(col("first"), lit(" "), col("last"))) +df = df.with_column("discounted", col("price") * lit(0.9)) +``` + +## Available Functions (Categorized) + +The `functions` module (imported as `F`) provides 290+ functions. Key categories: + +**Aggregate**: `sum`, `avg`, `min`, `max`, `count`, `count_star`, `median`, +`stddev`, `stddev_pop`, `var_samp`, `var_pop`, `corr`, `covar`, `approx_distinct`, +`approx_median`, `approx_percentile_cont`, `array_agg`, `string_agg`, +`first_value`, `last_value`, `bit_and`, `bit_or`, `bit_xor`, `bool_and`, +`bool_or`, `grouping`, `regr_*` (9 regression functions) + +**Window**: `row_number`, `rank`, `dense_rank`, `percent_rank`, `cume_dist`, +`ntile`, `lag`, `lead`, `first_value`, `last_value`, `nth_value` + +**String**: `length`, `lower`, `upper`, `trim`, `ltrim`, `rtrim`, `lpad`, +`rpad`, `starts_with`, `ends_with`, `contains`, `substr`, `substring`, +`replace`, `reverse`, `repeat`, `split_part`, `concat`, `concat_ws`, +`initcap`, `ascii`, `chr`, `left`, `right`, `strpos`, `translate`, `overlay`, +`levenshtein` + +`F.substr(str, start)` takes **only two arguments** and returns the tail of +the string from `start` onward — passing a third length argument raises +`TypeError: substr() takes 2 positional arguments but 3 were given`. For the +SQL-style 3-arg form (`SUBSTRING(str FROM start FOR length)`), use +`F.substring(col("s"), lit(start), lit(length))`. For a fixed-length prefix, +`F.left(col("s"), lit(n))` is cleanest. + +```python +# WRONG — substr does not accept a length argument +F.substr(col("c_phone"), lit(1), lit(2)) +# CORRECT +F.substring(col("c_phone"), lit(1), lit(2)) # explicit length +F.left(col("c_phone"), lit(2)) # prefix shortcut +``` + +**Math**: `abs`, `ceil`, `floor`, `round`, `trunc`, `sqrt`, `cbrt`, `exp`, +`ln`, `log`, `log2`, `log10`, `pow`, `signum`, `pi`, `random`, `factorial`, +`gcd`, `lcm`, `greatest`, `least`, sin/cos/tan and inverse/hyperbolic variants + +**Date/Time**: `now`, `today`, `current_date`, `current_time`, +`current_timestamp`, `date_part`, `date_trunc`, `date_bin`, `extract`, +`to_timestamp`, `to_timestamp_millis`, `to_timestamp_micros`, +`to_timestamp_nanos`, `to_timestamp_seconds`, `to_unixtime`, `from_unixtime`, +`make_date`, `make_time`, `to_date`, `to_time`, `to_local_time`, `date_format` + +**Conditional**: `case`, `when`, `coalesce`, `nullif`, `ifnull`, `nvl`, `nvl2` + +**Array/List**: `array`, `make_array`, `array_agg`, `array_length`, +`array_element`, `array_slice`, `array_append`, `array_prepend`, +`array_concat`, `array_has`, `array_has_all`, `array_has_any`, `array_position`, +`array_remove`, `array_distinct`, `array_sort`, `array_reverse`, `flatten`, +`array_to_string`, `array_intersect`, `array_union`, `array_except`, +`generate_series` +(Most `array_*` functions also have `list_*` aliases.) + +**Struct/Map**: `struct`, `named_struct`, `get_field`, `make_map`, `map_keys`, +`map_values`, `map_entries`, `map_extract` + +**Regex**: `regexp_like`, `regexp_match`, `regexp_replace`, `regexp_count`, +`regexp_instr` + +**Hash**: `md5`, `sha224`, `sha256`, `sha384`, `sha512`, `digest` + +**Type**: `arrow_typeof`, `arrow_cast`, `arrow_metadata` + +**Other**: `in_list`, `order_by`, `alias`, `col`, `encode`, `decode`, +`to_hex`, `to_char`, `uuid`, `version`, `bit_length`, `octet_length` diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 80dfa2fab..e4972411a 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,10 +15,44 @@ # specific language governing permissions and limitations # under the License. -"""DataFusion python package. - -This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. -See https://datafusion.apache.org/python for more information. +"""DataFusion: an in-process query engine built on Apache Arrow. + +DataFusion is not a database -- it has no server and no external dependencies. +You create a :py:class:`SessionContext`, point it at data sources (Parquet, CSV, +JSON, Arrow IPC, Pandas, Polars, or raw Python dicts/lists), and run queries +using either SQL or the DataFrame API. + +Core abstractions +----------------- +- **SessionContext** -- entry point for loading data, running SQL, and creating + DataFrames. +- **DataFrame** -- lazy query builder. Every method returns a new DataFrame; + call :py:meth:`~datafusion.dataframe.DataFrame.collect` or a ``to_*`` + method to execute. +- **Expr** -- expression tree node for column references, literals, and function + calls. Build with :py:func:`col` and :py:func:`lit`. +- **functions** -- 290+ built-in scalar, aggregate, and window functions. + +Quick start +----------- + +>>> from datafusion import SessionContext, col +>>> from datafusion import functions as F +>>> ctx = SessionContext() +>>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) +>>> result = ( +... df.filter(col("a") > 1) +... .with_column("total", col("a") + col("b")) +... .aggregate([], [F.sum(col("total")).alias("grand_total")]) +... ) +>>> result.to_pydict() +{'grand_total': [16]} + +User guide and full documentation: https://datafusion.apache.org/python + +AI agent reference (SQL-to-DataFrame mappings, expression-building patterns, +common pitfalls), written in a dense, skill-oriented format: +https://github.com/apache/datafusion-python/blob/main/SKILL.md """ from __future__ import annotations From 8a5d783c7e418bfbbd95e48a2d9cacafea6162c7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 19:05:06 -0400 Subject: [PATCH 2/6] Skills require the header to be the first thing in the file which conflicts with the RAT check. Make an exception for this file. (#1501) --- SKILL.md | 19 ------------------- dev/release/rat_exclude_files.txt | 3 ++- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/SKILL.md b/SKILL.md index 9ba1c0cac..14ea5c609 100644 --- a/SKILL.md +++ b/SKILL.md @@ -1,22 +1,3 @@ - - --- name: datafusion-python description: Use when the user is writing datafusion-python (Apache DataFusion Python bindings) DataFrame or SQL code. Covers imports, data loading, DataFrame operations, expression building, SQL-to-DataFrame mappings, idiomatic patterns, and common pitfalls. diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index b2db144e8..a7a497dab 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -48,4 +48,5 @@ benchmarks/tpch/create_tables.sql .cargo/config.toml **/.cargo/config.toml uv.lock -examples/tpch/answers_sf1/*.tbl \ No newline at end of file +examples/tpch/answers_sf1/*.tbl +SKILL.md \ No newline at end of file From 8741d30cd812e4668f3f9187b56f12ce2de0d6e7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 23 Apr 2026 22:01:01 -0400 Subject: [PATCH 3/6] docs: enrich module docstrings and add doctest examples (#1498) * Enrich module docstrings and add doctest examples Expands the module docstrings for `functions.py`, `dataframe.py`, `expr.py`, and `context.py` so each module opens with a concept summary, cross-references to related APIs, and a small executable example. Adds doctest examples to the high-traffic `DataFrame` methods that previously lacked them: `select`, `aggregate`, `sort`, `limit`, `join`, and `union`. Optional parameters are demonstrated with keyword syntax, and examples reuse the same input data across variants so the effect of each option is easy to see. Co-Authored-By: Claude Opus 4.7 (1M context) * Use distinct group sums in aggregate docstring example Change the score data from [1, 2, 3] to [1, 2, 5] so the grouped result produces [3, 5] instead of [3, 3], removing ambiguity about which total belongs to which team. Co-Authored-By: Claude Opus 4.7 (1M context) * Align module-docstring examples with SKILL.md idioms Drop the redundant lit() in the dataframe.py module-docstring filter example and use a plain string group key in the aggregate() doctest, so both examples model the style SKILL.md recommends. Also document the sort("a") string form and sort_by() shortcut in SKILL.md's sorting section. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- SKILL.md | 16 +++- python/datafusion/context.py | 27 ++++++- python/datafusion/dataframe.py | 135 ++++++++++++++++++++++++++++++--- python/datafusion/expr.py | 28 ++++++- python/datafusion/functions.py | 22 +++++- 5 files changed, 209 insertions(+), 19 deletions(-) diff --git a/SKILL.md b/SKILL.md index 14ea5c609..7b07b430f 100644 --- a/SKILL.md +++ b/SKILL.md @@ -128,14 +128,22 @@ aggregate. ### Sorting ```python -df.sort(col("a")) # ascending (default) +df.sort("a") # ascending (plain name, preferred) +df.sort(col("a")) # ascending via col() df.sort(col("a").sort(ascending=False)) # descending df.sort(col("a").sort(nulls_first=False)) # override null placement + +df.sort_by("a", "b") # ascending-only shortcut ``` -A plain expression passed to `sort()` is already treated as ascending. Only -reach for `col(...).sort(...)` when you need to override a default (descending -order or null placement). Writing `col("a").sort(ascending=True)` is redundant. +As with `select()` and `aggregate()`, bare column references can be passed as +plain name strings. A plain expression passed to `sort()` is already treated +as ascending, so reach for `col(...).sort(...)` only when you need to override +a default (descending order or null placement). Writing +`col("a").sort(ascending=True)` is redundant. + +For ascending-only sorts with no null-placement override, `df.sort_by(...)` is +a shorter alias for `df.sort(...)`. ### Joining diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c3f94cc16..dd6790402 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -15,7 +15,32 @@ # specific language governing permissions and limitations # under the License. -"""Session Context and it's associated configuration.""" +""":py:class:`SessionContext` — entry point for running DataFusion queries. + +A :py:class:`SessionContext` holds registered tables, catalogs, and +configuration for the current session. It is the first object most programs +create: from it you register data, run SQL strings +(:py:meth:`SessionContext.sql`), read files +(:py:meth:`SessionContext.read_csv`, +:py:meth:`SessionContext.read_parquet`, ...), and construct +:py:class:`~datafusion.dataframe.DataFrame` objects in memory +(:py:meth:`SessionContext.from_pydict`, +:py:meth:`SessionContext.from_arrow`). + +Session behavior (memory limits, batch size, configured optimizer passes, +...) is controlled by :py:class:`SessionConfig` and +:py:class:`RuntimeEnvBuilder`; SQL dialect limits are controlled by +:py:class:`SQLOptions`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> ctx.sql("SELECT 1 AS n").to_pydict() + {'n': [1]} + +See :ref:`user_guide_concepts` in the online documentation for the broader +execution model. +""" from __future__ import annotations diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c00c85fdb..2b07861da 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -14,9 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""":py:class:`DataFrame` is one of the core concepts in DataFusion. - -See :ref:`user_guide_concepts` in the online documentation for more information. +""":py:class:`DataFrame` — lazy, chainable query representation. + +A :py:class:`DataFrame` is a logical plan over one or more data sources. +Methods that reshape the plan (:py:meth:`DataFrame.select`, +:py:meth:`DataFrame.filter`, :py:meth:`DataFrame.aggregate`, +:py:meth:`DataFrame.sort`, :py:meth:`DataFrame.join`, +:py:meth:`DataFrame.limit`, the set-operation methods, ...) return a new +:py:class:`DataFrame` and do no work until a terminal method such as +:py:meth:`DataFrame.collect`, :py:meth:`DataFrame.to_pydict`, +:py:meth:`DataFrame.show`, or one of the ``write_*`` methods is called. + +DataFrames are produced from a +:py:class:`~datafusion.context.SessionContext`, typically via +:py:meth:`~datafusion.context.SessionContext.sql`, +:py:meth:`~datafusion.context.SessionContext.read_csv`, +:py:meth:`~datafusion.context.SessionContext.read_parquet`, or +:py:meth:`~datafusion.context.SessionContext.from_pydict`. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.filter(col("a") > 1).select("b").to_pydict() + {'b': [20, 30]} + +See :ref:`user_guide_concepts` in the online documentation for a high-level +overview of the execution model. """ from __future__ import annotations @@ -503,21 +526,29 @@ def select_exprs(self, *args: str) -> DataFrame: def select(self, *exprs: Expr | str) -> DataFrame: """Project arbitrary expressions into a new :py:class:`DataFrame`. + String arguments are treated as column names; :py:class:`~datafusion.expr.Expr` + arguments can reshape, rename, or compute new columns. + Args: exprs: Either column names or :py:class:`~datafusion.expr.Expr` to select. Returns: DataFrame after projection. It has one column for each expression. - Example usage: + Examples: + Select columns by name: - The following example will return 3 columns from the original dataframe. - The first two columns will be the original column ``a`` and ``b`` since the - string "a" is assumed to refer to column selection. Also a duplicate of - column ``a`` will be returned with the column name ``alternate_a``:: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]}) + >>> df.select("a").to_pydict() + {'a': [1, 2, 3]} - df = df.select("a", col("b"), col("a").alias("alternate_a")) + Mix column names, expressions, and aliases. The string ``"a"`` selects + column ``a`` directly; ``col("a").alias("alternate_a")`` returns a + duplicate under a new name: + >>> df.select("a", col("b"), col("a").alias("alternate_a")).to_pydict() + {'a': [1, 2, 3], 'b': [10, 20, 30], 'alternate_a': [1, 2, 3]} """ exprs_internal = expr_list_to_raw_expr_list(exprs) return DataFrame(self.df.select(*exprs_internal)) @@ -766,6 +797,24 @@ def aggregate( Returns: DataFrame after aggregation. + + Examples: + Aggregate without grouping — an empty ``group_by`` produces a + single row: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"team": ["x", "x", "y"], "score": [1, 2, 5]} + ... ) + >>> df.aggregate([], [F.sum(col("score")).alias("total")]).to_pydict() + {'total': [8]} + + Group by a column and produce one row per group: + + >>> df.aggregate( + ... ["team"], [F.sum(col("score")).alias("total")] + ... ).sort("team").to_pydict() + {'team': ['x', 'y'], 'total': [3, 5]} """ group_by_list = ( list(group_by) @@ -786,13 +835,27 @@ def sort(self, *exprs: SortKey) -> DataFrame: """Sort the DataFrame by the specified sorting expressions or column names. Note that any expression can be turned into a sort expression by - calling its ``sort`` method. + calling its ``sort`` method. For ascending-only sorts, the shorter + :py:meth:`sort_by` is usually more convenient. Args: exprs: Sort expressions or column names, applied in order. Returns: DataFrame after sorting. + + Examples: + Sort ascending by a column name: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [3, 1, 2], "b": [10, 20, 30]}) + >>> df.sort("a").to_pydict() + {'a': [1, 2, 3], 'b': [20, 30, 10]} + + Sort descending using :py:meth:`Expr.sort`: + + >>> df.sort(col("a").sort(ascending=False)).to_pydict() + {'a': [3, 2, 1], 'b': [10, 30, 20]} """ exprs_raw = sort_list_to_raw_sort_list(exprs) return DataFrame(self.df.sort(*exprs_raw)) @@ -812,12 +875,28 @@ def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame: def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new :py:class:`DataFrame` with a limited number of rows. + Results are returned in unspecified order unless the DataFrame is + explicitly sorted first via :py:meth:`sort` or :py:meth:`sort_by`. + Args: count: Number of rows to limit the DataFrame to. offset: Number of rows to skip. Returns: DataFrame after limiting. + + Examples: + Take the first two rows: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}).sort("a") + >>> df.limit(2).to_pydict() + {'a': [1, 2]} + + Skip the first row then take two (paging): + + >>> df.limit(2, offset=1).to_pydict() + {'a': [2, 3]} """ return DataFrame(self.df.limit(count, offset)) @@ -972,6 +1051,28 @@ def join( Returns: DataFrame after join. + + Examples: + Inner-join two DataFrames on a shared column: + + >>> ctx = dfn.SessionContext() + >>> left = ctx.from_pydict({"id": [1, 2, 3], "val": [10, 20, 30]}) + >>> right = ctx.from_pydict({"id": [2, 3, 4], "label": ["b", "c", "d"]}) + >>> left.join(right, on="id").sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'label': ['b', 'c']} + + Left join to keep all rows from the left side: + + >>> left.join(right, on="id", how="left").sort("id").to_pydict() + {'id': [1, 2, 3], 'val': [10, 20, 30], 'label': [None, 'b', 'c']} + + Use ``left_on`` / ``right_on`` when the key columns differ in name: + + >>> right2 = ctx.from_pydict({"rid": [2, 3], "label": ["b", "c"]}) + >>> left.join( + ... right2, left_on="id", right_on="rid" + ... ).sort("id").to_pydict() + {'id': [2, 3], 'val': [20, 30], 'rid': [2, 3], 'label': ['b', 'c']} """ if join_keys is not None: warnings.warn( @@ -1165,6 +1266,20 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: Returns: DataFrame after union. + + Examples: + Stack rows from both DataFrames, preserving duplicates: + + >>> ctx = dfn.SessionContext() + >>> df1 = ctx.from_pydict({"a": [1, 2]}) + >>> df2 = ctx.from_pydict({"a": [2, 3]}) + >>> df1.union(df2).sort("a").to_pydict() + {'a': [1, 2, 2, 3]} + + Deduplicate the combined result with ``distinct=True``: + + >>> df1.union(df2, distinct=True).sort("a").to_pydict() + {'a': [1, 2, 3]} """ return DataFrame(self.df.union(other.df, distinct)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 32004656f..1ff6976f7 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,31 @@ # specific language governing permissions and limitations # under the License. -"""This module supports expressions, one of the core concepts in DataFusion. - -See :ref:`Expressions` in the online documentation for more details. +""":py:class:`Expr` — the logical expression type used to build DataFusion queries. + +An :py:class:`Expr` represents a computation over columns or literals: a +column reference (``col("a")``), a literal (``lit(5)``), an operator +combination (``col("a") + lit(1)``), or the output of a function from +:py:mod:`datafusion.functions`. Expressions are passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.sort`. + +Convenience constructors are re-exported at the package level: +:py:func:`datafusion.col` / :py:func:`datafusion.column` for column references +and :py:func:`datafusion.lit` / :py:func:`datafusion.literal` for scalar +literals. + +Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.select((col("a") * lit(10)).alias("ten_a")).to_pydict() + {'ten_a': [10, 20, 30]} + +See :ref:`expressions` in the online documentation for details on available +operators and helpers. """ # ruff: noqa: PLC0415 diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 841cd9c0b..280a6d3ac 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,7 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User functions for operating on :py:class:`~datafusion.expr.Expr`.""" +"""Scalar, aggregate, and window functions for :py:class:`~datafusion.expr.Expr`. + +Each function returns an :py:class:`~datafusion.expr.Expr` that can be combined +with other expressions and passed to +:py:class:`~datafusion.dataframe.DataFrame` methods such as +:py:meth:`~datafusion.dataframe.DataFrame.select`, +:py:meth:`~datafusion.dataframe.DataFrame.filter`, +:py:meth:`~datafusion.dataframe.DataFrame.aggregate`, and +:py:meth:`~datafusion.dataframe.DataFrame.window`. The module is conventionally +imported as ``F`` so calls read like ``F.sum(col("price"))``. + +Examples: + >>> from datafusion import functions as F + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 2, 3, 4]}) + >>> df.aggregate([], [F.sum(col("a")).alias("total")]).to_pydict() + {'total': [10]} + +See :ref:`aggregation` and :ref:`window_functions` in the online documentation +for categorized catalogs of aggregate and window functions. +""" from __future__ import annotations From c8bb9f7d3876de97141d204740a6b99d5facd10f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 07:57:11 -0400 Subject: [PATCH 4/6] docs: add README section for AI coding assistants (#1503) Points users to the repo-root SKILL.md via the npx skills registry or a manual AGENTS.md / CLAUDE.md pointer. Implements PR 1c of the plan in #1394. Co-authored-by: Claude Opus 4.7 (1M context) --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index 7849e7a02..4baed7d1d 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,22 @@ You can verify the installation by running: '0.6.0' ``` +## Using DataFusion with AI coding assistants + +This project ships a [`SKILL.md`](SKILL.md) at the repo root that teaches AI +coding assistants how to write idiomatic DataFusion Python. It follows the +[Agent Skills](https://agentskills.io) open standard. + +**Preferred:** `npx skills add apache/datafusion-python` — installs the skill in +Claude Code, Cursor, Windsurf, Cline, Codex, Copilot, Gemini CLI, and other +supported agents. + +**Manual:** paste this line into your project's `AGENTS.md` / `CLAUDE.md`: + +``` +For DataFusion Python code, see https://github.com/apache/datafusion-python/blob/main/SKILL.md +``` + ## How to develop This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). The Maturin tools used in this workflow can be installed either via `uv` or `pip`. Both approaches should offer the same experience. It is recommended to use `uv` since it has significant performance improvements From 03577163a057f791b19f30ce5130464a4a1c78a4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 11:47:06 -0400 Subject: [PATCH 5/6] tpch examples: rewrite queries idiomatically and embed reference SQL (#1504) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tpch examples: add reference SQL to each query, fix Q20 - Append the canonical TPC-H reference SQL (from benchmarks/tpch/queries/) to each q01..q22 module docstring so readers can compare the DataFrame translation against the SQL at a glance. - Fix Q20: `df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))` was missing the assignment so the filter was dropped from the pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: rewrite non-idiomatic queries in idiomatic DataFrame form Rewrite the seven TPC-H example queries that did not demonstrate the idiomatic DataFrame pattern. The remaining queries (Q02/Q11/Q15/Q17/Q22, which use window functions in place of correlated subqueries) already are idiomatic and are left unchanged. - Q04: replace `.aggregate([col("l_orderkey")], [])` with `.select("l_orderkey").distinct()`, which is the natural way to express "reduce to one row per order" on a DataFrame. - Q07: remove the CASE-as-filter on `n_name` and use `F.in_list(col("n_name"), [nation_1, nation_2])` instead. Drops a comment block that admitted the filter form was simpler. - Q08: rewrite the switched CASE `F.case(...).when(lit(False), ...)` as a searched `F.when(col(...).is_not_null(), ...).otherwise(...)`. That mirrors the reference SQL's `case when ... then ... else 0 end` shape. - Q12: replace `array_position(make_array(...), col)` with `F.in_list(col("l_shipmode"), [...])`. Same semantics, without routing through array construction / array search. - Q19: remove the pyarrow UDF that re-implemented a disjunctive predicate in Python. Build the same predicate in DataFusion by OR-combining one `in_list` + range-filter expression per brand. Keeps the per-brand constants in the existing `items_of_interest` dict. - Q20: use `F.starts_with` instead of an explicit substring slice. Replace the inner-join + `select(...).distinct()` tail with a semi join against a precomputed set of excess-quantity suppliers so the supplier columns are preserved without deduplication after the fact. - Q21: replace the `array_agg` / `array_length` / `array_element` pipeline with two semi joins. One semi join keeps orders with more than one distinct supplier (stand-in for the reference SQL's `exists` subquery), the other keeps orders with exactly one late supplier (stand-in for the `not exists` subquery). All 22 answer-file comparisons and 22 plan-comparison diagnostics still pass (`pytest examples/tpch/_tests.py`: 44 passed). Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: align reference SQL constants with DataFrame queries The reference SQL embedded in each q01..q22 module docstring was carried over verbatim from ``benchmarks/tpch/queries/`` and uses a different set of TPC-H substitution parameters than the DataFrame examples (answer-file-validated at scale factor 1). Update each reference SQL to use the substitution parameters the DataFrame uses, so both expressions describe the same query and would produce the same results against the same data. Constants aligned: - Q01: ``90 days`` cutoff (DataFrame ``DAYS_BEFORE_FINAL = 90``). - Q02: ``p_size = 15``, ``p_type like '%BRASS'``, ``r_name = 'EUROPE'``. - Q04: base date ``1993-07-01`` (``3 month`` interval preserved per the "quarter of a year" wording). - Q05: ``r_name = 'ASIA'``. - Q06: ``l_discount between 0.06 - 0.01 and 0.06 + 0.01``. - Q07: nations ``'FRANCE'`` / ``'GERMANY'``. - Q08: ``r_name = 'AMERICA'``, ``p_type = 'ECONOMY ANODIZED STEEL'``, inner-case ``nation = 'BRAZIL'``. - Q09: ``p_name like '%green%'``. - Q10: base date ``1993-10-01`` (``3 month`` interval preserved). - Q11: ``n_name = 'GERMANY'``. - Q12: ship modes ``('MAIL', 'SHIP')``, base date ``1994-01-01``. - Q13: ``o_comment not like '%special%requests%'``. - Q14: base date ``1995-09-01``. - Q15: base date ``1996-01-01``. - Q16: ``p_brand <> 'Brand#45'``, ``p_type not like 'MEDIUM POLISHED%'``, sizes ``(49, 14, 23, 45, 19, 3, 36, 9)``. - Q17: ``p_brand = 'Brand#23'``, ``p_container = 'MED BOX'``. - Q18: ``sum(l_quantity) > 300``. - Q19: brands ``Brand#12`` / ``Brand#23`` / ``Brand#34`` with the matching minimum quantities (1, 10, 20). - Q20: ``p_name like 'forest%'``, base date ``1994-01-01``, ``n_name = 'CANADA'``. - Q21: ``n_name = 'SAUDI ARABIA'``. - Q22: country codes ``('13', '31', '23', '29', '30', '18', '17')``. Interval units (month / year) are preserved where the problem-statement text reads "given quarter", "given year", "given month". Q01 keeps the literal "days" unit because the TPC-H problem statement itself describes the cutoff in days. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: apply SKILL.md idioms across all 22 queries Sweep every q01..q22 example for idiomatic DataFrame style as described in the repo-root SKILL.md: - ``col("x") == "s"`` in place of ``col("x") == lit("s")`` on comparison right-hand sides (auto-wrap applies). - Plain-name strings in ``select``/``aggregate``/``sort`` group/sort key lists when the key is a bare column. - Drop redundant ``how="inner"`` and single-element ``left_on``/``right_on`` list wrapping on equi-joins. - Collapse chained ``.filter(a).filter(b)`` runs into ``.filter(a, b)`` and chained ``.with_column`` runs into ``.with_columns(a=..., b=...)``. - ``df.sort_by(...)`` or plain-name ``df.sort(...)`` when no null-placement override is needed. - ``F.count_star()`` in place of ``F.count(col("x"))`` whenever the SQL reads ``count(*)``. - ``F.starts_with(col, lit(prefix))`` and ``~F.starts_with(...)`` in place of substring-prefix equality/inequality tricks. - ``F.in_list(col, [lit(...)])`` in place of ``~F.array_position(...). is_null()`` and in place of disjunctions of equality comparisons. - Searched ``F.when(cond, x).otherwise(y)`` in place of switched ``F.case(bool_expr).when(lit(True/False), x).end()`` forms. - Semi-joins as the DataFrame form of ``EXISTS`` (Q04); anti-joins as ``NOT EXISTS`` (Q22 was already using this idiom). - Whole-frame window aggregates as the DataFrame stand-in for a SQL scalar subquery (Q11/Q15/Q17/Q22). Individual query fixes of note: - Q16 — add the secondary sort keys (``p_brand``, ``p_type``, ``p_size``) that the TPC-H spec requires but the original DataFrame omitted. - Q22 — drop a stray ``df.show()`` mid-pipeline; replace the 0-based substring slice with ``F.left(col("c_phone"), lit(2))``. - Q14 — rewrite the promo/non-promo factor split as a searched CASE inside ``F.sum(...)`` so the DataFrame expression matches the reference SQL shape exactly. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) * tpch examples: more idiomatic aggregate FILTER, string funcs, date handling Additional sweep of the TPC-H DataFrame examples informed by comparing against a fresh set of SKILL.md-only generations under ``examples/tpch/agentic_queries/``: - Q02: ``F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST))`` in place of ``F.strpos(col, lit) > 0``. The reference SQL is ``p_type like '%BRASS'``, which is an ends_with check, not contains. ``F.strpos > 0`` returned the correct rows on TPC-H data by coincidence but is semantically wrong. - Q09: ``F.contains(col("p_name"), lit(part_color))`` in place of ``F.strpos(col, lit) > 0``. The SQL is ``p_name like '%green%'``. - Q08, Q12, Q14: use the ``filter`` keyword on ``F.sum`` / ``F.count`` — the DataFrame form of SQL ``sum(...) FILTER (WHERE ...)`` — instead of wrapping the aggregate input in ``F.when(cond, x).otherwise(0)``. Q08 also reorganises to inner-join the supplier's nation onto the regional sales, which removes the previous left-join + ``F.when(is_not_null, ...)`` dance. - Q15: compute the grand maximum revenue as a separate scalar aggregate and ``join_on(...)`` on equality, instead of the whole-frame window ``F.max`` + filter shape. Simpler plan, same result. - Q16: ``F.regexp_like(col, pattern)`` in place of ``F.regexp_match(col, pattern).is_not_null()``. - Q04, Q05, Q06, Q07, Q08, Q10, Q12, Q14, Q15, Q20: store both the start and the end of the date window as plain ``datetime.date`` objects and compare with ``lit(end_date)``, instead of carrying the start date + ``pa.month_day_nano_interval`` and adding them at query-build time. Drops unused ``pyarrow`` imports from the files that no longer need Arrow scalars. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- examples/tpch/q01_pricing_summary_report.py | 44 ++-- examples/tpch/q02_minimum_cost_supplier.py | 87 ++++++-- examples/tpch/q03_shipping_priority.py | 51 +++-- examples/tpch/q04_order_priority_checking.py | 67 +++--- examples/tpch/q05_local_supplier_volume.py | 66 +++--- .../tpch/q06_forecasting_revenue_change.py | 35 +-- examples/tpch/q07_volume_shipping.py | 103 +++++---- examples/tpch/q08_market_share.py | 205 +++++++++--------- .../tpch/q09_product_type_profit_measure.py | 77 +++++-- examples/tpch/q10_returned_item_reporting.py | 102 +++++---- .../q11_important_stock_identification.py | 83 ++++--- examples/tpch/q12_ship_mode_order_priority.py | 108 +++++---- examples/tpch/q13_customer_distribution.py | 47 ++-- examples/tpch/q14_promotion_effect.py | 81 +++---- examples/tpch/q15_top_supplier.py | 94 ++++---- .../tpch/q16_part_supplier_relationship.py | 81 ++++--- examples/tpch/q17_small_quantity_order.py | 58 +++-- examples/tpch/q18_large_volume_customer.py | 71 ++++-- examples/tpch/q19_discounted_revenue.py | 134 ++++++------ examples/tpch/q20_potential_part_promotion.py | 120 ++++++---- .../tpch/q21_suppliers_kept_orders_waiting.py | 134 +++++++----- examples/tpch/q22_global_sales_opportunity.py | 104 ++++++--- 22 files changed, 1196 insertions(+), 756 deletions(-) diff --git a/examples/tpch/q01_pricing_summary_report.py b/examples/tpch/q01_pricing_summary_report.py index 3f97f00dc..105f1632d 100644 --- a/examples/tpch/q01_pricing_summary_report.py +++ b/examples/tpch/q01_pricing_summary_report.py @@ -27,6 +27,30 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from + lineitem + where + l_shipdate <= date '1998-12-01' - interval '90 days' + group by + l_returnflag, + l_linestatus + order by + l_returnflag, + l_linestatus; """ import pyarrow as pa @@ -58,31 +82,25 @@ # Aggregate the results +disc_price = col("l_extendedprice") * (lit(1) - col("l_discount")) + df = df.aggregate( - [col("l_returnflag"), col("l_linestatus")], + ["l_returnflag", "l_linestatus"], [ F.sum(col("l_quantity")).alias("sum_qty"), F.sum(col("l_extendedprice")).alias("sum_base_price"), - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "sum_disc_price" - ), - F.sum( - col("l_extendedprice") - * (lit(1) - col("l_discount")) - * (lit(1) + col("l_tax")) - ).alias("sum_charge"), + F.sum(disc_price).alias("sum_disc_price"), + F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"), F.avg(col("l_quantity")).alias("avg_qty"), F.avg(col("l_extendedprice")).alias("avg_price"), F.avg(col("l_discount")).alias("avg_disc"), - F.count(col("l_returnflag")).alias( - "count_order" - ), # Counting any column should return same result + F.count_star().alias("count_order"), ], ) # Sort per the expected result -df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort()) +df = df.sort_by("l_returnflag", "l_linestatus") # Note: There appears to be a discrepancy between what is returned here and what is in the generated # answers file for the case of return flag N and line status O, but I did not investigate further. diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 47961d2ef..c5c6b9c0b 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -27,6 +27,52 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment + from + part, + supplier, + partsupp, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) + order by + s_acctbal desc, + n_name, + s_name, + p_partkey limit 100; """ import datafusion @@ -67,35 +113,30 @@ "r_regionkey", "r_name" ) -# Filter down parts. Part names contain the type of interest, so we can use strpos to find where -# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position -# in the string where it is located. +# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which +# is an ``ends_with`` check; use the dedicated string function rather than +# a manual substring match. df_part = df_part.filter( - F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0) -).filter(col("p_size") == lit(SIZE_OF_INTEREST)) + F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)), + col("p_size") == SIZE_OF_INTEREST, +) # Filter regions down to the one of interest -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join( - df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" -) -df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -) +df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey") +df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join( - df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" -) +df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey") # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -112,9 +153,9 @@ ), ) -df = df.filter(col("min_cost") == col("ps_supplycost")) - -df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") +df = df.filter(col("min_cost") == col("ps_supplycost")).join( + df_part, left_on="ps_partkey", right_on="p_partkey" +) # From the problem statement, these are the values we wish to output @@ -132,12 +173,10 @@ # Sort and display 100 entries df = df.sort( col("s_acctbal").sort(ascending=False), - col("n_name").sort(), - col("s_name").sort(), - col("p_partkey").sort(), -) - -df = df.limit(100) + "n_name", + "s_name", + "p_partkey", +).limit(100) # Show results diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index fc1231e0a..880c7435f 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -25,6 +25,31 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority + from + customer, + orders, + lineitem + where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' + group by + l_orderkey, + o_orderdate, + o_shippriority + order by + revenue desc, + o_orderdate limit 10; """ from datafusion import SessionContext, col, lit @@ -50,20 +75,20 @@ # Limit dataframes to the rows of interest -df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST)) +df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST) df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST)) df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST)) # Join all 3 dataframes -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") +df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join( + df_lineitem, left_on="o_orderkey", right_on="l_orderkey" +) # Compute the revenue df = df.aggregate( - [col("l_orderkey")], + ["l_orderkey"], [ F.first_value(col("o_orderdate")).alias("o_orderdate"), F.first_value(col("o_shippriority")).alias("o_shippriority"), @@ -71,17 +96,13 @@ ], ) -# Sort by priority - -df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort()) - -# Only return 10 results +# Sort by priority, take 10, and project in the order expected by the spec. -df = df.limit(10) - -# Change the order that the columns are reported in just to match the spec - -df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +df = ( + df.sort(col("revenue").sort(ascending=False), "o_orderdate") + .limit(10) + .select("l_orderkey", "revenue", "o_orderdate", "o_shippriority") +) # Show result diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index 426338aea..6f11c1383 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -24,18 +24,40 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_orderpriority, + count(*) as order_count + from + orders + where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) + group by + o_orderpriority + order by + o_orderpriority; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -# Ideally we could put 3 months into the interval. See note below. -INTERVAL_DAYS = 92 -DATE_OF_INTEREST = "1993-07-01" +QUARTER_START = date(1993, 7, 1) +QUARTER_END = date(1993, 10, 1) # Load the dataframes we need @@ -48,36 +70,23 @@ "l_orderkey", "l_commitdate", "l_receiptdate" ) -# Create a date object from the string -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - -# Limit results to cases where commitment date before receipt date -# Aggregate the results so we only get one row to join with the order table. -# Alternately, and likely more idiomatic is instead of `.aggregate` you could -# do `.select("l_orderkey").distinct()`. The goal here is to show -# multiple examples of how to use Data Fusion. -df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate( - [col("l_orderkey")], [] +# Keep only orders in the quarter of interest, then restrict to those that +# have at least one late lineitem via a semi join (the DataFrame form of +# ``EXISTS`` from the reference SQL). +df_orders = df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), ) -# Limit orders to date range of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) -) +late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")) -# Perform the join to find only orders for which there are lineitems outside of expected range df = df_orders.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" + late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi" ) -# Based on priority, find the number of entries -df = df.aggregate( - [col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")] +# Count the number of orders in each priority group and sort. +df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by( + "o_orderpriority" ) -# Sort the results -df = df.sort(col("o_orderpriority").sort()) - df.show() diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index fa2b01dea..bfdba5d4c 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -27,23 +27,45 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue + from + customer, + orders, + lineitem, + supplier, + nation, + region + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year + group by + n_name + order by + revenue desc; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_OF_INTEREST = "1994-01-01" -INTERVAL_DAYS = 365 +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) REGION_OF_INTEREST = "ASIA" -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -68,38 +90,32 @@ ) # Restrict dataframes to cases of interest -df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter( - col("o_orderdate") < lit(date) + lit(interval) +df_orders = df_orders.filter( + col("o_orderdate") >= lit(YEAR_START), + col("o_orderdate") < lit(YEAR_END), ) -df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST)) +df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST) # Join all the dataframes df = ( - df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" - ) - .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") .join( df_supplier, left_on=["l_suppkey", "c_nationkey"], right_on=["s_suppkey", "s_nationkey"], - how="inner", ) - .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(df_region, left_on="n_regionkey", right_on="r_regionkey") ) -# Compute the final result +# Compute the final result, then sort in descending order. df = df.aggregate( - [col("n_name")], + ["n_name"], [F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")], -) - -# Sort in descending order - -df = df.sort(col("revenue").sort(ascending=False)) +).sort(col("revenue").sort(ascending=False)) df.show() diff --git a/examples/tpch/q06_forecasting_revenue_change.py b/examples/tpch/q06_forecasting_revenue_change.py index 1de5848b1..ed54d22a4 100644 --- a/examples/tpch/q06_forecasting_revenue_change.py +++ b/examples/tpch/q06_forecasting_revenue_change.py @@ -27,28 +27,34 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice * l_discount) as revenue + from + lineitem + where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path # Variables from the example query -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) DISCOUT = 0.06 DELTA = 0.01 QUANTITY = 24 -INTERVAL_DAYS = 365 - -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval()) - # Load the dataframes we need ctx = SessionContext() @@ -59,12 +65,11 @@ # Filter down to lineitems of interest -df = ( - df_lineitem.filter(col("l_shipdate") >= lit(date)) - .filter(col("l_shipdate") < lit(date) + lit(interval)) - .filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA)) - .filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA)) - .filter(col("l_quantity") < lit(QUANTITY)) +df = df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)), + col("l_quantity") < QUANTITY, ) # Add up all the "lost" revenue diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index ff2f891f1..df1c2ae0d 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -26,9 +26,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue + from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping + group by + supp_nation, + cust_nation, + l_year + order by + supp_nation, + cust_nation, + l_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit @@ -40,11 +82,8 @@ nation_1 = lit("FRANCE") nation_2 = lit("GERMANY") -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" - -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -69,60 +108,44 @@ # Filter to time of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( - col("l_shipdate") <= end_date +df_lineitem = df_lineitem.filter( + col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE) ) -# A simpler way to do the following operation is to use a filter, but we also want to demonstrate -# how to use case statements. Here we are assigning `n_name` to be itself when it is either of -# the two nations of interest. Since there is no `otherwise()` statement, any values that do -# not match these will result in a null value and then get filtered out. -# -# To do the same using a simple filter would be: -# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001 -df_nation = df_nation.with_column( - "n_name", - F.case(col("n_name")) - .when(nation_1, col("n_name")) - .when(nation_2, col("n_name")) - .end(), -).filter(~col("n_name").is_null()) +# Limit the nation table to the two nations of interest. +df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2])) # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" -).select(col("c_custkey"), col("n_name").alias("cust_nation")) + df_nation, left_on="c_nationkey", right_on="n_nationkey" +).select("c_custkey", col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join( - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" - ) - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") .filter(col("cust_nation") != col("supp_nation")) ) # Extract out two values for every line item -df = df.with_column( - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) +df = df.with_columns( + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), +) -# Aggregate the results +# Aggregate and sort per the spec. df = df.aggregate( - [col("supp_nation"), col("cust_nation"), col("l_year")], + ["supp_nation", "cust_nation", "l_year"], [F.sum(col("volume")).alias("revenue")], -) - -# Sort based on problem statement requirements -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) +).sort_by("supp_nation", "cust_nation", "l_year") df.show() diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 4bf50efba..dd7bacedb 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -25,24 +25,61 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share + from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations + group by + o_year + order by + o_year; """ -from datetime import datetime +from datetime import date import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -supplier_nation = lit("BRAZIL") -customer_region = lit("AMERICA") -part_of_interest = lit("ECONOMY ANODIZED STEEL") - -START_DATE = "1995-01-01" -END_DATE = "1996-12-31" +supplier_nation = "BRAZIL" +customer_region = "AMERICA" +part_of_interest = "ECONOMY ANODIZED STEEL" -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) +START_DATE = date(1995, 1, 1) +END_DATE = date(1996, 12, 31) # Load the dataframes we need @@ -74,105 +111,57 @@ # Limit orders to those in the specified range -df_orders = df_orders.filter(col("o_orderdate") >= start_date).filter( - col("o_orderdate") <= end_date -) - -# Part 1: Find customers in the region - -# We want customers in region specified by region_of_interest. This will be used to compute -# the total sales of the part of interest. We want to know of those sales what fraction -# was supplied by the nation of interest. There is no guarantee that the nation of -# interest is within the region of interest. - -# First we find all the sales that make up the basis. - -df_regional_customers = df_region.filter(col("r_name") == customer_region) - -# After this join we have all of the possible sales nations -df_regional_customers = df_regional_customers.join( - df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" -) - -# Now find the possible customers -df_regional_customers = df_regional_customers.join( - df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" -) - -# Next find orders for these customers -df_regional_customers = df_regional_customers.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" -) - -# Find all line items from these orders -df_regional_customers = df_regional_customers.join( - df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" -) - -# Limit to the part of interest -df_regional_customers = df_regional_customers.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" -) - -# Compute the volume for each line item -df_regional_customers = df_regional_customers.with_column( - "volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")) -) - -# Part 2: Find suppliers from the nation - -# Now that we have all of the sales of that part in the specified region, we need -# to determine which of those came from suppliers in the nation we are interested in. - -df_national_suppliers = df_nation.filter(col("n_name") == supplier_nation) - -# Determine the suppliers by the limited nation key we have in our single row df above -df_national_suppliers = df_national_suppliers.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" -) - -# When we join to the customer dataframe, we don't want to confuse other columns, so only -# select the supplier key that we need -df_national_suppliers = df_national_suppliers.select("s_suppkey") - - -# Part 3: Combine suppliers and customers and compute the market share - -# Now we can do a left outer join on the suppkey. Those line items from other suppliers -# will get a null value. We can check for the existence of this null to compute a volume -# column only from suppliers in the nation we are evaluating. - -df = df_regional_customers.join( - df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" -) - -# Use a case statement to compute the volume sold by suppliers in the nation of interest -df = df.with_column( - "national_volume", - F.case(col("s_suppkey").is_null()) - .when(lit(value=False), col("volume")) - .otherwise(lit(0.0)), -) - -df = df.with_column( - "o_year", F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()) -) - - -# Lastly, sum up the results - -df = df.aggregate( - [col("o_year")], - [ - F.sum(col("volume")).alias("volume"), - F.sum(col("national_volume")).alias("national_volume"), - ], +df_orders = df_orders.filter( + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) +) + +# Pair each supplier with its nation name so every regional-customer row +# below carries the supplier's nation and can be filtered inside the +# aggregate with ``F.sum(..., filter=...)``. + +df_supplier_with_nation = df_supplier.join( + df_nation, left_on="s_nationkey", right_on="n_nationkey" +).select("s_suppkey", col("n_name").alias("supp_nation")) + +# Build every (part, lineitem, order, customer) row for customers in the +# target region ordering the target part. Each row carries the supplier's +# nation so we can aggregate on it below. + +df = ( + df_region.filter(col("r_name") == customer_region) + .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") + .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") + .join(df_orders, left_on="c_custkey", right_on="o_custkey") + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") + .with_columns( + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), + ) +) + +# Aggregate the total and national volumes per year via the ``filter`` +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). +# ``coalesce`` handles the case where no sale came from the target nation +# for a given year. +df = ( + df.aggregate( + ["o_year"], + [ + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( + "national_volume" + ), + F.sum(col("volume")).alias("total_volume"), + ], + ) + .select( + "o_year", + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( + "mkt_share" + ), + ) + .sort_by("o_year") ) -df = df.select( - col("o_year"), (F.col("national_volume") / F.col("volume")).alias("mkt_share") -) - -df = df.sort(col("o_year").sort()) - df.show() diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index e2abbd095..ec68a2ab7 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -27,6 +27,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + nation, + o_year, + sum(amount) as sum_profit + from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit + group by + nation, + o_year + order by + nation, + o_year desc; """ import pyarrow as pa @@ -34,7 +69,7 @@ from datafusion import functions as F from util import get_data_path -part_color = lit("green") +part_color = "green" # Load the dataframes we need @@ -62,37 +97,35 @@ "n_nationkey", "n_name", "n_regionkey" ) -# Limit possible parts to the color specified -df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) - -# We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join( - df_partsupp, - left_on=["l_suppkey", "l_partkey"], - right_on=["ps_suppkey", "ps_partkey"], - how="inner", +# Limit possible parts to the color specified, then walk the joins down to the +# line-item rows we need and attach the supplier's nation. ``F.contains`` +# maps directly to the reference SQL's ``p_name like '%green%'``. +df = ( + df_part.filter(F.contains(col("p_name"), lit(part_color))) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join( + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(df_nation, left_on="s_nationkey", right_on="n_nationkey") ) -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( col("n_name").alias("nation"), F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()).alias("o_year"), ( - (col("l_extendedprice") * (lit(1) - col("l_discount"))) - - (col("ps_supplycost") * col("l_quantity")) + col("l_extendedprice") * (lit(1) - col("l_discount")) + - col("ps_supplycost") * col("l_quantity") ).alias("amount"), ) -# Sum up the values by nation and year -df = df.aggregate( - [col("nation"), col("o_year")], [F.sum(col("amount")).alias("profit")] +# Sum up the values by nation and year, then sort per the spec. +df = df.aggregate(["nation", "o_year"], [F.sum(col("amount")).alias("profit")]).sort( + "nation", col("o_year").sort(ascending=False) ) -# Sort according to the problem specification -df = df.sort(col("nation").sort(), col("o_year").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index ed822e264..e6532517e 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -27,20 +27,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment + from + customer, + orders, + lineitem, + nation + where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey + group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment + order by + revenue desc limit 20; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE_START_OF_QUARTER = "1993-10-01" - -date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date()) - -interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1993, 10, 1) +QUARTER_END = date(1994, 1, 1) # Load the dataframes we need @@ -66,44 +96,40 @@ ) # limit to returns -df_lineitem = df_lineitem.filter(col("l_returnflag") == lit("R")) +df_lineitem = df_lineitem.filter(col("l_returnflag") == "R") # Rather than aggregate by all of the customer fields as you might do looking at the specification, # we can aggregate by o_custkey and then join in the customer data at the end. -df = df_orders.filter(col("o_orderdate") >= date_start_of_quarter).filter( - col("o_orderdate") < date_start_of_quarter + interval_one_quarter +df = ( + df_orders.filter( + col("o_orderdate") >= lit(QUARTER_START), + col("o_orderdate") < lit(QUARTER_END), + ) + .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") + .aggregate( + ["o_custkey"], + [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], + ) ) -df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") - -# Compute the revenue -df = df.aggregate( - [col("o_custkey")], - [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], +# Now join in the customer data, project the spec's output columns, and take the top 20. +df = ( + df.join(df_customer, left_on="o_custkey", right_on="c_custkey") + .join(df_nation, left_on="c_nationkey", right_on="n_nationkey") + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(col("revenue").sort(ascending=False)) + .limit(20) ) -# Now join in the customer data -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") -df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") - -# These are the columns the problem statement requires -df = df.select( - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", -) - -# Sort the results in descending order -df = df.sort(col("revenue").sort(ascending=False)) - -# Only return the top 20 results -df = df.limit(20) - df.show() diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index de309fa64..1f40bbdad 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -25,6 +25,36 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + value desc; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -49,39 +79,30 @@ "n_nationkey", "n_name" ) -# limit to returns -df_nation = df_nation.filter(col("n_name") == lit(NATION)) - -# Find part supplies of within this target nation - -df = df_nation.join( - df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +# Restrict to the target nation, then walk to partsupp rows via the supplier +# join. Aggregate the per-part inventory value. +df = ( + df_nation.filter(col("n_name") == NATION) + .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") + .join(df_partsupp, left_on="s_suppkey", right_on="ps_suppkey") + .with_column("value", col("ps_supplycost") * col("ps_availqty")) + .aggregate(["ps_partkey"], [F.sum(col("value")).alias("value")]) ) -df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") - - -# Compute the value of individual parts -df = df.with_column("value", col("ps_supplycost") * col("ps_availqty")) - -# Compute total value of specific parts -df = df.aggregate([col("ps_partkey")], [F.sum(col("value")).alias("value")]) - -# By default window functions go from unbounded preceding to current row, but we want -# to compute this sum across all rows -window_frame = WindowFrame("rows", None, None) - -df = df.with_column( - "total_value", F.sum(col("value")).over(Window(window_frame=window_frame)) +# A window function evaluated over the entire output produces a scalar grand +# total that can be referenced row-by-row in the filter — a DataFrame-native +# stand-in for the SQL HAVING ... > (SELECT SUM(...) * FRACTION ...) pattern. +# The default frame is "UNBOUNDED PRECEDING to CURRENT ROW"; override to the +# full partition for the grand total. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df.with_column( + "total_value", F.sum(col("value")).over(Window(window_frame=whole_frame)) + ) + .filter(col("value") / col("total_value") >= lit(FRACTION)) + .select("ps_partkey", "value") + .sort(col("value").sort(ascending=False)) ) -# Limit to the parts for which there is a significant value based on the fraction of the total -df = df.filter(col("value") / col("total_value") >= lit(FRACTION)) - -# We only need to report on these two columns -df = df.select("ps_partkey", "value") - -# Sort in descending order of value -df = df.sort(col("value").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index 9071597f0..fb78fe3c2 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -27,18 +27,49 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count + from + orders, + lineitem + where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year + group by + l_shipmode + order by + l_shipmode; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path SHIP_MODE_1 = "MAIL" SHIP_MODE_2 = "SHIP" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) # Load the dataframes we need @@ -51,63 +82,30 @@ "l_orderkey", "l_shipmode", "l_commitdate", "l_shipdate", "l_receiptdate" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - - -df = df_lineitem.filter(col("l_receiptdate") >= lit(date)).filter( - col("l_receiptdate") < lit(date) + lit(interval) -) - -# Note: It is not recommended to use array_has because it treats the second argument as an argument -# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow. -# Instead check the position of the entry is not null. -df = df.filter( - ~F.array_position( - F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode") - ).is_null() -) - -# Since we have only two values, it's much easier to do this as a filter where the l_shipmode -# matches either of the two values, but we want to show doing some array operations in this -# example. If you want to see this done with filters, comment out the above line and uncomment -# this one. -# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001 +df = df_lineitem.filter( + col("l_receiptdate") >= lit(YEAR_START), + col("l_receiptdate") < lit(YEAR_END), + # ``in_list`` maps directly to ``l_shipmode in (...)`` from the SQL. + F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]), + col("l_shipdate") < col("l_commitdate"), + col("l_commitdate") < col("l_receiptdate"), +).join(df_orders, left_on="l_orderkey", right_on="o_orderkey") -# We need order priority, so join order df to line item -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +# Flag each line item as belonging to a high-priority order or not. +high_priorities = [lit("1-URGENT"), lit("2-HIGH")] +is_high = F.in_list(col("o_orderpriority"), high_priorities) +is_low = F.in_list(col("o_orderpriority"), high_priorities, negated=True) -# Restrict to line items we care about based on the problem statement. -df = df.filter(col("l_commitdate") < col("l_receiptdate")) - -df = df.filter(col("l_shipdate") < col("l_commitdate")) - -df = df.with_column( - "high_line_value", - F.case(col("o_orderpriority")) - .when(lit("1-URGENT"), lit(1)) - .when(lit("2-HIGH"), lit(1)) - .otherwise(lit(0)), -) - -# Aggregate the results +# Count the high-priority and low-priority lineitems per ship mode via the +# ``filter`` kwarg on ``F.count`` (DataFrame form of SQL's ``count(*) +# FILTER (WHERE ...)``). df = df.aggregate( - [col("l_shipmode")], + ["l_shipmode"], [ - F.sum(col("high_line_value")).alias("high_line_count"), - F.count(col("high_line_value")).alias("all_lines_count"), + F.count(col("o_orderkey"), filter=is_high).alias("high_line_count"), + F.count(col("o_orderkey"), filter=is_low).alias("low_line_count"), ], -) - -# Compute the final output -df = df.select( - col("l_shipmode"), - col("high_line_count"), - (col("all_lines_count") - col("high_line_count")).alias("low_line_count"), -) - -df = df.sort(col("l_shipmode").sort()) +).sort_by("l_shipmode") df.show() diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 93f082ea3..37c0b93f6 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -26,6 +26,29 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_count, + count(*) as custdist + from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) + group by + c_count + order by + custdist desc, + c_count desc; """ from datafusion import SessionContext, col, lit @@ -49,20 +72,16 @@ F.regexp_match(col("o_comment"), lit(f"{WORD_1}.?*{WORD_2}")).is_null() ) -# Since we may have customers with no orders we must do a left join -df = df_customer.join( - df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" -) - -# Find the number of orders for each customer -df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) - -# Ultimately we want to know the number of customers that have that customer count -df = df.aggregate([col("c_count")], [F.count(col("c_count")).alias("custdist")]) - -# We want to order the results by the highest number of customers per count -df = df.sort( - col("custdist").sort(ascending=False), col("c_count").sort(ascending=False) +# Customers with no orders still participate, so this is a left join. Count the +# orders per customer, then count customers per order-count value. +df = ( + df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="left") + .aggregate(["c_custkey"], [F.count(col("o_custkey")).alias("c_count")]) + .aggregate(["c_count"], [F.count_star().alias("custdist")]) + .sort( + col("custdist").sort(ascending=False), + col("c_count").sort(ascending=False), + ) ) df.show() diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index d62f76e3c..08f4f054d 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -24,20 +24,32 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + from + lineitem, + part + where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path -DATE = "1995-09-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_one_month = lit(pa.scalar((0, 30, 0), type=pa.month_day_nano_interval())) +MONTH_START = date(1995, 9, 1) +MONTH_END = date(1995, 10, 1) # Load the dataframes we need @@ -49,37 +61,30 @@ df_part = ctx.read_parquet(get_data_path("part.parquet")).select("p_partkey", "p_type") -# Check part type begins with PROMO -df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(6)) == lit("PROMO") -).with_column("promo_factor", lit(1.0)) - -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_one_month -) - -# Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join( - df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" -) - -# Make a factor of 1.0 if it is a promotion, 0.0 otherwise -df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) -df = df.with_column("revenue", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) - - -# Sum up the promo and total revenue -df = df.aggregate( - [], - [ - F.sum(col("promo_factor") * col("revenue")).alias("promo_revenue"), - F.sum(col("revenue")).alias("total_revenue"), - ], -) - -# Return the percentage of revenue from promotions -df = df.select( - (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias("promo_revenue") +# Restrict the line items to the month of interest, join the matching part +# rows, and aggregate revenue totals with a ``filter`` clause on the promo +# sum — the DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``. +revenue = col("l_extendedprice") * (lit(1.0) - col("l_discount")) +is_promo = F.starts_with(col("p_type"), lit("PROMO")) + +df = ( + df_lineitem.filter( + col("l_shipdate") >= lit(MONTH_START), + col("l_shipdate") < lit(MONTH_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + [], + [ + F.sum(revenue, filter=is_promo).alias("promo_revenue"), + F.sum(revenue).alias("total_revenue"), + ], + ) + .select( + (lit(100.0) * col("promo_revenue") / col("total_revenue")).alias( + "promo_revenue" + ) + ) ) df.show() diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 5128937a7..01c38b9f8 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -24,21 +24,50 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; + select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue + from + supplier, + revenue0 + where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) + order by + s_suppkey; + drop view revenue0; """ -from datetime import datetime +from datetime import date -import pyarrow as pa -from datafusion import SessionContext, WindowFrame, col, lit +from datafusion import SessionContext, col, lit from datafusion import functions as F -from datafusion.expr import Window from util import get_data_path -DATE = "1996-01-01" - -date_of_interest = lit(datetime.strptime(DATE, "%Y-%m-%d").date()) - -interval_3_months = lit(pa.scalar((0, 91, 0), type=pa.month_day_nano_interval())) +QUARTER_START = date(1996, 1, 1) +QUARTER_END = date(1996, 4, 1) # Load the dataframes we need @@ -54,38 +83,29 @@ "s_phone", ) -# Limit line items to the quarter of interest -df_lineitem = df_lineitem.filter(col("l_shipdate") >= date_of_interest).filter( - col("l_shipdate") < date_of_interest + interval_3_months -) +# Per-supplier revenue over the quarter of interest. +revenue = col("l_extendedprice") * (lit(1) - col("l_discount")) -df = df_lineitem.aggregate( - [col("l_suppkey")], - [ - F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias( - "total_revenue" - ) - ], -) +per_supplier_revenue = df_lineitem.filter( + col("l_shipdate") >= lit(QUARTER_START), + col("l_shipdate") < lit(QUARTER_END), +).aggregate(["l_suppkey"], [F.sum(revenue).alias("total_revenue")]) -# Use a window function to find the maximum revenue across the entire dataframe -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "max_revenue", - F.max(col("total_revenue")).over(Window(window_frame=window_frame)), +# Compute the grand maximum revenue separately and join on equality — the +# DataFrame stand-in for the reference SQL's +# ``total_revenue = (select max(total_revenue) from revenue0)`` subquery. +max_revenue = per_supplier_revenue.aggregate( + [], [F.max(col("total_revenue")).alias("max_rev")] ) -# Find all suppliers whose total revenue is the same as the maximum -df = df.filter(col("total_revenue") == col("max_revenue")) - -# Now that we know the supplier(s) with maximum revenue, get the rest of their information -# from the supplier table -df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +top_suppliers = per_supplier_revenue.join_on( + max_revenue, col("total_revenue") == col("max_rev") +).select("l_suppkey", "total_revenue") -# Return only the columns requested -df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") - -# If we have more than one, sort by supplier number (suppkey) -df = df.sort(col("s_suppkey").sort()) +df = ( + df_supplier.join(top_suppliers, left_on="s_suppkey", right_on="l_suppkey") + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort_by("s_suppkey") +) df.show() diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 65043ffda..ddeadff5f 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -26,9 +26,41 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt + from + partsupp, + part + where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) + group by + p_brand, + p_type, + p_size + order by + supplier_cnt desc, + p_brand, + p_type, + p_size; """ -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -52,39 +84,36 @@ ) df_unwanted_suppliers = df_supplier.filter( - ~F.regexp_match(col("s_comment"), lit("Customer.?*Complaints")).is_null() + F.regexp_like(col("s_comment"), lit("Customer.*Complaints")) ) -# Remove unwanted suppliers +# Remove unwanted suppliers via an anti join (DataFrame form of NOT IN). df_partsupp = df_partsupp.join( - df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" + df_unwanted_suppliers, left_on="ps_suppkey", right_on="s_suppkey", how="anti" ) -# Select the parts we are interested in -df_part = df_part.filter(col("p_brand") != lit(BRAND)) +# Select the parts we are interested in. df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) - != lit(TYPE_TO_IGNORE) -) - -# Python conversion of integer to literal casts it to int64 but the data for -# part size is stored as an int32, so perform a cast. Then check to find if the part -# size is within the array of possible sizes by checking the position of it is not -# null. -p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) -df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) - -df = df_part.join( - df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" + col("p_brand") != BRAND, + ~F.starts_with(col("p_type"), lit(TYPE_TO_IGNORE)), + F.in_list(col("p_size"), [lit(s) for s in SIZES_OF_INTEREST]), ) -df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() - -df = df.aggregate( - [col("p_brand"), col("p_type"), col("p_size")], - [F.count(col("ps_suppkey")).alias("supplier_cnt")], +# For each (brand, type, size), count the distinct suppliers remaining. +df = ( + df_part.join(df_partsupp, left_on="p_partkey", right_on="ps_partkey") + .select("p_brand", "p_type", "p_size", "ps_suppkey") + .distinct() + .aggregate( + ["p_brand", "p_type", "p_size"], + [F.count(col("ps_suppkey")).alias("supplier_cnt")], + ) + .sort( + col("supplier_cnt").sort(ascending=False), + "p_brand", + "p_type", + "p_size", + ) ) -df = df.sort(col("supplier_cnt").sort(ascending=False)) - df.show() diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5ccb38422..f2229171f 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -26,6 +26,26 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); """ from datafusion import SessionContext, WindowFrame, col, lit @@ -47,29 +67,23 @@ "l_partkey", "l_quantity", "l_extendedprice" ) -# Limit to the problem statement's brand and container types -df = df_part.filter(col("p_brand") == lit(BRAND)).filter( - col("p_container") == lit(CONTAINER) -) - -# Combine data -df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +# Limit to parts of the target brand/container, join their line items, and +# attach the per-part average quantity via a partitioned window function — +# the DataFrame form of the SQL's correlated ``avg(l_quantity)`` subquery. +whole_frame = WindowFrame("rows", None, None) -# Find the average quantity -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_quantity", - F.avg(col("l_quantity")).over( - Window(partition_by=[col("l_partkey")], window_frame=window_frame) - ), +df = ( + df_part.filter(col("p_brand") == BRAND, col("p_container") == CONTAINER) + .join(df_lineitem, left_on="p_partkey", right_on="l_partkey") + .with_column( + "avg_quantity", + F.avg(col("l_quantity")).over( + Window(partition_by=[col("l_partkey")], window_frame=whole_frame) + ), + ) + .filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) + .aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) + .select((col("total") / lit(7.0)).alias("avg_yearly")) ) -df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) - -# Compute the total -df = df.aggregate([], [F.sum(col("l_extendedprice")).alias("total")]) - -# Divide by number of years in the problem statement to get average -df = df.select((col("total") / lit(7)).alias("avg_yearly")) - df.show() diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 834d181c9..23132d60d 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -24,9 +24,44 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) + from + customer, + orders, + lineitem + where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey + group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice + order by + o_totalprice desc, + o_orderdate limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -46,22 +81,24 @@ "l_orderkey", "l_quantity", "l_extendedprice" ) -df = df_lineitem.aggregate( - [col("l_orderkey")], [F.sum(col("l_quantity")).alias("total_quantity")] +# Find orders whose total quantity exceeds the threshold, then join in the +# order + customer details the problem statement requires and sort. +df = ( + df_lineitem.aggregate( + ["l_orderkey"], [F.sum(col("l_quantity")).alias("total_quantity")] + ) + .filter(col("total_quantity") > QUANTITY) + .join(df_orders, left_on="l_orderkey", right_on="o_orderkey") + .join(df_customer, left_on="o_custkey", right_on="c_custkey") + .select( + "c_name", + "c_custkey", + "o_orderkey", + "o_orderdate", + "o_totalprice", + "total_quantity", + ) + .sort(col("o_totalprice").sort(ascending=False), "o_orderdate") ) -# Limit to orders in which the total quantity is above a threshold -df = df.filter(col("total_quantity") > lit(QUANTITY)) - -# We've identified the orders of interest, now join the additional data -# we are required to report on -df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") -df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") - -df = df.select( - "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" -) - -df = df.sort(col("o_totalprice").sort(ascending=False), col("o_orderdate").sort()) - df.show() diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index bd492aac0..a2be1c1b7 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -24,10 +24,47 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + sum(l_extendedprice* (1 - l_discount)) as revenue + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); """ -import pyarrow as pa -from datafusion import SessionContext, col, lit, udf +from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path @@ -65,72 +102,41 @@ "l_discount", ) -# These limitations apply to all line items, so go ahead and do them first - -df = df_lineitem.filter(col("l_shipinstruct") == lit("DELIVER IN PERSON")) - -df = df.filter( - (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) -) +# Filter conditions that apply to every disjunct of the reference SQL's WHERE +# clause — pull them out up front so the per-brand predicate stays focused on +# the brand-specific parts. +df = df_lineitem.filter( + col("l_shipinstruct") == "DELIVER IN PERSON", + F.in_list(col("l_shipmode"), [lit("AIR"), lit("AIR REG")]), +).join(df_part, left_on="l_partkey", right_on="p_partkey") + + +# Build one OR-combined predicate per brand. Each disjunct encodes the +# brand-specific container list, quantity window, and size range from the +# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand +# B ...) or (... brand C ...)`` form directly, without a UDF. +def _brand_predicate( + brand: str, min_quantity: int, containers: list[str], max_size: int +): + return ( + (col("p_brand") == brand) + & F.in_list(col("p_container"), [lit(c) for c in containers]) + & col("l_quantity").between(lit(min_quantity), lit(min_quantity + 10)) + & col("p_size").between(lit(1), lit(max_size)) + ) -df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") - - -# Create the user defined function (UDF) definition that does the work -def is_of_interest( - brand_arr: pa.Array, - container_arr: pa.Array, - quantity_arr: pa.Array, - size_arr: pa.Array, -) -> pa.Array: - """ - The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array - and generating a resultant Array. The length of the inputs should match and there should be the - same number of rows in the output. - """ - result = [] - for idx, brand_val in enumerate(brand_arr): - brand = brand_val.as_py() - if brand in items_of_interest: - values_of_interest = items_of_interest[brand] - - container_matches = ( - container_arr[idx].as_py() in values_of_interest["containers"] - ) - - quantity = quantity_arr[idx].as_py() - quantity_matches = ( - values_of_interest["min_quantity"] - <= quantity - <= values_of_interest["min_quantity"] + 10 - ) - - size = size_arr[idx].as_py() - size_matches = 1 <= size <= values_of_interest["max_size"] - - result.append(container_matches and quantity_matches and size_matches) - else: - result.append(False) - - return pa.array(result) - - -# Turn the above function into a UDF that DataFusion can understand -is_of_interest_udf = udf( - is_of_interest, - [pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()], - pa.bool_(), - "stable", -) -# Filter results using the above UDF -df = df.filter( - is_of_interest_udf( - col("p_brand"), col("p_container"), col("l_quantity"), col("p_size") +predicate = None +for brand, params in items_of_interest.items(): + part_predicate = _brand_predicate( + brand, + params["min_quantity"], + params["containers"], + params["max_size"], ) -) + predicate = part_predicate if predicate is None else predicate | part_predicate -df = df.aggregate( +df = df.filter(predicate).aggregate( [], [F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias("revenue")], ) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index a25188d31..18f96da97 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -25,17 +25,57 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + s_address + from + supplier, + nation + where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' + order by + s_name; """ -from datetime import datetime +from datetime import date -import pyarrow as pa from datafusion import SessionContext, col, lit from datafusion import functions as F from util import get_data_path COLOR_OF_INTEREST = "forest" -DATE_OF_INTEREST = "1994-01-01" +YEAR_START = date(1994, 1, 1) +YEAR_END = date(1995, 1, 1) NATION_OF_INTEREST = "CANADA" # Load the dataframes we need @@ -56,46 +96,48 @@ "n_nationkey", "n_name" ) -date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date() - -interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval()) - -# Filter down dataframes -df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) -df_part = df_part.filter( - F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1)) - == lit(COLOR_OF_INTEREST) -) - -df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter( - col("l_shipdate") < lit(date) + lit(interval) +# Filter down dataframes. ``starts_with`` reads more naturally than an +# explicit substring slice and maps directly to the reference SQL's +# ``p_name like 'forest%'`` clause. +df_nation = df_nation.filter(col("n_name") == NATION_OF_INTEREST) +df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST))) + +# Compute the total quantity of interesting parts shipped by each (part, +# supplier) pair within the year of interest. +totals = ( + df_lineitem.filter( + col("l_shipdate") >= lit(YEAR_START), + col("l_shipdate") < lit(YEAR_END), + ) + .join(df_part, left_on="l_partkey", right_on="p_partkey") + .aggregate( + ["l_partkey", "l_suppkey"], + [F.sum(col("l_quantity")).alias("total_sold")], + ) ) -# This will filter down the line items to the parts of interest -df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") - -# Compute the total sold and limit ourselves to individual supplier/part combinations -df = df.aggregate( - [col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")] +# Keep only (part, supplier) pairs whose available quantity exceeds 50% of +# the total shipped. The result already contains one row per supplier of +# interest, so we can semi-join the supplier table rather than inner-join +# and deduplicate afterwards. +excess_suppliers = ( + df_partsupp.join( + totals, + left_on=["ps_partkey", "ps_suppkey"], + right_on=["l_partkey", "l_suppkey"], + ) + .filter(col("ps_availqty") > lit(0.5) * col("total_sold")) + .select(col("ps_suppkey").alias("suppkey")) + .distinct() ) -df = df.join( - df_partsupp, - left_on=["l_partkey", "l_suppkey"], - right_on=["ps_partkey", "ps_suppkey"], - how="inner", +# Limit to suppliers in the nation of interest and pick out the two +# requested columns. +df = ( + df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey") + .join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi") + .select("s_name", "s_address") + .sort_by("s_name") ) -# Find cases of excess quantity -df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) - -# We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") -df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") - -# Restrict to the requested data per the problem statement -df = df.select("s_name", "s_address").distinct() - -df = df.sort(col("s_name").sort()) - df.show() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 4ee9d3733..d98f76ce7 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -24,9 +24,51 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + s_name, + count(*) as numwait + from + supplier, + lineitem l1, + orders, + nation + where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' + group by + s_name + order by + numwait desc, + s_name limit 100; """ -from datafusion import SessionContext, col, lit +from datafusion import SessionContext, col from datafusion import functions as F from util import get_data_path @@ -50,65 +92,57 @@ ) # Limit to suppliers in the nation of interest -df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) - -df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" +df_suppliers_of_interest = df_nation.filter(col("n_name") == NATION_OF_INTEREST).join( + df_supplier, left_on="n_nationkey", right_on="s_nationkey" ) -# Find the failed orders and all their line items -df = df_orders.filter(col("o_orderstatus") == lit("F")) - -df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") - -# Identify the line items for which the order is failed due to. -df = df.with_column( - "failed_supp", - F.case(col("l_receiptdate") > col("l_commitdate")) - .when(lit(value=True), col("l_suppkey")) - .end(), +# Line items for orders that have status 'F'. This is the candidate set of +# (order, supplier) pairs we reason about below. +failed_order_lineitems = df_lineitem.join( + df_orders.filter(col("o_orderstatus") == "F"), + left_on="l_orderkey", + right_on="o_orderkey", ) -# There are other ways we could do this but the purpose of this example is to work with rows where -# an element is an array of values. In this case, we will create two columns of arrays. One will be -# an array of all of the suppliers who made up this order. That way we can filter the dataframe for -# only orders where this array is larger than one for multiple supplier orders. The second column -# is all of the suppliers who failed to make their commitment. We can filter the second column for -# arrays with size one. That combination will give us orders that had multiple suppliers where only -# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the -# same supplier reported in either array. -df = df.aggregate( - [col("o_orderkey")], - [ - F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"), - F.array_agg( - col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True - ).alias("failed_suppliers"), - ], +# Line items whose receipt was late. This corresponds to ``l1`` in the +# reference SQL. +late_lineitems = failed_order_lineitems.filter( + col("l_receiptdate") > col("l_commitdate") ) -# This is the check described above which will identify single failed supplier in a multiple -# supplier order. -df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter( - F.array_length(col("all_suppliers")) > lit(1) +# Orders that had more than one distinct supplier. Expressed as +# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's +# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery. +multi_supplier_orders = ( + failed_order_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_suppliers")]) + .filter(col("n_suppliers") > 1) + .select("l_orderkey") ) -# Since we have an array we know is exactly one element long, we can extract that single value. -df = df.select( - col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey") +# Orders where exactly one distinct supplier was late. Stands in for the +# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is +# also late ...)`` subquery: if only one supplier on the order was late, +# nobody else on the same order was late. +single_late_supplier_orders = ( + late_lineitems.select("l_orderkey", "l_suppkey") + .distinct() + .aggregate(["l_orderkey"], [F.count_star().alias("n_late_suppliers")]) + .filter(col("n_late_suppliers") == 1) + .select("l_orderkey") ) -# Join to the supplier of interest list for the nation of interest -df = df.join( - df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +# Keep late line items whose order qualifies on both counts, attach the +# supplier name for suppliers in the nation of interest, count one row per +# qualifying order, and return the top 100. +df = ( + late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi") + .join(single_late_supplier_orders, on="l_orderkey", how="semi") + .join(df_suppliers_of_interest, left_on="l_suppkey", right_on="s_suppkey") + .aggregate(["s_name"], [F.count_star().alias("numwait")]) + .sort(col("numwait").sort(ascending=False), "s_name") + .limit(100) ) -# Count how many orders that supplier is the only failed supplier for -df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) - -# Return in descending order -df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort()) - -df = df.limit(100) - df.show() diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index a2d41b215..5043eeb51 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -24,6 +24,46 @@ The above problem statement text is copyrighted by the Transaction Processing Performance Council as part of their TPC Benchmark H Specification revision 2.18.0. + +Reference SQL (from TPC-H specification, used by the benchmark suite):: + + select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal + from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale + group by + cntrycode + order by + cntrycode; """ from datafusion import SessionContext, WindowFrame, col, lit @@ -42,40 +82,36 @@ ) df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select("o_custkey") -# The nation code is a two digit number, but we need to convert it to a string literal -nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) - -# Use the substring operation to extract the first two characters of the phone number -df = df_customer.with_column("cntrycode", F.substring(col("c_phone"), lit(0), lit(3))) - -# Limit our search to customers with some balance and in the country code above -df = df.filter(col("c_acctbal") > lit(0.0)) -df = df.filter(~F.array_position(nation_codes, col("cntrycode")).is_null()) - -# Compute the average balance. By default, the window frame is from unbounded preceding to the -# current row. We want our frame to cover the entire data frame. -window_frame = WindowFrame("rows", None, None) -df = df.with_column( - "avg_balance", - F.avg(col("c_acctbal")).over(Window(window_frame=window_frame)), -) - -df.show() -# Limit results to customers with above average balance -df = df.filter(col("c_acctbal") > col("avg_balance")) - -# Limit results to customers with no orders -df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") - -# Count up the customers and the balances -df = df.aggregate( - [col("cntrycode")], - [ - F.count(col("c_custkey")).alias("numcust"), - F.sum(col("c_acctbal")).alias("totacctbal"), - ], +# Country code is the two-digit prefix of the phone number. +nation_codes = [lit(str(n)) for n in NATION_CODES] + +# Start from customers with a positive balance in one of the target country +# codes, then attach the grand-mean balance via a whole-frame window so we +# can filter per row — DataFrame stand-in for the SQL's scalar ``(select +# avg(c_acctbal) ... )`` subquery. +whole_frame = WindowFrame("rows", None, None) + +df = ( + df_customer.with_column("cntrycode", F.left(col("c_phone"), lit(2))) + .filter( + col("c_acctbal") > 0.0, + F.in_list(col("cntrycode"), nation_codes), + ) + .with_column( + "avg_balance", + F.avg(col("c_acctbal")).over(Window(window_frame=whole_frame)), + ) + .filter(col("c_acctbal") > col("avg_balance")) + # Keep only customers with no orders (anti join = NOT EXISTS). + .join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") + .aggregate( + ["cntrycode"], + [ + F.count_star().alias("numcust"), + F.sum(col("c_acctbal")).alias("totacctbal"), + ], + ) + .sort_by("cntrycode") ) -df = df.sort(col("cntrycode").sort()) - df.show() From e0284c6e788b6fc893495ed929b9badef1cf925c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 24 Apr 2026 13:09:24 -0400 Subject: [PATCH 6/6] feat: add AI skill to find and improve the Pythonic interface to functions (#1484) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: accept native Python types in function arguments instead of requiring lit() Update 47 functions in functions.py to accept native Python types (int, float, str) for arguments that are contextually literals, eliminating verbose lit() wrapping. For example, users can now write split_part(col("a"), ",", 2) instead of split_part(col("a"), lit(","), lit(2)). All changes are backward compatible. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: update alias function signatures to match pythonic primary functions Update instr and position (aliases of strpos) to accept Expr | str for the substring parameter, matching the updated primary function signature. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: update make-pythonic skill to require alias type hint updates Alias functions that delegate to a primary function must have their type hints updated to match, even though coercion logic is only added to the primary. Added a new Step 3 to the implementation workflow for this. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: address review feedback on pythonic skill and function signatures Update SKILL.md to prevent three classes of issues: clarify that float already accepts int per PEP 484 (avoiding redundant int | float that fails ruff PYI041), add backward-compat rule for Category B so existing Expr params aren't removed, and add guidance for inline coercion with many optional nullable params instead of local helpers. Replace regexp_instr's _to_raw() helper with inline coercion matching the pattern used throughout the rest of the file. Co-Authored-By: Claude Opus 4.6 (1M context) * refactor: add coerce_to_expr helpers and replace inline coercion patterns Introduce coerce_to_expr() and coerce_to_expr_or_none() in expr.py as the complement to ensure_expr() — where ensure_expr rejects non-Expr values, these helpers wrap them via Expr.literal(). Replaces ~60 inline isinstance checks in functions.py with single-line helper calls, and updates the make-pythonic skill to document the new pattern. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: add aggregate function literal detection to make-pythonic skill Add Technique 1a to detect literal-only arguments in aggregate functions. Unlike scalar UDFs which enforce literals in invoke_with_args(), aggregate functions enforce them in accumulator() via get_scalar_value(), validate_percentile_expr(), or downcast_ref::(). Without this technique, the skill would incorrectly classify arguments like approx_percentile_cont's percentile as Category A (Expr | float) when they should be Category B (float only). Updates the decision flow to branch on scalar vs aggregate before checking for literal enforcement. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: add window function literal detection to make-pythonic skill Add Technique 1b to detect literal-only arguments in window functions. Window functions enforce literals in partition_evaluator() via get_scalar_value_from_args() / downcast_ref::(), not in invoke_with_args() (scalar) or accumulator() (aggregate). Updates the decision flow to branch on scalar vs aggregate vs window. Known window functions with literal-only arguments: ntile (n), lead/lag (offset, default_value), nth_value (n). Co-Authored-By: Claude Opus 4.6 (1M context) * fix: use explicit None checks, widen numeric type hints, and add tests Replace 7 fragile truthiness checks (x.expr if x else None) with explicit is not None checks to prevent silent None when zero-valued literals are passed. Widen log/power/pow type hints to Expr | int | float with noqa: PYI041 for clarity. Add unit tests for coerce_to_expr helpers and integration tests for pythonic calling conventions. Co-Authored-By: Claude Opus 4.6 (1M context) * chore: suppress FBT003 in tests and remove redundant noqa comments Add FBT003 (boolean positional value) to the per-file-ignores for python/tests/* in pyproject.toml, and remove the 6 now-redundant inline noqa: FBT003 comments across test_expr.py and test_context.py. Co-Authored-By: Claude Opus 4.6 (1M context) * docs: replace static function lists with discovery instructions in skill Replace hardcoded "Known aggregate/window functions with literal-only arguments" lists with instructions to discover them dynamically by searching the upstream crate source. Keeps a few examples as validation anchors so the agent knows its search is working correctly. Co-Authored-By: Claude Opus 4.6 (1M context) * fix: make interrupt test reliable on Python 3.11 PyThreadState_SetAsyncExc only delivers exceptions when the thread is executing Python bytecode, not while in native (Rust/C) code. The previous test had two issues causing flakiness on Python 3.11: 1. The interrupt fired before df.collect() entered the UDF, while the thread was still in native code where async exceptions are ignored. 2. time.sleep(2.0) is a single C call where async exceptions are not checked — they're only checked between bytecode instructions. Fix by adding a threading.Event so the interrupt waits until the UDF is actually executing Python code, and by sleeping in small increments so the eval loop has opportunities to check for pending exceptions. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .ai/skills/make-pythonic/SKILL.md | 430 +++++++++++++++++++++++++++++ pyproject.toml | 1 + python/datafusion/expr.py | 41 +++ python/datafusion/functions.py | 445 +++++++++++++++++------------- python/tests/test_context.py | 6 +- python/tests/test_dataframe.py | 21 +- python/tests/test_expr.py | 49 +++- python/tests/test_functions.py | 93 +++++++ 8 files changed, 893 insertions(+), 193 deletions(-) create mode 100644 .ai/skills/make-pythonic/SKILL.md diff --git a/.ai/skills/make-pythonic/SKILL.md b/.ai/skills/make-pythonic/SKILL.md new file mode 100644 index 000000000..57145ac6c --- /dev/null +++ b/.ai/skills/make-pythonic/SKILL.md @@ -0,0 +1,430 @@ + + +--- +name: make-pythonic +description: Audit and improve datafusion-python functions to accept native Python types (int, float, str, bool) instead of requiring explicit lit() or col() wrapping. Analyzes function signatures, checks upstream Rust implementations for type constraints, and applies the appropriate coercion pattern. +argument-hint: [scope] (e.g., "string functions", "datetime functions", "array functions", "math functions", "all", or a specific function name like "split_part") +--- + +# Make Python API Functions More Pythonic + +You are improving the datafusion-python API to feel more natural to Python users. The goal is to allow functions to accept native Python types (int, float, str, bool, etc.) for arguments that are contextually always or typically literal values, instead of requiring users to manually wrap them in `lit()`. + +**Core principle:** A Python user should be able to write `split_part(col("a"), ",", 2)` instead of `split_part(col("a"), lit(","), lit(2))` when the arguments are contextually obvious literals. + +## How to Identify Candidates + +The user may specify a scope via `$ARGUMENTS`. If no scope is given or "all" is specified, audit all functions in `python/datafusion/functions.py`. + +For each function, determine if any parameter can accept native Python types by evaluating **two complementary signals**: + +### Signal 1: Contextual Understanding + +Some arguments are contextually always or almost always literal values based on what the function does: + +| Context | Typical Arguments | Examples | +|---------|------------------|----------| +| **String position/count** | Character counts, indices, repetition counts | `left(str, n)`, `right(str, n)`, `repeat(str, n)`, `lpad(str, count, ...)` | +| **Delimiters/separators** | Fixed separator characters | `split_part(str, delim, idx)`, `concat_ws(sep, ...)` | +| **Search/replace patterns** | Literal search strings, replacements | `replace(str, from, to)`, `regexp_replace(str, pattern, replacement, flags)` | +| **Date/time parts** | Part names from a fixed set | `date_part(part, date)`, `date_trunc(part, date)` | +| **Rounding precision** | Decimal place counts | `round(val, places)`, `trunc(val, places)` | +| **Fill characters** | Padding characters | `lpad(str, count, fill)`, `rpad(str, count, fill)` | + +### Signal 2: Upstream Rust Implementation + +Check the Rust binding in `crates/core/src/functions.rs` and the upstream DataFusion function implementation to determine type constraints. The upstream source is cached locally at: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-/src/ +``` + +Check the DataFusion version in `crates/core/Cargo.toml` to find the right directory. Key subdirectories: `string/`, `datetime/`, `math/`, `regex/`. + +For **aggregate functions**, the upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-aggregate-/src/ +``` + +There are five concrete techniques to check, in order of signal strength: + +#### Technique 1: Check `invoke_with_args()` for literal-only enforcement (strongest signal) + +Some functions pattern-match on `ColumnarValue::Scalar` in their `invoke_with_args()` method and **return an error** if the argument is a column/array. This means the argument **must** be a literal — passing a column expression will fail at runtime. + +Example from `date_trunc.rs`: +```rust +let granularity_str = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v.to_lowercase() +} else { + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); +}; +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type (e.g., `str`), not `Expr`. The function will error at runtime with a column expression anyway. + +#### Technique 1a: Check `accumulator()` for literal-only enforcement (aggregate functions) + +Technique 1 applies to scalar UDFs. Aggregate functions do not have `invoke_with_args()` — instead, they enforce literal-only arguments in their `accumulator()` (or `create_accumulator()`) method, which runs at planning time before any data is processed. + +Look for these patterns inside `accumulator()`: + +- `get_scalar_value(expr)` — evaluates the expression against an empty batch and errors if it's not a scalar +- `validate_percentile_expr(expr)` — specific helper used by percentile functions +- `downcast_ref::()` — checks that the physical expression is a literal constant + +Example from `approx_percentile_cont.rs`: +```rust +fn accumulator(&self, args: AccumulatorArgs) -> Result { + let percentile = + validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?; + // ... +} +``` + +Where `validate_percentile_expr` calls `get_scalar_value` and errors with `"must be a literal"`. + +Example from `string_agg.rs`: +```rust +fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which aggregate functions have literal-only arguments, search the upstream aggregate crate for `get_scalar_value`, `validate_percentile_expr`, and `downcast_ref::()` inside `accumulator()` methods. For example, you should expect to find `approx_percentile_cont` (percentile) and `string_agg` (delimiter) among the results. + +#### Technique 1b: Check `partition_evaluator()` for literal-only enforcement (window functions) + +Window functions do not have `invoke_with_args()` or `accumulator()`. Instead, they enforce literal-only arguments in their `partition_evaluator()` method, which constructs the evaluator that processes each partition. + +The upstream source is in a separate crate: + +``` +~/.cargo/registry/src/index.crates.io-*/datafusion-functions-window-/src/ +``` + +Look for `get_scalar_value_from_args()` calls inside `partition_evaluator()`. This helper (defined in the window crate's `utils.rs`) calls `downcast_ref::()` and errors with `"There is only support Literal types for field at idx: {index} in Window Function"`. + +Example from `ntile.rs`: +```rust +fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, +) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + exec_datafusion_err!("NTILE requires a positive integer") + })?; + // ... +} +``` + +**If you find this pattern:** The argument is **Category B** — accept only the corresponding native Python type, not `Expr`. The function will error at planning time with a non-literal expression. + +To discover which window functions have literal-only arguments, search the upstream window crate for `get_scalar_value_from_args` inside `partition_evaluator()` methods. For example, you should expect to find `ntile` (n) and `lead`/`lag` (offset, default_value) among the results. + +#### Technique 2: Check the `Signature` for data type constraints + +Each function defines a `Signature::coercible(...)` that specifies what data types each argument accepts, using `Coercion` entries. This tells you the expected **data type** even if it doesn't enforce literal-only. + +Example from `repeat.rs`: +```rust +signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, +), +``` + +This tells you arg 2 (`n`) must be an integer type coerced to Int64. Use this to choose the correct Python type (e.g., `int` not `str` or `float`). + +Common mappings: +| Rust Type Constraint | Python Type | +|---------------------|-------------| +| `logical_int64()` / `TypeSignatureClass::Integer` | `int` | +| `logical_float64()` / `TypeSignatureClass::Numeric` | `int \| float` | +| `logical_string()` / `TypeSignatureClass::String` | `str` | +| `LogicalType::Boolean` | `bool` | + +**Important:** In Python's type system (PEP 484), `float` already accepts `int` values, so `int | float` is redundant and will fail the `ruff` linter (rule PYI041). Use `float` alone when the Rust side accepts a float/numeric type — Python users can still pass integer literals like `log(10, col("a"))` or `power(col("a"), 3)` without issue. Only use `int` when the Rust side strictly requires an integer (e.g., `logical_int64()`). + +#### Technique 3: Check `return_field_from_args()` for `scalar_arguments` usage + +Functions that inspect literal values at query planning time use `args.scalar_arguments.get(n)` in their `return_field_from_args()` method. This indicates the argument is **expected to be a literal** for optimal behavior (e.g., to determine output type precision), but may still work as a column. + +Example from `round.rs`: +```rust +let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), + Some(None) => None, // argument is not a literal (column) + Some(Some(scalar)) if scalar.is_null() => Some(0), + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), +}; +``` + +**If you find this pattern:** The argument is **Category A** — accept native types AND `Expr`. It works as a column but is primarily used as a literal. + +#### Decision flow + +``` +What kind of function is this? + Scalar UDF: + Is argument rejected at runtime if not a literal? + (check invoke_with_args for ColumnarValue::Scalar-only match + exec_err!) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Aggregate: + Is argument rejected at planning time if not a literal? + (check accumulator() for get_scalar_value / validate_percentile_expr / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + Window: + Is argument rejected at planning time if not a literal? + (check partition_evaluator() for get_scalar_value_from_args / + downcast_ref::() + error) + → YES: Category B — accept only native type, no Expr + → NO: continue below + +Does the Signature constrain it to a specific data type? + → YES: Category A — accept Expr | + → NO: Leave as Expr only +``` + +## Coercion Categories + +When making a function more pythonic, apply the correct coercion pattern based on **what the argument represents**: + +### Category A: Arguments That Should Accept Native Types AND Expr + +These are arguments that are *typically* literals but *could* be column references in advanced use cases. For these, accept a union type and coerce native types to `Expr.literal()`. + +**Type hint pattern:** `Expr | int`, `Expr | str`, `Expr | int | str`, etc. + +**When to use:** When the argument could plausibly come from a column in some use case (e.g., the repeat count might come from a column in a data-driven scenario). + +```python +def repeat(string: Expr, n: Expr | int) -> Expr: + """Repeats the ``string`` to ``n`` times. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["ha"]}) + >>> result = df.select( + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) + >>> result.collect_column("r")[0].as_py() + 'hahaha' + """ + if not isinstance(n, Expr): + n = Expr.literal(n) + return Expr(f.repeat(string.expr, n.expr)) +``` + +### Category B: Arguments That Should ONLY Accept Specific Native Types + +These are arguments where an `Expr` never makes sense because the value must be a fixed literal known at query-planning time (not a per-row value). For these, accept only the native type(s) and wrap internally. + +**Type hint pattern:** `str`, `int`, `list[str]`, etc. (no `Expr` in the union) + +**When to use:** When the argument is from a fixed enumeration or is always a compile-time constant, **AND** the parameter was not previously typed as `Expr`: +- Separator in `concat_ws` (already typed as `str` in the Rust binding) +- Index in `array_position` (already typed as `int` in the Rust binding) +- Values that the Rust implementation already accepts as native types + +**Backward compatibility rule:** If a parameter was previously typed as `Expr`, you **must** keep `Expr` in the union even if the Rust side requires a literal. Removing `Expr` would break existing user code like `date_part(lit("year"), col("a"))`. Use **Category A** instead — accept `Expr | str` — and let users who pass column expressions discover the runtime error from the Rust side. Never silently break backward compatibility. + +```python +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list ``args`` with the separator. + + ``separator`` is already typed as ``str`` in the Rust binding, so + there is no backward-compatibility concern. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": ["hello"], "b": ["world"]}) + >>> result = df.select( + ... dfn.functions.concat_ws("-", dfn.col("a"), dfn.col("b")).alias("c")) + >>> result.collect_column("c")[0].as_py() + 'hello-world' + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, args)) +``` + +### Category C: Arguments That Should Accept str as Column Name + +In some contexts a string argument naturally refers to a column name rather than a literal. This is the pattern used by DataFrame methods. + +**Type hint pattern:** `Expr | str` + +**When to use:** Only when the string contextually means a column name (rare in `functions.py`, more common in DataFrame methods). + +```python +# Use _to_raw_expr() from expr.py for this pattern +from datafusion.expr import _to_raw_expr + +def some_function(column: Expr | str) -> Expr: + raw = _to_raw_expr(column) # str -> col(str) + return Expr(f.some_function(raw)) +``` + +**IMPORTANT:** In `functions.py`, string arguments almost never mean column names. Functions operate on expressions, and column references should use `col()`. Category C applies mainly to DataFrame methods and context APIs, not to scalar/aggregate/window functions. Do NOT convert string arguments to column expressions in `functions.py` unless there is a very clear reason to do so. + +## Implementation Steps + +For each function being updated: + +### Step 1: Analyze the Function + +1. Read the current Python function signature in `python/datafusion/functions.py` +2. Read the Rust binding in `crates/core/src/functions.rs` +3. Optionally check the upstream DataFusion docs for the function +4. Determine which category (A, B, or C) applies to each parameter + +### Step 2: Update the Python Function + +1. **Change the type hints** to accept native types (e.g., `Expr` -> `Expr | int`) +2. **Add coercion logic** at the top of the function body +3. **Update the docstring** examples to use the simpler calling convention +4. **Preserve backward compatibility** — existing code using `Expr` must still work + +### Step 3: Update Alias Type Hints + +After updating a primary function, find all alias functions that delegate to it (e.g., `instr` and `position` delegate to `strpos`). Update each alias's **parameter type hints** to match the primary function's new signature. Do not add coercion logic to aliases — the primary function handles that. + +### Step 4: Update Docstring Examples (primary functions only) + +Per the project's CLAUDE.md rules: +- Every function must have doctest-style examples +- Optional parameters need examples both without and with the optional args, using keyword argument syntax +- Reuse the same input data across examples where possible + +**Update examples to demonstrate the pythonic calling convention:** + +```python +# BEFORE (old style - still works but verbose) +dfn.functions.left(dfn.col("a"), dfn.lit(3)) + +# AFTER (new style - shown in examples) +dfn.functions.left(dfn.col("a"), 3) +``` + +### Step 5: Run Tests + +After making changes, run the doctests to verify: +```bash +python -m pytest --doctest-modules python/datafusion/functions.py -v +``` + +## Coercion Helper Pattern + +Use the coercion helpers from `datafusion.expr` to convert native Python values to `Expr`. These are the complement of `ensure_expr()` — where `ensure_expr` *rejects* non-`Expr` values, the coercion helpers *wrap* them via `Expr.literal()`. + +**For required parameters** use `coerce_to_expr`: + +```python +from datafusion.expr import coerce_to_expr + +def left(string: Expr, n: Expr | int) -> Expr: + n = coerce_to_expr(n) + return Expr(f.left(string.expr, n.expr)) +``` + +**For optional nullable parameters** use `coerce_to_expr_or_none`: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none + +def regexp_count( + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, +) -> Expr: + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) +``` + +Both helpers are defined in `python/datafusion/expr.py` alongside `ensure_expr`. Import them in `functions.py` via: + +```python +from datafusion.expr import coerce_to_expr, coerce_to_expr_or_none +``` + +## What NOT to Change + +- **Do not change arguments that represent data columns.** If an argument is the primary data being operated on (e.g., the `string` in `left(string, n)` or the `array` in `array_sort(array)`), it should remain `Expr` only. Users should use `col()` for column references. +- **Do not change variadic `*args: Expr` parameters.** These represent multiple expressions and should stay as `Expr`. +- **Do not change arguments where the coercion is ambiguous.** If it is unclear whether a string should be a column name or a literal, leave it as `Expr` and let the user be explicit. +- **Do not add coercion logic to simple aliases.** If a function is just `return other_function(...)`, the primary function handles coercion. However, you **must update the alias's type hints** to match the primary function's signature so that type checkers and documentation accurately reflect what the alias accepts. +- **Do not change the Rust bindings.** All coercion happens in the Python layer. The Rust functions continue to accept `PyExpr`. + +## Priority Order + +When auditing functions, process them in this order: + +1. **Date/time functions** — `date_part`, `date_trunc`, `date_bin` — these have the clearest literal arguments +2. **String functions** — `left`, `right`, `repeat`, `lpad`, `rpad`, `split_part`, `substring`, `replace`, `regexp_replace`, `regexp_match`, `regexp_count` — common and verbose without coercion +3. **Math functions** — `round`, `trunc`, `power` — numeric literal arguments +4. **Array functions** — `array_slice`, `array_position`, `array_remove_n`, `array_replace_n`, `array_resize`, `array_element` — index and count arguments +5. **Other functions** — any remaining functions with literal arguments + +## Output Format + +For each function analyzed, report: + +``` +## [Function Name] + +**Current signature:** `function(arg1: Expr, arg2: Expr) -> Expr` +**Proposed signature:** `function(arg1: Expr, arg2: Expr | int) -> Expr` +**Category:** A (accepts native + Expr) +**Arguments changed:** +- `arg2`: Expr -> Expr | int (always a literal count) +**Rust binding:** Takes PyExpr, wraps to literal internally +**Status:** [Changed / Skipped / Needs Discussion] +``` + +If asked to implement (not just audit), make the changes directly and show a summary of what was updated. diff --git a/pyproject.toml b/pyproject.toml index 327199d1a..951f7adc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ extend-allowed-calls = ["datafusion.lit", "lit"] "ARG", "BLE001", "D", + "FBT003", "PD", "PLC0415", "PLR0913", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 1ff6976f7..0f7f3ab5a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -243,6 +243,8 @@ "WindowExpr", "WindowFrame", "WindowFrameBound", + "coerce_to_expr", + "coerce_to_expr_or_none", "ensure_expr", "ensure_expr_list", ] @@ -255,6 +257,10 @@ def ensure_expr(value: Expr | Any) -> expr_internal.Expr: higher level APIs consistently require explicit :func:`~datafusion.col` or :func:`~datafusion.lit` expressions. + See Also: + :func:`coerce_to_expr` — the opposite behavior: *wraps* non-``Expr`` + values as literals instead of rejecting them. + Args: value: Candidate expression or other object. @@ -299,6 +305,41 @@ def _iter( return list(_iter(exprs)) +def coerce_to_expr(value: Any) -> Expr: + """Coerce a native Python value to an ``Expr`` literal, passing ``Expr`` through. + + This is the complement of :func:`ensure_expr`: where ``ensure_expr`` + *rejects* non-``Expr`` values, ``coerce_to_expr`` *wraps* them via + :meth:`Expr.literal` so that functions can accept native Python types + (``int``, ``float``, ``str``, ``bool``, etc.) alongside ``Expr``. + + Args: + value: An ``Expr`` instance (returned as-is) or a Python literal to wrap. + + Returns: + An ``Expr`` representing the value. + """ + if isinstance(value, Expr): + return value + return Expr.literal(value) + + +def coerce_to_expr_or_none(value: Any | None) -> Expr | None: + """Coerce a value to ``Expr`` or pass ``None`` through unchanged. + + Same as :func:`coerce_to_expr` but accepts ``None`` for optional parameters. + + Args: + value: An ``Expr`` instance, a Python literal to wrap, or ``None``. + + Returns: + An ``Expr`` representing the value, or ``None``. + """ + if value is None: + return None + return coerce_to_expr(value) + + def _to_raw_expr(value: Expr | str) -> expr_internal.Expr: """Convert a Python expression or column name to its raw variant. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 280a6d3ac..08062851a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -49,6 +49,8 @@ Expr, SortExpr, SortKey, + coerce_to_expr, + coerce_to_expr_or_none, expr_list_to_raw_expr_list, sort_list_to_raw_sort_list, sort_or_default, @@ -383,49 +385,52 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(expr: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr | str) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.encode(dfn.col("a"), dfn.lit("base64")).alias("enc")) + ... dfn.functions.encode(dfn.col("a"), "base64").alias("enc")) >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) -def decode(expr: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr | str) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["aGVsbG8="]}) >>> result = df.select( - ... dfn.functions.decode(dfn.col("a"), dfn.lit("base64")).alias("dec")) + ... dfn.functions.decode(dfn.col("a"), "base64").alias("dec")) >>> result.collect_column("dec")[0].as_py() b'hello' """ + encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) -def array_to_string(expr: Expr, delimiter: Expr) -> Expr: +def array_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3]]}) >>> result = df.select( - ... dfn.functions.array_to_string(dfn.col("a"), dfn.lit(",")).alias("s")) + ... dfn.functions.array_to_string(dfn.col("a"), ",").alias("s")) >>> result.collect_column("s")[0].as_py() '1,2,3' """ + delimiter = coerce_to_expr(delimiter) return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) -def array_join(expr: Expr, delimiter: Expr) -> Expr: +def array_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -434,7 +439,7 @@ def array_join(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_to_string(expr: Expr, delimiter: Expr) -> Expr: +def list_to_string(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -443,7 +448,7 @@ def list_to_string(expr: Expr, delimiter: Expr) -> Expr: return array_to_string(expr, delimiter) -def list_join(expr: Expr, delimiter: Expr) -> Expr: +def list_join(expr: Expr, delimiter: Expr | str) -> Expr: """Converts each element to its text representation. See Also: @@ -479,7 +484,7 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: return Expr(f.in_list(arg.expr, values, negated)) -def digest(value: Expr, method: Expr) -> Expr: +def digest(value: Expr, method: Expr | str) -> Expr: """Computes the binary hash of an expression using the specified algorithm. Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, @@ -489,24 +494,26 @@ def digest(value: Expr, method: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.digest(dfn.col("a"), dfn.lit("md5")).alias("d")) + ... dfn.functions.digest(dfn.col("a"), "md5").alias("d")) >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) -def contains(string: Expr, search_str: Expr) -> Expr: +def contains(string: Expr, search_str: Expr | str) -> Expr: """Returns true if ``search_str`` is found within ``string`` (case-sensitive). Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the quick brown fox"]}) >>> result = df.select( - ... dfn.functions.contains(dfn.col("a"), dfn.lit("brown")).alias("c")) + ... dfn.functions.contains(dfn.col("a"), "brown").alias("c")) >>> result.collect_column("c")[0].as_py() True """ + search_str = coerce_to_expr(search_str) return Expr(f.contains(string.expr, search_str.expr)) @@ -969,17 +976,18 @@ def degrees(arg: Expr) -> Expr: return Expr(f.degrees(arg.expr)) -def ends_with(arg: Expr, suffix: Expr) -> Expr: +def ends_with(arg: Expr, suffix: Expr | str) -> Expr: """Returns true if the ``string`` ends with the ``suffix``, false otherwise. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abc","b","c"]}) >>> ends_with_df = df.select( - ... dfn.functions.ends_with(dfn.col("a"), dfn.lit("c")).alias("ends_with")) + ... dfn.functions.ends_with(dfn.col("a"), "c").alias("ends_with")) >>> ends_with_df.collect_column("ends_with")[0].as_py() True """ + suffix = coerce_to_expr(suffix) return Expr(f.ends_with(arg.expr, suffix.expr)) @@ -1011,7 +1019,7 @@ def factorial(arg: Expr) -> Expr: return Expr(f.factorial(arg.expr)) -def find_in_set(string: Expr, string_list: Expr) -> Expr: +def find_in_set(string: Expr, string_list: Expr | str) -> Expr: """Find a string in a list of strings. Returns a value in the range of 1 to N if the string is in the string list @@ -1023,10 +1031,11 @@ def find_in_set(string: Expr, string_list: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["b"]}) >>> result = df.select( - ... dfn.functions.find_in_set(dfn.col("a"), dfn.lit("a,b,c")).alias("pos")) + ... dfn.functions.find_in_set(dfn.col("a"), "a,b,c").alias("pos")) >>> result.collect_column("pos")[0].as_py() 2 """ + string_list = coerce_to_expr(string_list) return Expr(f.find_in_set(string.expr, string_list.expr)) @@ -1102,7 +1111,7 @@ def initcap(string: Expr) -> Expr: return Expr(f.initcap(string.expr)) -def instr(string: Expr, substring: Expr) -> Expr: +def instr(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1158,31 +1167,33 @@ def least(*args: Expr) -> Expr: return Expr(f.least(*exprs)) -def left(string: Expr, n: Expr) -> Expr: +def left(string: Expr, n: Expr | int) -> Expr: """Returns the first ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat"]}) >>> left_df = df.select( - ... dfn.functions.left(dfn.col("a"), dfn.lit(3)).alias("left")) + ... dfn.functions.left(dfn.col("a"), 3).alias("left")) >>> left_df.collect_column("left")[0].as_py() 'the' """ + n = coerce_to_expr(n) return Expr(f.left(string.expr, n.expr)) -def levenshtein(string1: Expr, string2: Expr) -> Expr: +def levenshtein(string1: Expr, string2: Expr | str) -> Expr: """Returns the Levenshtein distance between the two given strings. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["kitten"]}) >>> result = df.select( - ... dfn.functions.levenshtein(dfn.col("a"), dfn.lit("sitting")).alias("d")) + ... dfn.functions.levenshtein(dfn.col("a"), "sitting").alias("d")) >>> result.collect_column("d")[0].as_py() 3 """ + string2 = coerce_to_expr(string2) return Expr(f.levenshtein(string1.expr, string2.expr)) @@ -1199,18 +1210,19 @@ def ln(arg: Expr) -> Expr: return Expr(f.ln(arg.expr)) -def log(base: Expr, num: Expr) -> Expr: +def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041 """Returns the logarithm of a number for a particular ``base``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [100.0]}) >>> result = df.select( - ... dfn.functions.log(dfn.lit(10.0), dfn.col("a")).alias("log") + ... dfn.functions.log(10.0, dfn.col("a")).alias("log") ... ) >>> result.collect_column("log")[0].as_py() 2.0 """ + base = coerce_to_expr(base) return Expr(f.log(base.expr, num.expr)) @@ -1253,7 +1265,7 @@ def lower(arg: Expr) -> Expr: return Expr(f.lower(arg.expr)) -def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def lpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add left padding to a string. Extends the string to length length by prepending the characters fill (a @@ -1264,9 +1276,7 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["the cat", "a hat"]}) >>> lpad_df = df.select( - ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(6) - ... ).alias("lpad")) + ... dfn.functions.lpad(dfn.col("a"), 6).alias("lpad")) >>> lpad_df.collect_column("lpad")[0].as_py() 'the ca' >>> lpad_df.collect_column("lpad")[1].as_py() @@ -1274,12 +1284,13 @@ def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.lpad( - ... dfn.col("a"), dfn.lit(10), characters=dfn.lit(".") + ... dfn.col("a"), 10, characters="." ... ).alias("lpad")) >>> result.collect_column("lpad")[0].as_py() '...the cat' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -1374,7 +1385,10 @@ def octet_length(arg: Expr) -> Expr: def overlay( - string: Expr, substring: Expr, start: Expr, length: Expr | None = None + string: Expr, + substring: Expr | str, + start: Expr | int, + length: Expr | int | None = None, ) -> Expr: """Replace a substring with a new substring. @@ -1385,13 +1399,15 @@ def overlay( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcdef"]}) >>> result = df.select( - ... dfn.functions.overlay(dfn.col("a"), dfn.lit("XY"), dfn.lit(3), - ... dfn.lit(2)).alias("o")) + ... dfn.functions.overlay(dfn.col("a"), "XY", 3, 2).alias("o")) >>> result.collect_column("o")[0].as_py() 'abXYef' """ + substring = coerce_to_expr(substring) + start = coerce_to_expr(start) if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) + length = coerce_to_expr(length) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) @@ -1411,7 +1427,7 @@ def pi() -> Expr: return Expr(f.pi()) -def position(string: Expr, substring: Expr) -> Expr: +def position(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. See Also: @@ -1420,22 +1436,23 @@ def position(string: Expr, substring: Expr) -> Expr: return strpos(string, substring) -def power(base: Expr, exponent: Expr) -> Expr: +def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [2.0]}) >>> result = df.select( - ... dfn.functions.power(dfn.col("a"), dfn.lit(3.0)).alias("pow") + ... dfn.functions.power(dfn.col("a"), 3.0).alias("pow") ... ) >>> result.collect_column("pow")[0].as_py() 8.0 """ + exponent = coerce_to_expr(exponent) return Expr(f.power(base.expr, exponent.expr)) -def pow(base: Expr, exponent: Expr) -> Expr: +def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041 """Returns ``base`` raised to the power of ``exponent``. See Also: @@ -1460,7 +1477,9 @@ def radians(arg: Expr) -> Expr: return Expr(f.radians(arg.expr)) -def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_like( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Find if any regular expression (regex) matches exist. Tests a string using a regular expression returning true if at least one match, @@ -1470,9 +1489,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello123"]}) >>> result = df.select( - ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("m") + ... dfn.functions.regexp_like(dfn.col("a"), "\\d+").alias("m") ... ) >>> result.collect_column("m")[0].as_py() True @@ -1481,19 +1498,24 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_like( - ... dfn.col("a"), dfn.lit("HELLO"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "HELLO", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() True """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_like(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_like( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) -def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: +def regexp_match( + string: Expr, regex: Expr | str, flags: Expr | str | None = None +) -> Expr: r"""Perform regular expression (regex) matching. Returns an array with each element containing the leftmost-first match of the @@ -1503,9 +1525,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(\\d+)") - ... ).alias("m") + ... dfn.functions.regexp_match(dfn.col("a"), "(\\d+)").alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['42'] @@ -1514,20 +1534,26 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: >>> result = df.select( ... dfn.functions.regexp_match( - ... dfn.col("a"), dfn.lit("(HELLO)"), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "(HELLO)", flags="i", ... ).alias("m") ... ) >>> result.collect_column("m")[0].as_py() ['hello'] """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_match(string.expr, regex.expr, flags)) + regex = coerce_to_expr(regex) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_match( + string.expr, regex.expr, flags.expr if flags is not None else None + ) + ) def regexp_replace( - string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + replacement: Expr | str, + flags: Expr | str | None = None, ) -> Expr: r"""Replaces substring(s) matching a PCRE-like regular expression. @@ -1542,8 +1568,7 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["hello 42"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("XX") + ... dfn.col("a"), "\\d+", "XX" ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() @@ -1554,20 +1579,30 @@ def regexp_replace( >>> df = ctx.from_pydict({"a": ["a1 b2 c3"]}) >>> result = df.select( ... dfn.functions.regexp_replace( - ... dfn.col("a"), dfn.lit("\\d+"), - ... dfn.lit("X"), flags=dfn.lit("g"), + ... dfn.col("a"), "\\d+", "X", flags="g", ... ).alias("r") ... ) >>> result.collect_column("r")[0].as_py() 'aX bX cX' """ - if flags is not None: - flags = flags.expr - return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + pattern = coerce_to_expr(pattern) + replacement = coerce_to_expr(replacement) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_replace( + string.expr, + pattern.expr, + replacement.expr, + flags.expr if flags is not None else None, + ) + ) def regexp_count( - string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None + string: Expr, + pattern: Expr | str, + start: Expr | int | None = None, + flags: Expr | str | None = None, ) -> Expr: """Returns the number of matches in a string. @@ -1578,9 +1613,7 @@ def regexp_count( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["abcabc"]}) >>> result = df.select( - ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("abc") - ... ).alias("c")) + ... dfn.functions.regexp_count(dfn.col("a"), "abc").alias("c")) >>> result.collect_column("c")[0].as_py() 2 @@ -1589,25 +1622,31 @@ def regexp_count( >>> result = df.select( ... dfn.functions.regexp_count( - ... dfn.col("a"), dfn.lit("ABC"), - ... start=dfn.lit(4), flags=dfn.lit("i"), + ... dfn.col("a"), "ABC", start=4, flags="i", ... ).alias("c")) >>> result.collect_column("c")[0].as_py() 1 """ - if flags is not None: - flags = flags.expr - start = start.expr if start is not None else start - return Expr(f.regexp_count(string.expr, pattern.expr, start, flags)) + pattern = coerce_to_expr(pattern) + start = coerce_to_expr_or_none(start) + flags = coerce_to_expr_or_none(flags) + return Expr( + f.regexp_count( + string.expr, + pattern.expr, + start.expr if start is not None else None, + flags.expr if flags is not None else None, + ) + ) def regexp_instr( values: Expr, - regex: Expr, - start: Expr | None = None, - n: Expr | None = None, - flags: Expr | None = None, - sub_expr: Expr | None = None, + regex: Expr | str, + start: Expr | int | None = None, + n: Expr | int | None = None, + flags: Expr | str | None = None, + sub_expr: Expr | int | None = None, ) -> Expr: r"""Returns the position of a regular expression match in a string. @@ -1623,9 +1662,7 @@ def regexp_instr( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello 42 world"]}) >>> result = df.select( - ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("\\d+") - ... ).alias("pos") + ... dfn.functions.regexp_instr(dfn.col("a"), "\\d+").alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 7 @@ -1636,9 +1673,8 @@ def regexp_instr( >>> df = ctx.from_pydict({"a": ["abc ABC abc"]}) >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("abc"), - ... start=dfn.lit(2), n=dfn.lit(1), - ... flags=dfn.lit("i"), + ... dfn.col("a"), "abc", + ... start=2, n=1, flags="i", ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() @@ -1648,56 +1684,58 @@ def regexp_instr( >>> result = df.select( ... dfn.functions.regexp_instr( - ... dfn.col("a"), dfn.lit("(abc)"), - ... sub_expr=dfn.lit(1), + ... dfn.col("a"), "(abc)", sub_expr=1, ... ).alias("pos") ... ) >>> result.collect_column("pos")[0].as_py() 1 """ - start = start.expr if start is not None else None - n = n.expr if n is not None else None - flags = flags.expr if flags is not None else None - sub_expr = sub_expr.expr if sub_expr is not None else None + regex = coerce_to_expr(regex) + start = coerce_to_expr_or_none(start) + n = coerce_to_expr_or_none(n) + flags = coerce_to_expr_or_none(flags) + sub_expr = coerce_to_expr_or_none(sub_expr) return Expr( f.regexp_instr( values.expr, regex.expr, - start, - n, - flags, - sub_expr, + start.expr if start is not None else None, + n.expr if n is not None else None, + flags.expr if flags is not None else None, + sub_expr.expr if sub_expr is not None else None, ) ) -def repeat(string: Expr, n: Expr) -> Expr: +def repeat(string: Expr, n: Expr | int) -> Expr: """Repeats the ``string`` to ``n`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["ha"]}) >>> result = df.select( - ... dfn.functions.repeat(dfn.col("a"), dfn.lit(3)).alias("r")) + ... dfn.functions.repeat(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'hahaha' """ + n = coerce_to_expr(n) return Expr(f.repeat(string.expr, n.expr)) -def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def replace(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces all occurrences of ``from_val`` with ``to_val`` in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.replace(dfn.col("a"), dfn.lit("world"), - ... dfn.lit("there")).alias("r")) + ... dfn.functions.replace(dfn.col("a"), "world", "there").alias("r")) >>> result.collect_column("r")[0].as_py() 'hello there' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) @@ -1714,39 +1752,39 @@ def reverse(arg: Expr) -> Expr: return Expr(f.reverse(arg.expr)) -def right(string: Expr, n: Expr) -> Expr: +def right(string: Expr, n: Expr | int) -> Expr: """Returns the last ``n`` characters in the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) - >>> result = df.select(dfn.functions.right(dfn.col("a"), dfn.lit(3)).alias("r")) + >>> result = df.select(dfn.functions.right(dfn.col("a"), 3).alias("r")) >>> result.collect_column("r")[0].as_py() 'llo' """ + n = coerce_to_expr(n) return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: +def round(value: Expr, decimal_places: Expr | int | None = None) -> Expr: """Round the argument to the nearest integer. If the optional ``decimal_places`` is specified, round to the nearest number of decimal places. You can specify a negative number of decimal places. For example - ``round(lit(125.2345), lit(-2))`` would yield a value of ``100.0``. + ``round(lit(125.2345), -2)`` would yield a value of ``100.0``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) - >>> result = df.select(dfn.functions.round(dfn.col("a"), dfn.lit(2)).alias("r")) + >>> result = df.select(dfn.functions.round(dfn.col("a"), 2).alias("r")) >>> result.collect_column("r")[0].as_py() 1.57 """ - if decimal_places is None: - decimal_places = Expr.literal(0) + decimal_places = coerce_to_expr(decimal_places if decimal_places is not None else 0) return Expr(f.round(value.expr, decimal_places.expr)) -def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: +def rpad(string: Expr, count: Expr | int, characters: Expr | str | None = None) -> Expr: """Add right padding to a string. Extends the string to length length by appending the characters fill (a space @@ -1756,11 +1794,12 @@ def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hi"]}) >>> result = df.select( - ... dfn.functions.rpad(dfn.col("a"), dfn.lit(5), dfn.lit("!")).alias("r")) + ... dfn.functions.rpad(dfn.col("a"), 5, "!").alias("r")) >>> result.collect_column("r")[0].as_py() 'hi!!!' """ - characters = characters if characters is not None else Expr.literal(" ") + count = coerce_to_expr(count) + characters = coerce_to_expr(characters if characters is not None else " ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) @@ -1876,7 +1915,7 @@ def sinh(arg: Expr) -> Expr: return Expr(f.sinh(arg.expr)) -def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: +def split_part(string: Expr, delimiter: Expr | str, index: Expr | int) -> Expr: """Split a string and return one part. Splits a string based on a delimiter and picks out the desired field based @@ -1886,12 +1925,12 @@ def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a,b,c"]}) >>> result = df.select( - ... dfn.functions.split_part( - ... dfn.col("a"), dfn.lit(","), dfn.lit(2) - ... ).alias("s")) + ... dfn.functions.split_part(dfn.col("a"), ",", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'b' """ + delimiter = coerce_to_expr(delimiter) + index = coerce_to_expr(index) return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -1908,49 +1947,52 @@ def sqrt(arg: Expr) -> Expr: return Expr(f.sqrt(arg.expr)) -def starts_with(string: Expr, prefix: Expr) -> Expr: +def starts_with(string: Expr, prefix: Expr | str) -> Expr: """Returns true if string starts with prefix. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello_from_datafusion"]}) >>> result = df.select( - ... dfn.functions.starts_with(dfn.col("a"), dfn.lit("hello")).alias("sw")) + ... dfn.functions.starts_with(dfn.col("a"), "hello").alias("sw")) >>> result.collect_column("sw")[0].as_py() True """ + prefix = coerce_to_expr(prefix) return Expr(f.starts_with(string.expr, prefix.expr)) -def strpos(string: Expr, substring: Expr) -> Expr: +def strpos(string: Expr, substring: Expr | str) -> Expr: """Finds the position from where the ``substring`` matches the ``string``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.strpos(dfn.col("a"), dfn.lit("llo")).alias("pos")) + ... dfn.functions.strpos(dfn.col("a"), "llo").alias("pos")) >>> result.collect_column("pos")[0].as_py() 3 """ + substring = coerce_to_expr(substring) return Expr(f.strpos(string.expr, substring.expr)) -def substr(string: Expr, position: Expr) -> Expr: +def substr(string: Expr, position: Expr | int) -> Expr: """Substring from the ``position`` to the end. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.substr(dfn.col("a"), dfn.lit(3)).alias("s")) + ... dfn.functions.substr(dfn.col("a"), 3).alias("s")) >>> result.collect_column("s")[0].as_py() 'llo' """ + position = coerce_to_expr(position) return Expr(f.substr(string.expr, position.expr)) -def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: +def substr_index(string: Expr, delimiter: Expr | str, count: Expr | int) -> Expr: """Returns an indexed substring. The return will be the ``string`` from before ``count`` occurrences of @@ -1960,27 +2002,28 @@ def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["a.b.c"]}) >>> result = df.select( - ... dfn.functions.substr_index(dfn.col("a"), dfn.lit("."), - ... dfn.lit(2)).alias("s")) + ... dfn.functions.substr_index(dfn.col("a"), ".", 2).alias("s")) >>> result.collect_column("s")[0].as_py() 'a.b' """ + delimiter = coerce_to_expr(delimiter) + count = coerce_to_expr(count) return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) -def substring(string: Expr, position: Expr, length: Expr) -> Expr: +def substring(string: Expr, position: Expr | int, length: Expr | int) -> Expr: """Substring from the ``position`` with ``length`` characters. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello world"]}) >>> result = df.select( - ... dfn.functions.substring( - ... dfn.col("a"), dfn.lit(1), dfn.lit(5) - ... ).alias("s")) + ... dfn.functions.substring(dfn.col("a"), 1, 5).alias("s")) >>> result.collect_column("s")[0].as_py() 'hello' """ + position = coerce_to_expr(position) + length = coerce_to_expr(length) return Expr(f.substring(string.expr, position.expr, length.expr)) @@ -2053,7 +2096,7 @@ def current_timestamp() -> Expr: return now() -def to_char(arg: Expr, formatter: Expr) -> Expr: +def to_char(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. For usage of ``formatter`` see the rust chrono package ``strftime`` package. @@ -2066,16 +2109,17 @@ def to_char(arg: Expr, formatter: Expr) -> Expr: >>> result = df.select( ... dfn.functions.to_char( ... dfn.functions.to_timestamp(dfn.col("a")), - ... dfn.lit("%Y/%m/%d"), + ... "%Y/%m/%d", ... ).alias("formatted") ... ) >>> result.collect_column("formatted")[0].as_py() '2021/01/01' """ + formatter = coerce_to_expr(formatter) return Expr(f.to_char(arg.expr, formatter.expr)) -def date_format(arg: Expr, formatter: Expr) -> Expr: +def date_format(arg: Expr, formatter: Expr | str) -> Expr: """Returns a string representation of a date, time, timestamp or duration. See Also: @@ -2287,7 +2331,7 @@ def current_time() -> Expr: return Expr(f.current_time()) -def datepart(part: Expr, date: Expr) -> Expr: +def datepart(part: Expr | str, date: Expr) -> Expr: """Return a specified part of a date. See Also: @@ -2296,22 +2340,28 @@ def datepart(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_part(part: Expr, date: Expr) -> Expr: +def date_part(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. + Args: + part: The part of the date to extract. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to extract from. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_part(dfn.lit("year"), dfn.col("a")).alias("y")) + ... dfn.functions.date_part("year", dfn.col("a")).alias("y")) >>> result.collect_column("y")[0].as_py() 2021 """ + part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) -def extract(part: Expr, date: Expr) -> Expr: +def extract(part: Expr | str, date: Expr) -> Expr: """Extracts a subfield from the date. See Also: @@ -2320,25 +2370,29 @@ def extract(part: Expr, date: Expr) -> Expr: return date_part(part, date) -def date_trunc(part: Expr, date: Expr) -> Expr: +def date_trunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. + Args: + part: The precision to truncate to. Must be one of ``"year"``, + ``"month"``, ``"day"``, ``"hour"``, ``"minute"``, ``"second"``, etc. + date: The date expression to truncate. + Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) >>> df = df.select(dfn.functions.to_timestamp(dfn.col("a")).alias("a")) >>> result = df.select( - ... dfn.functions.date_trunc( - ... dfn.lit("month"), dfn.col("a") - ... ).alias("t") + ... dfn.functions.date_trunc("month", dfn.col("a")).alias("t") ... ) >>> str(result.collect_column("t")[0].as_py()) '2021-07-01 00:00:00' """ + part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) -def datetrunc(part: Expr, date: Expr) -> Expr: +def datetrunc(part: Expr | str, date: Expr) -> Expr: """Truncates the date to a specified level of precision. See Also: @@ -2399,18 +2453,19 @@ def make_time(hour: Expr, minute: Expr, second: Expr) -> Expr: return Expr(f.make_time(hour.expr, minute.expr, second.expr)) -def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: +def translate(string: Expr, from_val: Expr | str, to_val: Expr | str) -> Expr: """Replaces the characters in ``from_val`` with the counterpart in ``to_val``. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello"]}) >>> result = df.select( - ... dfn.functions.translate(dfn.col("a"), dfn.lit("helo"), - ... dfn.lit("HELO")).alias("t")) + ... dfn.functions.translate(dfn.col("a"), "helo", "HELO").alias("t")) >>> result.collect_column("t")[0].as_py() 'HELLO' """ + from_val = coerce_to_expr(from_val) + to_val = coerce_to_expr(to_val) return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) @@ -2427,27 +2482,24 @@ def trim(arg: Expr) -> Expr: return Expr(f.trim(arg.expr)) -def trunc(num: Expr, precision: Expr | None = None) -> Expr: +def trunc(num: Expr, precision: Expr | int | None = None) -> Expr: """Truncate the number toward zero with optional precision. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1.567]}) >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a") - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a")).alias("t")) >>> result.collect_column("t")[0].as_py() 1.0 >>> result = df.select( - ... dfn.functions.trunc( - ... dfn.col("a"), precision=dfn.lit(2) - ... ).alias("t")) + ... dfn.functions.trunc(dfn.col("a"), precision=2).alias("t")) >>> result.collect_column("t")[0].as_py() 1.56 """ if precision is not None: + precision = coerce_to_expr(precision) return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) @@ -2928,17 +2980,18 @@ def list_dims(array: Expr) -> Expr: return array_dims(array) -def array_element(array: Expr, n: Expr) -> Expr: +def array_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[10, 20, 30]]}) >>> result = df.select( - ... dfn.functions.array_element(dfn.col("a"), dfn.lit(2)).alias("result")) + ... dfn.functions.array_element(dfn.col("a"), 2).alias("result")) >>> result.collect_column("result")[0].as_py() 20 """ + n = coerce_to_expr(n) return Expr(f.array_element(array.expr, n.expr)) @@ -2964,7 +3017,7 @@ def list_empty(array: Expr) -> Expr: return array_empty(array) -def array_extract(array: Expr, n: Expr) -> Expr: +def array_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2973,7 +3026,7 @@ def array_extract(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_element(array: Expr, n: Expr) -> Expr: +def list_element(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -2982,7 +3035,7 @@ def list_element(array: Expr, n: Expr) -> Expr: return array_element(array, n) -def list_extract(array: Expr, n: Expr) -> Expr: +def list_extract(array: Expr, n: Expr | int) -> Expr: """Extracts the element with the index n from the array. See Also: @@ -3332,22 +3385,24 @@ def list_remove(array: Expr, element: Expr) -> Expr: return array_remove(array, element) -def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def array_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_remove_n(dfn.col("a"), dfn.lit(1), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_remove_n( + ... dfn.col("a"), dfn.lit(1), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 1] """ + max = coerce_to_expr(max) return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) -def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: +def list_remove_n(array: Expr, element: Expr, max: Expr | int) -> Expr: """Removes the first ``max`` elements from the array equal to the given value. See Also: @@ -3381,21 +3436,22 @@ def list_remove_all(array: Expr, element: Expr) -> Expr: return array_remove_all(array, element) -def array_repeat(element: Expr, count: Expr) -> Expr: +def array_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. Examples: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1]}) >>> result = df.select( - ... dfn.functions.array_repeat(dfn.lit(3), dfn.lit(3)).alias("result")) + ... dfn.functions.array_repeat(dfn.lit(3), 3).alias("result")) >>> result.collect_column("result")[0].as_py() [3, 3, 3] """ + count = coerce_to_expr(count) return Expr(f.array_repeat(element.expr, count.expr)) -def list_repeat(element: Expr, count: Expr) -> Expr: +def list_repeat(element: Expr, count: Expr | int) -> Expr: """Returns an array containing ``element`` ``count`` times. See Also: @@ -3428,7 +3484,7 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace(array, from_val, to_val) -def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3438,15 +3494,17 @@ def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Exp >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 1, 1]]}) >>> result = df.select( - ... dfn.functions.array_replace_n(dfn.col("a"), dfn.lit(1), dfn.lit(9), - ... dfn.lit(2)).alias("result")) + ... dfn.functions.array_replace_n( + ... dfn.col("a"), dfn.lit(1), dfn.lit(9), 2 + ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [9, 2, 9, 1] """ + max = coerce_to_expr(max) return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) -def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr | int) -> Expr: """Replace ``n`` occurrences of ``from_val`` with ``to_val``. Replaces the first ``max`` occurrences of the specified element with another @@ -3529,7 +3587,10 @@ def list_sort(array: Expr, descending: bool = False, null_first: bool = False) - def array_slice( - array: Expr, begin: Expr, end: Expr, stride: Expr | None = None + array: Expr, + begin: Expr | int, + end: Expr | int, + stride: Expr | int | None = None, ) -> Expr: """Returns a slice of the array. @@ -3537,9 +3598,7 @@ def array_slice( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) >>> result = df.select( - ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(2), dfn.lit(3) - ... ).alias("result")) + ... dfn.functions.array_slice(dfn.col("a"), 2, 3).alias("result")) >>> result.collect_column("result")[0].as_py() [2, 3] @@ -3547,18 +3606,27 @@ def array_slice( >>> result = df.select( ... dfn.functions.array_slice( - ... dfn.col("a"), dfn.lit(1), dfn.lit(4), - ... stride=dfn.lit(2), + ... dfn.col("a"), 1, 4, stride=2, ... ).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 3] """ - if stride is not None: - stride = stride.expr - return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + begin = coerce_to_expr(begin) + end = coerce_to_expr(end) + stride = coerce_to_expr_or_none(stride) + return Expr( + f.array_slice( + array.expr, + begin.expr, + end.expr, + stride.expr if stride is not None else None, + ) + ) -def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: +def list_slice( + array: Expr, begin: Expr | int, end: Expr | int, stride: Expr | int | None = None +) -> Expr: """Returns a slice of the array. See Also: @@ -3650,7 +3718,7 @@ def list_except(array1: Expr, array2: Expr) -> Expr: return array_except(array1, array2) -def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def array_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will @@ -3660,15 +3728,15 @@ def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [[1, 2]]}) >>> result = df.select( - ... dfn.functions.array_resize(dfn.col("a"), dfn.lit(4), - ... dfn.lit(0)).alias("result")) + ... dfn.functions.array_resize(dfn.col("a"), 4, dfn.lit(0)).alias("result")) >>> result.collect_column("result")[0].as_py() [1, 2, 0, 0] """ + size = coerce_to_expr(size) return Expr(f.array_resize(array.expr, size.expr, value.expr)) -def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: +def list_resize(array: Expr, size: Expr | int, value: Expr) -> Expr: """Returns an array with the specified size filled. If ``size`` is greater than the ``array`` length, the additional entries will be @@ -3822,7 +3890,7 @@ def list_zip(*arrays: Expr) -> Expr: def string_to_array( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. @@ -3832,9 +3900,7 @@ def string_to_array( >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": ["hello,world"]}) >>> result = df.select( - ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), - ... ).alias("result")) + ... dfn.functions.string_to_array(dfn.col("a"), ",").alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', 'world'] @@ -3842,17 +3908,24 @@ def string_to_array( >>> result = df.select( ... dfn.functions.string_to_array( - ... dfn.col("a"), dfn.lit(","), null_string=dfn.lit("world"), + ... dfn.col("a"), ",", null_string="world", ... ).alias("result")) >>> result.collect_column("result")[0].as_py() ['hello', None] """ - null_expr = null_string.expr if null_string is not None else None - return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr)) + delimiter = coerce_to_expr(delimiter) + null_string = coerce_to_expr_or_none(null_string) + return Expr( + f.string_to_array( + string.expr, + delimiter.expr, + null_string.expr if null_string is not None else None, + ) + ) def string_to_list( - string: Expr, delimiter: Expr, null_string: Expr | None = None + string: Expr, delimiter: Expr | str, null_string: Expr | str | None = None ) -> Expr: """Splits a string based on a delimiter and returns an array of parts. diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 13c05a9e6..e0ebdbae5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -964,12 +964,12 @@ def test_csv_read_options_builder_pattern(): options = ( CsvReadOptions() - .with_has_header(False) # noqa: FBT003 + .with_has_header(False) .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) - .with_truncated_rows(True) # noqa: FBT003 - .with_newlines_in_values(True) # noqa: FBT003 + .with_truncated_rows(True) + .with_newlines_in_values(True) .with_file_extension(".tsv") ) assert options.has_header is False diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 091fa9b56..9e2f791ea 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -3426,10 +3426,18 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +_slow_udf_started = threading.Event() + + @udf([pa.int64()], pa.int64(), "immutable") def slow_udf(x: pa.Array) -> pa.Array: - # This must be longer than the check interval in wait_for_future - time.sleep(2.0) + _slow_udf_started.set() + # Sleep in small increments so Python's eval loop checks for pending + # async exceptions (like KeyboardInterrupt via PyThreadState_SetAsyncExc) + # between iterations. A single long time.sleep() is a C call where async + # exceptions are not checked on all Python versions (notably 3.11). + for _ in range(200): + time.sleep(0.01) return x @@ -3463,6 +3471,7 @@ def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 P if as_c_stream: reader = pa.RecordBatchReader.from_stream(df) + _slow_udf_started.clear() read_started = threading.Event() read_exception = [] read_thread_id = None @@ -3474,6 +3483,14 @@ def trigger_interrupt(): msg = f"Read operation did not start within {max_wait_time} seconds" raise RuntimeError(msg) + # For slow_query tests, wait until the UDF is actually executing Python + # bytecode before sending the interrupt. PyThreadState_SetAsyncExc only + # delivers exceptions when the thread is in the Python eval loop, not + # while in native (Rust/C) code. + if slow_query and not _slow_udf_started.wait(timeout=max_wait_time): + msg = f"UDF did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + if read_thread_id is None: msg = "Cannot get read thread ID" raise RuntimeError(msg) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index d046eb48c..8aa791ae1 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -53,6 +53,8 @@ TransactionEnd, TransactionStart, Values, + coerce_to_expr, + coerce_to_expr_or_none, ensure_expr, ensure_expr_list, ) @@ -1030,12 +1032,55 @@ def test_ensure_expr_list_bytearray(): ensure_expr_list(bytearray(b"a")) +def test_coerce_to_expr_passes_expr_through(): + e = col("a") + result = coerce_to_expr(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + +def test_coerce_to_expr_wraps_int(): + result = coerce_to_expr(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_wraps_str(): + result = coerce_to_expr("hello") + assert isinstance(result, type(lit("hello"))) + + +def test_coerce_to_expr_wraps_float(): + result = coerce_to_expr(3.14) + assert isinstance(result, type(lit(3.14))) + + +def test_coerce_to_expr_wraps_bool(): + result = coerce_to_expr(True) + assert isinstance(result, type(lit(True))) + + +def test_coerce_to_expr_or_none_returns_none(): + assert coerce_to_expr_or_none(None) is None + + +def test_coerce_to_expr_or_none_wraps_value(): + result = coerce_to_expr_or_none(42) + assert isinstance(result, type(lit(42))) + + +def test_coerce_to_expr_or_none_passes_expr_through(): + e = col("a") + result = coerce_to_expr_or_none(e) + assert isinstance(result, type(e)) + assert str(result) == str(e) + + @pytest.mark.parametrize( "value", [ # Boolean - pa.scalar(True, type=pa.bool_()), # noqa: FBT003 - pa.scalar(False, type=pa.bool_()), # noqa: FBT003 + pa.scalar(True, type=pa.bool_()), + pa.scalar(False, type=pa.bool_()), # Integers - signed pa.scalar(127, type=pa.int8()), pa.scalar(-128, type=pa.int8()), diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 11e94af1c..d9781b1fb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -2099,3 +2099,96 @@ def test_gen_series_with_step(): f.gen_series(literal(1), literal(10), literal(3)).alias("v") ).collect() assert result[0].column(0)[0].as_py() == [1, 4, 7, 10] + + +class TestPythonicNativeTypes: + """Tests for accepting native Python types instead of requiring lit().""" + + def test_split_part_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect() + assert result[0].column(0)[0].as_py() == "b" + + def test_encode_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello"]}) + result = df.select(f.encode(column("a"), "base64").alias("e")).collect() + assert result[0].column(0)[0].as_py() == "aGVsbG8" + + def test_date_part_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_part("year", column("a")).alias("y")).collect() + assert result[0].column(0)[0].as_py() == 2021 + + def test_date_trunc_native_str(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]}) + df = df.select(f.to_timestamp(column("a")).alias("a")) + result = df.select(f.date_trunc("month", column("a")).alias("t")).collect() + assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00" + + def test_left_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["the cat"]}) + result = df.select(f.left(column("a"), 3).alias("l")).collect() + assert result[0].column(0)[0].as_py() == "the" + + def test_round_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.567]}) + result = df.select(f.round(column("a"), 2).alias("r")).collect() + assert result[0].column(0)[0].as_py() == 1.57 + + def test_regexp_count_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["abcabc"]}) + result = df.select( + f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c") + ).collect() + assert result[0].column(0)[0].as_py() == 1 + + def test_log_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [100.0]}) + result = df.select(f.log(10, column("a")).alias("l")).collect() + assert result[0].column(0)[0].as_py() == 2.0 + + def test_power_native_int(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [2.0]}) + result = df.select(f.power(column("a"), 3).alias("p")).collect() + assert result[0].column(0)[0].as_py() == 8.0 + + def test_array_slice_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3, 4]]}) + result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect() + assert result[0].column(0)[0].as_py() == [2, 3] + + def test_string_to_array_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["hello,NA,world"]}) + result = df.select( + f.string_to_array(column("a"), ",", null_string="NA").alias("v") + ).collect() + assert result[0].column(0)[0].as_py() == ["hello", None, "world"] + + def test_regexp_replace_native(self): + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a1 b2 c3"]}) + result = df.select( + f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r") + ).collect() + assert result[0].column(0)[0].as_py() == "aX bX cX" + + def test_backward_compat_with_lit(self): + """Verify that existing code using lit() still works.""" + ctx = SessionContext() + df = ctx.from_pydict({"a": ["a,b,c"]}) + result = df.select( + f.split_part(column("a"), literal(","), literal(2)).alias("s") + ).collect() + assert result[0].column(0)[0].as_py() == "b"