From fe1afe38fa4472e28c0c45e1f91d12d592cb4dc5 Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Wed, 29 Mar 2023 15:47:32 +0800 Subject: [PATCH] fix code of bang conv (#76) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix code of bang conv * test: 向 master push 时也执行 ci Signed-off-by: YdrMaster --------- Signed-off-by: YdrMaster Co-authored-by: wanghailu Co-authored-by: YdrMaster --- .github/workflows/build.yml | 2 ++ .github/workflows/clang-format-check.yml | 2 ++ include/bang/bang_runtime.h | 10 +++++----- src/bang/operator_timer.cc | 4 ++-- test/kernels/bang/test_bang_conv.cc | 2 +- test/kernels/bang/test_bang_matmul.cc | 2 +- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7dae8509..142bf78b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,5 +1,7 @@ name: Build and test cpu on: + push: + branch: 'master' pull_request: paths-ignore: - '**.md' diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml index 536c74b7..02f6cd1c 100644 --- a/.github/workflows/clang-format-check.yml +++ b/.github/workflows/clang-format-check.yml @@ -1,5 +1,7 @@ name: clang-format Check on: + push: + branch: 'master' pull_request: paths-ignore: - '**.md' diff --git a/include/bang/bang_runtime.h b/include/bang/bang_runtime.h index 7e2bad1c..6a40ae37 100644 --- a/include/bang/bang_runtime.h +++ b/include/bang/bang_runtime.h @@ -12,12 +12,12 @@ class BangRuntimeObj : public RuntimeObj { public: BangRuntimeObj() : RuntimeObj(Device::BANG) { - checkBangError(cnrtInit(0)); - cnrtDev_t dev; - checkBangError(cnrtGetDeviceHandle(&dev, 0)); - checkBangError(cnrtSetCurrentDevice(dev)); + cnInit(0); + CNdev dev; + cnDeviceGet(&dev, 0); + checkBangError(cnrtSetDevice(dev)); cnrtQueue_t queue; - checkBangError(cnrtCreateQueue(&queue)); + checkBangError(cnrtQueueCreate(&queue)); checkCnnlError(cnnlCreate(&cnnl)); checkCnnlError(cnnlSetQueue(cnnl, queue)); diff --git a/src/bang/operator_timer.cc b/src/bang/operator_timer.cc index d5c6782b..5c9a63cd 100644 --- a/src/bang/operator_timer.cc +++ b/src/bang/operator_timer.cc @@ -14,7 +14,7 @@ double getPerfConvCnnl(int n, int c, int h, int w, int f, int r, int s, int padh, int padw, int strideh, int stridew, int dilationh, int dilationw, int group, const char *name) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime bang = make_ref(); Graph gBang = make_ref(bang); @@ -42,7 +42,7 @@ double getPerfConvCnnl(int n, int c, int h, int w, int f, int r, int s, } double getPerfMatmulCnnl(int b, int m, int n, int k, const char *name) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime bang = make_ref(); Graph gBang = make_ref(bang); diff --git a/test/kernels/bang/test_bang_conv.cc b/test/kernels/bang/test_bang_conv.cc index c67b62b6..0b415b0f 100644 --- a/test/kernels/bang/test_bang_conv.cc +++ b/test/kernels/bang/test_bang_conv.cc @@ -13,7 +13,7 @@ void testConv(const std::function &generatorA, const std::function &generatorB, const Shape &shapeA, const Shape &shapeB) { // Runtime - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto bangRuntime = make_ref(); // Build input data on CPU diff --git a/test/kernels/bang/test_bang_matmul.cc b/test/kernels/bang/test_bang_matmul.cc index 77acf4ab..f6a47802 100644 --- a/test/kernels/bang/test_bang_matmul.cc +++ b/test/kernels/bang/test_bang_matmul.cc @@ -14,7 +14,7 @@ void testMatmul(const std::function &generatorA, bool transA, bool transB, const Shape &shapeA, const Shape &shapeB) { // Runtime - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); auto bangRuntime = make_ref(); // Build input data on CPU