Add bangc runtime and element-wise kernels

* add code for cambricon mlu, bang, cnnl

* add code for support cambricon mlu,cnnl,cnrt

* add code for support mlu

* add code for support cambricon cnnl

* add code for support mlu

* add code for mlu

* add code for mlu
`

* Update CMakeLists.txt

Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
Co-authored-by: zhengly123 <zhengly123@outlook.com>
This commit is contained in:
Hardy 2022-09-22 16:57:39 +08:00 committed by GitHub
parent 90eb9d05a8
commit c7c974f07a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 491 additions and 3 deletions

View File

@ -3,7 +3,8 @@ include(CMakeDependentOption)
project(InfiniTensor C CXX)
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
option(USE_CUDA "Support CUDA GPU" ON)
option(USE_CUDA "Support CUDA GPU" OFF)
option(USE_BANG "Support BANG MLU" OFF)
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
option(USE_PROTOBUF "Serialize and deserialize tensors" ON)
option(BUILD_TEST "Build tests" ON)
@ -81,6 +82,11 @@ if(USE_CUDA)
list (APPEND SRC ${SRC_CUDA})
endif()
if(USE_BANG)
file(GLOB_RECURSE SRC_BANG src/bang/*.cc src/kernels/bang/*.cc )
list (APPEND SRC ${SRC_BANG})
endif()
# Libraries
add_library(InfiniTensor SHARED ${SRC})
if(USE_PROTOBUF)
@ -109,6 +115,86 @@ if(USE_CUDA)
target_link_libraries(InfiniTensor cudnn curand cublas ${CUDA_LIBRARIES})
endif()
if(USE_BANG)
################################################################################
# Neuware Evironment
################################################################################
# cnrt cndrv cnnl
if ((NOT DEFINED NEUWARE_HOME) AND (NOT DEFINED ENV{NEUWARE_HOME}))
message(FATAL_ERROR "NEUWARE_HOME is not defined from cmake or env")
elseif (DEFINED NEUWARE_HOME)
set(NEUWARE_HOME ${NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development")
else()
set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development")
endif()
message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}")
include_directories("${NEUWARE_HOME}/include")
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
elseif(DEFINED TARGET_CPU_ARCH)
set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
else()
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
endif()
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
################################################################################
# Sample Kernels
################################################################################
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "$ENV{NEUWARE_HOME}/cmake" "$ENV{NEUWARE_HOME}/cmake/modules")
find_package(BANG)
if(NOT BANG_FOUND)
message(FATAL_ERROR "BANG cannot be found.")
elseif(NOT BANG_CNCC_EXECUTABLE)
message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake")
endif()
set(BANG_CNCC_FLAGS "-Wall -Werror -fPIC -std=c++11 --target=${TARGET_CPU_ARCH} -O3")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}"
"--bang-arch=compute_20"
"--bang-arch=compute_30"
"--bang-mlu-arch=mtp_322"
"--bang-wram-align64"
)
if(${TARGET_CPU_ARCH} MATCHES "aarch64-linux-gnu")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)
execute_process(
COMMAND uname -m
OUTPUT_VARIABLE _uname_m
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if (NOT ("${TARGET_CPU_ARCH}" MATCHES ".*${_uname_m}.*" AND "${_uname_m}" MATCHES "aarch64"))
execute_process(
COMMAND "${CMAKE_CXX_COMPILER}" "-v" "-c" "-x" "c++" "/dev/null" "-M"
ERROR_VARIABLE _cxx_verbose
)
execute_process(
COMMAND "echo" "${_cxx_verbose}"
COMMAND "sed" "-n" "/include.*search starts here/,/End of search list/{s/^ //p}"
COMMAND "tr" "'\n'" ";"
OUTPUT_VARIABLE _cxx_includes
)
list(REMOVE_ITEM _cxx_includes "/usr/include")
foreach(_include ${_cxx_includes})
message(STATUS "add include path: ${_include}")
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -idirafter ${_include}")
endforeach()
endif()
endif()
#bang_add_library(bangops SHARED ${SRC_BANG})
#target_link_libraries(bangops ${CAMBRICON_CNDRV})
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
#target_link_libraries(InfiniTensor bangops)
endif()
# # Python bindings
# pybind11_add_module(infini MODULE ${FFI})
# target_link_libraries(infini PRIVATE infini_cpp)
@ -135,6 +221,9 @@ if(BUILD_TEST)
if (USE_CUDA)
build_test(test/kernels/cuda/*.cc)
endif()
if (USE_BANG)
build_test(test/kernels/bang/*.cc)
endif()
endif()
if(BUILD_TEST_PET)
build_test(test/pet/*.cc)

View File

@ -0,0 +1,30 @@
#pragma once
#include "cnnl.h"
#include "cnrt.h"
#include "core/common.h"
#define checkBangError(call) \
{ \
auto err = call; \
if (CNRT_RET_SUCCESS != err) { \
fprintf(stderr, "Bang error in %s:%i : %s.\n", __FILE__, __LINE__, \
cnrtGetErrorStr(err)); \
exit(EXIT_FAILURE); \
} \
}
#define checkCnnlError(call) \
{ \
auto err = call; \
if (CNNL_STATUS_SUCCESS != err) { \
fprintf(stderr, "cnnl error in %s:%i : %s.\n", __FILE__, __LINE__, \
cnnlGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
namespace infini {
using BangPtr = void *;
} // namespace infini

View File

@ -0,0 +1,24 @@
#pragma once
#include "bang/bang_runtime.h"
#include "core/kernel.h"
namespace infini {
class BangKernelWithoutConfig : public Kernel {
public:
virtual void compute(const Operator &op, const PerfRecord &record,
const RuntimeObj *context) const {
compute(op, context);
}
virtual void compute(const Operator &op,
const RuntimeObj *context) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op,
const RuntimeObj *_context) const {
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
[&]() { context->sync(); }));
}
};
} // namespace infini

View File

@ -0,0 +1,71 @@
#pragma once
#include "bang/bang_common.h"
#include "core/runtime.h"
namespace infini {
class BangRuntimeObj : public RuntimeObj {
private:
cnnlHandle_t cnnl;
BangPtr workspace;
size_t workspaceSize;
public:
BangRuntimeObj() : RuntimeObj(Device::BANG) {
checkBangError(cnrtInit(0));
cnrtDev_t dev;
checkBangError(cnrtGetDeviceHandle(&dev, 0));
checkBangError(cnrtSetCurrentDevice(dev));
cnrtQueue_t queue;
checkBangError(cnrtCreateQueue(&queue));
checkCnnlError(cnnlCreate(&cnnl));
checkCnnlError(cnnlSetQueue(cnnl, queue));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
workspace = alloc(workspaceSize);
}
virtual ~BangRuntimeObj() {
dealloc(workspace);
checkCnnlError(cnnlDestroy(cnnl));
}
void run(const Graph &graph, bool tune = false,
bool profiling = false) const;
// double runEvaluation(const Graph &graph, int nWarmups,
// int nEvaluations) const;
void sync() const;
BangPtr alloc(size_t size) override {
void *ptr;
checkBangError(cnrtMalloc(&ptr, size));
return ptr;
}
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
cnnlHandle_t cnnlHandle() const { return cnnl; }
BangPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;
}
void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override {
checkBangError(
cnrtMemcpy(dst, src, bytes, CNRT_MEM_TRANS_DIR_HOST2DEV));
}
void copyBlobToCPU(void *dst, void *src, size_t bytes) const override {
checkBangError(
cnrtMemcpy(dst, src, bytes, CNRT_MEM_TRANS_DIR_DEV2HOST));
}
void copyBlobInsideRuntime(void *dst, void *src,
size_t bytes) const override {
checkBangError(
cnrtMemcpy(dst, src, bytes, CNRT_MEM_TRANS_DIR_PEER2PEER));
}
private:
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
};
} // namespace infini

View File

@ -26,7 +26,7 @@ using OpVec = vector<Operator>;
using VType = uint32_t;
enum class Device { CPU = 1, CUDA };
enum class Device { CPU = 1, CUDA, BANG };
/***************** Forward declaration end *****************/
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
@ -64,6 +64,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
Blob allocBlob(size_t size);
bool isCpu() const { return device == Device::CPU; }
bool isCuda() const { return device == Device::CUDA; }
bool isBang() const { return device == Device::BANG; }
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
protected:
@ -99,4 +100,4 @@ class CpuRuntimeObj : public RuntimeObj {
size_t bytes) const override;
};
} // namespace infini
} // namespace infini

57
src/bang/bang_runtime.cc Normal file
View File

@ -0,0 +1,57 @@
#include "bang/bang_runtime.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
namespace infini {
void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
bool profiling = false) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance();
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
if (!perfData && !tune) {
kernel->compute(op, this);
continue;
}
PerfRecord record;
if (!perfData) {
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = perfData;
double t = record->time;
totalTime += t;
if (profiling) {
double t = timeit([&]() { kernel->compute(op, record, this); },
[&]() { sync(); }, 1, 1);
op->print();
printf(" op_time on bang %lf\n", t);
totalTime += t;
opTime[op->getOpType()] += t;
opCnt[op->getOpType()]++;
}
}
}
void BangRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
if (profiling)
IT_TODO_HALT();
runWithoutSync(graph, tune, profiling);
sync();
}
void BangRuntimeObj::sync() const { cnrtSyncDevice(); }
} // namespace infini

View File

@ -0,0 +1,103 @@
#include "operators/element_wise.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
namespace infini {
class ElementWiseCnnl : public BangKernelWithoutConfig {
virtual cnnlOpTensorDesc_t getOpType() const = 0;
virtual tuple<float, float, float> getAlphBeta() const {
return {1.f, 1.f, 0.f};
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get op descriptor
cnnlOpTensorDescriptor_t opDesc;
checkCnnlError(cnnlCreateOpTensorDescriptor(&opDesc));
checkCnnlError(cnnlSetOpTensorDescriptor(
opDesc, getOpType(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN));
size_t wsSize;
cnnlGetOpTensorWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize);
auto [aAlpha, bAlpha, beta] = getAlphBeta();
cnnlStatus_t stat = cnnlOpTensor(context->cnnlHandle(), opDesc, &aAlpha,
aDesc, aData, &bAlpha, bDesc, bData,
wsData, wsSize, &beta, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
checkCnnlError(cnnlDestroyOpTensorDescriptor(opDesc));
}
};
class AddCnnl : public ElementWiseCnnl {
cnnlOpTensorDesc_t getOpType() const override { return CNNL_OP_TENSOR_ADD; }
};
class SubCnnl : public ElementWiseCnnl {
cnnlOpTensorDesc_t getOpType() const override { return CNNL_OP_TENSOR_ADD; }
tuple<float, float, float> getAlphBeta() const override {
return {1.f, -1.f, 0.f};
}
};
class MulCnnl : public ElementWiseCnnl {
cnnlOpTensorDesc_t getOpType() const override { return CNNL_OP_TENSOR_MUL; }
};
// class ElementWiseBang : public BangKernelWithoutConfig {
// void compute(const Operator &_op,
// const RuntimeObj *_context) const override {
// element_wise_kernel(_op);
// }
// };
REGISTER_KERNEL(Device::BANG, OpType::Add, DataType::Float32, AddCnnl,
"Add_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl,
"Sub_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
"Mul_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32,
// ElementWiseBang,
// "Div_Bang_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32,
// ElementWiseBang,
// "Pow_Bang_Float32");
}; // namespace infini

View File

@ -0,0 +1,59 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<float>;
template <class T>
void testElementWiseCnnl(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape, const ExpectOutput &ansVec) {
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
acpu->dataMalloc();
acpu->setData(generator);
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
bcpu->dataMalloc();
bcpu->setData(generator);
// Build BANG graph
Graph g = make_ref<GraphObj>(bangRuntime);
auto a = g->cloneTensor(acpu);
auto b = g->cloneTensor(bcpu);
auto op = g->addOp<T>(a, b, nullptr);
// allocate BANG memory
g->dataMalloc();
// Execute on BANG
bangRuntime->run(g);
// clone BANG output to CPU
auto c = op->getOutput();
auto ccpu = c->clone(cpuRuntime);
// check results on CPU
EXPECT_TRUE(ccpu->equalData(ansVec));
}
TEST(cnnl_ElementWise, run) {
testElementWiseCnnl<AddObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
testElementWiseCnnl<SubObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
testElementWiseCnnl<MulObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
}
} // namespace infini

View File

@ -0,0 +1,54 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testOptensor(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// CPU
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
cpuGraph->dataMalloc();
cpuRuntime->run(cpuGraph);
auto outputCpu = cpuOp->getOutput();
// Check
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
}
TEST(cuDNN_OpTensor, run) {
testOptensor<AddObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testOptensor<SubObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testOptensor<MulObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini