forked from jiuyuan/InfiniTensor
Fix workspace & bang conv (#183)
* fix bang workspace * fix convbpdata * fix code * add code * fix * fix * fix conv * fix test conv --------- Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
a7293c12ba
commit
3ead20a23a
|
@ -118,7 +118,7 @@ if(BUILD_TEST)
|
||||||
include_directories(3rd-party/googletest/googletest/include)
|
include_directories(3rd-party/googletest/googletest/include)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=pointer-arith")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ class BangRuntimeObj : public RuntimeObj {
|
||||||
cnnlHandle_t cnnl;
|
cnnlHandle_t cnnl;
|
||||||
BangPtr workspace;
|
BangPtr workspace;
|
||||||
size_t workspaceSize;
|
size_t workspaceSize;
|
||||||
|
mutable size_t cursor;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
BangRuntimeObj() : RuntimeObj(Device::BANG) {
|
BangRuntimeObj() : RuntimeObj(Device::BANG) {
|
||||||
|
@ -24,6 +25,7 @@ class BangRuntimeObj : public RuntimeObj {
|
||||||
// 10GB for Longformer
|
// 10GB for Longformer
|
||||||
// size_t longformerNum = 3lu * (1 << 30);
|
// size_t longformerNum = 3lu * (1 << 30);
|
||||||
workspaceSize = 7ll << 30; // 7 GB
|
workspaceSize = 7ll << 30; // 7 GB
|
||||||
|
cursor = 0;
|
||||||
workspace = alloc(workspaceSize);
|
workspace = alloc(workspaceSize);
|
||||||
}
|
}
|
||||||
virtual ~BangRuntimeObj() {
|
virtual ~BangRuntimeObj() {
|
||||||
|
@ -45,10 +47,15 @@ class BangRuntimeObj : public RuntimeObj {
|
||||||
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
|
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
|
||||||
cnnlHandle_t cnnlHandle() const { return cnnl; }
|
cnnlHandle_t cnnlHandle() const { return cnnl; }
|
||||||
BangPtr getWorkspace(size_t size) const {
|
BangPtr getWorkspace(size_t size) const {
|
||||||
IT_ASSERT(size <= workspaceSize);
|
IT_ASSERT((cursor + size) <= workspaceSize);
|
||||||
return workspace;
|
cursor += size;
|
||||||
|
void *temp = workspace;
|
||||||
|
temp += (cursor - size);
|
||||||
|
return temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void resetWorkspace() const { cursor = 0; }
|
||||||
|
|
||||||
void copyBlobFromCPU(void *dst, const void *src,
|
void copyBlobFromCPU(void *dst, const void *src,
|
||||||
size_t bytes) const override {
|
size_t bytes) const override {
|
||||||
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
||||||
|
|
|
@ -20,12 +20,14 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||||
auto perfData = perfEngine.getPerfData(perfKey);
|
auto perfData = perfEngine.getPerfData(perfKey);
|
||||||
if (!perfData && !tune) {
|
if (!perfData && !tune) {
|
||||||
kernel->compute(op, this);
|
kernel->compute(op, this);
|
||||||
|
this->resetWorkspace();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
PerfRecord record;
|
PerfRecord record;
|
||||||
if (!perfData) {
|
if (!perfData) {
|
||||||
record = kernel->tune(op, this);
|
record = kernel->tune(op, this);
|
||||||
|
this->resetWorkspace();
|
||||||
perfEngine.setPerfData(perfKey, record);
|
perfEngine.setPerfData(perfKey, record);
|
||||||
} else
|
} else
|
||||||
record = perfData;
|
record = perfData;
|
||||||
|
@ -36,6 +38,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||||
if (profiling) {
|
if (profiling) {
|
||||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||||
[&]() { sync(); }, 1, 1);
|
[&]() { sync(); }, 1, 1);
|
||||||
|
this->resetWorkspace();
|
||||||
op->print();
|
op->print();
|
||||||
printf(" op_time on bang %lf\n", t);
|
printf(" op_time on bang %lf\n", t);
|
||||||
totalTime += t;
|
totalTime += t;
|
||||||
|
|
|
@ -118,8 +118,8 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
|
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
|
||||||
|
|
||||||
stat = cnnlConvolutionForward(
|
stat = cnnlConvolutionForward(
|
||||||
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aData, bDesc,
|
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,
|
||||||
bData, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
bDataOut, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
@ -130,10 +130,10 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
||||||
|
|
||||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc,
|
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc,
|
||||||
&wsSize);
|
&wsSize);
|
||||||
wsData = context->getWorkspace(wsSize);
|
BangPtr wsData2 = context->getWorkspace(wsSize);
|
||||||
|
|
||||||
stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc,
|
stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc,
|
||||||
cDataIn, cDesc, cData, wsData, wsSize);
|
cDataIn, cDesc, cData, wsData2, wsSize);
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
||||||
CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo);
|
CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo);
|
||||||
size_t wsSize;
|
size_t wsSize;
|
||||||
cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(),
|
cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(),
|
||||||
aDesc, bDesc, convDesc,
|
bDesc, aDesc, convDesc,
|
||||||
cDesc, algo, &wsSize);
|
cDesc, algo, &wsSize);
|
||||||
BangPtr wsData = context->getWorkspace(wsSize);
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue