forked from jiuyuan/InfiniTensor
fix: 修正 batchNorm 实现
- onnx 和 pytorch 认为 batchNorm 的 4 个参数是 [c] 形状的,cuDNN 可能认为是 [1,c,1,...]。 优化已改为 [c],但 cuDNN 推理没有改; Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
dd5d091dbc
commit
a27391fcdc
|
@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase):
|
|||
|
||||
def test_batch_norm(self):
|
||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [3])
|
||||
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3])
|
||||
var = make_tensor_value_info("var", TensorProto.FLOAT, [3])
|
||||
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
|
||||
batch_norm = make_node(
|
||||
"BatchNormalization",
|
||||
|
|
|
@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const {
|
|||
auto var = inputs[2];
|
||||
auto scale = inputs[3];
|
||||
auto bias = inputs[4];
|
||||
if (input->getDims().size() < 2)
|
||||
return {};
|
||||
Shape dims(input->getDims().size(), 1);
|
||||
dims[1] = input->getDims()[1]; //
|
||||
if (mean->getDims() != dims || var->getDims() != dims ||
|
||||
scale->getDims() != dims || bias->getDims() != dims)
|
||||
auto c = std::vector<int>{input->getDims()[1]};
|
||||
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
|
||||
bias->getDims() != c)
|
||||
return {};
|
||||
return {{input->getDims()}};
|
||||
}
|
||||
|
|
|
@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) {
|
|||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||
Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor mean = g->addTensor({3}, DataType::Float32);
|
||||
Tensor var = g->addTensor({3}, DataType::Float32);
|
||||
Tensor scaler = g->addTensor({3}, DataType::Float32);
|
||||
Tensor bias = g->addTensor({3}, DataType::Float32);
|
||||
auto op = g->addOp<BatchNormObj>(i, nullptr, mean, var, scaler, bias,
|
||||
0.9, 1e-5);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
||||
|
|
Loading…
Reference in New Issue