forked from jiuyuan/InfiniTensor
fix op
This commit is contained in:
parent
5747eb8f7d
commit
a8443741c4
|
@ -35,10 +35,10 @@ class BatchNormAclnn : public ASCENDKernelWithoutConfig {
|
|||
|
||||
auto inputTensor = aclCreateTensor(
|
||||
inputDim.data(), inputDim.size(), ACL_FLOAT, inputStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_ND, inputDim.data(), inputDim.size(), inData);
|
||||
aclFormat::ACL_FORMAT_NCHW, inputDim.data(), inputDim.size(), inData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), outData);
|
||||
auto meanTensor = aclCreateTensor(
|
||||
paraDim.data(), paraDim.size(), ACL_FLOAT, paraStride.data(), 0,
|
||||
|
|
|
@ -52,7 +52,7 @@ class AvgPooling : public ASCENDKernelWithoutConfig {
|
|||
|
||||
auto ret = aclnnAvgPool2dGetWorkspaceSize(
|
||||
selfTensor, kernelSize, strides, paddings, false, true,
|
||||
divisorOverride, int8_t(1), outputTensor, &workspaceSize,
|
||||
divisorOverride, int8_t(0), outputTensor, &workspaceSize,
|
||||
&executor);
|
||||
assert(ret == ACL_SUCCESS);
|
||||
|
||||
|
|
|
@ -46,11 +46,13 @@ void testConvTransposedAclnn(
|
|||
}
|
||||
|
||||
TEST(ascend_ConvTransposed, run) {
|
||||
aclInit(nullptr);
|
||||
testConvTransposedAclnn(
|
||||
IncrementalGenerator(),
|
||||
std::vector<float>{0., 0., 1., 2., 3., 0., 6., 12., 18.,
|
||||
16., 8., 30., 36., 42., 32., 16., 54., 60.,
|
||||
66., 48., 24., 62., 67., 72., 45.});
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -39,10 +39,10 @@ void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
}
|
||||
|
||||
TEST(cnnl_Pooling, run) {
|
||||
// aclInit(nullptr);
|
||||
aclInit(nullptr);
|
||||
// testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5});
|
||||
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5});
|
||||
// aclFinalize();
|
||||
aclFinalize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue