// SPDX-License-Identifier: BSD-3-Clause /* Intel Distributed Runtime for MLIR */ #include #include #include #include #include #include #include #include #include #define STRINGIFY(a) #a constexpr id_t UNKNOWN_GUID = -1; using container_type = std::unordered_map>; static container_type garrays; static SHARPY::id_type _nguid = -1; inline SHARPY::id_type get_guid() { return ++_nguid; } static bool skip_comm = get_bool_env("SHARPY_SKIP_COMM"); static bool no_async = get_bool_env("SHARPY_NO_ASYNC"); // Transceiver * theTransceiver = MPITransceiver(); template T *mr_to_ptr(void *ptr, intptr_t offset) { if (!ptr) { throw std::invalid_argument("Fatal: cannot handle offset on nullptr"); } return reinterpret_cast(ptr) + offset; } // abstract handle providing an abstract wait method struct WaitHandleBase { virtual ~WaitHandleBase(){}; virtual void wait() = 0; }; // concrete handle to be instantiated with a lambda or alike // the lambda will be executed within wait() template class WaitHandle : public WaitHandleBase { T _fini; public: WaitHandle(T &&fini) : _fini(std::move(fini)) {} virtual void wait() override { _fini(); } }; template WaitHandle *mkWaitHandle(T &&fini) { return new WaitHandle(std::move(fini)); }; extern "C" { void _idtr_wait(WaitHandleBase *handle) { if (handle) { handle->wait(); delete handle; } } #define NO_TRANSCEIVER #ifdef NO_TRANSCEIVER static void initMPIRuntime() { if (SHARPY::getTransceiver() == nullptr) SHARPY::init_transceiver(new SHARPY::MPITransceiver(false)); } #endif // Return number of ranks/processes in given team/communicator uint64_t idtr_nprocs(SHARPY::Transceiver *tc) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif return tc ? tc->nranks() : 1; } #pragma weak _idtr_nprocs = idtr_nprocs #pragma weak _mlir_ciface__idtr_nprocs = idtr_nprocs // Return rank in given team/communicator uint64_t idtr_prank(SHARPY::Transceiver *tc) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif return tc ? tc->rank() : 0; } #pragma weak _idtr_prank = idtr_prank #pragma weak _mlir_ciface__idtr_prank = idtr_prank // Register a global array of given shape. // Returns guid. // The runtime does not own or manage any memory. id_t idtr_init_array(const uint64_t *shape, uint64_t nD) { auto guid = get_guid(); // garrays[guid] = std::unique_ptr(nD ? new // SHARPY::NDArray(shape, nD) : new SHARPY::NDArray); return guid; } id_t _idtr_init_array(void *alloced, void *aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD) { return idtr_init_array(mr_to_ptr(aligned, offset), nD); } // Get the offsets (one for each dimension) of the local partition of a // distributed array in number of elements. Result is stored in provided array. void idtr_local_offsets(id_t guid, uint64_t *offsets, uint64_t nD) { #if 0 const auto & tnsr = garrays.at(guid); auto slcs = tnsr->slice().local_slice().slices(); assert(nD == slcs.size()); int i = -1; for(auto s : slcs) { offsets[++i] = s._start; } #endif } void _idtr_local_offsets(id_t guid, void *alloced, void *aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD) { idtr_local_offsets(guid, mr_to_ptr(aligned, offset), nD); } // Get the shape (one size for each dimension) of the local partition of a // distributed array in number of elements. Result is stored in provided array. void idtr_local_shape(id_t guid, uint64_t *lshape, uint64_t N) { #if 0 const auto & tnsr = garrays.at(guid); auto shp = tnsr->slice().local_slice().shape(); std::copy(shp.begin(), shp.end(), lshape); #endif } void _idtr_local_shape(id_t guid, void *alloced, void *aligned, intptr_t offset, intptr_t size, intptr_t stride, uint64_t nD) { idtr_local_shape(guid, mr_to_ptr(aligned, offset), nD); } } // extern "C" // convert id of our reduction op to id of imex::ndarray reduction op static SHARPY::ReduceOpId mlir2sharpy(const ::imex::ndarray::ReduceOpId rop) { switch (rop) { case ::imex::ndarray::MEAN: return SHARPY::MEAN; case ::imex::ndarray::PROD: return SHARPY::PROD; case ::imex::ndarray::SUM: return SHARPY::SUM; case ::imex::ndarray::STD: return SHARPY::STD; case ::imex::ndarray::VAR: return SHARPY::VAR; case ::imex::ndarray::MAX: return SHARPY::MAX; case ::imex::ndarray::MIN: return SHARPY::MIN; default: throw std::invalid_argument("Unknown reduction operation"); } } // convert element type/dtype from MLIR to sharpy [[maybe_unused]] static SHARPY::DTypeId mlir2sharpy(const ::imex::ndarray::DType dt) { switch (dt) { case ::imex::ndarray::DType::F64: return SHARPY::FLOAT64; break; case ::imex::ndarray::DType::I64: return SHARPY::INT64; break; case ::imex::ndarray::DType::U64: return SHARPY::UINT64; break; case ::imex::ndarray::DType::F32: return SHARPY::FLOAT32; break; case ::imex::ndarray::DType::I32: return SHARPY::INT32; break; case ::imex::ndarray::DType::U32: return SHARPY::UINT32; break; case ::imex::ndarray::DType::I16: return SHARPY::INT16; break; case ::imex::ndarray::DType::U16: return SHARPY::UINT16; break; case ::imex::ndarray::DType::I8: return SHARPY::INT8; break; case ::imex::ndarray::DType::U8: return SHARPY::UINT8; break; case ::imex::ndarray::DType::I1: return SHARPY::BOOL; break; default: throw std::invalid_argument("unknown dtype"); }; } /// copy possibly strided array into a contiguous block of data void bufferize(void *cptr, SHARPY::DTypeId dtype, const int64_t *sizes, const int64_t *strides, const int64_t *tStarts, const int64_t *tSizes, uint64_t nd, uint64_t N, void *out) { if (!cptr || !sizes || !strides || !tStarts || !tSizes) { return; } dispatch(dtype, cptr, [sizes, strides, tStarts, tSizes, nd, N, out](auto *ptr) { auto buff = static_cast(out); for (auto i = 0ul; i < N; ++i) { auto szs = &tSizes[i * nd]; if (szs[0] > 0) { auto sts = &tStarts[i * nd]; uint64_t off = 0; for (auto r = 0ul; r < nd; ++r) { off += sts[r] * strides[r]; } SHARPY::forall(0, &ptr[off], szs, strides, nd, [&buff](const auto *in) { *buff = *in; ++buff; }); } } }); } /// copy contiguous block of data into a possibly strided array distributed to N /// ranks void unpackN(void *in, SHARPY::DTypeId dtype, const int64_t *sizes, const int64_t *strides, const int64_t *tStarts, const int64_t *tSizes, uint64_t nd, uint64_t N, void *out) { if (!in || !sizes || !strides || !tStarts || !tSizes || !out) { return; } dispatch(dtype, out, [sizes, strides, tStarts, tSizes, nd, N, in](auto *ptr) { auto buff = static_cast(in); for (auto i = 0ul; i < N; ++i) { auto szs = &tSizes[i * nd]; if (szs[0] > 0) { auto sts = &tStarts[i * nd]; uint64_t off = 0; for (auto r = 0ul; r < nd; ++r) { off += sts[r] * strides[r]; } SHARPY::forall(0, &ptr[off], szs, strides, nd, [&buff](auto *out) { *out = *buff; ++buff; }); } } }); } /// copy contiguous block of data into a possibly strided array void unpack(void *in, SHARPY::DTypeId dtype, const int64_t *sizes, const int64_t *strides, uint64_t ndim, void *out) { if (!in || !sizes || !strides || !out) { return; } dispatch(dtype, out, [sizes, strides, ndim, in](auto *out_) { auto in_ = static_cast(in); SHARPY::forall(0, out_, sizes, strides, ndim, [&in_](auto *out) { *out = *in_; ++in_; }); }); } template void copy_(uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes, const int64_t *strides, const uint64_t *chunks, uint64_t nd, uint64_t start, uint64_t end, T *&out) { if (!cptr || !sizes || !strides || !chunks || !out) { return; } auto stride = strides[d]; uint64_t sz = sizes[d]; uint64_t chunk = chunks[d]; uint64_t first = 0; if (pos < start) { first = (start - pos) / chunk; pos += first * chunk; cptr += first * stride; assert(pos <= start && pos < end); } if (d == nd - 1) { auto n = std::min(sz - first, end - pos); if (stride == 1) { memcpy(out, cptr, n * sizeof(T)); } else { for (auto i = 0ul; i < n; ++i) { out[i] = cptr[i * stride]; } } pos += n; out += n; } else { for (auto i = first; i < sz; ++i) { copy_(d + 1, pos, cptr, sizes, strides, chunks, nd, start, end, out); if (pos >= end) return; cptr += stride; } } } /// copy a number of array elements into a contiguous block of data void bufferizeN(uint64_t nd, void *cptr, const int64_t *sizes, const int64_t *strides, SHARPY::DTypeId dtype, uint64_t N, const int64_t *tStarts, const int64_t *tEnds, void *out) { if (!cptr || !sizes || !strides || !tStarts || !tEnds || !out) { return; } std::vector chunks(nd); chunks[nd - 1] = 1; for (uint64_t i = 1; i < nd; ++i) { auto j = nd - i; chunks[j - 1] = chunks[j] * sizes[j]; } dispatch(dtype, cptr, [sizes, strides, tStarts, tEnds, nd, N, out, &chunks](auto *ptr) { auto buff = static_cast(out); for (auto i = 0ul; i < N; ++i) { auto start = tStarts[i]; auto end = tEnds[i]; if (end > start) { uint64_t pos = 0; copy_(0, pos, ptr, sizes, strides, chunks.data(), nd, start, end, buff); } } }); } using MRIdx1d = SHARPY::Unranked1DMemRefType; // FIXME hard-coded for contiguous layout template void _idtr_reduce_all(int64_t dataRank, void *dataDescr, int op) { auto tc = SHARPY::getTransceiver(); if (!tc) return; SHARPY::UnrankedMemRefType data(dataRank, dataDescr); assert(dataRank == 0 || (dataRank == 1 && data.strides()[0] == 1)); auto d = data.data(); auto t = SHARPY::DTYPE::value; auto r = dataRank ? data.sizes()[0] : 1; auto o = mlir2sharpy(static_cast(op)); tc->reduce_all(d, t, r, o); } extern "C" { #define TYPED_REDUCEALL(_sfx, _typ) \ void _idtr_reduce_all_##_sfx(int64_t dataRank, void *dataDescr, int op) { \ _idtr_reduce_all<_typ>(dataRank, dataDescr, op); \ } \ _Pragma(STRINGIFY(weak _mlir_ciface__idtr_reduce_all_##_sfx = \ _idtr_reduce_all_##_sfx)) TYPED_REDUCEALL(f64, double); TYPED_REDUCEALL(f32, float); TYPED_REDUCEALL(i64, int64_t); TYPED_REDUCEALL(i32, int32_t); TYPED_REDUCEALL(i16, int16_t); TYPED_REDUCEALL(i8, int8_t); TYPED_REDUCEALL(i1, bool); } // extern "C" /// @brief reshape array /// We assume array is partitioned along the first dimension (only) and /// partitions are ordered by ranks WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype, SHARPY::Transceiver *tc, int64_t iNDims, int64_t *iGShapePtr, int64_t *iOffsPtr, void *iDataPtr, int64_t *iDataShapePtr, int64_t *iDataStridesPtr, int64_t oNDims, int64_t *oGShapePtr, int64_t *oOffsPtr, void *oDataPtr, int64_t *oDataShapePtr, int64_t *oDataStridesPtr) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr || !iDataStridesPtr || !oGShapePtr || !oOffsPtr || !oDataPtr || !oDataShapePtr || !oDataStridesPtr || !tc) { throw std::invalid_argument("Fatal: received nullptr in reshape"); } assert(std::accumulate(&iGShapePtr[0], &iGShapePtr[iNDims], 1, std::multiplies()) == std::accumulate(&oGShapePtr[0], &oGShapePtr[oNDims], 1, std::multiplies())); assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oNDims], 0, std::plus()) == 0); auto N = tc->nranks(); auto me = tc->rank(); if (N <= me) { throw std::out_of_range("Fatal: rank must be < number of ranks"); } int64_t icSz = std::accumulate(&iGShapePtr[1], &iGShapePtr[iNDims], 1, std::multiplies()); assert(icSz == std::accumulate(&iDataShapePtr[1], &iDataShapePtr[iNDims], 1, std::multiplies())); int64_t mySz = icSz * iDataShapePtr[0]; if (mySz / icSz != iDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOff = iOffsPtr[0] * icSz; if (myOff / icSz != iOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myEnd = myOff + mySz; if (myEnd < myOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t oCSz = std::accumulate(&oGShapePtr[1], &oGShapePtr[oNDims], 1, std::multiplies()); assert(oCSz == std::accumulate(&oDataShapePtr[1], &oDataShapePtr[oNDims], 1, std::multiplies())); int64_t myOSz = oCSz * oDataShapePtr[0]; if (myOSz / oCSz != oDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOOff = oOffsPtr[0] * oCSz; if (myOOff / oCSz != oOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOEnd = myOOff + myOSz; if (myOEnd < myOOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } // First we allgather the current and target partitioning ::std::vector buff(4 * N); buff[me * 4 + 0] = myOff; buff[me * 4 + 1] = mySz; buff[me * 4 + 2] = myOOff; buff[me * 4 + 3] = myOSz; ::std::vector counts(N, 4); ::std::vector dspl(N); for (auto i = 0ul; i < N; ++i) { dspl[i] = 4 * i; } tc->gather(buff.data(), counts.data(), dspl.data(), SHARPY::INT64, SHARPY::REPLICATED); // compute overlaps of current parts with requested parts // and store meta for alltoall std::vector soffs(N, 0); std::vector sszs(N, 0); std::vector roffs(N, 0); std::vector rszs(N, 0); std::vector lsOffs(N, 0); std::vector lsEnds(N, 0); int64_t totSSz = 0; for (auto i = 0ul; i < N; ++i) { int64_t *curr = &buff[i * 4]; auto xOff = curr[0]; auto xEnd = xOff + curr[1]; auto tOff = curr[2]; auto tEnd = tOff + curr[3]; // first check if this target part overlaps with my local part if (tEnd > myOff && tOff < myEnd) { auto sOff = std::max(tOff, myOff); sszs[i] = std::min(tEnd, myEnd) - sOff; soffs[i] = i ? soffs[i - 1] + sszs[i - 1] : 0; lsOffs[i] = sOff - myOff; lsEnds[i] = lsOffs[i] + sszs[i]; totSSz += sszs[i]; } // then check if my target part overlaps with the remote local part if (myOEnd > xOff && myOOff < xEnd) { auto rOff = std::max(xOff, myOOff); rszs[i] = std::min(xEnd, myOEnd) - rOff; roffs[i] = i ? roffs[i - 1] + rszs[i - 1] : 0; } } bool isStrided = !SHARPY::is_contiguous(oDataShapePtr, oDataStridesPtr, oNDims); void *rBuff = isStrided ? new char[sizeof_dtype(sharpytype) * myOSz] : oDataPtr; SHARPY::Buffer sendbuff(totSSz * sizeof_dtype(sharpytype), 2); bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N, lsOffs.data(), lsEnds.data(), sendbuff.data()); auto hdl = tc->alltoall(sendbuff.data(), sszs.data(), soffs.data(), sharpytype, rBuff, rszs.data(), roffs.data()); auto wait = [tc, hdl, isStrided, rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims, oDataPtr, sendbuff = std::move(sendbuff), sszs = std::move(sszs), soffs = std::move(soffs), rszs = std::move(rszs), roffs = std::move(roffs)]() { tc->wait(hdl); if (isStrided) { unpack(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims, oDataPtr); delete[](char *) rBuff; } }; assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() && roffs.empty()); if (no_async) { wait(); return nullptr; } return mkWaitHandle(std::move(wait)); } /// @brief reshape array template WaitHandleBase * _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iDataDescr, int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oDataDescr) { if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oGShapeDescr || !oLOffsDescr || !oDataDescr) { throw std::invalid_argument( "Fatal error: received nullptr in update_halo."); } auto sharpytype = SHARPY::DTYPE::value; // Construct unranked memrefs for metadata and data MRIdx1d iGShape(iNSzs, iGShapeDescr); MRIdx1d iOffs(iNOffs, iLOffsDescr); SHARPY::UnrankedMemRefType iData(iNDims, iDataDescr); MRIdx1d oGShape(oNSzs, oGShapeDescr); MRIdx1d oOffs(oNOffs, oLOffsDescr); SHARPY::UnrankedMemRefType oData(oNDims, oDataDescr); return _idtr_copy_reshape( sharpytype, tc, iNDims, iGShape.data(), iOffs.data(), iData.data(), iData.sizes(), iData.strides(), oNDims, oGShape.data(), oOffs.data(), oData.data(), oData.sizes(), oData.strides()); } namespace { /// /// An util class of multi-dimensional index /// class id { public: id(size_t dims) : _values(dims) {} id(size_t dims, int64_t *value) : _values(value, value + dims) {} id(const std::vector &values) : _values(values) {} id(const std::vector &&values) : _values(std::move(values)) {} /// Permute this id by axes and return a new id id permute(std::vector axes) const { std::vector new_values(_values.size()); for (size_t i = 0; i < _values.size(); i++) { new_values[i] = _values[axes[i]]; } return id(std::move(new_values)); } int64_t operator[](size_t i) const { return _values[i]; } int64_t &operator[](size_t i) { return _values[i]; } /// Subtract another id from this id and return a new id id operator-(const id &rhs) const { std::vector new_values(_values.size()); for (size_t i = 0; i < _values.size(); i++) { new_values[i] = _values[i] - rhs._values[i]; } return id(std::move(new_values)); } /// Subtract another id from this id and return a new id id operator-(const int64_t *rhs) const { std::vector new_values(_values.size()); for (size_t i = 0; i < _values.size(); i++) { new_values[i] = _values[i] - rhs[i]; } return id(std::move(new_values)); } /// Increase the last dimension value of this id which bounds by shape /// /// Example: /// In shape (2,2) : (0,0)->(0,1)->(1,0)->(1,1)->(0,0) void next(const int64_t *shape) { size_t i = _values.size(); while (i--) { ++_values[i]; if (_values[i] < shape[i]) { return; } _values[i] = 0; } } size_t size() { return _values.size(); } private: std::vector _values; }; /// /// An wrapper template class for distribute multi-dimensional array /// template class ndarray { public: ndarray(int64_t nDims, int64_t *gShape, int64_t *gOffsets, void *lData, int64_t *lShape, int64_t *lStrides) : _nDims(nDims), _gShape(gShape), _gOffsets(gOffsets), _lData((T *)lData), _lShape(lShape), _lStrides(lStrides) {} /// Return the first global index of local data id firstLocalIndex() const { return id(_nDims, _gOffsets); } /// Interate all global indices in local data void localIndices(const std::function &callback) const { size_t size = lSize(); id idx = firstLocalIndex(); while (size--) { callback(idx); idx.next(_gShape); } } /// Interate all global indices of the array void globalIndices(const std::function &callback) const { size_t size = gSize(); id idx(_nDims); while (size--) { callback(idx); idx.next(_gShape); } } int64_t getLocalDataOffset(const id &idx) const { auto localIdx = idx - _gOffsets; int64_t offset = 0; for (int64_t i = 0; i < _nDims - 1; ++i) { offset = (offset + localIdx[i]) * _lShape[i + 1]; } offset += localIdx[_nDims - 1]; return offset; } /// Using global index to access its data T &operator[](const id &idx) { return _lData[getLocalDataOffset(idx)]; } T operator[](const id &idx) const { return _lData[getLocalDataOffset(idx)]; } id gShape() { return id(_nDims, _gShape); } id lShape() { return id(_nDims, _lShape); } size_t gSize() const { return std::accumulate(_gShape, _gShape + _nDims, 1, std::multiplies()); } size_t lSize() const { return std::accumulate(_lShape, _lShape + _nDims, 1, std::multiplies()); } private: int64_t _nDims; int64_t *_gShape; int64_t *_gOffsets; T *_lData; int64_t *_lShape; int64_t *_lStrides; }; struct Parts { int64_t iStart; int64_t iEnd; int64_t oStart; int64_t oEnd; }; size_t getInputRank(const std::vector &parts, int64_t dim0) { for (size_t i = 0; i < parts.size(); i++) { if (dim0 >= parts[i].iStart && dim0 < parts[i].iEnd) { return i; } } assert(false && "unreachable"); return 0; } size_t getOutputRank(const std::vector &parts, int64_t dim0) { for (size_t i = 0; i < parts.size(); i++) { if (dim0 >= parts[i].oStart && dim0 < parts[i].oEnd) { return i; } } assert(false && "unreachable"); return 0; } template class WaitPermute { public: WaitPermute(SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl, SHARPY::rank_type cRank, SHARPY::rank_type nRanks, std::vector &&parts, std::vector &&axes, std::vector oGShape, ndarray &&input, ndarray &&output, std::vector &&receiveBuffer, std::vector &&receiveOffsets, std::vector &&receiveSizes) : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)), axes(std::move(axes)), oGShape(std::move(oGShape)), input(std::move(input)), output(std::move(output)), receiveBuffer(std::move(receiveBuffer)), receiveOffsets(std::move(receiveOffsets)), receiveSizes(std::move(receiveSizes)) {} void operator()() { tc->wait(hdl); std::vector> receiveRankBuffer(nRanks); for (size_t rank = 0; rank < nRanks; ++rank) { auto &rankBuffer = receiveRankBuffer[rank]; rankBuffer.insert( rankBuffer.end(), receiveBuffer.begin() + receiveOffsets[rank], receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]); } std::vector receiveRankBufferCount(nRanks, 0); input.globalIndices([&](const id &inputIndex) { id outputIndex = inputIndex.permute(axes); auto rank = getOutputRank(parts, outputIndex[0]); if (rank != cRank) return; rank = getInputRank(parts, inputIndex[0]); auto &count = receiveRankBufferCount[rank]; output[outputIndex] = receiveRankBuffer[rank][count++]; }); } private: SHARPY::Transceiver *tc; SHARPY::Transceiver::WaitHandle hdl; SHARPY::rank_type cRank; SHARPY::rank_type nRanks; std::vector parts; std::vector axes; std::vector oGShape; ndarray input; ndarray output; std::vector receiveBuffer; std::vector receiveOffsets; std::vector receiveSizes; }; } // namespace /// @brief permute array /// We assume array is partitioned along the first dimension (only) and /// partitions are ordered by ranks template WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, SHARPY::Transceiver *tc, int64_t iNDims, int64_t *iGShapePtr, int64_t *iOffsPtr, void *iDataPtr, int64_t *iDataShapePtr, int64_t *iDataStridesPtr, int64_t *oOffsPtr, void *oDataPtr, int64_t *oDataShapePtr, int64_t *oDataStridesPtr, int64_t *axesPtr) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr || !iDataStridesPtr || !oOffsPtr || !oDataPtr || !oDataShapePtr || !oDataStridesPtr || !tc) { throw std::invalid_argument("Fatal: received nullptr in reshape"); } std::vector oGShape(iNDims); for (int64_t i = 0; i < iNDims; ++i) { oGShape[i] = iGShapePtr[axesPtr[i]]; } auto *oGShapePtr = oGShape.data(); const auto oNDims = iNDims; assert(std::accumulate(&iGShapePtr[0], &iGShapePtr[iNDims], 1, std::multiplies()) == std::accumulate(&oGShapePtr[0], &oGShapePtr[oNDims], 1, std::multiplies())); assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oNDims], 0, std::plus()) == 0); const auto nRanks = tc->nranks(); const auto cRank = tc->rank(); if (nRanks <= cRank) { throw std::out_of_range("Fatal: rank must be < number of ranks"); } int64_t icSz = std::accumulate(&iGShapePtr[1], &iGShapePtr[iNDims], 1, std::multiplies()); assert(icSz == std::accumulate(&iDataShapePtr[1], &iDataShapePtr[iNDims], 1, std::multiplies())); int64_t mySz = icSz * iDataShapePtr[0]; if (mySz / icSz != iDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOff = iOffsPtr[0] * icSz; if (myOff / icSz != iOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myEnd = myOff + mySz; if (myEnd < myOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t oCSz = std::accumulate(&oGShapePtr[1], &oGShapePtr[oNDims], 1, std::multiplies()); assert(oCSz == std::accumulate(&oDataShapePtr[1], &oDataShapePtr[oNDims], 1, std::multiplies())); int64_t myOSz = oCSz * oDataShapePtr[0]; if (myOSz / oCSz != oDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOOff = oOffsPtr[0] * oCSz; if (myOOff / oCSz != oOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myOEnd = myOOff + myOSz; if (myOEnd < myOOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } // First we allgather the current and target partitioning std::vector parts(nRanks); parts[cRank].iStart = iOffsPtr[0]; parts[cRank].iEnd = iOffsPtr[0] + iDataShapePtr[0]; parts[cRank].oStart = oOffsPtr[0]; parts[cRank].oEnd = oOffsPtr[0] + oDataShapePtr[0]; std::vector counts(nRanks, 4); std::vector dspl(nRanks); for (auto i = 0ul; i < nRanks; ++i) { dspl[i] = 4 * i; } tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64, SHARPY::REPLICATED); // Transpose ndarray input(iNDims, iGShapePtr, iOffsPtr, iDataPtr, iDataShapePtr, iDataStridesPtr); ndarray output(oNDims, oGShapePtr, oOffsPtr, oDataPtr, oDataShapePtr, oDataStridesPtr); std::vector axes(axesPtr, axesPtr + iNDims); std::vector sendBuffer; std::vector receiveBuffer(output.lSize()); std::vector sendSizes(nRanks); std::vector sendOffsets(nRanks); std::vector receiveSizes(nRanks); std::vector receiveOffsets(nRanks); { std::vector> sendRankBuffer(nRanks); input.localIndices([&](const id &inputIndex) { id outputIndex = inputIndex.permute(axes); auto rank = getOutputRank(parts, outputIndex[0]); sendRankBuffer[rank].push_back(input[inputIndex]); }); int lastOffset = 0; for (size_t rank = 0; rank < nRanks; rank++) { sendSizes[rank] = sendRankBuffer[rank].size(); sendOffsets[rank] = lastOffset; sendBuffer.insert(sendBuffer.end(), sendRankBuffer[rank].begin(), sendRankBuffer[rank].end()); lastOffset += sendSizes[rank]; } output.localIndices([&](const id &outputIndex) { id inputIndex = outputIndex.permute(axes); auto rank = getInputRank(parts, inputIndex[0]); ++receiveSizes[rank]; }); for (size_t rank = 1; rank < nRanks; rank++) { receiveOffsets[rank] = receiveOffsets[rank - 1] + receiveSizes[rank - 1]; } } auto hdl = tc->alltoall(sendBuffer.data(), sendSizes.data(), sendOffsets.data(), sharpytype, receiveBuffer.data(), receiveSizes.data(), receiveOffsets.data()); auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), std::move(axes), std::move(oGShape), std::move(input), std::move(output), std::move(receiveBuffer), std::move(receiveOffsets), std::move(receiveSizes)); assert(parts.empty() && axes.empty() && receiveBuffer.empty() && receiveOffsets.empty() && receiveSizes.empty()); if (no_async) { wait(); return nullptr; } return mkWaitHandle(std::move(wait)); } /// @brief permute array template WaitHandleBase * _idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iDataDescr, int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oDataDescr, int64_t axesSzs, void *axesDescr) { if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oLOffsDescr || !oDataDescr || !axesDescr) { throw std::invalid_argument( "Fatal error: received nullptr in update_halo."); } auto sharpyType = SHARPY::DTYPE::value; // Construct unranked memrefs for metadata and data MRIdx1d iGShape(iNSzs, iGShapeDescr); MRIdx1d iOffs(iNOffs, iLOffsDescr); SHARPY::UnrankedMemRefType iData(iNDims, iDataDescr); MRIdx1d oOffs(oNOffs, oLOffsDescr); SHARPY::UnrankedMemRefType oData(oNDims, oDataDescr); MRIdx1d axes(axesSzs, axesDescr); return _idtr_copy_permute(sharpyType, tc, iNDims, iGShape.data(), iOffs.data(), iData.data(), iData.sizes(), iData.strides(), oOffs.data(), oData.data(), oData.sizes(), oData.strides(), axes.data()); } extern "C" { #define TYPED_COPY_RESHAPE(_sfx, _typ) \ void *_idtr_copy_reshape_##_sfx( \ SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \ int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \ int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, \ int64_t oNDims, void *oLDescr) { \ return _idtr_copy_reshape<_typ>( \ tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNSzs, \ oGShapeDescr, oNOffs, oLOffsDescr, oNDims, oLDescr); \ } \ _Pragma(STRINGIFY(weak _mlir_ciface__idtr_copy_reshape_##_sfx = \ _idtr_copy_reshape_##_sfx)) TYPED_COPY_RESHAPE(f64, double); TYPED_COPY_RESHAPE(f32, float); TYPED_COPY_RESHAPE(i64, int64_t); TYPED_COPY_RESHAPE(i32, int32_t); TYPED_COPY_RESHAPE(i16, int16_t); TYPED_COPY_RESHAPE(i8, int8_t); TYPED_COPY_RESHAPE(i1, bool); #define TYPED_COPY_PERMUTE(_sfx, _typ) \ void *_idtr_copy_permute_##_sfx( \ SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \ int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \ int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oLDescr, \ int64_t axesSzs, void *axesDescr) { \ return _idtr_copy_permute<_typ>( \ tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNOffs, \ oLOffsDescr, oNDims, oLDescr, axesSzs, axesDescr); \ } \ _Pragma(STRINGIFY(weak _mlir_ciface__idtr_copy_permute_##_sfx = \ _idtr_copy_permute_##_sfx)) TYPED_COPY_PERMUTE(f64, double); TYPED_COPY_PERMUTE(f32, float); TYPED_COPY_PERMUTE(i64, int64_t); TYPED_COPY_PERMUTE(i32, int32_t); TYPED_COPY_PERMUTE(i16, int16_t); TYPED_COPY_PERMUTE(i8, int8_t); // FIXME: bool is not supported yet due to std::vector // TYPED_COPY_PERMUTE(i1, bool); } // extern "C" // struct for caching meta data for update_halo // no copies allowed, only move-semantics and reference access struct UHCache { // copying needed? std::vector _lBufferStart, _lBufferSize, _rBufferStart, _rBufferSize; std::vector _lRecvBufferSize, _rRecvBufferSize; // send maps std::vector _lSendSize, _rSendSize, _lSendOff, _rSendOff; // receive maps std::vector _lRecvSize, _rRecvSize, _lRecvOff, _rRecvOff; // buffers SHARPY::Buffer _recvLBuff, _recvRBuff, _sendLBuff, _sendRBuff; bool _bufferizeSend, _bufferizeLRecv, _bufferizeRRecv; // start and sizes for chunks from remotes if copies are needed int64_t _lTotalRecvSize, _rTotalRecvSize, _lTotalSendSize, _rTotalSendSize; UHCache() = default; UHCache(const UHCache &) = delete; UHCache(UHCache &&) = default; UHCache(std::vector &&lBufferStart, std::vector &&lBufferSize, std::vector &&rBufferStart, std::vector &&rBufferSize, std::vector &&lRecvBufferSize, std::vector &&rRecvBufferSize, std::vector &&lSendSize, std::vector &&rSendSize, std::vector &&lSendOff, std::vector &&rSendOff, std::vector &&lRecvSize, std::vector &&rRecvSize, std::vector &&lRecvOff, SHARPY::Buffer &&recvLBuff, SHARPY::Buffer &&recvRBuff, SHARPY::Buffer &&sendLBuff, SHARPY::Buffer &&sendRBuff, std::vector &&rRecvOff, bool bufferizeSend, bool bufferizeLRecv, bool bufferizeRRecv, int64_t lTotalRecvSize, int64_t rTotalRecvSize, int64_t lTotalSendSize, int64_t rTotalSendSize) : _lBufferStart(std::move(lBufferStart)), _lBufferSize(std::move(lBufferSize)), _rBufferStart(std::move(rBufferStart)), _rBufferSize(std::move(rBufferSize)), _lRecvBufferSize(std::move(lRecvBufferSize)), _rRecvBufferSize(std::move(rRecvBufferSize)), _lSendSize(std::move(lSendSize)), _rSendSize(std::move(rSendSize)), _lSendOff(std::move(lSendOff)), _rSendOff(std::move(rSendOff)), _lRecvSize(std::move(lRecvSize)), _rRecvSize(std::move(rRecvSize)), _lRecvOff(std::move(lRecvOff)), _rRecvOff(std::move(rRecvOff)), _recvLBuff(std::move(recvLBuff)), _recvRBuff(std::move(recvRBuff)), _sendLBuff(std::move(sendLBuff)), _sendRBuff(std::move(sendRBuff)), _bufferizeSend(bufferizeSend), _bufferizeLRecv(bufferizeLRecv), _bufferizeRRecv(bufferizeRRecv), _lTotalRecvSize(lTotalRecvSize), _rTotalRecvSize(rTotalRecvSize), _lTotalSendSize(lTotalSendSize), _rTotalSendSize(rTotalSendSize) {} UHCache &operator=(const UHCache &) = delete; UHCache &operator=(UHCache &&) = default; }; UHCache getMetaData(SHARPY::rank_type nworkers, int64_t ndims, int64_t *ownedOff, int64_t *ownedShape, int64_t *ownedStride, int64_t *bbOff, int64_t *bbShape, int64_t *leftHaloShape, int64_t *leftHaloStride, int64_t *rightHaloShape, int64_t *rightHaloStride, SHARPY::Transceiver *tc) { UHCache cE; // holds data if non-cached auto myWorkerIndex = tc->rank(); if (myWorkerIndex >= nworkers) { throw std::out_of_range("Fatal: rank must be < number of workers"); } cE._lTotalRecvSize = 0; cE._rTotalRecvSize = 0; cE._lTotalSendSize = 0; cE._rTotalSendSize = 0; // Gather table with bounding box offsets and shapes for all workers // [ (w0 offsets) o_0, o_1, ..., o_ndims, // (w0 shapes) s_0, s_1, ..., s_ndims, // (w1 offsets) ... ] auto nn = 2 * ndims * nworkers; if (nn / 2 != ndims * nworkers) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } ::std::vector bbTable(nn); auto ptableStart = 2 * ndims * myWorkerIndex; for (int64_t i = 0; i < ndims; ++i) { bbTable[ptableStart + i] = bbOff[i]; bbTable[ptableStart + i + ndims] = bbShape[i]; } ::std::vector counts(nworkers, ndims * 2); ::std::vector offsets(nworkers); for (auto i = 0ul; i < nworkers; ++i) { offsets[i] = 2 * ndims * i; } tc->gather(bbTable.data(), counts.data(), offsets.data(), SHARPY::INT64, SHARPY::REPLICATED); // global indices for row partitioning auto ownedRowStart = ownedOff[0]; auto ownedRows = ownedShape[0]; auto ownedRowEnd = ownedRowStart + ownedRows; // all remaining dims are treated as one large column auto ownedTotCols = std::accumulate(&ownedShape[1], &ownedShape[ndims], 1, std::multiplies()); auto bbTotCols = std::accumulate(&bbShape[1], &bbShape[ndims], 1, std::multiplies()); // find local elements to send to next workers (destination leftHalo) // and previous workers (destination rightHalo) cE._lSendOff.resize(nworkers, 0); cE._rSendOff.resize(nworkers, 0); cE._lSendSize.resize(nworkers, 0); cE._rSendSize.resize(nworkers, 0); // use send buffer if owned data is strided or sending a subview cE._bufferizeSend = (!SHARPY::is_contiguous(ownedShape, ownedStride, ndims) || bbTotCols != ownedTotCols); cE._lBufferStart.resize(nworkers * ndims, 0); cE._lBufferSize.resize(nworkers * ndims, 0); cE._rBufferStart.resize(nworkers * ndims, 0); cE._rBufferSize.resize(nworkers * ndims, 0); for (auto i = 0ul; i < nworkers; ++i) { if (i == myWorkerIndex) { continue; } // worker i bounding box indices auto bRowStart = bbTable[2 * ndims * i]; auto bRows = bbTable[2 * ndims * i + ndims]; auto bRowEnd = bRowStart + bRows; if (bRowEnd > ownedRowStart && bRowStart < ownedRowEnd) { // bounding box overlaps with local data // calculate indices for data to be sent auto globalRowStart = std::max(ownedRowStart, bRowStart); auto globalRowEnd = std::min(ownedRowEnd, bRowEnd); auto localRowStart = globalRowStart - ownedRowStart; auto localStart = (int)(localRowStart)*ownedTotCols; auto nRows = globalRowEnd - globalRowStart; auto nSend = (int)(nRows)*bbTotCols; if (i < myWorkerIndex) { // target is rightHalo if (cE._bufferizeSend) { cE._rSendOff[i] = i ? cE._rSendOff[i - 1] + cE._rSendSize[i - 1] : 0; if (i && cE._rSendOff[i] < cE._rSendOff[i - 1]) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } cE._rBufferStart[i * ndims] = localRowStart; cE._rBufferSize[i * ndims] = nRows; for (auto j = 1; j < ndims; ++j) { cE._rBufferStart[i * ndims + j] = bbOff[j]; cE._rBufferSize[i * ndims + j] = bbShape[j]; } } else { cE._rSendOff[i] = localStart; } cE._rSendSize[i] = nSend; cE._rTotalSendSize += nSend; } else { // target is leftHalo if (cE._bufferizeSend) { cE._lSendOff[i] = i ? cE._lSendOff[i - 1] + cE._lSendSize[i - 1] : 0; if (i && cE._lSendOff[i] < cE._lSendOff[i - 1]) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } cE._lBufferStart[i * ndims] = localRowStart; cE._lBufferSize[i * ndims] = nRows; for (auto j = 1; j < ndims; ++j) { cE._lBufferStart[i * ndims + j] = bbOff[j]; cE._lBufferSize[i * ndims + j] = bbShape[j]; } } else { cE._lSendOff[i] = localStart; } cE._lSendSize[i] = nSend; cE._lTotalSendSize += nSend; } } } // receive maps cE._lRecvSize.resize(nworkers); cE._rRecvSize.resize(nworkers); cE._lRecvOff.resize(nworkers); cE._rRecvOff.resize(nworkers); // receive size is sender's send size tc->alltoall(cE._lSendSize.data(), 1, SHARPY::INT32, cE._lRecvSize.data()); tc->alltoall(cE._rSendSize.data(), 1, SHARPY::INT32, cE._rRecvSize.data()); // compute offset in a contiguous receive buffer cE._lRecvOff[0] = 0; cE._rRecvOff[0] = 0; for (auto i = 1ul; i < nworkers; ++i) { cE._lRecvOff[i] = cE._lRecvOff[i - 1] + cE._lRecvSize[i - 1]; cE._rRecvOff[i] = cE._rRecvOff[i - 1] + cE._rRecvSize[i - 1]; } // receive buffering cE._bufferizeLRecv = !SHARPY::is_contiguous(leftHaloShape, leftHaloStride, ndims); cE._bufferizeRRecv = !SHARPY::is_contiguous(rightHaloShape, rightHaloStride, ndims); cE._lRecvBufferSize.resize(nworkers * ndims, 0); cE._rRecvBufferSize.resize(nworkers * ndims, 0); // deduce receive shape for unpack for (auto i = 0ul; i < nworkers; ++i) { if (cE._bufferizeLRecv && cE._lRecvSize[i] != 0) { auto x = cE._lTotalRecvSize + cE._lRecvSize[i]; if (x < cE._lTotalRecvSize) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } cE._lTotalRecvSize = x; cE._lRecvBufferSize[i * ndims] = cE._lRecvSize[i] / bbTotCols; // nrows for (auto j = 1; j < ndims; ++j) { cE._lRecvBufferSize[i * ndims + j] = bbShape[j]; // leftHaloShape[j] } } if (cE._bufferizeRRecv && cE._rRecvSize[i] != 0) { auto x = cE._rTotalRecvSize + cE._rRecvSize[i]; if (x < cE._rTotalRecvSize) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } cE._rTotalRecvSize = x; if (cE._rTotalRecvSize < 0) { throw std::overflow_error("Fatal: Integer overflow in getMetaData"); } cE._rRecvBufferSize[i * ndims] = cE._rRecvSize[i] / bbTotCols; // nrows for (auto j = 1; j < ndims; ++j) { cE._rRecvBufferSize[i * ndims + j] = bbShape[j]; // rightHaloShape[j] } } } return cE; }; /// @brief Update data in halo parts /// We assume array is partitioned along the first dimension only /// (row partitioning) and partitions are ordered by ranks /// if cache-key is provided (>=0) meta data is read from cache /// @return (MPI) handles void *_idtr_update_halo(SHARPY::DTypeId sharpytype, int64_t ndims, int64_t *ownedOff, int64_t *ownedShape, int64_t *ownedStride, int64_t *bbOff, int64_t *bbShape, void *ownedData, int64_t *leftHaloShape, int64_t *leftHaloStride, void *leftHaloData, int64_t *rightHaloShape, int64_t *rightHaloStride, void *rightHaloData, SHARPY::Transceiver *tc, int64_t key) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif if (!ownedOff || !ownedShape || !ownedStride || !bbOff || !bbShape || !ownedData || !leftHaloShape || !leftHaloStride || !leftHaloData || !rightHaloShape || !rightHaloStride || !rightHaloData || !tc) { throw std::invalid_argument( "Fatal error: received nullptr in update_halo."); } auto nworkers = tc->nranks(); if (nworkers <= 1 || skip_comm) return nullptr; // not thread-safe static std::unordered_map uhCache; // meta-data cache static UHCache *cache = nullptr; // reading either from non-cached or cached auto cIt = key == -1 ? uhCache.end() : uhCache.find(key); if (cIt == uhCache.end()) { // not in cache // update cache if requested cIt = uhCache .insert_or_assign( key, std::move(getMetaData( nworkers, ndims, ownedOff, ownedShape, ownedStride, bbOff, bbShape, leftHaloShape, leftHaloStride, rightHaloShape, rightHaloStride, tc))) .first; } cache = &(cIt->second); int64_t nbytes = sizeof_dtype(sharpytype); if (cache->_bufferizeLRecv) { int64_t x = cache->_lTotalRecvSize * nbytes; if (x / nbytes != cache->_lTotalRecvSize) { throw std::overflow_error("Fatal: Integer overflow in update_halo"); } cache->_recvLBuff.resize(x); } if (cache->_bufferizeRRecv) { int64_t x = cache->_rTotalRecvSize * nbytes; if (x / nbytes != cache->_rTotalRecvSize) { throw std::overflow_error("Fatal: Integer overflow in update_halo"); } cache->_recvRBuff.resize(x); } if (cache->_bufferizeSend) { int64_t x = cache->_lTotalSendSize * nbytes; if (x / nbytes != cache->_lTotalSendSize) { throw std::overflow_error("Fatal: Integer overflow in update_halo"); } cache->_sendLBuff.resize(x); x = cache->_rTotalSendSize * nbytes; if (x / nbytes != cache->_rTotalSendSize) { throw std::overflow_error("Fatal: Integer overflow in update_halo"); } cache->_sendRBuff.resize(x); } void *lRecvData = cache->_bufferizeLRecv ? cache->_recvLBuff.data() : leftHaloData; void *rRecvData = cache->_bufferizeRRecv ? cache->_recvRBuff.data() : rightHaloData; void *lSendData = cache->_bufferizeSend ? cache->_sendLBuff.data() : ownedData; void *rSendData = cache->_bufferizeSend ? cache->_sendRBuff.data() : ownedData; // communicate left/right halos if (cache->_bufferizeSend) { bufferize(ownedData, sharpytype, ownedShape, ownedStride, cache->_lBufferStart.data(), cache->_lBufferSize.data(), ndims, nworkers, cache->_sendLBuff.data()); } auto lwh = tc->alltoall(lSendData, cache->_lSendSize.data(), cache->_lSendOff.data(), sharpytype, lRecvData, cache->_lRecvSize.data(), cache->_lRecvOff.data()); if (cache->_bufferizeSend) { bufferize(ownedData, sharpytype, ownedShape, ownedStride, cache->_rBufferStart.data(), cache->_rBufferSize.data(), ndims, nworkers, cache->_sendRBuff.data()); } auto rwh = tc->alltoall(rSendData, cache->_rSendSize.data(), cache->_rSendOff.data(), sharpytype, rRecvData, cache->_rRecvSize.data(), cache->_rRecvOff.data()); auto wait = [=]() { tc->wait(lwh); std::vector recvBufferStart(nworkers * ndims, 0); if (cache->_bufferizeLRecv) { unpackN(lRecvData, sharpytype, leftHaloShape, leftHaloStride, recvBufferStart.data(), cache->_lRecvBufferSize.data(), ndims, nworkers, leftHaloData); } tc->wait(rwh); if (cache->_bufferizeRRecv) { unpackN(rRecvData, sharpytype, rightHaloShape, rightHaloStride, recvBufferStart.data(), cache->_rRecvBufferSize.data(), ndims, nworkers, rightHaloData); } }; if (cache->_bufferizeLRecv || cache->_bufferizeRRecv || no_async) { wait(); return nullptr; } return mkWaitHandle(std::move(wait)); } /// @brief templated wrapper for typed function versions calling /// _idtr_update_halo template void *_idtr_update_halo(SHARPY::Transceiver *tc, int64_t gShapeRank, void *gShapeDescr, int64_t oOffRank, void *oOffDescr, int64_t oDataRank, void *oDataDescr, int64_t bbOffRank, void *bbOffDescr, int64_t bbShapeRank, void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank, void *rHaloDescr, int64_t key) { if (!gShapeDescr || !oOffDescr || !oDataDescr || !bbOffDescr || !bbShapeDescr || !lHaloDescr || !rHaloDescr) { throw std::invalid_argument( "Fatal error: received nullptr in update_halo."); } auto sharpytype = SHARPY::DTYPE::value; // Construct unranked memrefs for metadata and data MRIdx1d ownedOff(oOffRank, oOffDescr); MRIdx1d bbOff(bbOffRank, bbOffDescr); MRIdx1d bbShape(bbShapeRank, bbShapeDescr); SHARPY::UnrankedMemRefType ownedData(oDataRank, oDataDescr); SHARPY::UnrankedMemRefType leftHalo(lHaloRank, lHaloDescr); SHARPY::UnrankedMemRefType rightHalo(rHaloRank, rHaloDescr); return _idtr_update_halo( sharpytype, ownedData.rank(), ownedOff.data(), ownedData.sizes(), ownedData.strides(), bbOff.data(), bbShape.data(), ownedData.data(), leftHalo.sizes(), leftHalo.strides(), leftHalo.data(), rightHalo.sizes(), rightHalo.strides(), rightHalo.data(), tc, key); } extern "C" { #define TYPED_UPDATE_HALO(_sfx, _typ) \ void *_idtr_update_halo_##_sfx( \ SHARPY::Transceiver *tc, int64_t gShapeRank, void *gShapeDescr, \ int64_t oOffRank, void *oOffDescr, int64_t oDataRank, void *oDataDescr, \ int64_t bbOffRank, void *bbOffDescr, int64_t bbShapeRank, \ void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr, \ int64_t rHaloRank, void *rHaloDescr, int64_t key) { \ return _idtr_update_halo<_typ>( \ tc, gShapeRank, gShapeDescr, oOffRank, oOffDescr, oDataRank, \ oDataDescr, bbOffRank, bbOffDescr, bbShapeRank, bbShapeDescr, \ lHaloRank, lHaloDescr, rHaloRank, rHaloDescr, key); \ } \ _Pragma(STRINGIFY(weak _mlir_ciface__idtr_update_halo_##_sfx = \ _idtr_update_halo_##_sfx)) TYPED_UPDATE_HALO(f64, double); TYPED_UPDATE_HALO(f32, float); TYPED_UPDATE_HALO(i64, int64_t); TYPED_UPDATE_HALO(i32, int32_t); TYPED_UPDATE_HALO(i16, int16_t); TYPED_UPDATE_HALO(i8, int8_t); TYPED_UPDATE_HALO(i1, bool); } // extern "C"