From 8b61b0e3979bca4cdeb7ce70fd962b38b286277f Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:09:33 +0800 Subject: [PATCH] support ascend (#165) * fix * fix code * fix format * fix format * fix * fix * addAbs * more Unary * add kernels * fix concat&pooling test code * add softmax/element_wise kernel * fix format * add reshape * support for llama * add maxpooling & flatten * add conv_transpose&&native maxpooling * add conv_transpose * add communication operator * fix * style: fix format * style: fix format * add depthTospace&&resize * add layernorm * format * add gemm * add leakyRelu op * modified format * modified onnx leakyrelu alpha * modified batchnorm * fix gemm & avgpooling * fix: onnx resize op input is none bug * add instancenorm, use layernorm replace instance, error * modiefied format, replace layernorm as instancenorm * fix: onnx resize op input is none bug * add pad2d kernel * modified format * fix op * fix resize * remove sync in op * Update INSTALL_GUIDE_CN.md for ASCEND * Update env.sh * format * fix test_resize * fix resize * fix test_resize_ * fix test_resize_ * add HcclCommDestroy && use default context * install onnxtuntime * install onnx-simplifier * install numpy * fix bug after merge * remove CHECK_RET&LOG_PRINT * fix test_ascend_layernorm * fix test_cuda_resize * fix test_ascend_* * fix format --------- Co-authored-by: Haojie Wang Co-authored-by: wanghailu Co-authored-by: OdinaryWord Co-authored-by: xgqdut2016 Co-authored-by: zhangyunze Co-authored-by: Songxin Co-authored-by: zhangyue <138768300+zhangyue207@users.noreply.github.com> Co-authored-by: zhangyue Co-authored-by: sx941227 <14507528+sx941227@user.noreply.gitee.com> Co-authored-by: zhangyunze Co-authored-by: Chenjie Duan <44265800+kilinchange@users.noreply.github.com> --- .gitignore | 2 + CMakeLists.txt | 46 ++++ Makefile | 2 + README.md | 2 +- docs/INSTALL_GUIDE_CN.md | 22 ++ docs/SUPPORT_MATRIX_CN.md | 1 + docs/USER_GUIDE_CN.md | 1 + env.sh | 30 ++ examples/NNmodel | 2 +- examples/distributed/ascend_launch.py | 198 ++++++++++++++ include/ascend/ascend_common.h | 37 +++ include/ascend/ascend_kernel_without_config.h | 23 ++ include/ascend/ascend_runtime.h | 82 ++++++ include/ascend/hccl_communicator.h | 72 +++++ include/core/common.h | 2 + include/core/graph_handler.h | 4 +- include/core/op_type.h | 3 +- include/core/runtime.h | 3 +- include/operators/instance_norm.h | 24 ++ include/operators/resize.h | 6 +- pyinfinitensor/src/pyinfinitensor/onnx.py | 99 +++++-- scripts/format.py | 2 +- src/ascend/ascend_runtime.cc | 73 +++++ src/core/common.cc | 13 + src/core/graph_handler.cc | 22 +- src/ffi/ffi_infinitensor.cc | 28 +- src/kernels/ascend/all_gather.cc | 50 ++++ src/kernels/ascend/all_reduce.cc | 57 ++++ src/kernels/ascend/batch_norm.cc | 99 +++++++ src/kernels/ascend/concat.cc | 74 +++++ src/kernels/ascend/conv.cc | 92 +++++++ src/kernels/ascend/conv_transpose.cc | 94 +++++++ src/kernels/ascend/element_wise.cc | 219 +++++++++++++++ src/kernels/ascend/gather.cc | 78 ++++++ src/kernels/ascend/instance_norm.cc | 104 +++++++ src/kernels/ascend/layer_norm.cc | 103 +++++++ src/kernels/ascend/matmul.cc | 118 ++++++++ src/kernels/ascend/pad2d.cc | 83 ++++++ src/kernels/ascend/pooling.cc | 154 +++++++++++ src/kernels/ascend/recv.cc | 50 ++++ src/kernels/ascend/reduce.cc | 131 +++++++++ src/kernels/ascend/reshape.cc | 62 +++++ src/kernels/ascend/resize.cc | 82 ++++++ src/kernels/ascend/send.cc | 45 +++ src/kernels/ascend/slice.cc | 84 ++++++ src/kernels/ascend/softmax.cc | 61 +++++ src/kernels/ascend/split.cc | 75 +++++ src/kernels/ascend/transpose.cc | 136 +++++++++ src/kernels/ascend/unary.cc | 258 ++++++++++++++++++ src/operators/instance_norm.cc | 48 ++++ src/operators/resize.cc | 10 +- src/operators/unary.cc | 3 +- src/utils/operator_utils.cc | 2 + test/kernels/ascend/test_ascend_all_gather.cc | 55 ++++ test/kernels/ascend/test_ascend_all_reduce.cc | 120 ++++++++ test/kernels/ascend/test_ascend_batch_norm.cc | 58 ++++ test/kernels/ascend/test_ascend_concat.cc | 65 +++++ test/kernels/ascend/test_ascend_conv.cc | 60 ++++ .../ascend/test_ascend_conv_transpose_2d.cc | 58 ++++ .../ascend/test_ascend_element_wise.cc | 69 +++++ test/kernels/ascend/test_ascend_erf.cc | 42 +++ test/kernels/ascend/test_ascend_gather.cc | 71 +++++ .../ascend/test_ascend_instancenorm.cc | 74 +++++ test/kernels/ascend/test_ascend_layernorm.cc | 143 ++++++++++ test/kernels/ascend/test_ascend_matmul.cc | 59 ++++ test/kernels/ascend/test_ascend_pad.cc | 49 ++++ test/kernels/ascend/test_ascend_pooling.cc | 48 ++++ test/kernels/ascend/test_ascend_reduce.cc | 84 ++++++ test/kernels/ascend/test_ascend_reshape.cc | 84 ++++++ test/kernels/ascend/test_ascend_resize.cc | 69 +++++ test/kernels/ascend/test_ascend_sendrecv.cc | 75 +++++ test/kernels/ascend/test_ascend_slice.cc | 41 +++ test/kernels/ascend/test_ascend_softmax.cc | 61 +++++ test/kernels/ascend/test_ascend_split.cc | 50 ++++ test/kernels/ascend/test_ascend_transpose.cc | 49 ++++ test/kernels/ascend/test_ascend_unary.cc | 151 ++++++++++ test/kernels/cuda/test_cuda_resize.cc | 66 ++--- test/operators/test_resize.cc | 16 +- 78 files changed, 4803 insertions(+), 85 deletions(-) create mode 100644 examples/distributed/ascend_launch.py create mode 100644 include/ascend/ascend_common.h create mode 100644 include/ascend/ascend_kernel_without_config.h create mode 100644 include/ascend/ascend_runtime.h create mode 100644 include/ascend/hccl_communicator.h create mode 100644 include/operators/instance_norm.h create mode 100644 src/ascend/ascend_runtime.cc create mode 100644 src/kernels/ascend/all_gather.cc create mode 100644 src/kernels/ascend/all_reduce.cc create mode 100644 src/kernels/ascend/batch_norm.cc create mode 100644 src/kernels/ascend/concat.cc create mode 100644 src/kernels/ascend/conv.cc create mode 100644 src/kernels/ascend/conv_transpose.cc create mode 100644 src/kernels/ascend/element_wise.cc create mode 100644 src/kernels/ascend/gather.cc create mode 100644 src/kernels/ascend/instance_norm.cc create mode 100644 src/kernels/ascend/layer_norm.cc create mode 100644 src/kernels/ascend/matmul.cc create mode 100644 src/kernels/ascend/pad2d.cc create mode 100644 src/kernels/ascend/pooling.cc create mode 100644 src/kernels/ascend/recv.cc create mode 100644 src/kernels/ascend/reduce.cc create mode 100644 src/kernels/ascend/reshape.cc create mode 100644 src/kernels/ascend/resize.cc create mode 100644 src/kernels/ascend/send.cc create mode 100644 src/kernels/ascend/slice.cc create mode 100644 src/kernels/ascend/softmax.cc create mode 100644 src/kernels/ascend/split.cc create mode 100644 src/kernels/ascend/transpose.cc create mode 100644 src/kernels/ascend/unary.cc create mode 100644 src/operators/instance_norm.cc create mode 100644 test/kernels/ascend/test_ascend_all_gather.cc create mode 100644 test/kernels/ascend/test_ascend_all_reduce.cc create mode 100644 test/kernels/ascend/test_ascend_batch_norm.cc create mode 100644 test/kernels/ascend/test_ascend_concat.cc create mode 100644 test/kernels/ascend/test_ascend_conv.cc create mode 100644 test/kernels/ascend/test_ascend_conv_transpose_2d.cc create mode 100644 test/kernels/ascend/test_ascend_element_wise.cc create mode 100644 test/kernels/ascend/test_ascend_erf.cc create mode 100644 test/kernels/ascend/test_ascend_gather.cc create mode 100644 test/kernels/ascend/test_ascend_instancenorm.cc create mode 100644 test/kernels/ascend/test_ascend_layernorm.cc create mode 100644 test/kernels/ascend/test_ascend_matmul.cc create mode 100644 test/kernels/ascend/test_ascend_pad.cc create mode 100644 test/kernels/ascend/test_ascend_pooling.cc create mode 100644 test/kernels/ascend/test_ascend_reduce.cc create mode 100644 test/kernels/ascend/test_ascend_reshape.cc create mode 100644 test/kernels/ascend/test_ascend_resize.cc create mode 100644 test/kernels/ascend/test_ascend_sendrecv.cc create mode 100644 test/kernels/ascend/test_ascend_slice.cc create mode 100644 test/kernels/ascend/test_ascend_softmax.cc create mode 100644 test/kernels/ascend/test_ascend_split.cc create mode 100644 test/kernels/ascend/test_ascend_transpose.cc create mode 100644 test/kernels/ascend/test_ascend_unary.cc diff --git a/.gitignore b/.gitignore index 98e980ad..47ac8c07 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,5 @@ build_debug/ *.onnx *.pb *.npy + +*.swp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1dbb007b..fe27aeb6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,7 @@ option(USE_CUDA "Support CUDA GPU" OFF) option(USE_BANG "Support BANG MLU" OFF) option(USE_KUNLUN "Support KUNLUN XPU" OFF) +option(USE_ASCEND "Support HUAWEI ASCEND" OFF) option(USE_INTELCPU "Support INTELCPU" OFF) option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) option(USE_PROTOBUF "Serialize and deserialize tensors" OFF) @@ -151,6 +152,11 @@ if(USE_KUNLUN) list (APPEND SRC ${SRC_KUNLUN}) endif() +if(USE_ASCEND) + file(GLOB_RECURSE SRC_ASCEND src/ascend/*.cc src/kernels/ascend/*.cc ) + list (APPEND SRC ${SRC_ASCEND}) +endif() + if(USE_INTELCPU) file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc ) list (APPEND SRC ${SRC_INTELCPU}) @@ -297,6 +303,7 @@ if(USE_KUNLUN) else() set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") endif() + message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}") if (BUILD_DIST) @@ -309,6 +316,42 @@ if(USE_KUNLUN) target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++) endif() +if(USE_ASCEND) + add_compile_definitions(USE_ASCEND=1) + if ((NOT DEFINED ASCEND_HOME) AND (NOT DEFINED ENV{ASCEND_HOME})) + message(FATAL_ERROR "ASCEND_HOME is not defined from cmake or env") + elseif (DEFINED ASCEND_HOME) + set(ASCEND_HOME ${ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Ascend development") + else() + set(ASCEND_HOME $ENV{ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Ascend development") + endif() + message(STATUS "ASCEND_HOME: ${ASCEND_HOME}") + + include_directories("${ASCEND_HOME}/include/") + include_directories("${ASCEND_HOME}/include/aclnn") + find_library(ASCEND_CL libascendcl.so "${ASCEND_HOME}/lib64") + find_library(ASCEND_BASE libnnopbase.so "${ASCEND_HOME}/lib64") + find_library(ASCEND_DNN libopapi.so "${ASCEND_HOME}/lib64") + find_library(ASCEND_HCCL libhccl.so "${ASCEND_HOME}/lib64") + find_library(ASCEND_HAL libascend_hal.so "${ASCEND_HOME}/../../driver/lib64/driver") + # find_library(ASCEND_RT libruntime.so "${ASCEND_HOME}/lib64") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) + execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) + set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") + elseif(DEFINED TARGET_CPU_ARCH) + set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + else() + set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + endif() + message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}") + target_link_libraries(InfiniTensor ${ASCEND_HAL} ${ASCEND_CL} ${ASCEND_BASE} ${ASCEND_DNN} ${ASCEND_HCCL} stdc++) + if (BUILD_DIST) + message(STATUS "Add BUILD_DIST, use HCCL with ASCEND") + add_compile_definitions(INFINI_USE_HCCL=1) + endif() +endif() + # # Python bindings # pybind11_add_module(infini MODULE ${FFI}) # target_link_libraries(infini PRIVATE infini_cpp) @@ -346,6 +389,9 @@ if(BUILD_TEST) build_test(test/kernels/kunlun/*.cc) build_test(test/kunlun/*.cc) endif() + if (USE_ASCEND) + build_test(test/kernels/ascend/*.cc) + endif() if (USE_INTELCPU) build_test(test/kernels/intelcpu/*.cc) endif() diff --git a/Makefile b/Makefile index c0e0c8b7..f7aa82e5 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ TYPE ?= Release CUDA ?= OFF BANG ?= OFF KUNLUN ?= OFF +ASCEND ?= OFF INTELCPU ?= off BACKTRACE ?= ON TEST ?= ON @@ -29,6 +30,7 @@ CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) CMAKE_OPT += -DUSE_CUDA=$(CUDA) CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) +CMAKE_OPT += -DUSE_ASCEND=$(ASCEND) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) CMAKE_OPT += -DBUILD_DIST=$(DIST) diff --git a/README.md b/README.md index 5f1e8bd0..e6aba877 100644 --- a/README.md +++ b/README.md @@ -72,4 +72,4 @@ Please cite EinNet or PET in your publications if it helps your research: pages={37--54}, year={2021} } -``` +``` \ No newline at end of file diff --git a/docs/INSTALL_GUIDE_CN.md b/docs/INSTALL_GUIDE_CN.md index bac2534e..1869f748 100644 --- a/docs/INSTALL_GUIDE_CN.md +++ b/docs/INSTALL_GUIDE_CN.md @@ -14,6 +14,7 @@ | -------- | ------------ | ----------- | ---------- | | X86-64 | Nvidia GPU | Ubuntu-22.04 | Yes | | X86-64 | Cambricon MLU | Ubuntu-22.04 | Yes | +| arm64 | Ascend NPU |OpenEuler-22.03| Yes | 推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。 @@ -68,6 +69,20 @@ 我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。另外请注意,由于 MLU 上层软件建设适配程度有限,如您在其覆盖的机器,操作系统之外运行,需要在安装驱动之后使用上层软件的 Docker。 + - 如您的第三方加速卡为昇腾 NPU,请参考昇腾官方文档进行: + > [驱动及CANN安装](https://www.hiascend.com/document/detail/zh/canncommercial/80RC1/quickstart/quickstart/quickstart_18_0006.html) + > 安装完成后请进行相应的环境变量配置,将可执行文件目录与库目录添加到操作系统识别的路径中,例如 + > + > ```bash + > # 将如下内容写入到你的 bashrc 文件并 source 该文件 + > export ASCEND_HOME=/usr/local/Ascend/ascend-toolkit/latest + > source /usr/local/Ascend/ascend-toolkit/set_env.sh + > # 如您不方便将上述环境变量配置到 bashrc 文件中进行长期使用,你也可以在我们提供的 env.sh 文件中进行正确配置并激活,作为临时使用 + > source env.sh + > ``` + + 我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。 + 4. 确认您安装了 make,build-essential, python-is-python3, python-dev-is-python3, python3-pip, libdw-dev,如您的机器没有上述基础依赖,请自行按需安装。 - 在使用 apt-get 工具情况下,您可以这样执行 @@ -140,6 +155,13 @@ make install-python KUNLUN=ON ``` + 编译 CPU 部分,同时编译昇腾 NPU 部分: + + ```bash + export ASCEND_HOME=/path/to/your/ascend_home + make install-python ASCEND=ON + ``` + 3. 使用方法 安装成功后,您就可以使用本项目的 Python 接口进行编码并运行。具体使用方式可以参考项目样例代码 example/Resnet/resnet.py 以及用户使用手册 diff --git a/docs/SUPPORT_MATRIX_CN.md b/docs/SUPPORT_MATRIX_CN.md index 2720f2ed..7c3dc759 100644 --- a/docs/SUPPORT_MATRIX_CN.md +++ b/docs/SUPPORT_MATRIX_CN.md @@ -15,6 +15,7 @@ | -------- | ------------ | ----------- | ---------- | | X86-64 | Nvidia GPU | Ubuntu-22.04 | Yes | | X86-64 | Cambricon MLU | Ubuntu-22.04 | Yes | +| arm64 | Ascend NPU |OpenEuler-22.03| Yes | ## 神经网络支持 diff --git a/docs/USER_GUIDE_CN.md b/docs/USER_GUIDE_CN.md index 974579b9..f0bd8ce2 100644 --- a/docs/USER_GUIDE_CN.md +++ b/docs/USER_GUIDE_CN.md @@ -28,6 +28,7 @@ - `CUDA`:是否编译 CUDA 后端,默认为 `OFF`,`ON` 打开 - `BANG`:是否编译寒武纪后端,默认为 `OFF`,`ON` 打开 - `KUNLUN`:是否编译昆仑后端,默认为 `OFF`,`ON` 打开 +- `ASCEND`:是否编译华为后端,默认为 `OFF`,`ON` 打开 - `BACKTRACE`:是否启用栈回溯,默认为 `ON`,`OFF` 关闭,建议调试时打开 - `TEST`:是否编译 `googletest`,默认为 `ON`,`OFF` 关闭,只有 `test-cpp` 时必要 diff --git a/env.sh b/env.sh index 6971436f..4840b0fd 100644 --- a/env.sh +++ b/env.sh @@ -36,3 +36,33 @@ export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}" # ├── version # └── XTDK export KUNLUN_HOME=/usr/local/xpu + +# 配置华为ASCEND NPU 的 HOME 路径,请注意 /usr/local/Ascend/ascend-toolkit/latest 是华为ASCEND 软件栈提供的软件包路径。 +# 如若用户有其他的路径安装方式,请自行配置正确的路径。 +# 这里是 ascend 目录下一个可能的结构图,请参考。 +# . +# ├── aarch64-linux +# ├── acllib +# ├── arm64-linux +# ├── atc +# ├── bin +# ├── compiler +# ├── conf +# ├── fwkacllib +# ├── hccl +# ├── include +# ├── lib64 +# ├── mindstudio-toolkit +# ├── opp +# ├── opp_kernel +# ├── ops +# ├── pyACL +# ├── python +# ├── runtime +# ├── test-ops +# ├── toolkit +# └── tools + +export ASCEND_HOME=/usr/local/Ascend/ascend-toolkit/latest +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/toolbox/set_env.sh diff --git a/examples/NNmodel b/examples/NNmodel index 51d31052..b896cec2 160000 --- a/examples/NNmodel +++ b/examples/NNmodel @@ -1 +1 @@ -Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77 +Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98 diff --git a/examples/distributed/ascend_launch.py b/examples/distributed/ascend_launch.py new file mode 100644 index 00000000..ccaee5e1 --- /dev/null +++ b/examples/distributed/ascend_launch.py @@ -0,0 +1,198 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model +import acl + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx", + help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(model, runtime, world_size=1, rank=0, n=10): + stub = OnnxStub(model, runtime) + load_inputs(stub, world_size, rank) + # stub.tune() + stub.run() + # get outputs + time.sleep(0.01) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() + + # bench + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, model, runtime, world_size=1, rank = 0): + results = np.load(f"./data/output.npy") + outputs = run_model(model, runtime, world_size, rank) + print("answer argmax:", np.argmax(results)) + print("output argmax:", np.argmax(outputs)) + #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) + getDiff(results, outputs) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + runtime = backend.ASCENDRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_and_compare(name, model, runtime, world_size, rank) + + +def start_single(name, model): + runtime = backend.ASCENDRuntime(0) + run_and_compare(name, model, runtime) + + +def generate_input_output(model): + os.makedirs(os.path.dirname("./data/"), exist_ok=True) + runtime = backend.ASCENDRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + # input = np.array([position_id]) + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(f"./data/input_{i}", input) + stub.run() + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + if np.isnan(output).any(): + print("Nan in output") + np.save(f"./data/output", output) + + +def load_inputs(stub, world_size=1, rank=0): + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = np.load(f"./data/input_{i}.npy") + if all(x == y for x,y in zip(input.shape,tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + +def getDiff(base, test): + absolute_diff = np.abs(np.subtract(base, test)) + max_absolute_diff = np.max(absolute_diff) + + baseCopy = base.astype(np.float64).ravel() + testCopy = test.astype(np.float64).ravel() + upValue = np.sum(np.abs(baseCopy - testCopy)) + downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) + max_relative_diff = upValue / downValue + print(f"Max absolute difference: {max_absolute_diff}\n" + f"Max relative difference: {max_relative_diff}") + return max_absolute_diff, max_relative_diff + + +def main(): + acl.init() + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() + + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + p = mp.Process(target=generate_input_output, args=[model]) + p.start() + p.join() + return + + # run single process. + # use standalone process to isolate cuda. + print("run model by single MLU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} MLUs in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/include/ascend/ascend_common.h b/include/ascend/ascend_common.h new file mode 100644 index 00000000..246cb09b --- /dev/null +++ b/include/ascend/ascend_common.h @@ -0,0 +1,37 @@ +#pragma once +#include "acl/acl.h" +#include "acl/acl_op.h" +#include "core/common.h" + +#define checkASCENDError(call) \ + { \ + auto err = call; \ + if (ACL_SUCCESS != err) { \ + fprintf(stderr, "ASCEND error in %s:%i : .\n", __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } \ + } + +#define checkHCCLError(call) \ + { \ + auto err = call; \ + if (HCCL_SUCCESS != err) { \ + fprintf(stderr, "HCCL error in %s:%i : .\n", __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } \ + } + +#define GetRecentErrMsg() \ + { \ + auto tmp_err_msg = aclGetRecentErrMsg(); \ + if (tmp_err_msg != NULL) { \ + printf(" ERROR Message : %s \n ", tmp_err_msg); \ + } \ + } + +namespace infini { + +using ASCENDPtr = void *; + +} // namespace infini diff --git a/include/ascend/ascend_kernel_without_config.h b/include/ascend/ascend_kernel_without_config.h new file mode 100644 index 00000000..28b1566f --- /dev/null +++ b/include/ascend/ascend_kernel_without_config.h @@ -0,0 +1,23 @@ +#pragma once +#include "ascend/ascend_runtime.h" +#include "core/kernel.h" + +namespace infini { + +class ASCENDKernelWithoutConfig : public Kernel { + public: + virtual void compute(const Operator &op, const PerfRecord &record, + const RuntimeObj *context) const { + compute(op, context); + } + virtual void compute(const Operator &op, + const RuntimeObj *context) const = 0; + // Premise: op is idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op, + const RuntimeObj *_context) const { + auto context = dynamic_cast(_context); + return make_ref(timeit([&]() { compute(op, _context); }, + [&]() { context->sync(); })); + } +}; +} // namespace infini diff --git a/include/ascend/ascend_runtime.h b/include/ascend/ascend_runtime.h new file mode 100644 index 00000000..cff861bf --- /dev/null +++ b/include/ascend/ascend_runtime.h @@ -0,0 +1,82 @@ +#pragma once +#include "ascend/ascend_common.h" +#include "core/runtime.h" + +namespace infini { + +class ASCENDRuntimeObj : public RuntimeObj { + private: + aclrtContext context; + aclrtStream stream; + std::unique_ptr comm; + ASCENDPtr workspace = nullptr; + uint64_t workspaceSize; + + public: + ASCENDRuntimeObj(int deviceId = 0) : RuntimeObj(Device::ASCEND, deviceId) { + // auto ret = aclInit(nullptr); + // CHECK_RET(ret == ACL_SUCCESS, + // LOG_PRINT("aclInit failed. ERROR: %d\n", ret)); + auto ret = aclrtSetDevice(deviceId); + checkASCENDError(ret); + ret = aclrtCreateStream(&stream); + checkASCENDError(ret); + + // 10GB for Longformer + // size_t longformerNum = 3lu * (1 << 30); + workspaceSize = 3ll * (1 << 30); // 3 GB + + workspace = alloc(workspaceSize); + } + virtual ~ASCENDRuntimeObj() { + dealloc(workspace); + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + // aclFinalize(); + } + string toString() const override; + + void run(const Graph &graph, bool tune = false, + bool profiling = false) const; + + void sync() const; + ASCENDPtr alloc(size_t size) override { + void *ptr; + checkASCENDError( + aclrtMalloc((void **)&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + return ptr; + } + void dealloc(void *ptr) override { checkASCENDError(aclrtFree(ptr)); } + aclrtStream ASCENDHandle() const { return stream; } + ASCENDPtr getWorkspace(uint64_t size) const { + IT_ASSERT(size <= workspaceSize); + return workspace; + } + + void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const override { + aclrtMemcpy(dst, bytes, const_cast(src), bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + } + + void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const override { + aclrtMemcpy(dst, bytes, const_cast(src), bytes, + ACL_MEMCPY_DEVICE_TO_HOST); + } + + void copyBlobInsideRuntime(void *dst, const void *src, + size_t bytes) const override { + aclrtMemcpy(dst, bytes, const_cast(src), bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE); + } + + void initComm(const string &name, int worldSize, int rank) final; + + CommunicatorObj &getCommunicator() const override { return *comm; } + + private: + void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; +}; + +} // namespace infini diff --git a/include/ascend/hccl_communicator.h b/include/ascend/hccl_communicator.h new file mode 100644 index 00000000..b39dfd1c --- /dev/null +++ b/include/ascend/hccl_communicator.h @@ -0,0 +1,72 @@ +#pragma once +#include "ascend/ascend_common.h" +#include "core/communicator.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini { + +class HcclCommunicatorObj final : public CommunicatorObj { + private: + HcclComm comm; + + public: + HcclCommunicatorObj(const string &name, int worldSize, int rank) + : CommunicatorObj(worldSize, rank) { + const std::string filePath("./" + name + "_hccl_id.bin"); + int devId = rank; + int devCount = worldSize; + // get rootInfo in rootRank + HcclRootInfo rootInfo; + int32_t rootRank = 0; + + if (devId == rootRank) { + checkHCCLError(HcclGetRootInfo(&rootInfo)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *)&rootInfo, sizeof(HcclRootInfo)); + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + _IT_ASSERT_2(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *)&rootInfo, sizeof(HcclRootInfo)); + } + + auto ret = HcclCommInitRootInfo(uint32_t(devCount), &rootInfo, + uint32_t(devId), &comm); + checkHCCLError(ret); + + if (rank == 0) { + std::filesystem::remove(filePath); + } + } + + // Get the actual ncclComm_t + HcclComm getHcclComm() { return comm; } + + ~HcclCommunicatorObj() final { + auto ret = HcclCommDestroy(comm); + checkHCCLError(ret); + } + + virtual string toString() const final { + std::ostringstream oss; + oss << "HCCL communicator"; + return oss.str(); + } +}; + +} // namespace infini diff --git a/include/core/common.h b/include/core/common.h index 0df16135..c3e25bb5 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -93,4 +93,6 @@ double timeit( const std::function &sync = []() {}, int warmupRounds = 10, int timingRounds = 10); +std::vector castTo64(std::vector const &v32); + } // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 7f29cdd6..8dc4e41e 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -38,6 +38,8 @@ class GraphHandlerObj { float momentum, float eps, bool training); Tensor layerNormalization(Tensor input, Tensor scale, Tensor output, Tensor bias, float eps, int axis, int stash_type); + Tensor instanceNormalization(Tensor input, Tensor output, Tensor scale, + Tensor bias, float eps); Tensor rmsNorm(Tensor input, Tensor weight, Tensor output); Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, @@ -77,7 +79,7 @@ class GraphHandlerObj { Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor resize(Tensor input, Tensor output, const std::optional> &axes, Tensor sizes, - Tensor scales, Tensor roi, vector sizes_, + Tensor scales, Tensor roi, vector sizes_, vector scales_, vector roi_, string mode, string ratioPolicy, string nearestMode, string coordTransMode); diff --git a/include/core/op_type.h b/include/core/op_type.h index 78ca44b8..59c02783 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef OP_TYPE_H #define OP_TYPE_H @@ -261,6 +261,7 @@ struct OpType { enum class ActType { None, Relu, + LeakyRelu, Sigmoid, Tanh, }; diff --git a/include/core/runtime.h b/include/core/runtime.h index c5544276..962aa82e 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -32,7 +32,7 @@ using OpLists = list; using VType = uint32_t; -enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN }; +enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN, ASCEND }; /***************** Forward declaration end *****************/ class RuntimeObj : public std::enable_shared_from_this { @@ -75,6 +75,7 @@ class RuntimeObj : public std::enable_shared_from_this { bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } bool isKUNLUN() const { return device == Device::KUNLUN; } + bool isAscend() const { return device == Device::ASCEND; } void copyBlob(const TensorObj *dst, const TensorObj *src) const; // TODO: unify these copy APIs virtual void copyBlobFromCPU(void *dst, const void *src, diff --git a/include/operators/instance_norm.h b/include/operators/instance_norm.h new file mode 100644 index 00000000..a6d7f2e7 --- /dev/null +++ b/include/operators/instance_norm.h @@ -0,0 +1,24 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class InstanceNormObj : public OperatorObj { + public: + InstanceNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor scale, + Tensor bias, float eps = 1e-5); + OP_CLONE(InstanceNormObj); + optional> inferShape(const TensorVec &inputs) override; + std::string toString() const override; + + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return outputs.size(); } + float getEps() const { return eps; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + + vector inferDataType(const TensorVec &inputs) const override; + float eps; +}; +} // namespace infini diff --git a/include/operators/resize.h b/include/operators/resize.h index 220ef719..b8375a8d 100644 --- a/include/operators/resize.h +++ b/include/operators/resize.h @@ -85,6 +85,7 @@ class ResizeObj : public OperatorObj { vector axes; vector scales; vector roi; + bool isGivenSizes = false; ECoordinateTransMode coMode; // compute src coordinate from dst coordinate ECoeffMode mode; // coeff mode,for computing dst value from coordinate src @@ -139,9 +140,8 @@ class ResizeObj : public OperatorObj { } else return 0; } - bool isResizeBySizes() const { - return ratioPolicy != EKeepAspectRatioPolicy::none; - } + void setGivenSizes(bool val) { isGivenSizes = val; } + bool isResizeBySizes() const { return isGivenSizes; } private: vector getWorkloadVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f2adb71b..558eef3d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -191,7 +191,7 @@ class OnnxStub: node, { "dilations": [1, 1], - "pads": [0, 0], + "pads": [0, 0, 0, 0], "strides": [1, 1], "output_padding": [0, 0], }, @@ -200,19 +200,63 @@ class OnnxStub: attributes[name] for name in ["dilations", "pads", "strides", "output_padding"] ) - tensors[node.output[0]] = self.handler.convTransposed2d( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - op[0], - op[1], - ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors[node.input[0]], None, p, [-2, -1] + ) + p = [0, 0, 0, 0] + else: + adapt = node.input[0] + + if len(node.input) > 2: + bias = "{}-bias".format(node.output[0]) + reshape = "{}-reshape".format(node.output[0]) + tensors[bias] = self.handler.convTransposed2d( + tensors[adapt], + tensors[node.input[1]], + None, + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + op[0], + op[1], + ) + tensors[reshape] = self.handler.reshape( + tensors[node.input[2]], + None, + [ + 1, + reduce( + lambda acc, x: acc * x, + tensors[node.input[2]].shape(), + ), + 1, + 1, + ], + ) + tensors[node.output[0]] = self.handler.add( + tensors[bias], + tensors[reshape], + tensors.get(node.output[0]), + ) + else: + tensors[node.output[0]] = self.handler.convTransposed2d( + tensors[adapt], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + op[0], + op[1], + ) elif node.op_type == "MatMul": tensors[node.output[0]] = self.handler.matmul( tensors[node.input[0]], # input @@ -286,6 +330,21 @@ class OnnxStub: axis, stash_type, ) + elif node.op_type == "InstanceNormalization": + (input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2]) + + output = tensors.get(node.output[0]) + + tensors[node.output[0]] = self.handler.instanceNormalization( + input, + output, + scale, + bias, + next( + (attr.f for attr in node.attribute if attr.name == "epsilon"), + 1e-5, + ), + ) elif node.op_type == "RMSNorm": tensors[node.output[0]] = self.handler.RMSNorm( tensors[node.input[0]], @@ -557,16 +616,6 @@ class OnnxStub: tensors[node.input[1]], tensors.get(node.output[0]), ) - elif node.op_type == "LeakyRelu": - tensors[node.output[0]] = self.handler.leakyRelu( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.f for attr in node.attribute if attr.name == "alpha"), - 0.01, - ), - ) - elif node.op_type == "Clip": tensors[node.output[0]] = self.handler.clip( tensors[node.input[0]], @@ -625,7 +674,7 @@ class OnnxStub: "cubic_coeff_a": -0.75, "exclude_outside": 0, "extrapolation_value": 0.0, - "keep_aspect_ratio_policy": "none", + "keep_aspect_ratio_policy": "stretch", "mode": "nearest", "nearest_mode": "none", }, diff --git a/scripts/format.py b/scripts/format.py index e1c5665d..d4593776 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -11,7 +11,7 @@ proj_path = Path(sys.path[0]).parent def format_file(file): file = Path(proj_path.joinpath(file)) if file.suffix in c_style_file: - run(f"clang-format-14 -style=file -i {file}", cwd=proj_path, shell=True) + run(f"clang-format -style=file -i {file}", cwd=proj_path, shell=True) run(f"git add {file}", cwd=proj_path, shell=True) elif file.suffix == py_file: run(f"black {file}", cwd=proj_path, shell=True) diff --git a/src/ascend/ascend_runtime.cc b/src/ascend/ascend_runtime.cc new file mode 100644 index 00000000..c543997a --- /dev/null +++ b/src/ascend/ascend_runtime.cc @@ -0,0 +1,73 @@ +#include "ascend/ascend_runtime.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#ifdef INFINI_USE_HCCL +#include "ascend/hccl_communicator.h" +#endif + +namespace infini { + +void ASCENDRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, + bool profiling = false) const { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); + double totalTime = 0; + std::map opTime; + std::map opCnt; + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; + auto perfData = perfEngine.getPerfData(perfKey); + if (!perfData && !tune) { + kernel->compute(op, this); + continue; + } + + PerfRecord record; + if (!perfData) { + record = kernel->tune(op, this); + perfEngine.setPerfData(perfKey, record); + } else + record = perfData; + + double t = record->time; + totalTime += t; + + if (profiling) { + double t = timeit([&]() { kernel->compute(op, record, this); }, + [&]() { sync(); }, 1, 1); + op->print(); + printf(" op_time on kunlun xpu %lf\n", t); + totalTime += t; + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } + } +} + +void ASCENDRuntimeObj::run(const Graph &graph, bool tune, + bool profiling) const { + if (profiling) + IT_TODO_HALT(); + runWithoutSync(graph, tune, profiling); + sync(); +} + +void ASCENDRuntimeObj::sync() const { aclrtSynchronizeStream(stream); } + +string ASCENDRuntimeObj::toString() const { return "ASCEND Runtime"; } + +void ASCENDRuntimeObj::initComm(const string &name, int worldSize, int rank) { + IT_ASSERT(worldSize > 0); + IT_ASSERT(rank >= 0); + IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; +#ifdef INFINI_USE_HCCL + comm = std::make_unique(name, worldSize, rank); +#else + IT_TODO_HALT_MSG("Not compiled with CNCL."); +#endif +} +} // namespace infini diff --git a/src/core/common.cc b/src/core/common.cc index 94e09a10..9c4f80e0 100644 --- a/src/core/common.cc +++ b/src/core/common.cc @@ -21,4 +21,17 @@ double timeit(const std::function &func, timingRounds; } +// transform vector to vector +std::vector castTo64(std::vector const &v32) { + if (v32.size() == 0) { + std::vector v64(1, 1); + return v64; + } + std::vector v64(v32.size(), 1); + for (size_t i = 0; i < v32.size(); ++i) { + v64[i] = int64_t(v32[i]); + } + return v64; +} + } // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index f0d1687f..2fd01f90 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,4 +1,4 @@ -#include "core/graph_handler.h" +#include "core/graph_handler.h" #include "operators/all_gather.h" #include "operators/all_reduce.h" #include "operators/attention_kvcache.h" @@ -9,6 +9,7 @@ #include "operators/element_wise.h" #include "operators/expand.h" #include "operators/gather.h" +#include "operators/instance_norm.h" #include "operators/layer_norm.h" #include "operators/lrn.h" #include "operators/matmul.h" @@ -135,6 +136,21 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale, } } +Tensor GraphHandlerObj::instanceNormalization(Tensor input, Tensor output, + Tensor scale, Tensor bias, + float eps) { + if (output) { + g->addOpWithOutputs( + std::move(input), output, std::move(scale), std::move(bias), eps); + return output; + } else { + return g + ->addOp(std::move(input), output, std::move(scale), + std::move(bias), eps) + ->getOutput(); + } +} + Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) { if (output) { g->addOpWithOutputs(std::move(input), std::move(weight), @@ -295,13 +311,13 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { Tensor GraphHandlerObj::resize(Tensor input, Tensor output, const std::optional> &axes, Tensor sizes, Tensor scales, Tensor roi, - vector sizes_, vector scales_, + vector sizes_, vector scales_, vector roi_, string mode, string ratioPolicy, string nearestMode, string coordTransMode) { if (sizes_.size() > 0) { sizes->dataMalloc(); - sizes->copyin(sizes_); + sizes->copyin(sizes_); } if (scales_.size() > 0) { scales->dataMalloc(); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index a106074f..e7761f23 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -30,6 +30,9 @@ #ifdef USE_KUNLUN #include "kunlun/kunlun_runtime.h" #endif +#ifdef USE_ASCEND +#include "ascend/ascend_runtime.h" +#endif #ifdef USE_INTELCPU #include "intelcpu/mkl_runtime.h" #include "intelcpu/operator_timer.h" @@ -65,6 +68,7 @@ void export_values(py::module &m) { py::enum_(m, "ActType") .value("Linear", ActType::None) // `None` is Python keyword .VALUE(ActType, Relu) + .VALUE(ActType, LeakyRelu) .VALUE(ActType, Sigmoid) .VALUE(ActType, Tanh) .export_values(); @@ -102,9 +106,9 @@ void export_values(py::module &m) { .VALUE(OpType, BatchNormalization) .VALUE(OpType, Softmax) .VALUE(OpType, Relu) + .VALUE(OpType, LeakyRelu) .VALUE(OpType, Gelu) .VALUE(OpType, PRelu) - .VALUE(OpType, LeakyRelu) .VALUE(OpType, Sigmoid) .VALUE(OpType, Tanh) .VALUE(OpType, HardSigmoid) @@ -177,6 +181,12 @@ static Ref kunlun_runtime() { } #endif +#ifdef USE_ASCEND +static Ref ascend_runtime() { + return make_ref(); +} +#endif + #ifdef USE_INTELCPU static Ref intelcpu_runtime() { return make_ref(); } #endif @@ -355,6 +365,10 @@ void export_functions(py::module &m) { #ifdef USE_KUNLUN .FUNCTION(kunlun_runtime) #endif + +#ifdef USE_ASCEND + .FUNCTION(ascend_runtime) +#endif .FUNCTION(conv_attrs_of) .FUNCTION(conv_trans_attrs_of) .FUNCTION(matmul_attrs_of) @@ -442,6 +456,14 @@ void init_graph_builder(py::module &m) { .def(py::init(), py::arg("device") = 0) .def("init_comm", &KUNLUNRuntimeObj::initComm); #endif + +#ifdef USE_ASCEND + py::class_, RuntimeObj>( + m, "ASCENDRuntime") + .def(py::init(), py::arg("device") = 0) + .def("init_comm", &ASCENDRuntimeObj::initComm); + ; +#endif py::class_>(m, "Tensor", py::buffer_protocol()) .def("fuid", &TensorObj::getFuid, policy::automatic) @@ -516,6 +538,8 @@ void init_graph_builder(py::module &m) { .def("matmul", &Handler::matmul, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move) .def("layerNormalization", &Handler::layerNormalization, policy::move) + .def("instanceNormalization", &Handler::instanceNormalization, + policy::move) .def("RMSNorm", &Handler::rmsNorm, policy::move) .def("maxPool", &Handler::maxPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move) @@ -528,6 +552,7 @@ void init_graph_builder(py::module &m) { .def("min", &Handler::min, policy::move) .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) + .def("leakyRelu", &Handler::leakyRelu, policy::move) .def("silu", &Handler::silu, policy::move) .def("gelu", &Handler::gelu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) @@ -542,7 +567,6 @@ void init_graph_builder(py::module &m) { .def("identity", &Handler::identity, policy::move) .def("flatten", &Handler::flatten, policy::move) .def("pRelu", &Handler::pRelu, policy::move) - .def("leakyRelu", &Handler::leakyRelu, policy::move) .def("clip", &Handler::clip, policy::move) .def("transpose", &Handler::transpose, policy::move) .def("depthToSpace", &Handler::depthToSpace, policy::move) diff --git a/src/kernels/ascend/all_gather.cc b/src/kernels/ascend/all_gather.cc new file mode 100644 index 00000000..845bafbf --- /dev/null +++ b/src/kernels/ascend/all_gather.cc @@ -0,0 +1,50 @@ +#ifdef INFINI_USE_HCCL +#include "operators/all_gather.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +namespace infini { +class AllGatherHCCL : public ASCENDKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int world_size = op->getWorldSize(); + // Check if world size info in operator matches runtime + IT_ASSERT(world_size == context->getCommunicator().getWorldSize()); + + void *input = op->getInputs(0)->getRawDataPtr(); + ASCENDPtr output_temp = + context->getWorkspace(op->getInputs(0)->getBytes() * world_size); + + IT_ASSERT(op->getDType() == DataType::Float32); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / op->getDType().getSize(); + + HcclComm comm = + dynamic_cast(context->getCommunicator()) + .getHcclComm(); + + checkHCCLError(HcclAllGather(input, output_temp, uint64_t(count), + HCCL_DATA_TYPE_FP32, comm, + context->ASCENDHandle())); + checkASCENDError(aclrtSynchronizeStream(context->ASCENDHandle())); + + for (int i = 0; i < world_size; ++i) { + Tensor output = op->getOutput(i); + context->copyBlobInsideRuntime( + output->getRawDataPtr(), + static_cast(output_temp) + i * count, bytes); + } + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::AllGather, AllGatherHCCL, + "AllGather_HCCL_ASCEND"); +} // namespace infini + +#endif diff --git a/src/kernels/ascend/all_reduce.cc b/src/kernels/ascend/all_reduce.cc new file mode 100644 index 00000000..0e9799f6 --- /dev/null +++ b/src/kernels/ascend/all_reduce.cc @@ -0,0 +1,57 @@ +#ifdef INFINI_USE_HCCL +#include "operators/all_reduce.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +namespace infini { +class AllReduceHCCL : public ASCENDKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *sendBuff = op->getInputs(0)->getRawDataPtr(); + void *recvBuff = op->getOutput()->getRawDataPtr(); + + size_t count = op->getInputs(0)->size(); + + HcclComm comm = + dynamic_cast(context->getCommunicator()) + .getHcclComm(); + + checkHCCLError(HcclAllReduce(sendBuff, recvBuff, count, + HCCL_DATA_TYPE_FP32, getRedOp(), comm, + context->ASCENDHandle())); + checkASCENDError(aclrtSynchronizeStream(context->ASCENDHandle())); + } + + virtual HcclReduceOp getRedOp() const = 0; +}; + +class AllReduceSumHCCL : public AllReduceHCCL { + HcclReduceOp getRedOp() const override { return HCCL_REDUCE_SUM; } +}; +class AllReduceProdHCCL : public AllReduceHCCL { + HcclReduceOp getRedOp() const override { return HCCL_REDUCE_PROD; } +}; +class AllReduceMinHCCL : public AllReduceHCCL { + HcclReduceOp getRedOp() const override { return HCCL_REDUCE_MIN; } +}; +class AllReduceMaxHCCL : public AllReduceHCCL { + HcclReduceOp getRedOp() const override { return HCCL_REDUCE_MAX; } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceSum, AllReduceSumHCCL, + "AllReduce_Sum_HCCL_ASCEND"); +REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceProd, AllReduceProdHCCL, + "AllReduce_Prod_HCCL_ASCEND"); +REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceMin, AllReduceMinHCCL, + "AllReduce_Min_HCCL_ASCEND"); +REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceMax, AllReduceMaxHCCL, + "AllReduce_Max_HCCL_ASCEND"); + +} // namespace infini +#endif diff --git a/src/kernels/ascend/batch_norm.cc b/src/kernels/ascend/batch_norm.cc new file mode 100644 index 00000000..cac1097b --- /dev/null +++ b/src/kernels/ascend/batch_norm.cc @@ -0,0 +1,99 @@ +#include "operators/batch_norm.h" +#include "aclnnop/level2/aclnn_batch_norm.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class BatchNormAclnn : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + void *const meanData = (op->getInputs(1)->getRawDataPtr()); + void *const varData = (op->getInputs(2)->getRawDataPtr()); + void *const scaleData = (op->getInputs(3)->getRawDataPtr()); + void *const biasData = (op->getInputs(4)->getRawDataPtr()); + + auto inD = op->getInputs(0)->getDims(); + auto inS = op->getInputs(0)->getStride(); + auto paraD = op->getInputs(1)->getDims(); + auto paraS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector inputDim = castTo64(inD); + std::vector inputStride = castTo64(inS); + std::vector paraDim = castTo64(paraD); + std::vector paraStride = castTo64(paraS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto inputTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + inputDim.data(), inputDim.size(), inData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), outData); + auto meanTensor = aclCreateTensor( + paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0, + aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), meanData); + auto varTensor = aclCreateTensor( + paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0, + aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), varData); + auto scaleTensor = + aclCreateTensor(paraDim.data(), paraDim.size(), ACL_FLOAT, + paraStride.data(), 0, aclFormat::ACL_FORMAT_ND, + paraDim.data(), paraDim.size(), scaleData); + auto biasTensor = aclCreateTensor( + paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0, + aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), biasData); + auto savemeanTensor = + aclCreateTensor(paraDim.data(), paraDim.size(), ACL_FLOAT, + paraStride.data(), 0, aclFormat::ACL_FORMAT_ND, + paraDim.data(), paraDim.size(), scaleData); + auto saveinvstdTensor = aclCreateTensor( + paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0, + aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), biasData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnBatchNormGetWorkspaceSize( + inputTensor, scaleTensor, biasTensor, meanTensor, varTensor, false, + op->getMomentum(), op->getEps(), outputTensor, savemeanTensor, + saveinvstdTensor, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnBatchNorm(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyTensor(outputTensor); + aclDestroyTensor(meanTensor); + aclDestroyTensor(varTensor); + aclDestroyTensor(scaleTensor); + aclDestroyTensor(biasTensor); + aclDestroyTensor(savemeanTensor); + aclDestroyTensor(saveinvstdTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::BatchNormalization, BatchNormAclnn, + "batchnorm_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/concat.cc b/src/kernels/ascend/concat.cc new file mode 100644 index 00000000..f0f47fec --- /dev/null +++ b/src/kernels/ascend/concat.cc @@ -0,0 +1,74 @@ +#include "operators/concat.h" +#include "aclnnop/level2/aclnn_cat.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class ConcatAclnn : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + int dim = op->getDim(); + int num = op->numInputs(); + + std::vector inputsData{}; + + for (int i = 0; i < num; ++i) { + auto inD = op->getInputs(i)->getDims(); + auto inS = op->getInputs(i)->getStride(); + std::vector inputDim = castTo64(inD); + std::vector inputStride = castTo64(inS); + + void *const inData = (op->getInputs(i)->getRawDataPtr()); + auto tmpTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_ND, + inputDim.data(), inputDim.size(), inData); + + inputsData.push_back(tmpTensor); + } + aclTensorList *tensorList = + aclCreateTensorList(inputsData.data(), inputsData.size()); + + void *const outData = (op->getOutput()->getRawDataPtr()); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_ND, + outputDim.data(), outputDim.size(), outData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnCatGetWorkspaceSize( + tensorList, int64_t(dim), outputTensor, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnCat(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensorList(tensorList); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Concat, ConcatAclnn, + "concat_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/conv.cc b/src/kernels/ascend/conv.cc new file mode 100644 index 00000000..dcae118b --- /dev/null +++ b/src/kernels/ascend/conv.cc @@ -0,0 +1,92 @@ +#include "operators/conv.h" +#include "aclnnop/level2/aclnn_convolution.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class ConvAclnn : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int cpg = op->getChannelPerGroup(); + const int g = c / cpg; + + std::vector pads = {ph, pw}; + std::vector stride = {sh, sw}; + std::vector dilation = {dh, dw}; + std::vector outputPadding = {sh - 1, sw - 1}; + + aclIntArray *convpads = aclCreateIntArray(pads.data(), pads.size()); + aclIntArray *convstride = + aclCreateIntArray(stride.data(), stride.size()); + aclIntArray *convdilation = + aclCreateIntArray(dilation.data(), dilation.size()); + aclIntArray *convOutputpadding = + aclCreateIntArray(outputPadding.data(), outputPadding.size()); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto inputD = op->getInputs(0)->getDims(); + auto inputS = op->getInputs(0)->getStride(); + auto weightD = op->getInputs(1)->getDims(); + auto weightS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector inputDim = castTo64(inputD); + std::vector inputStride = castTo64(inputS); + std::vector weightDim = castTo64(weightD); + std::vector weightStride = castTo64(weightS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto inputTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + inputDim.data(), inputDim.size(), aData); + auto weightTensor = + aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT, + weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + weightDim.data(), weightDim.size(), bData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnConvolutionGetWorkspaceSize( + inputTensor, weightTensor, nullptr, convstride, convpads, + convdilation, false, convOutputpadding, int64_t(g), outputTensor, + int8_t(1), &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyTensor(weightTensor); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Conv, ConvAclnn, "conv_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/conv_transpose.cc b/src/kernels/ascend/conv_transpose.cc new file mode 100644 index 00000000..b6692d51 --- /dev/null +++ b/src/kernels/ascend/conv_transpose.cc @@ -0,0 +1,94 @@ +#include "aclnnop/level2/aclnn_convolution.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "operators/conv.h" + +namespace infini { + +class ConvTransAclnn : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const auto [oph, opw] = op->getOutputPadding(); + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int cpg = op->getChannelPerGroup(); + const int g = c / cpg; + + std::vector pads = {ph, pw}; + std::vector stride = {sh, sw}; + std::vector dilation = {dh, dw}; + std::vector outputPadding = {oph, opw}; + + aclIntArray *convpads = aclCreateIntArray(pads.data(), pads.size()); + aclIntArray *convstride = + aclCreateIntArray(stride.data(), stride.size()); + aclIntArray *convdilation = + aclCreateIntArray(dilation.data(), dilation.size()); + aclIntArray *convOutputpadding = + aclCreateIntArray(outputPadding.data(), outputPadding.size()); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto inputD = op->getInputs(0)->getDims(); + auto inputS = op->getInputs(0)->getStride(); + auto weightD = op->getInputs(1)->getDims(); + auto weightS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector inputDim = castTo64(inputD); + std::vector inputStride = castTo64(inputS); + std::vector weightDim = castTo64(weightD); + std::vector weightStride = castTo64(weightS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto inputTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + inputDim.data(), inputDim.size(), aData); + auto weightTensor = + aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT, + weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + weightDim.data(), weightDim.size(), bData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnConvolutionGetWorkspaceSize( + inputTensor, weightTensor, nullptr, convstride, convpads, + convdilation, true, convOutputpadding, int64_t(g), outputTensor, + int8_t(1), &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyTensor(weightTensor); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::ConvTranspose, ConvTransAclnn, + "ConvTrans_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/element_wise.cc b/src/kernels/ascend/element_wise.cc new file mode 100644 index 00000000..4ad4a7ba --- /dev/null +++ b/src/kernels/ascend/element_wise.cc @@ -0,0 +1,219 @@ +#include "operators/element_wise.h" +#include "aclnnop/aclnn_maximum.h" +#include "aclnnop/level2/aclnn_add.h" +#include "aclnnop/level2/aclnn_div.h" +#include "aclnnop/level2/aclnn_mul.h" +#include "aclnnop/level2/aclnn_pow_tensor_tensor.h" +#include "aclnnop/level2/aclnn_sub.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +#define DEFINE_ELEMENT_WISE_Aclnn(prefix) \ + class prefix##Aclnn : public ASCENDKernelWithoutConfig { \ + void compute(const Operator &_op, \ + const RuntimeObj *_context) const override { \ + auto op = as(_op); \ + auto context = dynamic_cast(_context); \ + IT_ASSERT(op->getDType() == DataType::Float32); \ + \ + void *const aData = (op->getInputs(0)->getRawDataPtr()); \ + void *const bData = (op->getInputs(1)->getRawDataPtr()); \ + void *const cData = (op->getOutput()->getRawDataPtr()); \ + \ + auto a = op->getInputs(0)->getDims(); \ + auto aS = op->getInputs(0)->getStride(); \ + auto b = op->getInputs(1)->getDims(); \ + auto bS = op->getInputs(1)->getStride(); \ + auto c = op->getOutput()->getDims(); \ + auto cS = op->getOutput()->getStride(); \ + \ + std::vector aDim = castTo64(a); \ + std::vector aStride = castTo64(aS); \ + std::vector bDim = castTo64(b); \ + std::vector bStride = castTo64(bS); \ + std::vector cDim = castTo64(c); \ + std::vector cStride = castTo64(cS); \ + \ + auto inputA = aclCreateTensor( \ + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, \ + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); \ + auto inputB = aclCreateTensor( \ + bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0, \ + aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData); \ + auto output = aclCreateTensor( \ + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, \ + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); \ + \ + uint64_t workspaceSize = 0; \ + aclOpExecutor *executor; \ + \ + auto ret = aclnn##prefix##GetWorkspaceSize( \ + inputA, inputB, output, &workspaceSize, &executor); \ + checkASCENDError(ret); \ + void *workspaceAddr = nullptr; \ + if (workspaceSize > 0) { \ + workspaceAddr = context->getWorkspace(workspaceSize); \ + } \ + \ + ret = aclnn##prefix(workspaceAddr, workspaceSize, executor, \ + context->ASCENDHandle()); \ + checkASCENDError(ret); \ + \ + aclDestroyTensor(inputA); \ + aclDestroyTensor(inputB); \ + aclDestroyTensor(output); \ + \ + return; \ + } \ + }; + +class AddAclnn : public ASCENDKernelWithoutConfig { + virtual tuple getAlphBeta() const { + return {1.f, 1.f, 0.f}; + } + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto b = op->getInputs(1)->getDims(); + auto bS = op->getInputs(1)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector bDim = castTo64(b); + std::vector bStride = castTo64(bS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto inputB = aclCreateTensor( + bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0, + aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + auto [aAlpha, bAlpha, beta] = getAlphBeta(); + auto alpha = aclCreateScalar(&bAlpha, ACL_FLOAT); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnAddGetWorkspaceSize(inputA, inputB, alpha, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnAdd(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyTensor(inputB); + aclDestroyScalar(alpha); + aclDestroyTensor(output); + + return; + } +}; + +class SubAclnn : public ASCENDKernelWithoutConfig { + virtual tuple getAlphBeta() const { + return {1.f, 1.f, 0.f}; + } + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto b = op->getInputs(1)->getDims(); + auto bS = op->getInputs(1)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector bDim = castTo64(b); + std::vector bStride = castTo64(bS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto inputB = aclCreateTensor( + bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0, + aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + auto [aAlpha, bAlpha, beta] = getAlphBeta(); + auto alpha = aclCreateScalar(&bAlpha, ACL_FLOAT); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnSubGetWorkspaceSize(inputA, inputB, alpha, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnSub(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyTensor(inputB); + aclDestroyScalar(alpha); + aclDestroyTensor(output); + + return; + } +}; + +DEFINE_ELEMENT_WISE_Aclnn(PowTensorTensor); +DEFINE_ELEMENT_WISE_Aclnn(Div); +DEFINE_ELEMENT_WISE_Aclnn(Mul); +DEFINE_ELEMENT_WISE_Aclnn(Maximum); + +REGISTER_KERNEL(Device::ASCEND, OpType::Pow, PowTensorTensorAclnn, + "pow_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Div, DivAclnn, "div_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Mul, MulAclnn, "mul_ASCEND_float"); + +REGISTER_KERNEL(Device::ASCEND, OpType::Add, AddAclnn, "add_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Sub, SubAclnn, "sub_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Max, MaximumAclnn, "max_ASCEND_float"); +// REGISTER_KERNEL(Device::ASCEND, OpType::Abs, AbsAclnn, "abs_ASCEND_float"); + +} // namespace infini diff --git a/src/kernels/ascend/gather.cc b/src/kernels/ascend/gather.cc new file mode 100644 index 00000000..09e4f0ae --- /dev/null +++ b/src/kernels/ascend/gather.cc @@ -0,0 +1,78 @@ +#include "operators/gather.h" +#include "aclnnop/level2/aclnn_gather_v2.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class GatherAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getInputs(1)->getDType() == DataType::Int32 || + op->getInputs(1)->getDType() == DataType::Int64); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + int64_t axis = int64_t(op->getAxis()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto b = op->getInputs(1)->getDims(); + auto bS = op->getInputs(1)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + if (b.size() == 0) { + c.insert(c.begin() + axis, 1); + cS.insert(cS.begin() + axis, axis > 0 ? cS[axis - 1] : cS[0]); + } + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector bDim = castTo64(b); + std::vector bStride = castTo64(bS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + + auto inputB = aclCreateTensor( + bDim.data(), bDim.size(), ACL_INT64, bStride.data(), 0, + aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData); + + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnGatherV2GetWorkspaceSize(inputA, axis, inputB, output, + &workspaceSize, &executor); + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + checkASCENDError(ret); + + ret = aclnnGatherV2(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyTensor(inputB); + aclDestroyTensor(output); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Gather, GatherAclnn, + "gather_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/instance_norm.cc b/src/kernels/ascend/instance_norm.cc new file mode 100644 index 00000000..2e49c18c --- /dev/null +++ b/src/kernels/ascend/instance_norm.cc @@ -0,0 +1,104 @@ +#include "operators/instance_norm.h" +#include "aclnnop/level2/aclnn_layer_norm.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "operators/gather.h" + +namespace infini { + +class InstanceNormAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const weightData = (op->getInputs(1)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + auto inputD = op->getInputs(0)->getDims(); + auto inputS = op->getInputs(0)->getStride(); + auto weightD = op->getInputs(1)->getDims(); + auto weightS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + double eps = static_cast(op->getEps()); + + std::vector inputDim = castTo64(inputD); + std::vector inputStride = castTo64(inputS); + std::vector weightDim = castTo64(weightD); + std::vector weightStride = castTo64(weightS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto axis = 3; + + auto rank = static_cast(inputDim.size()); + std::vector normalizedShape(rank - axis, 0); + for (auto i = rank; i > axis; --i) { + normalizedShape[i - 1 - axis] = inputDim[i - 1]; + } + + auto inputTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + inputDim.data(), inputDim.size(), inputData); + auto weightTensor = + aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT, + weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + weightDim.data(), weightDim.size(), weightData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), outputData); + + auto *normArray = + aclCreateIntArray(normalizedShape.data(), normalizedShape.size()); + + aclTensor *biasTensor = NULL; + if (op->numInputs() == 3) { + void *const biasData = (op->getInputs(2)->getRawDataPtr()); + + auto biasD = op->getInputs(2)->getDims(); + auto biasS = op->getInputs(2)->getStride(); + std::vector biasDim = castTo64(biasD); + std::vector biasStride = castTo64(biasS); + + biasTensor = aclCreateTensor( + biasDim.data(), biasDim.size(), ACL_FLOAT, biasStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, biasDim.data(), biasDim.size(), + biasData); + } + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnLayerNormGetWorkspaceSize( + inputTensor, normArray, weightTensor, biasTensor, eps, outputTensor, + NULL, NULL, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyTensor(weightTensor); + aclDestroyIntArray(normArray); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::InstanceNormalization, + InstanceNormAclnn, "InstanceNorm_ASCEND"); + +} // namespace infini diff --git a/src/kernels/ascend/layer_norm.cc b/src/kernels/ascend/layer_norm.cc new file mode 100644 index 00000000..12e8f1b5 --- /dev/null +++ b/src/kernels/ascend/layer_norm.cc @@ -0,0 +1,103 @@ +#include "operators/layer_norm.h" +#include "aclnnop/level2/aclnn_layer_norm.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "operators/gather.h" + +namespace infini { + +class LayerNormAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const weightData = (op->getInputs(1)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + auto inputD = op->getInputs(0)->getDims(); + auto inputS = op->getInputs(0)->getStride(); + auto weightD = op->getInputs(1)->getDims(); + auto weightS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + double eps = static_cast(op->getEps()); + + std::vector inputDim = castTo64(inputD); + std::vector inputStride = castTo64(inputS); + std::vector weightDim = castTo64(weightD); + std::vector weightStride = castTo64(weightS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto axis = op->getAxis(); + auto rank = static_cast(inputDim.size()); + std::vector normalizedShape(rank - axis, 0); + for (auto i = rank; i > axis; --i) { + normalizedShape[i - 1 - axis] = inputDim[i - 1]; + } + + auto inputTensor = + aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT, + inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + inputDim.data(), inputDim.size(), inputData); + auto weightTensor = + aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT, + weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + weightDim.data(), weightDim.size(), weightData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), outputData); + + auto *normArray = + aclCreateIntArray(normalizedShape.data(), normalizedShape.size()); + + aclTensor *biasTensor = NULL; + if (op->numInputs() == 3) { + void *const biasData = (op->getInputs(2)->getRawDataPtr()); + + auto biasD = op->getInputs(2)->getDims(); + auto biasS = op->getInputs(2)->getStride(); + std::vector biasDim = castTo64(biasD); + std::vector biasStride = castTo64(biasS); + + biasTensor = aclCreateTensor( + biasDim.data(), biasDim.size(), ACL_FLOAT, biasStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, biasDim.data(), biasDim.size(), + biasData); + } + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnLayerNormGetWorkspaceSize( + inputTensor, normArray, weightTensor, biasTensor, eps, outputTensor, + NULL, NULL, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyTensor(weightTensor); + aclDestroyIntArray(normArray); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::LayerNormalization, LayerNormAclnn, + "LayerNorm_ASCEND"); + +} // namespace infini diff --git a/src/kernels/ascend/matmul.cc b/src/kernels/ascend/matmul.cc new file mode 100644 index 00000000..15714cde --- /dev/null +++ b/src/kernels/ascend/matmul.cc @@ -0,0 +1,118 @@ +#include "operators/matmul.h" +#include "aclnnop/level2/aclnn_gemm.h" +#include "aclnnop/level2/aclnn_matmul.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class MatmulAclnn : public ASCENDKernelWithoutConfig { + // unsupport trans for "gemm" whithou biasInput + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + auto input_num = op->numInputs(); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + void *biasData = NULL; + if (input_num > 2) { + biasData = (op->getInputs(2)->getRawDataPtr()); + } + + auto selfD = op->getInputs(0)->getDims(); + auto selfS = op->getInputs(0)->getStride(); + auto matD = op->getInputs(1)->getDims(); + auto matS = op->getInputs(1)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + std::vector biasD; + std::vector biasS; + if (input_num > 2) { + biasD = op->getInputs(2)->getDims(); + biasS = op->getInputs(2)->getStride(); + } + + std::vector selfDim = castTo64(selfD); + std::vector selfStride = castTo64(selfS); + std::vector matDim = castTo64(matD); + std::vector matStride = castTo64(matS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + std::vector biasDim; + std::vector biasStride; + if (input_num > 2) { + biasDim = castTo64(biasD); + biasStride = castTo64(biasS); + } + + auto selfTensor = aclCreateTensor( + selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0, + aclFormat::ACL_FORMAT_ND, selfDim.data(), selfDim.size(), aData); + auto matTensor = aclCreateTensor( + matDim.data(), matDim.size(), ACL_FLOAT, matStride.data(), 0, + aclFormat::ACL_FORMAT_ND, matDim.data(), matDim.size(), bData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_ND, + outputDim.data(), outputDim.size(), cData); + aclTensor *biasTensor = NULL; + if (input_num > 2) { + biasTensor = + aclCreateTensor(biasDim.data(), biasDim.size(), ACL_FLOAT, + biasStride.data(), 0, aclFormat::ACL_FORMAT_ND, + biasDim.data(), biasDim.size(), biasData); + } + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + if (input_num > 2) { + float alpha = 1.0; + float beta = 1.0; + int32_t transA = op->getTransA(); + int32_t transB = op->getTransB(); + + auto ret = aclnnGemmGetWorkspaceSize( + selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA), + int64_t(transB), outputTensor, 1, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnGemm(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + } else { + auto ret = + aclnnMatmulGetWorkspaceSize(selfTensor, matTensor, outputTensor, + 1, &workspaceSize, &executor); + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + checkASCENDError(ret); + + ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + } + + // aclDestroyTensor(selfTensor); + // aclDestroyTensor(matTensor); + // aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::MatMul, MatmulAclnn, + "matmul_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/pad2d.cc b/src/kernels/ascend/pad2d.cc new file mode 100644 index 00000000..03b75926 --- /dev/null +++ b/src/kernels/ascend/pad2d.cc @@ -0,0 +1,83 @@ +#include "aclnnop/aclnn_constant_pad_nd.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "operators/pad.h" + +namespace infini { + +class PadAclnn : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto inputD = op->getInputs(0)->getDims(); + auto inputS = op->getInputs(0)->getStride(); + + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector inputDim = castTo64(inputD); + std::vector inputStride = castTo64(inputS); + + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + auto inputTensor = aclCreateTensor( + inputDim.data(), inputDim.size(), ACL_FLOAT, inputStride.data(), 0, + aclFormat::ACL_FORMAT_ND, inputDim.data(), inputDim.size(), aData); + + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_ND, + outputDim.data(), outputDim.size(), cData); + + std::vector intPads = op->getPads(); + + std::size_t length = intPads.size(); + std::vector pads(length); + std::size_t halfLen = length / 2; + for (std::size_t i = 0; i < halfLen; ++i) { + pads[2 * i] = intPads[halfLen - i - 1]; + pads[2 * i + 1] = intPads[2 * halfLen - i - 1]; + } + + std::cout << "pads = " << vecToString(pads) << std::endl; + + aclIntArray *padding = aclCreateIntArray(pads.data(), length); + float valueValue = 0.0f; + auto value = aclCreateScalar(&valueValue, ACL_FLOAT); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnConstantPadNdGetWorkspaceSize( + inputTensor, padding, value, outputTensor, &workspaceSize, + &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnConstantPadNd(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputTensor); + aclDestroyIntArray(padding); + aclDestroyScalar(value); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Pad, PadAclnn, "pad_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/pooling.cc b/src/kernels/ascend/pooling.cc new file mode 100644 index 00000000..fef068bc --- /dev/null +++ b/src/kernels/ascend/pooling.cc @@ -0,0 +1,154 @@ +#include "operators/pooling.h" +#include "aclnnop/level2/aclnn_avgpool2d.h" +#include "aclnnop/level2/aclnn_max_pool.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class AvgPooling : public ASCENDKernelWithoutConfig { + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto [n, c, h, w, kh, kw] = op->getNCHWRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + + std::vector ksize = {kh, kw}; + std::vector stride = {sh, sw}; + std::vector pad = {ph, pw}; + + int64_t divisorOverride = 0; + + auto selfD = op->getInputs(0)->getDims(); + auto selfS = op->getInputs(0)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector selfDim = castTo64(selfD); + std::vector selfStride = castTo64(selfS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + aclIntArray *kernelSize = aclCreateIntArray(ksize.data(), ksize.size()); + aclIntArray *strides = aclCreateIntArray(stride.data(), stride.size()); + aclIntArray *paddings = aclCreateIntArray(pad.data(), pad.size()); + + auto selfTensor = aclCreateTensor( + selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, selfDim.data(), selfDim.size(), aData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnAvgPool2dGetWorkspaceSize( + selfTensor, kernelSize, strides, paddings, false, true, + divisorOverride, int8_t(0), outputTensor, &workspaceSize, + &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(selfTensor); + aclDestroyIntArray(kernelSize); + aclDestroyIntArray(strides); + aclDestroyIntArray(paddings); + aclDestroyTensor(outputTensor); + + return; + } +}; + +class MaxPooling : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto [n, c, h, w, kh, kw] = op->getNCHWRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + int64_t ceilMode = int64_t(op->getCeilMode()); + + std::vector ksize = {kh, kw}; + std::vector stride = {sh, sw}; + std::vector pad = {ph, pw}; + std::vector dilation = {dh, dw}; + + auto selfD = op->getInputs(0)->getDims(); + auto selfS = op->getInputs(0)->getStride(); + auto outD = op->getOutput()->getDims(); + auto outS = op->getOutput()->getStride(); + + std::vector selfDim = castTo64(selfD); + std::vector selfStride = castTo64(selfS); + std::vector outputDim = castTo64(outD); + std::vector outputStride = castTo64(outS); + + aclIntArray *kernelSize = aclCreateIntArray(ksize.data(), ksize.size()); + aclIntArray *strides = aclCreateIntArray(stride.data(), stride.size()); + aclIntArray *paddings = aclCreateIntArray(pad.data(), pad.size()); + aclIntArray *dilations = + aclCreateIntArray(dilation.data(), dilation.size()); + + auto selfTensor = aclCreateTensor( + selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, selfDim.data(), selfDim.size(), aData); + auto outputTensor = + aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT, + outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW, + outputDim.data(), outputDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + auto ret = aclnnMaxPoolGetWorkspaceSize( + selfTensor, kernelSize, strides, 0, paddings, dilations, ceilMode, + outputTensor, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnMaxPool(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(selfTensor); + aclDestroyIntArray(kernelSize); + aclDestroyIntArray(strides); + aclDestroyIntArray(paddings); + aclDestroyIntArray(dilations); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::MaxPool, MaxPooling, + "maxpooling_ASCEND_float"); + +REGISTER_KERNEL(Device::ASCEND, OpType::AveragePool, AvgPooling, + "avgpooling_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/recv.cc b/src/kernels/ascend/recv.cc new file mode 100644 index 00000000..b5338432 --- /dev/null +++ b/src/kernels/ascend/recv.cc @@ -0,0 +1,50 @@ +#ifdef INFINI_USE_HCCL +#include "operators/recv.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +namespace infini { +class RecvHCCL : public ASCENDKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *output = op->getOutput(0)->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + const auto shape = op->getShape(); + int nDims = shape.size(); + int outputCount = 1; + for (int i = 0; i < nDims; i++) { + outputCount *= shape[i]; + } + + HcclComm comm = + dynamic_cast(context->getCommunicator()) + .getHcclComm(); + + uint32_t rank; + + checkHCCLError(HcclGetRankId(comm, &rank)); + + int source = op->getSourceRank(); + int destination = op->getDestinationRank(); + + if (int(rank) == destination) { + checkHCCLError(HcclRecv(output, uint64_t(outputCount), + HCCL_DATA_TYPE_FP32, uint32_t(source), comm, + context->ASCENDHandle())); + } + checkASCENDError(aclrtSynchronizeStream(context->ASCENDHandle())); + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Recv, RecvHCCL, "Recv_HCCL_ASCEND"); +} // namespace infini + +#endif diff --git a/src/kernels/ascend/reduce.cc b/src/kernels/ascend/reduce.cc new file mode 100644 index 00000000..0af84ec4 --- /dev/null +++ b/src/kernels/ascend/reduce.cc @@ -0,0 +1,131 @@ +#include "operators/reduce.h" +#include "aclnnop/aclnn_mean.h" +#include "aclnnop/aclnn_reduce_sum.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class MeanAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto axes_set = op->getAxes(); + std::vector axes; + axes.assign(axes_set.begin(), axes_set.end()); + + bool KeepDim = op->getKeepDims(); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + std::vector axes_64 = castTo64(axes); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + aclIntArray *dim = aclCreateIntArray(axes_64.data(), axes_64.size()); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnMeanV2GetWorkspaceSize( + inputA, dim, KeepDim, true, output, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnMeanV2(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyIntArray(dim); + aclDestroyTensor(output); + + return; + } +}; + +class ReduceSumAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto axes_set = op->getAxes(); + std::vector axes; + axes.assign(axes_set.begin(), axes_set.end()); + + bool KeepDim = op->getKeepDims(); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + std::vector axes_64 = castTo64(axes); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + aclIntArray *dim = aclCreateIntArray(axes_64.data(), axes_64.size()); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnReduceSumGetWorkspaceSize( + inputA, dim, KeepDim, ACL_FLOAT, output, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnReduceSum(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyIntArray(dim); + aclDestroyTensor(output); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::ReduceMean, MeanAclnn, + "reduceMean_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::ReduceSum, ReduceSumAclnn, + "reduceSum_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/reshape.cc b/src/kernels/ascend/reshape.cc new file mode 100644 index 00000000..defb8d3c --- /dev/null +++ b/src/kernels/ascend/reshape.cc @@ -0,0 +1,62 @@ +#include "operators/reshape.h" +#include "aclnnop/level2/aclnn_copy.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { +class CopyAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &op, + const RuntimeObj *_context) const override { + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aD = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + + std::vector aDim = castTo64(aD); + std::vector aStride = castTo64(aS); + + auto srcTensor = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto outputTensor = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnInplaceCopyGetWorkspaceSize(outputTensor, srcTensor, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(srcTensor); + aclDestroyTensor(outputTensor); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Reshape, CopyAclnn, + "reshape_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Unsqueeze, CopyAclnn, + "unsqueeze_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Squeeze, CopyAclnn, + "squeeze_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Flatten, CopyAclnn, + "Flatten_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Identity, CopyAclnn, + "Identity_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/resize.cc b/src/kernels/ascend/resize.cc new file mode 100644 index 00000000..8cf7ce73 --- /dev/null +++ b/src/kernels/ascend/resize.cc @@ -0,0 +1,82 @@ +#include "operators/resize.h" +#include "aclnnop/level2/aclnn_resize.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { +class ResizeAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + int nDims = op->getInputs(0)->getRank(); + if (nDims > 4) + IT_TODO_HALT(); + + vector scalesData = op->getScales(); + + const char *mode; + switch (op->getMode()) { + case ResizeObj::ECoeffMode::nearest: + mode = "nearest"; + break; + case ResizeObj::ECoeffMode::linear: + mode = "bilinear"; + break; + default: + IT_TODO_HALT(); + } + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + auto self = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, aDim.data(), aDim.size(), aData); + + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_NCHW, cDim.data(), cDim.size(), cData); + + aclFloatArray *scales = nullptr; + scales = aclCreateFloatArray(scalesData.data(), scalesData.size()); + assert(scales != nullptr); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnResizeGetWorkspaceSize(self, scales, mode, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnResize(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(self); + aclDestroyFloatArray(scales); + aclDestroyTensor(output); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Resize, ResizeAclnn, "Resize_ASCEND"); + +} // namespace infini diff --git a/src/kernels/ascend/send.cc b/src/kernels/ascend/send.cc new file mode 100644 index 00000000..20195b3d --- /dev/null +++ b/src/kernels/ascend/send.cc @@ -0,0 +1,45 @@ +#ifdef INFINI_USE_HCCL +#include "operators/send.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +namespace infini { +class SendHCCL : public ASCENDKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + + IT_ASSERT(op->getDType() == DataType::Float32); + int inputCount = + op->getInputs(0)->getBytes() / op->getDType().getSize(); + + HcclComm comm = + dynamic_cast(context->getCommunicator()) + .getHcclComm(); + + uint32_t rank; + + checkHCCLError(HcclGetRankId(comm, &rank)); + + int source = op->getSourceRank(); + int destination = op->getDestinationRank(); + + if (int(rank) == source) { + checkHCCLError(HcclSend(input, uint64_t(inputCount), + HCCL_DATA_TYPE_FP32, uint32_t(destination), + comm, context->ASCENDHandle())); + } + checkASCENDError(aclrtSynchronizeStream(context->ASCENDHandle())); + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Send, SendHCCL, "Send_HCCL_ASCEND"); +} // namespace infini + +#endif diff --git a/src/kernels/ascend/slice.cc b/src/kernels/ascend/slice.cc new file mode 100644 index 00000000..35467e3c --- /dev/null +++ b/src/kernels/ascend/slice.cc @@ -0,0 +1,84 @@ +#include "operators/slice.h" +#include "aclnnop/aclnn_slice_v2.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class SliceAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto starts_32 = op->getStarts(); + auto ends_32 = op->getEnds(); + auto steps_32 = op->getSteps(); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + std::vector starts_64 = castTo64(starts_32); + std::vector ends_64 = castTo64(ends_32); + std::vector steps_64 = castTo64(steps_32); + + vector axes_64 = vector(starts_32.size(), 0); + for (int i = 0; i < int(starts_32.size()); i++) { + axes_64[i] = i; + } + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + aclIntArray *starts = + aclCreateIntArray(starts_64.data(), starts_64.size()); + aclIntArray *ends = aclCreateIntArray(ends_64.data(), ends_64.size()); + aclIntArray *steps = + aclCreateIntArray(steps_64.data(), steps_64.size()); + aclIntArray *axes = aclCreateIntArray(axes_64.data(), axes_64.size()); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = + aclnnSliceV2GetWorkspaceSize(inputA, starts, ends, axes, steps, + output, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnSliceV2(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyIntArray(starts); + aclDestroyIntArray(ends); + aclDestroyIntArray(axes); + aclDestroyIntArray(steps); + aclDestroyTensor(output); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Slice, SliceAclnn, + "slice_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/softmax.cc b/src/kernels/ascend/softmax.cc new file mode 100644 index 00000000..15162855 --- /dev/null +++ b/src/kernels/ascend/softmax.cc @@ -0,0 +1,61 @@ +#include "operators/softmax.h" +#include "aclnnop/level2/aclnn_softmax.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { +class SoftmaxAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + int64_t axis = int64_t(op->getAxis()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + auto input = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnSoftmaxGetWorkspaceSize(input, axis, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnSoftmax(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(input); + aclDestroyTensor(output); + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Softmax, SoftmaxAclnn, + "softmax_ASCEND_float"); + +} // namespace infini diff --git a/src/kernels/ascend/split.cc b/src/kernels/ascend/split.cc new file mode 100644 index 00000000..a2e531c7 --- /dev/null +++ b/src/kernels/ascend/split.cc @@ -0,0 +1,75 @@ +#include "operators/split.h" +#include "aclnnop/aclnn_split_tensor.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class SplitAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + + int64_t dim = op->getDim(); + int num = op->numOutputs(); + int dimSize = a.at(op->getDim()); + uint64_t splitSections = dimSize / num; + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + + std::vector outputsData{}; + for (int i = 0; i < num; ++i) { + auto c = op->getOutput(i)->getDims(); + auto cS = op->getOutput(i)->getStride(); + + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + void *const cData = (op->getOutput(i)->getRawDataPtr()); + + aclTensor *tmpTensor = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + outputsData.push_back(tmpTensor); + } + aclTensorList *tensorList = + aclCreateTensorList(outputsData.data(), outputsData.size()); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnSplitTensorGetWorkspaceSize( + inputA, splitSections, dim, tensorList, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnSplitTensor(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyTensorList(tensorList); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Split, SplitAclnn, + "split_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/transpose.cc b/src/kernels/ascend/transpose.cc new file mode 100644 index 00000000..79d78713 --- /dev/null +++ b/src/kernels/ascend/transpose.cc @@ -0,0 +1,136 @@ +#include "operators/transpose.h" +#include "aclnnop/level2/aclnn_permute.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { + +class PermuteAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + auto aS = op->getInputs(0)->getStride(); + auto c = op->getOutput()->getDims(); + auto cS = op->getOutput()->getStride(); + + std::vector aDim = castTo64(a); + std::vector aStride = castTo64(aS); + std::vector cDim = castTo64(c); + std::vector cStride = castTo64(cS); + + auto _permute = op->getPermute(); + std::vector permute = castTo64(_permute); + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + aclIntArray *dims = aclCreateIntArray(permute.data(), permute.size()); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnPermuteGetWorkspaceSize(inputA, dims, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnPermute(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyIntArray(dims); + aclDestroyTensor(output); + + return; + } +}; + +class DepthToSpaceAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto getStride = [](std::vector Dim) { + Shape stride(Dim.size()); + ShapeElem p = 1; + for (auto i = Dim.size(); i > 0; --i) { + stride[i - 1] = p; + p = p * Dim[i - 1]; + } + return stride; + }; + auto reshapeDim = op->getReshapeDim(); + auto reshapeStride = getStride(reshapeDim); + auto transposeDim = op->getTransposeDim(); + auto transposeStride = getStride(transposeDim); + + std::vector aDim = castTo64(reshapeDim); + std::vector aStride = castTo64(reshapeStride); + std::vector cDim = castTo64(transposeDim); + std::vector cStride = castTo64(transposeStride); + + auto mode = op->getMode(); + + std::vector permute; + if (mode == 0) { + permute = {0, 3, 4, 1, 5, 2}; + } else { + permute = {0, 1, 4, 2, 5, 3}; + } + + auto inputA = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + aclIntArray *dims = aclCreateIntArray(permute.data(), permute.size()); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = aclnnPermuteGetWorkspaceSize(inputA, dims, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnPermute(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(inputA); + aclDestroyIntArray(dims); + aclDestroyTensor(output); + + return; + } +}; + +REGISTER_KERNEL(Device::ASCEND, OpType::Transpose, PermuteAclnn, + "transpose_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::DepthToSpace, DepthToSpaceAclnn, + "DepthToSpace_ASCEND_float"); +} // namespace infini diff --git a/src/kernels/ascend/unary.cc b/src/kernels/ascend/unary.cc new file mode 100644 index 00000000..d1b15a66 --- /dev/null +++ b/src/kernels/ascend/unary.cc @@ -0,0 +1,258 @@ +#include "operators/unary.h" +#include "aclnnop/aclnn_erf.h" +#include "aclnnop/level2/aclnn_abs.h" +#include "aclnnop/level2/aclnn_acos.h" +#include "aclnnop/level2/aclnn_atan.h" +#include "aclnnop/level2/aclnn_ceil.h" +#include "aclnnop/level2/aclnn_cos.h" +#include "aclnnop/level2/aclnn_exp.h" +#include "aclnnop/level2/aclnn_floor.h" +#include "aclnnop/level2/aclnn_gelu.h" +#include "aclnnop/level2/aclnn_hardswish.h" +#include "aclnnop/level2/aclnn_leaky_relu.h" +#include "aclnnop/level2/aclnn_neg.h" +#include "aclnnop/level2/aclnn_reciprocal.h" +#include "aclnnop/level2/aclnn_relu.h" +#include "aclnnop/level2/aclnn_round.h" +#include "aclnnop/level2/aclnn_sigmoid.h" +#include "aclnnop/level2/aclnn_sin.h" +#include "aclnnop/level2/aclnn_sqrt.h" +#include "aclnnop/level2/aclnn_tanh.h" +#include "ascend/ascend_kernel_without_config.h" +#include "ascend/ascend_runtime.h" + +namespace infini { +class ReluAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + std::vector aDim(a.size(), 1); + for (size_t i = 0; i < a.size(); ++i) { + aDim[i] = int64_t(a[i]); + } + auto aS = op->getInputs(0)->getStride(); + std::vector aStride(aS.size(), 1); + for (size_t i = 0; i < aS.size(); ++i) { + aStride[i] = int64_t(aS[i]); + } + auto c = op->getInputs(0)->getDims(); + std::vector cDim(c.size(), 1); + for (size_t i = 0; i < c.size(); ++i) { + cDim[i] = int64_t(c[i]); + } + auto cS = op->getInputs(0)->getStride(); + std::vector cStride(cS.size(), 1); + for (size_t i = 0; i < cS.size(); ++i) { + cStride[i] = int64_t(cS[i]); + } + + auto input = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + auto ret = + aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnRelu(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(input); + aclDestroyTensor(output); + + return; + } +}; +class LeakyReluAclnn : public ASCENDKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + IT_ASSERT(op->getDType() == DataType::Float32); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto a = op->getInputs(0)->getDims(); + std::vector aDim(a.size(), 1); + for (size_t i = 0; i < a.size(); ++i) { + aDim[i] = int64_t(a[i]); + } + auto aS = op->getInputs(0)->getStride(); + std::vector aStride(aS.size(), 1); + for (size_t i = 0; i < aS.size(); ++i) { + aStride[i] = int64_t(aS[i]); + } + auto c = op->getInputs(0)->getDims(); + std::vector cDim(c.size(), 1); + for (size_t i = 0; i < c.size(); ++i) { + cDim[i] = int64_t(c[i]); + } + auto cS = op->getInputs(0)->getStride(); + std::vector cStride(cS.size(), 1); + for (size_t i = 0; i < cS.size(); ++i) { + cStride[i] = int64_t(cS[i]); + } + + auto input = aclCreateTensor( + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); + auto output = aclCreateTensor( + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + float negativeSlopeValue = op->getAlpha(); + aclScalar *negativeSlope = nullptr; + negativeSlope = + aclCreateScalar(&negativeSlopeValue, aclDataType::ACL_FLOAT); + + auto ret = aclnnLeakyReluGetWorkspaceSize(input, negativeSlope, output, + &workspaceSize, &executor); + checkASCENDError(ret); + + void *workspaceAddr = nullptr; + if (workspaceSize > 0) { + workspaceAddr = context->getWorkspace(workspaceSize); + } + + ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, + context->ASCENDHandle()); + checkASCENDError(ret); + + aclDestroyTensor(input); + aclDestroyScalar(negativeSlope); + aclDestroyTensor(output); + + return; + } +}; + +#define DEFINE_UNARY_Aclnn(prefix) \ + class prefix##Aclnn : public ASCENDKernelWithoutConfig { \ + void compute(const Operator &_op, \ + const RuntimeObj *_context) const override { \ + auto op = as(_op); \ + auto context = dynamic_cast(_context); \ + IT_ASSERT(op->getDType() == DataType::Float32); \ + \ + void *const aData = (op->getInputs(0)->getRawDataPtr()); \ + void *const cData = (op->getOutput()->getRawDataPtr()); \ + \ + auto a = op->getInputs(0)->getDims(); \ + std::vector aDim(a.size(), 1); \ + for (size_t i = 0; i < a.size(); ++i) { \ + aDim[i] = int64_t(a[i]); \ + } \ + auto aS = op->getInputs(0)->getStride(); \ + std::vector aStride(aS.size(), 1); \ + for (size_t i = 0; i < aS.size(); ++i) { \ + aStride[i] = int64_t(aS[i]); \ + } \ + auto c = op->getInputs(0)->getDims(); \ + std::vector cDim(c.size(), 1); \ + for (size_t i = 0; i < c.size(); ++i) { \ + cDim[i] = int64_t(c[i]); \ + } \ + auto cS = op->getInputs(0)->getStride(); \ + std::vector cStride(cS.size(), 1); \ + for (size_t i = 0; i < cS.size(); ++i) { \ + cStride[i] = int64_t(cS[i]); \ + } \ + \ + auto input = aclCreateTensor( \ + aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, \ + aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); \ + auto output = aclCreateTensor( \ + cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, \ + aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); \ + \ + uint64_t workspaceSize = 0; \ + aclOpExecutor *executor; \ + \ + auto ret = aclnn##prefix##GetWorkspaceSize( \ + input, output, &workspaceSize, &executor); \ + checkASCENDError(ret); \ + void *workspaceAddr = nullptr; \ + if (workspaceSize > 0) { \ + workspaceAddr = context->getWorkspace(workspaceSize); \ + } \ + \ + ret = aclnn##prefix(workspaceAddr, workspaceSize, executor, \ + context->ASCENDHandle()); \ + checkASCENDError(ret); \ + \ + aclDestroyTensor(input); \ + aclDestroyTensor(output); \ + \ + return; \ + } \ + }; + +DEFINE_UNARY_Aclnn(Abs); +DEFINE_UNARY_Aclnn(Sigmoid); +DEFINE_UNARY_Aclnn(Hardswish); +DEFINE_UNARY_Aclnn(Gelu); + +DEFINE_UNARY_Aclnn(Tanh); +DEFINE_UNARY_Aclnn(Sin); +DEFINE_UNARY_Aclnn(Cos); +DEFINE_UNARY_Aclnn(Acos); +DEFINE_UNARY_Aclnn(Atan); + +DEFINE_UNARY_Aclnn(Ceil); +DEFINE_UNARY_Aclnn(Floor); +DEFINE_UNARY_Aclnn(Exp); +DEFINE_UNARY_Aclnn(Neg); +DEFINE_UNARY_Aclnn(Reciprocal); +DEFINE_UNARY_Aclnn(Sqrt); +DEFINE_UNARY_Aclnn(Round); + +DEFINE_UNARY_Aclnn(Erf); + +REGISTER_KERNEL(Device::ASCEND, OpType::Relu, ReluAclnn, "relu_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::LeakyRelu, LeakyReluAclnn, + "leakyrelu_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Abs, AbsAclnn, "abs_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Sigmoid, SigmoidAclnn, + "sigmoid_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::HardSwish, HardswishAclnn, + "hardswish_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Tanh, TanhAclnn, "tanh_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Gelu, GeluAclnn, "gelu_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Sin, SinAclnn, "sin_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Cos, CosAclnn, "cos_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Acos, AcosAclnn, "acos_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Atan, AtanAclnn, "atan_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Neg, NegAclnn, "neg_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Ceil, CeilAclnn, "ceil_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Floor, FloorAclnn, + "floor_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Exp, ExpAclnn, "exp_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Reciprocal, ReciprocalAclnn, + "reciprocal_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Sqrt, SqrtAclnn, "sqrt_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Round, RoundAclnn, + "round_ASCEND_float"); +REGISTER_KERNEL(Device::ASCEND, OpType::Erf, ErfAclnn, "erf_ASCEND_float"); +} // namespace infini diff --git a/src/operators/instance_norm.cc b/src/operators/instance_norm.cc new file mode 100644 index 00000000..6761f835 --- /dev/null +++ b/src/operators/instance_norm.cc @@ -0,0 +1,48 @@ +#include "operators/instance_norm.h" +#include "utils/operator_utils.h" + +namespace infini { +InstanceNormObj::InstanceNormObj(GraphObj *graph, Tensor input, Tensor output, + Tensor scale, Tensor bias, float eps) + : OperatorObj(OpType::InstanceNormalization, TensorVec{input, scale, bias}, + {output}), + eps(eps) { + + IT_ASSERT(checkValid(graph)); +} + +optional> InstanceNormObj::inferShape(const TensorVec &inputs) { + return {{inputs[0]->getDims()}}; +} + +vector InstanceNormObj::inferDataType(const TensorVec &inputs) const { + + return {inputs[0]->getDType()}; +} + +std::string InstanceNormObj::toString() const { + std::ostringstream os; + os << "InstanceNormalization[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "eps=" << eps << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "scale=" << inputs[1]->getGuid() << ","; + os << "bias=" << inputs[2]->getGuid() << ","; + os << "output="; + for (auto output : outputs) + os << output->getGuid() << ","; + return os.str(); +} + +vector InstanceNormObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector InstanceNormObj::getOpAttrVector() const { + return {type.underlying()}; +} + +} // namespace infini diff --git a/src/operators/resize.cc b/src/operators/resize.cc index 0f0b08fe..9fbc7108 100644 --- a/src/operators/resize.cc +++ b/src/operators/resize.cc @@ -33,10 +33,11 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes, // inputs of operator must not be nullptr, due to the check in // OperatorObj::OperatorObj if (nullptr != sizes) { - IT_ASSERT(isResizeBySizes()); + setGivenSizes(true); inputs.push_back(sizes); InitBySizes(input, sizes, axes); } else if (nullptr != scales) { + setGivenSizes(false); inputs.push_back(scales); InitByScales(input, scales, axes); } @@ -101,8 +102,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, // copy sizes data to host. IT_ASSERT(sizes->getDataBlob() != nullptr); Runtime runtime = NativeCpuRuntimeObj::getInstance(); - std::shared_ptr dataObj((int *)runtime->alloc(sizes->getBytes()), - [&](int *p) { runtime->dealloc(p); }); + std::shared_ptr dataObj( + (int64_t *)runtime->alloc(sizes->getBytes()), + [&](int64_t *p) { runtime->dealloc(p); }); auto data = dataObj.get(); sizes->getRuntime()->copyBlobToCPU( (void *)data, sizes->getRawDataPtr(), sizes->getBytes()); @@ -193,7 +195,7 @@ vector ResizeObj::inferDataType(const TensorVec &inputs) const { } if (isResizeBySizes()) { auto sizes = inputs[1]; - IT_ASSERT(sizes && sizes->getDType() == DataType::UInt32); + IT_ASSERT(sizes && sizes->getDType() == DataType::Int64); } else { auto scales = inputs[1]; IT_ASSERT(scales && scales->getDType() == DataType::Float32); diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 6dbf73c3..ff2075b1 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -300,7 +300,8 @@ std::string LeakyReluObj::toString() const { os << "("; os << vecToString(inputs[0]->getDims()) << ","; os << "input=" << inputs[0]->getGuid() << ","; - os << "output=" << outputs[0]->getGuid() << ")"; + os << "output=" << outputs[0]->getGuid() << ","; + os << "alpha=" << alphaValue << ")"; return os.str(); } diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index 502336a8..4d3c1395 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -104,6 +104,8 @@ std::string device_to_str(Device device) { return "INTELCPU"; case Device::KUNLUN: return "KUNLUN"; + case Device::ASCEND: + return "ASCEND"; default: IT_TODO_HALT(); } diff --git a/test/kernels/ascend/test_ascend_all_gather.cc b/test/kernels/ascend/test_ascend_all_gather.cc new file mode 100644 index 00000000..a872f566 --- /dev/null +++ b/test/kernels/ascend/test_ascend_all_gather.cc @@ -0,0 +1,55 @@ +#ifdef INFINI_USE_HCCL +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_gather.h" +#include "test.h" +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +void allGather(const string taskName, int deviceID, vector data, + vector> ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime ascnedRuntime = make_ref(deviceID); + ascnedRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(ascnedRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, WORLD_SIZE); + // Copy data from CPU to MLU + g->dataMalloc(); + input->copyin(data); + // Run operation + ascnedRuntime->run(g); + // Copy output from MLU to CPU + for (int i = 0; i < WORLD_SIZE; ++i) { + auto result = op->getOutputs()[i]->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(ans[i])); + } +} + +TEST(ASCEND_AllGather, run) { + aclInit(nullptr); + + vector data[4] = {{2., 3.}, {5., 6.}, {7., 8.}, {9., 10.}}; + vector> ans = {{2., 3.}, {5., 6.}, {7., 8.}, {9., 10.}}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } + + aclFinalize(); +} + +} // namespace infini +#endif diff --git a/test/kernels/ascend/test_ascend_all_reduce.cc b/test/kernels/ascend/test_ascend_all_reduce.cc new file mode 100644 index 00000000..9aae7597 --- /dev/null +++ b/test/kernels/ascend/test_ascend_all_reduce.cc @@ -0,0 +1,120 @@ +#ifdef INFINI_USE_HCCL +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_reduce.h" +#include "test.h" +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +template +void allReduce(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime ascendRuntime = make_ref(deviceID); + ascendRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(ascendRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr); + // Copy data from CPU to GPU + g->dataMalloc(); + input->copyin(data); + // Run operation + ascendRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +// TEST(ASCEND_AllReduce, sum) { +// aclInit(nullptr); +// vector data[2] = {{2., 3.}, {5., 6.}}; +// vector ans = {7., 9.}; +// +// std::vector threads; +// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { +// threads.emplace_back(allReduce, +// "test_allreduce_sum", +// gpu, data[gpu], ans); +// } +// for (auto &thread : threads) { +// thread.join(); +// } +// aclFinalize(); +// } + +// TEST(ASCEND_AllReduce, prod) { +// aclInit(nullptr); +// vector data[2] = {{2., 3.}, {5., 6.}}; +// vector ans = {10., 18.}; +// +// std::vector threads; +// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { +// threads.emplace_back(allReduce, +// "test_allreduce_prod", +// gpu, data[gpu], ans); +// } +// for (auto &thread : threads) { +// thread.join(); +// } +// aclFinalize(); +// } + +// TEST(ASCEND_AllReduce, min) { +// aclInit(nullptr); +// vector data[2] = {{2., 3.}, {5., 6.}}; +// vector ans = {2., 3.}; +// +// std::vector threads; +// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { +// threads.emplace_back(allReduce, +// "test_allreduce_min", +// gpu, data[gpu], ans); +// } +// for (auto &thread : threads) { +// thread.join(); +// } +// aclFinalize(); +// } + +TEST(ASCEND_AllReduce, max) { + aclInit(nullptr); + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {5., 6.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_max", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } + aclFinalize(); +} + +// TEST(ASCEND_AllReduce, avg) { +// vector data[2] = {{2., 3.}, {5., 6.}}; +// vector ans = {3.5, 4.5}; +// +// std::vector threads; +// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { +// threads.emplace_back(allReduce, +// "test_allreduce_avg", +// gpu, data[gpu], ans); +// } +// for (auto &thread : threads) { +// thread.join(); +// } +// } + +} // namespace infini +#endif diff --git a/test/kernels/ascend/test_ascend_batch_norm.cc b/test/kernels/ascend/test_ascend_batch_norm.cc new file mode 100644 index 00000000..112a900d --- /dev/null +++ b/test/kernels/ascend/test_ascend_batch_norm.cc @@ -0,0 +1,58 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/batch_norm.h" + +#include "test.h" + +namespace infini { + +TEST(ascend_BatchNorm, run) { + aclInit(nullptr); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build cpu graph + Graph gCpu = make_ref(cpuRuntime); + auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); + auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + + // Build input data on CPU + gCpu->dataMalloc(); + iCpu->setData(IncrementalGenerator()); + meanCpu->copyin(vector{1, 6, 9}); + varCpu->copyin(vector{4, 1, 9}); + scaleCpu->setData(OneGenerator()); + biasCpu->setData(ZeroGenerator()); + + // Build CUDA graph + Graph g = make_ref(npuRuntime); + auto i = g->cloneTensor(iCpu); + auto mean = g->cloneTensor(meanCpu); + auto var = g->cloneTensor(varCpu); + auto scale = g->cloneTensor(scaleCpu); + auto bias = g->cloneTensor(biasCpu); + auto op = + g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + npuRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(vector{ + -0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.333333, 0.666667})); + + aclFinalize(); +} +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_concat.cc b/test/kernels/ascend/test_ascend_concat.cc new file mode 100644 index 00000000..6fd29fa8 --- /dev/null +++ b/test/kernels/ascend/test_ascend_concat.cc @@ -0,0 +1,65 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/concat.h" + +#include "test.h" + +namespace infini { + +template +void testConcat(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu1->setData(generator); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu2->dataMalloc(); + inputCpu2->setData(generator); + Tensor inputCpu3 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu3->dataMalloc(); + inputCpu3->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu1 = npuGraph->cloneTensor(inputCpu1); + auto inputNpu2 = npuGraph->cloneTensor(inputCpu2); + auto inputNpu3 = npuGraph->cloneTensor(inputCpu3); + auto npuOp = npuGraph->addOp(TensorVec{inputNpu1, inputNpu2, inputNpu3}, + nullptr, 2); + npuGraph->dataMalloc(); + inputNpu1->setData(generator); + inputNpu2->setData(generator); + inputNpu3->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + inputCpu1->print(); + inputCpu1->printData(); + inputCpu2->print(); + inputCpu2->printData(); + inputCpu3->print(); + inputCpu3->printData(); + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_Concat, run) { + aclInit(nullptr); + testConcat(IncrementalGenerator(), Shape{1, 2, 2, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_conv.cc b/test/kernels/ascend/test_ascend_conv.cc new file mode 100644 index 00000000..b346f1d1 --- /dev/null +++ b/test/kernels/ascend/test_ascend_conv.cc @@ -0,0 +1,60 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +template +void testConv(const std::function &generatorA, + const std::function &generatorB, + const Shape &shapeA, const Shape &shapeB) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shapeA, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shapeB, DataType::Float32, cpuRuntime); + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu1 = npuGraph->cloneTensor(inputCpu1); + auto inputNpu2 = npuGraph->cloneTensor(inputCpu2); + auto npuOp = + npuGraph->addOp(inputNpu1, inputNpu2, nullptr, 1, 1, 1, 1, 1, 1); + npuGraph->dataMalloc(); + inputNpu1->setData(generatorA); + inputNpu2->setData(generatorB); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + cpuGraph->addTensor(inputCpu1); + cpuGraph->addTensor(inputCpu2); + auto cpuOp = + cpuGraph->addOp(inputCpu1, inputCpu2, nullptr, 1, 1, 1, 1, 1, 1); + cpuGraph->dataMalloc(); + inputCpu1->setData(generatorA); + inputCpu2->setData(generatorB); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + // outputCpu->printData(); + // outputNpu2Cpu->printData(); + EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu, 1e-3)); +} + +TEST(ascend_Conv, run) { + aclInit(nullptr); + testConv(IncrementalGenerator(), IncrementalGenerator(), + Shape{1, 3, 128, 128}, Shape{2, 3, 3, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_conv_transpose_2d.cc b/test/kernels/ascend/test_ascend_conv_transpose_2d.cc new file mode 100644 index 00000000..dc635b0c --- /dev/null +++ b/test/kernels/ascend/test_ascend_conv_transpose_2d.cc @@ -0,0 +1,58 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvTransposedAclnn( + const std::function &generator, + std::vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime npu = make_ref(); + Graph gNpu = make_ref(npu); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Npu = gNpu->cloneTensor(i0Cpu); + Tensor w0Npu = gNpu->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gNpu->addOp(i0Npu, w0Npu, nullptr, padding, + padding, stride, stride, + dilation, dilation); + gNpu->dataMalloc(); + i0Npu->setData(generator); + w0Npu->setData(generator); + // Execute on CUDA + npu->run(gNpu); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + o0Cpu->printData(); + EXPECT_TRUE(o0Cpu->equalData(ansVec)); +} + +TEST(ascend_ConvTransposed, run) { + aclInit(nullptr); + testConvTransposedAclnn( + IncrementalGenerator(), + std::vector{0., 0., 1., 2., 3., 0., 6., 12., 18., + 16., 8., 30., 36., 42., 32., 16., 54., 60., + 66., 48., 24., 62., 67., 72., 45.}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_element_wise.cc b/test/kernels/ascend/test_ascend_element_wise.cc new file mode 100644 index 00000000..d55fba7b --- /dev/null +++ b/test/kernels/ascend/test_ascend_element_wise.cc @@ -0,0 +1,69 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +template +void testElementWise( + const std::function &generator, + const Shape &shape0, const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape0, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu2->dataMalloc(); + inputCpu1->setData(generator); + inputCpu2->setData(generator); + + inputCpu1->print(); + inputCpu1->printData(); + inputCpu2->print(); + inputCpu2->printData(); + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu1 = npuGraph->cloneTensor(inputCpu1); + auto inputNpu2 = npuGraph->cloneTensor(inputCpu2); + auto npuOp = npuGraph->addOp(inputNpu1, inputNpu2, nullptr); + npuGraph->dataMalloc(); + inputNpu1->setData(generator); + inputNpu2->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_ElementWise, run) { + aclInit(nullptr); + testElementWise(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{1, 2, 2, 3}); + testElementWise(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{1, 2, 2, 3}); + testElementWise(IncrementalGenerator(), Shape{1, 1, 48, 48}, + Shape{1, 1, 1, 1}); + testElementWise(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{1, 2, 2, 3}); + testElementWise(IncrementalGenerator(), Shape{1}, + Shape{1, 2, 2, 3}); + testElementWise(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{1, 2, 2, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_erf.cc b/test/kernels/ascend/test_ascend_erf.cc new file mode 100644 index 00000000..79f22e81 --- /dev/null +++ b/test/kernels/ascend/test_ascend_erf.cc @@ -0,0 +1,42 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/unary.h" + +#include "test.h" + +namespace infini { + +template +void testErf(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto ascendRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // Npu + Graph npuGraph = make_ref(ascendRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr); + npuGraph->dataMalloc(); + ascendRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + inputCpu->printData(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_Erf, run) { + aclInit(nullptr); + testErf(IncrementalGenerator(), Shape{1, 2, 2, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_gather.cc b/test/kernels/ascend/test_ascend_gather.cc new file mode 100644 index 00000000..d159cd47 --- /dev/null +++ b/test/kernels/ascend/test_ascend_gather.cc @@ -0,0 +1,71 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { + +TEST(ascend_Gather, run) { + aclInit(nullptr); + { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{3, 3}, DataType::Float32, cpuRuntime); + Tensor indexCpu = + make_ref(Shape{1}, DataType::Int64, cpuRuntime); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto indexNpu = npuGraph->cloneTensor(indexCpu); + auto npuOp = npuGraph->addOp(inputNpu, indexNpu, nullptr, 1); + npuGraph->dataMalloc(); + inputNpu->setData(IncrementalGenerator()); + indexNpu->copyin(vector{0}); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + // Check + EXPECT_TRUE(outputNpu2Cpu->equalData(vector{0, 3, 6})); + } + { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{3, 2}, DataType::Float32, cpuRuntime); + Tensor indexCpu = + make_ref(Shape{2, 2}, DataType::Int64, cpuRuntime); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto indexNpu = npuGraph->cloneTensor(indexCpu); + auto npuOp = npuGraph->addOp(inputNpu, indexNpu, nullptr, 0); + npuGraph->dataMalloc(); + inputNpu->copyin(std::vector{1.0, 1.2, 2.3, 3.4, 4.5, 5.7}); + indexNpu->copyin(vector{0, 1, 1, 2}); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputNpu2Cpu->equalData( + vector{1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7})); + } + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_instancenorm.cc b/test/kernels/ascend/test_ascend_instancenorm.cc new file mode 100644 index 00000000..3e968589 --- /dev/null +++ b/test/kernels/ascend/test_ascend_instancenorm.cc @@ -0,0 +1,74 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/instance_norm.h" + +#include "test.h" + +namespace infini { + +void test_instancenormFp32(const Shape &inputShape, + const vector &inputData, + const Shape &scaleShape, + const vector &scaleData, float eps, + const vector &ExpectData, + const Shape &biasShape, + const vector &biasData) { + + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto bias = gCpu->addTensor(biasShape, DataType::Float32); + auto input = gCpu->addTensor(inputShape, DataType::Float32); + auto scale = gCpu->addTensor(scaleShape, DataType::Float32); + auto ExpectOutput = gCpu->addTensor(inputShape, DataType::Float32); + gCpu->dataMalloc(); + bias->copyin(biasData); // + // bias->printData(); + input->copyin(inputData); + scale->copyin(scaleData); // + ExpectOutput->copyin(ExpectData); + auto ascendRuntime = make_ref(); + Graph gAscend = make_ref(ascendRuntime); + auto biasNpu = gAscend->cloneTensor(bias); + auto inputNpu = gAscend->cloneTensor(input); + auto scaleNpu = gAscend->cloneTensor(scale); + // gCpu->cloneTensor(biasNpu)->printData(); + auto op = + gAscend->addOp(inputNpu, nullptr, scaleNpu, biasNpu, + eps); // InstancenormObj + gAscend->dataMalloc(); + biasNpu->copyin(biasData); + // gCpu->cloneTensor(biasNpu)->printData(); + inputNpu->copyin(inputData); + scaleNpu->copyin(scaleData); + ascendRuntime->run(gAscend); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectOutput, 1e-4)); +} + +TEST(CUDA_InstancenormFp32, run) { + aclInit(nullptr); + test_instancenormFp32( + Shape{2, 3, 2, 3}, + vector{0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., + 27., 28., 29., 30., 31., 32., 33., 34., 35.}, + Shape{3}, vector{0.3, 0.2, 0.5}, 1e-5, + vector{ + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}, + Shape{3}, vector{0, 0, 0}); + + aclFinalize(); +} // python output + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_layernorm.cc b/test/kernels/ascend/test_ascend_layernorm.cc new file mode 100644 index 00000000..8d84b867 --- /dev/null +++ b/test/kernels/ascend/test_ascend_layernorm.cc @@ -0,0 +1,143 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/layer_norm.h" + +#include "test.h" + +namespace infini { + +void test_layernormFp32( + const Shape &inputShape, const vector &inputData, + const Shape &scaleShape, const vector &scaleData, float eps, + int axis, int stash_type, const vector &ExpectData, + const std::optional &bShape = std::nullopt, + const std::optional> &biasData = std::nullopt) { + + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + if (bShape.has_value() && biasData.has_value()) { + Shape biasShape = *bShape; + + auto bias = gCpu->addTensor(biasShape, DataType::Float32); + auto input = gCpu->addTensor(inputShape, DataType::Float32); + auto scale = gCpu->addTensor(scaleShape, DataType::Float32); + auto ExpectOutput = gCpu->addTensor(inputShape, DataType::Float32); + + gCpu->dataMalloc(); + bias->copyin(*biasData); + // bias->printData(); + input->copyin(inputData); + scale->copyin(scaleData); + ExpectOutput->copyin(ExpectData); + + auto ascendRuntime = make_ref(); + Graph gAscend = make_ref(ascendRuntime); + auto biasNpu = gAscend->cloneTensor(bias); + auto inputNpu = gAscend->cloneTensor(input); + auto scaleNpu = gAscend->cloneTensor(scale); + + auto op = + gAscend->addOp(inputNpu, scaleNpu, nullptr, biasNpu, + eps, axis, stash_type); // LayernormObj + gAscend->dataMalloc(); + biasNpu->copyin(*biasData); + // gCpu->cloneTensor(biasNpu)->printData(); + inputNpu->copyin(inputData); + scaleNpu->copyin(scaleData); + ascendRuntime->run(gAscend); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu + oCpu->printData(); //->printData + + EXPECT_TRUE(oCpu->equalData(ExpectOutput, 1e-4)); + } else { + + auto input = gCpu->addTensor(inputShape, DataType::Float32); + auto scale = gCpu->addTensor(scaleShape, DataType::Float32); + auto ExpectOutput = gCpu->addTensor(inputShape, DataType::Float32); + + gCpu->dataMalloc(); + input->copyin(inputData); + scale->copyin(scaleData); // + ExpectOutput->copyin(ExpectData); + + auto ascendRuntime = make_ref(); + Graph gAscend = make_ref(ascendRuntime); + + auto inputNpu = gAscend->cloneTensor(input); + auto scaleNpu = gAscend->cloneTensor(scale); + auto op = + gAscend->addOp(inputNpu, scaleNpu, nullptr, nullptr, + eps, axis, stash_type); // LayernormObj + gAscend->dataMalloc(); + + inputNpu->copyin(inputData); + scaleNpu->copyin(scaleData); + ascendRuntime->run(gAscend); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu + oCpu->printData(); //->printData + + EXPECT_TRUE(oCpu->equalData(ExpectOutput, 1e-4)); + // EXPECT_TRUE(oCpu->equalData(ExpectData)); + } +} + +TEST(CUDA_LayernormFp32, run) { + aclInit(nullptr); + test_layernormFp32( + Shape{2, 3, 2, 3}, + vector{0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., + 27., 28., 29., 30., 31., 32., 33., 34., 35.}, + Shape{3}, vector{0.3, 0.2, 0.5}, 1e-5, 3, 1, + vector{ + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}, + Shape{3}, vector{0, 0, 0}); + test_layernormFp32( + Shape{2, 3, 2, 3}, + vector{0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., + 27., 28., 29., 30., 31., 32., 33., 34., 35.}, + Shape{3}, vector{0.3, 0.2, 0.5}, 1e-5, 3, 1, + vector{ + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, + -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679}, + Shape{3}, vector{0.3, 0.2, 0.5}); + test_layernormFp32( + Shape{2, 3, 2, 3}, + vector{0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., + 27., 28., 29., 30., 31., 32., 33., 34., 35.}, + Shape{3}, vector{0.3, 0.2, 0.5}, 1e-5, 3, 1, + vector{-0.3674207, 0.0000000, 0.6123678, -0.3674207, + 0.0000000, 0.6123678, -0.3674207, 0.0000000, + 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, + 0.0000000, 0.6123678, -0.3674207, 0.0000000, + 0.6123678, -0.3674207, 0.0000000, 0.6123678, + -0.3674207, 0.0000000, 0.6123678, -0.3674207, + 0.0000000, 0.6123678, -0.3674207, 0.0000000, + 0.6123678, -0.3674207, 0.0000000, 0.6123678}); + + aclFinalize(); +} // python output + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_matmul.cc b/test/kernels/ascend/test_ascend_matmul.cc new file mode 100644 index 00000000..247e6fc5 --- /dev/null +++ b/test/kernels/ascend/test_ascend_matmul.cc @@ -0,0 +1,59 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/matmul.h" + +#include "test.h" + +namespace infini { + +template +void testMatmul(const std::function &generatorA, + const std::function &generatorB, + bool transA, bool transB, const Shape &shapeA, + const Shape &shapeB) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shapeA, DataType::Float32, cpuRuntime); + Tensor inputCpu2 = + make_ref(shapeB, DataType::Float32, cpuRuntime); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu1 = npuGraph->cloneTensor(inputCpu1); + auto inputNpu2 = npuGraph->cloneTensor(inputCpu2); + auto npuOp = npuGraph->addOp(inputNpu1, inputNpu2, nullptr); + npuGraph->dataMalloc(); + inputNpu1->setData(generatorA); + inputNpu2->setData(generatorB); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu1, inputCpu2, nullptr); + cpuGraph->addTensor(inputCpu1); + cpuGraph->addTensor(inputCpu2); + cpuGraph->dataMalloc(); + inputCpu1->setData(generatorA); + inputCpu2->setData(generatorB); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + + // Check + EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu)); +} + +TEST(ascend_Matmul, run) { + aclInit(nullptr); + testMatmul(IncrementalGenerator(), IncrementalGenerator(), false, + false, Shape{1, 2, 3}, Shape{1, 3, 4}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_pad.cc b/test/kernels/ascend/test_ascend_pad.cc new file mode 100644 index 00000000..c99f8ecd --- /dev/null +++ b/test/kernels/ascend/test_ascend_pad.cc @@ -0,0 +1,49 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/pad.h" + +#include "test.h" + +namespace infini { + +template +void testPad(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, vector{1, 1, 1, 1}, + vector{0, 3}); + + npuGraph->dataMalloc(); + inputNpu->setData(generator); + std::cout << npuOp->toString() << std::endl; + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + // Check + inputCpu->print(); + inputCpu->printData(); + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_Pad, run) { + aclInit(nullptr); + testPad(IncrementalGenerator(), Shape{1, 1, 2, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_pooling.cc b/test/kernels/ascend/test_ascend_pooling.cc new file mode 100644 index 00000000..b7f579b2 --- /dev/null +++ b/test/kernels/ascend/test_ascend_pooling.cc @@ -0,0 +1,48 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/pooling.h" + +#include "test.h" + +namespace infini { + +template {}, + int>::type = 0> +void testPooling(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = + npuGraph->addOp(inputNpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0); + // npuGraph->addOp(inputNpu, nullptr, 2, 2, 1, 1, 0, 0, 1, 1, 0); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + inputCpu->printData(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(cnnl_Pooling, run) { + aclInit(nullptr); + testPooling(IncrementalGenerator(), Shape{1, 3, 5, 5}); + testPooling(IncrementalGenerator(), Shape{1, 2, 5, 5}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_reduce.cc b/test/kernels/ascend/test_ascend_reduce.cc new file mode 100644 index 00000000..c992dcf3 --- /dev/null +++ b/test/kernels/ascend/test_ascend_reduce.cc @@ -0,0 +1,84 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/reduce.h" + +#include "test.h" + +namespace infini { + +template +void test_reduce(const Shape &shape, const vector &data, + const optional> &axes, bool keepDims, + const vector &ExpectData) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // Build NPU graph + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto op = npuGraph->addOp(inputNpu, nullptr, axes, keepDims); + + // allocate NPU memory + npuGraph->dataMalloc(); + inputNpu->copyin(data); + + // Execute on NPU + npuRuntime->run(npuGraph); + + // clone NPU output to CPU + auto outputNpu = op->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectData)); +} + +TEST(ascend_ReduceMean, run) { + aclInit(nullptr); + test_reduce( + Shape{3, 2, 2}, vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reduce( + Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt, + false, vector{18.25}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); + aclFinalize(); +} + +TEST(ascend_ReduceSum, run) { + test_reduce(Shape{3, 2, 2}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, true, vector{12}); + test_reduce(Shape{1, 3, 2, 2, 1}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, false, vector{12}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{30, 36, 102, 108}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{30, 36, 102, 108}); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_reshape.cc b/test/kernels/ascend/test_ascend_reshape.cc new file mode 100644 index 00000000..6e873db0 --- /dev/null +++ b/test/kernels/ascend/test_ascend_reshape.cc @@ -0,0 +1,84 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/reshape.h" +#include "operators/squeeze.h" +#include "operators/unsqueeze.h" + +#include "test.h" + +namespace infini { + +template +void testReshape(const std::function &generator, + const Shape &shape, const Shape &outputShape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, outputShape); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + inputCpu->print(); + inputCpu->printData(); + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(inputCpu->equalData(outputNpu2Cpu, 1e-3)); +} + +void testFlatten(const std::function &generator, + const Shape &shape, int axis) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, axis); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + inputCpu->print(); + inputCpu->printData(); + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(inputCpu->equalData(outputNpu2Cpu, 1e-3)); +} + +TEST(ascend_Unary, run) { + aclInit(nullptr); + testReshape(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{1, 2, 6}); + testReshape(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{0}); + testReshape(IncrementalGenerator(), Shape{1, 2, 2, 3}, + Shape{4}); + testFlatten(IncrementalGenerator(), Shape{1, 2, 2, 3}, 2); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_resize.cc b/test/kernels/ascend/test_ascend_resize.cc new file mode 100644 index 00000000..af6a1ea3 --- /dev/null +++ b/test/kernels/ascend/test_ascend_resize.cc @@ -0,0 +1,69 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/resize.h" +#include "test.h" + +namespace infini { + +TEST(Resize, Ascend_downsample_scales_nearest) { + aclInit(nullptr); + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8}); + scales->copyin(vector{1, 1, 0.6, 0.6}); + + auto ascendRuntime = make_ref(); + Graph gNpu = make_ref(ascendRuntime); + + auto inputNpu = gNpu->cloneTensor(input); + auto scalesNpu = gNpu->cloneTensor(scales); + auto op = gNpu->addOp(inputNpu, nullptr, std::nullopt, + nullptr, scalesNpu, nullptr); + gNpu->dataMalloc(); + inputNpu->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8}); + scalesNpu->copyin(vector{1, 1, 0.6, 0.6}); + ascendRuntime->run(gNpu); + + // copy output from NPU to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{1, 3})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyin(vector{1, 2, 3, 4}); + scales->copyin(vector{1, 1, 2, 3}); + + auto ascendRuntime = make_ref(); + Graph gNpu = make_ref(ascendRuntime); + + auto inputNpu = gNpu->cloneTensor(input); + auto scalesNpu = gNpu->cloneTensor(scales); + auto op = gNpu->addOp(inputNpu, nullptr, std::nullopt, + nullptr, scalesNpu, nullptr); + gNpu->dataMalloc(); + inputNpu->copyin(vector{1, 2, 3, 4}); + scalesNpu->copyin(vector{1, 1, 2, 3}); + ascendRuntime->run(gNpu); + + // copy output from NPU to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE( + oCpu->equalData(vector{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4})); + } + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_sendrecv.cc b/test/kernels/ascend/test_ascend_sendrecv.cc new file mode 100644 index 00000000..7c9703b9 --- /dev/null +++ b/test/kernels/ascend/test_ascend_sendrecv.cc @@ -0,0 +1,75 @@ +#ifdef INFINI_USE_HCCL +#include "ascend/ascend_runtime.h" +#include "ascend/hccl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/recv.h" +#include "operators/send.h" +#include "test.h" +#include + +namespace infini { + +void sendrecv(const string taskName, int deviceID, vector data, + const Shape &dataShape, int WORLD_SIZE, int source, + int destination) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime ascnedRuntime = make_ref(deviceID); + ascnedRuntime->initComm(taskName, WORLD_SIZE, deviceID); + + if (deviceID == source) { + Graph gSend = make_ref(ascnedRuntime); + auto input = gSend->addTensor(Shape{static_cast(data.size())}, + DataType::Float32); + auto opSend = + gSend->addOp(input, source, destination, nullptr); + + // Copy data from CPU to GPU + gSend->dataMalloc(); + input->copyin(data); + ascnedRuntime->run(gSend); + } + + // ---------------- + + if (deviceID == destination) { + Graph gRecv = make_ref(ascnedRuntime); + int outputType = 1; + // auto input = + // gRecv->addTensor(Shape{static_cast(data.size())},DataType::Float32); + auto opRecv = gRecv->addOp(nullptr, source, destination, + dataShape, outputType, nullptr); + gRecv->dataMalloc(); + ascnedRuntime->run(gRecv); + + auto result = opRecv->getOutput()->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(data)); + } +} + +TEST(ASCEND_SendRecv1, run) { + // Only 1 device gets data. Every rank should have the same data after + // sendrecv. + aclInit(nullptr); + { + vector data = {2., 3., 5., 6.}; + + int WORLD_SIZE = 4; + int source = 0; + int destination = 2; + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, + Shape{2, 2}, WORLD_SIZE, source, destination); + } + + for (auto &thread : threads) { + thread.join(); + } + } + aclFinalize(); +} + +} // namespace infini +#endif diff --git a/test/kernels/ascend/test_ascend_slice.cc b/test/kernels/ascend/test_ascend_slice.cc new file mode 100644 index 00000000..df8b3a41 --- /dev/null +++ b/test/kernels/ascend/test_ascend_slice.cc @@ -0,0 +1,41 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/slice.h" + +#include "test.h" + +namespace infini { + +TEST(ascend_Unary, run) { + aclInit(nullptr); + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime); + // inputCpu->dataMalloc(); + // inputCpu->setData(IncrementalGenerator()); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, vector{1, 1}, + vector{2, 5}, vector{0, 3}, + std::nullopt); + npuGraph->dataMalloc(); + inputNpu->setData(IncrementalGenerator()); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputNpu2Cpu->equalData( + vector{11, 12, 13, 14, 16, 17, 18, 19})); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_softmax.cc b/test/kernels/ascend/test_ascend_softmax.cc new file mode 100644 index 00000000..40b2f611 --- /dev/null +++ b/test/kernels/ascend/test_ascend_softmax.cc @@ -0,0 +1,61 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/softmax.h" + +#include "test.h" + +namespace infini { + +template +void testSoftmax(const std::function &generator, + const Shape &shape, int axis, vector Out) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + // inputCpu1->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu1 = npuGraph->cloneTensor(inputCpu1); + auto npuOp = npuGraph->addOp(inputNpu1, nullptr, axis); + npuGraph->dataMalloc(); + inputNpu1->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputNpu2Cpu->equalData(Out)); +} + +TEST(ascend_ElementWise, run) { + aclInit(nullptr); + testSoftmax( + IncrementalGenerator(), Shape{2, 2, 2, 2}, 1, + vector{0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, + 0.9820138, 0.9820138, 0.9820138, 0.0179862, 0.0179862, + 0.0179862, 0.0179862, 0.9820138, 0.9820138, 0.9820138, + 0.9820138}); + testSoftmax( + IncrementalGenerator(), Shape{2, 2, 2, 2}, 2, + vector{0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, + 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029, + 0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, + 0.8807971}); + testSoftmax( + IncrementalGenerator(), Shape{2, 2, 2, 2}, 3, + vector{0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, + 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, + 0.7310586}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_split.cc b/test/kernels/ascend/test_ascend_split.cc new file mode 100644 index 00000000..f1912d63 --- /dev/null +++ b/test/kernels/ascend/test_ascend_split.cc @@ -0,0 +1,50 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/split.h" + +#include "test.h" + +namespace infini { + +template +void testSplit(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + // GPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto gpuOp = npuGraph->addOp(inputNpu, std::nullopt, 3, 3); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto o0Cpu = gpuOp->getOutput(0)->clone(cpuRuntime); + auto o1Cpu = gpuOp->getOutput(1)->clone(cpuRuntime); + auto o2Cpu = gpuOp->getOutput(2)->clone(cpuRuntime); + // Check + inputCpu->print(); + inputCpu->printData(); + o0Cpu->print(); + o0Cpu->printData(); + o1Cpu->print(); + o1Cpu->printData(); + o2Cpu->print(); + o2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_Split, run) { + aclInit(nullptr); + testSplit(IncrementalGenerator(), Shape{1, 2, 2, 3}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_transpose.cc b/test/kernels/ascend/test_ascend_transpose.cc new file mode 100644 index 00000000..37a4377b --- /dev/null +++ b/test/kernels/ascend/test_ascend_transpose.cc @@ -0,0 +1,49 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/transpose.h" + +#include "test.h" + +namespace infini { + +template +void testTranspose( + const std::function &generator, + const Shape &shape, const Shape &permute) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, permute); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + inputCpu->print(); + inputCpu->printData(); + outputNpu2Cpu->print(); + outputNpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(ascend_Unary, run) { + aclInit(nullptr); + testTranspose(IncrementalGenerator(), Shape{1, 1, 2, 3}, + vector{0, 1, 3, 2}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/ascend/test_ascend_unary.cc b/test/kernels/ascend/test_ascend_unary.cc new file mode 100644 index 00000000..93dede4f --- /dev/null +++ b/test/kernels/ascend/test_ascend_unary.cc @@ -0,0 +1,151 @@ +#include "ascend/ascend_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/unary.h" + +#include "test.h" + +namespace infini { + +template +void testUnary(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); + inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu, 1e-3)); +} + +template +void testUnarywithExpectData( + const std::function &generator, + const Shape &shape, const vector &ExpectData) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto npuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + auto ExpectOutput = gCpu->addTensor(shape, DataType::Float32); + gCpu->dataMalloc(); + ExpectOutput->copyin(ExpectData); + + // NPU + Graph npuGraph = make_ref(npuRuntime); + auto inputNpu = npuGraph->cloneTensor(inputCpu); + auto npuOp = npuGraph->addOp(inputNpu, nullptr); + npuGraph->dataMalloc(); + inputNpu->setData(generator); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + // CPU + // Graph cpuGraph = make_ref(cpuRuntime); + // auto cpuOp = cpuGraph->addOp(inputCpu, nullptr); + // cpuGraph->addTensor(inputCpu); + // cpuGraph->dataMalloc(); + // inputCpu->setData(generator); + // cpuRuntime->run(cpuGraph); + // auto outputCpu = cpuOp->getOutput(); + // Check + outputNpu2Cpu->printData(); + EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectOutput, 1e-4)); +} + +void testLeakyRelu(const Shape &shape, const vector &inputData, + const vector &ExpectData, float alpha) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor(shape, DataType::Float32); + + gCpu->dataMalloc(); + + input->copyin(inputData); + auto npuRuntime = make_ref(); + Graph npuGraph = make_ref(npuRuntime); + // NPU + + auto inputNpu = npuGraph->cloneTensor(input); + auto npuOp = npuGraph->addOp(inputNpu, nullptr, alpha); + npuGraph->dataMalloc(); + inputNpu->copyin(inputData); + npuRuntime->run(npuGraph); + auto outputNpu = npuOp->getOutput(); + auto outputNpu2Cpu = outputNpu->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectData)); +} + +TEST(ascend_Unary, run) { + aclInit(nullptr); + testLeakyRelu(Shape{1, 2, 2, 3}, + vector{-6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6}, + vector{-0.0600, -0.0500, -0.0400, -0.0300, -0.0200, + -0.0100, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, + 6.0000}, + 0.01); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + + testUnarywithExpectData(IncrementalGenerator(), Shape{1, 2, 2, 3}, + vector{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, + 6.0, 7.0, 8.0, 9.0, 10.0, + 11.0}); + testUnarywithExpectData(IncrementalGenerator(), Shape{1, 2, 2, 3}, + vector{0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0}); + testUnarywithExpectData( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + vector{1.0, 2.71828, 7.38906, 20.0855, 54.5981, 148.413, 403.429, + 1096.63, 2980.96, 8103.08, 22026.5, 59874.1}); + testUnarywithExpectData( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + vector{std::numeric_limits::infinity(), 1, 0.5, 0.333333, + 0.25, 0.2, 0.166667, 0.142857, 0.125, 0.111111, 0.1, + 0.0909091}); + testUnarywithExpectData(IncrementalGenerator(), Shape{1, 2, 2, 3}, + vector{0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0}); + aclFinalize(); +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_resize.cc b/test/kernels/cuda/test_cuda_resize.cc index 57c3aaf5..55721085 100644 --- a/test/kernels/cuda/test_cuda_resize.cc +++ b/test/kernels/cuda/test_cuda_resize.cc @@ -11,10 +11,10 @@ TEST(Resize, Cuda_downsample_sizes_nearest) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8}); - sizes->copyin(vector{1, 1, 1, 3}); + sizes->copyin(vector{1, 1, 1, 3}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -26,7 +26,7 @@ TEST(Resize, Cuda_downsample_sizes_nearest) { ResizeObj::EKeepAspectRatioPolicy::stretch); gCuda->dataMalloc(); inputCuda->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8}); - sizesCuda->copyin(vector{1, 1, 1, 3}); + sizesCuda->copyin(vector{1, 1, 1, 3}); cudaRuntime->run(gCuda); @@ -40,10 +40,10 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); - auto sizes = gCpu->addTensor({2}, DataType::UInt32); + auto sizes = gCpu->addTensor({2}, DataType::Int64); gCpu->dataMalloc(); input->copyin(vector{1, 2, 3, 4}); - sizes->copyin(vector{7, 8}); + sizes->copyin(vector{7, 8}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -57,7 +57,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) { ResizeObj::ECoordinateTransMode::halfPixel); gCuda->dataMalloc(); inputCuda->copyin(vector{1, 2, 3, 4}); - sizesCuda->copyin(vector{7, 8}); + sizesCuda->copyin(vector{7, 8}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -73,10 +73,10 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); - auto sizes = gCpu->addTensor({2}, DataType::UInt32); + auto sizes = gCpu->addTensor({2}, DataType::Int64); gCpu->dataMalloc(); input->copyin(vector{1, 2, 3, 4}); - sizes->copyin(vector{7, 8}); + sizes->copyin(vector{7, 8}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -90,7 +90,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) { ResizeObj::ECoordinateTransMode::halfPixel); gCuda->dataMalloc(); inputCuda->copyin(vector{1, 2, 3, 4}); - sizesCuda->copyin(vector{7, 8}); + sizesCuda->copyin(vector{7, 8}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -106,11 +106,11 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 8, 8}); + sizes->copyin(vector{1, 1, 8, 8}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -125,7 +125,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 8, 8}); + sizesCuda->copyin(vector{1, 1, 8, 8}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -143,11 +143,11 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({2}, DataType::UInt32); + auto sizes = gCpu->addTensor({2}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{8, 8}); + sizes->copyin(vector{8, 8}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -162,7 +162,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{8, 8}); + sizesCuda->copyin(vector{8, 8}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -180,11 +180,11 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 8, 8}); + sizes->copyin(vector{1, 1, 8, 8}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -199,7 +199,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 8, 8}); + sizesCuda->copyin(vector{1, 1, 8, 8}); cudaRuntime->run(gCuda); @@ -421,11 +421,11 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 3, 1}); + sizes->copyin(vector{1, 1, 3, 1}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -440,7 +440,7 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 3, 1}); + sizesCuda->copyin(vector{1, 1, 3, 1}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -453,12 +453,12 @@ TEST(Resize, Cuda_tf_crop_and_resize) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); auto roi = gCpu->addTensor({8}, DataType::Float32); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 3, 3}); + sizes->copyin(vector{1, 1, 3, 3}); roi->copyin(vector{0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8}); auto cudaRuntime = make_ref(); @@ -475,7 +475,7 @@ TEST(Resize, Cuda_tf_crop_and_resize) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 3, 3}); + sizesCuda->copyin(vector{1, 1, 3, 3}); roiCuda->copyin(vector{0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8}); cudaRuntime->run(gCuda); @@ -490,12 +490,12 @@ TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({2}, DataType::UInt32); + auto sizes = gCpu->addTensor({2}, DataType::Int64); auto roi = gCpu->addTensor({4}, DataType::Float32); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{3, 3}); + sizes->copyin(vector{3, 3}); roi->copyin(vector{0.6, 0.4, 0.8, 0.6}); auto cudaRuntime = make_ref(); @@ -512,7 +512,7 @@ TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{3, 3}); + sizesCuda->copyin(vector{3, 3}); roiCuda->copyin(vector{0.6, 0.4, 0.8, 0.6}); cudaRuntime->run(gCuda); @@ -715,11 +715,11 @@ TEST(Resize, Cuda_downsample_sizes_cubic) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 3, 3}); + sizes->copyin(vector{1, 1, 3, 3}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -733,7 +733,7 @@ TEST(Resize, Cuda_downsample_sizes_cubic) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 3, 3}); + sizesCuda->copyin(vector{1, 1, 3, 3}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU @@ -753,11 +753,11 @@ TEST(Resize, Cuda_upsample_sizes_cubic) { Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); - auto sizes = gCpu->addTensor({4}, DataType::UInt32); + auto sizes = gCpu->addTensor({4}, DataType::Int64); gCpu->dataMalloc(); input->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizes->copyin(vector{1, 1, 9, 10}); + sizes->copyin(vector{1, 1, 9, 10}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -771,7 +771,7 @@ TEST(Resize, Cuda_upsample_sizes_cubic) { gCuda->dataMalloc(); inputCuda->copyin( vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - sizesCuda->copyin(vector{1, 1, 9, 10}); + sizesCuda->copyin(vector{1, 1, 9, 10}); cudaRuntime->run(gCuda); // copy output from CUDA to CPU auto oCpu = gCpu->cloneTensor(op->getOutput(0)); diff --git a/test/operators/test_resize.cc b/test/operators/test_resize.cc index 9079c3bf..ce3b20f8 100644 --- a/test/operators/test_resize.cc +++ b/test/operators/test_resize.cc @@ -10,9 +10,9 @@ TEST(Resize, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32); - Tensor sizes = g->addTensor({4}, DataType::UInt32); + Tensor sizes = g->addTensor({4}, DataType::Int64); sizes->dataMalloc(); - sizes->copyin(vector{1, 1, 1, 3}); + sizes->copyin(vector{1, 1, 1, 3}); auto op = g->addOp( i, nullptr, std::nullopt, sizes, nullptr, nullptr, ResizeObj::EKeepAspectRatioPolicy::stretch); @@ -22,9 +22,9 @@ TEST(Resize, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32); - Tensor sizes = g->addTensor({2}, DataType::UInt32); + Tensor sizes = g->addTensor({2}, DataType::Int64); sizes->dataMalloc(); - sizes->copyin(vector{1, 3}); + sizes->copyin(vector{1, 3}); auto op = g->addOp( i, nullptr, vector{2, 3}, sizes, nullptr, nullptr, ResizeObj::EKeepAspectRatioPolicy::stretch); @@ -34,9 +34,9 @@ TEST(Resize, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32); - Tensor sizes = g->addTensor({2}, DataType::UInt32); + Tensor sizes = g->addTensor({2}, DataType::Int64); sizes->dataMalloc(); - sizes->copyin(vector{7, 8}); + sizes->copyin(vector{7, 8}); auto op = g->addOp( i, nullptr, vector{2, 3}, sizes, nullptr, nullptr, ResizeObj::EKeepAspectRatioPolicy::notLarger); @@ -46,9 +46,9 @@ TEST(Resize, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32); - Tensor sizes = g->addTensor({3}, DataType::UInt32); + Tensor sizes = g->addTensor({3}, DataType::Int64); sizes->dataMalloc(); - sizes->copyin(vector{2, 6, 8}); + sizes->copyin(vector{2, 6, 8}); auto op = g->addOp( i, nullptr, vector{1, 2, 3}, sizes, nullptr, nullptr, ResizeObj::EKeepAspectRatioPolicy::notSmaller);