fix format

This commit is contained in:
wanghailu 2023-10-23 10:48:35 +08:00
parent b1bdbbf478
commit 412f301323
2 changed files with 73 additions and 63 deletions

View File

@ -1,7 +1,6 @@
#pragma once #pragma once
#include "core/runtime.h"
#include "ascend/ascend_common.h" #include "ascend/ascend_common.h"
#include "core/runtime.h"
#define CHECK_RET(cond, return_expr) \ #define CHECK_RET(cond, return_expr) \
do { \ do { \
@ -25,18 +24,22 @@ class ASCENDRuntimeObj : public RuntimeObj {
size_t workspaceSize; size_t workspaceSize;
public: public:
ASCENDRuntimeObj(int deviceId = 0) ASCENDRuntimeObj(int deviceId = 0) : RuntimeObj(Device::ASCEND, deviceId) {
: RuntimeObj(Device::ASCEND, deviceId) {
auto ret = aclrtSetDevice(deviceId); auto ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret)); CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret));
ret = aclrtCreateContext(&aclnn, deviceId); ret = aclrtCreateContext(&aclnn, deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret)); CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret));
ret = aclrtSetCurrentContext(aclnn); ret = aclrtSetCurrentContext(aclnn);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret)); CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret));
ret = aclrtCreateStream(&stream); ret = aclrtCreateStream(&stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret)); CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret));
ret = aclInit(nullptr); ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret)); CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclInit failed. ERROR: %d\n", ret));
// 10GB for Longformer // 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30); // size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 3ll << 30; // 3 GB workspaceSize = 3ll << 30; // 3 GB

View File

@ -1,7 +1,7 @@
#include "operators/unary.h" #include "operators/unary.h"
#include "aclnnop/level2/aclnn_relu.h"
#include "ascend/ascend_kernel_without_config.h" #include "ascend/ascend_kernel_without_config.h"
#include "ascend/ascend_runtime.h" #include "ascend/ascend_runtime.h"
#include "aclnnop/level2/aclnn_relu.h"
namespace infini { namespace infini {
class ReluAclnn : public ASCENDKernelWithoutConfig { class ReluAclnn : public ASCENDKernelWithoutConfig {
@ -34,19 +34,26 @@ class ReluAclnn : public ASCENDKernelWithoutConfig {
cStride[i] = int64_t(cS[i]); cStride[i] = int64_t(cS[i]);
} }
auto input = aclCreateTensor(aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); auto input = aclCreateTensor(
auto output = aclCreateTensor(cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
auto output = aclCreateTensor(
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
uint64_t workspaceSize = 0; uint64_t workspaceSize = 0;
aclOpExecutor *executor; aclOpExecutor *executor;
auto ret = aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor); auto ret =
aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor);
void *workspaceAddr = nullptr; void *workspaceAddr = nullptr;
if (workspaceSize > 0) { if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); ret = aclrtMalloc(&workspaceAddr, workspaceSize,
ACL_MEM_MALLOC_HUGE_FIRST);
} }
assert(ret == ACL_SUCCESS); assert(ret == ACL_SUCCESS);
ret = aclnnRelu(workspaceAddr, workspaceSize, executor, context->ASCENDHandle()); ret = aclnnRelu(workspaceAddr, workspaceSize, executor,
context->ASCENDHandle());
assert(ret == ACL_SUCCESS); assert(ret == ACL_SUCCESS);
ret = aclrtSynchronizeStream(context->ASCENDHandle()); ret = aclrtSynchronizeStream(context->ASCENDHandle());
assert(ret == ACL_SUCCESS); assert(ret == ACL_SUCCESS);
@ -57,4 +64,4 @@ class ReluAclnn : public ASCENDKernelWithoutConfig {
REGISTER_KERNEL(Device::ASCEND, OpType::Relu, DataType::Float32, ReluAclnn, REGISTER_KERNEL(Device::ASCEND, OpType::Relu, DataType::Float32, ReluAclnn,
"relu_ASCEND_float"); "relu_ASCEND_float");
}; }; // namespace infini