forked from jiuyuan/InfiniTensor
Compare commits
49 Commits
Author | SHA1 | Date |
---|---|---|
zhangyue | bc2440cb98 | |
zhangyue | 7b04957699 | |
Songxin | abccff0829 | |
OdinaryWord | faed0682fc | |
OdinaryWord | c8d60d76e8 | |
OdinaryWord | a1a68d3624 | |
OdinaryWord | a8443741c4 | |
xgqdut2016 | 5747eb8f7d | |
xgqdut2016 | 9384cec7de | |
Haojie Wang | f0509facc6 | |
zhangyunze | 6ad05da684 | |
xgqdut2016 | 6a89946736 | |
xgqdut2016 | 0fcaf001c4 | |
zhangyunze | 377b3bf391 | |
OdinaryWord | 907239cf34 | |
xgqdut2016 | 47fc0bfa99 | |
xgqdut2016 | ef4646ec89 | |
xgqdut2016 | e6b98fd652 | |
xgqdut2016 | 4d078967e0 | |
OdinaryWord | 0c94b75a65 | |
OdinaryWord | 775ce5040d | |
OdinaryWord | 6ba1a0648a | |
OdinaryWord | a765cd2a3d | |
OdinaryWord | 8b8f165158 | |
OdinaryWord | 5b89c699dc | |
OdinaryWord | 2b8823515e | |
OdinaryWord | 87f975d969 | |
OdinaryWord | 33e1521754 | |
OdinaryWord | ec549d260b | |
OdinaryWord | dddb40cd93 | |
OdinaryWord | a5ccf06551 | |
OdinaryWord | fc4b62a88c | |
OdinaryWord | 36e0840f2f | |
OdinaryWord | 9db6703b58 | |
OdinaryWord | e7d34badfb | |
OdinaryWord | f6176124ec | |
OdinaryWord | c970c93ba1 | |
OdinaryWord | dcbbc82d5b | |
OdinaryWord | 70950e3fbb | |
OdinaryWord | 39484e0cc4 | |
OdinaryWord | a9bd73528d | |
OdinaryWord | 95ee579338 | |
wanghailu0717 | 11e2b08be3 | |
wanghailu0717 | cc057bcf80 | |
wanghailu | 6b06ab0534 | |
wanghailu | 412f301323 | |
Haojie Wang | b1bdbbf478 | |
wanghailu | 56634b3b19 | |
wanghailu0717 | b6ff4514fe |
|
@ -44,3 +44,5 @@ build_debug/
|
|||
*.onnx
|
||||
*.pb
|
||||
*.npy
|
||||
|
||||
*.swp
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
option(USE_CUDA "Support CUDA GPU" OFF)
|
||||
option(USE_BANG "Support BANG MLU" OFF)
|
||||
option(USE_KUNLUN "Support KUNLUN XPU" OFF)
|
||||
option(USE_ASCEND "Support HUAWEI ASCEND" OFF)
|
||||
option(USE_INTELCPU "Support INTELCPU" OFF)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
||||
|
@ -151,6 +152,11 @@ if(USE_KUNLUN)
|
|||
list (APPEND SRC ${SRC_KUNLUN})
|
||||
endif()
|
||||
|
||||
if(USE_ASCEND)
|
||||
file(GLOB_RECURSE SRC_ASCEND src/ascend/*.cc src/kernels/ascend/*.cc )
|
||||
list (APPEND SRC ${SRC_ASCEND})
|
||||
endif()
|
||||
|
||||
if(USE_INTELCPU)
|
||||
file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc )
|
||||
list (APPEND SRC ${SRC_INTELCPU})
|
||||
|
@ -297,6 +303,7 @@ if(USE_KUNLUN)
|
|||
else()
|
||||
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
endif()
|
||||
|
||||
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
||||
|
||||
if (BUILD_DIST)
|
||||
|
@ -309,6 +316,42 @@ if(USE_KUNLUN)
|
|||
target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++)
|
||||
endif()
|
||||
|
||||
if(USE_ASCEND)
|
||||
add_compile_definitions(USE_ASCEND=1)
|
||||
if ((NOT DEFINED ASCEND_HOME) AND (NOT DEFINED ENV{ASCEND_HOME}))
|
||||
message(FATAL_ERROR "ASCEND_HOME is not defined from cmake or env")
|
||||
elseif (DEFINED ASCEND_HOME)
|
||||
set(ASCEND_HOME ${ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Kunlun development")
|
||||
else()
|
||||
set(ASCEND_HOME $ENV{ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Kunlun development")
|
||||
endif()
|
||||
message(STATUS "ASCEND_HOME: ${ASCEND_HOME}")
|
||||
|
||||
include_directories("${ASCEND_HOME}/include/")
|
||||
include_directories("${ASCEND_HOME}/include/aclnn")
|
||||
find_library(ASCEND_CL libascendcl.so "${ASCEND_HOME}/lib64")
|
||||
find_library(ASCEND_BASE libnnopbase.so "${ASCEND_HOME}/lib64")
|
||||
find_library(ASCEND_DNN libopapi.so "${ASCEND_HOME}/lib64")
|
||||
find_library(ASCEND_HCCL libhccl.so "${ASCEND_HOME}/lib64")
|
||||
find_library(ASCEND_HAL libascend_hal.so "${ASCEND_HOME}/../../driver/lib64/driver")
|
||||
# find_library(ASCEND_RT libruntime.so "${ASCEND_HOME}/lib64")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
|
||||
elseif(DEFINED TARGET_CPU_ARCH)
|
||||
set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
else()
|
||||
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
endif()
|
||||
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
||||
target_link_libraries(InfiniTensor ${ASCEND_HAL} ${ASCEND_CL} ${ASCEND_BASE} ${ASCEND_DNN} ${ASCEND_HCCL} stdc++)
|
||||
if (BUILD_DIST)
|
||||
message(STATUS "Add BUILD_DIST, use HCCL with ASCEND")
|
||||
add_compile_definitions(INFINI_USE_HCCL=1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# # Python bindings
|
||||
# pybind11_add_module(infini MODULE ${FFI})
|
||||
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||
|
@ -346,6 +389,9 @@ if(BUILD_TEST)
|
|||
build_test(test/kernels/kunlun/*.cc)
|
||||
build_test(test/kunlun/*.cc)
|
||||
endif()
|
||||
if (USE_ASCEND)
|
||||
build_test(test/kernels/ascend/*.cc)
|
||||
endif()
|
||||
if (USE_INTELCPU)
|
||||
build_test(test/kernels/intelcpu/*.cc)
|
||||
endif()
|
||||
|
@ -360,4 +406,4 @@ if(BUILD_TEST)
|
|||
add_executable(nnet_reader test/nnet/readlog.cc)
|
||||
target_link_libraries(nnet_reader InfiniTensor)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
2
Makefile
2
Makefile
|
@ -4,6 +4,7 @@ TYPE ?= Release
|
|||
CUDA ?= OFF
|
||||
BANG ?= OFF
|
||||
KUNLUN ?= OFF
|
||||
ASCEND ?= OFF
|
||||
INTELCPU ?= off
|
||||
BACKTRACE ?= ON
|
||||
TEST ?= ON
|
||||
|
@ -29,6 +30,7 @@ CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
|||
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
||||
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||
CMAKE_OPT += -DUSE_ASCEND=$(ASCEND)
|
||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||
CMAKE_OPT += -DBUILD_DIST=$(DIST)
|
||||
|
|
|
@ -72,4 +72,4 @@ Please cite EinNet or PET in your publications if it helps your research:
|
|||
pages={37--54},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
```
|
|
@ -14,6 +14,7 @@
|
|||
| -------- | ------------ | ----------- | ---------- |
|
||||
| X86-64 | Nvidia GPU | Ubuntu-22.04 | Yes |
|
||||
| X86-64 | Cambricon MLU | Ubuntu-22.04 | Yes |
|
||||
| arm64 | Ascend NPU |OpenEuler 22.03| Yes |
|
||||
|
||||
推荐使用 X86-64 机器以及 Ubuntu-22.04,本文以此环境为例。
|
||||
|
||||
|
@ -68,6 +69,20 @@
|
|||
|
||||
我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。另外请注意,由于 MLU 上层软件建设适配程度有限,如您在其覆盖的机器,操作系统之外运行,需要在安装驱动之后使用上层软件的 Docker。
|
||||
|
||||
- 如您的第三方加速卡为昇腾 NPU,请参考昇腾官方文档进行:
|
||||
> [驱动及CANN安装](https://www.hiascend.com/document/detail/zh/canncommercial/80RC1/quickstart/quickstart/quickstart_18_0006.html)
|
||||
> 安装完成后请进行相应的环境变量配置,将可执行文件目录与库目录添加到操作系统识别的路径中,例如
|
||||
>
|
||||
> ```bash
|
||||
> # 将如下内容写入到你的 bashrc 文件并 source 该文件
|
||||
> export ASCEND_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
> source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
> # 如您不方便将上述环境变量配置到 bashrc 文件中进行长期使用,你也可以在我们提供的 env.sh 文件中进行正确配置并激活,作为临时使用
|
||||
> source env.sh
|
||||
> ```
|
||||
|
||||
我们强烈建议您规范安装,统一到一个目录下,以免不必要的麻烦。
|
||||
|
||||
4. 确认您安装了 make,build-essential, python-is-python3, python-dev-is-python3, python3-pip, libdw-dev,如您的机器没有上述基础依赖,请自行按需安装。
|
||||
|
||||
- 在使用 apt-get 工具情况下,您可以这样执行
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
- `CUDA`:是否编译 CUDA 后端,默认为 `OFF`,`ON` 打开
|
||||
- `BANG`:是否编译寒武纪后端,默认为 `OFF`,`ON` 打开
|
||||
- `KUNLUN`:是否编译昆仑后端,默认为 `OFF`,`ON` 打开
|
||||
- `ASCEND`:是否编译华为后端,默认为 `OFF`,`ON` 打开
|
||||
- `BACKTRACE`:是否启用栈回溯,默认为 `ON`,`OFF` 关闭,建议调试时打开
|
||||
- `TEST`:是否编译 `googletest`,默认为 `ON`,`OFF` 关闭,只有 `test-cpp` 时必要
|
||||
|
||||
|
|
14
env.sh
14
env.sh
|
@ -36,3 +36,17 @@ export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}"
|
|||
# ├── version
|
||||
# └── XTDK
|
||||
export KUNLUN_HOME=/usr/local/xpu
|
||||
|
||||
# 配置华为ASCEND NPU 的 HOME 路径,请注意 /usr/local/ascend 是华为ASCEND 软件栈提供的软件包路径。
|
||||
# 如若用户有其他的路径安装方式,请自行配置正确的路径。
|
||||
# 这里是 ascend 目录下一个可能的结构图,请参考。
|
||||
# .
|
||||
# ├── bin
|
||||
# ├── include
|
||||
# ├── lib64
|
||||
# ├── tools
|
||||
# ├── version
|
||||
# └── XTDK
|
||||
export ASCEND_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
source /usr/local/Ascend/toolbox/set_env.sh
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
||||
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
|
@ -0,0 +1,198 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
import acl
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
||||
help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
load_inputs(stub, world_size, rank)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
time.sleep(0.01)
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||
results = np.load(f"./data/output.npy")
|
||||
outputs = run_model(model, runtime, world_size, rank)
|
||||
print("answer argmax:", np.argmax(results))
|
||||
print("output argmax:", np.argmax(outputs))
|
||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||
getDiff(results, outputs)
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
onnx.save_model(
|
||||
model,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.ASCENDRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, world_size, rank)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
runtime = backend.ASCENDRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def generate_input_output(model):
|
||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||
runtime = backend.ASCENDRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
position_id = 0
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = tensor.copyout_numpy()
|
||||
if np.issubdtype(input.dtype, np.integer):
|
||||
if input.size == 1:
|
||||
# input = np.array([position_id])
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
else:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
elif input.dtype == np.bool_:
|
||||
input = np.random.randint(0,2,size=input.shape) > 0
|
||||
else:
|
||||
if i == 0:
|
||||
input = np.ones(input.shape).astype(input.dtype)
|
||||
position_id = input.shape[-1] - 1
|
||||
else:
|
||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||
tensor.copyin_numpy(input)
|
||||
np.save(f"./data/input_{i}", input)
|
||||
stub.run()
|
||||
time.sleep(0.01)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
np.save(f"./data/output", output)
|
||||
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(f"./data/input_{i}.npy")
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
def getDiff(base, test):
|
||||
absolute_diff = np.abs(np.subtract(base, test))
|
||||
max_absolute_diff = np.max(absolute_diff)
|
||||
|
||||
baseCopy = base.astype(np.float64).ravel()
|
||||
testCopy = test.astype(np.float64).ravel()
|
||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||
max_relative_diff = upValue / downValue
|
||||
print(f"Max absolute difference: {max_absolute_diff}\n"
|
||||
f"Max relative difference: {max_relative_diff}")
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
|
||||
def main():
|
||||
acl.init()
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
p = mp.Process(target=generate_input_output, args=[model])
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single MLU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} MLUs in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_op.h"
|
||||
#include "core/common.h"
|
||||
|
||||
#define checkASCENDError(call) \
|
||||
{ \
|
||||
auto err = call; \
|
||||
if (ACL_SUCCESS != err) { \
|
||||
fprintf(stderr, "ASCEND error in %s:%i : .\n", __FILE__, \
|
||||
__LINE__); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ASCENDPtr = void *;
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ASCENDKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
virtual void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const {
|
||||
compute(op, context);
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *_context) const {
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
// transform vector<int> to vector<int64_t>
|
||||
std::vector<int64_t> castTo64(std::vector<int> const &v32) const {
|
||||
if (v32.size() == 0) {
|
||||
std::vector<int64_t> v64(1, 1);
|
||||
return v64;
|
||||
}
|
||||
std::vector<int64_t> v64(v32.size(), 1);
|
||||
for (size_t i = 0; i < v32.size(); ++i) {
|
||||
v64[i] = int64_t(v32[i]);
|
||||
}
|
||||
|
||||
return v64;
|
||||
}
|
||||
|
||||
Shape getStride(std::vector<int> Dim) const {
|
||||
Shape stride(Dim.size());
|
||||
ShapeElem p = 1;
|
||||
for (auto i = Dim.size(); i > 0; --i) {
|
||||
stride[i - 1] = p;
|
||||
p = p * Dim[i - 1];
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
};
|
||||
} // namespace infini
|
|
@ -0,0 +1,105 @@
|
|||
#pragma once
|
||||
#include "ascend/ascend_common.h"
|
||||
#include "core/runtime.h"
|
||||
|
||||
#define CHECK_RET(cond, return_expr) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
return_expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOG_PRINT(message, ...) \
|
||||
do { \
|
||||
printf(message, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ASCENDRuntimeObj : public RuntimeObj {
|
||||
private:
|
||||
aclrtContext context;
|
||||
aclrtStream stream;
|
||||
std::unique_ptr<CommunicatorObj> comm;
|
||||
ASCENDPtr workspace = nullptr;
|
||||
uint64_t workspaceSize;
|
||||
|
||||
public:
|
||||
ASCENDRuntimeObj(int deviceId = 0) : RuntimeObj(Device::ASCEND, deviceId) {
|
||||
// auto ret = aclInit(nullptr);
|
||||
// CHECK_RET(ret == ACL_SUCCESS,
|
||||
// LOG_PRINT("aclInit failed. ERROR: %d\n", ret));
|
||||
auto ret = aclrtSetDevice(deviceId);
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret));
|
||||
ret = aclrtCreateContext(&context, deviceId);
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret));
|
||||
ret = aclrtSetCurrentContext(context);
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret));
|
||||
ret = aclrtCreateStream(&stream);
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret));
|
||||
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 3ll << 33; // 3 GB
|
||||
// std::cout<<workspaceSize/1024/1024/1024<< std::endl;
|
||||
// std::cout<<std::bitset<64>(workspaceSize)<< std::endl;
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
virtual ~ASCENDRuntimeObj() {
|
||||
dealloc(workspace);
|
||||
aclrtDestroyStream(stream);
|
||||
aclrtDestroyContext(context);
|
||||
aclrtResetDevice(deviceId);
|
||||
// aclFinalize();
|
||||
}
|
||||
string toString() const override;
|
||||
|
||||
void run(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const;
|
||||
// double runEvaluation(const Graph &graph, int nWarmups,
|
||||
// int nEvaluations) const;
|
||||
void sync() const;
|
||||
ASCENDPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
checkASCENDError(
|
||||
aclrtMalloc((void **)&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { aclrtFree(ptr); }
|
||||
aclrtStream ASCENDHandle() const { return stream; }
|
||||
ASCENDPtr getWorkspace(uint64_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
}
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
aclrtMemcpy(dst, bytes, const_cast<void *>(src), bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
|
||||
void copyBlobToCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
aclrtMemcpy(dst, bytes, const_cast<void *>(src), bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
}
|
||||
|
||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
aclrtMemcpy(dst, bytes, const_cast<void *>(src), bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE);
|
||||
}
|
||||
|
||||
void initComm(const string &name, int worldSize, int rank) final;
|
||||
|
||||
CommunicatorObj &getCommunicator() const override { return *comm; }
|
||||
|
||||
private:
|
||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,88 @@
|
|||
#pragma once
|
||||
#include "core/communicator.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#define ACLCHECK(ret) \
|
||||
do { \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
} while (0)
|
||||
#define HCCLCHECK(ret) \
|
||||
do { \
|
||||
assert(ret == HCCL_SUCCESS); \
|
||||
} while (0)
|
||||
|
||||
namespace infini {
|
||||
|
||||
class HcclCommunicatorObj final : public CommunicatorObj {
|
||||
private:
|
||||
HcclComm comm;
|
||||
|
||||
public:
|
||||
HcclCommunicatorObj(const string &name, int worldSize, int rank)
|
||||
: CommunicatorObj(worldSize, rank) {
|
||||
const std::string filePath("./" + name + "_hccl_id.bin");
|
||||
int devId = rank;
|
||||
int devCount = worldSize;
|
||||
// 在 rootRank 获取 rootInfo
|
||||
HcclRootInfo rootInfo;
|
||||
int32_t rootRank = 0;
|
||||
|
||||
if (devId == rootRank) {
|
||||
HCCLCHECK(HcclGetRootInfo(&rootInfo));
|
||||
std::ofstream ofs(filePath, std::ios::binary);
|
||||
ofs.write((char *)&rootInfo, sizeof(HcclRootInfo));
|
||||
} else {
|
||||
auto begin = std::chrono::steady_clock::now();
|
||||
while (!std::filesystem::exists(filePath)) {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||
"time limit (10s) exceeded.");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
std::ifstream ifs(filePath, std::ios::binary);
|
||||
ifs.read((char *)&rootInfo, sizeof(HcclRootInfo));
|
||||
}
|
||||
|
||||
auto ret = HcclCommInitRootInfo(uint32_t(devCount), &rootInfo,
|
||||
uint32_t(devId), &comm);
|
||||
|
||||
assert(ret == HCCL_SUCCESS);
|
||||
|
||||
if (rank == 0) {
|
||||
std::filesystem::remove(filePath);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the actual ncclComm_t
|
||||
HcclComm getHcclComm() { return comm; }
|
||||
|
||||
// void finalize() { HCCLCHECK(HcclCommFinalize(comm)); }
|
||||
|
||||
~HcclCommunicatorObj() final {
|
||||
// finalize();
|
||||
// auto ret = HcclCommDestroy(comm);
|
||||
// auto tmp_err_msg = HcclGetErrorString(ret);
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
//}
|
||||
// assert(ret == HCCL_SUCCESS);
|
||||
}
|
||||
|
||||
virtual string toString() const final {
|
||||
std::ostringstream oss;
|
||||
oss << "HCCL communicator";
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -1,4 +1,4 @@
|
|||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
|
@ -37,6 +37,8 @@ class GraphHandlerObj {
|
|||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
Tensor instanceNormalization(Tensor input, Tensor output, Tensor scale,
|
||||
Tensor bias, float eps);
|
||||
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
|
@ -55,6 +57,7 @@ class GraphHandlerObj {
|
|||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor silu(Tensor x, Tensor y);
|
||||
Tensor gelu(Tensor x, Tensor y);
|
||||
Tensor leakyrelu(Tensor x, Tensor y, float alpha);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor hardSigmoid(Tensor x, Tensor y);
|
||||
Tensor hardSwish(Tensor x, Tensor y);
|
||||
|
@ -74,7 +77,7 @@ class GraphHandlerObj {
|
|||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||
Tensor resize(Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
Tensor scales, Tensor roi, vector<uint32_t> sizes_,
|
||||
Tensor scales, Tensor roi, vector<int64_t> sizes_,
|
||||
vector<float> scales_, vector<float> roi_, string mode,
|
||||
string ratioPolicy, string nearestMode,
|
||||
string coordTransMode);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#pragma once
|
||||
#pragma once
|
||||
#ifndef OP_TYPE_H
|
||||
#define OP_TYPE_H
|
||||
|
||||
|
@ -260,6 +260,7 @@ struct OpType {
|
|||
enum class ActType {
|
||||
None,
|
||||
Relu,
|
||||
LeakyRelu,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
};
|
||||
|
|
|
@ -32,7 +32,7 @@ using OpLists = list<Operator>;
|
|||
|
||||
using VType = uint32_t;
|
||||
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN };
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN, ASCEND };
|
||||
/***************** Forward declaration end *****************/
|
||||
|
||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||
|
@ -75,6 +75,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
bool isCuda() const { return device == Device::CUDA; }
|
||||
bool isBang() const { return device == Device::BANG; }
|
||||
bool isKUNLUN() const { return device == Device::KUNLUN; }
|
||||
bool isAscend() const { return device == Device::ASCEND; }
|
||||
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
||||
// TODO: unify these copy APIs
|
||||
virtual void copyBlobFromCPU(void *dst, const void *src,
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class InstanceNormObj : public OperatorObj {
|
||||
float eps;
|
||||
|
||||
public:
|
||||
InstanceNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor scale,
|
||||
Tensor bias, float eps = 1e-5);
|
||||
OP_CLONE(InstanceNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -227,7 +227,24 @@ class PReluObj : public OperatorObj {
|
|||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
class LeakyReluObj : public OperatorObj {
|
||||
float alpha;
|
||||
|
||||
public:
|
||||
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
float alpha = 0.01);
|
||||
OP_CLONE(LeakyReluObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getAlpha() const { return alpha; }
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
class LogObj : public OperatorObj {
|
||||
public:
|
||||
enum LogType {
|
||||
|
|
|
@ -85,7 +85,7 @@ class OnnxStub:
|
|||
while len(sorted_nodes) < len(model.graph.node):
|
||||
updated = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
if all(t in known_edge for t in node.input):
|
||||
if all(t in known_edge or t == "" for t in node.input):
|
||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||
sorted_nodes.append(i)
|
||||
known_edge.update(node.output)
|
||||
|
@ -112,7 +112,6 @@ class OnnxStub:
|
|||
)
|
||||
tensors[input.name].set_input()
|
||||
|
||||
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
if node.op_type == "Conv":
|
||||
|
@ -185,7 +184,7 @@ class OnnxStub:
|
|||
node,
|
||||
{
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"pads": [0, 0, 0, 0],
|
||||
"strides": [1, 1],
|
||||
"output_padding": [0, 0],
|
||||
},
|
||||
|
@ -194,23 +193,67 @@ class OnnxStub:
|
|||
attributes[name]
|
||||
for name in ["dilations", "pads", "strides", "output_padding"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.convTransposed2d(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
op[0],
|
||||
op[1],
|
||||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
tensors[adapt] = self.handler.pad(
|
||||
tensors[node.input[0]], None, p, [-2, -1]
|
||||
)
|
||||
p = [0, 0, 0, 0]
|
||||
else:
|
||||
adapt = node.input[0]
|
||||
|
||||
if len(node.input) > 2:
|
||||
bias = "{}-bias".format(node.output[0])
|
||||
reshape = "{}-reshape".format(node.output[0])
|
||||
tensors[bias] = self.handler.convTransposed2d(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
None,
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
op[0],
|
||||
op[1],
|
||||
)
|
||||
tensors[reshape] = self.handler.reshape(
|
||||
tensors[node.input[2]],
|
||||
None,
|
||||
[
|
||||
1,
|
||||
reduce(
|
||||
lambda acc, x: acc * x,
|
||||
tensors[node.input[2]].shape(),
|
||||
),
|
||||
1,
|
||||
1,
|
||||
],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[bias],
|
||||
tensors[reshape],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.convTransposed2d(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
op[0],
|
||||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]], # input
|
||||
tensors[node.input[1]], # weight
|
||||
tensors[node.input[0]], # input
|
||||
tensors[node.input[1]], # weight
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
|
@ -280,6 +323,21 @@ class OnnxStub:
|
|||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "InstanceNormalization":
|
||||
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
|
||||
|
||||
output = tensors.get(node.output[0])
|
||||
|
||||
tensors[node.output[0]] = self.handler.instanceNormalization(
|
||||
input,
|
||||
output,
|
||||
scale,
|
||||
bias,
|
||||
next(
|
||||
(attr.f for attr in node.attribute if attr.name == "epsilon"),
|
||||
1e-5,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "RMSNorm":
|
||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||
tensors[node.input[0]],
|
||||
|
@ -447,6 +505,15 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "LeakyRelu":
|
||||
tensors[node.output[0]] = self.handler.leakyrelu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.f for attr in node.attribute if attr.name == "alpha"),
|
||||
0.01,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Silu":
|
||||
tensors[node.output[0]] = self.handler.silu(
|
||||
tensors[node.input[0]],
|
||||
|
@ -580,7 +647,7 @@ class OnnxStub:
|
|||
"cubic_coeff_a": -0.75,
|
||||
"exclude_outside": 0,
|
||||
"extrapolation_value": 0.0,
|
||||
"keep_aspect_ratio_policy": "none",
|
||||
"keep_aspect_ratio_policy": "stretch",
|
||||
"mode": "nearest",
|
||||
"nearest_mode": "none",
|
||||
},
|
||||
|
@ -601,15 +668,15 @@ class OnnxStub:
|
|||
"nearest_mode",
|
||||
]
|
||||
)
|
||||
if len(node.input) > 1:
|
||||
if len(node.input) > 1 and node.input[1] in data:
|
||||
roiVal = _parse_data(data[node.input[1]])
|
||||
else:
|
||||
roiVal = []
|
||||
if len(node.input) > 2:
|
||||
if len(node.input) > 2 and node.input[2] in data:
|
||||
scalesVal = _parse_data(data[node.input[2]])
|
||||
else:
|
||||
scalesVal = []
|
||||
if len(node.input) > 3:
|
||||
if len(node.input) > 3 and node.input[3] in data:
|
||||
sizesVal = _parse_data(data[node.input[3]])
|
||||
else:
|
||||
sizesVal = []
|
||||
|
@ -617,9 +684,21 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
output,
|
||||
axes,
|
||||
tensors[node.input[3]] if len(node.input) > 3 else None,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
(
|
||||
tensors[node.input[3]]
|
||||
if len(node.input) > 3 and node.input[3] != ""
|
||||
else None
|
||||
),
|
||||
(
|
||||
tensors[node.input[2]]
|
||||
if len(node.input) > 2 and node.input[2] != ""
|
||||
else None
|
||||
),
|
||||
(
|
||||
tensors[node.input[1]]
|
||||
if len(node.input) > 1 and node.input[1] != ""
|
||||
else None
|
||||
),
|
||||
sizesVal,
|
||||
scalesVal,
|
||||
roiVal,
|
||||
|
@ -629,18 +708,10 @@ class OnnxStub:
|
|||
coordinate_transformation_mode,
|
||||
)
|
||||
elif node.op_type == "Squeeze":
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else None
|
||||
)
|
||||
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
[],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.squeeze(
|
||||
|
@ -649,18 +720,10 @@ class OnnxStub:
|
|||
axes,
|
||||
)
|
||||
elif node.op_type == "Unsqueeze":
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else None
|
||||
)
|
||||
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
)
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes")
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.unsqueeze(
|
||||
tensors[node.input[0]],
|
||||
|
@ -684,24 +747,18 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "RoPE":
|
||||
tensors[node.output[0]]= self.handler.RoPE(
|
||||
tensors[node.output[0]] = self.handler.RoPE(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
split = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if (len(node.input) > 1)
|
||||
else None
|
||||
_parse_data(data[node.input[1]]) if (len(node.input) > 1) else None
|
||||
)
|
||||
if split is None:
|
||||
split = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "split"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "split"),
|
||||
None,
|
||||
)
|
||||
for name, tensor in zip(
|
||||
|
@ -710,11 +767,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
None,
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "axis"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
0,
|
||||
),
|
||||
split if split is not None else len(node.output),
|
||||
|
@ -942,18 +995,25 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Where":
|
||||
## If Y is single -inf, treat Where as Add
|
||||
## If Y is single -inf, treat Where as Add
|
||||
## TODO: deal with cases where Y is single inf or 0
|
||||
if node.input[0] in data and node.input[2] in data:
|
||||
where_condition = to_array(data[node.input[0]])
|
||||
where_alt = to_array(data[node.input[2]])
|
||||
where_alt = to_array(data[node.input[2]])
|
||||
if where_alt.size == 1:
|
||||
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
||||
node.input[0] = node.input[0] + "_alt"
|
||||
if node.input[0] not in data:
|
||||
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
|
||||
data[node.input[0]] = from_array(where_value, node.input[0])
|
||||
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
|
||||
where_value = np.where(
|
||||
where_condition, 0, -np.inf
|
||||
).astype(where_alt.dtype)
|
||||
data[node.input[0]] = from_array(
|
||||
where_value, node.input[0]
|
||||
)
|
||||
tensors[node.input[0]] = self.handler.tensor(
|
||||
list(where_value.shape),
|
||||
data[node.input[0]].data_type,
|
||||
)
|
||||
tensors[node.input[0]].set_weight()
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.input[1]],
|
||||
|
@ -980,8 +1040,7 @@ class OnnxStub:
|
|||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||
)
|
||||
(alpha, beta, bias, size) = (
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "bias", "size"]
|
||||
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.lrn(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#ifdef INFINI_USE_HCCL
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
void ASCENDRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
continue;
|
||||
}
|
||||
|
||||
PerfRecord record;
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
||||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
[&]() { sync(); }, 1, 1);
|
||||
op->print();
|
||||
printf(" op_time on kunlun xpu %lf\n", t);
|
||||
totalTime += t;
|
||||
opTime[op->getOpType()] += t;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ASCENDRuntimeObj::run(const Graph &graph, bool tune,
|
||||
bool profiling) const {
|
||||
if (profiling)
|
||||
IT_TODO_HALT();
|
||||
runWithoutSync(graph, tune, profiling);
|
||||
sync();
|
||||
}
|
||||
|
||||
void ASCENDRuntimeObj::sync() const { aclrtSynchronizeStream(stream); }
|
||||
|
||||
string ASCENDRuntimeObj::toString() const { return "ASCEND Runtime"; }
|
||||
|
||||
void ASCENDRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||
IT_ASSERT(worldSize > 0);
|
||||
IT_ASSERT(rank >= 0);
|
||||
IT_ASSERT(rank < worldSize);
|
||||
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||
#ifdef INFINI_USE_HCCL
|
||||
comm = std::make_unique<HcclCommunicatorObj>(name, worldSize, rank);
|
||||
#else
|
||||
IT_TODO_HALT_MSG("Not compiled with CNCL.");
|
||||
#endif
|
||||
}
|
||||
} // namespace infini
|
|
@ -1,4 +1,4 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "core/graph_handler.h"
|
||||
#include "operators/all_gather.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
|
@ -9,6 +9,7 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/instance_norm.h"
|
||||
#include "operators/layer_norm.h"
|
||||
#include "operators/lrn.h"
|
||||
#include "operators/matmul.h"
|
||||
|
@ -124,7 +125,29 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::instanceNormalization(Tensor input, Tensor output,
|
||||
Tensor scale, Tensor bias,
|
||||
float eps) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<InstanceNormObj>(
|
||||
std::move(input), output, std::move(scale), std::move(bias), eps);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<InstanceNormObj>(std::move(input), output, std::move(scale),
|
||||
std::move(bias), eps)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
Tensor GraphHandlerObj::leakyrelu(Tensor input, Tensor output, float alpha) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<LeakyReluObj>(std::move(input), output, alpha);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<LeakyReluObj>(std::move(input), output, alpha)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
|
||||
|
@ -276,13 +299,13 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
|||
Tensor GraphHandlerObj::resize(Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes,
|
||||
Tensor sizes, Tensor scales, Tensor roi,
|
||||
vector<uint32_t> sizes_, vector<float> scales_,
|
||||
vector<int64_t> sizes_, vector<float> scales_,
|
||||
vector<float> roi_, string mode,
|
||||
string ratioPolicy, string nearestMode,
|
||||
string coordTransMode) {
|
||||
if (sizes_.size() > 0) {
|
||||
sizes->dataMalloc();
|
||||
sizes->copyin<uint32_t>(sizes_);
|
||||
sizes->copyin<int64_t>(sizes_);
|
||||
}
|
||||
if (scales_.size() > 0) {
|
||||
scales->dataMalloc();
|
||||
|
|
|
@ -30,6 +30,9 @@
|
|||
#ifdef USE_KUNLUN
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_ASCEND
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_INTELCPU
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "intelcpu/operator_timer.h"
|
||||
|
@ -65,6 +68,7 @@ void export_values(py::module &m) {
|
|||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // `None` is Python keyword
|
||||
.VALUE(ActType, Relu)
|
||||
.VALUE(ActType, LeakyRelu)
|
||||
.VALUE(ActType, Sigmoid)
|
||||
.VALUE(ActType, Tanh)
|
||||
.export_values();
|
||||
|
@ -102,6 +106,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, BatchNormalization)
|
||||
.VALUE(OpType, Softmax)
|
||||
.VALUE(OpType, Relu)
|
||||
.VALUE(OpType, LeakyRelu)
|
||||
.VALUE(OpType, Gelu)
|
||||
.VALUE(OpType, PRelu)
|
||||
.VALUE(OpType, Sigmoid)
|
||||
|
@ -175,6 +180,12 @@ static Ref<KUNLUNRuntimeObj> kunlun_runtime() {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ASCEND
|
||||
static Ref<ASCENDRuntimeObj> ascend_runtime() {
|
||||
return make_ref<ASCENDRuntimeObj>();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_INTELCPU
|
||||
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
||||
#endif
|
||||
|
@ -347,6 +358,10 @@ void export_functions(py::module &m) {
|
|||
#ifdef USE_KUNLUN
|
||||
.FUNCTION(kunlun_runtime)
|
||||
#endif
|
||||
|
||||
#ifdef USE_ASCEND
|
||||
.FUNCTION(ascend_runtime)
|
||||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(conv_trans_attrs_of)
|
||||
.FUNCTION(matmul_attrs_of)
|
||||
|
@ -433,6 +448,14 @@ void init_graph_builder(py::module &m) {
|
|||
.def(py::init<int>(), py::arg("device") = 0)
|
||||
.def("init_comm", &KUNLUNRuntimeObj::initComm);
|
||||
#endif
|
||||
|
||||
#ifdef USE_ASCEND
|
||||
py::class_<ASCENDRuntimeObj, std::shared_ptr<ASCENDRuntimeObj>, RuntimeObj>(
|
||||
m, "ASCENDRuntime")
|
||||
.def(py::init<int>(), py::arg("device") = 0)
|
||||
.def("init_comm", &ASCENDRuntimeObj::initComm);
|
||||
;
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
|
||||
py::buffer_protocol())
|
||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||
|
@ -506,6 +529,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("instanceNormalization", &Handler::instanceNormalization,
|
||||
policy::move)
|
||||
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
|
@ -518,6 +543,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("min", &Handler::min, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("leakyrelu", &Handler::leakyrelu, policy::move)
|
||||
.def("silu", &Handler::silu, policy::move)
|
||||
.def("gelu", &Handler::gelu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "operators/all_gather.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
namespace infini {
|
||||
class AllGatherHCCL : public ASCENDKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AllGatherObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
int world_size = op->getWorldSize();
|
||||
// Check if world size info in operator matches runtime
|
||||
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
|
||||
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
ASCENDPtr output_temp =
|
||||
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
|
||||
// void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
size_t bytes = op->getInputs(0)->getBytes();
|
||||
size_t count = bytes / op->getDType().getSize();
|
||||
|
||||
HcclComm comm =
|
||||
dynamic_cast<HcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getHcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
HCCLCHECK(HcclAllGather(input, output_temp, uint64_t(count),
|
||||
HCCL_DATA_TYPE_FP32, comm,
|
||||
context->ASCENDHandle()));
|
||||
ACLCHECK(aclrtSynchronizeStream(context->ASCENDHandle()));
|
||||
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
Tensor output = op->getOutput(i);
|
||||
context->copyBlobInsideRuntime(
|
||||
output->getRawDataPtr<float *>(),
|
||||
static_cast<float *>(output_temp) + i * count, bytes);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AllGather, AllGatherHCCL,
|
||||
"AllGather_HCCL_ASCEND");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "operators/all_reduce.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
namespace infini {
|
||||
class AllReduceHCCL : public ASCENDKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AllReduceBaseObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
void *sendBuff = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *recvBuff = op->getOutput()->getRawDataPtr<void *>();
|
||||
|
||||
// HcclDataType
|
||||
|
||||
size_t count = op->getInputs(0)->size();
|
||||
|
||||
HcclComm comm =
|
||||
dynamic_cast<HcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getHcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
HCCLCHECK(HcclAllReduce(sendBuff, recvBuff, count, HCCL_DATA_TYPE_FP32,
|
||||
getRedOp(), comm, context->ASCENDHandle()));
|
||||
ACLCHECK(aclrtSynchronizeStream(context->ASCENDHandle()));
|
||||
}
|
||||
|
||||
virtual HcclReduceOp getRedOp() const = 0;
|
||||
};
|
||||
|
||||
class AllReduceSumHCCL : public AllReduceHCCL {
|
||||
HcclReduceOp getRedOp() const override { return HCCL_REDUCE_SUM; }
|
||||
};
|
||||
class AllReduceProdHCCL : public AllReduceHCCL {
|
||||
HcclReduceOp getRedOp() const override { return HCCL_REDUCE_PROD; }
|
||||
};
|
||||
class AllReduceMinHCCL : public AllReduceHCCL {
|
||||
HcclReduceOp getRedOp() const override { return HCCL_REDUCE_MIN; }
|
||||
};
|
||||
class AllReduceMaxHCCL : public AllReduceHCCL {
|
||||
HcclReduceOp getRedOp() const override { return HCCL_REDUCE_MAX; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceSum, AllReduceSumHCCL,
|
||||
"AllReduce_Sum_HCCL_ASCEND");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceProd, AllReduceProdHCCL,
|
||||
"AllReduce_Prod_HCCL_ASCEND");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceMin, AllReduceMinHCCL,
|
||||
"AllReduce_Min_HCCL_ASCEND");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AllReduceMax, AllReduceMaxHCCL,
|
||||
"AllReduce_Max_HCCL_ASCEND");
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -0,0 +1,100 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "aclnnop/level2/aclnn_batch_norm.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class BatchNormAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
void *const meanData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const varData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
|
||||
auto inD = op->getInputs(0)->getDims();
|
||||
auto inS = op->getInputs(0)->getStride();
|
||||
auto paraD = op->getInputs(1)->getDims();
|
||||
auto paraS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inD);
|
||||
std::vector<int64_t> inputStride = castTo64(inS);
|
||||
std::vector<int64_t> paraDim = castTo64(paraD);
|
||||
std::vector<int64_t> paraStride = castTo64(paraS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), inData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), outData);
|
||||
auto meanTensor = aclCreateTensor(
|
||||
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), meanData);
|
||||
auto varTensor = aclCreateTensor(
|
||||
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), varData);
|
||||
auto scaleTensor =
|
||||
aclCreateTensor(paraDim.data(), paraDim.size(), ACL_FLOAT,
|
||||
paraStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
paraDim.data(), paraDim.size(), scaleData);
|
||||
auto biasTensor = aclCreateTensor(
|
||||
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), biasData);
|
||||
auto savemeanTensor =
|
||||
aclCreateTensor(paraDim.data(), paraDim.size(), ACL_FLOAT,
|
||||
paraStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
paraDim.data(), paraDim.size(), scaleData);
|
||||
auto saveinvstdTensor = aclCreateTensor(
|
||||
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, paraDim.data(), paraDim.size(), biasData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnBatchNormGetWorkspaceSize(
|
||||
inputTensor, scaleTensor, biasTensor, meanTensor, varTensor, false,
|
||||
op->getMomentum(), op->getEps(), outputTensor, savemeanTensor,
|
||||
saveinvstdTensor, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnBatchNorm(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(inputTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
// aclDestroyTensor(meanTensor);
|
||||
// aclDestroyTensor(varTensor);
|
||||
// aclDestroyTensor(scaleTensor);
|
||||
// aclDestroyTensor(biasTensor);
|
||||
// aclDestroyTensor(savemeanTensor);
|
||||
// aclDestroyTensor(saveinvstdTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::BatchNormalization, BatchNormAclnn,
|
||||
"batchnorm_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,70 @@
|
|||
#include "operators/concat.h"
|
||||
#include "aclnnop/level2/aclnn_cat.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ConcatAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
int dim = op->getDim();
|
||||
int num = op->numInputs();
|
||||
|
||||
std::vector<aclTensor *> inputsData{};
|
||||
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto inD = op->getInputs(i)->getDims();
|
||||
auto inS = op->getInputs(i)->getStride();
|
||||
std::vector<int64_t> inputDim = castTo64(inD);
|
||||
std::vector<int64_t> inputStride = castTo64(inS);
|
||||
|
||||
void *const inData = (op->getInputs(i)->getRawDataPtr<void *>());
|
||||
auto tmpTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
inputDim.data(), inputDim.size(), inData);
|
||||
|
||||
inputsData.push_back(tmpTensor);
|
||||
}
|
||||
aclTensorList *tensorList =
|
||||
aclCreateTensorList(inputsData.data(), inputsData.size());
|
||||
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
outputDim.data(), outputDim.size(), outData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnCatGetWorkspaceSize(
|
||||
tensorList, int64_t(dim), outputTensor, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnCat(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensorList(tensorList);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Concat, ConcatAclnn,
|
||||
"concat_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,94 @@
|
|||
#include "operators/conv.h"
|
||||
#include "aclnnop/level2/aclnn_convolution.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ConvAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
|
||||
std::vector<int64_t> pads = {ph, pw};
|
||||
// std::vector<int64_t> ksize = {r, s};
|
||||
std::vector<int64_t> stride = {sh, sw};
|
||||
std::vector<int64_t> dilation = {dh, dw};
|
||||
std::vector<int64_t> outputPadding = {sh - 1, sw - 1};
|
||||
|
||||
aclIntArray *convpads = aclCreateIntArray(pads.data(), pads.size());
|
||||
aclIntArray *convstride =
|
||||
aclCreateIntArray(stride.data(), stride.size());
|
||||
aclIntArray *convdilation =
|
||||
aclCreateIntArray(dilation.data(), dilation.size());
|
||||
aclIntArray *convOutputpadding =
|
||||
aclCreateIntArray(outputPadding.data(), outputPadding.size());
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
auto weightD = op->getInputs(1)->getDims();
|
||||
auto weightS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
std::vector<int64_t> weightDim = castTo64(weightD);
|
||||
std::vector<int64_t> weightStride = castTo64(weightS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), aData);
|
||||
auto weightTensor =
|
||||
aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT,
|
||||
weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
weightDim.data(), weightDim.size(), bData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnConvolutionGetWorkspaceSize(
|
||||
inputTensor, weightTensor, nullptr, convstride, convpads,
|
||||
convdilation, false, convOutputpadding, int64_t(g), outputTensor,
|
||||
int8_t(1), &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(inputTensor);
|
||||
// aclDestroyTensor(weightTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Conv, ConvAclnn, "conv_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,95 @@
|
|||
#include "aclnnop/level2/aclnn_convolution.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ConvTransAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto [oph, opw] = op->getOutputPadding();
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
|
||||
std::vector<int64_t> pads = {ph, pw};
|
||||
// std::vector<int64_t> ksize = {r, s};
|
||||
std::vector<int64_t> stride = {sh, sw};
|
||||
std::vector<int64_t> dilation = {dh, dw};
|
||||
std::vector<int64_t> outputPadding = {oph, opw};
|
||||
|
||||
aclIntArray *convpads = aclCreateIntArray(pads.data(), pads.size());
|
||||
aclIntArray *convstride =
|
||||
aclCreateIntArray(stride.data(), stride.size());
|
||||
aclIntArray *convdilation =
|
||||
aclCreateIntArray(dilation.data(), dilation.size());
|
||||
aclIntArray *convOutputpadding =
|
||||
aclCreateIntArray(outputPadding.data(), outputPadding.size());
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
// void *const biasData = (op->getBias()->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
auto weightD = op->getInputs(1)->getDims();
|
||||
auto weightS = op->getInputs(1)->getStride();
|
||||
// auto biasD = op->getBias()->getDims();
|
||||
// auto biasS = op->getBias()->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
std::vector<int64_t> weightDim = castTo64(weightD);
|
||||
std::vector<int64_t> weightStride = castTo64(weightS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), aData);
|
||||
auto weightTensor =
|
||||
aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT,
|
||||
weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
weightDim.data(), weightDim.size(), bData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnConvolutionGetWorkspaceSize(
|
||||
inputTensor, weightTensor, nullptr, convstride, convpads,
|
||||
convdilation, true, convOutputpadding, int64_t(g), outputTensor,
|
||||
int8_t(1), &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(inputTensor);
|
||||
// aclDestroyTensor(weightTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::ConvTranspose, ConvTransAclnn,
|
||||
"ConvTrans_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,206 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "aclnnop/aclnn_maximum.h"
|
||||
#include "aclnnop/level2/aclnn_add.h"
|
||||
#include "aclnnop/level2/aclnn_div.h"
|
||||
#include "aclnnop/level2/aclnn_mul.h"
|
||||
#include "aclnnop/level2/aclnn_pow_tensor_tensor.h"
|
||||
#include "aclnnop/level2/aclnn_sub.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
#define DEFINE_ELEMENT_WISE_Aclnn(prefix) \
|
||||
class prefix##Aclnn : public ASCENDKernelWithoutConfig { \
|
||||
void compute(const Operator &_op, \
|
||||
const RuntimeObj *_context) const override { \
|
||||
auto op = as<ElementWiseObj>(_op); \
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context); \
|
||||
\
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); \
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); \
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); \
|
||||
\
|
||||
auto a = op->getInputs(0)->getDims(); \
|
||||
auto aS = op->getInputs(0)->getStride(); \
|
||||
auto b = op->getInputs(1)->getDims(); \
|
||||
auto bS = op->getInputs(1)->getStride(); \
|
||||
auto c = op->getOutput()->getDims(); \
|
||||
auto cS = op->getOutput()->getStride(); \
|
||||
\
|
||||
std::vector<int64_t> aDim = castTo64(a); \
|
||||
std::vector<int64_t> aStride = castTo64(aS); \
|
||||
std::vector<int64_t> bDim = castTo64(b); \
|
||||
std::vector<int64_t> bStride = castTo64(bS); \
|
||||
std::vector<int64_t> cDim = castTo64(c); \
|
||||
std::vector<int64_t> cStride = castTo64(cS); \
|
||||
\
|
||||
auto inputA = aclCreateTensor( \
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); \
|
||||
auto inputB = aclCreateTensor( \
|
||||
bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData); \
|
||||
auto output = aclCreateTensor( \
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); \
|
||||
\
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor *executor; \
|
||||
\
|
||||
auto ret = aclnn##prefix##GetWorkspaceSize( \
|
||||
inputA, inputB, output, &workspaceSize, &executor); \
|
||||
void *workspaceAddr = nullptr; \
|
||||
if (workspaceSize > 0) { \
|
||||
workspaceAddr = context->getWorkspace(workspaceSize); \
|
||||
} \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
ret = aclnn##prefix(workspaceAddr, workspaceSize, executor, \
|
||||
context->ASCENDHandle()); \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
\
|
||||
ret = aclDestroyTensor(inputA); \
|
||||
ret = aclDestroyTensor(inputB); \
|
||||
ret = aclDestroyTensor(output); \
|
||||
\
|
||||
return; \
|
||||
} \
|
||||
};
|
||||
|
||||
class AddAclnn : public ASCENDKernelWithoutConfig {
|
||||
virtual tuple<float, float, float> getAlphBeta() const {
|
||||
return {1.f, 1.f, 0.f};
|
||||
}
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto b = op->getInputs(1)->getDims();
|
||||
auto bS = op->getInputs(1)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> bDim = castTo64(b);
|
||||
std::vector<int64_t> bStride = castTo64(bS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto inputB = aclCreateTensor(
|
||||
bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
auto alpha = aclCreateScalar(&bAlpha, ACL_FLOAT);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnAddGetWorkspaceSize(inputA, inputB, alpha, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnAdd(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SubAclnn : public ASCENDKernelWithoutConfig {
|
||||
virtual tuple<float, float, float> getAlphBeta() const {
|
||||
return {1.f, 1.f, 0.f};
|
||||
}
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto b = op->getInputs(1)->getDims();
|
||||
auto bS = op->getInputs(1)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> bDim = castTo64(b);
|
||||
std::vector<int64_t> bStride = castTo64(bS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto inputB = aclCreateTensor(
|
||||
bDim.data(), bDim.size(), ACL_FLOAT, bStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, bDim.data(), bDim.size(), bData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
auto alpha = aclCreateScalar(&bAlpha, ACL_FLOAT);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnSubGetWorkspaceSize(inputA, inputB, alpha, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnSub(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
ret = aclDestroyTensor(inputA);
|
||||
ret = aclDestroyTensor(inputB);
|
||||
ret = aclDestroyScalar(alpha);
|
||||
ret = aclDestroyTensor(output);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
DEFINE_ELEMENT_WISE_Aclnn(PowTensorTensor);
|
||||
DEFINE_ELEMENT_WISE_Aclnn(Div);
|
||||
DEFINE_ELEMENT_WISE_Aclnn(Mul);
|
||||
DEFINE_ELEMENT_WISE_Aclnn(Maximum);
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Pow, PowTensorTensorAclnn,
|
||||
"pow_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Div, DivAclnn, "div_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Mul, MulAclnn, "mul_ASCEND_float");
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Add, AddAclnn, "add_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Sub, SubAclnn, "sub_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Max, MaximumAclnn, "max_ASCEND_float");
|
||||
// REGISTER_KERNEL(Device::ASCEND, OpType::Abs, AbsAclnn, "abs_ASCEND_float");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,84 @@
|
|||
#include "operators/gather.h"
|
||||
#include "aclnnop/level2/aclnn_gather_v2.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class GatherAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GatherObj>(_op);
|
||||
IT_ASSERT(op->getInputs(1)->getDType() == DataType::Int32 ||
|
||||
op->getInputs(1)->getDType() == DataType::Int64);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
int64_t axis = int64_t(op->getAxis());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto b = op->getInputs(1)->getDims();
|
||||
auto bS = op->getInputs(1)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
if (b.size() == 0) {
|
||||
c.insert(c.begin() + axis, 1);
|
||||
cS.insert(cS.begin() + axis, axis > 0 ? cS[axis - 1] : cS[0]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> bDim = castTo64(b);
|
||||
std::vector<int64_t> bStride = castTo64(bS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
|
||||
auto inputB = aclCreateTensor(
|
||||
bDim.data(), bDim.size(),
|
||||
// op->getInputs(1)->getDType() == DataType::Int32 ? ACL_INT32
|
||||
// : ACL_INT64,
|
||||
ACL_INT64, bStride.data(), 0, aclFormat::ACL_FORMAT_ND, bDim.data(),
|
||||
bDim.size(), bData);
|
||||
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnGatherV2GetWorkspaceSize(inputA, axis, inputB, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnGatherV2GetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
|
||||
ret = aclnnGatherV2(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnGatherV2 failed. ERROR: %d\n", ret));
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Gather, GatherAclnn,
|
||||
"gather_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,105 @@
|
|||
#include "operators/instance_norm.h"
|
||||
#include "aclnnop/level2/aclnn_layer_norm.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class InstanceNormAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<InstanceNormObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const weightData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
auto weightD = op->getInputs(1)->getDims();
|
||||
auto weightS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
double eps = static_cast<double>(op->getEps());
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
std::vector<int64_t> weightDim = castTo64(weightD);
|
||||
std::vector<int64_t> weightStride = castTo64(weightS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto axis = 3;
|
||||
|
||||
auto rank = static_cast<int>(inputDim.size());
|
||||
std::vector<int64_t> normalizedShape(rank - axis, 0);
|
||||
for (auto i = rank; i > axis; --i) {
|
||||
normalizedShape[i - 1 - axis] = inputDim[i - 1];
|
||||
}
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), inputData);
|
||||
auto weightTensor =
|
||||
aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT,
|
||||
weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
weightDim.data(), weightDim.size(), weightData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), outputData);
|
||||
|
||||
auto *normArray =
|
||||
aclCreateIntArray(normalizedShape.data(), normalizedShape.size());
|
||||
|
||||
aclTensor *biasTensor = NULL;
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
|
||||
auto biasD = op->getInputs(2)->getDims();
|
||||
auto biasS = op->getInputs(2)->getStride();
|
||||
std::vector<int64_t> biasDim = castTo64(biasD);
|
||||
std::vector<int64_t> biasStride = castTo64(biasS);
|
||||
|
||||
biasTensor = aclCreateTensor(
|
||||
biasDim.data(), biasDim.size(), ACL_FLOAT, biasStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, biasDim.data(), biasDim.size(),
|
||||
biasData);
|
||||
}
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnLayerNormGetWorkspaceSize(
|
||||
inputTensor, normArray, weightTensor, biasTensor, eps, outputTensor,
|
||||
NULL, NULL, &workspaceSize, &executor);
|
||||
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNormGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNorm failed. ERROR: %d\n", ret));
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::InstanceNormalization,
|
||||
InstanceNormAclnn, "InstanceNorm_ASCEND");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,101 @@
|
|||
#include "operators/layer_norm.h"
|
||||
#include "aclnnop/level2/aclnn_layer_norm.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class LayerNormAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LayerNormObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const weightData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
auto weightD = op->getInputs(1)->getDims();
|
||||
auto weightS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
double eps = static_cast<double>(op->getEps());
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
std::vector<int64_t> weightDim = castTo64(weightD);
|
||||
std::vector<int64_t> weightStride = castTo64(weightS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto axis = op->getAxis();
|
||||
auto rank = static_cast<int>(inputDim.size());
|
||||
std::vector<int64_t> normalizedShape(rank - axis, 0);
|
||||
for (auto i = rank; i > axis; --i) {
|
||||
normalizedShape[i - 1 - axis] = inputDim[i - 1];
|
||||
}
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), inputData);
|
||||
auto weightTensor =
|
||||
aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT,
|
||||
weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
weightDim.data(), weightDim.size(), weightData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), outputData);
|
||||
|
||||
auto *normArray =
|
||||
aclCreateIntArray(normalizedShape.data(), normalizedShape.size());
|
||||
|
||||
aclTensor *biasTensor = NULL;
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
|
||||
auto biasD = op->getInputs(2)->getDims();
|
||||
auto biasS = op->getInputs(2)->getStride();
|
||||
std::vector<int64_t> biasDim = castTo64(biasD);
|
||||
std::vector<int64_t> biasStride = castTo64(biasS);
|
||||
|
||||
biasTensor = aclCreateTensor(
|
||||
biasDim.data(), biasDim.size(), ACL_FLOAT, biasStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, biasDim.data(), biasDim.size(),
|
||||
biasData);
|
||||
}
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnLayerNormGetWorkspaceSize(
|
||||
inputTensor, normArray, weightTensor, biasTensor, eps, outputTensor,
|
||||
NULL, NULL, &workspaceSize, &executor);
|
||||
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNormGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNorm failed. ERROR: %d\n", ret));
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::LayerNormalization, LayerNormAclnn,
|
||||
"LayerNorm_ASCEND");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,126 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "aclnnop/level2/aclnn_gemm.h"
|
||||
#include "aclnnop/level2/aclnn_matmul.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||
// unsupport trans for "gemm" whithou biasInput
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
auto input_num = op->numInputs();
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
void *biasData = NULL;
|
||||
if (input_num > 2) {
|
||||
biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
}
|
||||
|
||||
auto selfD = op->getInputs(0)->getDims();
|
||||
auto selfS = op->getInputs(0)->getStride();
|
||||
auto matD = op->getInputs(1)->getDims();
|
||||
auto matS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
std::vector<int> biasD;
|
||||
std::vector<int> biasS;
|
||||
if (input_num > 2) {
|
||||
biasD = op->getInputs(2)->getDims();
|
||||
biasS = op->getInputs(2)->getStride();
|
||||
}
|
||||
|
||||
std::vector<int64_t> selfDim = castTo64(selfD);
|
||||
std::vector<int64_t> selfStride = castTo64(selfS);
|
||||
std::vector<int64_t> matDim = castTo64(matD);
|
||||
std::vector<int64_t> matStride = castTo64(matS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
std::vector<int64_t> biasDim;
|
||||
std::vector<int64_t> biasStride;
|
||||
if (input_num > 2) {
|
||||
biasDim = castTo64(biasD);
|
||||
biasStride = castTo64(biasS);
|
||||
}
|
||||
|
||||
auto selfTensor = aclCreateTensor(
|
||||
selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, selfDim.data(), selfDim.size(), aData);
|
||||
auto matTensor = aclCreateTensor(
|
||||
matDim.data(), matDim.size(), ACL_FLOAT, matStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, matDim.data(), matDim.size(), bData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
aclTensor *biasTensor = NULL;
|
||||
if (input_num > 2) {
|
||||
biasTensor =
|
||||
aclCreateTensor(biasDim.data(), biasDim.size(), ACL_FLOAT,
|
||||
biasStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
biasDim.data(), biasDim.size(), biasData);
|
||||
}
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
if (input_num > 2) {
|
||||
float alpha = 1.0;
|
||||
float beta = 1.0;
|
||||
int32_t transA = op->getTransA();
|
||||
int32_t transB = op->getTransB();
|
||||
|
||||
auto ret = aclnnGemmGetWorkspaceSize(
|
||||
selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA),
|
||||
int64_t(transB), outputTensor, 1, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
ret = aclnnGemm(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnGemm failed. ERROR: %d\n", ret));
|
||||
} else {
|
||||
auto ret =
|
||||
aclnnMatmulGetWorkspaceSize(selfTensor, matTensor, outputTensor,
|
||||
1, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
|
||||
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnMatmul failed. ERROR: %d\n", ret));
|
||||
}
|
||||
|
||||
// aclDestroyTensor(selfTensor);
|
||||
// aclDestroyTensor(matTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::MatMul, MatmulAclnn,
|
||||
"matmul_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,96 @@
|
|||
#include "aclnnop/level2/aclnn_reflection_pad2d.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "operators/pad.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class PadAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PadObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), aData);
|
||||
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
std::vector<int> intPads = op->getPads();
|
||||
|
||||
std::size_t length = intPads.size();
|
||||
std::vector<int64_t> pads(4);
|
||||
if (length == 8) {
|
||||
std::size_t halfLen = intPads.size() / 2;
|
||||
bool condition = true;
|
||||
// std::cout << "Length of intPads: " << length << std::endl;
|
||||
|
||||
for (std::size_t i = 0; i < halfLen; ++i) {
|
||||
condition = (intPads[i] == intPads[i + 4]);
|
||||
|
||||
// std::cout << "intPads[" << i << "]: " << intPads[i] <<
|
||||
// std::endl;
|
||||
}
|
||||
assert(condition);
|
||||
|
||||
pads[0] = intPads[2];
|
||||
pads[1] = intPads[3];
|
||||
pads[2] = intPads[6];
|
||||
pads[3] = intPads[7];
|
||||
} else if (length == 4) {
|
||||
for (std::size_t i = 0; i < 4; ++i) {
|
||||
|
||||
pads[i] = intPads[i];
|
||||
}
|
||||
}
|
||||
|
||||
aclIntArray *padding = aclCreateIntArray(pads.data(), 4);
|
||||
auto ret = aclnnReflectionPad2dGetWorkspaceSize(
|
||||
inputTensor, padding, outputTensor, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnReflectionPad2d(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(inputTensor);
|
||||
// aclDestroyTensor(weightTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Pad, PadAclnn, "pad_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,146 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "aclnnop/level2/aclnn_avgpool2d.h"
|
||||
#include "aclnnop/level2/aclnn_max_pool.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class AvgPooling : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
||||
std::vector<int64_t> ksize = {kh, kw};
|
||||
std::vector<int64_t> stride = {sh, sw};
|
||||
std::vector<int64_t> pad = {ph, pw};
|
||||
|
||||
int64_t divisorOverride = 0;
|
||||
|
||||
auto selfD = op->getInputs(0)->getDims();
|
||||
auto selfS = op->getInputs(0)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> selfDim = castTo64(selfD);
|
||||
std::vector<int64_t> selfStride = castTo64(selfS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
aclIntArray *kernelSize = aclCreateIntArray(ksize.data(), ksize.size());
|
||||
aclIntArray *strides = aclCreateIntArray(stride.data(), stride.size());
|
||||
aclIntArray *paddings = aclCreateIntArray(pad.data(), pad.size());
|
||||
|
||||
auto selfTensor = aclCreateTensor(
|
||||
selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, selfDim.data(), selfDim.size(), aData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnAvgPool2dGetWorkspaceSize(
|
||||
selfTensor, kernelSize, strides, paddings, false, true,
|
||||
divisorOverride, int8_t(0), outputTensor, &workspaceSize,
|
||||
&executor);
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnAvgPool2dGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnAvgPool2d(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnAvgPool2d failed. ERROR: %d\n", ret));
|
||||
|
||||
// aclDestroyTensor(selfTensor);
|
||||
// aclDestroyTensor(outputTensor);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MaxPooling : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
int64_t ceilMode = int64_t(op->getCeilMode());
|
||||
|
||||
std::vector<int64_t> ksize = {kh, kw};
|
||||
std::vector<int64_t> stride = {sh, sw};
|
||||
std::vector<int64_t> pad = {ph, pw};
|
||||
std::vector<int64_t> dilation = {dh, dw};
|
||||
|
||||
auto selfD = op->getInputs(0)->getDims();
|
||||
auto selfS = op->getInputs(0)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> selfDim = castTo64(selfD);
|
||||
std::vector<int64_t> selfStride = castTo64(selfS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
aclIntArray *kernelSize = aclCreateIntArray(ksize.data(), ksize.size());
|
||||
aclIntArray *strides = aclCreateIntArray(stride.data(), stride.size());
|
||||
aclIntArray *paddings = aclCreateIntArray(pad.data(), pad.size());
|
||||
aclIntArray *dilations =
|
||||
aclCreateIntArray(dilation.data(), dilation.size());
|
||||
|
||||
auto selfTensor = aclCreateTensor(
|
||||
selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, selfDim.data(), selfDim.size(), aData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
auto ret = aclnnMaxPoolGetWorkspaceSize(
|
||||
selfTensor, kernelSize, strides, 0, paddings, dilations, ceilMode,
|
||||
outputTensor, &workspaceSize, &executor);
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnMaxPool(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::MaxPool, MaxPooling,
|
||||
"maxpooling_ASCEND_float");
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::AveragePool, AvgPooling,
|
||||
"avgpooling_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,51 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "operators/recv.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
namespace infini {
|
||||
class RecvHCCL : public ASCENDKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<RecvObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *output = op->getOutput(0)->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
const auto shape = op->getShape();
|
||||
int nDims = shape.size();
|
||||
int outputCount = 1;
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
outputCount *= shape[i];
|
||||
}
|
||||
|
||||
HcclComm comm =
|
||||
dynamic_cast<HcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getHcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
uint32_t rank;
|
||||
|
||||
HCCLCHECK(HcclGetRankId(comm, &rank));
|
||||
|
||||
int source = op->getSourceRank();
|
||||
int destination = op->getDestinationRank();
|
||||
|
||||
// printf("###rank:%u,source:%d,outputCount:%d,destination:%d\n", rank,
|
||||
// source, outputCount, destination);
|
||||
if (int(rank) == destination) {
|
||||
HCCLCHECK(HcclRecv(output, uint64_t(outputCount),
|
||||
HCCL_DATA_TYPE_FP32, uint32_t(source), comm,
|
||||
context->ASCENDHandle()));
|
||||
}
|
||||
ACLCHECK(aclrtSynchronizeStream(context->ASCENDHandle()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Recv, RecvHCCL, "Recv_HCCL_ASCEND");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -0,0 +1,121 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "aclnnop/aclnn_mean.h"
|
||||
#include "aclnnop/aclnn_reduce_sum.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class MeanAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceBaseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto axes_set = op->getAxes();
|
||||
std::vector<int> axes;
|
||||
axes.assign(axes_set.begin(), axes_set.end());
|
||||
|
||||
bool KeepDim = op->getKeepDims();
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
std::vector<int64_t> axes_64 = castTo64(axes);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
aclIntArray *dim = aclCreateIntArray(axes_64.data(), axes_64.size());
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnMeanV2GetWorkspaceSize(
|
||||
inputA, dim, KeepDim, true, output, &workspaceSize, &executor);
|
||||
assert(ret == ACL_SUCCESS);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnMeanV2(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceSumAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceBaseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto axes_set = op->getAxes();
|
||||
std::vector<int> axes;
|
||||
axes.assign(axes_set.begin(), axes_set.end());
|
||||
|
||||
bool KeepDim = op->getKeepDims();
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
std::vector<int64_t> axes_64 = castTo64(axes);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
aclIntArray *dim = aclCreateIntArray(axes_64.data(), axes_64.size());
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnReduceSumGetWorkspaceSize(
|
||||
inputA, dim, KeepDim, ACL_FLOAT, output, &workspaceSize, &executor);
|
||||
assert(ret == ACL_SUCCESS);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnReduceSum(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::ReduceMean, MeanAclnn,
|
||||
"reduceMean_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::ReduceSum, ReduceSumAclnn,
|
||||
"reduceSum_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,56 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "aclnnop/level2/aclnn_copy.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class CopyAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aD = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(aD);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
|
||||
auto srcTensor = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto outputTensor = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnInplaceCopyGetWorkspaceSize(outputTensor, srcTensor,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnInplaceCopy(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Reshape, CopyAclnn,
|
||||
"reshape_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Unsqueeze, CopyAclnn,
|
||||
"unsqueeze_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Squeeze, CopyAclnn,
|
||||
"squeeze_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Flatten, CopyAclnn,
|
||||
"Flatten_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Identity, CopyAclnn,
|
||||
"Identity_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,85 @@
|
|||
#include "operators/resize.h"
|
||||
#include "aclnnop/level2/aclnn_resize.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ResizeAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ResizeObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
int nDims = op->getInputs(0)->getRank();
|
||||
if (nDims > 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
vector<float> scalesData = op->getScales();
|
||||
|
||||
const char *mode;
|
||||
switch (op->getMode()) {
|
||||
case ResizeObj::ECoeffMode::nearest:
|
||||
mode = "nearest";
|
||||
break;
|
||||
case ResizeObj::ECoeffMode::linear:
|
||||
mode = "bilinear";
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
auto self = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, aDim.data(), aDim.size(), aData);
|
||||
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, cDim.data(), cDim.size(), cData);
|
||||
|
||||
aclFloatArray *scales = nullptr;
|
||||
scales = aclCreateFloatArray(scalesData.data(), scalesData.size());
|
||||
CHECK_RET(scales != nullptr,
|
||||
LOG_PRINT("aclCreateFloatArray failed.\n"));
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnResizeGetWorkspaceSize(self, scales, mode, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnResizeGetWorkspaceSize failed. ERROR: %d\n", ret));
|
||||
|
||||
ret = aclnnResize(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnResize failed. ERROR: %d\n", ret));
|
||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
// if (tmp_err_msg != NULL) {
|
||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
// }
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Resize, ResizeAclnn, "Resize_ASCEND");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,47 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "operators/send.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "hccl/hccl.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
|
||||
namespace infini {
|
||||
class SendHCCL : public ASCENDKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SendObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
int inputCount =
|
||||
op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||
|
||||
HcclComm comm =
|
||||
dynamic_cast<HcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getHcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
uint32_t rank;
|
||||
|
||||
HCCLCHECK(HcclGetRankId(comm, &rank));
|
||||
|
||||
int source = op->getSourceRank();
|
||||
int destination = op->getDestinationRank();
|
||||
|
||||
// printf("***rank:%u,source:%d,inputCount:%d,destination:%d\n", rank,
|
||||
// source, inputCount, destination);
|
||||
if (int(rank) == source) {
|
||||
HCCLCHECK(HcclSend(input, uint64_t(inputCount), HCCL_DATA_TYPE_FP32,
|
||||
uint32_t(destination), comm,
|
||||
context->ASCENDHandle()));
|
||||
}
|
||||
ACLCHECK(aclrtSynchronizeStream(context->ASCENDHandle()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Send, SendHCCL, "Send_HCCL_ASCEND");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -0,0 +1,74 @@
|
|||
#include "operators/slice.h"
|
||||
#include "aclnnop/aclnn_slice_v2.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class SliceAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SliceObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto starts_32 = op->getStarts();
|
||||
auto ends_32 = op->getEnds();
|
||||
auto steps_32 = op->getSteps();
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
std::vector<int64_t> starts_64 = castTo64(starts_32);
|
||||
std::vector<int64_t> ends_64 = castTo64(ends_32);
|
||||
std::vector<int64_t> steps_64 = castTo64(steps_32);
|
||||
|
||||
vector<int64_t> axes_64 = vector<int64_t>(starts_32.size(), 0);
|
||||
for (int i = 0; i < int(starts_32.size()); i++) {
|
||||
axes_64[i] = i;
|
||||
}
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
aclIntArray *starts =
|
||||
aclCreateIntArray(starts_64.data(), starts_64.size());
|
||||
aclIntArray *ends = aclCreateIntArray(ends_64.data(), ends_64.size());
|
||||
aclIntArray *steps =
|
||||
aclCreateIntArray(steps_64.data(), steps_64.size());
|
||||
aclIntArray *axes = aclCreateIntArray(axes_64.data(), axes_64.size());
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret =
|
||||
aclnnSliceV2GetWorkspaceSize(inputA, starts, ends, axes, steps,
|
||||
output, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnSliceV2(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Slice, SliceAclnn,
|
||||
"slice_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,58 @@
|
|||
#include "operators/softmax.h"
|
||||
#include "aclnnop/level2/aclnn_softmax.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class SoftmaxAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
int64_t axis = int64_t(op->getAxis());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
auto input = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnSoftmaxGetWorkspaceSize(input, axis, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnSoftmax(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(input);
|
||||
// aclDestroyTensor(output);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Softmax, SoftmaxAclnn,
|
||||
"softmax_ASCEND_float");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,69 @@
|
|||
#include "operators/split.h"
|
||||
#include "aclnnop/aclnn_split_tensor.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class SplitAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SplitObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
|
||||
int64_t dim = op->getDim();
|
||||
int num = op->numOutputs();
|
||||
int dimSize = a.at(op->getDim());
|
||||
uint64_t splitSections = dimSize / num;
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
|
||||
std::vector<aclTensor *> outputsData{};
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto c = op->getOutput(i)->getDims();
|
||||
auto cS = op->getOutput(i)->getStride();
|
||||
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
void *const cData = (op->getOutput(i)->getRawDataPtr<void *>());
|
||||
|
||||
aclTensor *tmpTensor = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
outputsData.push_back(tmpTensor);
|
||||
}
|
||||
aclTensorList *tensorList =
|
||||
aclCreateTensorList(outputsData.data(), outputsData.size());
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnSplitTensorGetWorkspaceSize(
|
||||
inputA, splitSections, dim, tensorList, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnSplitTensor(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Split, SplitAclnn,
|
||||
"split_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,116 @@
|
|||
#include "operators/transpose.h"
|
||||
#include "aclnnop/level2/aclnn_permute.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class PermuteAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<TransposeObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
auto c = op->getOutput()->getDims();
|
||||
auto cS = op->getOutput()->getStride();
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(a);
|
||||
std::vector<int64_t> aStride = castTo64(aS);
|
||||
std::vector<int64_t> cDim = castTo64(c);
|
||||
std::vector<int64_t> cStride = castTo64(cS);
|
||||
|
||||
auto _permute = op->getPermute();
|
||||
std::vector<int64_t> permute = castTo64(_permute);
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
aclIntArray *dims = aclCreateIntArray(permute.data(), permute.size());
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnPermuteGetWorkspaceSize(inputA, dims, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnPermute(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class DepthToSpaceAclnn : public ASCENDKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DepthToSpaceObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto reshapeDim = op->getReshapeDim();
|
||||
auto reshapeStride = getStride(reshapeDim);
|
||||
auto transposeDim = op->getTransposeDim();
|
||||
auto transposeStride = getStride(transposeDim);
|
||||
|
||||
std::vector<int64_t> aDim = castTo64(reshapeDim);
|
||||
std::vector<int64_t> aStride = castTo64(reshapeStride);
|
||||
std::vector<int64_t> cDim = castTo64(transposeDim);
|
||||
std::vector<int64_t> cStride = castTo64(transposeStride);
|
||||
|
||||
auto mode = op->getMode();
|
||||
|
||||
std::vector<int64_t> permute;
|
||||
if (mode == 0) {
|
||||
permute = {0, 3, 4, 1, 5, 2};
|
||||
} else {
|
||||
permute = {0, 1, 4, 2, 5, 3};
|
||||
}
|
||||
|
||||
auto inputA = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
aclIntArray *dims = aclCreateIntArray(permute.data(), permute.size());
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnPermuteGetWorkspaceSize(inputA, dims, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnPermute(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Transpose, PermuteAclnn,
|
||||
"transpose_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::DepthToSpace, DepthToSpaceAclnn,
|
||||
"DepthToSpace_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,246 @@
|
|||
#include "operators/unary.h"
|
||||
#include "aclnnop/aclnn_erf.h"
|
||||
#include "aclnnop/level2/aclnn_abs.h"
|
||||
#include "aclnnop/level2/aclnn_acos.h"
|
||||
#include "aclnnop/level2/aclnn_atan.h"
|
||||
#include "aclnnop/level2/aclnn_ceil.h"
|
||||
#include "aclnnop/level2/aclnn_cos.h"
|
||||
#include "aclnnop/level2/aclnn_exp.h"
|
||||
#include "aclnnop/level2/aclnn_floor.h"
|
||||
#include "aclnnop/level2/aclnn_gelu.h"
|
||||
#include "aclnnop/level2/aclnn_hardswish.h"
|
||||
#include "aclnnop/level2/aclnn_leaky_relu.h"
|
||||
#include "aclnnop/level2/aclnn_neg.h"
|
||||
#include "aclnnop/level2/aclnn_reciprocal.h"
|
||||
#include "aclnnop/level2/aclnn_relu.h"
|
||||
#include "aclnnop/level2/aclnn_round.h"
|
||||
#include "aclnnop/level2/aclnn_sigmoid.h"
|
||||
#include "aclnnop/level2/aclnn_sin.h"
|
||||
#include "aclnnop/level2/aclnn_sqrt.h"
|
||||
#include "aclnnop/level2/aclnn_tanh.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ReluAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
std::vector<int64_t> aDim(a.size(), 1);
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
aDim[i] = int64_t(a[i]);
|
||||
}
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> aStride(aS.size(), 1);
|
||||
for (size_t i = 0; i < aS.size(); ++i) {
|
||||
aStride[i] = int64_t(aS[i]);
|
||||
}
|
||||
auto c = op->getInputs(0)->getDims();
|
||||
std::vector<int64_t> cDim(c.size(), 1);
|
||||
for (size_t i = 0; i < c.size(); ++i) {
|
||||
cDim[i] = int64_t(c[i]);
|
||||
}
|
||||
auto cS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> cStride(cS.size(), 1);
|
||||
for (size_t i = 0; i < cS.size(); ++i) {
|
||||
cStride[i] = int64_t(cS[i]);
|
||||
}
|
||||
|
||||
auto input = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret =
|
||||
aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnRelu(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(input);
|
||||
// aclDestroyTensor(output);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
class LeakyReluAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto a = op->getInputs(0)->getDims();
|
||||
std::vector<int64_t> aDim(a.size(), 1);
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
aDim[i] = int64_t(a[i]);
|
||||
}
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> aStride(aS.size(), 1);
|
||||
for (size_t i = 0; i < aS.size(); ++i) {
|
||||
aStride[i] = int64_t(aS[i]);
|
||||
}
|
||||
auto c = op->getInputs(0)->getDims();
|
||||
std::vector<int64_t> cDim(c.size(), 1);
|
||||
for (size_t i = 0; i < c.size(); ++i) {
|
||||
cDim[i] = int64_t(c[i]);
|
||||
}
|
||||
auto cS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> cStride(cS.size(), 1);
|
||||
for (size_t i = 0; i < cS.size(); ++i) {
|
||||
cStride[i] = int64_t(cS[i]);
|
||||
}
|
||||
|
||||
auto input = aclCreateTensor(
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output = aclCreateTensor(
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
float negativeSlopeValue = op->getAlpha();
|
||||
aclScalar *negativeSlope = nullptr;
|
||||
negativeSlope =
|
||||
aclCreateScalar(&negativeSlopeValue, aclDataType::ACL_FLOAT);
|
||||
|
||||
auto ret = aclnnLeakyReluGetWorkspaceSize(input, negativeSlope, output,
|
||||
&workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
// aclDestroyTensor(input);
|
||||
// aclDestroyTensor(output);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_UNARY_Aclnn(prefix) \
|
||||
class prefix##Aclnn : public ASCENDKernelWithoutConfig { \
|
||||
void compute(const Operator &_op, \
|
||||
const RuntimeObj *_context) const override { \
|
||||
auto op = as<UnaryObj>(_op); \
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context); \
|
||||
\
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); \
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); \
|
||||
\
|
||||
auto a = op->getInputs(0)->getDims(); \
|
||||
std::vector<int64_t> aDim(a.size(), 1); \
|
||||
for (size_t i = 0; i < a.size(); ++i) { \
|
||||
aDim[i] = int64_t(a[i]); \
|
||||
} \
|
||||
auto aS = op->getInputs(0)->getStride(); \
|
||||
std::vector<int64_t> aStride(aS.size(), 1); \
|
||||
for (size_t i = 0; i < aS.size(); ++i) { \
|
||||
aStride[i] = int64_t(aS[i]); \
|
||||
} \
|
||||
auto c = op->getInputs(0)->getDims(); \
|
||||
std::vector<int64_t> cDim(c.size(), 1); \
|
||||
for (size_t i = 0; i < c.size(); ++i) { \
|
||||
cDim[i] = int64_t(c[i]); \
|
||||
} \
|
||||
auto cS = op->getInputs(0)->getStride(); \
|
||||
std::vector<int64_t> cStride(cS.size(), 1); \
|
||||
for (size_t i = 0; i < cS.size(); ++i) { \
|
||||
cStride[i] = int64_t(cS[i]); \
|
||||
} \
|
||||
\
|
||||
auto input = aclCreateTensor( \
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); \
|
||||
auto output = aclCreateTensor( \
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); \
|
||||
\
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor *executor; \
|
||||
\
|
||||
auto ret = aclnn##prefix##GetWorkspaceSize( \
|
||||
input, output, &workspaceSize, &executor); \
|
||||
void *workspaceAddr = nullptr; \
|
||||
if (workspaceSize > 0) { \
|
||||
workspaceAddr = context->getWorkspace(workspaceSize); \
|
||||
} \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
ret = aclnn##prefix(workspaceAddr, workspaceSize, executor, \
|
||||
context->ASCENDHandle()); \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
\
|
||||
return; \
|
||||
} \
|
||||
};
|
||||
|
||||
DEFINE_UNARY_Aclnn(Abs);
|
||||
DEFINE_UNARY_Aclnn(Sigmoid);
|
||||
DEFINE_UNARY_Aclnn(Hardswish);
|
||||
DEFINE_UNARY_Aclnn(Gelu);
|
||||
|
||||
DEFINE_UNARY_Aclnn(Tanh);
|
||||
DEFINE_UNARY_Aclnn(Sin);
|
||||
DEFINE_UNARY_Aclnn(Cos);
|
||||
DEFINE_UNARY_Aclnn(Acos);
|
||||
DEFINE_UNARY_Aclnn(Atan);
|
||||
|
||||
DEFINE_UNARY_Aclnn(Ceil);
|
||||
DEFINE_UNARY_Aclnn(Floor);
|
||||
DEFINE_UNARY_Aclnn(Exp);
|
||||
DEFINE_UNARY_Aclnn(Neg);
|
||||
DEFINE_UNARY_Aclnn(Reciprocal);
|
||||
DEFINE_UNARY_Aclnn(Sqrt);
|
||||
DEFINE_UNARY_Aclnn(Round);
|
||||
|
||||
DEFINE_UNARY_Aclnn(Erf);
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Relu, ReluAclnn, "relu_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::LeakyRelu, LeakyReluAclnn,
|
||||
"leakyrelu_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Abs, AbsAclnn, "abs_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Sigmoid, SigmoidAclnn,
|
||||
"sigmoid_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::HardSwish, HardswishAclnn,
|
||||
"hardswish_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Tanh, TanhAclnn, "tanh_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Gelu, GeluAclnn, "gelu_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Sin, SinAclnn, "sin_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Cos, CosAclnn, "cos_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Acos, AcosAclnn, "acos_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Atan, AtanAclnn, "atan_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Neg, NegAclnn, "neg_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Ceil, CeilAclnn, "ceil_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Floor, FloorAclnn,
|
||||
"floor_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Exp, ExpAclnn, "exp_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Reciprocal, ReciprocalAclnn,
|
||||
"reciprocal_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Sqrt, SqrtAclnn, "sqrt_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Round, RoundAclnn,
|
||||
"round_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Erf, ErfAclnn, "erf_ASCEND_float");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,48 @@
|
|||
#include "operators/instance_norm.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
InstanceNormObj::InstanceNormObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
Tensor scale, Tensor bias, float eps)
|
||||
: OperatorObj(OpType::InstanceNormalization, TensorVec{input, scale, bias},
|
||||
{output}),
|
||||
eps(eps) {
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> InstanceNormObj::inferShape(const TensorVec &inputs) {
|
||||
return {{inputs[0]->getDims()}};
|
||||
}
|
||||
|
||||
vector<DataType> InstanceNormObj::inferDataType(const TensorVec &inputs) const {
|
||||
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
std::string InstanceNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "InstanceNormalization[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "eps=" << eps << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "scale=" << inputs[1]->getGuid() << ",";
|
||||
os << "bias=" << inputs[2]->getGuid() << ",";
|
||||
os << "output=";
|
||||
for (auto output : outputs)
|
||||
os << output->getGuid() << ",";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> InstanceNormObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> InstanceNormObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -101,8 +101,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
|||
// copy sizes data to host.
|
||||
IT_ASSERT(sizes->getDataBlob() != nullptr);
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
std::shared_ptr<int> dataObj((int *)runtime->alloc(sizes->getBytes()),
|
||||
[&](int *p) { runtime->dealloc(p); });
|
||||
std::shared_ptr<int64_t> dataObj(
|
||||
(int64_t *)runtime->alloc(sizes->getBytes()),
|
||||
[&](int64_t *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
sizes->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, sizes->getRawDataPtr<void *>(), sizes->getBytes());
|
||||
|
@ -193,7 +194,7 @@ vector<DataType> ResizeObj::inferDataType(const TensorVec &inputs) const {
|
|||
}
|
||||
if (isResizeBySizes()) {
|
||||
auto sizes = inputs[1];
|
||||
IT_ASSERT(sizes && sizes->getDType() == DataType::UInt32);
|
||||
IT_ASSERT(sizes && sizes->getDType() == DataType::Int64);
|
||||
} else {
|
||||
auto scales = inputs[1];
|
||||
IT_ASSERT(scales && scales->getDType() == DataType::Float32);
|
||||
|
@ -220,8 +221,7 @@ optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) {
|
|||
|
||||
std::string ResizeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Resize"
|
||||
<< "[" << getGuid() << "]";
|
||||
os << "Resize" << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
if (inputs.size() == 3) {
|
||||
|
|
|
@ -283,6 +283,39 @@ vector<int> PReluObj::getWorkloadVector() const {
|
|||
|
||||
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
float _alpha)
|
||||
: OperatorObj(OpType::LeakyRelu, {input}, {output}), alpha(_alpha) {
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
std::string LeakyReluObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ",";
|
||||
os << "alpha=" << alpha << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> LeakyReluObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
||||
vector<int> LeakyReluObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> LeakyReluObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
|
||||
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
|
|
|
@ -104,6 +104,8 @@ std::string device_to_str(Device device) {
|
|||
return "INTELCPU";
|
||||
case Device::KUNLUN:
|
||||
return "KUNLUN";
|
||||
case Device::ASCEND:
|
||||
return "ASCEND";
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/all_gather.h"
|
||||
#include "test.h"
|
||||
#include <thread>
|
||||
|
||||
static int WORLD_SIZE = 2;
|
||||
|
||||
namespace infini {
|
||||
|
||||
void allGather(const string taskName, int deviceID, vector<float> data,
|
||||
vector<vector<float>> ans) {
|
||||
// Create Runtimes and initiate communication
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Runtime ascnedRuntime = make_ref<ASCENDRuntimeObj>(deviceID);
|
||||
ascnedRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||
// Create Graph and insert allReduce operation
|
||||
Graph g = make_ref<GraphObj>(ascnedRuntime);
|
||||
auto input =
|
||||
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||
auto op = g->addOp<AllGatherObj>(input, std::nullopt, WORLD_SIZE);
|
||||
// Copy data from CPU to MLU
|
||||
g->dataMalloc();
|
||||
input->copyin(data);
|
||||
// Run operation
|
||||
ascnedRuntime->run(g);
|
||||
// Copy output from MLU to CPU
|
||||
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||
auto result = op->getOutputs()[i]->clone(cpuRuntime);
|
||||
EXPECT_TRUE(result->equalData(ans[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ASCEND_AllGather, run) {
|
||||
aclInit(nullptr);
|
||||
|
||||
vector<float> data[4] = {{2., 3.}, {5., 6.}, {7., 8.}, {9., 10.}};
|
||||
vector<vector<float>> ans = {{2., 3.}, {5., 6.}, {7., 8.}, {9., 10.}};
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans);
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -0,0 +1,120 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "test.h"
|
||||
#include <thread>
|
||||
|
||||
static int WORLD_SIZE = 2;
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <typename OperatorObj>
|
||||
void allReduce(const string taskName, int deviceID, vector<float> data,
|
||||
vector<float> ans) {
|
||||
// Create Runtimes and initiate communication
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Runtime ascendRuntime = make_ref<ASCENDRuntimeObj>(deviceID);
|
||||
ascendRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||
// Create Graph and insert allReduce operation
|
||||
Graph g = make_ref<GraphObj>(ascendRuntime);
|
||||
auto input =
|
||||
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||
auto op = g->addOp<OperatorObj>(input, nullptr);
|
||||
// Copy data from CPU to GPU
|
||||
g->dataMalloc();
|
||||
input->copyin(data);
|
||||
// Run operation
|
||||
ascendRuntime->run(g);
|
||||
// Copy output from GPU to CPU
|
||||
auto result = op->getOutput()->clone(cpuRuntime);
|
||||
|
||||
EXPECT_TRUE(result->equalData(ans));
|
||||
}
|
||||
|
||||
// TEST(ASCEND_AllReduce, sum) {
|
||||
// aclInit(nullptr);
|
||||
// vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||
// vector<float> ans = {7., 9.};
|
||||
//
|
||||
// std::vector<std::thread> threads;
|
||||
// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
// threads.emplace_back(allReduce<AllReduceSumObj>,
|
||||
// "test_allreduce_sum",
|
||||
// gpu, data[gpu], ans);
|
||||
// }
|
||||
// for (auto &thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
// aclFinalize();
|
||||
// }
|
||||
|
||||
// TEST(ASCEND_AllReduce, prod) {
|
||||
// aclInit(nullptr);
|
||||
// vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||
// vector<float> ans = {10., 18.};
|
||||
//
|
||||
// std::vector<std::thread> threads;
|
||||
// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
// threads.emplace_back(allReduce<AllReduceProdObj>,
|
||||
// "test_allreduce_prod",
|
||||
// gpu, data[gpu], ans);
|
||||
// }
|
||||
// for (auto &thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
// aclFinalize();
|
||||
// }
|
||||
|
||||
// TEST(ASCEND_AllReduce, min) {
|
||||
// aclInit(nullptr);
|
||||
// vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||
// vector<float> ans = {2., 3.};
|
||||
//
|
||||
// std::vector<std::thread> threads;
|
||||
// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
// threads.emplace_back(allReduce<AllReduceMinObj>,
|
||||
// "test_allreduce_min",
|
||||
// gpu, data[gpu], ans);
|
||||
// }
|
||||
// for (auto &thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
// aclFinalize();
|
||||
// }
|
||||
|
||||
TEST(ASCEND_AllReduce, max) {
|
||||
aclInit(nullptr);
|
||||
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||
vector<float> ans = {5., 6.};
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
threads.emplace_back(allReduce<AllReduceMaxObj>, "test_allreduce_max",
|
||||
gpu, data[gpu], ans);
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
// TEST(ASCEND_AllReduce, avg) {
|
||||
// vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||
// vector<float> ans = {3.5, 4.5};
|
||||
//
|
||||
// std::vector<std::thread> threads;
|
||||
// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
// threads.emplace_back(allReduce<AllReduceAvgObj>,
|
||||
// "test_allreduce_avg",
|
||||
// gpu, data[gpu], ans);
|
||||
// }
|
||||
// for (auto &thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
// }
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/batch_norm.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(ascend_BatchNorm, run) {
|
||||
aclInit(nullptr);
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build cpu graph
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
|
||||
auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
|
||||
// Build input data on CPU
|
||||
gCpu->dataMalloc();
|
||||
iCpu->setData(IncrementalGenerator());
|
||||
meanCpu->copyin(vector<float>{1, 6, 9});
|
||||
varCpu->copyin(vector<float>{4, 1, 9});
|
||||
scaleCpu->setData(OneGenerator());
|
||||
biasCpu->setData(ZeroGenerator());
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(npuRuntime);
|
||||
auto i = g->cloneTensor(iCpu);
|
||||
auto mean = g->cloneTensor(meanCpu);
|
||||
auto var = g->cloneTensor(varCpu);
|
||||
auto scale = g->cloneTensor(scaleCpu);
|
||||
auto bias = g->cloneTensor(biasCpu);
|
||||
auto op =
|
||||
g->addOp<BatchNormObj>(i, nullptr, mean, var, scale, bias, 0.9, 0);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
npuRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto o = op->getOutput();
|
||||
auto ocpu = o->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ocpu->equalData(vector<float>{
|
||||
-0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.333333, 0.666667}));
|
||||
|
||||
aclFinalize();
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,65 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/concat.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu2->dataMalloc();
|
||||
inputCpu2->setData(generator);
|
||||
Tensor inputCpu3 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu3->dataMalloc();
|
||||
inputCpu3->setData(generator);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu1 = npuGraph->cloneTensor(inputCpu1);
|
||||
auto inputNpu2 = npuGraph->cloneTensor(inputCpu2);
|
||||
auto inputNpu3 = npuGraph->cloneTensor(inputCpu3);
|
||||
auto npuOp = npuGraph->addOp<T>(TensorVec{inputNpu1, inputNpu2, inputNpu3},
|
||||
nullptr, 2);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu1->setData(generator);
|
||||
inputNpu2->setData(generator);
|
||||
inputNpu3->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
inputCpu1->print();
|
||||
inputCpu1->printData();
|
||||
inputCpu2->print();
|
||||
inputCpu2->printData();
|
||||
inputCpu3->print();
|
||||
inputCpu3->printData();
|
||||
outputNpu2Cpu->print();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(ascend_Concat, run) {
|
||||
aclInit(nullptr);
|
||||
testConcat<ConcatObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,60 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testConv(const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
const Shape &shapeA, const Shape &shapeB) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shapeA, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shapeB, DataType::Float32, cpuRuntime);
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu1 = npuGraph->cloneTensor(inputCpu1);
|
||||
auto inputNpu2 = npuGraph->cloneTensor(inputCpu2);
|
||||
auto npuOp =
|
||||
npuGraph->addOp<T>(inputNpu1, inputNpu2, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu1->setData(generatorA);
|
||||
inputNpu2->setData(generatorB);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
cpuGraph->addTensor(inputCpu1);
|
||||
cpuGraph->addTensor(inputCpu2);
|
||||
auto cpuOp =
|
||||
cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu1->setData(generatorA);
|
||||
inputCpu2->setData(generatorB);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
// outputCpu->printData();
|
||||
// outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu, 1e-3));
|
||||
}
|
||||
|
||||
TEST(ascend_Conv, run) {
|
||||
aclInit(nullptr);
|
||||
testConv<ConvObj>(IncrementalGenerator(), IncrementalGenerator(),
|
||||
Shape{1, 3, 128, 128}, Shape{2, 3, 3, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,58 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvTransposedAclnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
std::vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime npu = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gNpu = make_ref<GraphObj>(npu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Npu = gNpu->cloneTensor(i0Cpu);
|
||||
Tensor w0Npu = gNpu->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gNpu->addOp<ConvTransposed2dObj>(i0Npu, w0Npu, nullptr, padding,
|
||||
padding, stride, stride,
|
||||
dilation, dilation);
|
||||
gNpu->dataMalloc();
|
||||
i0Npu->setData(generator);
|
||||
w0Npu->setData(generator);
|
||||
// Execute on CUDA
|
||||
npu->run(gNpu);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
o0Cpu->printData();
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(ascend_ConvTransposed, run) {
|
||||
aclInit(nullptr);
|
||||
testConvTransposedAclnn(
|
||||
IncrementalGenerator(),
|
||||
std::vector<float>{0., 0., 1., 2., 3., 0., 6., 12., 18.,
|
||||
16., 8., 30., 36., 42., 32., 16., 54., 60.,
|
||||
66., 48., 24., 62., 67., 72., 45.});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,67 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testElementWise(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape0, const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape0, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu2->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
inputCpu2->setData(generator);
|
||||
|
||||
inputCpu1->print();
|
||||
inputCpu1->printData();
|
||||
inputCpu2->print();
|
||||
inputCpu2->printData();
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu1 = npuGraph->cloneTensor(inputCpu1);
|
||||
auto inputNpu2 = npuGraph->cloneTensor(inputCpu2);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu1, inputNpu2, nullptr);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu1->setData(generator);
|
||||
inputNpu2->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
|
||||
outputNpu2Cpu->print();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(ascend_ElementWise, run) {
|
||||
aclInit(nullptr);
|
||||
// testElementWise<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testElementWise<AddObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testElementWise<SubObj>(IncrementalGenerator(), Shape{1, 1, 48, 48},
|
||||
// Shape{1, 1, 1, 1});
|
||||
testElementWise<MaximumObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
Shape{1, 2, 2, 3});
|
||||
// testElementWise<DivObj>(IncrementalGenerator(),
|
||||
// Shape{1}, Shape{1, 2, 2, 3});
|
||||
// testElementWise<MulObj>(IncrementalGenerator(),
|
||||
// Shape{1, 2, 2, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,42 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testErf(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// Npu
|
||||
Graph npuGraph = make_ref<GraphObj>(ascendRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu, nullptr);
|
||||
npuGraph->dataMalloc();
|
||||
ascendRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
inputCpu->printData();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(ascend_Erf, run) {
|
||||
aclInit(nullptr);
|
||||
testErf<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,100 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(ascend_Gather, run) {
|
||||
aclInit(nullptr);
|
||||
//{
|
||||
// // Runtime
|
||||
// Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
// auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// // Build input data on CPU
|
||||
// Tensor inputCpu =
|
||||
// make_ref<TensorObj>(Shape{3, 2}, DataType::Float32, cpuRuntime);
|
||||
// Tensor indexCpu =
|
||||
// make_ref<TensorObj>(Shape{2, 2}, DataType::Int32, cpuRuntime);
|
||||
|
||||
// // NPU
|
||||
// Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
// auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
// auto indexNpu = npuGraph->cloneTensor(indexCpu);
|
||||
// auto npuOp = npuGraph->addOp<GatherObj>(inputNpu, indexNpu, nullptr,
|
||||
// 0); npuGraph->dataMalloc(); inputNpu->copyin(vector<float>{1, 2, 3, 4,
|
||||
// 5, 6}); indexNpu->copyin(vector<int>{0, 1, 1, 2});
|
||||
// npuRuntime->run(npuGraph);
|
||||
// auto outputNpu = npuOp->getOutput();
|
||||
// auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// // Check
|
||||
// EXPECT_TRUE(
|
||||
// outputNpu2Cpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||
//}
|
||||
{
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
// make_ref<TensorObj>(Shape{3, 3}, DataType::Float32, cpuRuntime);
|
||||
make_ref<TensorObj>(Shape{1, 2, 1024, 64, 4}, DataType::Float32,
|
||||
cpuRuntime);
|
||||
Tensor indexCpu =
|
||||
make_ref<TensorObj>(Shape{1}, DataType::Int64, cpuRuntime);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto indexNpu = npuGraph->cloneTensor(indexCpu);
|
||||
auto npuOp = npuGraph->addOp<GatherObj>(inputNpu, indexNpu, nullptr, 1);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(IncrementalGenerator());
|
||||
indexNpu->copyin(vector<int64_t>{0});
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
// EXPECT_TRUE(outputNpu2Cpu->equalData(vector<float>{0, 2, 3, 5, 6,
|
||||
// 8}));
|
||||
// EXPECT_TRUE(outputNpu2Cpu->equalData(vector<float>{0, 3, 6}));
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
//{
|
||||
// // Runtime
|
||||
// Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
// auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
//
|
||||
// // Build input data on CPU
|
||||
// Tensor inputCpu =
|
||||
// make_ref<TensorObj>(Shape{3, 2}, DataType::Float32, cpuRuntime);
|
||||
// Tensor indexCpu =
|
||||
// make_ref<TensorObj>(Shape{2, 2}, DataType::Int64, cpuRuntime);
|
||||
//
|
||||
// // NPU
|
||||
// Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
// auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
// auto indexNpu = npuGraph->cloneTensor(indexCpu);
|
||||
// auto npuOp = npuGraph->addOp<GatherObj>(inputNpu, indexNpu, nullptr,
|
||||
// 0); npuGraph->dataMalloc();
|
||||
// inputNpu->copyin(std::vector<float>{1.0, 1.2, 2.3, 3.4, 4.5, 5.7});
|
||||
// indexNpu->copyin(vector<int64_t>{0, 1, 1, 2});
|
||||
// npuRuntime->run(npuGraph);
|
||||
// auto outputNpu = npuOp->getOutput();
|
||||
// auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
//
|
||||
// // Check
|
||||
// EXPECT_TRUE(outputNpu2Cpu->equalData(
|
||||
// vector<float>{1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7}));
|
||||
//}
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,72 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/instance_norm.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_instancenormFp32(const Shape &inputShape,
|
||||
const vector<float> &inputData,
|
||||
const Shape &scaleShape,
|
||||
const vector<float> &scaleData, float eps,
|
||||
const vector<float> &ExpectData,
|
||||
const Shape &biasShape,
|
||||
const vector<float> &biasData) {
|
||||
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
bias->copyin(biasData); //
|
||||
// bias->printData();
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gAscend = make_ref<GraphObj>(ascendRuntime);
|
||||
auto biasNpu = gAscend->cloneTensor(bias);
|
||||
auto inputNpu = gAscend->cloneTensor(input);
|
||||
auto scaleNpu = gAscend->cloneTensor(scale);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
auto op =
|
||||
gAscend->addOp<InstanceNormObj>(inputNpu, nullptr, scaleNpu, biasNpu,
|
||||
eps); // InstancenormObj
|
||||
gAscend->dataMalloc();
|
||||
biasNpu->copyin(biasData);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
inputNpu->copyin(inputData);
|
||||
scaleNpu->copyin(scaleData);
|
||||
ascendRuntime->run(gAscend);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_InstancenormFp32, run) {
|
||||
aclInit(nullptr);
|
||||
test_instancenormFp32(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5,
|
||||
vector<float>{
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
||||
Shape{3}, vector<float>{0, 0, 0});
|
||||
|
||||
aclFinalize();
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,152 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/layer_norm.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_layernormFp32(
|
||||
const Shape &inputShape, const vector<float> &inputData,
|
||||
const Shape &scaleShape, const vector<float> &scaleData, float eps,
|
||||
int axis, int stash_type, const vector<float> &ExpectData,
|
||||
const std::optional<Shape> &bShape = std::nullopt,
|
||||
const std::optional<std::vector<float>> &biasData = std::nullopt) {
|
||||
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
if (bShape.has_value() && biasData.has_value()) {
|
||||
Shape biasShape = *bShape;
|
||||
|
||||
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
bias->copyin(*biasData); //
|
||||
// bias->printData();
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gAscend = make_ref<GraphObj>(ascendRuntime);
|
||||
auto biasNpu = gAscend->cloneTensor(bias);
|
||||
auto inputNpu = gAscend->cloneTensor(input);
|
||||
auto scaleNpu = gAscend->cloneTensor(scale);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
auto op =
|
||||
gAscend->addOp<LayerNormObj>(inputNpu, scaleNpu, nullptr, biasNpu,
|
||||
eps, axis, stash_type); // LayernormObj
|
||||
gAscend->dataMalloc();
|
||||
biasNpu->copyin(*biasData);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
inputNpu->copyin(inputData);
|
||||
scaleNpu->copyin(scaleData);
|
||||
ascendRuntime->run(gAscend);
|
||||
|
||||
auto oCpu =
|
||||
gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
} else {
|
||||
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gAscend = make_ref<GraphObj>(ascendRuntime);
|
||||
|
||||
auto inputNpu = gAscend->cloneTensor(input);
|
||||
auto scaleNpu = gAscend->cloneTensor(scale);
|
||||
auto op =
|
||||
gAscend->addOp<LayerNormObj>(inputNpu, scaleNpu, nullptr, nullptr,
|
||||
eps, axis, stash_type); // LayernormObj
|
||||
gAscend->dataMalloc();
|
||||
|
||||
inputNpu->copyin(inputData);
|
||||
scaleNpu->copyin(scaleData);
|
||||
ascendRuntime->run(gAscend);
|
||||
|
||||
auto oCpu =
|
||||
gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CUDA_LayernormFp32, run) {
|
||||
aclInit(nullptr);
|
||||
test_layernormFp32(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
vector<float>{
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
||||
Shape{3}, vector<float>{0, 0, 0});
|
||||
// test_layernormFp32(
|
||||
// Shape{2, 3, 2, 3},
|
||||
// vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
// 9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
// 18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
// 27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
// Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
// vector<float>{
|
||||
// -0.0674207, 0.2000000, 1.1123679, -0.0674207,
|
||||
// 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
// -0.0674207, 0.2000000, 1.1123679, -0.0674207,
|
||||
// 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
// -0.0674207, 0.2000000, 1.1123679, -0.0674207,
|
||||
// 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
// -0.0674207, 0.2000000, 1.1123679, -0.0674207,
|
||||
// 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
|
||||
// Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||
// test_layernormFp32(
|
||||
// Shape{2, 3, 2, 3},
|
||||
// vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
// 9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
// 18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
// 27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
// Shape{1}, vector<float>{0.3}, 1e-5, 3, 1,
|
||||
// vector<float>{
|
||||
// -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000,
|
||||
// 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207,
|
||||
// 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
// -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000,
|
||||
// 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207,
|
||||
// 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
// -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000,
|
||||
// 0.8674207},
|
||||
// Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||
// test_layernormFp32(
|
||||
// Shape{2, 3, 2, 3},
|
||||
// vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
// 9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
// 18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
// 27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
// Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
// vector<float>{-0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
// 0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
// 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
// -0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
// 0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
// 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
// -0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
// 0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
// 0.6123678, -0.3674207, 0.0000000, 0.6123678});
|
||||
|
||||
aclFinalize();
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,59 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testMatmul(const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA,
|
||||
const Shape &shapeB) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shapeA, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shapeB, DataType::Float32, cpuRuntime);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu1 = npuGraph->cloneTensor(inputCpu1);
|
||||
auto inputNpu2 = npuGraph->cloneTensor(inputCpu2);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu1, inputNpu2, nullptr);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu1->setData(generatorA);
|
||||
inputNpu2->setData(generatorB);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
|
||||
cpuGraph->addTensor(inputCpu1);
|
||||
cpuGraph->addTensor(inputCpu2);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu1->setData(generatorA);
|
||||
inputCpu2->setData(generatorB);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(ascend_Matmul, run) {
|
||||
aclInit(nullptr);
|
||||
testMatmul<MatmulObj>(IncrementalGenerator(), IncrementalGenerator(), false,
|
||||
false, Shape{1, 2, 3}, Shape{1, 3, 4});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,48 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/pooling.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T, typename std::enable_if<std::is_base_of<PoolingObj, T>{},
|
||||
int>::type = 0>
|
||||
void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp =
|
||||
npuGraph->addOp<T>(inputNpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||
// npuGraph->addOp<T>(inputNpu, nullptr, 2, 2, 1, 1, 0, 0, 1, 1, 0);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
inputCpu->printData();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(cnnl_Pooling, run) {
|
||||
aclInit(nullptr);
|
||||
testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 3, 5, 5});
|
||||
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,84 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/reduce.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <typename ReduceObjT>
|
||||
void test_reduce(const Shape &shape, const vector<float> &data,
|
||||
const optional<const vector<int>> &axes, bool keepDims,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// Build NPU graph
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto op = npuGraph->addOp<ReduceObjT>(inputNpu, nullptr, axes, keepDims);
|
||||
|
||||
// allocate NPU memory
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->copyin(data);
|
||||
|
||||
// Execute on NPU
|
||||
npuRuntime->run(npuGraph);
|
||||
|
||||
// clone NPU output to CPU
|
||||
auto outputNpu = op->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(ascend_ReduceMean, run) {
|
||||
aclInit(nullptr);
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{3, 2, 2}, vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, true, vector<float>{18.25});
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt,
|
||||
false, vector<float>{18.25});
|
||||
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
TEST(ascend_ReduceSum, run) {
|
||||
test_reduce<ReduceSumObj>(Shape{3, 2, 2},
|
||||
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
std::nullopt, true, vector<float>{12});
|
||||
test_reduce<ReduceSumObj>(Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
std::nullopt, false, vector<float>{12});
|
||||
|
||||
test_reduce<ReduceSumObj>(
|
||||
Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{30, 36, 102, 108});
|
||||
test_reduce<ReduceSumObj>(
|
||||
Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{30, 36, 102, 108});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,84 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testReshape(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const Shape &outputShape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu, nullptr, outputShape);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
inputCpu->print();
|
||||
inputCpu->printData();
|
||||
outputNpu2Cpu->print();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(inputCpu->equalData(outputNpu2Cpu, 1e-3));
|
||||
}
|
||||
|
||||
void testFlatten(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, int axis) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<FlattenObj>(inputNpu, nullptr, axis);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
inputCpu->print();
|
||||
inputCpu->printData();
|
||||
outputNpu2Cpu->print();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(inputCpu->equalData(outputNpu2Cpu, 1e-3));
|
||||
}
|
||||
|
||||
TEST(ascend_Unary, run) {
|
||||
aclInit(nullptr);
|
||||
testReshape<ReshapeObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
Shape{1, 2, 6});
|
||||
testReshape<SqueezeObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
Shape{0});
|
||||
testReshape<UnsqueezeObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
Shape{4});
|
||||
testFlatten(IncrementalGenerator(), Shape{1, 2, 2, 3}, 2);
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,68 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/resize.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Resize, Ascend_downsample_scales_nearest) {
|
||||
aclInit(nullptr);
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gNpu = make_ref<GraphObj>(ascendRuntime);
|
||||
|
||||
auto inputNpu = gNpu->cloneTensor(input);
|
||||
auto scalesNpu = gNpu->cloneTensor(scales);
|
||||
auto op = gNpu->addOp<ResizeObj>(inputNpu, nullptr, std::nullopt, nullptr,
|
||||
scalesNpu, nullptr);
|
||||
gNpu->dataMalloc();
|
||||
inputNpu->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scalesNpu->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
ascendRuntime->run(gNpu);
|
||||
|
||||
// copy output from NPU to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 3}));
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
// TEST(Resize, Ascend_upsample_scales_nearest) {
|
||||
// Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
// Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
//
|
||||
// auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
// auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
// gCpu->dataMalloc();
|
||||
// input->copyin(vector<float>{1, 2, 3, 4});
|
||||
// scales->copyin(vector<float>{1, 1, 2, 3});
|
||||
//
|
||||
// auto ascendRuntime = make_ref<ascendRuntimeObj>();
|
||||
// Graph gNpu = make_ref<GraphObj>(ascendRuntime);
|
||||
//
|
||||
// auto inputNpu = gNpu->cloneTensor(input);
|
||||
// auto scalesNpu = gNpu->cloneTensor(scales);
|
||||
// auto op = gNpu->addOp<ResizeObj>(inputNpu, nullptr, std::nullopt,
|
||||
// nullptr,
|
||||
// scalesNpu, nullptr);
|
||||
// gNpu->dataMalloc();
|
||||
// inputNpu->copyin(vector<float>{1, 2, 3, 4});
|
||||
// scalesNpu->copyin(vector<float>{1, 1, 2, 3});
|
||||
// ascendRuntime->run(gNpu);
|
||||
//
|
||||
// // copy output from NPU to CPU
|
||||
// auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
// EXPECT_TRUE(
|
||||
// oCpu->equalData(vector<float>{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
|
||||
// 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}));
|
||||
// }
|
||||
} // namespace infini
|
|
@ -0,0 +1,92 @@
|
|||
#ifdef INFINI_USE_HCCL
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "ascend/hccl_communicator.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/recv.h"
|
||||
#include "operators/send.h"
|
||||
#include "test.h"
|
||||
#include <thread>
|
||||
|
||||
namespace infini {
|
||||
|
||||
void sendrecv(const string taskName, int deviceID, vector<float> data,
|
||||
const Shape &dataShape, int WORLD_SIZE, int source,
|
||||
int destination) {
|
||||
// Create Runtimes and initiate communication
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Runtime ascnedRuntime = make_ref<ASCENDRuntimeObj>(deviceID);
|
||||
ascnedRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||
|
||||
if (deviceID == source) {
|
||||
Graph gSend = make_ref<GraphObj>(ascnedRuntime);
|
||||
auto input = gSend->addTensor(Shape{static_cast<int>(data.size())},
|
||||
DataType::Float32);
|
||||
auto opSend =
|
||||
gSend->addOp<SendObj>(input, source, destination, nullptr);
|
||||
|
||||
// Copy data from CPU to GPU
|
||||
gSend->dataMalloc();
|
||||
input->copyin(data);
|
||||
ascnedRuntime->run(gSend);
|
||||
}
|
||||
|
||||
// ----------------
|
||||
|
||||
if (deviceID == destination) {
|
||||
Graph gRecv = make_ref<GraphObj>(ascnedRuntime);
|
||||
int outputType = 1;
|
||||
// auto input =
|
||||
// gRecv->addTensor(Shape{static_cast<int>(data.size())},DataType::Float32);
|
||||
auto opRecv = gRecv->addOp<RecvObj>(nullptr, source, destination,
|
||||
dataShape, outputType, nullptr);
|
||||
gRecv->dataMalloc();
|
||||
ascnedRuntime->run(gRecv);
|
||||
|
||||
auto result = opRecv->getOutput()->clone(cpuRuntime);
|
||||
EXPECT_TRUE(result->equalData(data));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ASCEND_SendRecv1, run) {
|
||||
// Only 1 device gets data. Every rank should have the same data after
|
||||
// sendrecv.
|
||||
aclInit(nullptr);
|
||||
vector<float> data = {2., 3., 5., 6.};
|
||||
|
||||
int WORLD_SIZE = 4;
|
||||
int source = 0;
|
||||
int destination = 2;
|
||||
std::vector<std::thread> threads;
|
||||
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2},
|
||||
WORLD_SIZE, source, destination);
|
||||
}
|
||||
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
// TEST(ASCEND_SendRecv2, run) {
|
||||
// // Only 1 device gets data. Every rank should have the same data after
|
||||
// // sendrecv.
|
||||
// vector<float> data = {2., 3., 5., 6.};//
|
||||
|
||||
// int WORLD_SIZE = 3;
|
||||
// int source = 0;
|
||||
// int destination = 2;
|
||||
// std::vector<std::thread> threads;
|
||||
// for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||
// threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2,
|
||||
// 2},
|
||||
// WORLD_SIZE, source, destination);
|
||||
// }//
|
||||
|
||||
// for (auto &thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
//}
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -0,0 +1,41 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/slice.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(ascend_Unary, run) {
|
||||
aclInit(nullptr);
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime);
|
||||
// inputCpu->dataMalloc();
|
||||
// inputCpu->setData(IncrementalGenerator());
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<SliceObj>(inputNpu, nullptr, vector<int>{1, 1},
|
||||
vector<int>{2, 5}, vector<int>{0, 3},
|
||||
std::nullopt);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(IncrementalGenerator());
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(outputNpu2Cpu->equalData(
|
||||
vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,61 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/softmax.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testSoftmax(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, int axis, vector<float> Out) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
// inputCpu1->setData(generator);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu1 = npuGraph->cloneTensor(inputCpu1);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu1, nullptr, axis);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu1->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(outputNpu2Cpu->equalData(Out));
|
||||
}
|
||||
|
||||
TEST(ascend_ElementWise, run) {
|
||||
aclInit(nullptr);
|
||||
testSoftmax<SoftmaxObj>(
|
||||
IncrementalGenerator(), Shape{2, 2, 2, 2}, 1,
|
||||
vector<float>{0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138,
|
||||
0.9820138, 0.9820138, 0.9820138, 0.0179862, 0.0179862,
|
||||
0.0179862, 0.0179862, 0.9820138, 0.9820138, 0.9820138,
|
||||
0.9820138});
|
||||
testSoftmax<SoftmaxObj>(
|
||||
IncrementalGenerator(), Shape{2, 2, 2, 2}, 2,
|
||||
vector<float>{0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029,
|
||||
0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029,
|
||||
0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971,
|
||||
0.8807971});
|
||||
testSoftmax<SoftmaxObj>(
|
||||
IncrementalGenerator(), Shape{2, 2, 2, 2}, 3,
|
||||
vector<float>{0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414,
|
||||
0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
|
||||
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414,
|
||||
0.7310586});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,50 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/split.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testSplit(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
// GPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = npuGraph->addOp<T>(inputNpu, std::nullopt, 3, 3);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto o0Cpu = gpuOp->getOutput(0)->clone(cpuRuntime);
|
||||
auto o1Cpu = gpuOp->getOutput(1)->clone(cpuRuntime);
|
||||
auto o2Cpu = gpuOp->getOutput(2)->clone(cpuRuntime);
|
||||
// Check
|
||||
inputCpu->print();
|
||||
inputCpu->printData();
|
||||
o0Cpu->print();
|
||||
o0Cpu->printData();
|
||||
o1Cpu->print();
|
||||
o1Cpu->printData();
|
||||
o2Cpu->print();
|
||||
o2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(ascend_Split, run) {
|
||||
aclInit(nullptr);
|
||||
testSplit<SplitObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,49 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/transpose.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testTranspose(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const Shape &permute) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu, nullptr, permute);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
inputCpu->print();
|
||||
inputCpu->printData();
|
||||
outputNpu2Cpu->print();
|
||||
outputNpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(ascend_Unary, run) {
|
||||
aclInit(nullptr);
|
||||
testTranspose<TransposeObj>(IncrementalGenerator(), Shape{1, 1, 2, 3},
|
||||
vector<int>{0, 1, 3, 2});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,97 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// NPU
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
auto inputNpu = npuGraph->cloneTensor(inputCpu);
|
||||
auto npuOp = npuGraph->addOp<T>(inputNpu, nullptr);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->setData(generator);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputNpu2Cpu, 1e-3));
|
||||
}
|
||||
|
||||
void testLeakyRelu(const Shape &shape, const vector<float> &inputData,
|
||||
const vector<float> &ExpectData, float alpha) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor(shape, DataType::Float32);
|
||||
|
||||
gCpu->dataMalloc();
|
||||
|
||||
input->copyin(inputData);
|
||||
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
|
||||
// NPU
|
||||
|
||||
auto inputNpu = npuGraph->cloneTensor(input);
|
||||
auto npuOp = npuGraph->addOp<LeakyReluObj>(inputNpu, nullptr, alpha);
|
||||
npuGraph->dataMalloc();
|
||||
inputNpu->copyin(inputData);
|
||||
npuRuntime->run(npuGraph);
|
||||
auto outputNpu = npuOp->getOutput();
|
||||
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(ascend_Unary, run) {
|
||||
aclInit(nullptr);
|
||||
testLeakyRelu(Shape{1, 2, 2, 3},
|
||||
vector<float>{-6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6},
|
||||
vector<float>{-0.0600, -0.0500, -0.0400, -0.0300, -0.0200,
|
||||
-0.0100, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000,
|
||||
6.0000},
|
||||
0.01);
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SinObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<CosObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<ACosObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<ATanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testUnary<CeilObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testUnary<FloorObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testUnary<ExpObj>(IncrementalGenerators(), Shape{1, 2, 2, 3});
|
||||
testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testUnary<ReciprocalObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// testUnary<RoundObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue