diff --git a/CMakeLists.txt b/CMakeLists.txt index 786f9078..6f3dca02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ project(InfiniTensor C CXX) # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. option(USE_CUDA "Support CUDA GPU" OFF) option(USE_BANG "Support BANG MLU" OFF) +option(USE_MKL "Support MKL" OFF) option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) option(USE_PROTOBUF "Serialize and deserialize tensors" ON) option(BUILD_TEST "Build tests" ON) @@ -86,6 +87,11 @@ if(USE_BANG) list (APPEND SRC ${SRC_BANG}) endif() +if(USE_MKL) + file(GLOB_RECURSE SRC_MKL src/mkl/*.cc src/kernels/mkl/*.cc ) + list (APPEND SRC ${SRC_MKL}) +endif() + # Libraries add_library(InfiniTensor SHARED ${SRC}) if(USE_PROTOBUF) @@ -107,6 +113,21 @@ if(USE_BACKTRACE) target_link_libraries(InfiniTensor dw) endif() +if(USE_MKL) + find_package(MKL CONFIG REQUIRED) + target_link_libraries(InfiniTensor $) + set(DNNL_CONFIGURATION "cpu_gomp") + find_package(dnnl CONFIG REQUIRED) + if(dnnl_FOUND) + add_compile_definitions(USE_MKL=1) + include_directories(BEFORE ${dnnl_DIR}/../../../cpu_gomp/include/) + link_directories(${dnnl_DIR}/../../../cpu_gomp/lib) + target_link_libraries(InfiniTensor dnnl) + else() + message(FATAL_ERROR ”dnnl library not found”) + endif() +endif() + if(USE_CUDA) add_compile_definitions(USE_CUDA=1) # Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong @@ -189,6 +210,9 @@ if(BUILD_TEST) if (USE_BANG) build_test(test/kernels/bang/*.cc) endif() + if (USE_MKL) + build_test(test/kernels/mkl/*.cc) + endif() endif() if(BUILD_TEST_PET) build_test(test/pet/*.cc) diff --git a/include/core/mutator.h b/include/core/mutator.h index 0b4ef2fb..5cb7d16a 100644 --- a/include/core/mutator.h +++ b/include/core/mutator.h @@ -12,7 +12,8 @@ class Mutator { Runtime runtime; public: - Mutator(int candidatesLimit, Runtime runtime = CpuRuntimeObj::getInstance()) + Mutator(int candidatesLimit, + Runtime runtime = NativeCpuRuntimeObj::getInstance()) : candidatesLimit(candidatesLimit), runtime(runtime){}; virtual ~Mutator(){}; diff --git a/include/core/runtime.h b/include/core/runtime.h index 10103b4d..8e7be034 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -28,7 +28,7 @@ using OpVec = vector; using VType = uint32_t; -enum class Device { CPU = 1, CUDA, BANG }; +enum class Device { CPU = 1, CUDA, BANG, MKL }; /***************** Forward declaration end *****************/ class RuntimeObj : public std::enable_shared_from_this { @@ -53,7 +53,7 @@ class RuntimeObj : public std::enable_shared_from_this { bool profiling = false) const = 0; virtual void *alloc(size_t size) = 0; virtual void dealloc(void *ptr) = 0; - + void prepareAndRun(Graph &graph, bool tune = false, bool profiling = false); /** * @brief Get the execution time of each operator in performance record. No * execution happens. @@ -64,7 +64,9 @@ class RuntimeObj : public std::enable_shared_from_this { */ double getPerfTime(const Graph &graph, bool profiling = false) const; Blob allocBlob(size_t size); - bool isCpu() const { return device == Device::CPU; } + bool isCpu() const { + return device == Device::CPU || device == Device::MKL; + } bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } void copyBlob(const TensorObj *dst, const TensorObj *src) const; @@ -85,26 +87,33 @@ class RuntimeObj : public std::enable_shared_from_this { class CpuRuntimeObj : public RuntimeObj { public: - CpuRuntimeObj() : RuntimeObj(Device::CPU) {} - static Ref &getInstance() { - static Ref instance = make_ref(); - return instance; - } + CpuRuntimeObj(Device dev) : RuntimeObj(dev) {} void run(const Graph &graph, bool tune = false, bool profiling = false) const override; - void dealloc(void *ptr) override { return free(ptr); }; - - void *alloc(size_t size) override { - return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t), - sizeof(uint64_t)); - }; void copyBlobFromCPU(void *dst, const void *src, size_t bytes) const override; void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override; void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override; +}; + +class NativeCpuRuntimeObj : public CpuRuntimeObj { + public: + NativeCpuRuntimeObj() : CpuRuntimeObj(Device::CPU) {} + + static Ref &getInstance() { + static Ref instance = + make_ref(); + return instance; + } + void dealloc(void *ptr) override { return free(ptr); }; + + void *alloc(size_t size) override { + return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t), + sizeof(uint64_t)); + }; string toString() const override; }; diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index 61b8d032..3b9bff3f 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -5,7 +5,6 @@ #include "core/runtime.h" namespace infini { - class TensorBaseObj : public Object { public: // enum TensorType { diff --git a/include/mkl/mkl_runtime.h b/include/mkl/mkl_runtime.h new file mode 100644 index 00000000..6cfc7993 --- /dev/null +++ b/include/mkl/mkl_runtime.h @@ -0,0 +1,33 @@ +#pragma once +#include "core/runtime.h" +#include "dnnl.h" +#include "oneapi/dnnl/dnnl.h" +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_types.h" +#include +#include +namespace infini { +// TODO move utility function to alone file +class MklRuntimeObj : public CpuRuntimeObj { + dnnl_engine_t engine; + + public: + MklRuntimeObj(); + static Ref &getInstance() { + static Ref instance = make_ref(); + return instance; + } + + virtual ~MklRuntimeObj(); + void dealloc(void *ptr) override { return mkl_free(ptr); }; + + void *alloc(size_t size) override { + return mkl_calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t), + sizeof(uint64_t), 64); + }; + + string toString() const override { return "CPU MKL Runtime"; }; + dnnl::engine getEngine() const { return dnnl::engine(engine, true); } +}; + +} // namespace infini diff --git a/include/mkl/operator_timer.h b/include/mkl/operator_timer.h new file mode 100644 index 00000000..5ef03691 --- /dev/null +++ b/include/mkl/operator_timer.h @@ -0,0 +1,15 @@ +#pragma once +namespace infini { +namespace opTimer { +double getPerfConvMkl(int n, int c, int h, int w, int f, int r, int s, int padh, + int padw, int strideh, int stridew, int dilationh, + int dilationw, int group); + +double getPerfConvTransposed2dMkl(int n, int c, int h, int w, int f, int r, + int s, int padh, int padw, int strideh, + int stridew, int dilationh, int dilationw, + int oph, int opw, int group); + +double getPerfMatmulMkl(int b, int m, int n, int k); +} // namespace opTimer +} // namespace infini diff --git a/python/infinitensor/operator_timer.py b/python/infinitensor/operator_timer.py index 52c776fa..e39d9814 100644 --- a/python/infinitensor/operator_timer.py +++ b/python/infinitensor/operator_timer.py @@ -2,14 +2,28 @@ from tokenize import Double import pyinfinitensor # import getPerfConv, getPerfMatmul -def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""): - return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw, +# FIXME: change API from getPerfOpDevice(...) to getPerfOp(device='dev', ...) +def getPerfConvCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""): + return pyinfinitensor.getPerfConvCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name) -def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group): - return pyinfinitensor.getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group) +def getPerfConvTransposed2dCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group): + return pyinfinitensor.getPerfConvTransposed2dCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group) -def getPerfMatmul(b, m, n, k, name=""): - return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name) +def getPerfMatmulCuda(b, m, n, k, name=""): + return pyinfinitensor.getPerfMatmulCuda(b, m, n, k, name) + + +def getPerfConvMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""): + return pyinfinitensor.getPerfConvMkl(n, c, h, w, f, r, s, padh, padw, + strideh, stridew, dilationh, dilationw, group) + + +def getPerfConvTransposed2dMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group): + return pyinfinitensor.getPerfConvTransposed2dMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group) + + +def getPerfMatmulMkl(b, m, n, k, name=""): + return pyinfinitensor.getPerfMatmulMkl(b, m, n, k) diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 28f5ecc8..71ee6bdd 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -6,6 +6,9 @@ #include #include namespace infini { +void RuntimeObj::prepareAndRun(Graph &graph, bool tune, bool profiling) { + run(graph, tune, profiling); +} void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { if (!tune && profiling) @@ -159,6 +162,6 @@ void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, const void *src, memcpy(dst, src, bytes); } -string CpuRuntimeObj::toString() const { return "CPU Runtime"; } +string NativeCpuRuntimeObj::toString() const { return "CPU Runtime"; } } // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 362c3e76..cc2de201 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -3,6 +3,7 @@ #include "core/operator.h" #include "core/runtime.h" #include "utils/dataloader.h" +#include #include namespace infini { @@ -157,7 +158,7 @@ void TensorObj::setData( generator(getRawDataPtr(), size(), dtype); } else { // Create a CPU buffer for the generetor and copy results to the device - auto cpuRuntime = CpuRuntimeObj::getInstance(); + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); size_t nBytes = size() * dtype.getSize(); Blob buffer = cpuRuntime->allocBlob(nBytes); generator(buffer->getPtr(), size(), dtype); @@ -200,5 +201,4 @@ size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset, } return getOffsetByPos(pos, shape); } - }; // namespace infini diff --git a/src/cuda/operator_timer.cc b/src/cuda/operator_timer.cc index 34241e27..c011e588 100644 --- a/src/cuda/operator_timer.cc +++ b/src/cuda/operator_timer.cc @@ -17,7 +17,7 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, // dilationh, dilationw, group] = // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -51,7 +51,7 @@ double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r, // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, // dilationh, dilationw, group] = // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -83,7 +83,7 @@ double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) { // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, // dilationh, dilationw, group] = // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -109,4 +109,4 @@ double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) { } } // namespace opTimer -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 7d080548..7414d4d5 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -12,7 +12,9 @@ #include "cuda/cuda_runtime.h" #include "cuda/operator_timer.h" #endif - +#ifdef USE_MKL +#include "mkl/operator_timer.h" +#endif namespace py = pybind11; namespace infini { @@ -27,6 +29,13 @@ void register_operator_timer(py::module &m) { m.def("getPerfConvTransposed2dCudnn", &getPerfConvTransposed2dCudnn); m.def("getPerfMatmulCublas", &getPerfMatmulCublas); #endif + +#ifdef USE_MKL + using namespace opTimer; + m.def("getPerfConvMkl", &getPerfConvMkl); + m.def("getPerfConvTransposed2dMkl", &getPerfConvTransposed2dMkl); + m.def("getPerfMatmulMkl", &getPerfMatmulMkl); +#endif } void export_values(py::module &m) { @@ -149,7 +158,7 @@ static Shape reshape_shape_of(Operator op) { void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) - m.def("cpu_runtime", &CpuRuntimeObj::getInstance) + m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) #ifdef USE_CUDA .FUNCTION(cuda_runtime) #endif @@ -168,8 +177,8 @@ void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; py::class_>(m, "Runtime"); - py::class_, RuntimeObj>( - m, "CpuRuntime"); + py::class_, + RuntimeObj>(m, "CpuRuntime"); #ifdef USE_CUDA py::class_, RuntimeObj>( m, "CudaRuntime"); @@ -184,7 +193,8 @@ void init_graph_builder(py::module &m) { .def("copyout_int32", &TensorObj::copyout, policy::move) .def("copyout_int64", &TensorObj::copyout, policy::move) .def("has_target", &TensorObj::hasTarget, policy::automatic) - .def("src", &TensorObj::getSource, policy::move); + .def("src", &TensorObj::getSource, policy::move) + .def("printData", &TensorObj::printData, policy::automatic); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic) .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_), diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc index 35cc78c0..ce1aaf27 100644 --- a/src/kernels/cuda/batch_norm.cc +++ b/src/kernels/cuda/batch_norm.cc @@ -18,8 +18,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { void *const biasData = (op->getInputs(4)->getRawDataPtr()); auto dims = op->getInputs(0)->getDims(); - if (dims.size() == 2) - IT_TODO_HALT(); // Only 4D and 5D tensors are supported by // cudnnBatchNormalizationForwardInference IT_ASSERT(dims.size() == 4 || dims.size() == 5); diff --git a/src/kernels/mkl/conv.cc b/src/kernels/mkl/conv.cc new file mode 100644 index 00000000..18cc4ca2 --- /dev/null +++ b/src/kernels/mkl/conv.cc @@ -0,0 +1,237 @@ +#include "operators/conv.h" +#include "core/kernel.h" +#include "mkl/mkl_runtime.h" + +namespace infini { +struct ConvMklPerfRecordObj : public PerfRecordObj { + dnnl::algorithm algo = dnnl::algorithm::convolution_auto; + void to_json(json &j) override { + j["type"] = 1; + j["data"] = std::make_tuple(enum_to_underlying(algo), time); + } + static PerfRecord from_json(const json &j) { + ConvMklPerfRecordObj tmp; + auto [Algo, Time] = j["data"].get>(); + tmp.algo = (dnnl::algorithm)Algo; + tmp.time = Time; + return make_ref(tmp); + } +}; + +using ConvMklPerfRecord = Ref; +class MklConv : public Kernel { + bool createPrimitives( + const Ref &op, const ConvMklPerfRecord &record, + const MklRuntimeObj *context, bool allowEmpty, + std::vector &prims, + std::vector> &primArgs) const { + auto srcData = op->getInputs(0)->getRawDataPtr(); + auto wData = op->getInputs(1)->getRawDataPtr(); + auto dstData = op->getOutput(0)->getRawDataPtr(); + + auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const int cpg = op->getChannelPerGroup(); + + auto oDims = op->getOutput(0)->getDims(); + int oH = oDims[oDims.size() - 2]; + int oW = oDims[oDims.size() - 1]; + + // create user memory that describes data layout in the buffers + auto userSrcMd = + dnnl::memory::desc({n, c, h, w}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::nchw); + auto userSrcMemory = + dnnl::memory(userSrcMd, context->getEngine(), srcData); + + auto userWMd = + dnnl::memory::desc({f, cpg, r, s}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::oihw); + auto userWMemory = dnnl::memory(userWMd, context->getEngine(), wData); + auto userDstMd = + dnnl::memory::desc({n, f, oH, oW}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::nchw); + + // create memory descriptors with layout tag::any, to let convolution + // choose memory format + // Convolution and inner product primitives choose the memory format + // when you create them with the placeholder memory format + // dnnl::memory::format_tag::any for input or output. The memory format + // chosen is based on different circumstances such as hardware and + // convolutional parameters. Using the placeholder memory format is the + // recommended practice for convolutions, since they are the most + // compute-intensive operations in most topologies where they are + // present. + auto srcMd = + dnnl::memory::desc({n, c, h, w}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + auto wMd = + dnnl::memory::desc({f, cpg, r, s}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + auto dstMd = + dnnl::memory::desc({n, f, oH, oW}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + + // create convolution descriptor + dnnl::memory::dims strides = {sh, sw}; + dnnl::memory::dims pads = {ph, pw}; + dnnl::memory::dims dilations = {dh - 1, dw - 1}; + auto convDesc = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, record->algo, srcMd, wMd, dstMd, + strides, dilations, pads, pads); + + dnnl::convolution_forward::primitive_desc primDesc; + + // fused convolution + // The non-intensive operation is added as a post-op attribute to the + // compute intensive primitive descriptor + if (ActType::None != op->getAct()) { + dnnl::algorithm algo; + switch (op->getAct()) { + case ActType::Relu: + algo = dnnl::algorithm::eltwise_relu; + break; + case ActType::Sigmoid: + algo = dnnl::algorithm::eltwise_logsigmoid; + break; + case ActType::Tanh: + algo = dnnl::algorithm::eltwise_tanh; + break; + + default: + IT_ASSERT(0); + } + dnnl::primitive_attr attr; + dnnl::post_ops po; + po.append_eltwise(1.f, algo, 0.f, 0.f); + attr.set_post_ops(po); + + primDesc = dnnl::convolution_forward::primitive_desc( + convDesc, attr, context->getEngine(), allowEmpty); + + } else { + primDesc = dnnl::convolution_forward::primitive_desc( + convDesc, context->getEngine(), allowEmpty); + } + + if (primDesc.get(allowEmpty) == nullptr) + return false; + + // reorder data and weight + auto srcMemory = userSrcMemory; + if (primDesc.src_desc() != userSrcMemory.get_desc()) { + srcMemory = dnnl::memory(primDesc.src_desc(), context->getEngine()); + + prims.push_back(dnnl::reorder(userSrcMemory, srcMemory)); + primArgs.push_back( + {{DNNL_ARG_FROM, userSrcMemory}, {DNNL_ARG_TO, srcMemory}}); + } + + auto wMemory = userWMemory; + if (primDesc.weights_desc() != userWMemory.get_desc()) { + wMemory = + dnnl::memory(primDesc.weights_desc(), context->getEngine()); + + prims.push_back(dnnl::reorder(userWMemory, wMemory)); + primArgs.push_back( + {{DNNL_ARG_FROM, userWMemory}, {DNNL_ARG_TO, wMemory}}); + } + + // Create memory for output + if (primDesc.dst_desc() == userDstMd) { + auto output = dnnl::memory(primDesc.dst_desc(), + context->getEngine(), dstData); + + // create convolution primitivee + prims.push_back(dnnl::convolution_forward(primDesc)); + primArgs.push_back({{DNNL_ARG_SRC, srcMemory}, + {DNNL_ARG_WEIGHTS, wMemory}, + {DNNL_ARG_DST, output}}); + } else { + auto dstMemory = + dnnl::memory(primDesc.dst_desc(), context->getEngine()); + + // create convolution primitivee + prims.push_back(dnnl::convolution_forward(primDesc)); + primArgs.push_back({{DNNL_ARG_SRC, srcMemory}, + {DNNL_ARG_WEIGHTS, wMemory}, + {DNNL_ARG_DST, dstMemory}}); + + auto output = + dnnl::memory(userDstMd, context->getEngine(), dstData); + prims.push_back(dnnl::reorder(dstMemory, output)); + primArgs.push_back( + {{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}}); + } + return true; + } + + void compute(const Operator &_op, const PerfRecord &_record, + const RuntimeObj *_context) const { + auto op = as(_op); + auto context = dynamic_cast(_context); + auto record = as(_record); + + dnnl::stream stream(context->getEngine()); + std::vector prims; + std::vector> primArgs; + IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs)); + + IT_ASSERT(prims.size() == primArgs.size()); + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + stream.wait(); + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + auto record = make_ref(); + compute(op, record, context); + } + + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + ConvMklPerfRecordObj ret; + ret.time = std::numeric_limits::max(); + auto context = dynamic_cast(_context); + auto op = as(_op); + + // Try every possible algorithm of convolution + for (auto algo : {dnnl::algorithm::convolution_auto, + dnnl::algorithm::convolution_direct, + dnnl::algorithm::convolution_winograd}) { + ConvMklPerfRecordObj record; + record.algo = algo; + + std::vector prims; + std::vector> primArgs; + if (!createPrimitives(op, make_ref(record), + context, true, prims, primArgs)) + continue; + + IT_ASSERT(prims.size() == primArgs.size()); + dnnl::stream stream(context->getEngine()); + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + stream.wait(); + + record.time = timeit( + [&]() { + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + }, + [&]() { stream.wait(); }); + + // Update the tune result + if (ret.time > record.time) + ret = record; + } + + IT_ASSERT(ret.time < std::numeric_limits::max(), "No valid " + "algorithm " + "found"); + return make_ref(ret); + } +}; +REGISTER_KERNEL(Device::MKL, OpType::Conv, DataType::Float32, MklConv, + "MklConv_CPU_float32"); +} // namespace infini diff --git a/src/kernels/mkl/conv_transposed.cc b/src/kernels/mkl/conv_transposed.cc new file mode 100644 index 00000000..3c45ddd4 --- /dev/null +++ b/src/kernels/mkl/conv_transposed.cc @@ -0,0 +1,250 @@ +#include "core/kernel.h" +#include "mkl/mkl_runtime.h" +#include "operators/conv.h" + +namespace infini { +struct ConvTransposeMklPerfRecordObj : public PerfRecordObj { + dnnl::algorithm algo = dnnl::algorithm::deconvolution_direct; + void to_json(json &j) override { + j["type"] = 1; + j["data"] = std::make_tuple(enum_to_underlying(algo), time); + } + static PerfRecord from_json(const json &j) { + ConvTransposeMklPerfRecordObj tmp; + auto [Algo, Time] = j["data"].get>(); + tmp.algo = (dnnl::algorithm)Algo; + tmp.time = Time; + return make_ref(tmp); + } +}; + +using ConvTransposeMklPerfRecord = Ref; +class MklConvTranspose : public Kernel { + private: + bool createPrimitives( + const Ref &op, + const ConvTransposeMklPerfRecord &record, const MklRuntimeObj *context, + bool allowEmpty, std::vector &prims, + std::vector> &primArgs) const { + auto srcData = op->getInputs(0)->getRawDataPtr(); + auto wData = op->getInputs(1)->getRawDataPtr(); + // FIXME: iohw2iohwData + auto dstData = op->getOutput(0)->getRawDataPtr(); + + auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const int cpg = op->getChannelPerGroup(); + if (cpg != c) + IT_TODO_HALT(); + + auto oDims = op->getOutput(0)->getDims(); + int oH = oDims[oDims.size() - 2]; + int oW = oDims[oDims.size() - 1]; + + // create user memory that describes data layout in the buffers + auto userSrcMd = + dnnl::memory::desc({n, f, h, w}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::nchw); + auto userSrcMemory = + dnnl::memory(userSrcMd, context->getEngine(), srcData); + + // DNNL deconvolution expects the logical order of weights (parameters) + // dimensions to be in order {o, i, h, w}. So need to reorder wData. + // TODO: to make reorder happen only once when inference (because + // weights are fixed). + // TODO: Fix by whj, change memory format tag from oihw to iohw to + // remove extra transpose. Correctness to be confirmed. + auto userWMd = + dnnl::memory::desc({cpg, f, r, s}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::iohw); + + auto userWMemory = dnnl::memory(userWMd, context->getEngine(), wData); + + auto userDstMd = + dnnl::memory::desc({n, c, oH, oW}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::nchw); + + // create memory descriptors with layout tag::any, to let convolution + // choose memory format + // Convolution and inner product primitives choose the memory format + // when you create them with the placeholder memory format + // dnnl::memory::format_tag::any for input or output. The memory format + // chosen is based on different circumstances such as hardware and + // convolutional parameters. Using the placeholder memory format is the + // recommended practice for convolutions, since they are the most + // compute-intensive operations in most topologies where they are + // present. + auto srcMd = + dnnl::memory::desc({n, f, h, w}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + auto wMd = + dnnl::memory::desc({cpg, f, r, s}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + auto dstMd = + dnnl::memory::desc({n, c, oH, oW}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + + // create convolution descriptor + dnnl::memory::dims strides = {sh, sw}; + dnnl::memory::dims pads = {ph, pw}; + dnnl::memory::dims dilations = {dh - 1, dw - 1}; + auto deconvDesc = dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, record->algo, srcMd, wMd, dstMd, + strides, dilations, pads, pads); + + dnnl::deconvolution_forward::primitive_desc primDesc; + // fused convolution + // The non-intensive operation is added as a post-op attribute to the + // compute intensive primitive descriptor + if (ActType::None != op->getAct()) { + dnnl::algorithm algo; + switch (op->getAct()) { + case ActType::Relu: + algo = dnnl::algorithm::eltwise_relu; + break; + case ActType::Sigmoid: + algo = dnnl::algorithm::eltwise_logsigmoid; + break; + case ActType::Tanh: + algo = dnnl::algorithm::eltwise_tanh; + break; + + default: + IT_ASSERT(0); + } + dnnl::primitive_attr attr; + dnnl::post_ops po; + po.append_eltwise(1.f, algo, 0.f, 0.f); + attr.set_post_ops(po); + + primDesc = dnnl::deconvolution_forward::primitive_desc( + deconvDesc, attr, context->getEngine(), allowEmpty); + + } else { + primDesc = dnnl::deconvolution_forward::primitive_desc( + deconvDesc, context->getEngine(), allowEmpty); + } + + if (primDesc.get(allowEmpty) == nullptr) + return false; + + // reorder data and weight + auto srcMemory = userSrcMemory; + if (primDesc.src_desc() != userSrcMemory.get_desc()) { + srcMemory = dnnl::memory(primDesc.src_desc(), context->getEngine()); + + prims.push_back(dnnl::reorder(userSrcMemory, srcMemory)); + primArgs.push_back( + {{DNNL_ARG_FROM, userSrcMemory}, {DNNL_ARG_TO, srcMemory}}); + } + + auto wMemory = userWMemory; + if (primDesc.weights_desc() != userWMemory.get_desc()) { + wMemory = + dnnl::memory(primDesc.weights_desc(), context->getEngine()); + + prims.push_back(dnnl::reorder(userWMemory, wMemory)); + primArgs.push_back( + {{DNNL_ARG_FROM, userWMemory}, {DNNL_ARG_TO, wMemory}}); + } + + if (primDesc.dst_desc() == userDstMd) { + // Create memory for output + auto dstMemory = dnnl::memory(primDesc.dst_desc(), + context->getEngine(), dstData); + + // create convolution primitivee + prims.push_back(dnnl::deconvolution_forward(primDesc)); + primArgs.push_back({{DNNL_ARG_SRC, srcMemory}, + {DNNL_ARG_WEIGHTS, wMemory}, + {DNNL_ARG_DST, dstMemory}}); + } else { + auto dstMemory = + dnnl::memory(primDesc.dst_desc(), context->getEngine()); + + // create convolution primitivee + prims.push_back(dnnl::deconvolution_forward(primDesc)); + primArgs.push_back({{DNNL_ARG_SRC, srcMemory}, + {DNNL_ARG_WEIGHTS, wMemory}, + {DNNL_ARG_DST, dstMemory}}); + + auto output = + dnnl::memory(userDstMd, context->getEngine(), dstData); + + prims.push_back(dnnl::reorder(dstMemory, output)); + primArgs.push_back( + {{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}}); + } + return true; + } + + void compute(const Operator &_op, const PerfRecord &_record, + const RuntimeObj *_context) const { + auto op = as(_op); + auto context = dynamic_cast(_context); + auto record = as(_record); + + dnnl::stream stream(context->getEngine()); + std::vector prims; + std::vector> primArgs; + IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs)); + + IT_ASSERT(prims.size() == primArgs.size()); + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + stream.wait(); + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + auto record = make_ref(); + compute(op, record, context); + } + + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + ConvTransposeMklPerfRecordObj ret; + ret.time = std::numeric_limits::max(); + auto context = dynamic_cast(_context); + auto op = as(_op); + + // Try every possible algorithm of convolution + for (auto algo : {dnnl::algorithm::deconvolution_direct, + dnnl::algorithm::deconvolution_winograd}) { + ConvTransposeMklPerfRecordObj record; + record.algo = algo; + + std::vector prims; + std::vector> primArgs; + if (!createPrimitives( + op, make_ref(record), + context, true, prims, primArgs)) + continue; + + IT_ASSERT(prims.size() == primArgs.size()); + dnnl::stream stream(context->getEngine()); + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + stream.wait(); + + record.time = timeit( + [&]() { + for (size_t i = 0; i < prims.size(); ++i) + prims.at(i).execute(stream, primArgs.at(i)); + }, + [&]() { stream.wait(); }); + + // Update the tune result + if (ret.time > record.time) + ret = record; + } + + IT_ASSERT(ret.time < std::numeric_limits::max(), "No valid " + "algorithm " + "found"); + return make_ref(ret); + } +}; +REGISTER_KERNEL(Device::MKL, OpType::ConvTrans, DataType::Float32, + MklConvTranspose, "MklConvTrans_CPU_float32"); + +} // namespace infini diff --git a/src/kernels/mkl/matmul.cc b/src/kernels/mkl/matmul.cc new file mode 100644 index 00000000..02e6dd53 --- /dev/null +++ b/src/kernels/mkl/matmul.cc @@ -0,0 +1,38 @@ +#include "operators/matmul.h" +#include "core/kernel.h" +#include "mkl/mkl_runtime.h" + +namespace infini { + +template class MklMatmul : public CpuKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *context) const override { + auto op = as(_op); + IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); + const T *A = op->getInputs(0)->getRawDataPtr(); + const T *B = op->getInputs(1)->getRawDataPtr(); + T *C = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getAct() == ActType::None); + const int m = op->getM(), n = op->getN(), k = op->getK(), + b = op->getB(); + + auto opA = op->getTransA() ? CblasTrans : CblasNoTrans; + auto opB = op->getTransB() ? CblasTrans : CblasNoTrans; + // lda is always a.col, and ldb is always b.col when row major + const int lda = std::max((opA == CblasNoTrans) ? k : m, 1); + const int ldb = std::max((opB == CblasNoTrans) ? n : k, 1); + const int ldc = std::max(n, 1); + + const float alpha = 1.f, beta = 0.f; + // TODO: Intel MKL ERROR will occur when using cblas_sgemm_batch + for (int i = 0; i < b; ++i) { + cblas_sgemm(CblasRowMajor, opA, opB, m, n, k, alpha, A + m * k * i, + lda, B + k * n * i, ldb, beta, C + m * n * i, ldc); + } + } +}; + +REGISTER_KERNEL(Device::MKL, OpType::Matmul, DataType::Float32, + MklMatmul, "MklMatmul_CPU_float32"); + +} // namespace infini diff --git a/src/mkl/mkl_runtime.cc b/src/mkl/mkl_runtime.cc new file mode 100644 index 00000000..6b868f70 --- /dev/null +++ b/src/mkl/mkl_runtime.cc @@ -0,0 +1,13 @@ +#include "mkl/mkl_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +namespace infini { +MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::MKL) { + dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0); +} + +MklRuntimeObj::~MklRuntimeObj() { + mkl_free_buffers(); + dnnl_engine_destroy(engine); +} +} // namespace infini diff --git a/src/mkl/operator_timer.cc b/src/mkl/operator_timer.cc new file mode 100644 index 00000000..c6f1c55d --- /dev/null +++ b/src/mkl/operator_timer.cc @@ -0,0 +1,82 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "mkl/mkl_runtime.h" +#include "operators/conv.h" +#include "operators/matmul.h" +#include "utils/data_generator.h" + +namespace infini { +namespace opTimer { + +double getPerfConvMkl(int n, int c, int h, int w, int f, int r, int s, int padh, + int padw, int strideh, int stridew, int dilationh, + int dilationw, int group) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton + Graph g = make_ref(runtime); + IT_ASSERT(c % group == 0); + Tensor i0 = g->addTensor({n, c, h, w}, DataType::Float32); + Tensor w0 = g->addTensor({f, c / group, r, s}, DataType::Float32); + auto conv = g->addOp(i0, w0, nullptr, padh, padw, strideh, stridew, + dilationh, dilationw); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + g->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + bool tune = true; + runtime->run(g, tune); + return runtime->getPerfTime(g); +} + +double getPerfConvTransposed2dMkl(int n, int c, int h, int w, int f, int r, + int s, int padh, int padw, int strideh, + int stridew, int dilationh, int dilationw, + int oph, int opw, int group) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton + Graph g = make_ref(runtime); + // Set input data on CPU in a CPU Graph + IT_ASSERT(c % group == 0); + Tensor i0 = g->addTensor({n, f, h, w}, DataType::Float32); + Tensor w0 = g->addTensor({f, c / group, r, s}, DataType::Float32); + auto conv = g->addOp(i0, w0, nullptr, padh, padw, + strideh, stridew, dilationh, + dilationw, oph, opw, group); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + g->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + bool tune = true; + runtime->run(g, tune); + return runtime->getPerfTime(g); +} + +double getPerfMatmulMkl(int b, int m, int n, int k) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton + Graph g = make_ref(runtime); + // Set input data on CPU in a CPU Graph + Tensor i0 = g->addTensor({b, m, k}, DataType::Float32); + Tensor w0 = g->addTensor({b, k, n}, DataType::Float32); + auto conv = g->addOp(i0, w0, nullptr); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + g->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + bool tune = true; + runtime->run(g, tune); + return runtime->getPerfTime(g); +} + +} // namespace opTimer +} // namespace infini diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index 683f8039..b26806c3 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -607,4 +607,4 @@ double NMutator::memboundTime(const Shape &dims) { // return graph; // } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 2c1cd57f..afb4996e 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -38,7 +38,7 @@ bool GatherObj::CheckIndexValid() const { if (index->getDataBlob() == nullptr) return true; - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); int *data = (int *)runtime->alloc(index->getBytes()); index->getRuntime()->copyBlobToCPU( (void *)data, index->getRawDataPtr(), index->getBytes()); diff --git a/src/operators/resize.cc b/src/operators/resize.cc index 2826af15..5270abd2 100644 --- a/src/operators/resize.cc +++ b/src/operators/resize.cc @@ -57,7 +57,7 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes, this->roi.emplace_back(1); } - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); std::shared_ptr dataObj((float *)runtime->alloc(roi->getBytes()), [&](float *p) { runtime->dealloc(p); }); auto data = dataObj.get(); @@ -117,7 +117,7 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, // copy sizes data to host. IT_ASSERT(sizes->getDataBlob() != nullptr); - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); std::shared_ptr dataObj((int *)runtime->alloc(sizes->getBytes()), [&](int *p) { runtime->dealloc(p); }); auto data = dataObj.get(); @@ -166,7 +166,7 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales, // copy scales data to host. IT_ASSERT(scales->getDataBlob() != nullptr); - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); std::shared_ptr dataObj((float *)runtime->alloc(scales->getBytes()), [&](float *p) { runtime->dealloc(p); }); auto data = dataObj.get(); diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 85c012b9..c2b1ff4c 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -9,7 +9,7 @@ namespace infini { TEST(Graph, build_and_run) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -38,7 +38,7 @@ TEST(Graph, build_and_run) { } TEST(Graph, topological) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32); @@ -77,7 +77,7 @@ TEST(Graph, topological) { } // namespace infini TEST(Graph, perf_engine) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -99,7 +99,7 @@ TEST(Graph, perf_engine) { } TEST(Graph, test_tensor_id) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -117,7 +117,7 @@ TEST(Graph, test_tensor_id) { } TEST(Graph, test_OpVec_ctor) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); diff --git a/test/core/test_graph_handler.cc b/test/core/test_graph_handler.cc index cac81dab..b5dce89b 100644 --- a/test/core/test_graph_handler.cc +++ b/test/core/test_graph_handler.cc @@ -5,7 +5,7 @@ namespace infini { TEST(Handler, matmul) { - auto runtime = CpuRuntimeObj::getInstance(); + auto runtime = NativeCpuRuntimeObj::getInstance(); auto handler = make_ref(runtime); auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32); auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32); diff --git a/test/core/test_search.cc b/test/core/test_search.cc index 5f354fdb..a514f63f 100644 --- a/test/core/test_search.cc +++ b/test/core/test_search.cc @@ -13,7 +13,7 @@ namespace infini { // TEST(Graph, search) { -// Runtime runtime = CpuRuntimeObj::getInstance(); +// Runtime runtime = NativeCpuRuntimeObj::getInstance(); // Graph g = make_ref(runtime); // Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); // Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -30,7 +30,7 @@ namespace infini { // } TEST(Graph, search_withdm) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor t0 = g->addTensor({1, 3, 224, 224}); Tensor w0 = g->addTensor({3, 3, 3, 3}); @@ -53,7 +53,7 @@ TEST(Graph, search_withdm) { } // TEST(DummyMutator, run) { -// Runtime runtime = CpuRuntimeObj::getInstance(); +// Runtime runtime = NativeCpuRuntimeObj::getInstance(); // Graph g = make_ref(runtime); // Tensor i0 = g->addTensor({1, 3, 224, 224}); // Tensor w0 = g->addTensor({2, 3, 3, 3}); @@ -67,7 +67,7 @@ TEST(Graph, search_withdm) { // } // TEST(DummyMutator, fuse) { -// Runtime runtime = CpuRuntimeObj::getInstance(); +// Runtime runtime = NativeCpuRuntimeObj::getInstance(); // Graph g = make_ref(runtime); // Tensor i0 = g->addTensor({1, 2, 3}); // Tensor w0 = g->addTensor({1, 3, 4}); diff --git a/test/core/test_tensor_save.cc b/test/core/test_tensor_save.cc index 086e6455..f5a84618 100644 --- a/test/core/test_tensor_save.cc +++ b/test/core/test_tensor_save.cc @@ -7,7 +7,7 @@ namespace infini { TEST(Prtotbuf, save_and_load) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 3, 4}, DataType::Float32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32); diff --git a/test/kernels/bang/test_bang_bangcKernel.cc b/test/kernels/bang/test_bang_bangcKernel.cc index f29a0e50..8c4e62e7 100644 --- a/test/kernels/bang/test_bang_bangcKernel.cc +++ b/test/kernels/bang/test_bang_bangcKernel.cc @@ -15,7 +15,7 @@ void testBangcKernel( const std::function &generator, const Shape &shape) { // Runtime - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto bangRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/bang/test_bang_element_wise.cc b/test/kernels/bang/test_bang_element_wise.cc index b7629786..7dc6ac3e 100644 --- a/test/kernels/bang/test_bang_element_wise.cc +++ b/test/kernels/bang/test_bang_element_wise.cc @@ -13,7 +13,7 @@ template void testElementWiseCnnl( const std::function &generator, const Shape &shape, const ExpectOutput &ansVec) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto bangRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/bang/test_bang_optensor.cc b/test/kernels/bang/test_bang_optensor.cc index 9e42f461..436ab6dc 100644 --- a/test/kernels/bang/test_bang_optensor.cc +++ b/test/kernels/bang/test_bang_optensor.cc @@ -13,7 +13,7 @@ void testOptensor( const std::function &generator, const Shape &shape) { // Runtime - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto bangRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_G2BMM.cc b/test/kernels/cuda/test_cuda_G2BMM.cc index 24ee5a4d..e92ec6a2 100644 --- a/test/kernels/cuda/test_cuda_G2BMM.cc +++ b/test/kernels/cuda/test_cuda_G2BMM.cc @@ -13,7 +13,7 @@ using ExpectOutput = vector; TEST(CUDA_G2BMM, ShapeInference) { const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4; const int hidden = featlen, hiddenPerHead = hidden / heads; - auto cpuRuntime = CpuRuntimeObj::getInstance(); + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpuRuntime); auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_GBMM.cc b/test/kernels/cuda/test_cuda_GBMM.cc index 99d96388..ea6dd3ce 100644 --- a/test/kernels/cuda/test_cuda_GBMM.cc +++ b/test/kernels/cuda/test_cuda_GBMM.cc @@ -12,7 +12,7 @@ using ExpectOutput = vector; TEST(CUDA_GBMM, ShapeInference) { const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4; const int hidden = featlen, hiddenPerHead = hidden / heads; - auto cpuRuntime = CpuRuntimeObj::getInstance(); + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpuRuntime); auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, w * 2 + 1}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_batch_norm.cc b/test/kernels/cuda/test_cuda_batch_norm.cc index 0f92710d..cf700e4a 100644 --- a/test/kernels/cuda/test_cuda_batch_norm.cc +++ b/test/kernels/cuda/test_cuda_batch_norm.cc @@ -8,16 +8,16 @@ namespace infini { TEST(CUDA_BatchNorm, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build cpu graph Graph gCpu = make_ref(cpuRuntime); auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); - auto meanCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto varCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto scaleCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto biasCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32); // Build input data on CPU gCpu->dataMalloc(); diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 12d937c9..41832e82 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -44,7 +44,7 @@ TEST(Concat, OffsetTrans) { } */ TEST(Concat, Cuda) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_conv.cc b/test/kernels/cuda/test_cuda_conv.cc index 2c0a6419..657ecd17 100644 --- a/test/kernels/cuda/test_cuda_conv.cc +++ b/test/kernels/cuda/test_cuda_conv.cc @@ -13,7 +13,7 @@ void testConvCudnn( const std::function &generator, vector ansVec) { // Construct Runtime and graph for CPU and CUDA - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -52,7 +52,7 @@ TEST(cuDNN_Conv, run) { } TEST(cuDNN_Conv, tune) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); diff --git a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc index 692c8cb1..9aef0f39 100644 --- a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc +++ b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc @@ -16,7 +16,7 @@ void testConvTransposedCudnn( const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; const int stride = 1, padding = 0, dilation = 1; // Construct Runtime and graph for CPU and CUDA - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -50,7 +50,7 @@ void testConvTransposedNHWCCudnn( const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 2, 4, 4}; const int stride = 1, padding = 0, dilation = 1; // Construct Runtime and graph for CPU and CUDA - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -94,8 +94,42 @@ TEST(cuDNN_ConvTransposedNHWC, run) { 465, 487, 509, 307}); } +TEST(cuDNN_ConvTransposed, run1) { + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 2, 3, 3}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 2, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 0, 0); + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(vector{ + 162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553, + 747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835, + 396, 843, 1343, 953, 506, 243, 531, 866, 629, 341, + 621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518, + 963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731})); +} + TEST(cuDNN_ConvTransposed, tune) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); @@ -117,8 +151,6 @@ TEST(cuDNN_ConvTransposed, tune) { // Execute on CUDA bool tune = true; cuda->run(gCuda, tune); - // print a tensor/operator/graph by print() - gCuda->print(); // check record auto kernelAttrs = KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32}; diff --git a/test/kernels/cuda/test_cuda_element_wise.cc b/test/kernels/cuda/test_cuda_element_wise.cc index b242af48..a5c04f77 100644 --- a/test/kernels/cuda/test_cuda_element_wise.cc +++ b/test/kernels/cuda/test_cuda_element_wise.cc @@ -14,7 +14,7 @@ template void testElementWiseCudnn( const std::function &generator, const Shape &shape, const ExpectOutput &ansVec) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_extend.cc b/test/kernels/cuda/test_cuda_extend.cc index 197246c2..a0f431f3 100644 --- a/test/kernels/cuda/test_cuda_extend.cc +++ b/test/kernels/cuda/test_cuda_extend.cc @@ -10,7 +10,7 @@ namespace infini { TEST(CUDA_Extend, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_gather.cc b/test/kernels/cuda/test_cuda_gather.cc index 807ea0ff..90620a89 100644 --- a/test/kernels/cuda/test_cuda_gather.cc +++ b/test/kernels/cuda/test_cuda_gather.cc @@ -176,7 +176,7 @@ TEST(Gather, offsetTrans) { TEST(Gather, Cuda) { { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({3, 2}, DataType::Float32); auto index = gCpu->addTensor({2, 2}, DataType::UInt32); @@ -197,7 +197,7 @@ TEST(Gather, Cuda) { EXPECT_TRUE(oCpu->equalData(vector{1, 2, 3, 4, 3, 4, 5, 6})); } { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({3, 3}, DataType::Float32); auto index = gCpu->addTensor({1, 2}, DataType::UInt32); @@ -218,7 +218,7 @@ TEST(Gather, Cuda) { EXPECT_TRUE(oCpu->equalData(vector{0, 2, 3, 5, 6, 8})); } { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); auto index = gCpu->addTensor({3, 1}, DataType::UInt32); diff --git a/test/kernels/cuda/test_cuda_matmul.cc b/test/kernels/cuda/test_cuda_matmul.cc index 26d5e3d2..f52fc2f1 100644 --- a/test/kernels/cuda/test_cuda_matmul.cc +++ b/test/kernels/cuda/test_cuda_matmul.cc @@ -16,7 +16,7 @@ void testMatmulCuda( const std::function &generatorB, bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, const ExpectOutput &ansVec) { - auto cpuRuntime = CpuRuntimeObj::getInstance(); + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpuRuntime); auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); @@ -54,7 +54,7 @@ TEST(cuBLAS_Matmul, run) { } TEST(cuBLAS_Matmul, tune) { - auto cpuRuntime = CpuRuntimeObj::getInstance(); + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpuRuntime); auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_pad.cc b/test/kernels/cuda/test_cuda_pad.cc index e157114d..ead88962 100644 --- a/test/kernels/cuda/test_cuda_pad.cc +++ b/test/kernels/cuda/test_cuda_pad.cc @@ -7,7 +7,7 @@ namespace infini { TEST(Pad, Cuda) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_pooling.cc b/test/kernels/cuda/test_cuda_pooling.cc index 3f341591..f055a881 100644 --- a/test/kernels/cuda/test_cuda_pooling.cc +++ b/test/kernels/cuda/test_cuda_pooling.cc @@ -14,7 +14,7 @@ void testPoolCudnn( const std::function &generator, const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) { EXPECT_TRUE(kdps.size() == 8); - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_reduce_mean.cc b/test/kernels/cuda/test_cuda_reduce_mean.cc index ff309635..830c49c4 100644 --- a/test/kernels/cuda/test_cuda_reduce_mean.cc +++ b/test/kernels/cuda/test_cuda_reduce_mean.cc @@ -12,7 +12,7 @@ namespace infini { void test_reducemean(const Shape &shape, const vector &data, const optional> &axis, bool keepDims, const vector &ExpectData) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_reshape.cc b/test/kernels/cuda/test_cuda_reshape.cc index 4dae2430..843caa4b 100644 --- a/test/kernels/cuda/test_cuda_reshape.cc +++ b/test/kernels/cuda/test_cuda_reshape.cc @@ -10,7 +10,7 @@ namespace infini { TEST(CUDA_Reshape, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU @@ -39,7 +39,7 @@ TEST(CUDA_Reshape, run) { } TEST(CUDA_Flatten, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU @@ -68,7 +68,7 @@ TEST(CUDA_Flatten, run) { } TEST(CUDA_Identity, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_resize.cc b/test/kernels/cuda/test_cuda_resize.cc index b462f516..7b096790 100644 --- a/test/kernels/cuda/test_cuda_resize.cc +++ b/test/kernels/cuda/test_cuda_resize.cc @@ -7,7 +7,7 @@ #include "test.h" namespace infini { TEST(Resize, Cuda_downsample_sizes_nearest) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); @@ -32,7 +32,7 @@ TEST(Resize, Cuda_downsample_sizes_nearest) { } TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -62,7 +62,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) { } TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -92,7 +92,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) { } TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -125,7 +125,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) { } TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -158,7 +158,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) { } TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -191,7 +191,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) { } TEST(Resize, Cuda_downsample_scales_nearest) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); @@ -215,7 +215,7 @@ TEST(Resize, Cuda_downsample_scales_nearest) { } TEST(Resize, Cuda_upsample_scales_nearest) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -241,7 +241,7 @@ TEST(Resize, Cuda_upsample_scales_nearest) { } TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -267,7 +267,7 @@ TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) { } TEST(Resize, Cuda_downsample_scales_linear) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); @@ -291,7 +291,7 @@ TEST(Resize, Cuda_downsample_scales_linear) { } TEST(Resize, Cuda_downsample_scales_linear_aligncorners) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); @@ -317,7 +317,7 @@ TEST(Resize, Cuda_downsample_scales_linear_aligncorners) { } TEST(Resize, Cuda_upsample_scales_linear) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -343,7 +343,7 @@ TEST(Resize, Cuda_upsample_scales_linear) { } TEST(Resize, Cuda_upsample_scales_linear_align_corners) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); @@ -371,7 +371,7 @@ TEST(Resize, Cuda_upsample_scales_linear_align_corners) { } TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -399,7 +399,7 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) { } TEST(Resize, Cuda_tf_crop_and_resize) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -430,7 +430,7 @@ TEST(Resize, Cuda_tf_crop_and_resize) { } TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -461,7 +461,7 @@ TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) { } TEST(Resize, Cuda_downsample_scales_cubic) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -488,7 +488,7 @@ TEST(Resize, Cuda_downsample_scales_cubic) { } TEST(Resize, Cuda_downsample_scales_cubic_align_corners) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -516,7 +516,7 @@ TEST(Resize, Cuda_downsample_scales_cubic_align_corners) { } TEST(Resize, Cuda_upsample_scales_cubic) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -553,7 +553,7 @@ TEST(Resize, Cuda_upsample_scales_cubic) { } TEST(Resize, Cuda_upsample_scales_cubic_align_corners) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -592,7 +592,7 @@ TEST(Resize, Cuda_upsample_scales_cubic_align_corners) { } TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -627,7 +627,7 @@ TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) { // TEST(Resize, Cuda_downsample_sizes_cubic) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); @@ -661,7 +661,7 @@ TEST(Resize, Cuda_downsample_sizes_cubic) { } TEST(Resize, Cuda_upsample_sizes_cubic) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_slice.cc b/test/kernels/cuda/test_cuda_slice.cc index 0657a559..14926ea3 100644 --- a/test/kernels/cuda/test_cuda_slice.cc +++ b/test/kernels/cuda/test_cuda_slice.cc @@ -7,7 +7,7 @@ namespace infini { TEST(CUDA_Slice, run) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 484d2ad3..9b68a70c 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -9,7 +9,7 @@ namespace infini { TEST(Split, Cuda) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({2, 10, 2, 1}, DataType::Float32); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 8b463121..6aac66d4 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -13,7 +13,7 @@ template void testUnary(const std::function &generator, const Shape &shape) { // Runtime - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/cuda/test_perfengine.cc b/test/kernels/cuda/test_perfengine.cc index 6230f0a2..7bfac62d 100644 --- a/test/kernels/cuda/test_perfengine.cc +++ b/test/kernels/cuda/test_perfengine.cc @@ -11,7 +11,7 @@ namespace infini { TEST(PerfEngine, save_and_load) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); { // Conv diff --git a/test/kernels/mkl/test_mkl_conv.cc b/test/kernels/mkl/test_mkl_conv.cc new file mode 100644 index 00000000..4ba5fd7f --- /dev/null +++ b/test/kernels/mkl/test_mkl_conv.cc @@ -0,0 +1,66 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "core/runtime.h" +#include "mkl/mkl_runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvDnnl( + const std::function &generator, + vector ansVec) { + auto mklRuntime = MklRuntimeObj::getInstance(); + Graph gMkl = make_ref(mklRuntime); + + Tensor i0 = gMkl->addTensor({1, 3, 4, 4}, DataType::Float32); + Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. + gMkl->dataMalloc(); + i0->setData(generator); + w0->setData(generator); + + // Build graph + auto conv = gMkl->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); + // allocate CUDA memory + gMkl->dataMalloc(); + // Execute on CUDA + mklRuntime->run(gMkl); + // check results on CPU + EXPECT_TRUE(conv->getOutput(0)->equalData(ansVec)); +} + +TEST(dnnl_Conv, run) { + testConvDnnl(OneGenerator(), vector{12, 12, 18, 18, 12, 12, 18, 18}); + testConvDnnl( + IncrementalGenerator(), + vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); +} + +TEST(mkl_Conv, tune) { + auto mklRuntime = MklRuntimeObj::getInstance(); + Graph gMkl = make_ref(mklRuntime); + + Tensor i0 = gMkl->addTensor({1, 3, 224, 224}, DataType::Float32); + Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32); + auto conv = gMkl->addOp(i0, w0, nullptr, 1, 1, 1, 1, 1, 1); + gMkl->dataMalloc(); + + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + // Execute on CUDA + bool tune = true; + mklRuntime->run(gMkl, tune); + + // check record + auto kernelAttrs = + KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32}; + auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; + std::optional perfData = + PerfEngine::getInstance().getPerfData(perfKey); + ASSERT_TRUE(perfData.has_value()); +} +} // namespace infini diff --git a/test/kernels/mkl/test_mkl_conv_transposed.cc b/test/kernels/mkl/test_mkl_conv_transposed.cc new file mode 100644 index 00000000..ab869896 --- /dev/null +++ b/test/kernels/mkl/test_mkl_conv_transposed.cc @@ -0,0 +1,84 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "mkl/mkl_runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvTransposedMkl( + const std::function &generator, + vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + + Runtime runtime = MklRuntimeObj::getInstance(); + Graph gMkl = make_ref(runtime); + // Set input data on CPU in a CPU Graph + Tensor i0 = gMkl->addTensor({N, F, H, H}, DataType::Float32); + Tensor w0 = gMkl->addTensor({F, C, R, S}, DataType::Float32); + auto conv = gMkl->addOp( + i0, w0, nullptr, padding, padding, stride, stride, dilation, dilation); + + gMkl->dataMalloc(); + i0->setData(generator); + w0->setData(generator); + + runtime->prepareAndRun(gMkl); + EXPECT_TRUE(conv->getOutput()->equalData(ansVec)); +} + +TEST(mkl_ConvTransposed, run) { + testConvTransposedMkl(IncrementalGenerator(), + vector{0., 0., 1., 2., 3., 0., 6., + 12., 18., 16., 8., 30., 36., 42., + 32., 16., 54., 60., 66., 48., 24., + 62., 67., 72., 45.}); +} + +TEST(mkl_ConvTransposed, run1) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph gMkl = make_ref(runtime); + // Set input data on CPU in a CPU Graph + Tensor i0 = gMkl->addTensor({1, 2, 3, 3}, DataType::Float32); + Tensor w0 = gMkl->addTensor({2, 2, 3, 3}, DataType::Float32); + auto conv = gMkl->addOp(i0, w0, nullptr, 0, 0); + + gMkl->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + runtime->prepareAndRun(gMkl); + EXPECT_TRUE(conv->getOutput()->equalData(vector{ + 162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553, + 747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835, + 396, 843, 1343, 953, 506, 243, 531, 866, 629, 341, + 621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518, + 963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731})); +} + +TEST(mkl_ConvTransposed, tune) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph gMkl = make_ref(runtime); + + Tensor i0 = gMkl->addTensor({1, 448, 2, 2}, DataType::Float32); + Tensor w0 = gMkl->addTensor({448, 256, 4, 4}, DataType::Float32); + auto conv = gMkl->addOp(i0, w0, nullptr); + gMkl->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + + bool tune = true; + runtime->prepareAndRun(gMkl, tune); + // check record + auto kernelAttrs = + KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32}; + auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; + std::optional perfData = + PerfEngine::getInstance().getPerfData(perfKey); + ASSERT_TRUE(perfData.has_value()); +} + +} // namespace infini diff --git a/test/kernels/mkl/test_mkl_matmul.cc b/test/kernels/mkl/test_mkl_matmul.cc new file mode 100644 index 00000000..e919ffd4 --- /dev/null +++ b/test/kernels/mkl/test_mkl_matmul.cc @@ -0,0 +1,44 @@ + +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "mkl/mkl_runtime.h" +#include "operators/matmul.h" + +#include "test.h" + +namespace infini { +using ExpectOutput = vector; + +void testMatmulMkl( + const std::function &generatorA, + const std::function &generatorB, + bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, + const ExpectOutput &ansVec) { + auto cpuRuntime = MklRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); + auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(generatorA); + BCpu->setData(generatorB); + + auto matmul = gCpu->addOp(ACpu, BCpu, nullptr, transA, transB); + + gCpu->dataMalloc(); + cpuRuntime->run(gCpu); + matmul->getOutput()->printData(); + EXPECT_TRUE(matmul->getOutput()->equalData(ansVec)); +} + +TEST(mkl_Matmul, run) { + testMatmulMkl(IncrementalGenerator(), OneGenerator(), false, false, + Shape{1, 3, 5}, Shape{1, 5, 2}, + ExpectOutput{10, 10, 35, 35, 60, 60}); + testMatmulMkl(IncrementalGenerator(), IncrementalGenerator(), true, false, + Shape{2, 3, 4}, Shape{2, 3, 2}, + ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424, + 475, 448, 502, 472, 529}); +} + +}; // namespace infini diff --git a/test/nnet/test_memboundOp.cc b/test/nnet/test_memboundOp.cc index 46fb2157..9f1847d6 100644 --- a/test/nnet/test_memboundOp.cc +++ b/test/nnet/test_memboundOp.cc @@ -12,7 +12,7 @@ using namespace infini; using namespace std; TEST(nnet, MemboundOpInterpretation) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -42,7 +42,7 @@ TEST(nnet, MemboundOpInterpretation) { TEST(nnet, MemboundOp_Ansor_Codegen) { auto runtime = make_ref(); - Runtime cpu = CpuRuntimeObj::getInstance(); + Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpu); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32); @@ -76,4 +76,4 @@ TEST(nnet, MemboundOp_Ansor_Codegen) { // Timing // double time = timeit([&]() { runtime->run(gNew, false); }); // tune // kernels std::cout << "Time (ms):" << time << std::endl; -} \ No newline at end of file +} diff --git a/test/nnet/test_mutator.cc b/test/nnet/test_mutator.cc index f7dfe344..0b9411bd 100644 --- a/test/nnet/test_mutator.cc +++ b/test/nnet/test_mutator.cc @@ -13,7 +13,7 @@ namespace infini { TEST(Mutator, NaiveConvWithInterpreter) { // verifyNaiveMembound True: subgraph after transformation // verifyNaiveMembound False: subgraph of one single membound (eOP) - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); // const bool verifyNaiveMembound = false; @@ -61,7 +61,7 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) { // const bool verifyNaiveMembound = false; Runtime runtime = make_ref(); Graph g = make_ref(runtime); - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); // {n, h, w, f} * {f, r, s, c} diff --git a/test/operators/test_batch_norm.cc b/test/operators/test_batch_norm.cc index 438db6e6..e2ef15ce 100644 --- a/test/operators/test_batch_norm.cc +++ b/test/operators/test_batch_norm.cc @@ -5,7 +5,7 @@ namespace infini { TEST(BatchNorm, ShapeInference) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32); diff --git a/test/operators/test_concat.cc b/test/operators/test_concat.cc index 8c5da64b..9a0fe74e 100644 --- a/test/operators/test_concat.cc +++ b/test/operators/test_concat.cc @@ -5,7 +5,7 @@ namespace infini { TEST(Concat, ShapeInfer) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32); auto t2 = g->addTensor({1, 3, 2, 5}, DataType::Float32); diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index 420b4ab8..8ab50c0e 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Conv, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); // Padding modes { Graph g = make_ref(runtime); @@ -43,7 +43,7 @@ TEST(Conv, ShapeInference) { } TEST(Conv, NaiveCPU) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); diff --git a/test/operators/test_conv_transposed_2d.cc b/test/operators/test_conv_transposed_2d.cc index 9ce8d7d8..68039da7 100644 --- a/test/operators/test_conv_transposed_2d.cc +++ b/test/operators/test_conv_transposed_2d.cc @@ -9,7 +9,7 @@ namespace infini { TEST(ConvTransposed, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { // No pad: InfoGAN ConvTranspose_0 Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 228, 1, 1}); diff --git a/test/operators/test_element_wise.cc b/test/operators/test_element_wise.cc index 68fdc7b9..a1ffa708 100644 --- a/test/operators/test_element_wise.cc +++ b/test/operators/test_element_wise.cc @@ -9,7 +9,7 @@ namespace infini { using ExpectOutput = vector; TEST(ElementWise, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i0 = g->addTensor({2, 3, 3, 4}, DataType::UInt32); diff --git a/test/operators/test_extend.cc b/test/operators/test_extend.cc index 5fbd4d8a..0cfa0703 100644 --- a/test/operators/test_extend.cc +++ b/test/operators/test_extend.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Extend, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); diff --git a/test/operators/test_gather.cc b/test/operators/test_gather.cc index 32dbac64..6d900d6a 100644 --- a/test/operators/test_gather.cc +++ b/test/operators/test_gather.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Gather, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32); diff --git a/test/operators/test_matmul.cc b/test/operators/test_matmul.cc index b17f8a3a..22d07a1a 100644 --- a/test/operators/test_matmul.cc +++ b/test/operators/test_matmul.cc @@ -10,7 +10,7 @@ namespace infini { using ExpectOutput = vector; TEST(Matmul, ShapeInference) { - auto runtime = CpuRuntimeObj::getInstance(); + auto runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); auto A = g->addTensor(Shape{1, 3, 5}); diff --git a/test/operators/test_pad.cc b/test/operators/test_pad.cc index 3d033927..831df437 100644 --- a/test/operators/test_pad.cc +++ b/test/operators/test_pad.cc @@ -5,7 +5,7 @@ namespace infini { TEST(Pad, ShapeInference) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); diff --git a/test/operators/test_pooling.cc b/test/operators/test_pooling.cc index 97715a63..8b4c52ed 100644 --- a/test/operators/test_pooling.cc +++ b/test/operators/test_pooling.cc @@ -7,7 +7,7 @@ namespace infini { using KDPS = vector; using ExpectOutput = vector; TEST(MaxPool, ShapeInference) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); @@ -27,7 +27,7 @@ TEST(MaxPool, ShapeInference) { } TEST(MaxPool, NaiveCPU) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32); auto op = g->addOp(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2); @@ -46,7 +46,7 @@ TEST(MaxPool, NaiveCPU) { } TEST(AvgPool, NaiveCPU) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32); auto op = g->addOp(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2); diff --git a/test/operators/test_reduce_mean.cc b/test/operators/test_reduce_mean.cc index c6f0784a..8c3d477e 100644 --- a/test/operators/test_reduce_mean.cc +++ b/test/operators/test_reduce_mean.cc @@ -8,7 +8,7 @@ namespace infini { TEST(ReduceMean, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); diff --git a/test/operators/test_reshape.cc b/test/operators/test_reshape.cc index a0c016c6..457a06ea 100644 --- a/test/operators/test_reshape.cc +++ b/test/operators/test_reshape.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Reshape, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); @@ -17,7 +17,7 @@ TEST(Reshape, ShapeInference) { } } TEST(Flatten, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); @@ -27,7 +27,7 @@ TEST(Flatten, ShapeInference) { } TEST(Identity, ShapeInference) { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); diff --git a/test/operators/test_resize.cc b/test/operators/test_resize.cc index 1c36da20..9079c3bf 100644 --- a/test/operators/test_resize.cc +++ b/test/operators/test_resize.cc @@ -5,7 +5,7 @@ namespace infini { TEST(Resize, ShapeInference) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); // downsample_sizes_nearest no axes { Graph g = make_ref(cpuRuntime); diff --git a/test/operators/test_slice.cc b/test/operators/test_slice.cc index a6717b03..436deada 100644 --- a/test/operators/test_slice.cc +++ b/test/operators/test_slice.cc @@ -5,7 +5,7 @@ namespace infini { TEST(Slice, ShapeInference) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32); diff --git a/test/operators/test_split.cc b/test/operators/test_split.cc index 6b336497..9914e37f 100644 --- a/test/operators/test_split.cc +++ b/test/operators/test_split.cc @@ -7,7 +7,7 @@ namespace infini { TEST(Split, ShapeInfer) { { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); @@ -21,7 +21,7 @@ TEST(Split, ShapeInfer) { } { - Runtime runtime = CpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); diff --git a/test/script/env_lotus.sh b/test/script/env_lotus.sh index d9c2e170..72268491 100644 --- a/test/script/env_lotus.sh +++ b/test/script/env_lotus.sh @@ -1,3 +1,5 @@ . /home/spack/spack/share/spack/setup-env.sh -spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 +spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc +# The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly. +export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH