fix: Matmul 支持 2 维或以上的输入

> 现在能导入 resnet18

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 11:51:46 +08:00
parent a27391fcdc
commit 4ffaa44c1e
2 changed files with 27 additions and 20 deletions

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info,
)
from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx
def make_and_import_model(graph: onnx.GraphProto):
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
)
)
parse_onnx(onnx.load(model_file))
from_onnx(onnx.load(model_file))
def test_tensor(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
@ -290,7 +290,6 @@ class TestStringMethods(unittest.TestCase):
model = make_model(graph)
check_model(model)
from_onnx(model)
parse_onnx(model)
def test_frontend(self):
handler = backend.GraphHandler(runtime)

View File

@ -5,10 +5,26 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
act(act), b(A->getDims()[0]),
m(transA ? A->getDims()[2] : A->getDims()[1]),
n(transB ? B->getDims()[1] : B->getDims()[2]),
k(transA ? A->getDims()[1] : A->getDims()[2]) {
act(act), b(1) {
auto shape_a = A->getDims();
auto shape_b = B->getDims();
IT_ASSERT(shape_a.size() == shape_b.size());
switch (shape_a.size()) {
case 0:
case 1:
IT_ASSERT(false);
case 2:
break;
default:
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
IT_ASSERT(shape_a[i] == shape_b[i]);
b *= shape_a[i];
}
break;
}
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
IT_ASSERT(checkValid(graph));
}
@ -22,19 +38,11 @@ string MatmulObj::toString() const {
}
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1];
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight)
// return false;
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
return {};
if (!(A->getDims()[0] == B->getDims()[0]))
return {};
if (!((transA ? A->getDims()[1] : A->getDims()[2]) ==
(transB ? B->getDims()[2] : B->getDims()[1])))
return {};
int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]),
n(transB ? B->getDims()[1] : B->getDims()[2]);
return {{{b, m, n}}};
auto shape_a = inputs[0]->getDims();
auto it = shape_a.rbegin();
*it++ = n;
*it++ = m;
return {{std::move(shape_a)}};
}
vector<int> MatmulObj::getWorkloadVector() const {