forked from jiuyuan/InfiniTensor
Fix workspace & bang conv (#183)
* fix bang workspace * fix convbpdata * fix code * add code * fix * fix * fix conv * fix test conv --------- Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
a7293c12ba
commit
3ead20a23a
|
@ -118,7 +118,7 @@ if(BUILD_TEST)
|
|||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=pointer-arith")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
cnnlHandle_t cnnl;
|
||||
BangPtr workspace;
|
||||
size_t workspaceSize;
|
||||
mutable size_t cursor;
|
||||
|
||||
public:
|
||||
BangRuntimeObj() : RuntimeObj(Device::BANG) {
|
||||
|
@ -24,6 +25,7 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
cursor = 0;
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
virtual ~BangRuntimeObj() {
|
||||
|
@ -45,10 +47,15 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
|
||||
cnnlHandle_t cnnlHandle() const { return cnnl; }
|
||||
BangPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
IT_ASSERT((cursor + size) <= workspaceSize);
|
||||
cursor += size;
|
||||
void *temp = workspace;
|
||||
temp += (cursor - size);
|
||||
return temp;
|
||||
}
|
||||
|
||||
void resetWorkspace() const { cursor = 0; }
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
||||
|
|
|
@ -20,12 +20,14 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
this->resetWorkspace();
|
||||
continue;
|
||||
}
|
||||
|
||||
PerfRecord record;
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op, this);
|
||||
this->resetWorkspace();
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
|
@ -36,6 +38,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
[&]() { sync(); }, 1, 1);
|
||||
this->resetWorkspace();
|
||||
op->print();
|
||||
printf(" op_time on bang %lf\n", t);
|
||||
totalTime += t;
|
||||
|
|
|
@ -118,8 +118,8 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
|
||||
|
||||
stat = cnnlConvolutionForward(
|
||||
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aData, bDesc,
|
||||
bData, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
||||
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,
|
||||
bDataOut, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
@ -130,10 +130,10 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc,
|
||||
&wsSize);
|
||||
wsData = context->getWorkspace(wsSize);
|
||||
BangPtr wsData2 = context->getWorkspace(wsSize);
|
||||
|
||||
stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc,
|
||||
cDataIn, cDesc, cData, wsData, wsSize);
|
||||
cDataIn, cDesc, cData, wsData2, wsSize);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo);
|
||||
size_t wsSize;
|
||||
cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(),
|
||||
aDesc, bDesc, convDesc,
|
||||
bDesc, aDesc, convDesc,
|
||||
cDesc, algo, &wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
|
|
Loading…
Reference in New Issue