Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplify lu_factor by moving logic to dpnp_lu_factor
  • Loading branch information
vlad-perevezentsev committed Aug 28, 2025
commit 02b645e8500fb7635756a2f01929077a08e10e4a
90 changes: 16 additions & 74 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def _is_empty_2d(arr):
return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0


def _lu_factor(a, res_type, scipy=False, overwrite_a=False):
def _lu_factor(a, res_type):
"""
Compute pivoted LU decomposition.

Expand Down Expand Up @@ -1050,41 +1050,18 @@ def _lu_factor(a, res_type, scipy=False, overwrite_a=False):

a_usm_arr = dpnp.get_usm_ndarray(a)

if not scipy:
# Internal use case (e.g., det(), slogdet()). Always copy.
# `a` must be copied because getrf destroys the input matrix
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
# `a` must be copied because getrf destroys the input matrix
a_h = dpnp.empty_like(a, order="C", dtype=res_type)

# use DPCTL tensor function to fill the сopy of the input array
# from the input array
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)

else:
# SciPy-compatible behavior
# Copy is required if:
# - overwrite_a is False (always copy),
# - dtype mismatch,
# - not F-contiguous,
# - not writeable
if not overwrite_a or _is_copy_required(a, res_type):
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)
else:
# input is suitable for in-place modification
a_h = a
copy_ev = None
# use DPCTL tensor function to fill the сopy of the input array
# from the input array
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_h = dpnp.empty(
n,
Expand All @@ -1102,55 +1079,20 @@ def _lu_factor(a, res_type, scipy=False, overwrite_a=False):
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[copy_ev] if copy_ev is not None else [],
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, getrf_ev)

# Return list if called in SciPy-compatible mode
# else dpnp.ndarray
if scipy:
dev_info_array = dev_info_h
else:
dev_info_array = dpnp.array(
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
)
dev_info_array = dpnp.array(
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
)

# Return a tuple containing the factorized matrix 'a_h',
# pivot indices 'ipiv_h'
# and the status 'dev_info_h' from the LAPACK getrf call
return (a_h, ipiv_h, dev_info_array)


def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
"""Compute pivoted LU decomposition."""

res_type = _common_type(a)

# accommodate empty arrays
if a.size == 0:
lu = dpnp.empty_like(a)
piv = dpnp.arange(0, dtype=dpnp.int32)
return lu, piv

if check_finite:
if not dpnp.isfinite(a).all():
raise ValueError("array must not contain infs or NaNs")

lu, piv, dev_info = _lu_factor(
a, res_type, scipy=True, overwrite_a=overwrite_a
)

if any(dev_info):
diag_nums = ", ".join(str(v) for v in dev_info if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
RuntimeWarning,
stacklevel=2,
)

return lu, piv


def _multi_dot(arrays, order, i, j, out=None):
"""Actually do the multiplication with the given order."""

Expand Down