forked from jiuyuan/InfiniTensor
feat: 前端支持 gemm 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
315763a83a
commit
afa90ec9c9
|
@ -32,6 +32,25 @@ def from_onnx(model: onnx.ModelProto):
|
|||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
)
|
||||
(alpha, beta, transA, transB) = (
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# TODO 不支持这些参数
|
||||
assert alpha == 1.0
|
||||
assert beta == 1.0
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
transA == 1,
|
||||
transB == 1,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "BatchNormalization":
|
||||
(input, mean, var, scale, bias) = (
|
||||
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||
|
|
|
@ -37,10 +37,18 @@ class TestStringMethods(unittest.TestCase):
|
|||
def test_matmul(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||
xa = make_tensor_value_info("xa", TensorProto.FLOAT, [1, 2, 4])
|
||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
||||
|
||||
def test_gemm(self):
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 2, 3])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 4, 3])
|
||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 2, 4])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
|
||||
gemm = make_node("Gemm", ["a", "b", "c"], ["y"], transB=1, name="gemm")
|
||||
make_and_import_model(make_graph([gemm], "gemm", [a, b, c], [y]))
|
||||
|
||||
def test_batch_norm(self):
|
||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
|
|
|
@ -6,8 +6,8 @@ namespace infini {
|
|||
|
||||
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||
: type(opType), inputs(inputs), outputs(outputs) {
|
||||
for (auto &t : inputs)
|
||||
IT_ASSERT(t != nullptr);
|
||||
for (const auto &t : inputs)
|
||||
IT_ASSERT(t);
|
||||
}
|
||||
|
||||
bool OperatorObj::isLinearOp() const {
|
||||
|
|
Loading…
Reference in New Issue