forked from jiuyuan/InfiniTensor
Xpu (#82)
* support kunlun xpu and add an operator named Add * add sub, mul, div, pow, maximum, minimum * add code * add xpu code * add code * add matmul * add transpose * add unary operator * add unary operator * add some operator * add code * support run resnet18 on xpu * add code * add max pool2d * fix xpu code, let it can run. * 添加XPU算子 (#120) * add floordiv for xpu * add batchnorm for xpu * add more cast types for xpu * add conv_trans for xpu * add pad for xpu * add logical ops for xpu * fix format for xpu src and include * fix format for xpu test * fix format for xpu src --------- Co-authored-by: Bolun <bolunz@u.nus.edu> * Xpu abs (#121) * add: unary kernel for xpu * formatting * format * format * format * fix: pointer jump * fix optype comments * fix bug introduced while resolving conflict * change cmake option for kunlunxin xpu from 'xpu' to 'kunlun'; fix bug after merging distributed infrastructure * Add doc support for xpu (#141) * fix * fix * fix pooling test * format * format * fix * fix * set cmake version requirement * fix cmakelists * rename xpu to kunlun * fix * fix format * fix format * fix format * fix change name to kunlun * format * fix format * clang format * fix format --------- Co-authored-by: root <root@localhost.localdomain> Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com> Co-authored-by: Bolun <bolunz@u.nus.edu> Co-authored-by: zhangyue207 <138768300+zhangyue207@users.noreply.github.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com> Co-authored-by: baominghelly <41820386+baominghelly@users.noreply.github.com> Co-authored-by: Bolun <chamberlain0w0@gmail.com>
This commit is contained in:
parent
8e4d88fb9f
commit
1184fa131f
|
@ -1,16 +1,23 @@
|
|||
cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit
|
||||
include(CMakeDependentOption)
|
||||
project(InfiniTensor C CXX)
|
||||
|
||||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(USE_CUDA "Support CUDA GPU" OFF)
|
||||
option(USE_BANG "Support BANG MLU" OFF)
|
||||
option(USE_KUNLUN "Support KUNLUN XPU" OFF)
|
||||
option(USE_INTELCPU "Support INTELCPU" OFF)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
||||
option(BUILD_DIST "Build project for distributed running" OFF)
|
||||
option(BUILD_TEST "Build tests" OFF)
|
||||
|
||||
if(USE_CUDA)
|
||||
message("CMake 3.18 or higher is required for setting CUDAToolkit")
|
||||
cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit
|
||||
else()
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
endif()
|
||||
|
||||
include(CMakeDependentOption)
|
||||
project(InfiniTensor C CXX)
|
||||
|
||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
||||
|
@ -128,6 +135,11 @@ if(USE_BANG)
|
|||
list (APPEND SRC ${SRC_BANG})
|
||||
endif()
|
||||
|
||||
if(USE_KUNLUN)
|
||||
file(GLOB_RECURSE SRC_KUNLUN src/kunlun/*.cc src/kernels/kunlun/*.cc )
|
||||
list (APPEND SRC ${SRC_KUNLUN})
|
||||
endif()
|
||||
|
||||
if(USE_INTELCPU)
|
||||
file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc )
|
||||
list (APPEND SRC ${SRC_INTELCPU})
|
||||
|
@ -243,6 +255,35 @@ if(USE_BANG)
|
|||
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
endif()
|
||||
|
||||
if(USE_KUNLUN)
|
||||
add_compile_definitions(USE_KUNLUN=1)
|
||||
if ((NOT DEFINED KUNLUN_HOME) AND (NOT DEFINED ENV{KUNLUN_HOME}))
|
||||
message(FATAL_ERROR "KUNLUN_HOME is not defined from cmake or env")
|
||||
elseif (DEFINED KUNLUN_HOME)
|
||||
set(KUNLUN_HOME ${KUNLUN_HOME} CACHE STRING "KUNLUN_HOME directory for Kunlun development")
|
||||
else()
|
||||
set(KUNLUN_HOME $ENV{KUNLUN_HOME} CACHE STRING "KUNLUN_HOME directory for Kunlun development")
|
||||
endif()
|
||||
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
||||
|
||||
include_directories("${KUNLUN_HOME}/XTDK/include/")
|
||||
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64")
|
||||
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/XTDK/shlib")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
|
||||
elseif(DEFINED TARGET_CPU_ARCH)
|
||||
set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
else()
|
||||
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
endif()
|
||||
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
||||
|
||||
target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++)
|
||||
endif()
|
||||
|
||||
# # Python bindings
|
||||
# pybind11_add_module(infini MODULE ${FFI})
|
||||
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||
|
@ -275,6 +316,9 @@ if(BUILD_TEST)
|
|||
if (USE_BANG)
|
||||
build_test(test/kernels/bang/*.cc)
|
||||
endif()
|
||||
if (USE_KUNLUN)
|
||||
build_test(test/kernels/kunlun/*.cc)
|
||||
endif()
|
||||
if (USE_INTELCPU)
|
||||
build_test(test/kernels/intelcpu/*.cc)
|
||||
endif()
|
||||
|
|
2
Makefile
2
Makefile
|
@ -3,6 +3,7 @@
|
|||
TYPE ?= Release
|
||||
CUDA ?= OFF
|
||||
BANG ?= OFF
|
||||
KUNLUN ?= OFF
|
||||
INTELCPU ?= off
|
||||
BACKTRACE ?= ON
|
||||
TEST ?= ON
|
||||
|
@ -25,6 +26,7 @@ endif
|
|||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
||||
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||
|
||||
|
|
|
@ -133,6 +133,13 @@
|
|||
make install-python BANG=ON
|
||||
```
|
||||
|
||||
编译 CPU 部分,同时编译昆仑 XPU 部分:
|
||||
|
||||
```bash
|
||||
export KUNLUN_HOME=/path/to/your/kunlun_home
|
||||
make install-python KUNLUN=ON
|
||||
```
|
||||
|
||||
3. 使用方法
|
||||
|
||||
安装成功后,您就可以使用本项目的 Python 接口进行编码并运行。具体使用方式可以参考项目样例代码 example/Resnet/resnet.py 以及用户使用手册
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
- `TYPE`:编译模式(`debug`/`release`),默认值为 `release`
|
||||
- `CUDA`:是否编译 CUDA 后端,默认为 `OFF`,`ON` 打开
|
||||
- `BANG`:是否编译寒武纪后端,默认为 `OFF`,`ON` 打开
|
||||
- `KUNLUN`:是否编译昆仑后端,默认为 `OFF`,`ON` 打开
|
||||
- `BACKTRACE`:是否启用栈回溯,默认为 `ON`,`OFF` 关闭,建议调试时打开
|
||||
- `TEST`:是否编译 `googletest`,默认为 `ON`,`OFF` 关闭,只有 `test-cpp` 时必要
|
||||
|
||||
|
|
2
env.sh
2
env.sh
|
@ -35,4 +35,4 @@ export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}"
|
|||
# ├── tools
|
||||
# ├── version
|
||||
# └── XTDK
|
||||
export XPU_HOME=/usr/local/xpu
|
||||
export KUNLUN_HOME=/usr/local/xpu
|
||||
|
|
|
@ -21,10 +21,10 @@ struct OpType {
|
|||
Add, // Binary
|
||||
And, // Binary
|
||||
ArgMax, //
|
||||
Asin, // Binary
|
||||
Asinh, // Binary
|
||||
Atan, // Binary
|
||||
Atanh, // Binary
|
||||
Asin, // Unary
|
||||
Asinh, // Unary
|
||||
Atan, // Unary
|
||||
Atanh, // Unary
|
||||
AveragePool, // Pool
|
||||
BatchNormalization, //
|
||||
Bernoulli, //
|
||||
|
|
|
@ -30,7 +30,7 @@ using OpLists = list<Operator>;
|
|||
|
||||
using VType = uint32_t;
|
||||
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU };
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN };
|
||||
/***************** Forward declaration end *****************/
|
||||
|
||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||
|
@ -72,6 +72,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
}
|
||||
bool isCuda() const { return device == Device::CUDA; }
|
||||
bool isBang() const { return device == Device::BANG; }
|
||||
bool isKUNLUN() const { return device == Device::KUNLUN; }
|
||||
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
||||
// TODO: unify these copy APIs
|
||||
virtual void copyBlobFromCPU(void *dst, const void *src,
|
||||
|
|
|
@ -180,14 +180,15 @@ class TensorObj : public TensorBaseObj {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool equalDataImpl(const T *a, const T *b, size_t size) const {
|
||||
bool equalDataImpl(const T *a, const T *b, size_t size,
|
||||
double relativeError = 1e-6) const {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
if (a[i] != b[i])
|
||||
return false;
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) >
|
||||
1e-6) {
|
||||
relativeError) {
|
||||
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include "xpu/runtime_ex.h"
|
||||
#include "xpu/xdnn.h"
|
||||
|
||||
#define checkKUNLUNError(call) \
|
||||
{ \
|
||||
auto err = call; \
|
||||
if (XPU_SUCCESS != err) { \
|
||||
fprintf(stderr, "KUNLUN error in %s:%i : %s.\n", __FILE__, \
|
||||
__LINE__, xpu_strerror(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
using KUNLUNPtr = void *;
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
#include "core/kernel.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class KUNLUNKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
virtual void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const {
|
||||
compute(op, context);
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *_context) const {
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,73 @@
|
|||
#pragma once
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_common.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class KUNLUNRuntimeObj : public RuntimeObj {
|
||||
private:
|
||||
baidu::xpu::api::Context *xdnn;
|
||||
KUNLUNPtr workspace;
|
||||
size_t workspaceSize;
|
||||
|
||||
public:
|
||||
KUNLUNRuntimeObj() : RuntimeObj(Device::KUNLUN) {
|
||||
xdnn = baidu::xpu::api::create_context();
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 3ll << 30; // 3 GB
|
||||
// std::cout<<workspaceSize/1024/1024/1024<< std::endl;
|
||||
// std::cout<<std::bitset<64>(workspaceSize)<< std::endl;
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
virtual ~KUNLUNRuntimeObj() {
|
||||
dealloc(workspace);
|
||||
baidu::xpu::api::destroy_context(xdnn);
|
||||
}
|
||||
string toString() const override;
|
||||
|
||||
void run(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const;
|
||||
// double runEvaluation(const Graph &graph, int nWarmups,
|
||||
// int nEvaluations) const;
|
||||
void sync() const;
|
||||
KUNLUNPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
checkKUNLUNError(
|
||||
xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN));
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { xpu_free(ptr); }
|
||||
baidu::xpu::api::Context *KUNLUNHandle() const { return xdnn; }
|
||||
KUNLUNPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
}
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
|
||||
}
|
||||
|
||||
void copyBlobToCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||
}
|
||||
|
||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||
XPUMemcpyKind::XPU_DEVICE_TO_DEVICE);
|
||||
}
|
||||
|
||||
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
||||
|
||||
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
||||
|
||||
private:
|
||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
double getPerfConvXdnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name);
|
||||
double getPerfMatmulXdnn(int b, int m, int n, int k, const char *name);
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -35,6 +35,7 @@ class OnnxStub:
|
|||
The Onnx model imported into infinitensor.
|
||||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
|
@ -74,7 +75,6 @@ class OnnxStub:
|
|||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
|
||||
node_name = []
|
||||
new_node_name = []
|
||||
for node in model.graph.node:
|
||||
|
@ -244,7 +244,13 @@ class OnnxStub:
|
|||
)
|
||||
(k, d, p, s, ceil_mode) = (
|
||||
attributes[name]
|
||||
for name in ["kernel_shape", "dilations", "pads", "strides", "ceil_mode"]
|
||||
for name in [
|
||||
"kernel_shape",
|
||||
"dilations",
|
||||
"pads",
|
||||
"strides",
|
||||
"ceil_mode",
|
||||
]
|
||||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
|
@ -289,7 +295,8 @@ class OnnxStub:
|
|||
},
|
||||
)
|
||||
(k, p, s, ceil_mode) = (
|
||||
attributes[name] for name in ["kernel_shape", "pads", "strides", "ceil_mode"]
|
||||
attributes[name]
|
||||
for name in ["kernel_shape", "pads", "strides", "ceil_mode"]
|
||||
)
|
||||
if p[0] != p[2] or p[1] != p[3]:
|
||||
adapt = "{}-adapt".format(node.output[0])
|
||||
|
@ -714,10 +721,9 @@ class OnnxStub:
|
|||
elif node.op_type == "Constant":
|
||||
output_name = node.output[0]
|
||||
attributes = _parse_attribute(node)
|
||||
tensor = attributes['value']
|
||||
tensor = attributes["value"]
|
||||
dims = [d for d in tensor.dims]
|
||||
tensors[output_name] = self.handler.tensor(
|
||||
dims, tensor.data_type)
|
||||
tensors[output_name] = self.handler.tensor(dims, tensor.data_type)
|
||||
data[output_name] = tensor
|
||||
tensors[output_name].set_weight()
|
||||
else:
|
||||
|
|
|
@ -208,7 +208,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||
|
||||
'''Gelu operator is not supported by onnx 14.1 currently.'''
|
||||
"""Gelu operator is not supported by onnx 14.1 currently."""
|
||||
def test_gelu(self):
|
||||
pass
|
||||
# x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -239,7 +239,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
||||
make_and_import_model(make_graph([tanh], "tanh", [x], [y]))
|
||||
|
||||
|
||||
def test_hard_sigmoid(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -263,7 +263,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
abs = make_node("Abs", ["x"], ["y"], name="abs")
|
||||
make_and_import_model(make_graph([abs], "abs", [x], [y]))
|
||||
|
||||
|
||||
def test_neg(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -319,9 +319,15 @@ class TestStringMethods(unittest.TestCase):
|
|||
indices = make_tensor_value_info("indices", TensorProto.INT64, [2, 1, 2])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 1, 2])
|
||||
gatherElements = make_node(
|
||||
"GatherElements", ["data", "indices"], ["output"], axis=1, name="gatherElements"
|
||||
"GatherElements",
|
||||
["data", "indices"],
|
||||
["output"],
|
||||
axis=1,
|
||||
name="gatherElements",
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph([gatherElements], "gatherElements", [data, indices], [output])
|
||||
)
|
||||
make_and_import_model(make_graph([gatherElements], "gatherElements", [data, indices], [output]))
|
||||
|
||||
def test_reduce_mean(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||
|
|
|
@ -11,7 +11,7 @@ proj_path = Path(sys.path[0]).parent
|
|||
def format_file(file):
|
||||
file = Path(proj_path.joinpath(file))
|
||||
if file.suffix in c_style_file:
|
||||
run(f"clang-format-14 -i {file}", cwd=proj_path, shell=True)
|
||||
run(f"clang-format-14 -style=file -i {file}", cwd=proj_path, shell=True)
|
||||
run(f"git add {file}", cwd=proj_path, shell=True)
|
||||
elif file.suffix == py_file:
|
||||
run(f"black {file}", cwd=proj_path, shell=True)
|
||||
|
|
|
@ -100,7 +100,8 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
|
|||
#define TEST_EQUAL(N) \
|
||||
if (dtype == DataType(N)) \
|
||||
return equalDataImpl(getRawDataPtr<DT<N>::t *>(), \
|
||||
rhs->getRawDataPtr<DT<N>::t *>(), size());
|
||||
rhs->getRawDataPtr<DT<N>::t *>(), size(), \
|
||||
relativeError);
|
||||
|
||||
TEST_EQUAL(0) // fmt: new line
|
||||
else TEST_EQUAL(1) //
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
#ifdef USE_BANG
|
||||
#include "bang/bang_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_KUNLUN
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_INTELCPU
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "intelcpu/operator_timer.h"
|
||||
|
@ -158,6 +161,12 @@ static int tensor_dtype(Tensor t) {
|
|||
static Ref<BangRuntimeObj> bang_runtime() { return make_ref<BangRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
#ifdef USE_KUNLUN
|
||||
static Ref<KUNLUNRuntimeObj> kunlun_runtime() {
|
||||
return make_ref<KUNLUNRuntimeObj>();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_INTELCPU
|
||||
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
||||
#endif
|
||||
|
@ -292,6 +301,10 @@ void export_functions(py::module &m) {
|
|||
#ifdef USE_BANG
|
||||
.FUNCTION(bang_runtime)
|
||||
#endif
|
||||
|
||||
#ifdef USE_KUNLUN
|
||||
.FUNCTION(kunlun_runtime)
|
||||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(conv_trans_attrs_of)
|
||||
.FUNCTION(matmul_attrs_of)
|
||||
|
@ -365,6 +378,10 @@ void init_graph_builder(py::module &m) {
|
|||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
m, "BangRuntime");
|
||||
#endif
|
||||
#ifdef USE_KUNLUN
|
||||
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
||||
m, "KUNLUNRuntime");
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
|
||||
py::buffer_protocol())
|
||||
|
|
|
@ -58,6 +58,21 @@ template <typename T> class NaiveMul : public NativeElementWise<T> {
|
|||
template <typename T> class NaiveDiv : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
|
||||
};
|
||||
template <typename T> class NaiveEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 == val1); }
|
||||
};
|
||||
template <typename T> class NaiveGreaterEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); }
|
||||
};
|
||||
template <typename T> class NaiveGreaterThan : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 > val1); }
|
||||
};
|
||||
template <typename T> class NaiveLessEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); }
|
||||
};
|
||||
template <typename T> class NaiveLessThan : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 < val1); }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
|
||||
"addNaive_CPU_uint32");
|
||||
|
@ -75,4 +90,24 @@ REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
|
|||
"divNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
|
||||
"divNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32,
|
||||
NaiveEqual<uint32_t>, "equalNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32,
|
||||
NaiveEqual<float>, "equalNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32,
|
||||
NaiveGreaterEqual<uint32_t>, "greaterEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32,
|
||||
NaiveGreaterEqual<float>, "greaterEqualNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32,
|
||||
NaiveGreaterThan<uint32_t>, "greaterThanNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32,
|
||||
NaiveGreaterThan<float>, "greaterThanNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32,
|
||||
NaiveLessEqual<uint32_t>, "lessEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32,
|
||||
NaiveLessEqual<float>, "lessEqualNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32,
|
||||
NaiveLessThan<uint32_t>, "lessEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32,
|
||||
NaiveLessThan<float>, "lessEqualNaive_CPU_float32");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -71,6 +71,26 @@ template <typename T> class NaiveSqrt : public NativeUnary<T> {
|
|||
T doCompute(T val) const override { return std::sqrt(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveCos : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::cos(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSin : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::sin(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveTan : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::tan(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSinh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::sinh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveCosh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::cosh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveGelu : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
|
||||
|
@ -81,6 +101,26 @@ template <typename T> class NaiveErf : public NativeUnary<T> {
|
|||
T doCompute(T val) const override { return std::erf(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveACos : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::acos(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveACosh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::acosh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveASin : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::asin(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveASinh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::asinh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveATanh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::atanh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveNeg : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return -val; }
|
||||
};
|
||||
|
@ -104,6 +144,43 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T> class Log : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<LogObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
auto logType = op->getType(); // get log type
|
||||
|
||||
auto len = op->getOutput()->size();
|
||||
for (size_t offset = 0; offset < len; offset++) {
|
||||
T res;
|
||||
auto val = *inptr++;
|
||||
switch (logType) {
|
||||
case LogObj::LogE:
|
||||
res = std::log(val);
|
||||
*outptr++ = res;
|
||||
break;
|
||||
case LogObj::Log2:
|
||||
res = std::log2(val);
|
||||
*outptr++ = res;
|
||||
break;
|
||||
case LogObj::Log10:
|
||||
res = std::log10(val);
|
||||
*outptr++ = res;
|
||||
break;
|
||||
default:
|
||||
printf("LogType not Defined");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveATan : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::atan(val); }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
|
||||
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
|
||||
|
@ -140,4 +217,28 @@ REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
|||
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip<float>,
|
||||
"Clip_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan<float>,
|
||||
"Atan_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log<float>,
|
||||
"Log_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos<float>,
|
||||
"Cos_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin<float>,
|
||||
"Sin_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan<float>,
|
||||
"Tan_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh<float>,
|
||||
"Sinh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh<float>,
|
||||
"Cosh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos<float>,
|
||||
"ACos_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32,
|
||||
NaiveACosh<float>, "ACosh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin<float>,
|
||||
"ASin_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32,
|
||||
NaiveASinh<float>, "ASinh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32,
|
||||
NaiveATanh<float>, "ATanh_CPU_float32");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -59,7 +59,8 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
|||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else {
|
||||
IT_TODO_HALT_MSG("GatherElements Cuda Kernel: Unsupported data type.\n");
|
||||
IT_TODO_HALT_MSG(
|
||||
"GatherElements Cuda Kernel: Unsupported data type.\n");
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int w = dims[3];
|
||||
int h = dims[2];
|
||||
int c = dims[1];
|
||||
int n = dims[0];
|
||||
auto ret = baidu::xpu::api::batch_norm_infer<float>(
|
||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||
(float *)var, true);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::BatchNormalization, DataType::Float32,
|
||||
BatchNormXdnn, "BatchNorm_xdnn_KUNLUN_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,98 @@
|
|||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
class CastXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<CastObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
CastType type = op->getType();
|
||||
|
||||
int ret = 0;
|
||||
switch (type) {
|
||||
case CastType::Float2Float16:
|
||||
ret = baidu::xpu::api::cast<float, float16>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float16 *)cData, len);
|
||||
break;
|
||||
case CastType::Float2Int64:
|
||||
ret = baidu::xpu::api::cast<float, int64_t>(
|
||||
context->KUNLUNHandle(), (float *)aData, (int64_t *)cData, len);
|
||||
break;
|
||||
case CastType::Float2Int32:
|
||||
ret = baidu::xpu::api::cast<float, int>(
|
||||
context->KUNLUNHandle(), (float *)aData, (int *)cData, len);
|
||||
break;
|
||||
case CastType::Float2Int16:
|
||||
ret = baidu::xpu::api::cast<float, int16_t>(
|
||||
context->KUNLUNHandle(), (float *)aData, (int16_t *)cData, len);
|
||||
break;
|
||||
case CastType::Float2Int8:
|
||||
ret = baidu::xpu::api::cast<float, int8_t>(
|
||||
context->KUNLUNHandle(), (float *)aData, (int8_t *)cData, len);
|
||||
break;
|
||||
case CastType::Int322Float:
|
||||
ret = baidu::xpu::api::cast<int, float>(
|
||||
context->KUNLUNHandle(), (int *)aData, (float *)cData, len);
|
||||
break;
|
||||
case CastType::Int322Int8:
|
||||
ret = baidu::xpu::api::cast<int, int8_t>(
|
||||
context->KUNLUNHandle(), (int *)aData, (int8_t *)cData, len);
|
||||
break;
|
||||
case CastType::Int322Int16:
|
||||
ret = baidu::xpu::api::cast<int, int16_t>(
|
||||
context->KUNLUNHandle(), (int *)aData, (int16_t *)cData, len);
|
||||
break;
|
||||
case CastType::Int162Float:
|
||||
ret = baidu::xpu::api::cast<int16_t, float>(
|
||||
context->KUNLUNHandle(), (int16_t *)aData, (float *)cData, len);
|
||||
break;
|
||||
case CastType::Int162Int32:
|
||||
ret = baidu::xpu::api::cast<int16_t, int>(
|
||||
context->KUNLUNHandle(), (int16_t *)aData, (int *)cData, len);
|
||||
break;
|
||||
case CastType::Int82Float:
|
||||
ret = baidu::xpu::api::cast<int8_t, float>(
|
||||
context->KUNLUNHandle(), (int8_t *)aData, (float *)cData, len);
|
||||
break;
|
||||
case CastType::Int82Int16:
|
||||
ret = baidu::xpu::api::cast<int8_t, int16_t>(
|
||||
context->KUNLUNHandle(), (int8_t *)aData, (int16_t *)cData,
|
||||
len);
|
||||
break;
|
||||
case CastType::Int82Int32:
|
||||
ret = baidu::xpu::api::cast<int8_t, int>(
|
||||
context->KUNLUNHandle(), (int8_t *)aData, (int *)cData, len);
|
||||
break;
|
||||
case CastType::Int322Int64:
|
||||
ret = baidu::xpu::api::cast<int, int64_t>(
|
||||
context->KUNLUNHandle(), (int *)aData, (int64_t *)cData, len);
|
||||
break;
|
||||
case CastType::Int642Int32:
|
||||
ret = baidu::xpu::api::cast<int64_t, int>(
|
||||
context->KUNLUNHandle(), (int64_t *)aData, (int *)cData, len);
|
||||
break;
|
||||
case CastType::Int642Float:
|
||||
ret = baidu::xpu::api::cast<int64_t, float>(
|
||||
context->KUNLUNHandle(), (int64_t *)aData, (float *)cData, len);
|
||||
break;
|
||||
case CastType::Float162Float:
|
||||
ret = baidu::xpu::api::cast<float16, float>(
|
||||
context->KUNLUNHandle(), (float16 *)aData, (float *)cData, len);
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Cast, DataType::Float32, CastXdnn,
|
||||
"Cast_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/concat.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ConcatXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
int axis = op->getDim();
|
||||
int num = op->numInputs();
|
||||
std::vector<const float *> inputsData;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
inputsData.push_back(
|
||||
(float *)(op->getInputs(i)->getRawDataPtr<void *>()));
|
||||
}
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
std::vector<std::vector<int>> dims;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto dim = op->getInputs(i)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
dims.push_back(dim);
|
||||
}
|
||||
auto ret = baidu::xpu::api::concat<float>(
|
||||
context->KUNLUNHandle(), inputsData, (float *)cData, dims, axis);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Concat, DataType::Float32, ConcatXdnn,
|
||||
"Concat_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/conv.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ConvXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
std::vector<int> pads = {ph, pw};
|
||||
std::vector<int> ksize = {r, s};
|
||||
std::vector<int> stride = {sh, sw};
|
||||
std::vector<int> dilation = {dh, dw};
|
||||
|
||||
auto ret = baidu::xpu::api::conv2d<float, float, float, float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
||||
nullptr, nullptr, nullptr, true);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Conv, DataType::Float32, ConvXdnn,
|
||||
"Conv_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,54 @@
|
|||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
namespace infini {
|
||||
class ConvTransXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
const bool isNCHW =
|
||||
(op->getOpType() == OpType::ConvTransNHWC) ? false : true;
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
std::vector<int> pads = {ph, pw};
|
||||
std::vector<int> ksize = {r, s};
|
||||
std::vector<int> stride = {sh, sw};
|
||||
std::vector<int> dilation = {dh, dw};
|
||||
|
||||
auto dimInputs0 = op->getInputs(0)->getDims();
|
||||
auto dimInputs1 = op->getInputs(1)->getDims();
|
||||
auto dimOutput = op->getOutput()->getDims();
|
||||
|
||||
if (dimInputs0.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
if (dimInputs1.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
if (dimOutput.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
auto ret =
|
||||
baidu::xpu::api::conv2d_transpose<float, float, float, float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
||||
nullptr, nullptr, nullptr, isNCHW);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTranspose, DataType::Float32,
|
||||
ConvTransXdnn, "ConvTrans_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::ConvTransNHWC, DataType::Float32,
|
||||
ConvTransXdnn, "ConvTranposedNHWC_xdnn_KUNLUN_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,476 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class AddXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_add<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SubXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MulXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class DivXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class PowXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MaxXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_max<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MinXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_min<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<int, float>(
|
||||
context->KUNLUNHandle(), (int *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MSELossXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MSELossObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
auto ret = baidu::xpu::api::mse_loss<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class AndXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_and<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class OrXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_or<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class XorXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_xor<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class NotXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
if (aDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_not<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Add, DataType::Float32, AddXdnn,
|
||||
"Add_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sub, DataType::Float32, SubXdnn,
|
||||
"Sub_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Mul, DataType::Float32, MulXdnn,
|
||||
"Mul_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Div, DataType::Float32, DivXdnn,
|
||||
"Div_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Pow, DataType::Float32, PowXdnn,
|
||||
"Pow_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Max, DataType::Float32, MaxXdnn,
|
||||
"Max_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Min, DataType::Float32, MinXdnn,
|
||||
"Min_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Equal, DataType::Float32, EqualXdnn,
|
||||
"Equal_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::GreaterOrEqual, DataType::Float32,
|
||||
GreaterEqualXdnn, "GreaterEqual_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Greater, DataType::Float32,
|
||||
GreaterThanXdnn, "GreaterThan_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::LessOrEqual, DataType::Float32,
|
||||
LessEqualXdnn, "LessEqual_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Less, DataType::Float32, LessThanXdnn,
|
||||
"LessThan_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::FloorDiv, DataType::Float32,
|
||||
FloorDivXdnn, "FloorDiv_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::MSELoss, DataType::Float32, MSELossXdnn,
|
||||
"MSELoss_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::And, DataType::Float32, AndXdnn,
|
||||
"And_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Or, DataType::Float32, OrXdnn,
|
||||
"Or_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Xor, DataType::Float32, XorXdnn,
|
||||
"Xor_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Not, DataType::Float32, NotXdnn,
|
||||
"Not_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,38 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MatmulXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
bool transA = op->getTransA();
|
||||
bool transB = op->getTransB();
|
||||
if (op->getInputs(0)->getDims().size() != 2 ||
|
||||
op->getInputs(1)->getDims().size() != 2) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
auto m = transA ? op->getInputs(0)->getDims()[1]
|
||||
: op->getInputs(0)->getDims()[0];
|
||||
auto n = transB ? op->getInputs(1)->getDims()[0]
|
||||
: op->getInputs(1)->getDims()[1];
|
||||
auto k = transA ? op->getInputs(0)->getDims()[0]
|
||||
: op->getInputs(0)->getDims()[1];
|
||||
|
||||
auto ret = baidu::xpu::api::fc<float, float, float, int>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, m, n, k, transA, transB, nullptr, nullptr, nullptr);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::MatMul, DataType::Float32, MatmulXdnn,
|
||||
"Matmul_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/pad.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class PadXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PadObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int dim_size = dim.size();
|
||||
|
||||
std::vector<int> pads = op->getPads();
|
||||
|
||||
std::cout << std::endl;
|
||||
std::vector<int> paddings_left(pads.begin(), pads.begin() + dim_size);
|
||||
std::vector<int> paddings_right(pads.begin() + dim_size, pads.end());
|
||||
|
||||
float paddingValue = 0.0;
|
||||
auto ret = baidu::xpu::api::pad<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, dim,
|
||||
paddings_left, paddings_right, paddingValue);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Pad, DataType::Float32, PadXdnn,
|
||||
"Pad_xdnn_KUNLUN_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,62 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class AvgPooling : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
||||
std::vector<int> ksize = {kh, kw};
|
||||
std::vector<int> stride = {sh, sw};
|
||||
std::vector<int> pad = {ph, pw};
|
||||
|
||||
auto ret = baidu::xpu::api::avg_pool2d<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, n, c, h, w,
|
||||
ksize, stride, pad, true, true, nullptr, nullptr);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class MaxPooling : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
||||
std::vector<int> ksize = {kh, kw};
|
||||
std::vector<int> stride = {sh, sw};
|
||||
std::vector<int> pad = {ph, pw};
|
||||
|
||||
int yh = (h + ph * 2 - kh) / sh + 1;
|
||||
int yw = (w + pw * 2 - kw) / sw + 1;
|
||||
|
||||
KUNLUNPtr indices = context->getWorkspace(yh * yw * 4);
|
||||
|
||||
auto ret = baidu::xpu::api::max_pool2d<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData,
|
||||
(int *)indices, n, c, h, w, ksize, stride, pad, true, nullptr,
|
||||
nullptr, false);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::MaxPool, DataType::Float32, MaxPooling,
|
||||
"MaxPool_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::AveragePool, DataType::Float32,
|
||||
AvgPooling, "AvgPool_xdnn_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,41 @@
|
|||
#include "operators/split.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class SplitXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SplitObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
int axis = op->getDim();
|
||||
int num = op->numOutputs();
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
auto inputDim = op->getInputs(0)->getDims();
|
||||
|
||||
std::vector<float *> outputsData;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
outputsData.push_back(
|
||||
(float *)(op->getOutput(i)->getRawDataPtr<void *>()));
|
||||
}
|
||||
|
||||
std::vector<int> splitList;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto dim = op->getOutput(i)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
splitList.push_back(dim[axis]);
|
||||
}
|
||||
|
||||
auto ret = baidu::xpu::api::split<float>(
|
||||
context->KUNLUNHandle(), (float *)inputData, outputsData, inputDim,
|
||||
splitList, axis);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Split, DataType::Float32, SplitXdnn,
|
||||
"Split_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,32 @@
|
|||
#include "operators/transpose.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class TransposeXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<TransposeObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dimin = op->getInputs(0)->getDims();
|
||||
auto permute = op->getPermute();
|
||||
|
||||
if (dimin.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
auto ret = baidu::xpu::api::transpose<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, dimin,
|
||||
permute);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Transpose, DataType::Float32,
|
||||
TransposeXdnn, "Transpose_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,550 @@
|
|||
#include "operators/unary.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::relu<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::sigmoid<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class TanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::tanh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SquareXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::square<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SqrtXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::sqrt<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class RsqrtXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::rsqrt<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ExpXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::exp<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class CeilXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::ceil<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ClipXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ClipObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
float min = op->getMin().value();
|
||||
float max = op->getMax().value();
|
||||
|
||||
auto ret = baidu::xpu::api::clip<float>(context->KUNLUNHandle(),
|
||||
(float *)aData, (float *)cData,
|
||||
len, min, max);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class FloorXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::floor<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class NegXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::neg<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class CopyXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::copy<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ReciprocalXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::reciprocal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class AbsXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::abs<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ATanXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
|
||||
auto ret = baidu::xpu::api::arctan<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class LogXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LogObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
std::vector<int> divDim = {
|
||||
1,
|
||||
};
|
||||
auto len = op->getInputs(0)->size();
|
||||
// get ptr of tempspace
|
||||
KUNLUNPtr temp = context->getWorkspace(len * sizeof(float));
|
||||
LogObj::LogType type = op->getType();
|
||||
// get output of xpu::api::loge(x)
|
||||
auto ret = baidu::xpu::api::log<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)temp, len);
|
||||
// get ptr of divider
|
||||
KUNLUNPtr dd =
|
||||
(float *)(context->getWorkspace((1 + len) * sizeof(float))) + len;
|
||||
// choose from logE, log2, log10
|
||||
switch (type) {
|
||||
float constant;
|
||||
case LogObj::LogE:
|
||||
// if use loge, copy from temp to cData
|
||||
ret = baidu::xpu::api::copy<float>(
|
||||
context->KUNLUNHandle(), (float *)temp, (float *)cData, len);
|
||||
break;
|
||||
case LogObj::Log2:
|
||||
constant = std::log(2);
|
||||
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
||||
ret = baidu::xpu::api::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)temp, (float *)dd,
|
||||
(float *)cData, aDim, divDim);
|
||||
break;
|
||||
case LogObj::Log10:
|
||||
constant = std::log(10);
|
||||
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
||||
ret = baidu::xpu::api::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)temp, (float *)dd,
|
||||
(float *)cData, aDim, divDim);
|
||||
break;
|
||||
default:
|
||||
printf("LogType not support!");
|
||||
break;
|
||||
}
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class CosXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<CosObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::cos<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SinXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SinObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::sin<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class TanXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<TanObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::tan<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SinhXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SinHObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::sinh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class CoshXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<CosHObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::cosh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ErfXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ErfObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::erf<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ACosXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ACosObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::arccos<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ACoshXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ACosHObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::acosh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ASinXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ASinObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::arcsin<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ASinhXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ASinHObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::asinh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ATanHObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto ret = baidu::xpu::api::atanh<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
||||
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, DataType::Float32, ReluXdnn,
|
||||
"Relu_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, DataType::Float32, SigmoidXdnn,
|
||||
"Sigmoid_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, DataType::Float32, TanhXdnn,
|
||||
"Tanh_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Square, DataType::Float32, SquareXdnn,
|
||||
"Square_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sqrt, DataType::Float32, SqrtXdnn,
|
||||
"Sqrt_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Rsqrt, DataType::Float32, RsqrtXdnn,
|
||||
"Rsqrt_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Exp, DataType::Float32, ExpXdnn,
|
||||
"Exp_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Ceil, DataType::Float32, CeilXdnn,
|
||||
"Ceil_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Clip, DataType::Float32, ClipXdnn,
|
||||
"Clip_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Floor, DataType::Float32, FloorXdnn,
|
||||
"Floor_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Neg, DataType::Float32, NegXdnn,
|
||||
"Neg_xdnn_KUNLUN_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, DataType::Float32,
|
||||
ReciprocalXdnn, "Reciprocal_xdnn_KUNLUN_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, DataType::Float32, CopyXdnn,
|
||||
"Reshape_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, DataType::Float32, CopyXdnn,
|
||||
"Flatten_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, DataType::Float32, CopyXdnn,
|
||||
"Identity_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, DataType::Float32, AbsXdnn,
|
||||
"Abs_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, DataType::Float32, ATanXdnn,
|
||||
"Atan_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Log, DataType::Float32, LogXdnn,
|
||||
"Log_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Cos, DataType::Float32, CosXdnn,
|
||||
"Cos_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sin, DataType::Float32, SinXdnn,
|
||||
"Sin_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tan, DataType::Float32, TanXdnn,
|
||||
"Tan_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sinh, DataType::Float32, SinhXdnn,
|
||||
"Sinh_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Cosh, DataType::Float32, CoshXdnn,
|
||||
"Cosh_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, DataType::Float32, ErfXdnn,
|
||||
"Erf_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, DataType::Float32, ACosXdnn,
|
||||
"ACos_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, DataType::Float32, ACoshXdnn,
|
||||
"ACosh_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, DataType::Float32, ASinXdnn,
|
||||
"ASin_xdnn_Float32");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, DataType::Float32, ASinhXdnn,
|
||||
"ASinh_xdnn_Float3 2");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, DataType::Float32, ATanhXdnn,
|
||||
"ATanh_xdnn_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,60 @@
|
|||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void KUNLUNRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
continue;
|
||||
}
|
||||
|
||||
PerfRecord record;
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
||||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
[&]() { sync(); }, 1, 1);
|
||||
op->print();
|
||||
printf(" op_time on kunlun xpu %lf\n", t);
|
||||
totalTime += t;
|
||||
opTime[op->getOpType()] += t;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KUNLUNRuntimeObj::run(const Graph &graph, bool tune,
|
||||
bool profiling) const {
|
||||
if (profiling)
|
||||
IT_TODO_HALT();
|
||||
runWithoutSync(graph, tune, profiling);
|
||||
sync();
|
||||
}
|
||||
|
||||
void KUNLUNRuntimeObj::sync() const { ; }
|
||||
|
||||
string KUNLUNRuntimeObj::toString() const { return "KUNLUN Runtime"; }
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,71 @@
|
|||
#include "kunlun/operator_timer.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "utils/data_generator.h"
|
||||
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
|
||||
double getPerfConvKunlun(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime kunlun = make_ref<KUNLUNRuntimeObj>();
|
||||
Graph gKunlun = make_ref<GraphObj>(kunlun);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
IT_ASSERT(c % group == 0);
|
||||
Tensor i0Cpu = gCpu->addTensor({n, h, w, c}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({f, r, s, c / group}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to Kunlun
|
||||
Tensor i0Kunlun = gKunlun->cloneTensor(i0Cpu);
|
||||
Tensor w0Kunlun = gKunlun->cloneTensor(w0Cpu);
|
||||
// Build Kunlun graph
|
||||
auto conv = gKunlun->addOp<ConvObj>(i0Kunlun, w0Kunlun, nullptr, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw);
|
||||
// allocate Kunlun memory
|
||||
gKunlun->dataMalloc();
|
||||
// Execute on Kunlun
|
||||
bool tune = true;
|
||||
kunlun->run(gKunlun, tune);
|
||||
return kunlun->getPerfTime(gKunlun);
|
||||
}
|
||||
|
||||
double getPerfMatmulKunlun(int b, int m, int n, int k, const char *name) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime kunlun = make_ref<KUNLUNRuntimeObj>();
|
||||
Graph gKunlun = make_ref<GraphObj>(kunlun);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to Kunlun
|
||||
Tensor i0Kunlun = gKunlun->cloneTensor(i0Cpu);
|
||||
Tensor w0Kunlun = gKunlun->cloneTensor(w0Cpu);
|
||||
// Build Kunlun graph
|
||||
auto conv = gKunlun->addOp<MatmulObj>(i0Kunlun, w0Kunlun, nullptr);
|
||||
// allocate Kunlun memory
|
||||
gKunlun->dataMalloc();
|
||||
// Execute on Kunlun
|
||||
bool tune = true;
|
||||
kunlun->run(gKunlun, tune);
|
||||
return kunlun->getPerfTime(gKunlun);
|
||||
}
|
||||
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -0,0 +1,61 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testAdd(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2);
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu1->setData(generator);
|
||||
inputGpu2->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
|
||||
cpuGraph->addTensor(inputCpu1);
|
||||
cpuGraph->addTensor(inputCpu2);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
inputCpu2->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(xpu_add, run) {
|
||||
testAdd<AddObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<SubObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<MulObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<DivObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<EqualObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<GreaterEqualObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<GreaterThanObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<LessEqualObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
testAdd<LessThanObj>(IncrementalGenerator(), Shape{1, 1, 1, 30});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,61 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(XPU_BatchNorm, run) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build cpu graph
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
|
||||
auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
|
||||
// Build input data on CPU
|
||||
gCpu->dataMalloc();
|
||||
iCpu->setData(IncrementalGenerator());
|
||||
meanCpu->copyin(vector<float>{1, 6, 9});
|
||||
varCpu->copyin(vector<float>{4, 1, 9});
|
||||
scaleCpu->setData(OneGenerator());
|
||||
biasCpu->setData(ZeroGenerator());
|
||||
|
||||
// Build XPU graph
|
||||
Graph g = make_ref<GraphObj>(xpuRuntime);
|
||||
|
||||
auto i = g->cloneTensor(iCpu);
|
||||
auto mean = g->cloneTensor(meanCpu);
|
||||
auto var = g->cloneTensor(varCpu);
|
||||
auto scale = g->cloneTensor(scaleCpu);
|
||||
auto bias = g->cloneTensor(biasCpu);
|
||||
auto op =
|
||||
g->addOp<BatchNormObj>(i, nullptr, mean, var, scale, bias, 0.9, 0);
|
||||
|
||||
// allocate XPU memory
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
mean->copyin(vector<float>{1, 6, 9});
|
||||
var->copyin(vector<float>{4, 1, 9});
|
||||
scale->setData(OneGenerator());
|
||||
bias->setData(ZeroGenerator());
|
||||
|
||||
// Execute on XPU
|
||||
xpuRuntime->run(g);
|
||||
|
||||
// clone XPU output to CPU
|
||||
auto o = op->getOutput();
|
||||
auto ocpu = o->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
||||
EXPECT_TRUE(ocpu->equalData(vector<float>{
|
||||
-0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.3333333, 0.6666667}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,54 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/concat.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu2->dataMalloc();
|
||||
inputCpu2->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2);
|
||||
auto gpuOp =
|
||||
xpuGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu1->setData(generator);
|
||||
inputGpu2->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// Check
|
||||
inputCpu1->print();
|
||||
inputCpu1->printData();
|
||||
inputCpu2->print();
|
||||
inputCpu2->printData();
|
||||
outputGpu2Cpu->print();
|
||||
outputGpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(xpu_Concat, run) {
|
||||
testConcat<ConcatObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,56 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testConv(const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
const Shape &shapeA, const Shape &shapeB) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shapeA, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shapeB, DataType::Float32, cpuRuntime);
|
||||
// MLU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2);
|
||||
auto mluOp =
|
||||
xpuGraph->addOp<T>(inputMlu1, inputMlu2, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
xpuGraph->dataMalloc();
|
||||
inputMlu1->setData(generatorA);
|
||||
inputMlu2->setData(generatorB);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputXpu = mluOp->getOutput();
|
||||
auto outputXpu2Cpu = outputXpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
cpuGraph->addTensor(inputCpu1);
|
||||
cpuGraph->addTensor(inputCpu2);
|
||||
auto cpuOp =
|
||||
cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu1->setData(generatorA);
|
||||
inputCpu2->setData(generatorB);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputXpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(xpu_Conv, run) {
|
||||
testConv<ConvObj>(IncrementalGenerator(), IncrementalGenerator(),
|
||||
Shape{1, 3, 32, 32}, Shape{2, 3, 3, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,136 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvTransposedXdnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and XPU
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime xpu = make_ref<KUNLUNRuntimeObj>();
|
||||
Graph gXpu = make_ref<GraphObj>(xpu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to XPU
|
||||
Tensor i0Xpu = gXpu->cloneTensor(i0Cpu);
|
||||
Tensor w0Xpu = gXpu->cloneTensor(w0Cpu);
|
||||
// Build XPU graph
|
||||
auto conv = gXpu->addOp<ConvTransposed2dObj>(i0Xpu, w0Xpu, nullptr, padding,
|
||||
padding, stride, stride,
|
||||
dilation, dilation);
|
||||
gXpu->dataMalloc();
|
||||
i0Xpu->setData(generator);
|
||||
w0Xpu->setData(generator);
|
||||
// Execute on XPU
|
||||
xpu->run(gXpu);
|
||||
// copy output from XPU to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
void testConvTransposedNHWCXdnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and XPU
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime xpu = make_ref<KUNLUNRuntimeObj>();
|
||||
Graph gXpu = make_ref<GraphObj>(xpu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, H, W, F}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, R, S, C}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to XPU
|
||||
Tensor i0Xpu = gXpu->cloneTensor(i0Cpu);
|
||||
Tensor w0Xpu = gXpu->cloneTensor(w0Cpu);
|
||||
// Build XPU graph
|
||||
auto conv = gXpu->addOp<ConvTransposed2dNHWCObj>(
|
||||
i0Xpu, w0Xpu, nullptr, padding, padding, stride, stride, dilation,
|
||||
dilation);
|
||||
gXpu->dataMalloc();
|
||||
i0Xpu->setData(generator);
|
||||
w0Xpu->setData(generator);
|
||||
// Execute on XPU
|
||||
xpu->run(gXpu);
|
||||
// copy output from XPU to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(XPU_ConvTransposed, run) {
|
||||
testConvTransposedXdnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(XPU_ConvTransposedNHWC, run) {
|
||||
testConvTransposedNHWCXdnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(XPU_ConvTransposed, run1) {
|
||||
// Construct Runtime and graph for CPU and XPU
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime xpu = make_ref<KUNLUNRuntimeObj>();
|
||||
Graph gXpu = make_ref<GraphObj>(xpu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 2, 3, 3}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 2, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to XPU
|
||||
Tensor i0Xpu = gXpu->cloneTensor(i0Cpu);
|
||||
Tensor w0Xpu = gXpu->cloneTensor(w0Cpu);
|
||||
// Build XPU graph
|
||||
auto conv = gXpu->addOp<ConvTransposed2dObj>(i0Xpu, w0Xpu, nullptr, 0, 0);
|
||||
gXpu->dataMalloc();
|
||||
i0Xpu->setData(IncrementalGenerator());
|
||||
w0Xpu->setData(IncrementalGenerator());
|
||||
// Execute on XPU
|
||||
xpu->run(gXpu);
|
||||
// copy output from XPU to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(vector<float>{
|
||||
162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553,
|
||||
747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835,
|
||||
396, 843, 1343, 953, 506, 243, 531, 866, 629, 341,
|
||||
621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518,
|
||||
963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,66 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ExpectOutput = vector<float>;
|
||||
template <class T>
|
||||
void testElementWiseXdnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
acpu->dataMalloc();
|
||||
acpu->setData(generator);
|
||||
|
||||
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
bcpu->dataMalloc();
|
||||
bcpu->setData(generator);
|
||||
|
||||
// Build XPU graph
|
||||
Graph g = make_ref<GraphObj>(xpuRuntime);
|
||||
auto a = g->cloneTensor(acpu);
|
||||
auto b = g->cloneTensor(bcpu);
|
||||
auto op = g->addOp<T>(a, b, nullptr);
|
||||
|
||||
// allocate XPU memory
|
||||
g->dataMalloc();
|
||||
a->setData(generator);
|
||||
b->setData(generator);
|
||||
|
||||
// Execute on XPU
|
||||
xpuRuntime->run(g);
|
||||
|
||||
// clone XPU output to CPU
|
||||
auto c = op->getOutput();
|
||||
auto ccpu = c->clone(cpuRuntime);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ccpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(xdnn_ElementWise, run) {
|
||||
testElementWiseXdnn<AddObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
|
||||
testElementWiseXdnn<SubObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
testElementWiseXdnn<MulObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
|
||||
testElementWiseXdnn<DivObj>(
|
||||
OneGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
testElementWiseXdnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
|
||||
ExpectOutput{1, 1, 4, 27});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,58 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testMatmul(const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA,
|
||||
const Shape &shapeB) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shapeA, DataType::Float32, cpuRuntime);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shapeB, DataType::Float32, cpuRuntime);
|
||||
|
||||
// MLU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2);
|
||||
auto mluOp = xpuGraph->addOp<T>(inputMlu1, inputMlu2, nullptr);
|
||||
xpuGraph->dataMalloc();
|
||||
inputMlu1->setData(generatorA);
|
||||
inputMlu2->setData(generatorB);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputMlu = mluOp->getOutput();
|
||||
auto outputMlu2Cpu = outputMlu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
|
||||
cpuGraph->addTensor(inputCpu1);
|
||||
cpuGraph->addTensor(inputCpu2);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu1->setData(generatorA);
|
||||
inputCpu2->setData(generatorB);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
outputCpu->print();
|
||||
outputMlu2Cpu->print();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputMlu2Cpu));
|
||||
}
|
||||
|
||||
TEST(xpu_Matmul, run) {
|
||||
testMatmul<MatmulObj>(IncrementalGenerator(), IncrementalGenerator(), false,
|
||||
false, Shape{2, 3}, Shape{3, 4});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,40 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/pad.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(xpu_Pad, run) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor icpu =
|
||||
make_ref<TensorObj>(Shape{1, 2, 3, 2}, DataType::Float32, cpuRuntime);
|
||||
|
||||
// Build XPU graph;
|
||||
Graph g = make_ref<GraphObj>(xpuRuntime);
|
||||
auto i = g->cloneTensor(icpu);
|
||||
auto op = g->addOp<PadObj>(i, nullptr, vector<int>{1, 0, 1, 1},
|
||||
vector<int>{0, 3});
|
||||
|
||||
// allocate XPU memory
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
|
||||
// Execute on XPU
|
||||
xpuRuntime->run(g);
|
||||
|
||||
// clone XPU output to CPU
|
||||
auto o = op->getOutput();
|
||||
auto cpuo = o->clone(cpuRuntime);
|
||||
cpuo->printData();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(cpuo->equalData(
|
||||
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,51 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/pooling.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp =
|
||||
xpuGraph->addOp<T>(inputGpu, nullptr, 3, 3, 1, 1, 0, 0, 2, 2, 0);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
auto cpuOp =
|
||||
cpuGraph->addOp<T>(inputCpu, nullptr, 3, 3, 1, 1, 0, 0, 2, 2, 0);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(xdnn_Pooling, run) {
|
||||
testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 1, 5, 5});
|
||||
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 1, 5, 5});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,48 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/split.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testSplit(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu1, std::nullopt, 3, 3);
|
||||
xpuGraph->dataMalloc();
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto o0Cpu = gpuOp->getOutput(0)->clone(cpuRuntime);
|
||||
auto o1Cpu = gpuOp->getOutput(1)->clone(cpuRuntime);
|
||||
auto o2Cpu = gpuOp->getOutput(2)->clone(cpuRuntime);
|
||||
// Check
|
||||
inputCpu1->print();
|
||||
inputCpu1->printData();
|
||||
o0Cpu->print();
|
||||
o0Cpu->printData();
|
||||
o1Cpu->print();
|
||||
o1Cpu->printData();
|
||||
o2Cpu->print();
|
||||
o2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(xpu_Split, run) {
|
||||
testSplit<SplitObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/transpose.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testTranspose(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
vector<int> permute = {0, 1, 3, 2};
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu, nullptr, permute);
|
||||
xpuGraph->dataMalloc();
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// Check
|
||||
inputCpu->printData();
|
||||
outputGpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(xpu_Transpose, run) {
|
||||
testTranspose<TransposeObj>(IncrementalGenerator(), Shape{1, 1, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,190 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu, nullptr);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu, 1e-6));
|
||||
}
|
||||
|
||||
void testClip(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
float min = 1.0;
|
||||
float max = 5.0;
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = xpuGraph->addOp<ClipObj>(inputGpu, nullptr, min, max);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<ClipObj>(inputCpu, nullptr, min, max);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp =
|
||||
xpuGraph->addOp<CastObj>(inputGpu, nullptr, CastType::Float2Int32);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp =
|
||||
cpuGraph->addOp<CastObj>(inputCpu, nullptr, CastType::Float2Int32);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
template <LogObj::LogType T>
|
||||
void testLog(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = xpuGraph->addOp<LogObj>(inputGpu, nullptr, T);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<LogObj>(inputCpu, nullptr, T);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void testTrigon(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu, nullptr);
|
||||
xpuGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
|
||||
cpuGraph->addTensor(inputCpu);
|
||||
cpuGraph->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu, 1e-3));
|
||||
}
|
||||
|
||||
TEST(xdnn_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(ValGenerator<-1>(), Shape{1, 2, 2, 3});
|
||||
testUnary<ATanObj>(OneGenerator(), Shape{1, 2, 2, 3});
|
||||
testLog<LogObj::Log10>(ValGenerator<2>(), Shape{1, 2, 2, 3});
|
||||
testLog<LogObj::Log2>(ValGenerator<2>(), Shape{1, 2, 2, 3});
|
||||
testLog<LogObj::LogE>(ValGenerator<2>(), Shape{1, 2, 2, 3});
|
||||
testTrigon<CosObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<SinObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<TanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<SinHObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<CosHObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<ACosObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<ACosHObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<ASinObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<ASinHObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testTrigon<ATanHObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue