feat: 前端支持 relu sigmoid tanh softmax abs 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-13 11:54:54 +08:00
parent 6e5beceadd
commit 7f0c8ebae3
4 changed files with 88 additions and 6 deletions

View File

@ -45,6 +45,12 @@ class GraphHandlerObj {
Tensor mul(Tensor a, Tensor b, Tensor c);
Tensor div(Tensor a, Tensor b, Tensor c);
Tensor pow(Tensor a, Tensor b, Tensor c);
Tensor relu(Tensor x, Tensor y);
Tensor sigmoid(Tensor x, Tensor y);
Tensor tanh(Tensor x, Tensor y);
Tensor softmax(Tensor x, Tensor y);
Tensor abs(Tensor x, Tensor y);
};
} // namespace infini

View File

@ -85,6 +85,51 @@ class TestStringMethods(unittest.TestCase):
check_model(model)
from_onnx(model)
def test_relu(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
relu = make_node("Relu", ["x"], ["y"], name="relu")
graph = make_graph([relu], "relu", [x], [y])
model = make_model(graph)
check_model(model)
from_onnx(model)
def test_sigmoid(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid")
graph = make_graph([sigmoid], "sigmoid", [x], [y])
model = make_model(graph)
check_model(model)
from_onnx(model)
def test_tanh(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
graph = make_graph([tanh], "tanh", [x], [y])
model = make_model(graph)
check_model(model)
from_onnx(model)
def test_softmax(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
graph = make_graph([softmax], "softmax", [x], [y])
model = make_model(graph)
check_model(model)
from_onnx(model)
def test_abs(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
abs = make_node("Abs", ["x"], ["y"], name="abs")
graph = make_graph([abs], "abs", [x], [y])
model = make_model(graph)
check_model(model)
from_onnx(model)
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/unary.h"
namespace infini {
@ -24,6 +25,7 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
}
}
// see operators/element_wise.h
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
if (c) { \
@ -40,6 +42,23 @@ DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
DEFINE_ELEMENT_WISE_METHOD(div, Div)
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
// see operators/unary.h
#define DEFINE_UNARY_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
if (y) { \
g->addOpWithOutputs<obj##Obj>(x, y); \
return y; \
} else { \
return g->addOp<obj##Obj>(x, y)->getOutput(); \
} \
}
DEFINE_UNARY_METHOD(relu, Relu)
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(softmax, Softmax)
DEFINE_UNARY_METHOD(abs, Abs)
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT:

View File

@ -43,22 +43,34 @@ void init_graph_builder(py::module &m) {
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphHandlerObj::matmul),
policy::reference_internal)
policy::move)
.def("add",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
policy::reference_internal)
policy::move)
.def("sub",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
policy::reference_internal)
policy::move)
.def("mul",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
policy::reference_internal)
policy::move)
.def("div",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
policy::reference_internal)
policy::move)
.def("pow",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
policy::reference_internal);
policy::move)
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
policy::move)
.def("sigmoid",
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
policy::move)
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
policy::reference_internal)
.def("softmax",
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
policy::move);
}
} // namespace infini