/* C++ representation of the array-API's creation functions. */ #include "sharpy/Creator.hpp" #include "sharpy/Deferred.hpp" #include "sharpy/Factory.hpp" #include "sharpy/NDArray.hpp" #include "sharpy/Transceiver.hpp" #include "sharpy/TypeDispatch.hpp" #include "sharpy/jit/mlir.hpp" #include #include #include #include #include #include #include #include namespace SHARPY { static bool FORCE_DIST = get_bool_env("SHARPY_FORCE_DIST"); inline uint64_t mkTeam(uint64_t team) { if (team && (FORCE_DIST || getTransceiver()->nranks() > 1)) { return 1; } return 0; } // check that shape elements are non-negative void validateShape(const shape_type &shape) { for (auto &v : shape) { if (v < 0) { throw std::invalid_argument( "invalid shape, negative dimensions are not allowed\n"); } } } struct DeferredFull : public Deferred { PyScalar _val; DeferredFull() = default; DeferredFull(const shape_type &shape, PyScalar val, DTypeId dtype, const std::string &device, uint64_t team) : Deferred(dtype, shape, device, team), _val(val) { validateShape(shape); } template struct ValAndDType { static ::mlir::Value op(::mlir::OpBuilder &builder, const ::mlir::Location &loc, const PyScalar &val, ::imex::ndarray::DType &dtyp) { dtyp = jit::PT_DTYPE::value; if (is_none(val)) { return {}; } else if constexpr (std::is_floating_point_v) { return ::imex::createFloat(loc, builder, val._float, sizeof(T) * 8); } else if constexpr (std::is_same_v) { return ::imex::createInt(loc, builder, val._int, 1); } else if constexpr (std::is_integral_v) { return ::imex::createInt(loc, builder, val._int, sizeof(T) * 8); } throw std::invalid_argument("Unsupported dtype in dispatch"); return {}; }; }; bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc, jit::DepManager &dm) override { ::mlir::SmallVector<::mlir::Value> shp(rank()); for (auto i = 0ul; i < rank(); ++i) { shp[i] = ::imex::createIndex(loc, builder, shape()[i]); } ::imex::ndarray::DType dtyp; ::mlir::Value val = dispatch(_dtype, builder, loc, _val, dtyp); auto envs = jit::mkEnvs(builder, rank(), _device, team()); dm.addVal( this->guid(), builder.create<::imex::ndarray::CreateOp>(loc, shp, dtyp, val, envs), [this](uint64_t rank, void *l_allocated, void *l_aligned, intptr_t l_offset, const intptr_t *l_sizes, const intptr_t *l_strides, void *o_allocated, void *o_aligned, intptr_t o_offset, const intptr_t *o_sizes, const intptr_t *o_strides, void *r_allocated, void *r_aligned, intptr_t r_offset, const intptr_t *r_sizes, const intptr_t *r_strides, std::vector &&loffs) { assert(rank == this->rank()); this->set_value(mk_tnsr( this->guid(), _dtype, this->shape(), this->device(), this->team(), l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated, o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned, r_offset, r_sizes, r_strides, std::move(loffs))); }); return false; } FactoryId factory() const override { return F_FULL; } template void serialize(S &ser) { // ser.template container(_shape, 8); ser.template value(_val._int); ser.template value(_dtype); } }; FutureArray *Creator::full(const shape_type &shape, const py::object &val, DTypeId dtype, const std::string &device, uint64_t team) { auto v = mk_scalar(val, dtype); return new FutureArray( defer(shape, v, dtype, device, mkTeam(team))); } // *************************************************************************** struct DeferredArange : public Deferred { uint64_t _start, _end, _step; DeferredArange() = default; DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype, const std::string &device, uint64_t team) : Deferred(dtype, {static_cast( (end - start + step + (step < 0 ? 1 : -1)) / step)}, device, team), _start(start), _end(end), _step(step) { if (_start > _end && _step > -1ul) { throw std::invalid_argument("start > end and step > -1 in arange"); } if (_start < _end && _step < 1) { throw std::invalid_argument("start < end and step < 1 in arange"); } } bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc, jit::DepManager &dm) override { auto _num = shape()[0]; auto start = ::imex::createFloat(loc, builder, _start); auto stop = ::imex::createFloat(loc, builder, _start + _num * _step); auto num = ::imex::createIndex(loc, builder, _num); auto dtyp = jit::getPTDType(dtype()); auto envs = jit::mkEnvs(builder, rank(), _device, team()); dm.addVal( this->guid(), builder.create<::imex::ndarray::LinSpaceOp>(loc, start, stop, num, false, dtyp, envs), [this](uint64_t rank, void *l_allocated, void *l_aligned, intptr_t l_offset, const intptr_t *l_sizes, const intptr_t *l_strides, void *o_allocated, void *o_aligned, intptr_t o_offset, const intptr_t *o_sizes, const intptr_t *o_strides, void *r_allocated, void *r_aligned, intptr_t r_offset, const intptr_t *r_sizes, const intptr_t *r_strides, std::vector &&loffs) { assert(rank == 1); assert(o_strides[0] == 1); this->set_value(mk_tnsr( this->guid(), _dtype, this->shape(), this->device(), this->team(), l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated, o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned, r_offset, r_sizes, r_strides, std::move(loffs))); }); return false; } FactoryId factory() const override { return F_ARANGE; } template void serialize(S &ser) { ser.template value(_start); ser.template value(_end); ser.template value(_step); } }; FutureArray *Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype, const std::string &device, uint64_t team) { return new FutureArray( defer(start, end, step, dtype, device, mkTeam(team))); } // *************************************************************************** struct DeferredLinspace : public Deferred { double _start, _end; uint64_t _num; bool _endpoint; DeferredLinspace() = default; DeferredLinspace(double start, double end, uint64_t num, bool endpoint, DTypeId dtype, const std::string &device, uint64_t team) : Deferred(dtype, {static_cast(num)}, device, team), _start(start), _end(end), _num(num), _endpoint(endpoint) {} bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc, jit::DepManager &dm) override { auto start = ::imex::createFloat(loc, builder, _start); auto stop = ::imex::createFloat(loc, builder, _end); auto num = ::imex::createIndex(loc, builder, _num); auto dtyp = jit::getPTDType(dtype()); auto envs = jit::mkEnvs(builder, rank(), _device, team()); dm.addVal( this->guid(), builder.create<::imex::ndarray::LinSpaceOp>(loc, start, stop, num, _endpoint, dtyp, envs), [this](uint64_t rank, void *l_allocated, void *l_aligned, intptr_t l_offset, const intptr_t *l_sizes, const intptr_t *l_strides, void *o_allocated, void *o_aligned, intptr_t o_offset, const intptr_t *o_sizes, const intptr_t *o_strides, void *r_allocated, void *r_aligned, intptr_t r_offset, const intptr_t *r_sizes, const intptr_t *r_strides, std::vector &&loffs) { assert(rank == 1); assert(l_strides[0] == 1); this->set_value(mk_tnsr( this->guid(), _dtype, this->shape(), this->device(), this->team(), l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated, o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned, r_offset, r_sizes, r_strides, std::move(loffs))); }); return false; } FactoryId factory() const override { return F_ARANGE; } template void serialize(S &ser) { ser.template value(_start); ser.template value(_end); ser.template value(_num); ser.template value(_endpoint); } }; FutureArray *Creator::linspace(double start, double end, uint64_t num, bool endpoint, DTypeId dtype, const std::string &device, uint64_t team) { return new FutureArray(defer(start, end, num, endpoint, dtype, device, mkTeam(team))); } // *************************************************************************** extern DTypeId DEFAULT_FLOAT; extern DTypeId DEFAULT_INT; std::pair Creator::mk_future(const py::object &b, const std::string &device, uint64_t team, DTypeId dtype) { if (py::isinstance(b)) { return {b.cast(), false}; } else if (py::isinstance(b) || py::isinstance(b)) { return {Creator::full({}, b, dtype, device, team), true}; } throw std::invalid_argument( "Invalid right operand to elementwise binary operation"); }; FACTORY_INIT(DeferredFull, F_FULL); FACTORY_INIT(DeferredArange, F_ARANGE); FACTORY_INIT(DeferredLinspace, F_LINSPACE); } // namespace SHARPY