forked from jiuyuan/InfiniTensor
format
This commit is contained in:
parent
6ba1a0648a
commit
775ce5040d
|
@ -92,53 +92,52 @@ class ReluAclnn : public ASCENDKernelWithoutConfig {
|
|||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); \
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); \
|
||||
\
|
||||
auto a = op->getInputs(0) -> getDims();
|
||||
|
||||
std::vector<int64_t> aDim(a.size(), 1);
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
aDim[i] = int64_t(a[i]);
|
||||
}
|
||||
auto aS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> aStride(aS.size(), 1);
|
||||
for (size_t i = 0; i < aS.size(); ++i) {
|
||||
aStride[i] = int64_t(aS[i]);
|
||||
}
|
||||
auto c = op->getInputs(0)->getDims();
|
||||
std::vector<int64_t> cDim(c.size(), 1);
|
||||
for (size_t i = 0; i < c.size(); ++i) {
|
||||
cDim[i] = int64_t(c[i]);
|
||||
}
|
||||
auto cS = op->getInputs(0)->getStride();
|
||||
std::vector<int64_t> cStride(cS.size(), 1);
|
||||
for (size_t i = 0; i < cS.size(); ++i) {
|
||||
cStride[i] = int64_t(cS[i]);
|
||||
}
|
||||
|
||||
auto input =
|
||||
aclCreateTensor(aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData);
|
||||
auto output =
|
||||
aclCreateTensor(cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData);
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret =
|
||||
aclnn##prefix##GetWorkspaceSize(input, output, &workspaceSize, &executor);
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclnn##prefix(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
return;
|
||||
} // namespace infini \
|
||||
auto a = op->getInputs(0) -> getDims(); \
|
||||
std::vector<int64_t> aDim(a.size(), 1); \
|
||||
for (size_t i = 0; i < a.size(); ++i) { \
|
||||
aDim[i] = int64_t(a[i]); \
|
||||
} \
|
||||
auto aS = op->getInputs(0) -> getStride(); \
|
||||
std::vector<int64_t> aStride(aS.size(), 1); \
|
||||
for (size_t i = 0; i < aS.size(); ++i) { \
|
||||
aStride[i] = int64_t(aS[i]); \
|
||||
} \
|
||||
auto c = op->getInputs(0) -> getDims(); \
|
||||
std::vector<int64_t> cDim(c.size(), 1); \
|
||||
for (size_t i = 0; i < c.size(); ++i) { \
|
||||
cDim[i] = int64_t(c[i]); \
|
||||
} \
|
||||
auto cS = op->getInputs(0) -> getStride(); \
|
||||
std::vector<int64_t> cStride(cS.size(), 1); \
|
||||
for (size_t i = 0; i < cS.size(); ++i) { \
|
||||
cStride[i] = int64_t(cS[i]); \
|
||||
} \
|
||||
\
|
||||
auto input = aclCreateTensor( \
|
||||
aDim.data(), aDim.size(), ACL_FLOAT, aStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, aDim.data(), aDim.size(), aData); \
|
||||
auto output = aclCreateTensor( \
|
||||
cDim.data(), cDim.size(), ACL_FLOAT, cStride.data(), 0, \
|
||||
aclFormat::ACL_FORMAT_ND, cDim.data(), cDim.size(), cData); \
|
||||
\
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor *executor; \
|
||||
\
|
||||
auto ret = aclnn##prefix##GetWorkspaceSize( \
|
||||
input, output, &workspaceSize, &executor); \
|
||||
void *workspaceAddr = nullptr; \
|
||||
if (workspaceSize > 0) { \
|
||||
workspaceAddr = context->getWorkspace(workspaceSize); \
|
||||
} \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
ret = aclnn##prefix(workspaceAddr, workspaceSize, executor, \
|
||||
context->ASCENDHandle()); \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
ret = aclrtSynchronizeStream(context->ASCENDHandle()); \
|
||||
assert(ret == ACL_SUCCESS); \
|
||||
\
|
||||
return; \
|
||||
} \
|
||||
};
|
||||
|
||||
DEFINE_UNARY_Aclnn(Abs);
|
||||
|
@ -185,5 +184,4 @@ REGISTER_KERNEL(Device::ASCEND, OpType::Sqrt, SqrtAclnn, "sqrt_ASCEND_float");
|
|||
REGISTER_KERNEL(Device::ASCEND, OpType::Round, RoundAclnn,
|
||||
"round_ASCEND_float");
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::Erf, ErfAclnn, "erf_ASCEND_float");
|
||||
}
|
||||
; // namespace infini
|
||||
}; // namespace infini
|
||||
|
|
Loading…
Reference in New Issue