forked from lance-format/lance
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_huggingface.py
More file actions
76 lines (57 loc) · 2.12 KB
/
test_huggingface.py
File metadata and controls
76 lines (57 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
from pathlib import Path
import lance
import numpy as np
import pyarrow as pa
import pytest
datasets = pytest.importorskip("datasets")
pil = pytest.importorskip("PIL")
def test_write_hf_dataset(tmp_path: Path):
hf_ds = datasets.load_dataset(
"cornell-movie-review-data/rotten_tomatoes",
split="train[:50]",
)
ds = lance.write_dataset(hf_ds, tmp_path)
assert ds.count_rows() == 50
assert ds.schema == hf_ds.features.arrow_schema
@pytest.mark.cuda
def test_image_hf_dataset(tmp_path: Path):
import lance.torch.data
ds = datasets.Dataset.from_dict(
{"i": [np.zeros(shape=(16, 16, 3), dtype=np.uint8)]},
features=datasets.Features({"i": datasets.Image()}),
)
ds = lance.write_dataset(ds, tmp_path)
dataset = lance.torch.data.LanceDataset(
ds,
columns=["i"],
batch_size=8,
)
batch = next(iter(dataset))
assert len(batch) == 1
assert all(
(isinstance(img, pil.Image.Image) and np.all(np.array(img) == 0))
for img in batch
)
def test_iterable_dataset(tmp_path: Path):
# IterableDataset yields dict of arrays
def gen():
yield {"text": "Good", "label": 0}
yield {"text": "Bad", "label": 1}
arrow_schema = pa.schema([("text", pa.string()), ("label", pa.int64())])
features = datasets.Features.from_arrow_schema(arrow_schema)
iter_ds = datasets.IterableDataset.from_generator(gen, features=features)
# streaming batch size is controlled by max_rows_per_group
ds1 = lance.write_dataset(iter_ds, tmp_path / "ds1.lance")
assert ds1.count_rows() == 2
assert ds1.schema == iter_ds.features.arrow_schema
# to manually control streaming batch size
ds2 = lance.write_dataset(
pa.Table.from_arrays([[], []], schema=arrow_schema), tmp_path / "ds2.lance"
)
for batch in iter_ds.iter(batch_size=1):
# shouldn't fail
ds2 = lance.write_dataset(batch, tmp_path / "ds2.lance", mode="append")
assert len(ds1) == len(ds2)
assert ds1.schema == ds2.schema