forked from jiuyuan/InfiniTensor
add gemm
This commit is contained in:
parent
775ce5040d
commit
0c94b75a65
|
@ -1,4 +1,5 @@
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "aclnnop/level2/aclnn_gemm.h"
|
||||||
#include "aclnnop/level2/aclnn_matmul.h"
|
#include "aclnnop/level2/aclnn_matmul.h"
|
||||||
#include "ascend/ascend_kernel_without_config.h"
|
#include "ascend/ascend_kernel_without_config.h"
|
||||||
#include "ascend/ascend_runtime.h"
|
#include "ascend/ascend_runtime.h"
|
||||||
|
@ -6,15 +7,21 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
|
// unsupport trans for "gemm" whithou biasInput
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
auto op = as<MatmulObj>(_op);
|
auto op = as<MatmulObj>(_op);
|
||||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
auto input_num = op->numInputs();
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
void *biasData = NULL;
|
||||||
|
if (input_num > 2) {
|
||||||
|
biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
|
|
||||||
auto selfD = op->getInputs(0)->getDims();
|
auto selfD = op->getInputs(0)->getDims();
|
||||||
auto selfS = op->getInputs(0)->getStride();
|
auto selfS = op->getInputs(0)->getStride();
|
||||||
|
@ -22,6 +29,12 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
auto matS = op->getInputs(1)->getStride();
|
auto matS = op->getInputs(1)->getStride();
|
||||||
auto outD = op->getOutput()->getDims();
|
auto outD = op->getOutput()->getDims();
|
||||||
auto outS = op->getOutput()->getStride();
|
auto outS = op->getOutput()->getStride();
|
||||||
|
std::vector<int> biasD;
|
||||||
|
std::vector<int> biasS;
|
||||||
|
if (input_num > 2) {
|
||||||
|
biasD = op->getInputs(2)->getDims();
|
||||||
|
biasS = op->getInputs(2)->getStride();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> selfDim = castTo64(selfD);
|
std::vector<int64_t> selfDim = castTo64(selfD);
|
||||||
std::vector<int64_t> selfStride = castTo64(selfS);
|
std::vector<int64_t> selfStride = castTo64(selfS);
|
||||||
|
@ -29,6 +42,12 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
std::vector<int64_t> matStride = castTo64(matS);
|
std::vector<int64_t> matStride = castTo64(matS);
|
||||||
std::vector<int64_t> outputDim = castTo64(outD);
|
std::vector<int64_t> outputDim = castTo64(outD);
|
||||||
std::vector<int64_t> outputStride = castTo64(outS);
|
std::vector<int64_t> outputStride = castTo64(outS);
|
||||||
|
std::vector<int64_t> biasDim;
|
||||||
|
std::vector<int64_t> biasStride;
|
||||||
|
if (input_num > 2) {
|
||||||
|
biasDim = castTo64(biasD);
|
||||||
|
biasStride = castTo64(biasS);
|
||||||
|
}
|
||||||
|
|
||||||
auto selfTensor = aclCreateTensor(
|
auto selfTensor = aclCreateTensor(
|
||||||
selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0,
|
selfDim.data(), selfDim.size(), ACL_FLOAT, selfStride.data(), 0,
|
||||||
|
@ -40,29 +59,61 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||||
outputDim.data(), outputDim.size(), cData);
|
outputDim.data(), outputDim.size(), cData);
|
||||||
|
aclTensor *biasTensor = NULL;
|
||||||
|
if (input_num > 2) {
|
||||||
|
biasTensor =
|
||||||
|
aclCreateTensor(biasDim.data(), biasDim.size(), ACL_FLOAT,
|
||||||
|
biasStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||||
|
biasDim.data(), biasDim.size(), biasData);
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t workspaceSize = 0;
|
uint64_t workspaceSize = 0;
|
||||||
aclOpExecutor *executor;
|
aclOpExecutor *executor;
|
||||||
|
|
||||||
auto ret = aclnnMatmulGetWorkspaceSize(
|
if (input_num > 2) {
|
||||||
selfTensor, matTensor, outputTensor, 1, &workspaceSize, &executor);
|
float alpha = 1.0;
|
||||||
void *workspaceAddr = nullptr;
|
float beta = 1.0;
|
||||||
if (workspaceSize > 0) {
|
int32_t transA = op->getTransA();
|
||||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
int32_t transB = op->getTransB();
|
||||||
}
|
|
||||||
// auto tmp_err_msg = aclGetRecentErrMsg();
|
|
||||||
// if (tmp_err_msg != NULL) {
|
|
||||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
|
||||||
// }
|
|
||||||
CHECK_RET(
|
|
||||||
ret == ACL_SUCCESS,
|
|
||||||
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", ret));
|
|
||||||
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor,
|
|
||||||
context->ASCENDHandle());
|
|
||||||
CHECK_RET(ret == ACL_SUCCESS,
|
|
||||||
LOG_PRINT("aclnnMatmul failed. ERROR: %d\n", ret));
|
|
||||||
|
|
||||||
ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
auto ret = aclnnGemmGetWorkspaceSize(
|
||||||
|
selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA),
|
||||||
|
int64_t(transB), outputTensor, 0, &workspaceSize, &executor);
|
||||||
|
void *workspaceAddr = nullptr;
|
||||||
|
if (workspaceSize > 0) {
|
||||||
|
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||||
|
}
|
||||||
|
// auto tmp_err_msg = aclGetRecentErrMsg();
|
||||||
|
// if (tmp_err_msg != NULL) {
|
||||||
|
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||||
|
// }
|
||||||
|
CHECK_RET(
|
||||||
|
ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n",
|
||||||
|
ret));
|
||||||
|
ret = aclnnGemm(workspaceAddr, workspaceSize, executor,
|
||||||
|
context->ASCENDHandle());
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclnnGemm failed. ERROR: %d\n", ret));
|
||||||
|
} else {
|
||||||
|
auto ret =
|
||||||
|
aclnnMatmulGetWorkspaceSize(selfTensor, matTensor, outputTensor,
|
||||||
|
1, &workspaceSize, &executor);
|
||||||
|
void *workspaceAddr = nullptr;
|
||||||
|
if (workspaceSize > 0) {
|
||||||
|
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||||
|
}
|
||||||
|
CHECK_RET(
|
||||||
|
ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n",
|
||||||
|
ret));
|
||||||
|
|
||||||
|
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor,
|
||||||
|
context->ASCENDHandle());
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
LOG_PRINT("aclnnMatmul failed. ERROR: %d\n", ret));
|
||||||
|
}
|
||||||
|
auto ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
||||||
CHECK_RET(ret == ACL_SUCCESS,
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret));
|
LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret));
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue