fix: Matmul 支持 2 维或以上的输入

> 现在能导入 resnet18

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 11:51:46 +08:00
parent a27391fcdc
commit 4ffaa44c1e
2 changed files with 27 additions and 20 deletions

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info, make_tensor_value_info,
) )
from onnx.checker import check_model from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx
def make_and_import_model(graph: onnx.GraphProto): def make_and_import_model(graph: onnx.GraphProto):
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
) )
) )
parse_onnx(onnx.load(model_file)) from_onnx(onnx.load(model_file))
def test_tensor(self): def test_tensor(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
@ -290,7 +290,6 @@ class TestStringMethods(unittest.TestCase):
model = make_model(graph) model = make_model(graph)
check_model(model) check_model(model)
from_onnx(model) from_onnx(model)
parse_onnx(model)
def test_frontend(self): def test_frontend(self):
handler = backend.GraphHandler(runtime) handler = backend.GraphHandler(runtime)

View File

@ -5,10 +5,26 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act) bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB), : OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
act(act), b(A->getDims()[0]), act(act), b(1) {
m(transA ? A->getDims()[2] : A->getDims()[1]), auto shape_a = A->getDims();
n(transB ? B->getDims()[1] : B->getDims()[2]), auto shape_b = B->getDims();
k(transA ? A->getDims()[1] : A->getDims()[2]) { IT_ASSERT(shape_a.size() == shape_b.size());
switch (shape_a.size()) {
case 0:
case 1:
IT_ASSERT(false);
case 2:
break;
default:
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
IT_ASSERT(shape_a[i] == shape_b[i]);
b *= shape_a[i];
}
break;
}
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -22,19 +38,11 @@ string MatmulObj::toString() const {
} }
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1]; auto shape_a = inputs[0]->getDims();
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight) auto it = shape_a.rbegin();
// return false; *it++ = n;
if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) *it++ = m;
return {}; return {{std::move(shape_a)}};
if (!(A->getDims()[0] == B->getDims()[0]))
return {};
if (!((transA ? A->getDims()[1] : A->getDims()[2]) ==
(transB ? B->getDims()[2] : B->getDims()[1])))
return {};
int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]),
n(transB ? B->getDims()[1] : B->getDims()[2]);
return {{{b, m, n}}};
} }
vector<int> MatmulObj::getWorkloadVector() const { vector<int> MatmulObj::getWorkloadVector() const {