forked from jiuyuan/InfiniTensor
feat: 前端支持 relu sigmoid tanh softmax abs 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6e5beceadd
commit
7f0c8ebae3
|
@ -45,6 +45,12 @@ class GraphHandlerObj {
|
||||||
Tensor mul(Tensor a, Tensor b, Tensor c);
|
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor div(Tensor a, Tensor b, Tensor c);
|
Tensor div(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor pow(Tensor a, Tensor b, Tensor c);
|
Tensor pow(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
||||||
|
Tensor relu(Tensor x, Tensor y);
|
||||||
|
Tensor sigmoid(Tensor x, Tensor y);
|
||||||
|
Tensor tanh(Tensor x, Tensor y);
|
||||||
|
Tensor softmax(Tensor x, Tensor y);
|
||||||
|
Tensor abs(Tensor x, Tensor y);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -85,6 +85,51 @@ class TestStringMethods(unittest.TestCase):
|
||||||
check_model(model)
|
check_model(model)
|
||||||
from_onnx(model)
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_relu(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||||
|
graph = make_graph([relu], "relu", [x], [y])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_sigmoid(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid")
|
||||||
|
graph = make_graph([sigmoid], "sigmoid", [x], [y])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_tanh(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
||||||
|
graph = make_graph([tanh], "tanh", [x], [y])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_softmax(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
|
||||||
|
graph = make_graph([softmax], "softmax", [x], [y])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_abs(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
abs = make_node("Abs", ["x"], ["y"], name="abs")
|
||||||
|
graph = make_graph([abs], "abs", [x], [y])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -24,6 +25,7 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// see operators/element_wise.h
|
||||||
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||||
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||||
if (c) { \
|
if (c) { \
|
||||||
|
@ -40,6 +42,23 @@ DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
|
||||||
DEFINE_ELEMENT_WISE_METHOD(div, Div)
|
DEFINE_ELEMENT_WISE_METHOD(div, Div)
|
||||||
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
||||||
|
|
||||||
|
// see operators/unary.h
|
||||||
|
#define DEFINE_UNARY_METHOD(name, obj) \
|
||||||
|
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
|
||||||
|
if (y) { \
|
||||||
|
g->addOpWithOutputs<obj##Obj>(x, y); \
|
||||||
|
return y; \
|
||||||
|
} else { \
|
||||||
|
return g->addOp<obj##Obj>(x, y)->getOutput(); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_UNARY_METHOD(relu, Relu)
|
||||||
|
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||||
|
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||||
|
DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||||
|
DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -43,22 +43,34 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("matmul",
|
.def("matmul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&GraphHandlerObj::matmul),
|
ActType>(&GraphHandlerObj::matmul),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("add",
|
.def("add",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("sub",
|
.def("sub",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("mul",
|
.def("mul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("div",
|
.def("div",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
||||||
policy::reference_internal)
|
policy::move)
|
||||||
.def("pow",
|
.def("pow",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
||||||
policy::reference_internal);
|
policy::move)
|
||||||
|
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
|
||||||
|
policy::move)
|
||||||
|
.def("sigmoid",
|
||||||
|
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
|
||||||
|
policy::move)
|
||||||
|
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("softmax",
|
||||||
|
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
|
||||||
|
policy::move)
|
||||||
|
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
|
||||||
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue