diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index c66995d2..9d7a4a2d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,6 +32,25 @@ def from_onnx(model: onnx.ModelProto): None, backend.ActType.Linear, ) + elif node.op_type == "Gemm": + attributes = _parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + (alpha, beta, transA, transB) = ( + attributes[name] for name in ["alpha", "beta", "transA", "transB"] + ) + # TODO 不支持这些参数 + assert alpha == 1.0 + assert beta == 1.0 + tensors[node.output[0]] = handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + transA == 1, + transB == 1, + tensors[node.input[2]] if len(node.input) > 2 else None, + backend.ActType.Linear, + ) elif node.op_type == "BatchNormalization": (input, mean, var, scale, bias) = ( tensors[node.input[i]] for i in [0, 3, 4, 1, 2] diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 4f86cd8d..150ce3af 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -37,10 +37,18 @@ class TestStringMethods(unittest.TestCase): def test_matmul(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) - xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4]) + xa = make_tensor_value_info("xa", TensorProto.FLOAT, [1, 2, 4]) matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul") make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa])) + def test_gemm(self): + a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 2, 3]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 3]) + c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 2, 4]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4]) + gemm = make_node("Gemm", ["a", "b", "c"], ["y"], transB=1, name="gemm") + make_and_import_model(make_graph([gemm], "gemm", [a, b, c], [y])) + def test_batch_norm(self): x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2]) scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1]) diff --git a/src/core/operator.cc b/src/core/operator.cc index b8e69af8..51568f8f 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -6,8 +6,8 @@ namespace infini { OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs) : type(opType), inputs(inputs), outputs(outputs) { - for (auto &t : inputs) - IT_ASSERT(t != nullptr); + for (const auto &t : inputs) + IT_ASSERT(t); } bool OperatorObj::isLinearOp() const {