forked from lance-format/lance
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_torch.py
More file actions
65 lines (50 loc) · 1.9 KB
/
test_torch.py
File metadata and controls
65 lines (50 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
import lance
import numpy as np
import pyarrow as pa
import pytest
pytest.importorskip("torch")
from lance.torch.data import SafeLanceDataset, get_safe_loader # noqa: E402
@pytest.fixture(scope="module")
def temp_lance_dataset(tmp_path_factory):
"""Create temporary Lance dataset for testing"""
test_dir = tmp_path_factory.mktemp("lance_data")
dataset_path = test_dir / "test_dataset.lance"
# Generate test data with batch_size aligned sample count
num_samples = 96 # 16 samples/batch * 6 batches
data = pa.table(
{
"id": range(num_samples),
"embedding": [
np.random.rand(128).astype(np.float32).tobytes()
for _ in range(num_samples)
],
}
)
lance.write_dataset(data, dataset_path)
yield str(dataset_path)
def test_dataset_initialization(temp_lance_dataset):
"""Verify dataset basic functionality"""
ds = SafeLanceDataset(temp_lance_dataset)
# Validate metadata
assert len(ds) == 96, "Sample count should match configured size"
# Validate single sample format
sample = ds[0]
assert isinstance(sample, dict), "Sample should be dictionary type"
assert {"id", "embedding"}.issubset(sample.keys()), "Missing required fields"
def test_multiprocess_loading(temp_lance_dataset, capsys):
"""Verify multi-worker data loading"""
dataset = SafeLanceDataset(temp_lance_dataset)
loader = get_safe_loader(
dataset,
num_workers=2,
batch_size=16,
drop_last=False, # Ensure full batches
)
total_samples = 0
for batch in loader:
assert batch["id"].shape == (16,), "Batch dimension mismatch"
total_samples += batch["id"].shape[0]
# Validate complete dataset loading
assert total_samples == 96, "Should load all samples"