-
Notifications
You must be signed in to change notification settings - Fork 152
Expand file tree
/
Copy pathtest_plans.py
More file actions
232 lines (188 loc) · 7.99 KB
/
test_plans.py
File metadata and controls
232 lines (188 loc) · 7.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pytest
from datafusion import (
ExecutionPlan,
LogicalPlan,
Metric,
MetricsSet,
SessionContext,
)
# Note: We must use CSV because memory tables are currently not supported for
# conversion to/from protobuf.
@pytest.fixture
def df():
ctx = SessionContext()
return ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv").select("c1")
def test_logical_plan_to_proto(ctx, df) -> None:
logical_plan_bytes = df.logical_plan().to_proto()
logical_plan = LogicalPlan.from_proto(ctx, logical_plan_bytes)
df_round_trip = ctx.create_dataframe_from_logical_plan(logical_plan)
assert df.collect() == df_round_trip.collect()
original_execution_plan = df.execution_plan()
execution_plan_bytes = original_execution_plan.to_proto()
execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes)
assert str(original_execution_plan) == str(execution_plan)
def test_metrics_tree_walk() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()
results = plan.collect_metrics()
assert len(results) >= 1
output_rows_by_op: dict[str, int] = {}
for name, ms in results:
assert isinstance(name, str)
assert isinstance(ms, MetricsSet)
if ms.output_rows is not None:
output_rows_by_op[name] = ms.output_rows
# The filter passes rows where column1 > 1, so exactly
# 2 rows from (1,'a'),(2,'b'),(3,'c').
# At least one operator must report exactly 2 output rows (the filter).
assert 2 in output_rows_by_op.values(), (
f"Expected an operator with output_rows=2, got {output_rows_by_op}"
)
def test_metric_properties() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()
found_any_metric = False
for _, ms in plan.collect_metrics():
r = repr(ms)
assert isinstance(r, str)
for metric in ms.metrics():
found_any_metric = True
assert isinstance(metric, Metric)
assert isinstance(metric.name, str)
assert len(metric.name) > 0
assert metric.partition is None or isinstance(metric.partition, int)
assert metric.value is None or isinstance(
metric.value, int | datetime.datetime
)
assert isinstance(metric.labels(), dict)
mr = repr(metric)
assert isinstance(mr, str)
assert len(mr) > 0
assert found_any_metric, "Expected at least one metric after execution"
def test_no_meaningful_metrics_before_execution() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
plan_before = df.execution_plan()
# Some plan nodes (e.g. DataSourceExec) eagerly initialize a MetricsSet,
# so metrics() may return a set even before execution. However, no rows
# should have been processed yet — output_rows must be absent or zero.
for _, ms in plan_before.collect_metrics():
rows = ms.output_rows
assert rows is None or rows == 0, (
f"Expected 0 output_rows before execution, got {rows}"
)
# After execution, at least one operator must report rows processed.
df.collect()
plan_after = df.execution_plan()
output_rows_after = [
ms.output_rows
for _, ms in plan_after.collect_metrics()
if ms.output_rows is not None and ms.output_rows > 0
]
assert len(output_rows_after) > 0, "Expected output_rows > 0 after execution"
def test_collect_partitioned_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect_partitioned()
plan = df.execution_plan()
output_rows_values = [
ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
]
assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
def test_execute_stream_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
for _ in df.execute_stream():
pass
plan = df.execution_plan()
output_rows_values = [
ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
]
assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
def test_execute_stream_partitioned_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
for stream in df.execute_stream_partitioned():
for _ in stream:
pass
plan = df.execution_plan()
output_rows_values = [
ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
]
assert 2 in output_rows_values, f"Expected 2 in {output_rows_values}"
def test_value_as_datetime() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()
for _, ms in plan.collect_metrics():
for metric in ms.metrics():
if metric.name in ("start_timestamp", "end_timestamp"):
dt = metric.value_as_datetime
assert dt is None or isinstance(dt, datetime.datetime)
if dt is not None:
assert dt.tzinfo is not None
else:
assert metric.value_as_datetime is None
def test_metric_names_and_labels() -> None:
"""Verify that known metric names appear and labels are well-formed."""
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
plan = df.execution_plan()
all_metric_names: set[str] = set()
for _, ms in plan.collect_metrics():
for metric in ms.metrics():
all_metric_names.add(metric.name)
# Labels must be a dict of str->str
labels = metric.labels()
for k, v in labels.items():
assert isinstance(k, str)
assert isinstance(v, str)
# After a filter query, we expect at minimum these standard metric names.
assert "output_rows" in all_metric_names, (
f"Expected 'output_rows' in {all_metric_names}"
)
assert "elapsed_compute" in all_metric_names, (
f"Expected 'elapsed_compute' in {all_metric_names}"
)
def test_collect_twice_has_metrics() -> None:
ctx = SessionContext()
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
df.collect()
df.collect()
plan = df.execution_plan()
output_rows_values = [
ms.output_rows for _, ms in plan.collect_metrics() if ms.output_rows is not None
]
assert len(output_rows_values) > 0