From 1184fa131f67e4418adbb27a9c502f7b0901e3a8 Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:57:08 +0800 Subject: [PATCH] Xpu (#82) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support kunlun xpu and add an operator named Add * add sub, mul, div, pow, maximum, minimum * add code * add xpu code * add code * add matmul * add transpose * add unary operator * add unary operator * add some operator * add code * support run resnet18 on xpu * add code * add max pool2d * fix xpu code, let it can run. * 添加XPU算子 (#120) * add floordiv for xpu * add batchnorm for xpu * add more cast types for xpu * add conv_trans for xpu * add pad for xpu * add logical ops for xpu * fix format for xpu src and include * fix format for xpu test * fix format for xpu src --------- Co-authored-by: Bolun * Xpu abs (#121) * add: unary kernel for xpu * formatting * format * format * format * fix: pointer jump * fix optype comments * fix bug introduced while resolving conflict * change cmake option for kunlunxin xpu from 'xpu' to 'kunlun'; fix bug after merging distributed infrastructure * Add doc support for xpu (#141) * fix * fix * fix pooling test * format * format * fix * fix * set cmake version requirement * fix cmakelists * rename xpu to kunlun * fix * fix format * fix format * fix format * fix change name to kunlun * format * fix format * clang format * fix format --------- Co-authored-by: root Co-authored-by: wanghailu Co-authored-by: wanghailu Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com> Co-authored-by: Bolun Co-authored-by: zhangyue207 <138768300+zhangyue207@users.noreply.github.com> Co-authored-by: Haojie Wang Co-authored-by: baominghelly <41820386+baominghelly@users.noreply.github.com> Co-authored-by: Bolun --- CMakeLists.txt | 52 +- Makefile | 2 + docs/INSTALL_GUIDE_CN.md | 7 + docs/USER_GUIDE_CN.md | 1 + env.sh | 2 +- include/core/op_type.h | 8 +- include/core/runtime.h | 3 +- include/core/tensor.h | 5 +- include/kunlun/kunlun_common.h | 20 + include/kunlun/kunlun_kernel_without_config.h | 24 + include/kunlun/kunlun_runtime.h | 73 +++ include/kunlun/operator_timer.h | 10 + pyinfinitensor/src/pyinfinitensor/onnx.py | 18 +- pyinfinitensor/tests/test_onnx.py | 16 +- scripts/format.py | 2 +- src/core/tensor.cc | 3 +- src/ffi/ffi_infinitensor.cc | 17 + src/kernels/cpu/element_wise.cc | 35 ++ src/kernels/cpu/unary.cc | 101 ++++ src/kernels/cuda/gather_elements.cu | 3 +- src/kernels/kunlun/batch_norm.cc | 41 ++ src/kernels/kunlun/cast.cc | 98 ++++ src/kernels/kunlun/concat.cc | 37 ++ src/kernels/kunlun/conv.cc | 37 ++ src/kernels/kunlun/conv_trans.cc | 54 ++ src/kernels/kunlun/element_wise.cc | 476 +++++++++++++++ src/kernels/kunlun/matmul.cc | 38 ++ src/kernels/kunlun/pad.cc | 37 ++ src/kernels/kunlun/pooling.cc | 62 ++ src/kernels/kunlun/split.cc | 41 ++ src/kernels/kunlun/transpose.cc | 32 + src/kernels/kunlun/unary.cc | 550 ++++++++++++++++++ src/kunlun/kunlun_runtime.cc | 60 ++ src/kunlun/operator_timer.cc | 71 +++ test/kernels/kunlun/test_kunlun_add.cc | 61 ++ test/kernels/kunlun/test_kunlun_batch_norm.cc | 61 ++ test/kernels/kunlun/test_kunlun_concat.cc | 54 ++ test/kernels/kunlun/test_kunlun_conv.cc | 56 ++ test/kernels/kunlun/test_kunlun_conv_trans.cc | 136 +++++ .../kunlun/test_kunlun_element_wise.cc | 66 +++ test/kernels/kunlun/test_kunlun_matmul.cc | 58 ++ test/kernels/kunlun/test_kunlun_pad.cc | 40 ++ test/kernels/kunlun/test_kunlun_pooling.cc | 51 ++ test/kernels/kunlun/test_kunlun_split.cc | 48 ++ test/kernels/kunlun/test_kunlun_transpose.cc | 43 ++ test/kernels/kunlun/test_kunlun_unary.cc | 190 ++++++ 46 files changed, 2874 insertions(+), 26 deletions(-) create mode 100644 include/kunlun/kunlun_common.h create mode 100644 include/kunlun/kunlun_kernel_without_config.h create mode 100644 include/kunlun/kunlun_runtime.h create mode 100644 include/kunlun/operator_timer.h create mode 100644 src/kernels/kunlun/batch_norm.cc create mode 100644 src/kernels/kunlun/cast.cc create mode 100644 src/kernels/kunlun/concat.cc create mode 100644 src/kernels/kunlun/conv.cc create mode 100644 src/kernels/kunlun/conv_trans.cc create mode 100644 src/kernels/kunlun/element_wise.cc create mode 100644 src/kernels/kunlun/matmul.cc create mode 100644 src/kernels/kunlun/pad.cc create mode 100644 src/kernels/kunlun/pooling.cc create mode 100644 src/kernels/kunlun/split.cc create mode 100644 src/kernels/kunlun/transpose.cc create mode 100644 src/kernels/kunlun/unary.cc create mode 100644 src/kunlun/kunlun_runtime.cc create mode 100644 src/kunlun/operator_timer.cc create mode 100644 test/kernels/kunlun/test_kunlun_add.cc create mode 100644 test/kernels/kunlun/test_kunlun_batch_norm.cc create mode 100644 test/kernels/kunlun/test_kunlun_concat.cc create mode 100644 test/kernels/kunlun/test_kunlun_conv.cc create mode 100644 test/kernels/kunlun/test_kunlun_conv_trans.cc create mode 100644 test/kernels/kunlun/test_kunlun_element_wise.cc create mode 100644 test/kernels/kunlun/test_kunlun_matmul.cc create mode 100644 test/kernels/kunlun/test_kunlun_pad.cc create mode 100644 test/kernels/kunlun/test_kunlun_pooling.cc create mode 100644 test/kernels/kunlun/test_kunlun_split.cc create mode 100644 test/kernels/kunlun/test_kunlun_transpose.cc create mode 100644 test/kernels/kunlun/test_kunlun_unary.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 291adf92..72df016c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,16 +1,23 @@ -cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit -include(CMakeDependentOption) -project(InfiniTensor C CXX) - # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. option(USE_CUDA "Support CUDA GPU" OFF) option(USE_BANG "Support BANG MLU" OFF) +option(USE_KUNLUN "Support KUNLUN XPU" OFF) option(USE_INTELCPU "Support INTELCPU" OFF) option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) option(USE_PROTOBUF "Serialize and deserialize tensors" OFF) option(BUILD_DIST "Build project for distributed running" OFF) option(BUILD_TEST "Build tests" OFF) +if(USE_CUDA) + message("CMake 3.18 or higher is required for setting CUDAToolkit") + cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit +else() + cmake_minimum_required(VERSION 3.12) +endif() + +include(CMakeDependentOption) +project(InfiniTensor C CXX) + cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF) @@ -128,6 +135,11 @@ if(USE_BANG) list (APPEND SRC ${SRC_BANG}) endif() +if(USE_KUNLUN) + file(GLOB_RECURSE SRC_KUNLUN src/kunlun/*.cc src/kernels/kunlun/*.cc ) + list (APPEND SRC ${SRC_KUNLUN}) +endif() + if(USE_INTELCPU) file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc ) list (APPEND SRC ${SRC_INTELCPU}) @@ -243,6 +255,35 @@ if(USE_BANG) target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) endif() +if(USE_KUNLUN) + add_compile_definitions(USE_KUNLUN=1) + if ((NOT DEFINED KUNLUN_HOME) AND (NOT DEFINED ENV{KUNLUN_HOME})) + message(FATAL_ERROR "KUNLUN_HOME is not defined from cmake or env") + elseif (DEFINED KUNLUN_HOME) + set(KUNLUN_HOME ${KUNLUN_HOME} CACHE STRING "KUNLUN_HOME directory for Kunlun development") + else() + set(KUNLUN_HOME $ENV{KUNLUN_HOME} CACHE STRING "KUNLUN_HOME directory for Kunlun development") + endif() + message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}") + + include_directories("${KUNLUN_HOME}/XTDK/include/") + find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64") + find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/XTDK/shlib") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") + + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) + execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) + set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") + elseif(DEFINED TARGET_CPU_ARCH) + set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + else() + set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + endif() + message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}") + + target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++) +endif() + # # Python bindings # pybind11_add_module(infini MODULE ${FFI}) # target_link_libraries(infini PRIVATE infini_cpp) @@ -275,6 +316,9 @@ if(BUILD_TEST) if (USE_BANG) build_test(test/kernels/bang/*.cc) endif() + if (USE_KUNLUN) + build_test(test/kernels/kunlun/*.cc) + endif() if (USE_INTELCPU) build_test(test/kernels/intelcpu/*.cc) endif() diff --git a/Makefile b/Makefile index 01784937..19f1b353 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ TYPE ?= Release CUDA ?= OFF BANG ?= OFF +KUNLUN ?= OFF INTELCPU ?= off BACKTRACE ?= ON TEST ?= ON @@ -25,6 +26,7 @@ endif CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) CMAKE_OPT += -DUSE_CUDA=$(CUDA) CMAKE_OPT += -DUSE_BANG=$(BANG) +CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) diff --git a/docs/INSTALL_GUIDE_CN.md b/docs/INSTALL_GUIDE_CN.md index 285187ef..bac2534e 100644 --- a/docs/INSTALL_GUIDE_CN.md +++ b/docs/INSTALL_GUIDE_CN.md @@ -133,6 +133,13 @@ make install-python BANG=ON ``` + 编译 CPU 部分,同时编译昆仑 XPU 部分: + + ```bash + export KUNLUN_HOME=/path/to/your/kunlun_home + make install-python KUNLUN=ON + ``` + 3. 使用方法 安装成功后,您就可以使用本项目的 Python 接口进行编码并运行。具体使用方式可以参考项目样例代码 example/Resnet/resnet.py 以及用户使用手册 diff --git a/docs/USER_GUIDE_CN.md b/docs/USER_GUIDE_CN.md index b45e07da..e6a12c31 100644 --- a/docs/USER_GUIDE_CN.md +++ b/docs/USER_GUIDE_CN.md @@ -26,6 +26,7 @@ - `TYPE`:编译模式(`debug`/`release`),默认值为 `release` - `CUDA`:是否编译 CUDA 后端,默认为 `OFF`,`ON` 打开 - `BANG`:是否编译寒武纪后端,默认为 `OFF`,`ON` 打开 +- `KUNLUN`:是否编译昆仑后端,默认为 `OFF`,`ON` 打开 - `BACKTRACE`:是否启用栈回溯,默认为 `ON`,`OFF` 关闭,建议调试时打开 - `TEST`:是否编译 `googletest`,默认为 `ON`,`OFF` 关闭,只有 `test-cpp` 时必要 diff --git a/env.sh b/env.sh index 58e74a86..6971436f 100644 --- a/env.sh +++ b/env.sh @@ -35,4 +35,4 @@ export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}" # ├── tools # ├── version # └── XTDK -export XPU_HOME=/usr/local/xpu +export KUNLUN_HOME=/usr/local/xpu diff --git a/include/core/op_type.h b/include/core/op_type.h index 82439650..ad2e6acb 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -21,10 +21,10 @@ struct OpType { Add, // Binary And, // Binary ArgMax, // - Asin, // Binary - Asinh, // Binary - Atan, // Binary - Atanh, // Binary + Asin, // Unary + Asinh, // Unary + Atan, // Unary + Atanh, // Unary AveragePool, // Pool BatchNormalization, // Bernoulli, // diff --git a/include/core/runtime.h b/include/core/runtime.h index bd9da89a..5bc2123e 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -30,7 +30,7 @@ using OpLists = list; using VType = uint32_t; -enum class Device { CPU = 1, CUDA, BANG, INTELCPU }; +enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN }; /***************** Forward declaration end *****************/ class RuntimeObj : public std::enable_shared_from_this { @@ -72,6 +72,7 @@ class RuntimeObj : public std::enable_shared_from_this { } bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } + bool isKUNLUN() const { return device == Device::KUNLUN; } void copyBlob(const TensorObj *dst, const TensorObj *src) const; // TODO: unify these copy APIs virtual void copyBlobFromCPU(void *dst, const void *src, diff --git a/include/core/tensor.h b/include/core/tensor.h index edaa8655..48590fd6 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -180,14 +180,15 @@ class TensorObj : public TensorBaseObj { } template - bool equalDataImpl(const T *a, const T *b, size_t size) const { + bool equalDataImpl(const T *a, const T *b, size_t size, + double relativeError = 1e-6) const { for (size_t i = 0; i < size; ++i) { if constexpr (std::is_integral_v) { if (a[i] != b[i]) return false; } else if constexpr (std::is_floating_point_v) { if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) > - 1e-6) { + relativeError) { printf("Error on %lu: %f %f\n", i, a[i], b[i]); return false; } diff --git a/include/kunlun/kunlun_common.h b/include/kunlun/kunlun_common.h new file mode 100644 index 00000000..2350cc93 --- /dev/null +++ b/include/kunlun/kunlun_common.h @@ -0,0 +1,20 @@ +#pragma once +#include "core/common.h" +#include "xpu/runtime_ex.h" +#include "xpu/xdnn.h" + +#define checkKUNLUNError(call) \ + { \ + auto err = call; \ + if (XPU_SUCCESS != err) { \ + fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \ + __LINE__, xpu_strerror(err)); \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace infini { + +using KUNLUNPtr = void *; + +} // namespace infini diff --git a/include/kunlun/kunlun_kernel_without_config.h b/include/kunlun/kunlun_kernel_without_config.h new file mode 100644 index 00000000..6f53f471 --- /dev/null +++ b/include/kunlun/kunlun_kernel_without_config.h @@ -0,0 +1,24 @@ +#pragma once +#include "core/kernel.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { + +class KUNLUNKernelWithoutConfig : public Kernel { + public: + virtual void compute(const Operator &op, const PerfRecord &record, + const RuntimeObj *context) const { + compute(op, context); + } + virtual void compute(const Operator &op, + const RuntimeObj *context) const = 0; + // Premise: op is idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op, + const RuntimeObj *_context) const { + auto context = dynamic_cast(_context); + return make_ref(timeit([&]() { compute(op, _context); }, + [&]() { context->sync(); })); + } +}; + +} // namespace infini diff --git a/include/kunlun/kunlun_runtime.h b/include/kunlun/kunlun_runtime.h new file mode 100644 index 00000000..6a5be4c9 --- /dev/null +++ b/include/kunlun/kunlun_runtime.h @@ -0,0 +1,73 @@ +#pragma once +#include "core/runtime.h" +#include "kunlun/kunlun_common.h" + +namespace infini { + +class KUNLUNRuntimeObj : public RuntimeObj { + private: + baidu::xpu::api::Context *xdnn; + KUNLUNPtr workspace; + size_t workspaceSize; + + public: + KUNLUNRuntimeObj() : RuntimeObj(Device::KUNLUN) { + xdnn = baidu::xpu::api::create_context(); + // 10GB for Longformer + // size_t longformerNum = 3lu * (1 << 30); + workspaceSize = 3ll << 30; // 3 GB + // std::cout<(workspaceSize)<< std::endl; + workspace = alloc(workspaceSize); + } + virtual ~KUNLUNRuntimeObj() { + dealloc(workspace); + baidu::xpu::api::destroy_context(xdnn); + } + string toString() const override; + + void run(const Graph &graph, bool tune = false, + bool profiling = false) const; + // double runEvaluation(const Graph &graph, int nWarmups, + // int nEvaluations) const; + void sync() const; + KUNLUNPtr alloc(size_t size) override { + void *ptr; + checkKUNLUNError( + xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN)); + return ptr; + } + void dealloc(void *ptr) override { xpu_free(ptr); } + baidu::xpu::api::Context *KUNLUNHandle() const { return xdnn; } + KUNLUNPtr getWorkspace(size_t size) const { + IT_ASSERT(size <= workspaceSize); + return workspace; + } + + void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + } + + void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } + + void copyBlobInsideRuntime(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, + XPUMemcpyKind::XPU_DEVICE_TO_DEVICE); + } + + void initComm(const string &, int, int) override { IT_TODO_HALT(); } + + CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } + + private: + void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; +}; + +} // namespace infini diff --git a/include/kunlun/operator_timer.h b/include/kunlun/operator_timer.h new file mode 100644 index 00000000..15ebce94 --- /dev/null +++ b/include/kunlun/operator_timer.h @@ -0,0 +1,10 @@ +#pragma once +namespace infini { +namespace opTimer { +double getPerfConvXdnn(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name); +double getPerfMatmulXdnn(int b, int m, int n, int k, const char *name); +} // namespace opTimer +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index e4336dc4..d11fbb90 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -35,6 +35,7 @@ class OnnxStub: The Onnx model imported into infinitensor. It can be generated from an Onnx model object. """ + def __init__(self, model: ModelProto, runtime): # We use some user-defined operators for distributed inference try: @@ -74,7 +75,6 @@ class OnnxStub: ) tensors[output.name].set_output() - node_name = [] new_node_name = [] for node in model.graph.node: @@ -244,7 +244,13 @@ class OnnxStub: ) (k, d, p, s, ceil_mode) = ( attributes[name] - for name in ["kernel_shape", "dilations", "pads", "strides", "ceil_mode"] + for name in [ + "kernel_shape", + "dilations", + "pads", + "strides", + "ceil_mode", + ] ) if p[0] != p[2] or p[1] != p[3]: adapt = "{}-adapt".format(node.output[0]) @@ -289,7 +295,8 @@ class OnnxStub: }, ) (k, p, s, ceil_mode) = ( - attributes[name] for name in ["kernel_shape", "pads", "strides", "ceil_mode"] + attributes[name] + for name in ["kernel_shape", "pads", "strides", "ceil_mode"] ) if p[0] != p[2] or p[1] != p[3]: adapt = "{}-adapt".format(node.output[0]) @@ -714,10 +721,9 @@ class OnnxStub: elif node.op_type == "Constant": output_name = node.output[0] attributes = _parse_attribute(node) - tensor = attributes['value'] + tensor = attributes["value"] dims = [d for d in tensor.dims] - tensors[output_name] = self.handler.tensor( - dims, tensor.data_type) + tensors[output_name] = self.handler.tensor(dims, tensor.data_type) data[output_name] = tensor tensors[output_name].set_weight() else: diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 3808f516..035baf34 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -208,7 +208,7 @@ class TestStringMethods(unittest.TestCase): relu = make_node("Relu", ["x"], ["y"], name="relu") make_and_import_model(make_graph([relu], "relu", [x], [y])) - '''Gelu operator is not supported by onnx 14.1 currently.''' + """Gelu operator is not supported by onnx 14.1 currently.""" def test_gelu(self): pass # x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) @@ -239,7 +239,7 @@ class TestStringMethods(unittest.TestCase): y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) tanh = make_node("Tanh", ["x"], ["y"], name="tanh") make_and_import_model(make_graph([tanh], "tanh", [x], [y])) - + def test_hard_sigmoid(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) @@ -263,7 +263,7 @@ class TestStringMethods(unittest.TestCase): y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) abs = make_node("Abs", ["x"], ["y"], name="abs") make_and_import_model(make_graph([abs], "abs", [x], [y])) - + def test_neg(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) @@ -319,9 +319,15 @@ class TestStringMethods(unittest.TestCase): indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2]) output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 1, 2]) gatherElements = make_node( - "GatherElements", ["data", "indices"], ["output"], axis=1, name="gatherElements" + "GatherElements", + ["data", "indices"], + ["output"], + axis=1, + name="gatherElements", + ) + make_and_import_model( + make_graph([gatherElements], "gatherElements", [data, indices], [output]) ) - make_and_import_model(make_graph([gatherElements], "gatherElements", [data, indices], [output])) def test_reduce_mean(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) diff --git a/scripts/format.py b/scripts/format.py index 7d74d54b..e1c5665d 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -11,7 +11,7 @@ proj_path = Path(sys.path[0]).parent def format_file(file): file = Path(proj_path.joinpath(file)) if file.suffix in c_style_file: - run(f"clang-format-14 -i {file}", cwd=proj_path, shell=True) + run(f"clang-format-14 -style=file -i {file}", cwd=proj_path, shell=True) run(f"git add {file}", cwd=proj_path, shell=True) elif file.suffix == py_file: run(f"black {file}", cwd=proj_path, shell=True) diff --git a/src/core/tensor.cc b/src/core/tensor.cc index d318f014..e34fb8bc 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -100,7 +100,8 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const { #define TEST_EQUAL(N) \ if (dtype == DataType(N)) \ return equalDataImpl(getRawDataPtr::t *>(), \ - rhs->getRawDataPtr::t *>(), size()); + rhs->getRawDataPtr::t *>(), size(), \ + relativeError); TEST_EQUAL(0) // fmt: new line else TEST_EQUAL(1) // diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 9881f92a..e1a726c3 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -24,6 +24,9 @@ #ifdef USE_BANG #include "bang/bang_runtime.h" #endif +#ifdef USE_KUNLUN +#include "kunlun/kunlun_runtime.h" +#endif #ifdef USE_INTELCPU #include "intelcpu/mkl_runtime.h" #include "intelcpu/operator_timer.h" @@ -158,6 +161,12 @@ static int tensor_dtype(Tensor t) { static Ref bang_runtime() { return make_ref(); } #endif +#ifdef USE_KUNLUN +static Ref kunlun_runtime() { + return make_ref(); +} +#endif + #ifdef USE_INTELCPU static Ref intelcpu_runtime() { return make_ref(); } #endif @@ -292,6 +301,10 @@ void export_functions(py::module &m) { #ifdef USE_BANG .FUNCTION(bang_runtime) #endif + +#ifdef USE_KUNLUN + .FUNCTION(kunlun_runtime) +#endif .FUNCTION(conv_attrs_of) .FUNCTION(conv_trans_attrs_of) .FUNCTION(matmul_attrs_of) @@ -365,6 +378,10 @@ void init_graph_builder(py::module &m) { #ifdef USE_BANG py::class_, RuntimeObj>( m, "BangRuntime"); +#endif +#ifdef USE_KUNLUN + py::class_, RuntimeObj>( + m, "KUNLUNRuntime"); #endif py::class_>(m, "Tensor", py::buffer_protocol()) diff --git a/src/kernels/cpu/element_wise.cc b/src/kernels/cpu/element_wise.cc index 8657a1fe..8d225779 100644 --- a/src/kernels/cpu/element_wise.cc +++ b/src/kernels/cpu/element_wise.cc @@ -58,6 +58,21 @@ template class NaiveMul : public NativeElementWise { template class NaiveDiv : public NativeElementWise { T doCompute(T val0, T val1) const override { return (T)(val0 / val1); } }; +template class NaiveEqual : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 == val1); } +}; +template class NaiveGreaterEqual : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); } +}; +template class NaiveGreaterThan : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 > val1); } +}; +template class NaiveLessEqual : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); } +}; +template class NaiveLessThan : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 < val1); } +}; REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd, "addNaive_CPU_uint32"); @@ -75,4 +90,24 @@ REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv, "divNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv, "divNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32, + NaiveEqual, "equalNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32, + NaiveEqual, "equalNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32, + NaiveGreaterEqual, "greaterEqualNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32, + NaiveGreaterEqual, "greaterEqualNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32, + NaiveGreaterThan, "greaterThanNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32, + NaiveGreaterThan, "greaterThanNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32, + NaiveLessEqual, "lessEqualNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32, + NaiveLessEqual, "lessEqualNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32, + NaiveLessThan, "lessEqualNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32, + NaiveLessThan, "lessEqualNaive_CPU_float32"); }; // namespace infini diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index ed2f30c7..8975d7cd 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -71,6 +71,26 @@ template class NaiveSqrt : public NativeUnary { T doCompute(T val) const override { return std::sqrt(val); } }; +template class NaiveCos : public NativeUnary { + T doCompute(T val) const override { return std::cos(val); } +}; + +template class NaiveSin : public NativeUnary { + T doCompute(T val) const override { return std::sin(val); } +}; + +template class NaiveTan : public NativeUnary { + T doCompute(T val) const override { return std::tan(val); } +}; + +template class NaiveSinh : public NativeUnary { + T doCompute(T val) const override { return std::sinh(val); } +}; + +template class NaiveCosh : public NativeUnary { + T doCompute(T val) const override { return std::cosh(val); } +}; + template class NaiveGelu : public NativeUnary { T doCompute(T val) const override { return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); @@ -81,6 +101,26 @@ template class NaiveErf : public NativeUnary { T doCompute(T val) const override { return std::erf(val); } }; +template class NaiveACos : public NativeUnary { + T doCompute(T val) const override { return std::acos(val); } +}; + +template class NaiveACosh : public NativeUnary { + T doCompute(T val) const override { return std::acosh(val); } +}; + +template class NaiveASin : public NativeUnary { + T doCompute(T val) const override { return std::asin(val); } +}; + +template class NaiveASinh : public NativeUnary { + T doCompute(T val) const override { return std::asinh(val); } +}; + +template class NaiveATanh : public NativeUnary { + T doCompute(T val) const override { return std::atanh(val); } +}; + template class NaiveNeg : public NativeUnary { T doCompute(T val) const override { return -val; } }; @@ -104,6 +144,43 @@ template class Clip : public CpuKernelWithoutConfig { } }; +template class Log : public CpuKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *context) const override { + auto op = as(_op); + T *inptr = op->getInputs(0)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + auto logType = op->getType(); // get log type + + auto len = op->getOutput()->size(); + for (size_t offset = 0; offset < len; offset++) { + T res; + auto val = *inptr++; + switch (logType) { + case LogObj::LogE: + res = std::log(val); + *outptr++ = res; + break; + case LogObj::Log2: + res = std::log2(val); + *outptr++ = res; + break; + case LogObj::Log10: + res = std::log10(val); + *outptr++ = res; + break; + default: + printf("LogType not Defined"); + break; + } + } + } +}; + +template class NaiveATan : public NativeUnary { + T doCompute(T val) const override { return std::atan(val); } +}; + REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32, NaiveRelu, "reluNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu, @@ -140,4 +217,28 @@ REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, NaiveSoftmax, "softmaxNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip, "Clip_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan, + "Atan_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log, + "Log_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos, + "Cos_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin, + "Sin_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan, + "Tan_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh, + "Sinh_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh, + "Cosh_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos, + "ACos_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32, + NaiveACosh, "ACosh_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin, + "ASin_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32, + NaiveASinh, "ASinh_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32, + NaiveATanh, "ATanh_CPU_float32"); }; // namespace infini diff --git a/src/kernels/cuda/gather_elements.cu b/src/kernels/cuda/gather_elements.cu index 675a6b15..0b7817eb 100644 --- a/src/kernels/cuda/gather_elements.cu +++ b/src/kernels/cuda/gather_elements.cu @@ -59,7 +59,8 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else { - IT_TODO_HALT_MSG("GatherElements Cuda Kernel: Unsupported data type.\n"); + IT_TODO_HALT_MSG( + "GatherElements Cuda Kernel: Unsupported data type.\n"); } } } // namespace infini diff --git a/src/kernels/kunlun/batch_norm.cc b/src/kernels/kunlun/batch_norm.cc new file mode 100644 index 00000000..d1c8c3b4 --- /dev/null +++ b/src/kernels/kunlun/batch_norm.cc @@ -0,0 +1,41 @@ +#include "operators/batch_norm.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class BatchNormXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const input = (op->getInputs(0)->getRawDataPtr()); + void *const mean = (op->getInputs(1)->getRawDataPtr()); + void *const var = (op->getInputs(2)->getRawDataPtr()); + void *const scale = (op->getInputs(3)->getRawDataPtr()); + void *const bias = (op->getInputs(4)->getRawDataPtr()); + void *const output = (op->getOutput()->getRawDataPtr()); + + auto dims = op->getInputs(0)->getDims(); + + if (dims.size() != 4) + IT_TODO_HALT(); + + int w = dims[3]; + int h = dims[2]; + int c = dims[1]; + int n = dims[0]; + auto ret = baidu::xpu::api::batch_norm_infer( + context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h, + w, op->getEps(), (float *)scale, (float *)bias, (float *)mean, + (float *)var, true); + + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::BatchNormalization, DataType::Float32, + BatchNormXdnn, "BatchNorm_xdnn_KUNLUN_Float32"); + +}; // namespace infini diff --git a/src/kernels/kunlun/cast.cc b/src/kernels/kunlun/cast.cc new file mode 100644 index 00000000..443cc259 --- /dev/null +++ b/src/kernels/kunlun/cast.cc @@ -0,0 +1,98 @@ +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/unary.h" + +namespace infini { +class CastXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + CastType type = op->getType(); + + int ret = 0; + switch (type) { + case CastType::Float2Float16: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float *)aData, (float16 *)cData, len); + break; + case CastType::Float2Int64: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float *)aData, (int64_t *)cData, len); + break; + case CastType::Float2Int32: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float *)aData, (int *)cData, len); + break; + case CastType::Float2Int16: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float *)aData, (int16_t *)cData, len); + break; + case CastType::Float2Int8: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float *)aData, (int8_t *)cData, len); + break; + case CastType::Int322Float: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int *)aData, (float *)cData, len); + break; + case CastType::Int322Int8: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int *)aData, (int8_t *)cData, len); + break; + case CastType::Int322Int16: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int *)aData, (int16_t *)cData, len); + break; + case CastType::Int162Float: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int16_t *)aData, (float *)cData, len); + break; + case CastType::Int162Int32: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int16_t *)aData, (int *)cData, len); + break; + case CastType::Int82Float: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int8_t *)aData, (float *)cData, len); + break; + case CastType::Int82Int16: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int8_t *)aData, (int16_t *)cData, + len); + break; + case CastType::Int82Int32: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int8_t *)aData, (int *)cData, len); + break; + case CastType::Int322Int64: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int *)aData, (int64_t *)cData, len); + break; + case CastType::Int642Int32: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int64_t *)aData, (int *)cData, len); + break; + case CastType::Int642Float: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int64_t *)aData, (float *)cData, len); + break; + case CastType::Float162Float: + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (float16 *)aData, (float *)cData, len); + break; + default: + IT_TODO_HALT(); + } + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Cast, DataType::Float32, CastXdnn, + "Cast_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/concat.cc b/src/kernels/kunlun/concat.cc new file mode 100644 index 00000000..35777cae --- /dev/null +++ b/src/kernels/kunlun/concat.cc @@ -0,0 +1,37 @@ +#include "operators/concat.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class ConcatXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int axis = op->getDim(); + int num = op->numInputs(); + std::vector inputsData; + for (int i = 0; i < num; ++i) { + inputsData.push_back( + (float *)(op->getInputs(i)->getRawDataPtr())); + } + void *const cData = (op->getOutput()->getRawDataPtr()); + + std::vector> dims; + for (int i = 0; i < num; ++i) { + auto dim = op->getInputs(i)->getDims(); + if (dim.size() != 4) { + IT_TODO_HALT(); + } + dims.push_back(dim); + } + auto ret = baidu::xpu::api::concat( + context->KUNLUNHandle(), inputsData, (float *)cData, dims, axis); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Concat, DataType::Float32, ConcatXdnn, + "Concat_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/conv.cc b/src/kernels/kunlun/conv.cc new file mode 100644 index 00000000..80cc37c7 --- /dev/null +++ b/src/kernels/kunlun/conv.cc @@ -0,0 +1,37 @@ +#include "operators/conv.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class ConvXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int cpg = op->getChannelPerGroup(); + const int g = c / cpg; + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + std::vector pads = {ph, pw}; + std::vector ksize = {r, s}; + std::vector stride = {sh, sw}; + std::vector dilation = {dh, dw}; + + auto ret = baidu::xpu::api::conv2d( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g, + nullptr, nullptr, nullptr, true); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Conv, DataType::Float32, ConvXdnn, + "Conv_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/conv_trans.cc b/src/kernels/kunlun/conv_trans.cc new file mode 100644 index 00000000..841955a6 --- /dev/null +++ b/src/kernels/kunlun/conv_trans.cc @@ -0,0 +1,54 @@ +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/conv.h" + +namespace infini { +class ConvTransXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int cpg = op->getChannelPerGroup(); + const int g = c / cpg; + const bool isNCHW = + (op->getOpType() == OpType::ConvTransNHWC) ? false : true; + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + std::vector pads = {ph, pw}; + std::vector ksize = {r, s}; + std::vector stride = {sh, sw}; + std::vector dilation = {dh, dw}; + + auto dimInputs0 = op->getInputs(0)->getDims(); + auto dimInputs1 = op->getInputs(1)->getDims(); + auto dimOutput = op->getOutput()->getDims(); + + if (dimInputs0.size() != 4) + IT_TODO_HALT(); + if (dimInputs1.size() != 4) + IT_TODO_HALT(); + if (dimOutput.size() != 4) + IT_TODO_HALT(); + + auto ret = + baidu::xpu::api::conv2d_transpose( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g, + nullptr, nullptr, nullptr, isNCHW); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTranspose, DataType::Float32, + ConvTransXdnn, "ConvTrans_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTransNHWC, DataType::Float32, + ConvTransXdnn, "ConvTranposedNHWC_xdnn_KUNLUN_Float32"); + +}; // namespace infini diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc new file mode 100644 index 00000000..03ce74b1 --- /dev/null +++ b/src/kernels/kunlun/element_wise.cc @@ -0,0 +1,476 @@ +#include "operators/element_wise.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class AddXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_add( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class SubXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_sub( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class MulXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_mul( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class DivXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_div( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class PowXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_pow( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class MaxXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_max( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class MinXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_min( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim); + assert(ret == 0); + return; + } +}; + +class EqualXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_equal( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (bool *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_greater_equal( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (bool *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_greater_than( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (bool *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class LessEqualXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_less_equal( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (bool *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class LessThanXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_less_than( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (bool *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class FloorDivXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_floordiv( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)wsData, aDim, bDim); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (int *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class MSELossXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + + auto dim = op->getInputs(0)->getDims(); + if (dim.size() != 4) + IT_TODO_HALT(); + + auto ret = baidu::xpu::api::mse_loss( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class AndXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::logical_and( + context->KUNLUNHandle(), (bool *)aData, (bool *)bData, + (bool *)wsData, len); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class OrXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::logical_or( + context->KUNLUNHandle(), (bool *)aData, (bool *)bData, + (bool *)wsData, len); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class XorXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::logical_xor( + context->KUNLUNHandle(), (bool *)aData, (bool *)bData, + (bool *)wsData, len); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class NotXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + size_t len = op->getOutput()->size(); + KUNLUNPtr wsData = context->getWorkspace(len); + + auto aDim = op->getInputs(0)->getDims(); + if (aDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::logical_not( + context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len); + ret = baidu::xpu::api::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Add, DataType::Float32, AddXdnn, + "Add_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sub, DataType::Float32, SubXdnn, + "Sub_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Mul, DataType::Float32, MulXdnn, + "Mul_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Div, DataType::Float32, DivXdnn, + "Div_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Pow, DataType::Float32, PowXdnn, + "Pow_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Max, DataType::Float32, MaxXdnn, + "Max_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Min, DataType::Float32, MinXdnn, + "Min_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Equal, DataType::Float32, EqualXdnn, + "Equal_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::GreaterOrEqual, DataType::Float32, + GreaterEqualXdnn, "GreaterEqual_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Greater, DataType::Float32, + GreaterThanXdnn, "GreaterThan_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::LessOrEqual, DataType::Float32, + LessEqualXdnn, "LessEqual_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Less, DataType::Float32, LessThanXdnn, + "LessThan_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::FloorDiv, DataType::Float32, + FloorDivXdnn, "FloorDiv_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::MSELoss, DataType::Float32, MSELossXdnn, + "MSELoss_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::And, DataType::Float32, AndXdnn, + "And_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Or, DataType::Float32, OrXdnn, + "Or_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Xor, DataType::Float32, XorXdnn, + "Xor_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Not, DataType::Float32, NotXdnn, + "Not_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/matmul.cc b/src/kernels/kunlun/matmul.cc new file mode 100644 index 00000000..91240ce3 --- /dev/null +++ b/src/kernels/kunlun/matmul.cc @@ -0,0 +1,38 @@ +#include "operators/matmul.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class MatmulXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + bool transA = op->getTransA(); + bool transB = op->getTransB(); + if (op->getInputs(0)->getDims().size() != 2 || + op->getInputs(1)->getDims().size() != 2) { + IT_TODO_HALT(); + } + + auto m = transA ? op->getInputs(0)->getDims()[1] + : op->getInputs(0)->getDims()[0]; + auto n = transB ? op->getInputs(1)->getDims()[0] + : op->getInputs(1)->getDims()[1]; + auto k = transA ? op->getInputs(0)->getDims()[0] + : op->getInputs(0)->getDims()[1]; + + auto ret = baidu::xpu::api::fc( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, m, n, k, transA, transB, nullptr, nullptr, nullptr); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::MatMul, DataType::Float32, MatmulXdnn, + "Matmul_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/pad.cc b/src/kernels/kunlun/pad.cc new file mode 100644 index 00000000..2ae93d99 --- /dev/null +++ b/src/kernels/kunlun/pad.cc @@ -0,0 +1,37 @@ +#include "operators/pad.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class PadXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int dim_size = dim.size(); + + std::vector pads = op->getPads(); + + std::cout << std::endl; + std::vector paddings_left(pads.begin(), pads.begin() + dim_size); + std::vector paddings_right(pads.begin() + dim_size, pads.end()); + + float paddingValue = 0.0; + auto ret = baidu::xpu::api::pad( + context->KUNLUNHandle(), (float *)aData, (float *)cData, dim, + paddings_left, paddings_right, paddingValue); + + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Pad, DataType::Float32, PadXdnn, + "Pad_xdnn_KUNLUN_Float32"); + +}; // namespace infini diff --git a/src/kernels/kunlun/pooling.cc b/src/kernels/kunlun/pooling.cc new file mode 100644 index 00000000..27b8458a --- /dev/null +++ b/src/kernels/kunlun/pooling.cc @@ -0,0 +1,62 @@ +#include "operators/pooling.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class AvgPooling : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto [n, c, h, w, kh, kw] = op->getNCHWRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + + std::vector ksize = {kh, kw}; + std::vector stride = {sh, sw}; + std::vector pad = {ph, pw}; + + auto ret = baidu::xpu::api::avg_pool2d( + context->KUNLUNHandle(), (float *)aData, (float *)cData, n, c, h, w, + ksize, stride, pad, true, true, nullptr, nullptr); + assert(ret == 0); + return; + } +}; + +class MaxPooling : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto [n, c, h, w, kh, kw] = op->getNCHWRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + + std::vector ksize = {kh, kw}; + std::vector stride = {sh, sw}; + std::vector pad = {ph, pw}; + + int yh = (h + ph * 2 - kh) / sh + 1; + int yw = (w + pw * 2 - kw) / sw + 1; + + KUNLUNPtr indices = context->getWorkspace(yh * yw * 4); + + auto ret = baidu::xpu::api::max_pool2d( + context->KUNLUNHandle(), (float *)aData, (float *)cData, + (int *)indices, n, c, h, w, ksize, stride, pad, true, nullptr, + nullptr, false); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::MaxPool, DataType::Float32, MaxPooling, + "MaxPool_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::AveragePool, DataType::Float32, + AvgPooling, "AvgPool_xdnn_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/split.cc b/src/kernels/kunlun/split.cc new file mode 100644 index 00000000..301ef027 --- /dev/null +++ b/src/kernels/kunlun/split.cc @@ -0,0 +1,41 @@ +#include "operators/split.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class SplitXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int axis = op->getDim(); + int num = op->numOutputs(); + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + auto inputDim = op->getInputs(0)->getDims(); + + std::vector outputsData; + for (int i = 0; i < num; ++i) { + outputsData.push_back( + (float *)(op->getOutput(i)->getRawDataPtr())); + } + + std::vector splitList; + for (int i = 0; i < num; ++i) { + auto dim = op->getOutput(i)->getDims(); + if (dim.size() != 4) { + IT_TODO_HALT(); + } + splitList.push_back(dim[axis]); + } + + auto ret = baidu::xpu::api::split( + context->KUNLUNHandle(), (float *)inputData, outputsData, inputDim, + splitList, axis); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Split, DataType::Float32, SplitXdnn, + "Split_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/transpose.cc b/src/kernels/kunlun/transpose.cc new file mode 100644 index 00000000..443df8d9 --- /dev/null +++ b/src/kernels/kunlun/transpose.cc @@ -0,0 +1,32 @@ +#include "operators/transpose.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class TransposeXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto dimin = op->getInputs(0)->getDims(); + auto permute = op->getPermute(); + + if (dimin.size() != 4) { + IT_TODO_HALT(); + } + + auto ret = baidu::xpu::api::transpose( + context->KUNLUNHandle(), (float *)aData, (float *)cData, dimin, + permute); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Transpose, DataType::Float32, + TransposeXdnn, "Transpose_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/unary.cc b/src/kernels/kunlun/unary.cc new file mode 100644 index 00000000..c24fddaf --- /dev/null +++ b/src/kernels/kunlun/unary.cc @@ -0,0 +1,550 @@ +#include "operators/unary.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class ReluXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::relu( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class SigmoidXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::sigmoid( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class TanhXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::tanh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class SquareXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::square( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class SqrtXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::sqrt( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class RsqrtXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::rsqrt( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class ExpXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::exp( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class CeilXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::ceil( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class ClipXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + float min = op->getMin().value(); + float max = op->getMax().value(); + + auto ret = baidu::xpu::api::clip(context->KUNLUNHandle(), + (float *)aData, (float *)cData, + len, min, max); + assert(ret == 0); + return; + } +}; + +class FloorXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::floor( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class NegXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::neg( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class CopyXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::copy( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class ReciprocalXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::reciprocal( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class AbsXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::abs( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class ATanXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = baidu::xpu::api::arctan( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class LogXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto aDim = op->getInputs(0)->getDims(); + std::vector divDim = { + 1, + }; + auto len = op->getInputs(0)->size(); + // get ptr of tempspace + KUNLUNPtr temp = context->getWorkspace(len * sizeof(float)); + LogObj::LogType type = op->getType(); + // get output of xpu::api::loge(x) + auto ret = baidu::xpu::api::log( + context->KUNLUNHandle(), (float *)aData, (float *)temp, len); + // get ptr of divider + KUNLUNPtr dd = + (float *)(context->getWorkspace((1 + len) * sizeof(float))) + len; + // choose from logE, log2, log10 + switch (type) { + float constant; + case LogObj::LogE: + // if use loge, copy from temp to cData + ret = baidu::xpu::api::copy( + context->KUNLUNHandle(), (float *)temp, (float *)cData, len); + break; + case LogObj::Log2: + constant = std::log(2); + context->copyBlobFromCPU(dd, &constant, sizeof(float)); + ret = baidu::xpu::api::broadcast_div( + context->KUNLUNHandle(), (float *)temp, (float *)dd, + (float *)cData, aDim, divDim); + break; + case LogObj::Log10: + constant = std::log(10); + context->copyBlobFromCPU(dd, &constant, sizeof(float)); + ret = baidu::xpu::api::broadcast_div( + context->KUNLUNHandle(), (float *)temp, (float *)dd, + (float *)cData, aDim, divDim); + break; + default: + printf("LogType not support!"); + break; + } + assert(ret == 0); + return; + } +}; + +class CosXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::cos( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class SinXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::sin( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class TanXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::tan( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class SinhXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::sinh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class CoshXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::cosh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ErfXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::erf( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ACosXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::arccos( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ACoshXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::acosh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ASinXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::arcsin( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ASinhXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::asinh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +class ATanhXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + auto ret = baidu::xpu::api::atanh( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, DataType::Float32, ReluXdnn, + "Relu_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, DataType::Float32, SigmoidXdnn, + "Sigmoid_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, DataType::Float32, TanhXdnn, + "Tanh_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Square, DataType::Float32, SquareXdnn, + "Square_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sqrt, DataType::Float32, SqrtXdnn, + "Sqrt_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Rsqrt, DataType::Float32, RsqrtXdnn, + "Rsqrt_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Exp, DataType::Float32, ExpXdnn, + "Exp_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Ceil, DataType::Float32, CeilXdnn, + "Ceil_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Clip, DataType::Float32, ClipXdnn, + "Clip_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Floor, DataType::Float32, FloorXdnn, + "Floor_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Neg, DataType::Float32, NegXdnn, + "Neg_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, DataType::Float32, + ReciprocalXdnn, "Reciprocal_xdnn_KUNLUN_Float32"); + +REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, DataType::Float32, CopyXdnn, + "Reshape_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, DataType::Float32, CopyXdnn, + "Flatten_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, DataType::Float32, CopyXdnn, + "Identity_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, DataType::Float32, AbsXdnn, + "Abs_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, DataType::Float32, ATanXdnn, + "Atan_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Log, DataType::Float32, LogXdnn, + "Log_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Cos, DataType::Float32, CosXdnn, + "Cos_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sin, DataType::Float32, SinXdnn, + "Sin_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Tan, DataType::Float32, TanXdnn, + "Tan_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Sinh, DataType::Float32, SinhXdnn, + "Sinh_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Cosh, DataType::Float32, CoshXdnn, + "Cosh_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, DataType::Float32, ErfXdnn, + "Erf_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, DataType::Float32, ACosXdnn, + "ACos_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, DataType::Float32, ACoshXdnn, + "ACosh_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, DataType::Float32, ASinXdnn, + "ASin_xdnn_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, DataType::Float32, ASinhXdnn, + "ASinh_xdnn_Float3 2"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, DataType::Float32, ATanhXdnn, + "ATanh_xdnn_Float32"); +}; // namespace infini diff --git a/src/kunlun/kunlun_runtime.cc b/src/kunlun/kunlun_runtime.cc new file mode 100644 index 00000000..b40e772f --- /dev/null +++ b/src/kunlun/kunlun_runtime.cc @@ -0,0 +1,60 @@ +#include "kunlun/kunlun_runtime.h" +#include "core/kernel.h" +#include "core/perf_engine.h" + +namespace infini { + +void KUNLUNRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, + bool profiling = false) const { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); + double totalTime = 0; + std::map opTime; + std::map opCnt; + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType().underlying(), op->getDType()}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; + auto perfData = perfEngine.getPerfData(perfKey); + if (!perfData && !tune) { + kernel->compute(op, this); + continue; + } + + PerfRecord record; + if (!perfData) { + record = kernel->tune(op, this); + perfEngine.setPerfData(perfKey, record); + } else + record = perfData; + + double t = record->time; + totalTime += t; + + if (profiling) { + double t = timeit([&]() { kernel->compute(op, record, this); }, + [&]() { sync(); }, 1, 1); + op->print(); + printf(" op_time on kunlun xpu %lf\n", t); + totalTime += t; + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } + } +} + +void KUNLUNRuntimeObj::run(const Graph &graph, bool tune, + bool profiling) const { + if (profiling) + IT_TODO_HALT(); + runWithoutSync(graph, tune, profiling); + sync(); +} + +void KUNLUNRuntimeObj::sync() const { ; } + +string KUNLUNRuntimeObj::toString() const { return "KUNLUN Runtime"; } + +} // namespace infini diff --git a/src/kunlun/operator_timer.cc b/src/kunlun/operator_timer.cc new file mode 100644 index 00000000..aeb5d18f --- /dev/null +++ b/src/kunlun/operator_timer.cc @@ -0,0 +1,71 @@ +#include "kunlun/operator_timer.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/conv.h" +#include "operators/matmul.h" +#include "utils/data_generator.h" + +namespace infini { +namespace opTimer { + +double getPerfConvKunlun(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name) { + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime kunlun = make_ref(); + Graph gKunlun = make_ref(kunlun); + // Set input data on CPU in a CPU Graph + IT_ASSERT(c % group == 0); + Tensor i0Cpu = gCpu->addTensor({n, h, w, c}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({f, r, s, c / group}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to Kunlun + Tensor i0Kunlun = gKunlun->cloneTensor(i0Cpu); + Tensor w0Kunlun = gKunlun->cloneTensor(w0Cpu); + // Build Kunlun graph + auto conv = gKunlun->addOp(i0Kunlun, w0Kunlun, nullptr, padh, padw, + strideh, stridew, dilationh, dilationw); + // allocate Kunlun memory + gKunlun->dataMalloc(); + // Execute on Kunlun + bool tune = true; + kunlun->run(gKunlun, tune); + return kunlun->getPerfTime(gKunlun); +} + +double getPerfMatmulKunlun(int b, int m, int n, int k, const char *name) { + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime kunlun = make_ref(); + Graph gKunlun = make_ref(kunlun); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to Kunlun + Tensor i0Kunlun = gKunlun->cloneTensor(i0Cpu); + Tensor w0Kunlun = gKunlun->cloneTensor(w0Cpu); + // Build Kunlun graph + auto conv = gKunlun->addOp(i0Kunlun, w0Kunlun, nullptr); + // allocate Kunlun memory + gKunlun->dataMalloc(); + // Execute on Kunlun + bool tune = true; + kunlun->run(gKunlun, tune); + return kunlun->getPerfTime(gKunlun); +} + +} // namespace opTimer +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_add.cc b/test/kernels/kunlun/test_kunlun_add.cc new file mode 100644 index 00000000..a6b44a7a --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_add.cc @@ -0,0 +1,61 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +template +void testAdd(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1); + auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2); + auto gpuOp = xpuGraph->addOp(inputGpu1, inputGpu2, nullptr); + xpuGraph->dataMalloc(); + inputGpu1->setData(generator); + inputGpu2->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu1, inputCpu2, nullptr); + cpuGraph->addTensor(inputCpu1); + cpuGraph->addTensor(inputCpu2); + cpuGraph->dataMalloc(); + inputCpu1->setData(generator); + inputCpu2->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +TEST(xpu_add, run) { + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); + testAdd(IncrementalGenerator(), Shape{1, 1, 1, 30}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_batch_norm.cc b/test/kernels/kunlun/test_kunlun_batch_norm.cc new file mode 100644 index 00000000..ef6204be --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_batch_norm.cc @@ -0,0 +1,61 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/batch_norm.h" +#include "test.h" + +namespace infini { + +TEST(XPU_BatchNorm, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build cpu graph + Graph gCpu = make_ref(cpuRuntime); + auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); + auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + + // Build input data on CPU + gCpu->dataMalloc(); + iCpu->setData(IncrementalGenerator()); + meanCpu->copyin(vector{1, 6, 9}); + varCpu->copyin(vector{4, 1, 9}); + scaleCpu->setData(OneGenerator()); + biasCpu->setData(ZeroGenerator()); + + // Build XPU graph + Graph g = make_ref(xpuRuntime); + + auto i = g->cloneTensor(iCpu); + auto mean = g->cloneTensor(meanCpu); + auto var = g->cloneTensor(varCpu); + auto scale = g->cloneTensor(scaleCpu); + auto bias = g->cloneTensor(biasCpu); + auto op = + g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); + + // allocate XPU memory + g->dataMalloc(); + i->setData(IncrementalGenerator()); + mean->copyin(vector{1, 6, 9}); + var->copyin(vector{4, 1, 9}); + scale->setData(OneGenerator()); + bias->setData(ZeroGenerator()); + + // Execute on XPU + xpuRuntime->run(g); + + // clone XPU output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2})); + EXPECT_TRUE(ocpu->equalData(vector{ + -0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.3333333, 0.6666667})); +} +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_concat.cc b/test/kernels/kunlun/test_kunlun_concat.cc new file mode 100644 index 00000000..7e3ee714 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_concat.cc @@ -0,0 +1,54 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/concat.h" + +#include "test.h" + +namespace infini { + +template +void testConcat(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu1->setData(generator); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu2->dataMalloc(); + inputCpu2->setData(generator); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1); + auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2); + auto gpuOp = + xpuGraph->addOp(TensorVec{inputGpu1, inputGpu2}, nullptr, 2); + xpuGraph->dataMalloc(); + inputGpu1->setData(generator); + inputGpu2->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + inputCpu1->print(); + inputCpu1->printData(); + inputCpu2->print(); + inputCpu2->printData(); + outputGpu2Cpu->print(); + outputGpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(xpu_Concat, run) { + testConcat(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_conv.cc b/test/kernels/kunlun/test_kunlun_conv.cc new file mode 100644 index 00000000..5584853c --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_conv.cc @@ -0,0 +1,56 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +template +void testConv(const std::function &generatorA, + const std::function &generatorB, + const Shape &shapeA, const Shape &shapeB) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shapeA, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shapeB, DataType::Float32, cpuRuntime); + // MLU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1); + auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2); + auto mluOp = + xpuGraph->addOp(inputMlu1, inputMlu2, nullptr, 1, 1, 1, 1, 1, 1); + xpuGraph->dataMalloc(); + inputMlu1->setData(generatorA); + inputMlu2->setData(generatorB); + xpuRuntime->run(xpuGraph); + auto outputXpu = mluOp->getOutput(); + auto outputXpu2Cpu = outputXpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + cpuGraph->addTensor(inputCpu1); + cpuGraph->addTensor(inputCpu2); + auto cpuOp = + cpuGraph->addOp(inputCpu1, inputCpu2, nullptr, 1, 1, 1, 1, 1, 1); + cpuGraph->dataMalloc(); + inputCpu1->setData(generatorA); + inputCpu2->setData(generatorB); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputXpu2Cpu)); +} + +TEST(xpu_Conv, run) { + testConv(IncrementalGenerator(), IncrementalGenerator(), + Shape{1, 3, 32, 32}, Shape{2, 3, 3, 3}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_conv_trans.cc b/test/kernels/kunlun/test_kunlun_conv_trans.cc new file mode 100644 index 00000000..218e22f4 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_conv_trans.cc @@ -0,0 +1,136 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "core/runtime.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvTransposedXdnn( + const std::function &generator, + vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + // Construct Runtime and graph for CPU and XPU + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime xpu = make_ref(); + Graph gXpu = make_ref(xpu); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to XPU + Tensor i0Xpu = gXpu->cloneTensor(i0Cpu); + Tensor w0Xpu = gXpu->cloneTensor(w0Cpu); + // Build XPU graph + auto conv = gXpu->addOp(i0Xpu, w0Xpu, nullptr, padding, + padding, stride, stride, + dilation, dilation); + gXpu->dataMalloc(); + i0Xpu->setData(generator); + w0Xpu->setData(generator); + // Execute on XPU + xpu->run(gXpu); + // copy output from XPU to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); +} + +void testConvTransposedNHWCXdnn( + const std::function &generator, + vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + // Construct Runtime and graph for CPU and XPU + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime xpu = make_ref(); + Graph gXpu = make_ref(xpu); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({N, H, W, F}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({F, R, S, C}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to XPU + Tensor i0Xpu = gXpu->cloneTensor(i0Cpu); + Tensor w0Xpu = gXpu->cloneTensor(w0Cpu); + // Build XPU graph + auto conv = gXpu->addOp( + i0Xpu, w0Xpu, nullptr, padding, padding, stride, stride, dilation, + dilation); + gXpu->dataMalloc(); + i0Xpu->setData(generator); + w0Xpu->setData(generator); + // Execute on XPU + xpu->run(gXpu); + // copy output from XPU to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); +} + +TEST(XPU_ConvTransposed, run) { + testConvTransposedXdnn(IncrementalGenerator(), + vector{0., 0., 1., 2., 3., 0., 6., + 12., 18., 16., 8., 30., 36., 42., + 32., 16., 54., 60., 66., 48., 24., + 62., 67., 72., 45.}); +} + +TEST(XPU_ConvTransposedNHWC, run) { + testConvTransposedNHWCXdnn(IncrementalGenerator(), + vector{0., 0., 1., 2., 3., 0., 6., + 12., 18., 16., 8., 30., 36., 42., + 32., 16., 54., 60., 66., 48., 24., + 62., 67., 72., 45.}); +} + +TEST(XPU_ConvTransposed, run1) { + // Construct Runtime and graph for CPU and XPU + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime xpu = make_ref(); + Graph gXpu = make_ref(xpu); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 2, 3, 3}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 2, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to XPU + Tensor i0Xpu = gXpu->cloneTensor(i0Cpu); + Tensor w0Xpu = gXpu->cloneTensor(w0Cpu); + // Build XPU graph + auto conv = gXpu->addOp(i0Xpu, w0Xpu, nullptr, 0, 0); + gXpu->dataMalloc(); + i0Xpu->setData(IncrementalGenerator()); + w0Xpu->setData(IncrementalGenerator()); + // Execute on XPU + xpu->run(gXpu); + // copy output from XPU to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(vector{ + 162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553, + 747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835, + 396, 843, 1343, 953, 506, 243, 531, 866, 629, 341, + 621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518, + 963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731})); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_element_wise.cc b/test/kernels/kunlun/test_kunlun_element_wise.cc new file mode 100644 index 00000000..faf0668b --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_element_wise.cc @@ -0,0 +1,66 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testElementWiseXdnn( + const std::function &generator, + const Shape &shape, const ExpectOutput &ansVec) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor acpu = make_ref(shape, DataType::Float32, cpuRuntime); + acpu->dataMalloc(); + acpu->setData(generator); + + Tensor bcpu = make_ref(shape, DataType::Float32, cpuRuntime); + bcpu->dataMalloc(); + bcpu->setData(generator); + + // Build XPU graph + Graph g = make_ref(xpuRuntime); + auto a = g->cloneTensor(acpu); + auto b = g->cloneTensor(bcpu); + auto op = g->addOp(a, b, nullptr); + + // allocate XPU memory + g->dataMalloc(); + a->setData(generator); + b->setData(generator); + + // Execute on XPU + xpuRuntime->run(g); + + // clone XPU output to CPU + auto c = op->getOutput(); + auto ccpu = c->clone(cpuRuntime); + // check results on CPU + EXPECT_TRUE(ccpu->equalData(ansVec)); +} + +TEST(xdnn_ElementWise, run) { + testElementWiseXdnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); + testElementWiseXdnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + testElementWiseXdnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + testElementWiseXdnn( + OneGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + testElementWiseXdnn(IncrementalGenerator(), Shape{1, 2, 2, 1}, + ExpectOutput{1, 1, 4, 27}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_matmul.cc b/test/kernels/kunlun/test_kunlun_matmul.cc new file mode 100644 index 00000000..dcd2084f --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_matmul.cc @@ -0,0 +1,58 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/matmul.h" + +#include "test.h" + +namespace infini { + +template +void testMatmul(const std::function &generatorA, + const std::function &generatorB, + bool transA, bool transB, const Shape &shapeA, + const Shape &shapeB) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shapeA, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shapeB, DataType::Float32, cpuRuntime); + + // MLU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1); + auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2); + auto mluOp = xpuGraph->addOp(inputMlu1, inputMlu2, nullptr); + xpuGraph->dataMalloc(); + inputMlu1->setData(generatorA); + inputMlu2->setData(generatorB); + xpuRuntime->run(xpuGraph); + auto outputMlu = mluOp->getOutput(); + auto outputMlu2Cpu = outputMlu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu1, inputCpu2, nullptr); + cpuGraph->addTensor(inputCpu1); + cpuGraph->addTensor(inputCpu2); + cpuGraph->dataMalloc(); + inputCpu1->setData(generatorA); + inputCpu2->setData(generatorB); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + outputCpu->print(); + outputMlu2Cpu->print(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputMlu2Cpu)); +} + +TEST(xpu_Matmul, run) { + testMatmul(IncrementalGenerator(), IncrementalGenerator(), false, + false, Shape{2, 3}, Shape{3, 4}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_pad.cc b/test/kernels/kunlun/test_kunlun_pad.cc new file mode 100644 index 00000000..43cc1bd7 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_pad.cc @@ -0,0 +1,40 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/pad.h" +#include "test.h" + +namespace infini { +TEST(xpu_Pad, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{1, 2, 3, 2}, DataType::Float32, cpuRuntime); + + // Build XPU graph; + Graph g = make_ref(xpuRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, vector{1, 0, 1, 1}, + vector{0, 3}); + + // allocate XPU memory + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute on XPU + xpuRuntime->run(g); + + // clone XPU output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + cpuo->printData(); + // check results on CPU + EXPECT_TRUE(cpuo->equalData( + vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); +} +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_pooling.cc b/test/kernels/kunlun/test_kunlun_pooling.cc new file mode 100644 index 00000000..67050628 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_pooling.cc @@ -0,0 +1,51 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/pooling.h" + +#include "test.h" + +namespace infini { + +template +void testPooling(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = + xpuGraph->addOp(inputGpu, nullptr, 3, 3, 1, 1, 0, 0, 2, 2, 0); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + cpuGraph->addTensor(inputCpu); + auto cpuOp = + cpuGraph->addOp(inputCpu, nullptr, 3, 3, 1, 1, 0, 0, 2, 2, 0); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +TEST(xdnn_Pooling, run) { + testPooling(IncrementalGenerator(), Shape{1, 1, 5, 5}); + testPooling(IncrementalGenerator(), Shape{1, 1, 5, 5}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_split.cc b/test/kernels/kunlun/test_kunlun_split.cc new file mode 100644 index 00000000..03b4375a --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_split.cc @@ -0,0 +1,48 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { + +template +void testSplit(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu1->setData(generator); + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1); + auto gpuOp = xpuGraph->addOp(inputGpu1, std::nullopt, 3, 3); + xpuGraph->dataMalloc(); + xpuRuntime->run(xpuGraph); + auto o0Cpu = gpuOp->getOutput(0)->clone(cpuRuntime); + auto o1Cpu = gpuOp->getOutput(1)->clone(cpuRuntime); + auto o2Cpu = gpuOp->getOutput(2)->clone(cpuRuntime); + // Check + inputCpu1->print(); + inputCpu1->printData(); + o0Cpu->print(); + o0Cpu->printData(); + o1Cpu->print(); + o1Cpu->printData(); + o2Cpu->print(); + o2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(xpu_Split, run) { + testSplit(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_transpose.cc b/test/kernels/kunlun/test_kunlun_transpose.cc new file mode 100644 index 00000000..6671904c --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_transpose.cc @@ -0,0 +1,43 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/transpose.h" + +#include "test.h" + +namespace infini { + +template +void testTranspose( + const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + vector permute = {0, 1, 3, 2}; + auto gpuOp = xpuGraph->addOp(inputGpu, nullptr, permute); + xpuGraph->dataMalloc(); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + inputCpu->printData(); + outputGpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(xpu_Transpose, run) { + testTranspose(IncrementalGenerator(), Shape{1, 1, 2, 3}); +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_unary.cc b/test/kernels/kunlun/test_kunlun_unary.cc new file mode 100644 index 00000000..319fbfef --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_unary.cc @@ -0,0 +1,190 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/unary.h" + +#include "test.h" + +namespace infini { + +template +void testUnary(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = xpuGraph->addOp(inputGpu, nullptr); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu, 1e-6)); +} + +void testClip(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + float min = 1.0; + float max = 5.0; + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = xpuGraph->addOp(inputGpu, nullptr, min, max); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr, min, max); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +void testCast(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = + xpuGraph->addOp(inputGpu, nullptr, CastType::Float2Int32); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = + cpuGraph->addOp(inputCpu, nullptr, CastType::Float2Int32); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +template +void testLog(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = xpuGraph->addOp(inputGpu, nullptr, T); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr, T); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +template +void testTrigon(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu = xpuGraph->cloneTensor(inputCpu); + auto gpuOp = xpuGraph->addOp(inputGpu, nullptr); + xpuGraph->dataMalloc(); + inputGpu->setData(generator); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu, 1e-3)); +} + +TEST(xdnn_Unary, run) { + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(ValGenerator<-1>(), Shape{1, 2, 2, 3}); + testUnary(OneGenerator(), Shape{1, 2, 2, 3}); + testLog(ValGenerator<2>(), Shape{1, 2, 2, 3}); + testLog(ValGenerator<2>(), Shape{1, 2, 2, 3}); + testLog(ValGenerator<2>(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testTrigon(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini