diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 0db20fc9..8be02f11 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,7 +8,7 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx +from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx def make_and_import_model(graph: onnx.GraphProto): @@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase): file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 ) ) - parse_onnx(onnx.load(model_file)) + from_onnx(onnx.load(model_file)) def test_tensor(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) @@ -290,7 +290,6 @@ class TestStringMethods(unittest.TestCase): model = make_model(graph) check_model(model) from_onnx(model) - parse_onnx(model) def test_frontend(self): handler = backend.GraphHandler(runtime) diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 8e82ccc7..07708d07 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -5,10 +5,26 @@ namespace infini { MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, bool transB, [[maybe_unused]] Tensor bias, ActType act) : OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB), - act(act), b(A->getDims()[0]), - m(transA ? A->getDims()[2] : A->getDims()[1]), - n(transB ? B->getDims()[1] : B->getDims()[2]), - k(transA ? A->getDims()[1] : A->getDims()[2]) { + act(act), b(1) { + auto shape_a = A->getDims(); + auto shape_b = B->getDims(); + IT_ASSERT(shape_a.size() == shape_b.size()); + switch (shape_a.size()) { + case 0: + case 1: + IT_ASSERT(false); + case 2: + break; + default: + for (size_t i = 0; i < shape_a.size() - 2; ++i) { + IT_ASSERT(shape_a[i] == shape_b[i]); + b *= shape_a[i]; + } + break; + } + m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1); + n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin()); + k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); IT_ASSERT(checkValid(graph)); } @@ -22,19 +38,11 @@ string MatmulObj::toString() const { } optional> MatmulObj::inferShape(const TensorVec &inputs) const { - auto A = inputs[0], B = inputs[1]; - // if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight) - // return false; - if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) - return {}; - if (!(A->getDims()[0] == B->getDims()[0])) - return {}; - if (!((transA ? A->getDims()[1] : A->getDims()[2]) == - (transB ? B->getDims()[2] : B->getDims()[1]))) - return {}; - int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]), - n(transB ? B->getDims()[1] : B->getDims()[2]); - return {{{b, m, n}}}; + auto shape_a = inputs[0]->getDims(); + auto it = shape_a.rbegin(); + *it++ = n; + *it++ = m; + return {{std::move(shape_a)}}; } vector MatmulObj::getWorkloadVector() const {