fix: 正确导入 conv 的 bias

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-06-27 17:37:21 +08:00
parent 025280040f
commit 8dd9f30ef4
5 changed files with 30 additions and 58 deletions

View File

@ -45,8 +45,8 @@ class GraphHandlerObj {
inline OpVec operators() { return g->getOperators(); }
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh, int sw, int dh, int dw);
Tensor conv(Tensor input, Tensor weight, Tensor bias, Tensor output, int ph,
int pw, int sh, int sw, int dh, int dw);
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh, int sw, int dh, int dw, int oph,
int opw);

View File

@ -84,50 +84,18 @@ class OnnxStub:
else:
adapt = node.input[0]
if len(node.input) > 2:
bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0])
tensors[bias] = self.handler.conv(
tensors[adapt],
tensors[node.input[1]],
None,
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
tensors[reshape] = self.handler.reshape(
tensors[node.input[2]],
None,
[
1,
reduce(
lambda acc, x: acc * x,
_search_shape(model, node.input[2]),
),
1,
1,
],
)
tensors[node.output[0]] = self.handler.add(
tensors[bias],
tensors[reshape],
tensors.get(node.output[0]),
)
else:
tensors[node.output[0]] = self.handler.conv(
tensors[adapt],
tensors[node.input[1]],
tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
tensors[node.output[0]] = self.handler.conv(
tensors[adapt],
tensors[node.input[1]],
tensors[node.input[2]] if len(node.input) > 2 else None,
tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
elif node.op_type == "ConvTranspose":
attributes = _parse_attribute(
node,

View File

@ -43,6 +43,8 @@ class TestStringMethods(unittest.TestCase):
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
"new"
)
# with open("modified.onnx", "wb") as f:
# f.write(model.SerializeToString())
model = infer_shapes(model)
def test_tensor(self):

View File

@ -23,16 +23,18 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
}
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh, int sw, int dh, int dw) {
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor bias,
Tensor output, int ph, int pw, int sh, int sw,
int dh, int dw) {
if (output) {
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
output, ph, pw, sh, sw, dh, dw);
output, ph, pw, sh, sw, dh, dw, bias,
ActType::None);
return output;
} else {
return g
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
pw, sh, sw, dh, dw)
pw, sh, sw, dh, dw, bias, ActType::None)
->getOutput();
}
}

View File

@ -64,10 +64,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
input, weight, act) {
if (bias)
IT_TODO_HALT();
: ConvBaseObj(OpType::Conv,
bias ? TensorVec{input, weight, bias}
: TensorVec{input, weight},
output, ph, pw, sh, sw, dh, dw, input, weight, act) {
setAuxilaryAttributes(PaddingMode::Other);
IT_ASSERT(checkValid(graph));
}
@ -75,10 +75,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
input, weight, act) {
if (bias)
IT_TODO_HALT();
: ConvBaseObj(OpType::Conv,
bias ? TensorVec{input, weight, bias}
: TensorVec{input, weight},
output, mode, sh, sw, dh, dw, input, weight, act) {
setAuxilaryAttributes(mode);
IT_ASSERT(checkValid(graph));
}