forked from jiuyuan/InfiniTensor
fix format
This commit is contained in:
parent
b1bdbbf478
commit
412f301323
|
@ -1,7 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/runtime.h"
|
|
||||||
#include "ascend/ascend_common.h"
|
#include "ascend/ascend_common.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
|
||||||
#define CHECK_RET(cond, return_expr) \
|
#define CHECK_RET(cond, return_expr) \
|
||||||
do { \
|
do { \
|
||||||
|
@ -25,18 +24,22 @@ class ASCENDRuntimeObj : public RuntimeObj {
|
||||||
size_t workspaceSize;
|
size_t workspaceSize;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ASCENDRuntimeObj(int deviceId = 0)
|
ASCENDRuntimeObj(int deviceId = 0) : RuntimeObj(Device::ASCEND, deviceId) {
|
||||||
: RuntimeObj(Device::ASCEND, deviceId) {
|
|
||||||
auto ret = aclrtSetDevice(deviceId);
|
auto ret = aclrtSetDevice(deviceId);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret));
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret));
|
||||||
ret = aclrtCreateContext(&aclnn, deviceId);
|
ret = aclrtCreateContext(&aclnn, deviceId);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret));
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret));
|
||||||
ret = aclrtSetCurrentContext(aclnn);
|
ret = aclrtSetCurrentContext(aclnn);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret));
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret));
|
||||||
ret = aclrtCreateStream(&stream);
|
ret = aclrtCreateStream(&stream);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret));
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret));
|
||||||
ret = aclInit(nullptr);
|
ret = aclInit(nullptr);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret));
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclInit failed. ERROR: %d\n", ret));
|
||||||
// 10GB for Longformer
|
// 10GB for Longformer
|
||||||
// size_t longformerNum = 3lu * (1 << 30);
|
// size_t longformerNum = 3lu * (1 << 30);
|
||||||
workspaceSize = 3ll << 30; // 3 GB
|
workspaceSize = 3ll << 30; // 3 GB
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
#include "aclnnop/level2/aclnn_relu.h"
|
||||||
#include "ascend/ascend_kernel_without_config.h"
|
#include "ascend/ascend_kernel_without_config.h"
|
||||||
#include "ascend/ascend_runtime.h"
|
#include "ascend/ascend_runtime.h"
|
||||||
#include "aclnnop/level2/aclnn_relu.h"
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class ReluAclnn : public ASCENDKernelWithoutConfig {
|
class ReluAclnn : public ASCENDKernelWithoutConfig {
|
||||||
|
@ -34,19 +34,26 @@ class ReluAclnn : public ASCENDKernelWithoutConfig {
|
||||||
cStride[i] = int64_t(cS[i]);
|
cStride[i] = int64_t(cS[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input = aclCreateTensor(aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
auto input = aclCreateTensor(
|
||||||
auto output = aclCreateTensor(cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||||
|
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||||
|
auto output = aclCreateTensor(
|
||||||
|
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||||
|
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||||
|
|
||||||
uint64_t workspaceSize = 0;
|
uint64_t workspaceSize = 0;
|
||||||
aclOpExecutor *executor;
|
aclOpExecutor *executor;
|
||||||
|
|
||||||
auto ret = aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor);
|
auto ret =
|
||||||
|
aclnnReluGetWorkspaceSize(input, output, &workspaceSize, &executor);
|
||||||
void *workspaceAddr = nullptr;
|
void *workspaceAddr = nullptr;
|
||||||
if (workspaceSize > 0) {
|
if (workspaceSize > 0) {
|
||||||
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
|
ret = aclrtMalloc(&workspaceAddr, workspaceSize,
|
||||||
|
ACL_MEM_MALLOC_HUGE_FIRST);
|
||||||
}
|
}
|
||||||
assert(ret == ACL_SUCCESS);
|
assert(ret == ACL_SUCCESS);
|
||||||
ret = aclnnRelu(workspaceAddr, workspaceSize, executor, context->ASCENDHandle());
|
ret = aclnnRelu(workspaceAddr, workspaceSize, executor,
|
||||||
|
context->ASCENDHandle());
|
||||||
assert(ret == ACL_SUCCESS);
|
assert(ret == ACL_SUCCESS);
|
||||||
ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
||||||
assert(ret == ACL_SUCCESS);
|
assert(ret == ACL_SUCCESS);
|
||||||
|
@ -57,4 +64,4 @@ class ReluAclnn : public ASCENDKernelWithoutConfig {
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::ASCEND, OpType::Relu, DataType::Float32, ReluAclnn,
|
REGISTER_KERNEL(Device::ASCEND, OpType::Relu, DataType::Float32, ReluAclnn,
|
||||||
"relu_ASCEND_float");
|
"relu_ASCEND_float");
|
||||||
};
|
}; // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue