Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Search reductions use correct branch for float16
constexpr branch logic accounted for floating point types but not sycl::half,
which meant NaNs were not propagating for float16 data
  • Loading branch information
ndgrigorian committed Nov 3, 2023
commit 119d43d565e86e5055dbd85416ed8d9df1bb2e47
20 changes: 15 additions & 5 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3476,7 +3476,9 @@ struct SequentialSearchReduction
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
Expand All @@ -3501,7 +3503,9 @@ struct SequentialSearchReduction
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
Expand Down Expand Up @@ -3789,7 +3793,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3833,7 +3839,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3876,7 +3884,9 @@ struct CustomSearchReduction
? local_idx
: idx_identity_;
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
// equality does not hold for NaNs, so check here
local_idx =
(red_val_over_wg == local_red_val || std::isnan(local_red_val))
Expand Down