feat: 前端支持 gemm 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-15 13:20:34 +08:00
parent 315763a83a
commit afa90ec9c9
3 changed files with 30 additions and 3 deletions

View File

@ -32,6 +32,25 @@ def from_onnx(model: onnx.ModelProto):
None,
backend.ActType.Linear,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
)
(alpha, beta, transA, transB) = (
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
)
# TODO 不支持这些参数
assert alpha == 1.0
assert beta == 1.0
tensors[node.output[0]] = handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
transA == 1,
transB == 1,
tensors[node.input[2]] if len(node.input) > 2 else None,
backend.ActType.Linear,
)
elif node.op_type == "BatchNormalization":
(input, mean, var, scale, bias) = (
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]

View File

@ -37,10 +37,18 @@ class TestStringMethods(unittest.TestCase):
def test_matmul(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
xa = make_tensor_value_info("xa", TensorProto.FLOAT, [1, 2, 4])
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
def test_gemm(self):
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 2, 3])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 3])
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 2, 4])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
gemm = make_node("Gemm", ["a", "b", "c"], ["y"], transB=1, name="gemm")
make_and_import_model(make_graph([gemm], "gemm", [a, b, c], [y]))
def test_batch_norm(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])

View File

@ -6,8 +6,8 @@ namespace infini {
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
: type(opType), inputs(inputs), outputs(outputs) {
for (auto &t : inputs)
IT_ASSERT(t != nullptr);
for (const auto &t : inputs)
IT_ASSERT(t);
}
bool OperatorObj::isLinearOp() const {