fix gemm & avgpooling

This commit is contained in:
OdinaryWord 2024-04-29 16:10:32 +08:00
parent 47fc0bfa99
commit 907239cf34
3 changed files with 7 additions and 7 deletions

@ -1 +1 @@
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77 Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98

View File

@ -78,7 +78,7 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
auto ret = aclnnGemmGetWorkspaceSize( auto ret = aclnnGemmGetWorkspaceSize(
selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA), selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA),
int64_t(transB), outputTensor, 0, &workspaceSize, &executor); int64_t(transB), outputTensor, 1, &workspaceSize, &executor);
void *workspaceAddr = nullptr; void *workspaceAddr = nullptr;
if (workspaceSize > 0) { if (workspaceSize > 0) {
workspaceAddr = context->getWorkspace(workspaceSize); workspaceAddr = context->getWorkspace(workspaceSize);
@ -87,10 +87,9 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
// if (tmp_err_msg != NULL) { // if (tmp_err_msg != NULL) {
// printf(" ERROR Message : %s \n ", tmp_err_msg); // printf(" ERROR Message : %s \n ", tmp_err_msg);
// } // }
CHECK_RET( CHECK_RET(ret == ACL_SUCCESS,
ret == ACL_SUCCESS, LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n",
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", ret));
ret));
ret = aclnnGemm(workspaceAddr, workspaceSize, executor, ret = aclnnGemm(workspaceAddr, workspaceSize, executor,
context->ASCENDHandle()); context->ASCENDHandle());
CHECK_RET(ret == ACL_SUCCESS, CHECK_RET(ret == ACL_SUCCESS,

View File

@ -52,7 +52,8 @@ class AvgPooling : public ASCENDKernelWithoutConfig {
auto ret = aclnnAvgPool2dGetWorkspaceSize( auto ret = aclnnAvgPool2dGetWorkspaceSize(
selfTensor, kernelSize, strides, paddings, false, true, selfTensor, kernelSize, strides, paddings, false, true,
divisorOverride, 0, outputTensor, &workspaceSize, &executor); divisorOverride, int8_t(1), outputTensor, &workspaceSize,
&executor);
assert(ret == ACL_SUCCESS); assert(ret == ACL_SUCCESS);
void *workspaceAddr = nullptr; void *workspaceAddr = nullptr;