forked from jiuyuan/InfiniTensor
fix: Matmul 支持 2 维或以上的输入
> 现在能导入 resnet18 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
a27391fcdc
commit
4ffaa44c1e
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
||||
from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
|
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
parse_onnx(onnx.load(model_file))
|
||||
from_onnx(onnx.load(model_file))
|
||||
|
||||
def test_tensor(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
@ -290,7 +290,6 @@ class TestStringMethods(unittest.TestCase):
|
|||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
parse_onnx(model)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
|
|
@ -5,10 +5,26 @@ namespace infini {
|
|||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
|
||||
act(act), b(A->getDims()[0]),
|
||||
m(transA ? A->getDims()[2] : A->getDims()[1]),
|
||||
n(transB ? B->getDims()[1] : B->getDims()[2]),
|
||||
k(transA ? A->getDims()[1] : A->getDims()[2]) {
|
||||
act(act), b(1) {
|
||||
auto shape_a = A->getDims();
|
||||
auto shape_b = B->getDims();
|
||||
IT_ASSERT(shape_a.size() == shape_b.size());
|
||||
switch (shape_a.size()) {
|
||||
case 0:
|
||||
case 1:
|
||||
IT_ASSERT(false);
|
||||
case 2:
|
||||
break;
|
||||
default:
|
||||
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
|
||||
IT_ASSERT(shape_a[i] == shape_b[i]);
|
||||
b *= shape_a[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -22,19 +38,11 @@ string MatmulObj::toString() const {
|
|||
}
|
||||
|
||||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight)
|
||||
// return false;
|
||||
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
|
||||
return {};
|
||||
if (!(A->getDims()[0] == B->getDims()[0]))
|
||||
return {};
|
||||
if (!((transA ? A->getDims()[1] : A->getDims()[2]) ==
|
||||
(transB ? B->getDims()[2] : B->getDims()[1])))
|
||||
return {};
|
||||
int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]),
|
||||
n(transB ? B->getDims()[1] : B->getDims()[2]);
|
||||
return {{{b, m, n}}};
|
||||
auto shape_a = inputs[0]->getDims();
|
||||
auto it = shape_a.rbegin();
|
||||
*it++ = n;
|
||||
*it++ = m;
|
||||
return {{std::move(shape_a)}};
|
||||
}
|
||||
|
||||
vector<int> MatmulObj::getWorkloadVector() const {
|
||||
|
|
Loading…
Reference in New Issue