// SPDX-License-Identifier: BSD-3-Clause /* Random number ops. */ #include "sharpy/Random.hpp" #include "sharpy/Factory.hpp" #include "sharpy/NDArray.hpp" #include namespace SHARPY { using ptr_type = array_i::ptr_type; #if 0 namespace x { template struct Rand { //template static ptr_type op(const shape_type & shp, T lower, T upper) { PVSlice pvslice(shp); shape_type shape(std::move(pvslice.tile_shape())); auto r = operatorx::mk_tx(std::move(pvslice), std::move(xt::random::rand(std::move(shape), lower, upper))); return r; } }; } #endif // if 0 struct DeferredRandomOp : public Deferred { shape_type _shape; double _lower, _upper; DTypeId _dtype; DeferredRandomOp() = default; DeferredRandomOp(const shape_type &shape, double lower, double upper, DTypeId dtype) : _shape(shape), _lower(lower), _upper(upper), _dtype(dtype) {} void run() override { #if 0 switch(_dtype) { case FLOAT64: set_value(std::move(x::Rand::op(_shape, _lower, _upper))); return; case FLOAT32: set_value(std::move(x::Rand::op(_shape, static_cast(_lower), static_cast(_upper)))); return; } throw std::runtime_error("rand: dtype must be a floating point type"); #endif // if 0 } FactoryId factory() const override { return F_RANDOM; } template void serialize(S &ser) { ser.template container(_shape, 8); ser.template value(_lower); ser.template value(_upper); ser.template value(_dtype); } }; FutureArray *Random::rand(DTypeId dtype, const shape_type &shape, const py::object &lower, const py::object &upper) { return new FutureArray(defer( shape, to_native(lower), to_native(upper), dtype)); } void Random::seed(uint64_t s) { // FIXME defer_lambda([s](){xt::random::seed(s); return // array_i::ptr_type();}); } FACTORY_INIT(DeferredRandomOp, F_RANDOM); } // namespace SHARPY