forked from jiuyuan/InfiniTensor
fix gemm & avgpooling
This commit is contained in:
parent
47fc0bfa99
commit
907239cf34
|
@ -1 +1 @@
|
||||||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
|
@ -78,7 +78,7 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
|
|
||||||
auto ret = aclnnGemmGetWorkspaceSize(
|
auto ret = aclnnGemmGetWorkspaceSize(
|
||||||
selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA),
|
selfTensor, matTensor, biasTensor, alpha, beta, int64_t(transA),
|
||||||
int64_t(transB), outputTensor, 0, &workspaceSize, &executor);
|
int64_t(transB), outputTensor, 1, &workspaceSize, &executor);
|
||||||
void *workspaceAddr = nullptr;
|
void *workspaceAddr = nullptr;
|
||||||
if (workspaceSize > 0) {
|
if (workspaceSize > 0) {
|
||||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||||
|
@ -87,10 +87,9 @@ class MatmulAclnn : public ASCENDKernelWithoutConfig {
|
||||||
// if (tmp_err_msg != NULL) {
|
// if (tmp_err_msg != NULL) {
|
||||||
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
// printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||||
// }
|
// }
|
||||||
CHECK_RET(
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
ret == ACL_SUCCESS,
|
LOG_PRINT("aclnnGemmGetWorkspaceSize failed. ERROR: %d\n",
|
||||||
LOG_PRINT("aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n",
|
ret));
|
||||||
ret));
|
|
||||||
ret = aclnnGemm(workspaceAddr, workspaceSize, executor,
|
ret = aclnnGemm(workspaceAddr, workspaceSize, executor,
|
||||||
context->ASCENDHandle());
|
context->ASCENDHandle());
|
||||||
CHECK_RET(ret == ACL_SUCCESS,
|
CHECK_RET(ret == ACL_SUCCESS,
|
||||||
|
|
|
@ -52,7 +52,8 @@ class AvgPooling : public ASCENDKernelWithoutConfig {
|
||||||
|
|
||||||
auto ret = aclnnAvgPool2dGetWorkspaceSize(
|
auto ret = aclnnAvgPool2dGetWorkspaceSize(
|
||||||
selfTensor, kernelSize, strides, paddings, false, true,
|
selfTensor, kernelSize, strides, paddings, false, true,
|
||||||
divisorOverride, 0, outputTensor, &workspaceSize, &executor);
|
divisorOverride, int8_t(1), outputTensor, &workspaceSize,
|
||||||
|
&executor);
|
||||||
assert(ret == ACL_SUCCESS);
|
assert(ret == ACL_SUCCESS);
|
||||||
|
|
||||||
void *workspaceAddr = nullptr;
|
void *workspaceAddr = nullptr;
|
||||||
|
|
Loading…
Reference in New Issue