forked from jiuyuan/InfiniTensor
fix: Matmul 支持 2 维或以上的输入
> 现在能导入 resnet18 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
a27391fcdc
commit
4ffaa44c1e
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
||||||
make_tensor_value_info,
|
make_tensor_value_info,
|
||||||
)
|
)
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx
|
||||||
|
|
||||||
|
|
||||||
def make_and_import_model(graph: onnx.GraphProto):
|
def make_and_import_model(graph: onnx.GraphProto):
|
||||||
|
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
parse_onnx(onnx.load(model_file))
|
from_onnx(onnx.load(model_file))
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
@ -290,7 +290,6 @@ class TestStringMethods(unittest.TestCase):
|
||||||
model = make_model(graph)
|
model = make_model(graph)
|
||||||
check_model(model)
|
check_model(model)
|
||||||
from_onnx(model)
|
from_onnx(model)
|
||||||
parse_onnx(model)
|
|
||||||
|
|
||||||
def test_frontend(self):
|
def test_frontend(self):
|
||||||
handler = backend.GraphHandler(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
|
|
|
@ -5,10 +5,26 @@ namespace infini {
|
||||||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||||
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
|
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
|
||||||
act(act), b(A->getDims()[0]),
|
act(act), b(1) {
|
||||||
m(transA ? A->getDims()[2] : A->getDims()[1]),
|
auto shape_a = A->getDims();
|
||||||
n(transB ? B->getDims()[1] : B->getDims()[2]),
|
auto shape_b = B->getDims();
|
||||||
k(transA ? A->getDims()[1] : A->getDims()[2]) {
|
IT_ASSERT(shape_a.size() == shape_b.size());
|
||||||
|
switch (shape_a.size()) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
IT_ASSERT(false);
|
||||||
|
case 2:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
|
||||||
|
IT_ASSERT(shape_a[i] == shape_b[i]);
|
||||||
|
b *= shape_a[i];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||||
|
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||||
|
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,19 +38,11 @@ string MatmulObj::toString() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
||||||
auto A = inputs[0], B = inputs[1];
|
auto shape_a = inputs[0]->getDims();
|
||||||
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight)
|
auto it = shape_a.rbegin();
|
||||||
// return false;
|
*it++ = n;
|
||||||
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
|
*it++ = m;
|
||||||
return {};
|
return {{std::move(shape_a)}};
|
||||||
if (!(A->getDims()[0] == B->getDims()[0]))
|
|
||||||
return {};
|
|
||||||
if (!((transA ? A->getDims()[1] : A->getDims()[2]) ==
|
|
||||||
(transB ? B->getDims()[2] : B->getDims()[1])))
|
|
||||||
return {};
|
|
||||||
int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]),
|
|
||||||
n(transB ? B->getDims()[1] : B->getDims()[2]);
|
|
||||||
return {{{b, m, n}}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> MatmulObj::getWorkloadVector() const {
|
vector<int> MatmulObj::getWorkloadVector() const {
|
||||||
|
|
Loading…
Reference in New Issue