forked from jiuyuan/InfiniTensor
feat: 前端支持 relu sigmoid tanh softmax abs 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6e5beceadd
commit
7f0c8ebae3
|
@ -45,6 +45,12 @@ class GraphHandlerObj {
|
|||
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||
Tensor div(Tensor a, Tensor b, Tensor c);
|
||||
Tensor pow(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y);
|
||||
Tensor abs(Tensor x, Tensor y);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -85,6 +85,51 @@ class TestStringMethods(unittest.TestCase):
|
|||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
def test_relu(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||
graph = make_graph([relu], "relu", [x], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
def test_sigmoid(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid")
|
||||
graph = make_graph([sigmoid], "sigmoid", [x], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
def test_tanh(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
||||
graph = make_graph([tanh], "tanh", [x], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
def test_softmax(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
|
||||
graph = make_graph([softmax], "softmax", [x], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
def test_abs(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
abs = make_node("Abs", ["x"], ["y"], name="abs")
|
||||
graph = make_graph([abs], "abs", [x], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
|
||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||
def test_linear(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -24,6 +25,7 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
|||
}
|
||||
}
|
||||
|
||||
// see operators/element_wise.h
|
||||
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||
if (c) { \
|
||||
|
@ -40,6 +42,23 @@ DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
|
|||
DEFINE_ELEMENT_WISE_METHOD(div, Div)
|
||||
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
||||
|
||||
// see operators/unary.h
|
||||
#define DEFINE_UNARY_METHOD(name, obj) \
|
||||
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
|
||||
if (y) { \
|
||||
g->addOpWithOutputs<obj##Obj>(x, y); \
|
||||
return y; \
|
||||
} else { \
|
||||
return g->addOp<obj##Obj>(x, y)->getOutput(); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_UNARY_METHOD(relu, Relu)
|
||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||
DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||
DEFINE_UNARY_METHOD(abs, Abs)
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
|
|
|
@ -43,22 +43,34 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&GraphHandlerObj::matmul),
|
||||
policy::reference_internal)
|
||||
policy::move)
|
||||
.def("add",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
||||
policy::reference_internal)
|
||||
policy::move)
|
||||
.def("sub",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
||||
policy::reference_internal)
|
||||
policy::move)
|
||||
.def("mul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
||||
policy::reference_internal)
|
||||
policy::move)
|
||||
.def("div",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
||||
policy::reference_internal)
|
||||
policy::move)
|
||||
.def("pow",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
||||
policy::reference_internal);
|
||||
policy::move)
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
|
||||
policy::move)
|
||||
.def("sigmoid",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
|
||||
policy::move)
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
|
||||
policy::reference_internal)
|
||||
.def("softmax",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
|
||||
policy::move)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
|
||||
policy::move);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue