From 078493ecc742294da21da61b972ac8ff9a7d760f Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 1 Feb 2023 13:31:50 -0600 Subject: [PATCH] Changed simplify_iteration_two_strides, simplify_iteration_three_strides The target ordering used to be based on absolute values of the first vector of strides, now it uses lexicographic ordering of tuples of absolute values of all strides involved. This enables iteration space reduction for examples where all strides in the first vector are all zero, like in the example arising from ``` dpctl.tensor.full((2,3,4,), dpctl.tensor.asarray(1)) ``` The following two invocations show that iteration space used is 1d: ``` onetrace -d -v --demangle python -c "import dpctl.tensor as dpt; x = dpt.ones((30, 40, 50), dtype='i4'); y = dpt.empty_like(x, dtype='f4'); print((x.flags, y.flags)); y[:] = x" onetrace -d -v --demangle python -c "import dpctl.tensor._tensor_impl as ti, dpctl.tensor as dpt; dpt.full((2,3,4), dpt.asarray(1, dtype='f4'))" ``` --- .../libtensor/include/utils/strided_iters.hpp | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp index e2fe6051bc..e7a7b1d75f 100644 --- a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp +++ b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp @@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd, std::iota(pos.begin(), pos.end(), 0); std::stable_sort( - pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) { - auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; - auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; - return (abs_str1 > abs_str2) || - (abs_str1 == abs_str2 && shape[i1] > shape[i2]); + pos.begin(), pos.end(), [&strides1, &strides2, &shape](int i1, int i2) { + auto abs_str1_i1 = + (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; + auto abs_str1_i2 = + (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; + auto abs_str2_i1 = + (strides2[i1] < 0) ? -strides2[i1] : strides2[i1]; + auto abs_str2_i2 = + (strides2[i2] < 0) ? -strides2[i2] : strides2[i2]; + return (abs_str1_i1 > abs_str1_i2) || + (abs_str1_i1 == abs_str1_i2 && + (abs_str2_i1 > abs_str2_i2 || + (abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2]))); }); std::vector shape_w; @@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd, strides1_w.push_back(str1_p); strides2_w.push_back(str2_p); } + int nd_ = nd; while (contractable) { bool changed = false; @@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd, std::vector pos(nd); std::iota(pos.begin(), pos.end(), 0); - std::stable_sort( - pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) { - auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; - auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; - return (abs_str1 > abs_str2) || - (abs_str1 == abs_str2 && shape[i1] > shape[i2]); - }); + std::stable_sort(pos.begin(), pos.end(), + [&strides1, &strides2, &strides3, &shape](int i1, int i2) { + auto abs_str1_i1 = + (strides1[i1] < 0) ? -strides1[i1] : strides1[i1]; + auto abs_str1_i2 = + (strides1[i2] < 0) ? -strides1[i2] : strides1[i2]; + auto abs_str2_i1 = + (strides2[i1] < 0) ? -strides2[i1] : strides2[i1]; + auto abs_str2_i2 = + (strides2[i2] < 0) ? -strides2[i2] : strides2[i2]; + auto abs_str3_i1 = + (strides3[i1] < 0) ? -strides3[i1] : strides3[i1]; + auto abs_str3_i2 = + (strides3[i2] < 0) ? -strides3[i2] : strides3[i2]; + return (abs_str1_i1 > abs_str1_i2) || + ((abs_str1_i1 == abs_str1_i2) && + ((abs_str2_i1 > abs_str2_i2) || + ((abs_str2_i1 == abs_str2_i2) && + ((abs_str3_i1 > abs_str3_i2) || + ((abs_str3_i1 == abs_str3_i2) && + (shape[i1] > shape[i2])))))); + }); std::vector shape_w; std::vector strides1_w;