diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f3dca02..13ce9cd1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ project(InfiniTensor C CXX) # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. option(USE_CUDA "Support CUDA GPU" OFF) option(USE_BANG "Support BANG MLU" OFF) -option(USE_MKL "Support MKL" OFF) +option(USE_INTELCPU "Support INTELCPU" OFF) option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) option(USE_PROTOBUF "Serialize and deserialize tensors" ON) option(BUILD_TEST "Build tests" ON) @@ -19,10 +19,6 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations") -set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion - find_package( Python COMPONENTS Interpreter Development @@ -35,6 +31,20 @@ endif() if(OpenMP_CXX_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") endif() + + +if(BUILD_TEST) + set(BUILD_GMOCK + OFF + CACHE BOOL "Do not build gmock" FORCE) + set(INSTALL_GTEST + OFF + CACHE BOOL "Do not install gtest" FORCE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall ") + add_subdirectory(3rd-party/googletest) + include_directories(SYSTEM 3rd-party/googletest/googletest/include) +endif() + #Protobuf if(USE_PROTOBUF) add_definitions(-D TENSOR_PROTOBUF) @@ -47,14 +57,12 @@ if(USE_PROTOBUF) set(PROTO_PATH "${CMAKE_CURRENT_SOURCE_DIR}/proto") file(GLOB PROTO_FILES "${PROTO_PATH}/data.proto") protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILES}) - message(${PROTO_SRCS} "-----------" ${PROTO_FILES}) - message(${PROTO_HDRS} "-----------" ${PROTO_FILES}) + set_source_files_properties (${PROTO_SRCS} PROPERTIES COMPILE_FLAGS -Wno-unused-variable) add_library(tensor_proto SHARED ${PROTO_SRCS} ${PROTO_HDRS}) target_link_libraries(tensor_proto PUBLIC ${PROTOBUF_LIBRARIES}) endif() include_directories(include) - # Pybind11 add_subdirectory(3rd-party/pybind11) include_directories(3rd-party/pybind11/include) @@ -63,16 +71,9 @@ include_directories(3rd-party/pybind11/include) add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) -if(BUILD_TEST) - set(BUILD_GMOCK - OFF - CACHE BOOL "Do not build gmock" FORCE) - set(INSTALL_GTEST - OFF - CACHE BOOL "Do not install gtest" FORCE) - add_subdirectory(3rd-party/googletest) - include_directories(3rd-party/googletest/googletest/include) -endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion # Source files file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc) @@ -87,9 +88,9 @@ if(USE_BANG) list (APPEND SRC ${SRC_BANG}) endif() -if(USE_MKL) - file(GLOB_RECURSE SRC_MKL src/mkl/*.cc src/kernels/mkl/*.cc ) - list (APPEND SRC ${SRC_MKL}) +if(USE_INTELCPU) + file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc ) + list (APPEND SRC ${SRC_INTELCPU}) endif() # Libraries @@ -113,19 +114,28 @@ if(USE_BACKTRACE) target_link_libraries(InfiniTensor dw) endif() -if(USE_MKL) +if(USE_INTELCPU) + add_compile_definitions(USE_INTELCPU=1) find_package(MKL CONFIG REQUIRED) - target_link_libraries(InfiniTensor $) + + # Refer to https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-link-line-advisor.html + target_link_libraries(InfiniTensor sycl OpenCL) + set(DNNL_CONFIGURATION "cpu_gomp") find_package(dnnl CONFIG REQUIRED) - if(dnnl_FOUND) - add_compile_definitions(USE_MKL=1) + if(dnnl_FOUND) include_directories(BEFORE ${dnnl_DIR}/../../../cpu_gomp/include/) link_directories(${dnnl_DIR}/../../../cpu_gomp/lib) - target_link_libraries(InfiniTensor dnnl) + target_link_libraries(InfiniTensor dnnl) else() - message(FATAL_ERROR ”dnnl library not found”) + message(FATAL_ERROR "dnnl library not found") endif() + set(WNO_ERRORS "-Wno-error=unused-parameter -Wno-error=unused-function -Wno-error=unused-private-field -Wno-error=ignored-attributes -Wno-error=unused-const-variable -Wno-error=inconsistent-missing-override -Wno-error=unused-variable -Wno-error=tautological-constant-compare") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMKL_ILP64 -qmkl=parallel -Werror ${WNO_ERRORS}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DMKL_ILP64 -qmkl=parallel ${WNO_ERRORS}") # Enable assertion + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -DMKL_ILP64 -qmkl=parallel ${WNO_ERRORS}") # Enable assertion + + find_package(IntelDPCPP REQUIRED) endif() if(USE_CUDA) @@ -210,8 +220,8 @@ if(BUILD_TEST) if (USE_BANG) build_test(test/kernels/bang/*.cc) endif() - if (USE_MKL) - build_test(test/kernels/mkl/*.cc) + if (USE_INTELCPU) + build_test(test/kernels/intelcpu/*.cc) endif() endif() if(BUILD_TEST_PET) diff --git a/Makefile b/Makefile index 4de85392..3df4c34d 100644 --- a/Makefile +++ b/Makefile @@ -2,16 +2,21 @@ TYPE ?= release CUDA ?= off +INTELCPU ?= off -CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) +CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) ifeq ($(CUDA), ON) CMAKE_OPT += -DUSE_CUDA=ON endif +ifeq ($(INTELCPU), ON) + CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp +endif + build: mkdir -p build/$(TYPE) - cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8 + cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j22 clean: rm -rf build diff --git a/README.md b/README.md index 1f6c07fa..ea404705 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,19 @@ # InfiniTensor ## Compilation on Lotus - +# Compilation for cuda ``` bash # Enter the root of InfiniTensor source test/script/env_lotus.sh make CUDA=ON ``` +## Compilation for intelcpu +``` bash +# Enter the root of InfiniTensor +source test/script/env_lotus.sh intelcpu +mkdir build && cd build +cmake -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp .. && make -j 12 +``` ### Make Commands diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 9936c637..a76197de 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -66,10 +66,10 @@ class GraphHandlerObj { Tensor relu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y); - Tensor softmax(Tensor x, Tensor y); + Tensor softmax(Tensor x, Tensor y, int axis); Tensor abs(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y); - Tensor flatten(Tensor s, Tensor y); + Tensor flatten(Tensor s, Tensor y, int axis); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/core/runtime.h b/include/core/runtime.h index 8e7be034..60b7ad72 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -28,7 +28,7 @@ using OpVec = vector; using VType = uint32_t; -enum class Device { CPU = 1, CUDA, BANG, MKL }; +enum class Device { CPU = 1, CUDA, BANG, INTELCPU }; /***************** Forward declaration end *****************/ class RuntimeObj : public std::enable_shared_from_this { @@ -53,7 +53,6 @@ class RuntimeObj : public std::enable_shared_from_this { bool profiling = false) const = 0; virtual void *alloc(size_t size) = 0; virtual void dealloc(void *ptr) = 0; - void prepareAndRun(Graph &graph, bool tune = false, bool profiling = false); /** * @brief Get the execution time of each operator in performance record. No * execution happens. @@ -65,7 +64,7 @@ class RuntimeObj : public std::enable_shared_from_this { double getPerfTime(const Graph &graph, bool profiling = false) const; Blob allocBlob(size_t size); bool isCpu() const { - return device == Device::CPU || device == Device::MKL; + return device == Device::CPU || device == Device::INTELCPU; } bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } diff --git a/include/cuda/softmax.h b/include/cuda/softmax.h new file mode 100644 index 00000000..5c0eccf9 --- /dev/null +++ b/include/cuda/softmax.h @@ -0,0 +1,6 @@ +#pragma once + +namespace infini { +void softmax_kernel(int max_threadblock_size, int batch_size, float *x, + float *y, int dim, int stride); +} diff --git a/include/intelcpu/mkl_kernel_without_config.h b/include/intelcpu/mkl_kernel_without_config.h new file mode 100644 index 00000000..d197f675 --- /dev/null +++ b/include/intelcpu/mkl_kernel_without_config.h @@ -0,0 +1,40 @@ +#pragma once +#include "core/kernel.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { + +class MklKernelWithoutConfig : public Kernel { + public: + virtual void compute(const Operator &op, const PerfRecord &record, + const RuntimeObj *_context) const override { + compute(op, _context); + auto context = dynamic_cast(_context); + context->sync(); + } + virtual void compute(const Operator &op, + const RuntimeObj *context) const = 0; + // Premise: op is idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op, + const RuntimeObj *_context) const override { + auto context = dynamic_cast(_context); + return make_ref(timeit([&]() { compute(op, _context); }, + [&]() { context->sync(); })); + } + + protected: + dnnl::memory::format_tag getUserFormatTag(int nDim) const { + if (nDim == 2) + return dnnl::memory::format_tag::nc; + else if (nDim == 3) + return dnnl::memory::format_tag::ncw; + else if (nDim == 4) + return dnnl::memory::format_tag::nchw; + else if (nDim == 5) + return dnnl::memory::format_tag::ncdhw; + else + IT_TODO_HALT(); + } +}; + +} // namespace infini diff --git a/include/mkl/mkl_runtime.h b/include/intelcpu/mkl_runtime.h similarity index 81% rename from include/mkl/mkl_runtime.h rename to include/intelcpu/mkl_runtime.h index 6cfc7993..e8be877f 100644 --- a/include/mkl/mkl_runtime.h +++ b/include/intelcpu/mkl_runtime.h @@ -7,9 +7,9 @@ #include #include namespace infini { -// TODO move utility function to alone file class MklRuntimeObj : public CpuRuntimeObj { dnnl_engine_t engine; + dnnl_stream_t stream; public: MklRuntimeObj(); @@ -26,8 +26,10 @@ class MklRuntimeObj : public CpuRuntimeObj { sizeof(uint64_t), 64); }; - string toString() const override { return "CPU MKL Runtime"; }; + string toString() const override { return "INTELCPU Runtime"; }; dnnl::engine getEngine() const { return dnnl::engine(engine, true); } + dnnl::stream getStream() const { return dnnl::stream(stream, true); } + void sync() const; }; } // namespace infini diff --git a/include/mkl/operator_timer.h b/include/intelcpu/operator_timer.h similarity index 100% rename from include/mkl/operator_timer.h rename to include/intelcpu/operator_timer.h diff --git a/include/nnet/routine.h b/include/nnet/routine.h index 52e0d637..e5bd89f4 100644 --- a/include/nnet/routine.h +++ b/include/nnet/routine.h @@ -22,6 +22,7 @@ class RoutineNode { public: RoutineNode(Expr _expr, const vector &_inputs); + virtual ~RoutineNode() {} virtual string toReadable() const = 0; const Expr &getExpr() const { return expr; } const vector &getInputs() const { return inputs; } @@ -147,4 +148,4 @@ std::ostream &operator<<(std::ostream &os, const Ref &a) { return os; } -} // namespace nnet \ No newline at end of file +} // namespace nnet diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 31cc3576..907bbcbb 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -42,6 +42,7 @@ class ReshapeObj : public OperatorObj { * */ class FlattenObj : public OperatorObj { + int axis; public: /** @@ -51,7 +52,7 @@ class FlattenObj : public OperatorObj { * @param input The input tensor. * @param output The output one-dimensional tensor. */ - FlattenObj(GraphObj *graph, Tensor input, Tensor output); + FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis); OP_CLONE(FlattenObj); optional> inferShape(const TensorVec &inputs) const override; diff --git a/include/operators/resize.h b/include/operators/resize.h index 4cc328dc..a762ea30 100644 --- a/include/operators/resize.h +++ b/include/operators/resize.h @@ -75,6 +75,9 @@ class ResizeObj : public OperatorObj { IT_ASSERT((size_t)i < scales.size()); return scales.at(i); } + + vector getScales() const { return scales; } + float getRoi(int i) const { if (coMode == ECoordinateTransMode::tfCropAndResize) { IT_ASSERT(size_t(i) < roi.size()); diff --git a/include/operators/softmax.h b/include/operators/softmax.h new file mode 100644 index 00000000..0611f63f --- /dev/null +++ b/include/operators/softmax.h @@ -0,0 +1,27 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class SoftmaxObj : public OperatorObj { + int axis; + + public: + SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int axis); + + OP_CLONE(SoftmaxObj); + + optional> inferShape(const TensorVec &inputs) const override { + return {{inputs[0]->getDims()}}; + }; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + int getAxis() const { return axis; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h index 111a9314..e75025c5 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -39,6 +39,7 @@ class UnaryObj : public OperatorObj { DEFINE_UNARY_OBJ(Relu, OpType::Relu) DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) -DEFINE_UNARY_OBJ(Softmax, OpType::Softmax) +// DEFINE_UNARY_OBJ(Softmax, OpType::Softmax) DEFINE_UNARY_OBJ(Abs, OpType::Abs) + }; // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 9a6afd21..360f5aaa 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -25,12 +25,7 @@ from onnx.shape_inference import infer_shapes from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce -cpu_runtime = backend.cpu_runtime() - - -def cuda_runtime(): - return backend.cuda_runtime() - +runtime = backend.runtime() class OnnxStub: inputs: Dict[str, backend.Tensor] = {} @@ -253,6 +248,7 @@ class OnnxStub: tensors[node.output[0]] = self.handler.softmax( tensors[node.input[0]], tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), ) elif node.op_type == "Abs": tensors[node.output[0]] = self.handler.abs( @@ -265,14 +261,11 @@ class OnnxStub: tensors.get(node.output[0]), ) elif node.op_type == "Flatten": - # FIXME axis must be 1 - axis = next( - (attr.i for attr in node.attribute if attr.name == "axis"), None - ) - assert axis == None or axis == 1 + tensors[node.output[0]] = self.handler.flatten( tensors[node.input[0]], tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), ) elif node.op_type == "Reshape": input_shape = next( @@ -583,6 +576,9 @@ def from_onnx(model: ModelProto, runtime): stub = OnnxStub(model, runtime) return stub.inputs, stub.outputs, stub.handler +def run_onnx(model: ModelProto, runtime): + stub = OnnxStub(model, runtime) + stub.run() def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 150a96e7..99ad3b9d 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,16 +8,28 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime +from pyinfinitensor.onnx import from_onnx, backend, runtime, run_onnx def make_and_import_model(graph: onnx.GraphProto): model = make_model(graph) check_model(model) - from_onnx(model, cpu_runtime) + from_onnx(model, runtime) class TestStringMethods(unittest.TestCase): + #def test_run(self): + # model_file = next( + # (name for name in os.listdir() if name.endswith(".onnx")), None + # ) + # if model_file != None: + # print( + # "model: {file}({size:.2f} MiB)".format( + # file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 + # ) + # ) + # run_onnx(onnx.load(model_file), runtime) + def test_load(self): model_file = next( (name for name in os.listdir() if name.endswith(".onnx")), None @@ -28,7 +40,7 @@ class TestStringMethods(unittest.TestCase): file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 ) ) - from_onnx(onnx.load(model_file), cpu_runtime) + from_onnx(onnx.load(model_file), runtime) def test_tensor(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) @@ -177,7 +189,7 @@ class TestStringMethods(unittest.TestCase): def test_softmax(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) - softmax = make_node("Softmax", ["x"], ["y"], name="softmax") + softmax = make_node("Softmax", ["x"], ["y"], axis=2, name="softmax") make_and_import_model(make_graph([softmax], "softmax", [x], [y])) def test_abs(self): @@ -194,9 +206,8 @@ class TestStringMethods(unittest.TestCase): def test_flatten(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) - y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 1 * 3 * 5 * 7]) - flatten = make_node("Flatten", ["x"], ["y"], name="flatten") - # FIXME 后端要求产生 Π(dims) 长的一维张量,onnx 产生 1×Π(dims) 的二维张量 + y = make_tensor_value_info("y", TensorProto.FLOAT, [1*3, 5 * 7]) + flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten") # make_and_import_model( make_graph([flatten], "flatten", [x], [y]) # ) @@ -289,10 +300,10 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([matmul, add], "lr", [x, a, b], [y]) model = make_model(graph) check_model(model) - from_onnx(model, cpu_runtime) + from_onnx(model, runtime) def test_frontend(self): - handler = backend.GraphHandler(cpu_runtime) + handler = backend.GraphHandler(runtime) a = handler.tensor([1, 2, 3], 12) b = handler.tensor([1, 2, 3], 12) c = handler.tensor([1, 2, 3], 12) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index ac8bfb21..424ca276 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -10,6 +10,7 @@ #include "operators/reduce_mean.h" #include "operators/reshape.h" #include "operators/slice.h" +#include "operators/softmax.h" #include "operators/unary.h" namespace infini { @@ -126,11 +127,29 @@ DEFINE_ELEMENT_WISE_METHOD(pow, Pow) DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) -DEFINE_UNARY_METHOD(softmax, Softmax) DEFINE_UNARY_METHOD(abs, Abs) // see operators/reshape.h DEFINE_UNARY_METHOD(identity, Identity) -DEFINE_UNARY_METHOD(flatten, Flatten) + +Tensor GraphHandlerObj::softmax(Tensor input, Tensor output, int axis) { + if (output) { + g->addOpWithOutputs(std::move(input), output, axis); + return output; + } else { + return g->addOp(std::move(input), output, axis) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::flatten(Tensor input, Tensor output, int axis) { + if (output) { + g->addOpWithOutputs(std::move(input), output, axis); + return output; + } else { + return g->addOp(std::move(input), output, axis) + ->getOutput(); + } +} Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { if (reshaped) { diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 71ee6bdd..1e1e7c1d 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -6,10 +6,6 @@ #include #include namespace infini { -void RuntimeObj::prepareAndRun(Graph &graph, bool tune, bool profiling) { - run(graph, tune, profiling); -} - void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { if (!tune && profiling) IT_TODO_HALT(); diff --git a/src/core/tensor.cc b/src/core/tensor.cc index cc2de201..cdcd9e28 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -9,7 +9,7 @@ namespace infini { TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime) - : TensorBaseObj(shape.size(), dtype, runtime), shape(std::move(shape_)), + : TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)), _size(shape.empty() ? 0 : std::accumulate(shape.begin(), shape.end(), 1, diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index cfbdcb9f..83490959 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -5,7 +5,7 @@ __global__ void cudaPrintFloatImpl(float *x, int len) { int start = threadIdx.x + blockDim.x * blockIdx.x; if (start == 0) { for (int i = 0; i < len; ++i) { - printf("%.3f ", x[i]); + printf("%.7f ", x[i]); } printf("\n"); } @@ -18,4 +18,4 @@ void cudaPrintFloat(float *x, int len) { cudaDeviceSynchronize(); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 7414d4d5..1b58abd5 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -12,8 +12,9 @@ #include "cuda/cuda_runtime.h" #include "cuda/operator_timer.h" #endif -#ifdef USE_MKL -#include "mkl/operator_timer.h" +#ifdef USE_INTELCPU +#include "intelcpu/mkl_runtime.h" +#include "intelcpu/operator_timer.h" #endif namespace py = pybind11; @@ -30,7 +31,7 @@ void register_operator_timer(py::module &m) { m.def("getPerfMatmulCublas", &getPerfMatmulCublas); #endif -#ifdef USE_MKL +#ifdef USE_INTELCPU using namespace opTimer; m.def("getPerfConvMkl", &getPerfConvMkl); m.def("getPerfConvTransposed2dMkl", &getPerfConvTransposed2dMkl); @@ -111,6 +112,10 @@ static int tensor_dtype(Tensor t) { static Ref cuda_runtime() { return make_ref(); } #endif +#ifdef USE_INTELCPU +static Ref intelcpu_runtime() { return make_ref(); } +#endif + static std::tuple conv_attrs_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Conv); auto conv = dynamic_cast(op.get()); @@ -158,10 +163,14 @@ static Shape reshape_shape_of(Operator op) { void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) - m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) #ifdef USE_CUDA - .FUNCTION(cuda_runtime) + m.def("runtime", cuda_runtime) +#elif USE_INTELCPU + m.def("runtime", intelcpu_runtime) +#else + m.def("runtime", &NativeCpuRuntimeObj::getInstance) #endif + .FUNCTION(conv_attrs_of) .FUNCTION(batch_norm_attrs_of) .FUNCTION(pool_attrs_of) diff --git a/src/intelcpu/mkl_runtime.cc b/src/intelcpu/mkl_runtime.cc new file mode 100644 index 00000000..7ae6b03d --- /dev/null +++ b/src/intelcpu/mkl_runtime.cc @@ -0,0 +1,19 @@ +#include "intelcpu/mkl_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +namespace infini { +MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::INTELCPU) { + dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0); + dnnl_stream_create( + &stream, engine, + static_cast(dnnl_stream_default_flags)); +} + +MklRuntimeObj::~MklRuntimeObj() { + mkl_free_buffers(); + dnnl_stream_destroy(stream); + dnnl_engine_destroy(engine); +} + +void MklRuntimeObj::sync() const { getStream().wait(); } +} // namespace infini diff --git a/src/mkl/operator_timer.cc b/src/intelcpu/operator_timer.cc similarity index 99% rename from src/mkl/operator_timer.cc rename to src/intelcpu/operator_timer.cc index c6f1c55d..c179f56d 100644 --- a/src/mkl/operator_timer.cc +++ b/src/intelcpu/operator_timer.cc @@ -1,7 +1,7 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" #include "operators/conv.h" #include "operators/matmul.h" #include "utils/data_generator.h" diff --git a/src/kernels/cpu/conv.cc b/src/kernels/cpu/conv.cc index 53670c5b..b0ffa724 100644 --- a/src/kernels/cpu/conv.cc +++ b/src/kernels/cpu/conv.cc @@ -10,8 +10,13 @@ template class NaiveConv : public CpuKernelWithoutConfig { T *iptr = op->getInputs(0)->getRawDataPtr(); T *wptr = op->getInputs(1)->getRawDataPtr(); T *optr = op->getOutput()->getRawDataPtr(); - auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); - auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + // Clang will give an error of " reference to local binding 'sh' + // declared in enclosing function" if we write like this: + // auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + int n, c, h, w, f, r, s; + std::tie(n, c, h, w, f, r, s) = op->getNCHWFRS(); + int ph, pw, sh, sw, dh, dw; + std::tie(ph, pw, sh, sw, dh, dw) = op->getPadStrideDilation(); int cpg = op->getChannelPerGroup(); int g = op->getNumGroups(); IT_ASSERT(f % g == 0, "Illegal number of channel"); @@ -23,7 +28,7 @@ template class NaiveConv : public CpuKernelWithoutConfig { for (int hh = 0; hh < oh; hh++) for (int ww = 0; ww < ow; ww++) { int gidx = ff / (f / g); - VType val = 0; + T val = 0; for (int cc = 0; cc < cpg; cc++) for (int rr = 0; rr < r; rr++) for (int ss = 0; ss < s; ss++) { @@ -52,4 +57,4 @@ REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv, "ConvNaive_CPU_float32"); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cpu/membound.cc b/src/kernels/cpu/membound.cc index 31755389..18fb304c 100644 --- a/src/kernels/cpu/membound.cc +++ b/src/kernels/cpu/membound.cc @@ -30,8 +30,8 @@ class MemboundInterpreter : public Kernel { // } nnet::RangeOp range = nnet::as(op->getNnetExpr()); - const auto &rangeShape = range->getOutputShape(); - const auto &outputShape = output->getDims(); + // const auto &rangeShape = range->getOutputShape(); + // const auto &outputShape = output->getDims(); // rangeShape and outputShape may extra dims of length 1. // But their sizes should be the same. IT_ASSERT((ssize_t)range->getOutputSize() == (ssize_t)output->size()); diff --git a/src/kernels/cuda/resize.cu b/src/kernels/cuda/resize.cu index 11fe360e..3f985dde 100644 --- a/src/kernels/cuda/resize.cu +++ b/src/kernels/cuda/resize.cu @@ -213,7 +213,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData, sizeof(p_cooridnate_trans_mode_func[0])); IT_ASSERT(nearestMode < sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0])); - _resize_kernel_nearest<<>>( + _resize_kernel_nearest<<>>( in, out, metaData, num, coordinateMode, nearestMode); } @@ -223,7 +223,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData, auto gridsize = (num + blocksize - 1) / blocksize; IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); - _resize_kernel_linear_coeff<<>>(in, out, metaData, num, + _resize_kernel_linear_coeff<<>>(in, out, metaData, num, coordinateMode); } @@ -233,7 +233,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData, auto gridsize = (num + blocksize - 1) / blocksize; IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); - _resize_kernel_cubic_coeff<<>>(in, out, metaData, num, + _resize_kernel_cubic_coeff<<>>(in, out, metaData, num, coordinateMode); } } // namespace infini diff --git a/src/kernels/cuda/softmax.cc b/src/kernels/cuda/softmax.cc new file mode 100644 index 00000000..437ed849 --- /dev/null +++ b/src/kernels/cuda/softmax.cc @@ -0,0 +1,30 @@ +#include "operators/softmax.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/softmax.h" + +namespace infini { +class SoftmaxCudnn : public CudaKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto x = op->getInputs(0)->getRawDataPtr(); + auto y = op->getOutput(0)->getRawDataPtr(); + auto dims = op->getInputs(0)->getDims(); + + int batch_size = 1; + for (size_t i = 0; i < dims.size(); ++i) + batch_size *= dims[i]; + int dim = dims[op->getAxis()]; + + int block_num = batch_size / dim; + int max_threadblock_size = batch_size / block_num; + softmax_kernel(max_threadblock_size, block_num, x, y, dim, + op->getInputs(0)->getStride().at(op->getAxis())); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn, + "Softmax_CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu new file mode 100644 index 00000000..1f7f39e6 --- /dev/null +++ b/src/kernels/cuda/softmax.cu @@ -0,0 +1,77 @@ +#include "cuda/cuda_common.h" +#include "cuda/softmax.h" +#include + +struct __align__(8) MD { + float data; + float d; +}; + +__device__ __forceinline__ MD reduce_md_op(MD a, MD b) { + bool a_bigger = (a.data > b.data); + MD bigger_m = a_bigger ? a : b; + MD smaller_m = a_bigger ? b : a; + MD res; + res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data); + res.data = bigger_m.data; + return res; +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void online_softmax(const float *__restrict in, float *__restrict out, + int dimSize, int stride) { + + // reposition in and out to data for the current vector + int blockOffset = blockIdx.x; + if (blockIdx.x >= stride) { + int tmp = blockIdx.x % stride; + blockOffset = tmp + (blockIdx.x - tmp) * dimSize; + } + in += blockOffset; + out += blockOffset; + + MD md_partial; + md_partial.data = -FLT_MAX; + md_partial.d = 0.0F; + + for (int elem_id = threadIdx.x; elem_id < dimSize; + elem_id += THREADBLOCK_SIZE) { + MD new_elem; + new_elem.data = in[elem_id * stride]; + new_elem.d = 1.0F; + md_partial = reduce_md_op(md_partial, new_elem); + } + + // blockreduce for THREADBLOCK_SIZE threads. + // The actrual threads num used in the block is "dimsSize" + typedef cub::BlockReduce BlockReduce; + + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ MD md_total; + + MD md = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op); + if (threadIdx.x == 0) + md_total = md; + __syncthreads(); + + float d_total_inverse = __fdividef(1.0F, md_total.d); + for (int elem_id = threadIdx.x; elem_id < dimSize; + elem_id += THREADBLOCK_SIZE) + out[elem_id * stride] = + __expf(in[elem_id * stride] - md_total.data) * d_total_inverse; +} + +namespace infini { +void softmax_kernel(int max_threadblock_size, int blockNum, float *in, + float *out, int dimSize, int stride) { + if (max_threadblock_size >= 255) + online_softmax<256><<>>(in, out, dimSize, stride); + else if (max_threadblock_size >= 128) + online_softmax<128><<>>(in, out, dimSize, stride); + else if (max_threadblock_size >= 64) + online_softmax<64><<>>(in, out, dimSize, stride); + else + online_softmax<32><<>>(in, out, dimSize, stride); +} +} // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 805b8fe0..a944b8af 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -60,48 +60,6 @@ class ActivationCudnn : public CudaKernelWithoutConfig { } }; -class SoftmaxCudnn : public CudaKernelWithoutConfig { - virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0; - virtual cudnnSoftmaxMode_t getModeType() const = 0; - virtual tuple getAlphBeta() const { return {1.f, 0.f}; } - void compute(const Operator &_op, - const RuntimeObj *_context) const override { - auto op = as(_op); - auto context = dynamic_cast(_context); - - void *const inputData = (op->getInputs(0)->getRawDataPtr()); - void *const outputData = (op->getOutput()->getRawDataPtr()); - - cudnnTensorDescriptor_t inputDesc, outputDesc; - auto dim = op->getInputs(0)->getDims(); - if (dim.size() != 4) - IT_TODO_HALT(); - int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; - - // get inputs - checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor( - inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); - - // get outputs - checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor( - outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); - - auto [alpha, beta] = getAlphBeta(); - cudnnStatus_t stat = cudnnSoftmaxForward( - context->cudnnHandle(), getAlgorithmType(), getModeType(), &alpha, - inputDesc, inputData, &beta, outputDesc, outputData); - if (stat != CUDNN_STATUS_SUCCESS) - return; - - // Destories in CUDA does not require sync. But cuDNN does not state - // whether sync is required before destories. - checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc)); - checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc)); - } -}; - class ReluCudnn : public ActivationCudnn { cudnnActivationMode_t getOpType() const override { return CUDNN_ACTIVATION_RELU; @@ -120,17 +78,6 @@ class TanhCudnn : public ActivationCudnn { } }; -class NormalSoftmaxCudnn : public SoftmaxCudnn { - cudnnSoftmaxAlgorithm_t getAlgorithmType() const override { - return CUDNN_SOFTMAX_ACCURATE; - } - cudnnSoftmaxMode_t getModeType() const override { - return CUDNN_SOFTMAX_MODE_INSTANCE; - } -}; - -REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, - NormalSoftmaxCudnn, "Softmax_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn, "Relu_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn, diff --git a/src/kernels/intelcpu/batch_norm.cc b/src/kernels/intelcpu/batch_norm.cc new file mode 100644 index 00000000..88296605 --- /dev/null +++ b/src/kernels/intelcpu/batch_norm.cc @@ -0,0 +1,68 @@ +#include "operators/batch_norm.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklBatchNorm : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + float *const srcData = op->getInputs(0)->getRawDataPtr(); + float *const dstData = op->getOutput()->getRawDataPtr(); + + // create user memory that describes data layout in the buffers + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), dstData); + + std::vector meanDims(op->getInputs(0)->getDims().size(), 1); + meanDims[1] = op->getInputs(0)->getDims()[1]; + auto meanMd = dnnl::memory::desc(meanDims, dnnl::memory::data_type::f32, + getUserFormatTag(meanDims.size())); + + auto meanMemory = + dnnl::memory(meanMd, context->getEngine(), + op->getInputs(1)->getRawDataPtr()); + auto varMemory = + dnnl::memory(meanMd, context->getEngine(), + op->getInputs(2)->getRawDataPtr()); + auto scaleMemory = + dnnl::memory(meanMd, context->getEngine(), + op->getInputs(3)->getRawDataPtr()); + auto baisMemory = + dnnl::memory(meanMd, context->getEngine(), + op->getInputs(4)->getRawDataPtr()); + using op_desc_t = dnnl::batch_normalization_forward::desc; + using pd_t = dnnl::batch_normalization_forward::primitive_desc; + + // use_global_stats stands for use mean and var as inputs + auto opDesc = + op_desc_t(dnnl::prop_kind::forward_inference, srcMd, op->getEps(), + dnnl::normalization_flags::use_global_stats | + dnnl::normalization_flags::use_shift | + dnnl::normalization_flags::use_scale); + auto primDesc = pd_t(opDesc, context->getEngine()); + + // create and execute primitive + dnnl::batch_normalization_forward(primDesc).execute( + context->getStream(), {{DNNL_ARG_SRC, srcMemory}, + {DNNL_ARG_DST, output}, + {DNNL_ARG_MEAN, meanMemory}, + {DNNL_ARG_VARIANCE, varMemory}, + {DNNL_ARG_SCALE, scaleMemory}, + {DNNL_ARG_SHIFT, baisMemory}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNorm, DataType::Float32, + MklBatchNorm, "BatchNorm_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/concat.cc b/src/kernels/intelcpu/concat.cc new file mode 100644 index 00000000..b4e7b24b --- /dev/null +++ b/src/kernels/intelcpu/concat.cc @@ -0,0 +1,58 @@ +#include "operators/concat.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklConcat : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + // create user memory that describes data layout in the buffers + std::vector srcsMd; + std::vector srcs; + + for (size_t i = 0; i < op->getInputs().size(); i++) { + std::vector dims; + auto inDims = op->getInputs(i)->getDims(); + int ndim = inDims.size(); + for (int j = 0; j < ndim; ++j) + dims.push_back(inDims.at(j)); + + auto md = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + srcsMd.push_back(md); + + auto srcMemory = + dnnl::memory(md, context->getEngine(), + op->getInputs(i)->getRawDataPtr()); + srcs.push_back(srcMemory); + } + + std::vector dims; + auto oDims = op->getOutput(0)->getDims(); + int ndim = oDims.size(); + for (int i = 0; i < ndim; ++i) + dims.push_back(oDims.at(i)); + + auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto primDesc = + dnnl::concat::primitive_desc(dstMd, static_cast(op->getDim()), + srcsMd, context->getEngine()); + + float *const dstData = op->getOutput()->getRawDataPtr(); + auto output = dnnl::memory(dstMd, context->getEngine(), dstData); + + // create and execute primitive + std::unordered_map args = {{DNNL_ARG_DST, output}}; + for (int i = 0; i < (int)srcs.size(); i++) { + args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs.at(i)}); + } + dnnl::concat(primDesc).execute(context->getStream(), args); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Concat, DataType::Float32, MklConcat, + "Concat_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/mkl/conv.cc b/src/kernels/intelcpu/conv.cc similarity index 93% rename from src/kernels/mkl/conv.cc rename to src/kernels/intelcpu/conv.cc index 18cc4ca2..77749e09 100644 --- a/src/kernels/mkl/conv.cc +++ b/src/kernels/intelcpu/conv.cc @@ -1,6 +1,6 @@ #include "operators/conv.h" #include "core/kernel.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" namespace infini { struct ConvMklPerfRecordObj : public PerfRecordObj { @@ -167,20 +167,19 @@ class MklConv : public Kernel { } void compute(const Operator &_op, const PerfRecord &_record, - const RuntimeObj *_context) const { + const RuntimeObj *_context) const override { auto op = as(_op); auto context = dynamic_cast(_context); auto record = as(_record); - dnnl::stream stream(context->getEngine()); std::vector prims; std::vector> primArgs; IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs)); IT_ASSERT(prims.size() == primArgs.size()); for (size_t i = 0; i < prims.size(); ++i) - prims.at(i).execute(stream, primArgs.at(i)); - stream.wait(); + prims.at(i).execute(context->getStream(), primArgs.at(i)); + context->getStream().wait(); } void compute(const Operator &op, const RuntimeObj *context) const override { @@ -209,17 +208,19 @@ class MklConv : public Kernel { continue; IT_ASSERT(prims.size() == primArgs.size()); - dnnl::stream stream(context->getEngine()); + // does context->getStream() need to be attached to runtime, and + // delete after each use? for (size_t i = 0; i < prims.size(); ++i) - prims.at(i).execute(stream, primArgs.at(i)); - stream.wait(); + prims.at(i).execute(context->getStream(), primArgs.at(i)); + context->getStream().wait(); record.time = timeit( [&]() { for (size_t i = 0; i < prims.size(); ++i) - prims.at(i).execute(stream, primArgs.at(i)); + prims.at(i).execute(context->getStream(), + primArgs.at(i)); }, - [&]() { stream.wait(); }); + [&]() { context->getStream().wait(); }); // Update the tune result if (ret.time > record.time) @@ -232,6 +233,6 @@ class MklConv : public Kernel { return make_ref(ret); } }; -REGISTER_KERNEL(Device::MKL, OpType::Conv, DataType::Float32, MklConv, +REGISTER_KERNEL(Device::INTELCPU, OpType::Conv, DataType::Float32, MklConv, "MklConv_CPU_float32"); } // namespace infini diff --git a/src/kernels/mkl/conv_transposed.cc b/src/kernels/intelcpu/conv_transposed.cc similarity index 98% rename from src/kernels/mkl/conv_transposed.cc rename to src/kernels/intelcpu/conv_transposed.cc index 3c45ddd4..aca5cca5 100644 --- a/src/kernels/mkl/conv_transposed.cc +++ b/src/kernels/intelcpu/conv_transposed.cc @@ -1,5 +1,5 @@ #include "core/kernel.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" #include "operators/conv.h" namespace infini { @@ -244,7 +244,7 @@ class MklConvTranspose : public Kernel { return make_ref(ret); } }; -REGISTER_KERNEL(Device::MKL, OpType::ConvTrans, DataType::Float32, +REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTrans, DataType::Float32, MklConvTranspose, "MklConvTrans_CPU_float32"); } // namespace infini diff --git a/src/kernels/intelcpu/element_wise.cc b/src/kernels/intelcpu/element_wise.cc new file mode 100644 index 00000000..dbc19b32 --- /dev/null +++ b/src/kernels/intelcpu/element_wise.cc @@ -0,0 +1,133 @@ +#include "operators/element_wise.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/unary.h" + +namespace infini { +class MklBinary : public MklKernelWithoutConfig { + dnnl::algorithm getAlgorithem(const Ref &op) const { + switch (op->getOpType()) { + case OpType::Add: + return dnnl::algorithm::binary_add; + case OpType::Sub: + return dnnl::algorithm::binary_sub; + case OpType::Mul: + return dnnl::algorithm::binary_mul; + case OpType::Div: + return dnnl::algorithm::binary_div; + + default: + IT_TODO_HALT(); + } + return dnnl::algorithm::undef; + } + + // Binary primitives support elementwise broadcast + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + // create user memory that describes data layout in the buffers + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + auto srcMd1 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto srcMemory1 = dnnl::memory(srcMd1, context->getEngine(), aData); + + auto srcMd2 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto srcMemory2 = dnnl::memory(srcMd2, context->getEngine(), bData); + + auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), cData); + + auto binaryDesc = + dnnl::binary::desc(getAlgorithem(op), srcMd1, srcMd2, dstMd); + auto primDesc = + dnnl::binary::primitive_desc(binaryDesc, context->getEngine()); + + // create and execute binary primitive + dnnl::binary(primDesc).execute(context->getStream(), + {{DNNL_ARG_SRC_0, srcMemory1}, + {DNNL_ARG_SRC_1, srcMemory2}, + {DNNL_ARG_DST, output}}); + } +}; + +class MklUnary : public MklKernelWithoutConfig { + dnnl::algorithm getAlgorithem(const Ref &op) const { + switch (op->getOpType()) { + case OpType::Relu: + return dnnl::algorithm::eltwise_relu; + case OpType::Tanh: + return dnnl::algorithm::eltwise_tanh; + case OpType::Abs: + return dnnl::algorithm::eltwise_abs; + case OpType::Sigmoid: + return dnnl::algorithm::eltwise_logistic; + default: + IT_TODO_HALT(); + } + return dnnl::algorithm::undef; + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const srcData = (op->getInputs(0)->getRawDataPtr()); + void *const dstData = (op->getOutput()->getRawDataPtr()); + + // create user memory that describes data layout in the buffers + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size()), false); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto output = dnnl::memory(srcMd, context->getEngine(), dstData); + + const float negative1_slope = 0.0f; + + auto unaryDesc = dnnl::eltwise_forward::desc( + dnnl::prop_kind::forward_inference, getAlgorithem(op), srcMd, + negative1_slope); + auto primDesc = dnnl::eltwise_forward::primitive_desc( + unaryDesc, context->getEngine()); + + // create and execute binary primitive + dnnl::eltwise_forward(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); + } +}; + +REGISTER_KERNEL(Device::INTELCPU, OpType::Add, DataType::Float32, MklBinary, + "Add_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Sub, DataType::Float32, MklBinary, + "Sub_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Mul, DataType::Float32, MklBinary, + "Mul_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Div, DataType::Float32, MklBinary, + "Div_Mkl_Float32"); + +REGISTER_KERNEL(Device::INTELCPU, OpType::Relu, DataType::Float32, MklUnary, + "Relu_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Sigmoid, DataType::Float32, MklUnary, + "Sigmoid_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Tanh, DataType::Float32, MklUnary, + "Tanh_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Abs, DataType::Float32, MklUnary, + "Abs_Mkl_Float32"); +} // namespace infini diff --git a/src/kernels/intelcpu/extend.cc b/src/kernels/intelcpu/extend.cc new file mode 100644 index 00000000..dff2ebc1 --- /dev/null +++ b/src/kernels/intelcpu/extend.cc @@ -0,0 +1,45 @@ +#include "operators/extend.h" +#include "core/kernel.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" +#include +#include + +namespace infini { +class MklExtend : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto inData = op->getInputs(0)->getRawDataPtr(); + auto outData = op->getOutput(0)->getRawDataPtr(); + int iSize = op->getInputs(0)->size(); + int oSize = op->getOutput(0)->size(); + + sycl::queue q(sycl::cpu_selector{}); + auto inDevice = sycl::malloc_device(iSize, q); + auto outDevice = sycl::malloc_device(oSize, q); + + q.memcpy(inDevice, inData, iSize * sizeof(float)); + q.wait(); + + int blockSize = 1; + auto iDim = op->getInputs(0)->getDims(); + for (size_t i = iDim.size() - 1; + i >= (size_t)op->getDim() && i != (size_t)-1; --i) + blockSize *= iDim[i]; + auto blockSizeOuter = (op->getNum() + 1) * blockSize; + + q.parallel_for(sycl::range<1>(oSize), [=](sycl::id<1> index) { + auto iIdx = index % blockSize + index / blockSizeOuter * blockSize; + outDevice[index] = inDevice[iIdx]; + }).wait(); + + q.memcpy(outData, outDevice, oSize * sizeof(float)); + q.wait(); + sycl::free(inDevice, q); + sycl::free(outDevice, q); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Extend, DataType::Float32, MklExtend, + "Extend_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/gather.cc b/src/kernels/intelcpu/gather.cc new file mode 100644 index 00000000..a95ece4e --- /dev/null +++ b/src/kernels/intelcpu/gather.cc @@ -0,0 +1,86 @@ +#include "operators/gather.h" +#include "core/kernel.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" +#include +#include + +namespace infini { +class MklGather : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto in = op->getInputs(0); + auto index = op->getInputs(1); + auto out = op->getOutput(); + int iSize = in->size(); + int oSize = out->size(); + int idxSize = index->size(); + + int inNDim = in->getDims().size(); + int oNDim = out->getDims().size(); + int idxNDim = index->getDims().size(); + int axis = op->getAxis(); + + int outDim[4] = {0}; + int idxDim[4] = {0}; + int idxStride[4] = {0}; + int inStride[4] = {0}; + for (int i = 0; i < oNDim; ++i) + outDim[i] = out->getDims()[i]; + for (int i = 0; i < idxNDim; ++i) { + idxDim[i] = index->getDims()[i]; + idxStride[i] = index->getStride()[i]; + } + for (int i = 0; i < inNDim; ++i) { + inStride[i] = in->getStride()[i]; + } + + sycl::queue q(sycl::cpu_selector{}); + auto inDevice = sycl::malloc_device(iSize, q); + auto indexDevice = sycl::malloc_device(idxSize, q); + auto outDevice = sycl::malloc_device(oSize, q); + + q.memcpy(inDevice, in->getRawDataPtr(), iSize * sizeof(float)); + q.memcpy(indexDevice, index->getRawDataPtr(), + idxSize * sizeof(uint32_t)); + q.wait(); + + q.parallel_for(sycl::range<1>(oSize), [=](sycl::id<1> index) { + int offset = 0; + int gOffset = index; + for (int i = inNDim - 1, k = oNDim - 1; i >= 0; --i) { + int idx = 0; + if (i == axis) { + int idxOffset = 0; + for (int j = idxNDim - 1; j >= 0; --j) { + int p = gOffset % idxDim[j]; + gOffset = gOffset / idxDim[j]; + idxOffset += p * idxStride[j]; + } + + idx = indexDevice[idxOffset]; + k = k - idxNDim; + + } else { + idx = gOffset % outDim[k]; + gOffset = gOffset / outDim[k]; + --k; + } + offset += idx * inStride[i]; + } + + outDevice[index] = inDevice[offset]; + }).wait(); + + q.memcpy(out->getRawDataPtr(), outDevice, + oSize * sizeof(float)); + q.wait(); + sycl::free(inDevice, q); + sycl::free(outDevice, q); + sycl::free(indexDevice, q); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Gather, DataType::Float32, MklGather, + "Gather_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/mkl/matmul.cc b/src/kernels/intelcpu/matmul.cc similarity index 89% rename from src/kernels/mkl/matmul.cc rename to src/kernels/intelcpu/matmul.cc index 02e6dd53..61cf5c94 100644 --- a/src/kernels/mkl/matmul.cc +++ b/src/kernels/intelcpu/matmul.cc @@ -1,9 +1,8 @@ #include "operators/matmul.h" #include "core/kernel.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" namespace infini { - template class MklMatmul : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { @@ -32,7 +31,7 @@ template class MklMatmul : public CpuKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::MKL, OpType::Matmul, DataType::Float32, - MklMatmul, "MklMatmul_CPU_float32"); +/*REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, DataType::Float32, + MklMatmul, "MklMatmul_CPU_float32");*/ } // namespace infini diff --git a/src/kernels/intelcpu/matmul_dpcpp.cc b/src/kernels/intelcpu/matmul_dpcpp.cc new file mode 100644 index 00000000..fd77ee39 --- /dev/null +++ b/src/kernels/intelcpu/matmul_dpcpp.cc @@ -0,0 +1,75 @@ +#include "core/kernel.h" +#include "intelcpu/mkl_runtime.h" +#include "mkl.h" +#include "oneapi/mkl/blas.hpp" +#include "operators/matmul.h" +#include + +namespace infini { +template class MklDpcppMatmul : public CpuKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *context) const override { + auto op = as(_op); + IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); + const T *A = op->getInputs(0)->getRawDataPtr(); + const T *B = op->getInputs(1)->getRawDataPtr(); + T *C = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getAct() == ActType::None); + const int m = op->getM(), n = op->getN(), k = op->getK(), + b = op->getB(); + + auto opA = op->getTransA() ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + auto opB = op->getTransB() ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + // ldA is always a.col, and ldB is always b.col when row major + const int ldA = + std::max((opA == oneapi::mkl::transpose::nontrans) ? k : m, 1); + const int ldB = + std::max((opB == oneapi::mkl::transpose::nontrans) ? n : k, 1); + const int ldC = std::max(n, 1); + + const float alpha = 1.f, beta = 0.f; + // TODO: Intel MKL ERROR will occur when using cblas_sgemm_batch + /*for (int i = 0; i < b; ++i) { + cblas_sgemm(CblasRowMajor, opA, opB, m, n, k, alpha, A + m * k * i, + ldA, B + k * n * i, ldB, beta, C + m * n * i, ldC); + }*/ + + sycl::queue q(sycl::cpu_selector{}); + // Catch asynchronous exceptions + auto exception_handler = [](cl::sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (cl::sycl::exception const &e) { + std::cout + << "Caught asynchronous SYCL exception during GEMM:\n" + << e.what() << std::endl; + } + } + }; + + // create execution queue and buffers of matrix data + cl::sycl::queue main_queue(sycl::cpu_selector{}, exception_handler); + + cl::sycl::buffer A_buffer(A, op->getInputs(0)->size()); + cl::sycl::buffer B_buffer(B, op->getInputs(1)->size()); + cl::sycl::buffer C_buffer(C, op->getOutput(0)->size()); + + // add oneapi::mkl::blas::gemm to execution queue + try { + oneapi::mkl::blas::row_major::gemm_batch( + main_queue, opA, opB, m, n, k, alpha, A_buffer, ldA, m * k, + B_buffer, ldB, k * n, beta, C_buffer, ldC, m * n, b); + } catch (cl::sycl::exception const &e) { + std::cout << "\t\tCaught synchronous SYCL exception during GEMM:\n" + << e.what() << std::endl; + } + } +}; + +REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, DataType::Float32, + MklDpcppMatmul, "MklDpcppMatmul_CPU_float32"); + +} // namespace infini diff --git a/src/kernels/intelcpu/pad.cc b/src/kernels/intelcpu/pad.cc new file mode 100644 index 00000000..02dc4143 --- /dev/null +++ b/src/kernels/intelcpu/pad.cc @@ -0,0 +1,58 @@ +#include "operators/pad.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklPad : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + dims.push_back(op->getInputs(0)->getDims()[i]); + } + auto paddedMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + + // dst md + auto oDims = op->getOutput(0)->getDims(); + int ndim = oDims.size(); + std::vector paddedDims, offsets; + for (int i = 0; i < ndim; ++i) { + paddedDims.push_back(oDims.at(i)); + paddedMd.data.padded_dims[i] = oDims.at(i); + paddedMd.data.padded_offsets[i] = op->getPads().at(i); + offsets.push_back(op->getPads().at(i)); + } + // will fill padded area with zero. + auto paddedMemory = + dnnl::memory(paddedMd, context->getEngine(), + op->getOutput(0)->getRawDataPtr()); + + auto dstMd = + dnnl::memory::desc(paddedDims, dnnl::memory::data_type::f32, + getUserFormatTag(paddedDims.size())); + + // copy src to the submemory of dst + // create submemory + auto md = dstMd.submemory_desc(dims, offsets); + auto mem = dnnl::memory(md, context->getEngine(), + op->getOutput(0)->getRawDataPtr()); + + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto srcMemory = + dnnl::memory(srcMd, context->getEngine(), + op->getInputs(0)->getRawDataPtr()); + + // copy data to submemory + dnnl::reorder(srcMemory, mem) + .execute(context->getStream(), + {{DNNL_ARG_FROM, srcMemory}, {DNNL_ARG_TO, mem}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Pad, DataType::Float32, MklPad, + "Pad_Mkl_Float32"); +} // namespace infini diff --git a/src/kernels/intelcpu/pooling.cc b/src/kernels/intelcpu/pooling.cc new file mode 100644 index 00000000..d27238fe --- /dev/null +++ b/src/kernels/intelcpu/pooling.cc @@ -0,0 +1,84 @@ +#include "operators/pooling.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklPooling : public MklKernelWithoutConfig { + virtual dnnl::algorithm getAlgorithm() const = 0; + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + float *const srcData = op->getInputs(0)->getRawDataPtr(); + float *const dstData = op->getOutput()->getRawDataPtr(); + + // create user memory that describes data layout in the buffers + auto [n, c, h, w, r, s] = op->getNCHWRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + auto nDim = op->getOutput()->getDims().size(); + auto oh = op->getOutput()->getDims()[nDim - 2]; + auto ow = op->getOutput()->getDims()[nDim - 1]; + + auto srcMd = dnnl::memory::desc( + {n, c, h, w}, dnnl::memory::data_type::f32, getUserFormatTag(nDim)); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto userDstMd = + dnnl::memory::desc({n, c, oh, ow}, dnnl::memory::data_type::f32, + getUserFormatTag(nDim)); + + auto dstMd = + dnnl::memory::desc({n, c, oh, ow}, dnnl::memory::data_type::f32, + dnnl::memory::format_tag::any); + + using op_desc_t = dnnl::pooling_v2_forward::desc; + using pd_t = dnnl::pooling_v2_forward::primitive_desc; + + auto opDesc = op_desc_t(dnnl::prop_kind::forward_inference, + getAlgorithm(), srcMd, dstMd, {sh, sw}, {r, s}, + {dh - 1, dw - 1}, {ph, pw}, {ph, pw}); + auto primDesc = pd_t(opDesc, context->getEngine()); + + if (primDesc.dst_desc() == userDstMd) { + auto output = dnnl::memory(primDesc.dst_desc(), + context->getEngine(), dstData); + + dnnl::pooling_v2_forward(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); + } else { + auto dstMemory = + dnnl::memory(primDesc.dst_desc(), context->getEngine()); + + dnnl::pooling_v2_forward(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, dstMemory}}); + + auto output = + dnnl::memory(userDstMd, context->getEngine(), dstData); + dnnl::reorder(dstMemory, output) + .execute(context->getStream(), + {{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}}); + } + } +}; + +class MklAvgPool : public MklPooling { + dnnl::algorithm getAlgorithm() const override { + return dnnl::algorithm::pooling_avg_include_padding; + } +}; + +class MklMaxPool : public MklPooling { + dnnl::algorithm getAlgorithm() const override { + return dnnl::algorithm::pooling_max; + } +}; + +REGISTER_KERNEL(Device::INTELCPU, OpType::AvgPool, DataType::Float32, + MklAvgPool, "AvgPool_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::MaxPool, DataType::Float32, + MklMaxPool, "MaxPool_Mkl_Float32"); +} // namespace infini diff --git a/src/kernels/intelcpu/pow.cc b/src/kernels/intelcpu/pow.cc new file mode 100644 index 00000000..166d0a75 --- /dev/null +++ b/src/kernels/intelcpu/pow.cc @@ -0,0 +1,43 @@ +#include "core/kernel.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/element_wise.h" +#include +#include + +namespace infini { +class MklPow : public MklKernelWithoutConfig { + // TODO: not need to copy memory?? + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto in0Data = op->getInputs(0)->getRawDataPtr(); + auto in1Data = op->getInputs(1)->getRawDataPtr(); + auto outData = op->getOutput(0)->getRawDataPtr(); + int size = op->getInputs(0)->size(); + + // cpu_selector using openCL as backend;and host_selector bypasses the + // OnenCL backend and runs directly on CPU hardware + sycl::queue q(sycl::cpu_selector{}); + auto in0Device = sycl::malloc_device(size, q); + auto in1Device = sycl::malloc_device(size, q); + auto outDevice = sycl::malloc_device(size, q); + q.memcpy(in0Device, in0Data, size * sizeof(float)); + q.wait(); + q.memcpy(in1Device, in1Data, size * sizeof(float)); + q.wait(); + + q.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) { + outDevice[i] = pow(in0Device[i], in1Device[i]); + }).wait(); + q.memcpy(outData, outDevice, size * sizeof(float)); + q.wait(); + sycl::free(in0Device, q); + sycl::free(in1Device, q); + sycl::free(outDevice, q); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Pow, DataType::Float32, MklPow, + "Pow_Mkl_Float32"); + +}; // namespace infini diff --git a/src/kernels/intelcpu/reduce.cc b/src/kernels/intelcpu/reduce.cc new file mode 100644 index 00000000..23202fec --- /dev/null +++ b/src/kernels/intelcpu/reduce.cc @@ -0,0 +1,69 @@ +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/reduce_mean.h" + +namespace infini { +class MklReduce : public MklKernelWithoutConfig { + dnnl::algorithm getAlgorithm() const { + return dnnl::algorithm::reduction_mean; + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + float *const srcData = op->getInputs(0)->getRawDataPtr(); + float *const dstData = op->getOutput()->getRawDataPtr(); + + // create user memory that describes data layout in the buffers + std::vector inDims, inStrides; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + inDims.push_back(op->getInputs(0)->getDims()[i]); + inStrides.push_back(op->getInputs(0)->getStride()[i]); + } + + std::vector oDims(op->getInputs(0)->getDims().size(), 0), + oStrides(op->getInputs(0)->getDims().size(), 1); + if (!op->getKeepDims()) { + oDims = inDims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + if (op->isReduced(i)) { + oDims[i] = 1; + } + } + int stride = 1; + for (int i = (int)oDims.size() - 1; i >= 0; --i) { + oStrides[i] = stride; + stride *= oDims[i]; + } + } else { + for (size_t i = 0; i < op->getOutput(0)->getDims().size(); ++i) { + oDims[i] = op->getOutput(0)->getDims()[i]; + oStrides[i] = op->getOutput(0)->getStride()[i]; + } + } + + auto srcMd = + dnnl::memory::desc(inDims, dnnl::memory::data_type::f32, inStrides); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto dstMd = + dnnl::memory::desc(oDims, dnnl::memory::data_type::f32, oStrides); + auto output = dnnl::memory(dstMd, context->getEngine(), dstData); + + using op_desc_t = dnnl::reduction::desc; + using pd_t = dnnl::reduction::primitive_desc; + + auto opDesc = op_desc_t(getAlgorithm(), srcMd, dstMd, 0, 0); + auto primDesc = pd_t(opDesc, context->getEngine()); + + // create and execute primitive + dnnl::reduction(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::ReduceMean, DataType::Float32, + MklReduce, "ReduceMean_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/reshape.cc b/src/kernels/intelcpu/reshape.cc new file mode 100644 index 00000000..bddef40f --- /dev/null +++ b/src/kernels/intelcpu/reshape.cc @@ -0,0 +1,50 @@ +#include "operators/reshape.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklReshape : public MklKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + + auto context = dynamic_cast(_context); + + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + // create src md and src memory + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + + // dst md + auto oDims = op->getOutput(0)->getDims(); + int ndim = oDims.size(); + std::vector reshapeDims; + for (int i = 0; i < ndim; ++i) { + reshapeDims.push_back(oDims.at(i)); + } + auto reshapeMd = srcMd.reshape(reshapeDims); + auto reshapeMemory = + dnnl::memory(reshapeMd, context->getEngine(), + op->getInputs(0)->getRawDataPtr()); + + auto dstMd = + dnnl::memory::desc(reshapeDims, dnnl::memory::data_type::f32, + getUserFormatTag(reshapeDims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), + op->getOutput(0)->getRawDataPtr()); + + // copy data to dst + dnnl::reorder(reshapeMemory, output) + .execute(context->getStream(), + {{DNNL_ARG_FROM, reshapeMemory}, {DNNL_ARG_TO, output}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Reshape, DataType::Float32, + MklReshape, "Reshape_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Identity, DataType::Float32, + MklReshape, "Identify_Mkl_Float32"); +REGISTER_KERNEL(Device::INTELCPU, OpType::Flatten, DataType::Float32, + MklReshape, "Flatten_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/resize.cc b/src/kernels/intelcpu/resize.cc new file mode 100644 index 00000000..e7b3eea4 --- /dev/null +++ b/src/kernels/intelcpu/resize.cc @@ -0,0 +1,80 @@ +#include "operators/resize.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklResize : public MklKernelWithoutConfig { + dnnl::algorithm getAlgorithm(Ref op) const { + switch (op->getMode()) { + case ResizeObj::ECoeffMode::nearest: { + if (op->getNearestMode() != + enum_to_underlying(ResizeObj::ENearestMode::ceil)) + IT_TODO_HALT(); + return dnnl::algorithm::resampling_nearest; + } + case ResizeObj::ECoeffMode::linear: + return dnnl::algorithm::resampling_linear; + + default: + IT_TODO_HALT(); + } + return dnnl::algorithm::resampling_nearest; + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + // only support default coordinate transmode?? + if (op->getCoordinateTransMode() != + enum_to_underlying(ResizeObj::ECoordinateTransMode::halfPixel)) + IT_TODO_HALT(); + + int nDim = op->getInputs(0)->getDims().size(); + IT_ASSERT(nDim == 3 || nDim == 4 || + nDim == 5 && + (op->getInputs(0)->getDims()[0] == 1 && + op->getInputs(0)->getDims()[1] == 1) && + (op->getOutput(0)->getDims()[0] == 1 && + op->getOutput(0)->getDims()[1] == 1)); + + IT_ASSERT(op->getScales().size() == nDim); + std::vector::iterator beg = op->getScales().begin() + 2; + std::vector scales(beg, op->getScales().end()); + + // create user memory that describes data layout in the buffers + std::vector idims, odims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + idims.push_back(op->getInputs(0)->getDims()[i]); + odims.push_back(op->getOutput(0)->getDims()[i]); + } + + auto context = dynamic_cast(_context); + + float *const srcData = op->getInputs(0)->getRawDataPtr(); + float *const dstData = op->getOutput()->getRawDataPtr(); + + auto srcMd = dnnl::memory::desc(idims, dnnl::memory::data_type::f32, + getUserFormatTag(idims.size())); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto dstMd = dnnl::memory::desc(odims, dnnl::memory::data_type::f32, + getUserFormatTag(odims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), dstData); + + using op_desc_t = dnnl::resampling_forward::desc; + using pd_t = dnnl::resampling_forward::primitive_desc; + + auto opDesc = op_desc_t(dnnl::prop_kind::forward_inference, + getAlgorithm(op), scales, srcMd, dstMd); + auto primDesc = pd_t(opDesc, context->getEngine()); + + // create and execute primitive + dnnl::resampling_forward(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Resize, DataType::Float32, MklResize, + "Resize_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/slice.cc b/src/kernels/intelcpu/slice.cc new file mode 100644 index 00000000..8a5a489f --- /dev/null +++ b/src/kernels/intelcpu/slice.cc @@ -0,0 +1,46 @@ +#include "operators/slice.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklSlice : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + // create src md + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + + // dst md + auto oDims = op->getOutput(0)->getDims(); + int ndim = oDims.size(); + std::vector sDims, offsets; + for (int i = 0; i < ndim; ++i) { + sDims.push_back(oDims.at(i)); + offsets.push_back(op->getStart().at(i)); + } + auto sliceMd = srcMd.submemory_desc(sDims, offsets); + auto sliceMemory = + dnnl::memory(sliceMd, context->getEngine(), + op->getInputs(0)->getRawDataPtr()); + + auto dstMd = dnnl::memory::desc(sDims, dnnl::memory::data_type::f32, + getUserFormatTag(sDims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), + op->getOutput(0)->getRawDataPtr()); + + // copy data to dst + dnnl::reorder(sliceMemory, output) + .execute(context->getStream(), + {{DNNL_ARG_FROM, sliceMemory}, {DNNL_ARG_TO, output}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Slice, DataType::Float32, MklSlice, + "Slice_Mkl_Float32"); +} // namespace infini diff --git a/src/kernels/intelcpu/softmax.cc b/src/kernels/intelcpu/softmax.cc new file mode 100644 index 00000000..f8ce568c --- /dev/null +++ b/src/kernels/intelcpu/softmax.cc @@ -0,0 +1,43 @@ +#include "operators/softmax.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklSoftmax : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + float *const srcData = op->getInputs(0)->getRawDataPtr(); + float *const dstData = op->getOutput()->getRawDataPtr(); + + // create user memory that describes data layout in the buffers + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData); + + auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto output = dnnl::memory(dstMd, context->getEngine(), dstData); + + using op_desc_t = dnnl::softmax_forward::desc; + using pd_t = dnnl::softmax_forward::primitive_desc; + + auto opDesc = + op_desc_t(dnnl::prop_kind::forward_inference, srcMd, op->getAxis()); + auto primDesc = pd_t(opDesc, context->getEngine()); + + // create and execute primitive + dnnl::softmax_forward(primDesc).execute( + context->getStream(), + {{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}}); + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Softmax, DataType::Float32, + MklSoftmax, "Softmax_Mkl_Float32"); +}; // namespace infini diff --git a/src/kernels/intelcpu/split.cc b/src/kernels/intelcpu/split.cc new file mode 100644 index 00000000..654cf9a8 --- /dev/null +++ b/src/kernels/intelcpu/split.cc @@ -0,0 +1,54 @@ +#include "operators/split.h" +#include "intelcpu/mkl_kernel_without_config.h" +#include "intelcpu/mkl_runtime.h" + +namespace infini { +class MklSplit : public MklKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + std::vector dims; + for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + dims.push_back(op->getInputs(0)->getDims()[i]); + + // create src md + auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + + // dst md + std::vector dstsMd; + std::vector dsts; + int offset = 0; + for (size_t i = 0; i < op->getOutputs().size(); i++) { + auto oDims = op->getOutput(i)->getDims(); + int ndim = oDims.size(); + std::vector dims, offsets(ndim, 0); + for (int i = 0; i < ndim; ++i) { + dims.push_back(oDims.at(i)); + } + offsets[op->getDim()] = offset; + auto splitMd = srcMd.submemory_desc(dims, offsets); + auto splitMemory = + dnnl::memory(splitMd, context->getEngine(), + op->getInputs(0)->getRawDataPtr()); + + auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, + getUserFormatTag(dims.size())); + auto output = + dnnl::memory(dstMd, context->getEngine(), + op->getOutput(i)->getRawDataPtr()); + + // copy data to dst + dnnl::reorder(splitMemory, output) + .execute(context->getStream(), + {{DNNL_ARG_FROM, splitMemory}, {DNNL_ARG_TO, output}}); + + offset += dims.at(op->getDim()); + } + } +}; +REGISTER_KERNEL(Device::INTELCPU, OpType::Split, DataType::Float32, MklSplit, + "Split_Mkl_Float32"); +}; // namespace infini diff --git a/src/mkl/mkl_runtime.cc b/src/mkl/mkl_runtime.cc deleted file mode 100644 index 6b868f70..00000000 --- a/src/mkl/mkl_runtime.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "mkl/mkl_runtime.h" -#include "core/graph.h" -#include "core/kernel.h" -namespace infini { -MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::MKL) { - dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0); -} - -MklRuntimeObj::~MklRuntimeObj() { - mkl_free_buffers(); - dnnl_engine_destroy(engine); -} -} // namespace infini diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index 6ae7673b..fa45d48e 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -39,18 +39,25 @@ vector ReshapeObj::getOpAttrVector() const { return ret; } -FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output) +FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis) : OperatorObj(OpType::Flatten, {input}, {output}) { + if (_axis >= 0 && (size_t)_axis < input->getDims().size()) + axis = _axis; + else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size()) + axis = _axis + input->getDims().size(); + else + IT_ASSERT(0); IT_ASSERT(checkValid(graph)); } optional> FlattenObj::inferShape(const TensorVec &inputs) const { - int size = 1; + int sizeB = 1, sizeE = 1; auto dims = getInputs(0)->getDims(); - for (size_t i = 0; i < dims.size(); ++i) - size *= dims.at(i); + int ndim = dims.size(); + for (int i = 0; i < ndim; ++i) + ((i < axis) ? sizeB : sizeE) *= dims.at(i); - return {{{size}}}; + return {{{sizeB, sizeE}}}; } std::string FlattenObj::toString() const { @@ -59,18 +66,20 @@ std::string FlattenObj::toString() const { os << "("; os << vecToString(inputs[0]->getDims()) << ","; os << "input=" << inputs[0]->getGuid() << ","; - os << "output=" << outputs[0]->getGuid() << ")"; + os << "output=" << outputs[0]->getGuid() << ","; + os << "axis=" << axis << ")"; return os.str(); } vector FlattenObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), axis); ret.emplace(ret.begin(), enum_to_underlying(type)); return ret; } vector FlattenObj::getOpAttrVector() const { - return {enum_to_underlying(type)}; + return {enum_to_underlying(type), axis}; } IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output) diff --git a/src/operators/resize.cc b/src/operators/resize.cc index 5270abd2..2b04664f 100644 --- a/src/operators/resize.cc +++ b/src/operators/resize.cc @@ -70,25 +70,6 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes, } } } -/* -Operator ResizeObj::clone(TensorVec inputs, TensorVec outputs) { - Tensor roi{nullptr}, sizes{nullptr}, scales{nullptr}; - if (inputs.size() == 3) - roi = inputs[2]; - if (isResizeBySizes()) - sizes = inputs[1]; - else - scales = inputs[1]; - - if (mode == ECoeffMode::nearest) - return make_ref(nullptr, inputs[0], outputs[0], axes, - inputs[1], nullptr, roi, ratioPolicy, - nearestMode, coMode); - else - return make_ref(nullptr, inputs[0], outputs[0], axes, - inputs[1], nullptr, roi, mode, ratioPolicy, - coMode); -}*/ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, const std::optional> &axes) { diff --git a/src/operators/softmax.cc b/src/operators/softmax.cc new file mode 100644 index 00000000..c5bd7d25 --- /dev/null +++ b/src/operators/softmax.cc @@ -0,0 +1,37 @@ +#include "operators/softmax.h" + +namespace infini { + +SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis) + : OperatorObj(OpType::Softmax, {input}, {output}) { + if (_axis >= 0 && (size_t)_axis < input->getDims().size()) + axis = _axis; + else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size()) + axis = _axis + input->getDims().size(); + else + IT_ASSERT(0); + IT_ASSERT(checkValid(graph)); +} + +std::string SoftmaxObj::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ","; + os << "axis=" << axis << ")"; + return os.str(); +} + +vector SoftmaxObj::getWorkloadVector() const { + vector ret{enum_to_underlying(type), axis}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector SoftmaxObj::getOpAttrVector() const { + return {enum_to_underlying(type), axis}; +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_reshape.cc b/test/kernels/cuda/test_cuda_reshape.cc index 843caa4b..7e4a9c0c 100644 --- a/test/kernels/cuda/test_cuda_reshape.cc +++ b/test/kernels/cuda/test_cuda_reshape.cc @@ -51,7 +51,7 @@ TEST(CUDA_Flatten, run) { // Build CUDA graph Graph g = make_ref(cudaRuntime); auto i = g->cloneTensor(icpu); - auto op = g->addOp(i, nullptr); + auto op = g->addOp(i, nullptr, 2); // allocate CUDA memory g->dataMalloc(); diff --git a/test/kernels/cuda/test_cuda_softmax.cc b/test/kernels/cuda/test_cuda_softmax.cc new file mode 100644 index 00000000..5a07ca78 --- /dev/null +++ b/test/kernels/cuda/test_cuda_softmax.cc @@ -0,0 +1,142 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/softmax.h" +#include "test.h" +#include +namespace infini { + +TEST(cuDNN_Softmax, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 1); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + cudaPrintTensor(outputGpu); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData( + vector{0.032058604, 0.08714432, 0.23688284, 0.6439143, + 0.032058604, 0.08714432, 0.23688284, 0.6439143})); +} + +TEST(cuDNN_Softmax, run_axis0) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 0); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + cudaPrintTensor(outputGpu); + // Check + EXPECT_TRUE( + outputGpu2Cpu->equalData(vector{0., 0., 0., 0., 1, 1, 1, 1})); +} + +TEST(cuDNN_Softmax2, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(IncrementalGenerator()); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 1); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + cudaPrintTensor(outputGpu); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138, + 0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862, + 0.9820138, 0.9820138, 0.9820138, 0.9820138})); +} + +TEST(cuDNN_Softmax2, run_axis2) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(IncrementalGenerator()); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 2); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + cudaPrintTensor(outputGpu); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029, + 0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971, + 0.1192029, 0.1192029, 0.8807971, 0.8807971})); +} + +TEST(cuDNN_Softmax2, run_axis3) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(IncrementalGenerator()); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr, 3); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + cudaPrintTensor(outputGpu); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586})); +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 6aac66d4..c7beb760 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -41,7 +41,6 @@ void testUnary(const std::function &generator, TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); - testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); diff --git a/test/kernels/intelcpu/test_mkl_batch_norm.cc b/test/kernels/intelcpu/test_mkl_batch_norm.cc new file mode 100644 index 00000000..24c87474 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_batch_norm.cc @@ -0,0 +1,36 @@ + +#include "core/graph.h" +#include "core/kernel.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/batch_norm.h" +#include "test.h" + +namespace infini { +TEST(MklBatchNorm, run) { + // Runtime + auto runtime = make_ref(); + + // Build graph + Graph g = make_ref(runtime); + auto i = g->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); + auto mean = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto var = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto scale = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto bias = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto op = + g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + mean->copyin(vector{1, 6, 9}); + var->copyin(vector{4, 1, 9}); + scale->setData(OneGenerator()); + bias->setData(ZeroGenerator()); + + runtime->run(g); + + auto o = op->getOutput(); + EXPECT_TRUE(o->equalData(vector{ + -0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.3333333, 0, 0.3333333, 0.6666667})); +} + +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_concat.cc b/test/kernels/intelcpu/test_mkl_concat.cc new file mode 100644 index 00000000..63760e43 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_concat.cc @@ -0,0 +1,29 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/concat.h" + +#include "test.h" + +namespace infini { + +TEST(Concat, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto t1 = g->addTensor({2, 2, 3, 1}, DataType::Float32); + auto t2 = g->addTensor({2, 2, 1, 1}, DataType::Float32); + auto t3 = g->addTensor({2, 2, 2, 1}, DataType::Float32); + auto op = g->addOp(TensorVec{t1, t2, t3}, nullptr, 2); + g->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + t3->setData(OneGenerator()); + + runtime->run(g); + EXPECT_TRUE(op->getOutput()->equalData( + vector{0, 1, 2, 1, 1, 1, 3, 4, 5, 1, 1, 1, + 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); +} + +} // namespace infini diff --git a/test/kernels/mkl/test_mkl_conv.cc b/test/kernels/intelcpu/test_mkl_conv.cc similarity index 90% rename from test/kernels/mkl/test_mkl_conv.cc rename to test/kernels/intelcpu/test_mkl_conv.cc index 4ba5fd7f..96fd5498 100644 --- a/test/kernels/mkl/test_mkl_conv.cc +++ b/test/kernels/intelcpu/test_mkl_conv.cc @@ -2,7 +2,7 @@ #include "core/kernel.h" #include "core/perf_engine.h" #include "core/runtime.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" #include "operators/conv.h" #include "test.h" @@ -17,18 +17,15 @@ void testConvDnnl( Tensor i0 = gMkl->addTensor({1, 3, 4, 4}, DataType::Float32); Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32); + + // Build graph + auto conv = gMkl->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); // Malloc data for all tensors in a graph. gMkl->dataMalloc(); i0->setData(generator); w0->setData(generator); - // Build graph - auto conv = gMkl->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); - // allocate CUDA memory - gMkl->dataMalloc(); - // Execute on CUDA mklRuntime->run(gMkl); - // check results on CPU EXPECT_TRUE(conv->getOutput(0)->equalData(ansVec)); } @@ -57,7 +54,7 @@ TEST(mkl_Conv, tune) { // check record auto kernelAttrs = - KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32}; + KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; std::optional perfData = PerfEngine::getInstance().getPerfData(perfKey); diff --git a/test/kernels/mkl/test_mkl_conv_transposed.cc b/test/kernels/intelcpu/test_mkl_conv_transposed.cc similarity index 93% rename from test/kernels/mkl/test_mkl_conv_transposed.cc rename to test/kernels/intelcpu/test_mkl_conv_transposed.cc index ab869896..44b04174 100644 --- a/test/kernels/mkl/test_mkl_conv_transposed.cc +++ b/test/kernels/intelcpu/test_mkl_conv_transposed.cc @@ -1,7 +1,7 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/perf_engine.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" #include "operators/conv.h" #include "test.h" @@ -26,7 +26,7 @@ void testConvTransposedMkl( i0->setData(generator); w0->setData(generator); - runtime->prepareAndRun(gMkl); + runtime->run(gMkl); EXPECT_TRUE(conv->getOutput()->equalData(ansVec)); } @@ -50,7 +50,7 @@ TEST(mkl_ConvTransposed, run1) { i0->setData(IncrementalGenerator()); w0->setData(IncrementalGenerator()); - runtime->prepareAndRun(gMkl); + runtime->run(gMkl); EXPECT_TRUE(conv->getOutput()->equalData(vector{ 162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553, 747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835, @@ -71,10 +71,10 @@ TEST(mkl_ConvTransposed, tune) { w0->setData(IncrementalGenerator()); bool tune = true; - runtime->prepareAndRun(gMkl, tune); + runtime->run(gMkl, tune); // check record auto kernelAttrs = - KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32}; + KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; std::optional perfData = PerfEngine::getInstance().getPerfData(perfKey); diff --git a/test/kernels/intelcpu/test_mkl_element_wise.cc b/test/kernels/intelcpu/test_mkl_element_wise.cc new file mode 100644 index 00000000..9b5a978c --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_element_wise.cc @@ -0,0 +1,84 @@ + +#include "core/graph.h" +#include "core/kernel.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/element_wise.h" +#include "operators/unary.h" +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testBinary(const std::function &generator, + const Shape &shape, const ExpectOutput &ansVec) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto a = g->addTensor(shape, DataType::Float32); + auto b = g->addTensor(shape, DataType::Float32); + auto op = g->addOp(a, b, nullptr); + g->dataMalloc(); + a->setData(generator); + b->setData(generator); + + runtime->run(g); + + auto c = op->getOutput(); + // check results on CPU + EXPECT_TRUE(c->equalData(ansVec)); +} + +TEST(dnnl_Binary, run) { + testBinary(IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); + testBinary(IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + testBinary( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + + testBinary(OneGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); +} + +TEST(sycl_Pow, run) { + testBinary(IncrementalGenerator(), Shape{1, 2, 2, 1}, + ExpectOutput{1, 1, 4, 27}); +} + +template +void testUnary(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime rCpu = NativeCpuRuntimeObj::getInstance(); + auto rMkl = make_ref(); + + // Build input data on CPU + + Graph gCpu = make_ref(rCpu); + Tensor iCpu = gCpu->addTensor(shape, DataType::Float32); + auto opCpu = gCpu->addOp(iCpu, nullptr); + gCpu->dataMalloc(); + iCpu->setData(generator); + rCpu->run(gCpu); + + // MKL + Graph gMkl = make_ref(rMkl); + auto iMkl = gMkl->addTensor(shape, DataType::Float32); + auto opMkl = gMkl->addOp(iMkl, nullptr); + gMkl->dataMalloc(); + iMkl->setData(generator); + rMkl->run(gMkl); + + // Check + EXPECT_TRUE(opCpu->getOutput()->equalData(opMkl->getOutput())); +} + +TEST(dnnl_Unary, run) { + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +}; // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_extend.cc b/test/kernels/intelcpu/test_mkl_extend.cc new file mode 100644 index 00000000..8375ad51 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_extend.cc @@ -0,0 +1,31 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/extend.h" + +#include "test.h" + +namespace infini { + +TEST(MKL_Extend, run) { + Runtime runtime = MklRuntimeObj::getInstance(); + + Graph g = make_ref(runtime); + Tensor i = g->addTensor(Shape{2, 3, 2, 2}, DataType::Float32); + auto op = g->addOp(i, nullptr, 1, 1); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute + runtime->run(g); + + auto o = op->getOutput(); + + // check results on CPU + EXPECT_TRUE(o->equalData(vector{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23})); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_gather.cc b/test/kernels/intelcpu/test_mkl_gather.cc new file mode 100644 index 00000000..1fc1f09b --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_gather.cc @@ -0,0 +1,60 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { +TEST(Gather, Cuda) { + { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({3, 2}, DataType::Float32); + auto index = g->addTensor({2, 2}, DataType::UInt32); + g->dataMalloc(); + input->copyin(vector{1, 2, 3, 4, 5, 6}); + index->copyin(vector{0, 1, 1, 2}); + + auto op = g->addOp(input, index, nullptr, 0); + g->dataMalloc(); + runtime->run(g); + + EXPECT_TRUE( + op->getOutput()->equalData(vector{1, 2, 3, 4, 3, 4, 5, 6})); + } + { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({3, 3}, DataType::Float32); + auto index = g->addTensor({1, 2}, DataType::UInt32); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 2}); + + auto op = g->addOp(input, index, nullptr, 1); + g->dataMalloc(); + runtime->run(g); + + EXPECT_TRUE( + op->getOutput()->equalData(vector{0, 2, 3, 5, 6, 8})); + } + { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({2, 4, 2}, DataType::Float32); + auto index = g->addTensor({3, 1}, DataType::UInt32); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 3, 1}); + + auto op = g->addOp(input, index, nullptr, 1); + g->dataMalloc(); + runtime->run(g); + + EXPECT_TRUE(op->getOutput()->equalData( + vector{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11})); + } +} +} // namespace infini diff --git a/test/kernels/mkl/test_mkl_matmul.cc b/test/kernels/intelcpu/test_mkl_matmul.cc similarity index 95% rename from test/kernels/mkl/test_mkl_matmul.cc rename to test/kernels/intelcpu/test_mkl_matmul.cc index e919ffd4..8fcfe964 100644 --- a/test/kernels/mkl/test_mkl_matmul.cc +++ b/test/kernels/intelcpu/test_mkl_matmul.cc @@ -2,7 +2,7 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "mkl/mkl_runtime.h" +#include "intelcpu/mkl_runtime.h" #include "operators/matmul.h" #include "test.h" @@ -27,7 +27,6 @@ void testMatmulMkl( gCpu->dataMalloc(); cpuRuntime->run(gCpu); - matmul->getOutput()->printData(); EXPECT_TRUE(matmul->getOutput()->equalData(ansVec)); } diff --git a/test/kernels/intelcpu/test_mkl_pad.cc b/test/kernels/intelcpu/test_mkl_pad.cc new file mode 100644 index 00000000..d0859421 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_pad.cc @@ -0,0 +1,30 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/pad.h" +#include "test.h" + +namespace infini { +TEST(Pad, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + // Build input data + Tensor i = g->addTensor(Shape{1, 2, 3, 2}, DataType::Float32); + auto op = g->addOp(i, nullptr, vector{1, 0, 1, 1}, + vector{0, 3}); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute + runtime->run(g); + + auto o = op->getOutput(); + + // check results + EXPECT_TRUE(o->equalData( + vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_pooling.cc b/test/kernels/intelcpu/test_mkl_pooling.cc new file mode 100644 index 00000000..5d25bb22 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_pooling.cc @@ -0,0 +1,47 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/pooling.h" +#include "test.h" + +namespace infini { +using KDPS = vector; +using ExpectOutput = vector; + +template +void testPoolMkl(const std::function &generator, + const Shape &shape, const KDPS &kdps, + const ExpectOutput &ansVec) { + EXPECT_TRUE(kdps.size() == 8); + Runtime runtime = MklRuntimeObj::getInstance(); + + Graph g = make_ref(runtime); + // Build input data + Tensor i0 = g->addTensor(shape, DataType::Float32); + auto pool = g->addOp(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3], + kdps[4], kdps[5], kdps[6], kdps[7]); + g->dataMalloc(); + i0->setData(generator); + + runtime->run(g); + // check results on CPU + EXPECT_TRUE(pool->getOutput()->equalData(ansVec)); +} + +TEST(mkl_MaxPool, run) { + testPoolMkl(IncrementalGenerator(), Shape{1, 2, 5, 5}, + KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31, + 33, 34, 41, 43, 44, 46, 48, 49}); +} + +TEST(mkl_AvgPool, run) { + testPoolMkl( + IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000, + 8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778, + 23.666667, 37.0000, 25.666667, 19.111111, 29.666667, + 20.444444}); +} + +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_reduce.cc b/test/kernels/intelcpu/test_mkl_reduce.cc new file mode 100644 index 00000000..859a1f91 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_reduce.cc @@ -0,0 +1,52 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/reduce_mean.h" + +#include "test.h" + +namespace infini { + +void test_reducemean(const Shape &shape, const vector &data, + const optional> &axis, bool keepDims, + const vector &ExpectData) { + Runtime runtime = MklRuntimeObj::getInstance(); + + Graph g = make_ref(runtime); + Tensor i = g->addTensor(shape, DataType::Float32); + auto op = g->addOp(i, nullptr, axis, keepDims); + + g->dataMalloc(); + i->copyin(data); + + // Execute + runtime->run(g); + + auto o = op->getOutput(); + + // check results + EXPECT_TRUE(o->equalData(ExpectData)); +} + +TEST(MKL_ReduceMean, run) { + test_reducemean(Shape{3, 2, 2}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reducemean(Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, false, vector{18.25}); + + test_reducemean(Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reducemean(Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); +} + +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_reshape.cc b/test/kernels/intelcpu/test_mkl_reshape.cc new file mode 100644 index 00000000..8baec112 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_reshape.cc @@ -0,0 +1,57 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/reshape.h" + +#include "test.h" + +namespace infini { + +TEST(Reshape, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr, Shape{3, 2, 4, 3}); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + + runtime->run(g); + + auto o = g->cloneTensor(op->getOutput(0)); + // check results + EXPECT_TRUE(o->equalData(input)); +} + +TEST(Flatten, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr, 2); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + + runtime->run(g); + + auto o = g->cloneTensor(op->getOutput(0)); + // check results + EXPECT_TRUE(o->equalData(input)); +} + +TEST(Identify, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + + runtime->run(g); + + auto o = g->cloneTensor(op->getOutput(0)); + // check results + EXPECT_TRUE(o->equalData(input)); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_resize.cc b/test/kernels/intelcpu/test_mkl_resize.cc new file mode 100644 index 00000000..c3c71a9d --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_resize.cc @@ -0,0 +1,30 @@ +#include "cmath" +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/resize.h" +#include "test.h" +namespace infini { +TEST(Resize, Mkl_downsample_sizes_nearest) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + + auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({4}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8}); + sizes->copyin(vector{1, 1, 1, 3}); + + auto runtime = make_ref(); + Graph g = make_ref(runtime); + + auto op = g->addOp(g->cloneTensor(input), nullptr, std::nullopt, + g->cloneTensor(sizes), nullptr, nullptr, + ResizeObj::EKeepAspectRatioPolicy::stretch, + ResizeObj::ENearestMode::ceil); + g->dataMalloc(); + runtime->run(g); + + EXPECT_TRUE(op->getOutput(0)->equalData(vector{5, 7, 8})); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_slice.cc b/test/kernels/intelcpu/test_mkl_slice.cc new file mode 100644 index 00000000..04a5ae86 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_slice.cc @@ -0,0 +1,26 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(MKL_Slice, run) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + // Build input data + Tensor i = g->addTensor(Shape{3, 2, 1, 5}, DataType::Float32); + auto op = + g->addOp(i, nullptr, vector{1, 1}, vector{1, 4}, + vector{0, 3}, std::nullopt); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute + runtime->run(g); + + auto o = op->getOutput(); + EXPECT_TRUE(o->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_softmax.cc b/test/kernels/intelcpu/test_mkl_softmax.cc new file mode 100644 index 00000000..e802e7d9 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_softmax.cc @@ -0,0 +1,83 @@ + +#include "core/graph.h" +#include "core/kernel.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/softmax.h" +#include "test.h" + +namespace infini { +TEST(MklSoftmax, run) { + // Runtime + auto runtime = make_ref(); + + // Build input data on intelcpu + Graph g = make_ref(runtime); + Tensor i = g->addTensor(Shape{2, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, 1); + g->dataMalloc(); + i->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + runtime->run(g); + + // Check + EXPECT_TRUE(op->getOutput(0)->equalData( + vector{0.032058604, 0.08714432, 0.23688284, 0.6439143, + 0.032058604, 0.08714432, 0.23688284, 0.6439143})); +} + +TEST(MklSoftmax, run_axis1) { + // Runtime + auto runtime = make_ref(); + + // Build input data on intelcpu + Graph g = make_ref(runtime); + Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32); + auto op = g->addOp(i, nullptr, 1); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + runtime->run(g); + + // Check + EXPECT_TRUE(op->getOutput(0)->equalData(vector{ + 0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138, + 0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862, + 0.9820138, 0.9820138, 0.9820138, 0.9820138})); +} + +TEST(MklSoftmax, run_axis2) { + // Runtime + auto runtime = make_ref(); + + // Build input data on intelcpu + Graph g = make_ref(runtime); + Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32); + auto op = g->addOp(i, nullptr, 2); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + runtime->run(g); + + // Check + EXPECT_TRUE(op->getOutput(0)->equalData(vector{ + 0.119203, 0.119203, 0.880797, 0.880797, 0.119203, 0.119203, 0.880797, + 0.880797, 0.119203, 0.119203, 0.880797, 0.880797, 0.119203, 0.119203, + 0.880797, 0.880797})); +} + +TEST(MklSoftmax, run_axis3) { + // Runtime + auto runtime = make_ref(); + + // Build input data on intelcpu + Graph g = make_ref(runtime); + Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32); + auto op = g->addOp(i, nullptr, 3); + g->dataMalloc(); + i->setData(IncrementalGenerator()); + runtime->run(g); + + // Check + EXPECT_TRUE(op->getOutput(0)->equalData(vector{ + 0.2689414, 0.7310585, 0.2689414, 0.7310585, 0.2689414, 0.7310585, + 0.2689414, 0.7310585, 0.2689414, 0.7310585, 0.2689414, 0.7310585, + 0.2689414, 0.7310585, 0.2689414, 0.7310585})); +} +} // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_split.cc b/test/kernels/intelcpu/test_mkl_split.cc new file mode 100644 index 00000000..94509672 --- /dev/null +++ b/test/kernels/intelcpu/test_mkl_split.cc @@ -0,0 +1,33 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "intelcpu/mkl_runtime.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { + +TEST(Split, Mkl) { + Runtime runtime = MklRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto input = g->addTensor({2, 10, 2, 1}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, 1, 3); + g->dataMalloc(); + input->setData(IncrementalGenerator()); + + runtime->run(g); + + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0 = g->cloneTensor(op->getOutput(0)); + auto o1 = g->cloneTensor(op->getOutput(1)); + auto o2 = g->cloneTensor(op->getOutput(2)); + EXPECT_TRUE( + o0->equalData(vector{0, 1, 2, 3, 4, 5, 20, 21, 22, 23, 24, 25})); + EXPECT_TRUE(o1->equalData( + vector{6, 7, 8, 9, 10, 11, 26, 27, 28, 29, 30, 31})); + EXPECT_TRUE(o2->equalData(vector{12, 13, 14, 15, 16, 17, 18, 19, 32, + 33, 34, 35, 36, 37, 38, 39})); +} + +} // namespace infini diff --git a/test/operators/test_reshape.cc b/test/operators/test_reshape.cc index 457a06ea..00ab514f 100644 --- a/test/operators/test_reshape.cc +++ b/test/operators/test_reshape.cc @@ -21,8 +21,26 @@ TEST(Flatten, ShapeInference) { { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr); - EXPECT_EQ(op->getOutput()->getDims(), (Shape{72})); + auto op = g->addOp(i, nullptr, 1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 36})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, 0); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 72})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, -1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{18, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, -2); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{6, 12})); } } diff --git a/test/script/env_lotus.sh b/test/script/env_lotus.sh index 72268491..428024f1 100644 --- a/test/script/env_lotus.sh +++ b/test/script/env_lotus.sh @@ -1,5 +1,26 @@ +#!/bin/bash + . /home/spack/spack/share/spack/setup-env.sh -spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 -export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc -# The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly. -export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH +if [ "$#" == 0 ] || [ "$1" == "cuda" ] +then + echo "Load CUDA environment." + spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 + export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc +elif [ "$1" == "intelcpu" ] +then + echo "Load INTELCPU environment." + spack load intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0 + # The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly. + export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH + + + # flopen mkl libs will fail when used by python. + # Refering to "https://groups.google.com/g/kaldi-help/c/m3nyQke0HS0/m/4fj8gkSWAgAJ", it is recommended to use mkl_rt instead, + # but mkl_rt do not support dpc++ refered to https://www.intel.com/content/www/us/en/docs/onemkl/developer-guide-linux/2023-0/using-the-single-dynamic-library.html + # Preloading the missing libs will work, refered to https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/mkl-fails-to-load/m-p/1155538 + + export MKLLIB_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-mkl-2022.1.0-mf6te62fo6wxlo33jwwwgg5kljoagc6g/mkl/2022.1.0/ + export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-11.3.0/intel-oneapi-compilers-2022.1.0-qrq4a63scjip455bpxvl5ipgqbllwecj/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so +else + echo "Bad option. Please enter 'cuda' or 'intelcpu'. CUDA will be loaded by default if nothing specified." +fi