diff --git a/CMakeLists.txt b/CMakeLists.txt index b08d9f85..1101a8c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,7 +118,7 @@ if(BUILD_TEST) include_directories(3rd-party/googletest/googletest/include) endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=pointer-arith") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion diff --git a/include/bang/bang_runtime.h b/include/bang/bang_runtime.h index 684e238f..2dde7756 100644 --- a/include/bang/bang_runtime.h +++ b/include/bang/bang_runtime.h @@ -9,6 +9,7 @@ class BangRuntimeObj : public RuntimeObj { cnnlHandle_t cnnl; BangPtr workspace; size_t workspaceSize; + mutable size_t cursor; public: BangRuntimeObj() : RuntimeObj(Device::BANG) { @@ -24,6 +25,7 @@ class BangRuntimeObj : public RuntimeObj { // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); workspaceSize = 7ll << 30; // 7 GB + cursor = 0; workspace = alloc(workspaceSize); } virtual ~BangRuntimeObj() { @@ -45,10 +47,15 @@ class BangRuntimeObj : public RuntimeObj { void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); } cnnlHandle_t cnnlHandle() const { return cnnl; } BangPtr getWorkspace(size_t size) const { - IT_ASSERT(size <= workspaceSize); - return workspace; + IT_ASSERT((cursor + size) <= workspaceSize); + cursor += size; + void *temp = workspace; + temp += (cursor - size); + return temp; } + void resetWorkspace() const { cursor = 0; } + void copyBlobFromCPU(void *dst, const void *src, size_t bytes) const override { checkBangError(cnrtMemcpy(dst, const_cast(src), bytes, diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index 9d422a56..c9f9a933 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -20,12 +20,14 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, auto perfData = perfEngine.getPerfData(perfKey); if (!perfData && !tune) { kernel->compute(op, this); + this->resetWorkspace(); continue; } PerfRecord record; if (!perfData) { record = kernel->tune(op, this); + this->resetWorkspace(); perfEngine.setPerfData(perfKey, record); } else record = perfData; @@ -36,6 +38,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, if (profiling) { double t = timeit([&]() { kernel->compute(op, record, this); }, [&]() { sync(); }, 1, 1); + this->resetWorkspace(); op->print(); printf(" op_time on bang %lf\n", t); totalTime += t; diff --git a/src/kernels/bang/conv.cc b/src/kernels/bang/conv.cc index e55c749e..d9ff3df8 100644 --- a/src/kernels/bang/conv.cc +++ b/src/kernels/bang/conv.cc @@ -118,8 +118,8 @@ class ConvCnnl : public BangKernelWithoutConfig { cnnlGetTensorElementNum(cInDesc) * sizeof(float)); stat = cnnlConvolutionForward( - context->cnnlHandle(), convDesc, algo, NULL, aDesc, aData, bDesc, - bData, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn); + context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc, + bDataOut, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn); if (stat != CNNL_STATUS_SUCCESS) return; @@ -130,10 +130,10 @@ class ConvCnnl : public BangKernelWithoutConfig { cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc, &wsSize); - wsData = context->getWorkspace(wsSize); + BangPtr wsData2 = context->getWorkspace(wsSize); stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc, - cDataIn, cDesc, cData, wsData, wsSize); + cDataIn, cDesc, cData, wsData2, wsSize); if (stat != CNNL_STATUS_SUCCESS) return; diff --git a/src/kernels/bang/conv_trans.cc b/src/kernels/bang/conv_trans.cc index baa84c8a..a081e279 100644 --- a/src/kernels/bang/conv_trans.cc +++ b/src/kernels/bang/conv_trans.cc @@ -57,7 +57,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig { CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo); size_t wsSize; cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(), - aDesc, bDesc, convDesc, + bDesc, aDesc, convDesc, cDesc, algo, &wsSize); BangPtr wsData = context->getWorkspace(wsSize);