forked from jiuyuan/InfiniTensor
fix: 正确导入 conv 的 bias
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
025280040f
commit
8dd9f30ef4
|
@ -45,8 +45,8 @@ class GraphHandlerObj {
|
|||
|
||||
inline OpVec operators() { return g->getOperators(); }
|
||||
|
||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw);
|
||||
Tensor conv(Tensor input, Tensor weight, Tensor bias, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw);
|
||||
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||
int opw);
|
||||
|
|
|
@ -84,50 +84,18 @@ class OnnxStub:
|
|||
else:
|
||||
adapt = node.input[0]
|
||||
|
||||
if len(node.input) > 2:
|
||||
bias = "{}-bias".format(node.output[0])
|
||||
reshape = "{}-reshape".format(node.output[0])
|
||||
tensors[bias] = self.handler.conv(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
None,
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
tensors[reshape] = self.handler.reshape(
|
||||
tensors[node.input[2]],
|
||||
None,
|
||||
[
|
||||
1,
|
||||
reduce(
|
||||
lambda acc, x: acc * x,
|
||||
_search_shape(model, node.input[2]),
|
||||
),
|
||||
1,
|
||||
1,
|
||||
],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[bias],
|
||||
tensors[reshape],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.conv(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.conv(
|
||||
tensors[adapt],
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
elif node.op_type == "ConvTranspose":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
|
|
@ -43,6 +43,8 @@ class TestStringMethods(unittest.TestCase):
|
|||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
|
||||
"new"
|
||||
)
|
||||
# with open("modified.onnx", "wb") as f:
|
||||
# f.write(model.SerializeToString())
|
||||
model = infer_shapes(model)
|
||||
|
||||
def test_tensor(self):
|
||||
|
|
|
@ -23,16 +23,18 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
|||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw) {
|
||||
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor bias,
|
||||
Tensor output, int ph, int pw, int sh, int sw,
|
||||
int dh, int dw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
|
||||
output, ph, pw, sh, sw, dh, dw);
|
||||
output, ph, pw, sh, sw, dh, dw, bias,
|
||||
ActType::None);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
|
||||
pw, sh, sw, dh, dw)
|
||||
pw, sh, sw, dh, dw, bias, ActType::None)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,10 +64,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
: ConvBaseObj(OpType::Conv,
|
||||
bias ? TensorVec{input, weight, bias}
|
||||
: TensorVec{input, weight},
|
||||
output, ph, pw, sh, sw, dh, dw, input, weight, act) {
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
@ -75,10 +75,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
: ConvBaseObj(OpType::Conv,
|
||||
bias ? TensorVec{input, weight, bias}
|
||||
: TensorVec{input, weight},
|
||||
output, mode, sh, sw, dh, dw, input, weight, act) {
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue