This commit is contained in:
OdinaryWord 2024-05-15 22:29:52 +08:00
parent 5747eb8f7d
commit a8443741c4
4 changed files with 7 additions and 5 deletions

View File

@ -35,10 +35,10 @@ class BatchNormAclnn : public ASCENDKernelWithoutConfig {
auto inputTensor = aclCreateTensor(
inputDim.data(), inputDim.size(), ACL_FLOAT, inputStride.data(), 0,
aclFormat::ACL_FORMAT_ND, inputDim.data(), inputDim.size(), inData);
aclFormat::ACL_FORMAT_NCHW, inputDim.data(), inputDim.size(), inData);
auto outputTensor =
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
outputDim.data(), outputDim.size(), outData);
auto meanTensor = aclCreateTensor(
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,

View File

@ -52,7 +52,7 @@ class AvgPooling : public ASCENDKernelWithoutConfig {
auto ret = aclnnAvgPool2dGetWorkspaceSize(
selfTensor, kernelSize, strides, paddings, false, true,
divisorOverride, int8_t(1), outputTensor, &workspaceSize,
divisorOverride, int8_t(0), outputTensor, &workspaceSize,
&executor);
assert(ret == ACL_SUCCESS);

View File

@ -46,11 +46,13 @@ void testConvTransposedAclnn(
}
TEST(ascend_ConvTransposed, run) {
aclInit(nullptr);
testConvTransposedAclnn(
IncrementalGenerator(),
std::vector<float>{0., 0., 1., 2., 3., 0., 6., 12., 18.,
16., 8., 30., 36., 42., 32., 16., 54., 60.,
66., 48., 24., 62., 67., 72., 45.});
aclFinalize();
}
} // namespace infini

View File

@ -39,10 +39,10 @@ void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
}
TEST(cnnl_Pooling, run) {
// aclInit(nullptr);
aclInit(nullptr);
// testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5});
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5});
// aclFinalize();
aclFinalize();
}
} // namespace infini