forked from jiuyuan/InfiniTensor
fix: 正确导入 conv 的 bias
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
025280040f
commit
8dd9f30ef4
|
@ -45,8 +45,8 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
inline OpVec operators() { return g->getOperators(); }
|
inline OpVec operators() { return g->getOperators(); }
|
||||||
|
|
||||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
Tensor conv(Tensor input, Tensor weight, Tensor bias, Tensor output, int ph,
|
||||||
int sh, int sw, int dh, int dw);
|
int pw, int sh, int sw, int dh, int dw);
|
||||||
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
|
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
|
||||||
int pw, int sh, int sw, int dh, int dw, int oph,
|
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||||
int opw);
|
int opw);
|
||||||
|
|
|
@ -84,50 +84,18 @@ class OnnxStub:
|
||||||
else:
|
else:
|
||||||
adapt = node.input[0]
|
adapt = node.input[0]
|
||||||
|
|
||||||
if len(node.input) > 2:
|
tensors[node.output[0]] = self.handler.conv(
|
||||||
bias = "{}-bias".format(node.output[0])
|
tensors[adapt],
|
||||||
reshape = "{}-reshape".format(node.output[0])
|
tensors[node.input[1]],
|
||||||
tensors[bias] = self.handler.conv(
|
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||||
tensors[adapt],
|
tensors.get(node.output[0]),
|
||||||
tensors[node.input[1]],
|
p[0],
|
||||||
None,
|
p[1],
|
||||||
p[0],
|
s[0],
|
||||||
p[1],
|
s[1],
|
||||||
s[0],
|
d[0],
|
||||||
s[1],
|
d[1],
|
||||||
d[0],
|
)
|
||||||
d[1],
|
|
||||||
)
|
|
||||||
tensors[reshape] = self.handler.reshape(
|
|
||||||
tensors[node.input[2]],
|
|
||||||
None,
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
reduce(
|
|
||||||
lambda acc, x: acc * x,
|
|
||||||
_search_shape(model, node.input[2]),
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
tensors[node.output[0]] = self.handler.add(
|
|
||||||
tensors[bias],
|
|
||||||
tensors[reshape],
|
|
||||||
tensors.get(node.output[0]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tensors[node.output[0]] = self.handler.conv(
|
|
||||||
tensors[adapt],
|
|
||||||
tensors[node.input[1]],
|
|
||||||
tensors.get(node.output[0]),
|
|
||||||
p[0],
|
|
||||||
p[1],
|
|
||||||
s[0],
|
|
||||||
s[1],
|
|
||||||
d[0],
|
|
||||||
d[1],
|
|
||||||
)
|
|
||||||
elif node.op_type == "ConvTranspose":
|
elif node.op_type == "ConvTranspose":
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
node,
|
node,
|
||||||
|
|
|
@ -43,6 +43,8 @@ class TestStringMethods(unittest.TestCase):
|
||||||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
|
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()).to_onnx(
|
||||||
"new"
|
"new"
|
||||||
)
|
)
|
||||||
|
# with open("modified.onnx", "wb") as f:
|
||||||
|
# f.write(model.SerializeToString())
|
||||||
model = infer_shapes(model)
|
model = infer_shapes(model)
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
|
|
|
@ -23,16 +23,18 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
||||||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
|
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor bias,
|
||||||
int pw, int sh, int sw, int dh, int dw) {
|
Tensor output, int ph, int pw, int sh, int sw,
|
||||||
|
int dh, int dw) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
|
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
|
||||||
output, ph, pw, sh, sw, dh, dw);
|
output, ph, pw, sh, sw, dh, dw, bias,
|
||||||
|
ActType::None);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
|
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
|
||||||
pw, sh, sw, dh, dw)
|
pw, sh, sw, dh, dw, bias, ActType::None)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,10 +64,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||||
ActType act)
|
ActType act)
|
||||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
: ConvBaseObj(OpType::Conv,
|
||||||
input, weight, act) {
|
bias ? TensorVec{input, weight, bias}
|
||||||
if (bias)
|
: TensorVec{input, weight},
|
||||||
IT_TODO_HALT();
|
output, ph, pw, sh, sw, dh, dw, input, weight, act) {
|
||||||
setAuxilaryAttributes(PaddingMode::Other);
|
setAuxilaryAttributes(PaddingMode::Other);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
@ -75,10 +75,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||||
ActType act)
|
ActType act)
|
||||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
: ConvBaseObj(OpType::Conv,
|
||||||
input, weight, act) {
|
bias ? TensorVec{input, weight, bias}
|
||||||
if (bias)
|
: TensorVec{input, weight},
|
||||||
IT_TODO_HALT();
|
output, mode, sh, sw, dh, dw, input, weight, act) {
|
||||||
setAuxilaryAttributes(mode);
|
setAuxilaryAttributes(mode);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue