forked from lance-format/lance
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
31 lines (25 loc) · 878 Bytes
/
__init__.py
File metadata and controls
31 lines (25 loc) · 878 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
from typing import Optional
from lance.dependencies import torch
def preferred_device(device: Optional[str] = None):
"""Get the preferred device for computation.
Parameters
----------
device : str, optional
Device to use for computation. If None, the device will be
detected automatically based on the platform.
Returns
-------
device : torch.device
Device to use for computation.
"""
if device is not None:
if isinstance(device, str):
device = torch.device(device)
return device
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
return torch.device("cpu")