This commit is contained in:
wanghailu0717 2023-10-20 14:08:39 +08:00
parent 1184fa131f
commit b6ff4514fe
13 changed files with 394 additions and 4 deletions

View File

@ -2,8 +2,9 @@
option(USE_CUDA "Support CUDA GPU" OFF)
option(USE_BANG "Support BANG MLU" OFF)
option(USE_KUNLUN "Support KUNLUN XPU" OFF)
option(USE_ASCEND "Support HUAWEI ASCEND" OFF)
option(USE_INTELCPU "Support INTELCPU" OFF)
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" OFF)
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
option(BUILD_DIST "Build project for distributed running" OFF)
option(BUILD_TEST "Build tests" OFF)
@ -140,6 +141,11 @@ if(USE_KUNLUN)
list (APPEND SRC ${SRC_KUNLUN})
endif()
if(USE_ASCEND)
file(GLOB_RECURSE SRC_ASCEND src/ascend/*.cc src/kernels/ascend/*.cc )
list (APPEND SRC ${SRC_ASCEND})
endif()
if(USE_INTELCPU)
file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc )
list (APPEND SRC ${SRC_INTELCPU})
@ -270,7 +276,6 @@ if(USE_KUNLUN)
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64")
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/XTDK/shlib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
@ -280,10 +285,40 @@ if(USE_KUNLUN)
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
endif()
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++)
endif()
if(USE_ASCEND)
add_compile_definitions(USE_ASCEND=1)
if ((NOT DEFINED ASCEND_HOME) AND (NOT DEFINED ENV{ASCEND_HOME}))
message(FATAL_ERROR "ASCEND_HOME is not defined from cmake or env")
elseif (DEFINED ASCEND_HOME)
set(ASCEND_HOME ${ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Kunlun development")
else()
set(ASCEND_HOME $ENV{ASCEND_HOME} CACHE STRING "ASCEND_HOME directory for Kunlun development")
endif()
message(STATUS "ASCEND_HOME: ${ASCEND_HOME}")
include_directories("${ASCEND_HOME}/include/")
include_directories("${ASCEND_HOME}/include/aclnn")
find_library(ASCEND_CL libascendcl.so "${ASCEND_HOME}/lib64")
find_library(ASCEND_BASE libnnopbase.so "${ASCEND_HOME}/lib64")
find_library(ASCEND_DNN libopapi.so "${ASCEND_HOME}/lib64")
find_library(ASCEND_HAL libascend_hal.so "${ASCEND_HOME}/../../driver/lib64/driver")
find_library(ASCEND_RT libruntime.so "${ASCEND_HOME}/lib64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
elseif(DEFINED TARGET_CPU_ARCH)
set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
else()
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
endif()
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
target_link_libraries(InfiniTensor ${ASCEND_HAL} ${ASCEND_RT} ${ASCEND_CL} ${ASCEND_BASE} ${ASCEND_DNN} stdc++)
endif()
# # Python bindings
# pybind11_add_module(infini MODULE ${FFI})
# target_link_libraries(infini PRIVATE infini_cpp)
@ -319,6 +354,9 @@ if(BUILD_TEST)
if (USE_KUNLUN)
build_test(test/kernels/kunlun/*.cc)
endif()
if (USE_ASCEND)
build_test(test/kernels/ascend/*.cc)
endif()
if (USE_INTELCPU)
build_test(test/kernels/intelcpu/*.cc)
endif()

View File

@ -4,6 +4,7 @@ TYPE ?= Release
CUDA ?= OFF
BANG ?= OFF
KUNLUN ?= OFF
ASCEND ?= OFF
INTELCPU ?= off
BACKTRACE ?= ON
TEST ?= ON
@ -27,6 +28,7 @@ CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
CMAKE_OPT += -DUSE_BANG=$(BANG)
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
CMAKE_OPT += -DUSE_ASCEND=$(ASCEND)
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
CMAKE_OPT += -DBUILD_TEST=$(TEST)

View File

@ -140,6 +140,13 @@
make install-python KUNLUN=ON
```
编译 CPU 部分,同时编译华为 ASCEND 部分:
```bash
export ASCEND_HOME=/path/to/your/ascend_home
make install-python ASCEND=ON
```
3. 使用方法
安装成功后,您就可以使用本项目的 Python 接口进行编码并运行。具体使用方式可以参考项目样例代码 example/Resnet/resnet.py 以及用户使用手册

View File

@ -27,6 +27,7 @@
- `CUDA`:是否编译 CUDA 后端,默认为 `OFF``ON` 打开
- `BANG`:是否编译寒武纪后端,默认为 `OFF``ON` 打开
- `KUNLUN`:是否编译昆仑后端,默认为 `OFF``ON` 打开
- `ASCEND`:是否编译华为后端,默认为 `OFF``ON` 打开
- `BACKTRACE`:是否启用栈回溯,默认为 `ON``OFF` 关闭,建议调试时打开
- `TEST`:是否编译 `googletest`,默认为 `ON``OFF` 关闭,只有 `test-cpp` 时必要

12
env.sh
View File

@ -36,3 +36,15 @@ export LD_LIBRARY_PATH="${NEUWARE_HOME}/lib64:${LD_LIBRARY_PATH}"
# ├── version
# └── XTDK
export KUNLUN_HOME=/usr/local/xpu
# 配置华为ASCEND NPU 的 HOME 路径,请注意 /usr/local/ascend 是昆仑芯软件栈提供的软件包路径。
# 如若用户有其他的路径安装方式,请自行配置正确的路径。
# 这里是 ascend 目录下一个可能的结构图,请参考。
# .
# ├── bin
# ├── include
# ├── lib64
# ├── tools
# ├── version
# └── XTDK
export ASCEND_HOME=/usr/local/Ascend/ascend-toolkit/6.3

View File

@ -0,0 +1,20 @@
#pragma once
#include "core/common.h"
#include "acl/acl.h"
#include "acl/acl_op.h"
#define checkASCENDError(call) \
{ \
auto err = call; \
if (ACL_SUCCESS != err) { \
fprintf(stderr, "ASCEND error in %s:%i : .\n", __FILE__, \
__LINE__); \
exit(EXIT_FAILURE); \
} \
}
namespace infini {
using ASCENDPtr = void *;
} // namespace infini

View File

@ -0,0 +1,24 @@
#pragma once
#include "core/kernel.h"
#include "ascend/ascend_runtime.h"
namespace infini {
class ASCENDKernelWithoutConfig : public Kernel {
public:
virtual void compute(const Operator &op, const PerfRecord &record,
const RuntimeObj *context) const {
compute(op, context);
}
virtual void compute(const Operator &op,
const RuntimeObj *context) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op,
const RuntimeObj *_context) const {
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
[&]() { context->sync(); }));
}
};
} // namespace infini

View File

@ -0,0 +1,103 @@
#pragma once
#include "core/runtime.h"
#include "ascend/ascend_common.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
namespace infini {
class ASCENDRuntimeObj : public RuntimeObj {
private:
aclrtContext aclnn;
aclrtStream stream;
ASCENDPtr workspace;
size_t workspaceSize;
public:
ASCENDRuntimeObj(int deviceId = 0)
: RuntimeObj(Device::ASCEND, deviceId) {
std::cout<<"1------------------------------------"<< std::endl;
auto ret = aclrtSetDevice(deviceId);
std::cout<<"2------------------------------------"<< std::endl;
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret));
ret = aclrtCreateContext(&aclnn, deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret));
ret = aclrtSetCurrentContext(aclnn);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret));
ret = aclrtCreateStream(&stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret));
ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 3ll << 30; // 3 GB
std::cout<<"3------------------------------------"<< std::endl;
// std::cout<<workspaceSize/1024/1024/1024<< std::endl;
// std::cout<<std::bitset<64>(workspaceSize)<< std::endl;
workspace = alloc(workspaceSize);
}
virtual ~ASCENDRuntimeObj() {
dealloc(workspace);
aclrtDestroyStream(stream);
aclrtDestroyContext(aclnn);
aclrtResetDevice(deviceId);
aclFinalize();
}
string toString() const override;
void run(const Graph &graph, bool tune = false,
bool profiling = false) const;
// double runEvaluation(const Graph &graph, int nWarmups,
// int nEvaluations) const;
void sync() const;
ASCENDPtr alloc(size_t size) override {
void *ptr;
checkASCENDError(
aclrtMalloc((void **)&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
return ptr;
}
void dealloc(void *ptr) override { aclrtFree(ptr); }
aclrtContext *ASCENDHandle() const { return nullptr; }
ASCENDPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;
}
void copyBlobFromCPU(void *dst, const void *src,
size_t bytes) const override {
aclrtMemcpy(dst, 1024*1024*1024, const_cast<void *>(src), bytes,
ACL_MEMCPY_HOST_TO_DEVICE);
}
void copyBlobToCPU(void *dst, const void *src,
size_t bytes) const override {
aclrtMemcpy(dst, 1024*1024*1024, const_cast<void *>(src), bytes,
ACL_MEMCPY_DEVICE_TO_HOST);
}
void copyBlobInsideRuntime(void *dst, const void *src,
size_t bytes) const override {
aclrtMemcpy(dst, 1024*1024*1024, const_cast<void *>(src), bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE);
}
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
private:
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
};
} // namespace infini

View File

@ -30,7 +30,7 @@ using OpLists = list<Operator>;
using VType = uint32_t;
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN };
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN, ASCEND };
/***************** Forward declaration end *****************/
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
@ -73,6 +73,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
bool isCuda() const { return device == Device::CUDA; }
bool isBang() const { return device == Device::BANG; }
bool isKUNLUN() const { return device == Device::KUNLUN; }
bool isAscend() const { return device == Device::ASCEND; }
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
// TODO: unify these copy APIs
virtual void copyBlobFromCPU(void *dst, const void *src,

View File

@ -0,0 +1,60 @@
#include "ascend/ascend_runtime.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
namespace infini {
void ASCENDRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
bool profiling = false) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance();
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
if (!perfData && !tune) {
kernel->compute(op, this);
continue;
}
PerfRecord record;
if (!perfData) {
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = perfData;
double t = record->time;
totalTime += t;
if (profiling) {
double t = timeit([&]() { kernel->compute(op, record, this); },
[&]() { sync(); }, 1, 1);
op->print();
printf(" op_time on kunlun xpu %lf\n", t);
totalTime += t;
opTime[op->getOpType()] += t;
opCnt[op->getOpType()]++;
}
}
}
void ASCENDRuntimeObj::run(const Graph &graph, bool tune,
bool profiling) const {
if (profiling)
IT_TODO_HALT();
runWithoutSync(graph, tune, profiling);
sync();
}
void ASCENDRuntimeObj::sync() const { ; }
string ASCENDRuntimeObj::toString() const { return "ASCEND Runtime"; }
} // namespace infini

View File

@ -27,6 +27,9 @@
#ifdef USE_KUNLUN
#include "kunlun/kunlun_runtime.h"
#endif
#ifdef USE_ASCEND
#include "ascend/ascend_runtime.h"
#endif
#ifdef USE_INTELCPU
#include "intelcpu/mkl_runtime.h"
#include "intelcpu/operator_timer.h"
@ -167,6 +170,12 @@ static Ref<KUNLUNRuntimeObj> kunlun_runtime() {
}
#endif
#ifdef USE_ASCEND
static Ref<ASCENDRuntimeObj> ascend_runtime() {
return make_ref<ASCENDRuntimeObj>();
}
#endif
#ifdef USE_INTELCPU
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
#endif
@ -305,6 +314,10 @@ void export_functions(py::module &m) {
#ifdef USE_KUNLUN
.FUNCTION(kunlun_runtime)
#endif
#ifdef USE_ASCEND
.FUNCTION(ascend_runtime)
#endif
.FUNCTION(conv_attrs_of)
.FUNCTION(conv_trans_attrs_of)
.FUNCTION(matmul_attrs_of)
@ -383,6 +396,11 @@ void init_graph_builder(py::module &m) {
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
m, "KUNLUNRuntime");
#endif
#ifdef USE_ASCEND
py::class_<ASCENDRuntimeObj, std::shared_ptr<ASCENDRuntimeObj>, RuntimeObj>(
m, "ASCENDRuntime");
#endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
py::buffer_protocol())
.def("fuid", &TensorObj::getFuid, policy::automatic)

View File

@ -0,0 +1,57 @@
#include "operators/unary.h"
#include "ascend/ascend_kernel_without_config.h"
#include "ascend/ascend_runtime.h"
#include "aclnnop/level2/aclnn_relu.h"
namespace infini {
class ReluAclnn : public ASCENDKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto a = op->getInputs(0)->getDims();
std::vector<int64_t> aDim(a.size(), 1);
for(size_t i = 0; i < a.size(); ++i) {
aDim[i] = int64_t(a[i]);
}
auto aS = op->getInputs(0)->getStride();
std::vector<int64_t> aStride(aS.size(), 1);
for(size_t i = 0; i < aS.size(); ++i) {
aStride[i] = int64_t(aS[i]);
}
auto c = op->getInputs(0)->getDims();
std::vector<int64_t> cDim(c.size(), 1);
for(size_t i = 0; i < c.size(); ++i) {
cDim[i] = int64_t(c[i]);
}
auto cS = op->getInputs(0)->getStride();
std::vector<int64_t> cStride(cS.size(), 1);
for(size_t i = 0; i < cS.size(); ++i) {
cStride[i] = int64_t(cS[i]);
}
auto input = aclCreateTensor(aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
auto output = aclCreateTensor(cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
uint64_t workspaceSize = 0;
aclOpExecutor* executor;
auto ret = aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor);
void* workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
}
assert(ret == ACL_SUCCESS);
ret = aclnnRelu(workspaceAddr, workspaceSize, executor, context->ASCENDHandle());
assert(ret == ACL_SUCCESS);
ret = aclrtSynchronizeStream(context->ASCENDHandle());
assert(ret == ACL_SUCCESS);
return;
}
};
}

View File

@ -0,0 +1,47 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "ascend/ascend_runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto xpuRuntime = make_ref<ASCENDRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
// GPU
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
auto inputGpu = xpuGraph->cloneTensor(inputCpu);
auto gpuOp = xpuGraph->addOp<T>(inputGpu, nullptr);
xpuGraph->dataMalloc();
inputGpu->setData(generator);
xpuRuntime->run(xpuGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// CPU
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
cpuGraph->addTensor(inputCpu);
cpuGraph->dataMalloc();
inputCpu->setData(generator);
cpuRuntime->run(cpuGraph);
auto outputCpu = cpuOp->getOutput();
// Check
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu, 1e-6));
}
TEST(ascend_Unary, run) {
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini