Fix workspace & bang conv (#183)

* fix bang workspace

* fix convbpdata

* fix code

* add code

* fix

* fix

* fix conv

* fix test conv

---------

Co-authored-by: wanghailu <wanghailu0717@163.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
Hardy 2023-11-24 15:16:25 +08:00 committed by GitHub
parent a7293c12ba
commit 3ead20a23a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 8 deletions

View File

@ -118,7 +118,7 @@ if(BUILD_TEST)
include_directories(3rd-party/googletest/googletest/include)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=pointer-arith")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion

View File

@ -9,6 +9,7 @@ class BangRuntimeObj : public RuntimeObj {
cnnlHandle_t cnnl;
BangPtr workspace;
size_t workspaceSize;
mutable size_t cursor;
public:
BangRuntimeObj() : RuntimeObj(Device::BANG) {
@ -24,6 +25,7 @@ class BangRuntimeObj : public RuntimeObj {
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
cursor = 0;
workspace = alloc(workspaceSize);
}
virtual ~BangRuntimeObj() {
@ -45,10 +47,15 @@ class BangRuntimeObj : public RuntimeObj {
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
cnnlHandle_t cnnlHandle() const { return cnnl; }
BangPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;
IT_ASSERT((cursor + size) <= workspaceSize);
cursor += size;
void *temp = workspace;
temp += (cursor - size);
return temp;
}
void resetWorkspace() const { cursor = 0; }
void copyBlobFromCPU(void *dst, const void *src,
size_t bytes) const override {
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,

View File

@ -20,12 +20,14 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
auto perfData = perfEngine.getPerfData(perfKey);
if (!perfData && !tune) {
kernel->compute(op, this);
this->resetWorkspace();
continue;
}
PerfRecord record;
if (!perfData) {
record = kernel->tune(op, this);
this->resetWorkspace();
perfEngine.setPerfData(perfKey, record);
} else
record = perfData;
@ -36,6 +38,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
if (profiling) {
double t = timeit([&]() { kernel->compute(op, record, this); },
[&]() { sync(); }, 1, 1);
this->resetWorkspace();
op->print();
printf(" op_time on bang %lf\n", t);
totalTime += t;

View File

@ -118,8 +118,8 @@ class ConvCnnl : public BangKernelWithoutConfig {
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
stat = cnnlConvolutionForward(
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aData, bDesc,
bData, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,
bDataOut, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
if (stat != CNNL_STATUS_SUCCESS)
return;
@ -130,10 +130,10 @@ class ConvCnnl : public BangKernelWithoutConfig {
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc,
&wsSize);
wsData = context->getWorkspace(wsSize);
BangPtr wsData2 = context->getWorkspace(wsSize);
stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc,
cDataIn, cDesc, cData, wsData, wsSize);
cDataIn, cDesc, cData, wsData2, wsSize);
if (stat != CNNL_STATUS_SUCCESS)
return;

View File

@ -57,7 +57,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo);
size_t wsSize;
cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(),
aDesc, bDesc, convDesc,
bDesc, aDesc, convDesc,
cDesc, algo, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize);