forked from lance-format/lance
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
486 lines (417 loc) · 16.5 KB
/
data.py
File metadata and controls
486 lines (417 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
"""Read Lance dataset as torch DataPipe."""
# PEP-585. Can be removed after deprecating python 3.8 support.
from __future__ import annotations
import json
import logging
import math
import warnings
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Literal,
Optional,
Protocol,
Union,
)
import pyarrow as pa
import lance
from lance._dataset.cache import CachedDataset
from lance.dependencies import _check_for_numpy, torch
from lance.dependencies import numpy as np
from ..sampler import (
FullScanSampler,
Sampler,
ShardedBatchSampler,
ShardedFragmentSampler,
maybe_sample,
)
from .dist import get_global_rank, get_global_world_size
__all__ = ["LanceDataset", "SafeLanceDataset", "get_safe_loader"]
class ToTensorFn(Protocol):
def __call__(
self,
batch: Union[pa.RecordBatch, Dict[str, Any]],
*,
hf_converter: Optional[dict] = None,
use_blob_api: bool = False,
**kwargs: Any,
) -> Union[dict[str, torch.Tensor], torch.Tensor]: ...
# Convert an Arrow FSL array into a 2D torch tensor
def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor:
# Note: FixedSizeListArray.values does not take offset/len into account and
# so may we need to slice here
values = arr.values
start = arr.offset * dimension
num_vals = len(arr) * dimension
values = values.slice(start, num_vals)
# Convert to numpy
nparr = values.to_numpy(zero_copy_only=False).reshape(-1, dimension)
return torch.from_numpy(nparr)
def _to_tensor(
batch: Union[pa.RecordBatch, Dict[str, pa.Array]],
*,
uint64_as_int64: bool = True,
hf_converter: Optional[dict] = None,
use_blob_api: bool = False,
**kwargs,
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
"""Convert a pyarrow RecordBatch to torch Tensor."""
ret = {}
cols = (
batch.column_names if isinstance(batch, pa.RecordBatch) else list(batch.keys())
)
for col in cols:
arr: pa.Array = batch[col]
if (
use_blob_api
and isinstance(arr, list)
and arr
and isinstance(arr[0], lance.BlobFile)
):
raise NotImplementedError(
'Need user-provided "to_tensor_fn" for Blob files'
)
tensor: torch.Tensor = None
if (isinstance(arr.type, pa.FixedShapeTensorType)) and (
pa.types.is_floating(arr.type.value_type)
or pa.types.is_integer(arr.type.value_type)
):
arr = arr.storage
if (pa.types.is_fixed_size_list(arr.type)) and (
pa.types.is_floating(arr.type.value_type)
or pa.types.is_integer(arr.type.value_type)
):
tensor = _fsl_to_tensor(arr, arr.type.list_size)
elif (
pa.types.is_integer(arr.type)
or pa.types.is_floating(arr.type)
or pa.types.is_boolean(arr.type)
):
tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False))
if uint64_as_int64 and tensor.dtype == torch.uint64:
tensor = tensor.to(torch.int64)
elif hf_converter is not None:
tensor = hf_converter.to_pytorch(col, arr)
if tensor is None:
raise ValueError(
"Only support FixedSizeList<f16/f32/f64> or "
+ f"numeric values, got: {arr.type}"
)
del arr
ret[col] = tensor
if len(ret) == 1:
t = next(iter(ret.values()))
del ret
return t
return ret
class TensorDataset(torch.utils.data.Dataset):
"""A PyTorch Dataset that wraps over a tensor, returns in batches.
Unlike `torch.utils.data.TensorDataset`, this has the same behavior as LanceDataset
that it yields tensor in batches.
"""
def __init__(
self, data: Union[torch.Tensor, np.ndarray], batch_size: int, *args, **kwargs
):
super().__init__(*args, **kwargs)
if _check_for_numpy(data) and isinstance(data, np.ndarray):
data = torch.from_numpy(data)
self._data: torch.Tensor = data
self._batch_size = batch_size
def __repr__(self):
return "LanceTensorDataset"
def __len__(self) -> int:
return math.ceil(self._data.shape[0] / self._batch_size)
def __getitem__(self, idx: int) -> torch.Tensor:
if idx >= len(self):
raise StopIteration
start = idx * self._batch_size
end = min((idx + 1) * self._batch_size, self._data.shape[0])
return self._data[start:end, :]
def concat_batches(bs):
return pa.RecordBatch.from_arrays(
[
pa.concat_arrays([b.columns[i] for b in bs])
for i in range(bs[0].num_columns)
],
schema=bs[0].schema,
)
def _buffer_arrow_batches(
it: Iterable[pa.RecordBatch],
buffer_size: int = 10240,
) -> Iterable[pa.RecordBatch]:
buffer = []
cur_size = 0
for item in it:
if cur_size > 0 and cur_size + item.num_rows > buffer_size:
if len(buffer) == 1:
# Most of the time, we are in the happy situation where we have a single
# batch to yield.
yield buffer[0]
else:
yield concat_batches(buffer)
buffer = []
cur_size = 0
buffer.append(item)
cur_size += item.num_rows
if buffer:
yield concat_batches(buffer)
class LanceDataset(torch.utils.data.IterableDataset):
"""PyTorch :class:`torch.utils.data.IterableDataset` over lance dataset."""
def __init__(
self,
dataset: Union[torch.utils.data.Dataset, str, Path],
batch_size: int,
*args,
dataset_options: Optional[Dict[str, Any]] = None,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
filter: Optional[str] = None,
samples: Optional[int] = 0,
cache: Optional[Union[str, bool]] = None,
with_row_id: bool = False,
rank: Optional[int] = None,
world_size: Optional[int] = None,
shard_granularity: Optional[Literal["fragment", "batch"]] = None,
batch_readahead: int = 16,
to_tensor_fn: Optional[ToTensorFn] = _to_tensor,
sampler: Optional[Sampler] = None,
auto_detect_rank: bool = True,
**kwargs,
):
"""Use PyTorch Dataset API to read Lance dataset.
Parameters
----------
dataset : Union[torch.utils.data.Dataset, str, Path]
Lance dataset to read. Can be URI, path, or an initialized Lance Dataset.
batch_size : int
Batch size to yield for each iteration.
columns : list of str, optional
The names of the column to read, by default None, which means reading all
columns.
filter : str, optional
If set, only rows that match the filter will be read. Currently, this
can only be used when doing a full scan (`sampler` is None and
shard_granularity is None or "fragment" and `samples` is None)
cache : str or bool, optional
If set true, the dataset will be cached on disk from the first iteration.
The following iterations will read from the cache.
with_row_id : bool, optional
If set true, the returned batch will have an additional column named
`_rowid` that contains the row id of the batch.
rank: int, optional (deprecated)
If set, the rank (idx) of this process in distributed training / inference.
world_size: int, optional (deprecated)
If set, the total number of processes in distributed training / inference.
shard_granularity: str, optional
The basic unit of sharding data. If set to "fragment", each worker will get
the a subset of fragments.
If set to "batch", it will read the "batch" interleave with the
same fragments.
batch_readahead: int, optional
The number of batches to read ahead in different (Rust) threads for each
fragment.
sampler: callable, optional
A function that samples the dataset.
to_tensor_fn : callable, optional
A function that converts a pyarrow RecordBatch to torch.Tensor.
Should accept a batch (RecordBatch or Dict[str, pa.Array]) as the first
argument, plus optional keyword arguments ``hf_converter`` and
``use_blob_api``.
auto_detect_rank: bool = True, optional
If set true, the rank and world_size will be detected automatically.
"""
super().__init__()
if isinstance(dataset, (str, Path)):
dataset_options = dataset_options or {}
dataset = lance.dataset(dataset, **dataset_options)
self.dataset = dataset
self.columns = columns
self.batch_size = batch_size
self.samples: Optional[int] = samples
self.filter = filter
self.with_row_id = with_row_id
self.batch_readahead = batch_readahead
self._to_tensor_fn = to_tensor_fn
self._hf_converter = None
self._blob_columns = self._blob_columns()
if self._blob_columns:
self.with_row_id = True
# As Shared Dataset
self.shard_granularity = shard_granularity
self.rank = rank
self.world_size = world_size
if rank is not None and world_size is not None:
warnings.warn("rank and world_size are deprecated", DeprecationWarning)
self.sampler: Optional[Sampler] = sampler
# Dataset with huggingface metadata
if (
dataset.schema.metadata is not None
and (hf_meta := dataset.schema.metadata.get(b"huggingface")) is not None
):
from ..hf import HuggingFaceConverter
hf_ds_info = json.loads(hf_meta)
self._hf_converter = HuggingFaceConverter(hf_ds_info)
self.cache = cache
self.cached_ds: Optional[CachedDataset] = None
self._auto_detect_rank = auto_detect_rank
def __repr__(self) -> str:
return f"LanceTorchDataset({self.dataset.uri}, size={self.samples})"
@property
def schema(self) -> pa.Schema:
if not self.columns:
return self.dataset.schema
fields = [self.dataset.schema.field(col) for col in self.columns]
return pa.schema(fields, metadata=self.dataset.schema.metadata)
def __iter__(self):
if self.sampler is None:
if self.rank is not None:
rank = self.rank
elif self._auto_detect_rank:
rank = get_global_rank()
else:
rank = None
if self.world_size is not None:
world_size = self.world_size
elif self._auto_detect_rank:
world_size = get_global_world_size()
else:
world_size = None
if self.shard_granularity is None:
if rank is not None and world_size is not None:
sampler = ShardedFragmentSampler(rank=rank, world_size=world_size)
else:
sampler = FullScanSampler()
elif self.shard_granularity == "batch":
sampler = ShardedBatchSampler(rank, world_size)
elif self.shard_granularity == "fragment":
sampler = ShardedFragmentSampler(rank, world_size)
else:
raise ValueError("Invalid shard_granularity: {}")
else:
sampler = self.sampler
projected_columns = self.columns or self.dataset.schema.names
if self._blob_columns:
projected_columns = [
c for c in projected_columns if c not in self._blob_columns
]
stream: Iterable[pa.RecordBatch]
if self.cached_ds:
stream = self.cached_ds
else:
if self.samples:
raw_stream = maybe_sample(
self.dataset,
n=self.samples,
columns=projected_columns,
batch_size=self.batch_size,
filt=self.filter,
)
else:
raw_stream = sampler(
self.dataset,
columns=projected_columns,
filter=self.filter,
batch_size=self.batch_size,
with_row_id=self.with_row_id,
batch_readahead=self.batch_readahead,
)
stream = _buffer_arrow_batches(raw_stream, buffer_size=self.batch_size)
if self.cache:
self.cached_ds = CachedDataset(stream, cache=self.cache)
stream = self.cached_ds
use_blob_api = bool(self._blob_columns)
for batch in stream:
if use_blob_api:
dict_batch = {}
assert "_rowid" in batch.column_names
row_ids = batch["_rowid"]
for col in batch.column_names:
dict_batch[col] = batch[col]
for col in self._blob_columns:
dict_batch[col] = self.dataset.take_blobs(
ids=row_ids.to_pylist(), blob_column=col
)
batch = dict_batch
if self._to_tensor_fn is not None:
batch = self._to_tensor_fn(
batch, hf_converter=self._hf_converter, use_blob_api=use_blob_api
)
yield batch
del batch
def _blob_columns(self) -> List[str]:
"""Returns True if one of the projected column is Large Blob encoded."""
cols = self.columns
if not cols:
cols = self.dataset.schema.names
blob_cols = []
for col in cols:
field = self.dataset.schema.field(col)
if (
field.type == pa.large_binary()
and field.metadata is not None
and field.metadata.get(b"lance-encoding:blob") == b"true"
):
logging.debug("Column %s is a Large Blob column", col)
blob_cols.append(col)
return blob_cols
class SafeLanceDataset(torch.utils.data.Dataset):
def __init__(self, uri, *, dataset_options=None, **kwargs):
super().__init__(**kwargs)
self.uri = uri
self.dataset_options = dataset_options or {}
self._len = self._safe_preload()
self._ds = None
def _safe_preload(self):
"""Main-process safe metadata loading"""
ds = lance.dataset(self.uri, **self.dataset_options)
length = ds.count_rows()
del ds
return length
def __len__(self):
return self._len
def __getitem__(self, idx):
return self.__getitems__([idx])[0]
def __getitems__(self, indices):
"""Batch data fetching with worker-safe initialization
Args:
indices: List[int] - batch indices to retrieve
Returns:
List[dict] - samples in original data format
"""
if self._ds is None:
# Worker-process initialization
import os
self._ds = lance.dataset(self.uri)
print(f"Worker {os.getpid()} initialized dataset")
# Leverage native batch reading
batch = self._ds.take(indices)
# Convert to python-native format
return batch.to_pylist()
def get_safe_loader(dataset, batch_size=32, num_workers=4, **kwargs):
"""Create a DataLoader with safe multiprocessing defaults
Args:
dataset: Input dataset object
batch_size: Number of samples per batch (default=32)
num_workers: Number of parallel data workers (default=4)
**kwargs: Additional DataLoader arguments. Note:
- Forces 'spawn' context for Windows compatibility
- Sets persistent_workers=True by default
- User-provided args override defaults
Returns:
Configured DataLoader instance with process-safe settings
"""
# Force spawn context for Windows/multiprocessing compatibility
ctx = torch.multiprocessing.get_context("spawn")
# Configure default parameters with process safety
loader_args = {
"batch_size": batch_size,
"num_workers": num_workers,
"persistent_workers": kwargs.pop("persistent_workers", True),
"multiprocessing_context": ctx,
**kwargs, # User-provided arguments take priority
}
return torch.utils.data.DataLoader(dataset, **loader_args)