forked from jiuyuan/InfiniTensor
feat: 前端支持 gemm 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
315763a83a
commit
afa90ec9c9
|
@ -32,6 +32,25 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
None,
|
None,
|
||||||
backend.ActType.Linear,
|
backend.ActType.Linear,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Gemm":
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||||
|
)
|
||||||
|
(alpha, beta, transA, transB) = (
|
||||||
|
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||||
|
)
|
||||||
|
# TODO 不支持这些参数
|
||||||
|
assert alpha == 1.0
|
||||||
|
assert beta == 1.0
|
||||||
|
tensors[node.output[0]] = handler.matmul(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
transA == 1,
|
||||||
|
transB == 1,
|
||||||
|
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||||
|
backend.ActType.Linear,
|
||||||
|
)
|
||||||
elif node.op_type == "BatchNormalization":
|
elif node.op_type == "BatchNormalization":
|
||||||
(input, mean, var, scale, bias) = (
|
(input, mean, var, scale, bias) = (
|
||||||
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||||
|
|
|
@ -37,10 +37,18 @@ class TestStringMethods(unittest.TestCase):
|
||||||
def test_matmul(self):
|
def test_matmul(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
xa = make_tensor_value_info("xa", TensorProto.FLOAT, [1, 2, 4])
|
||||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
||||||
|
|
||||||
|
def test_gemm(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 3])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 2, 4])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
|
||||||
|
gemm = make_node("Gemm", ["a", "b", "c"], ["y"], transB=1, name="gemm")
|
||||||
|
make_and_import_model(make_graph([gemm], "gemm", [a, b, c], [y]))
|
||||||
|
|
||||||
def test_batch_norm(self):
|
def test_batch_norm(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||||
|
|
|
@ -6,8 +6,8 @@ namespace infini {
|
||||||
|
|
||||||
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||||
: type(opType), inputs(inputs), outputs(outputs) {
|
: type(opType), inputs(inputs), outputs(outputs) {
|
||||||
for (auto &t : inputs)
|
for (const auto &t : inputs)
|
||||||
IT_ASSERT(t != nullptr);
|
IT_ASSERT(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OperatorObj::isLinearOp() const {
|
bool OperatorObj::isLinearOp() const {
|
||||||
|
|
Loading…
Reference in New Issue