feat: 完成导入

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-23 13:01:42 +08:00
parent cc62a3216d
commit 28287f3782
1 changed files with 139 additions and 61 deletions

View File

@ -2,6 +2,10 @@
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/gather.h"
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/unary.h"
#include "optimization/common.h"
#include <algorithm>
@ -137,65 +141,6 @@ optimization::DataType cast(DataType ty) {
#undef IT
}
optimization::OpType cast(OpType ty) {
#define IT(A, B) \
case OpType::A: \
return optimization::OpType::B
switch (ty) {
IT(Abs, Abs);
IT(Add, Add);
IT(And, And);
IT(AvgPool, AveragePool);
IT(BatchNorm, BatchNormalization);
IT(Cast, Cast);
IT(Clip, Clip);
IT(Concat, Concat);
IT(Conv, Conv);
IT(ConvTrans, ConvTranspose);
IT(Cos, Cos);
IT(Div, Div);
IT(Dropout, Dropout);
IT(Erf, Erf);
IT(Exp, Exp);
IT(Flatten, Flatten);
IT(Gather, Gather);
IT(Identity, Identity);
IT(Log, Log);
IT(Matmul, MatMul);
IT(MaxPool, MaxPool);
IT(Mul, Mul);
IT(Neg, Neg);
IT(Not, Not);
IT(Or, Or);
IT(PRelu, PRelu);
IT(Pad, Pad);
IT(Pow, Pow);
IT(ReduceMean, ReduceMean);
IT(Relu, Relu);
IT(Reshape, Reshape);
IT(Resize, Resize);
IT(Shape, Shape);
IT(Sigmoid, Sigmoid);
IT(Sin, Sin);
IT(SinH, Sinh);
IT(Slice, Slice);
IT(Softmax, Softmax);
IT(Split, Split);
IT(Sqrt, Sqrt);
IT(Sub, Sub);
IT(Tan, Tan);
IT(TanH, Tanh);
IT(Transpose, Transpose);
IT(Xor, Xor);
default:
IT_ASSERT(false);
break;
}
#undef IT
}
void GraphObj::optimize() {
namespace opt = optimization;
@ -261,10 +206,19 @@ void GraphObj::optimize() {
ans.push_operator(opt::OpType::Atanh, std::move(in),
std::move(out));
break;
case OpType::AvgPool:
case OpType::AvgPool: {
auto obj = as<AvgPoolObj>(op);
in.push_back(
opt::Tensor::share_vec<int>({obj->getDh(), obj->getDw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getKh(), obj->getKw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getPh(), obj->getPw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getSh(), obj->getSw()}));
ans.push_operator(opt::OpType::AveragePool, std::move(in),
std::move(out));
break;
} break;
case OpType::BatchNorm:
ans.push_operator(opt::OpType::BatchNormalization, std::move(in),
std::move(out));
@ -369,6 +323,130 @@ void GraphObj::optimize() {
ans.push_operator(opt::OpType::Identity, std::move(in),
std::move(out));
break;
case OpType::Log:
ans.push_operator(opt::OpType::Log, std::move(in), std::move(out));
break;
case OpType::Matmul: {
auto obj = as<MatmulObj>(op);
IT_ASSERT(obj->getAct() == ActType::None);
in.push_back(opt::Tensor::share_single<float>(1.0f));
in.push_back(opt::Tensor::share_single<float>(1.0f));
in.push_back(
opt::Tensor::share_single<int>(obj->getTransA() ? 1 : 0));
in.push_back(
opt::Tensor::share_single<int>(obj->getTransB() ? 1 : 0));
ans.push_operator(opt::OpType::Gemm, std::move(in), std::move(out));
} break;
case OpType::Maximum:
ans.push_operator(opt::OpType::Max, std::move(in), std::move(out));
break;
case OpType::MaxPool: {
auto obj = as<MaxPoolObj>(op);
in.push_back(
opt::Tensor::share_vec<int>({obj->getDh(), obj->getDw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getKh(), obj->getKw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getPh(), obj->getPw()}));
in.push_back(
opt::Tensor::share_vec<int>({obj->getSh(), obj->getSw()}));
ans.push_operator(opt::OpType::AveragePool, std::move(in),
std::move(out));
} break;
case OpType::Minimum:
ans.push_operator(opt::OpType::Min, std::move(in), std::move(out));
break;
case OpType::Mul:
ans.push_operator(opt::OpType::Mul, std::move(in), std::move(out));
break;
case OpType::Neg:
ans.push_operator(opt::OpType::Neg, std::move(in), std::move(out));
break;
case OpType::Not:
ans.push_operator(opt::OpType::Not, std::move(in), std::move(out));
break;
case OpType::Or:
ans.push_operator(opt::OpType::Or, std::move(in), std::move(out));
break;
case OpType::Pad:
in.push_back(
opt::Tensor::share_vec<int>(as<PadObj>(op)->getPads()));
ans.push_operator(opt::OpType::Pad, std::move(in), std::move(out));
break;
case OpType::Reciprocal:
ans.push_operator(opt::OpType::Reciprocal, std::move(in),
std::move(out));
break;
case OpType::ReduceMean: {
const auto obj = as<ReduceMeanObj>(op);
const auto axes = obj->getAxes();
in.push_back(
opt::Tensor::share_vec<int>(vector(axes.begin(), axes.end())));
in.push_back(
opt::Tensor::share_single<int>(obj->getKeepDims() ? 1 : 0));
ans.push_operator(opt::OpType::ReduceMean, std::move(in),
std::move(out));
} break;
case OpType::Relu:
ans.push_operator(opt::OpType::Relu, std::move(in), std::move(out));
break;
case OpType::Reshape:
ans.push_operator(opt::OpType::Reshape, std::move(in),
std::move(out));
break;
case OpType::Resize:
ans.push_operator(opt::OpType::Resize, std::move(in),
std::move(out));
break;
case OpType::Round:
ans.push_operator(opt::OpType::Round, std::move(in),
std::move(out));
break;
case OpType::Shape:
ans.push_operator(opt::OpType::Shape, std::move(in),
std::move(out));
break;
case OpType::Sigmoid:
ans.push_operator(opt::OpType::Sigmoid, std::move(in),
std::move(out));
break;
case OpType::Sin:
ans.push_operator(opt::OpType::Sin, std::move(in), std::move(out));
break;
case OpType::SinH:
ans.push_operator(opt::OpType::Sinh, std::move(in), std::move(out));
break;
case OpType::Slice:
IT_TODO_HALT();
ans.push_operator(opt::OpType::Slice, std::move(in),
std::move(out));
break;
case OpType::Softmax:
ans.push_operator(opt::OpType::Softmax, std::move(in),
std::move(out));
break;
case OpType::Split:
ans.push_operator(opt::OpType::Split, std::move(in),
std::move(out));
break;
case OpType::Sqrt:
ans.push_operator(opt::OpType::Sqrt, std::move(in), std::move(out));
break;
case OpType::Sub:
ans.push_operator(opt::OpType::Sub, std::move(in), std::move(out));
break;
case OpType::Tan:
ans.push_operator(opt::OpType::Tan, std::move(in), std::move(out));
break;
case OpType::TanH:
ans.push_operator(opt::OpType::Tanh, std::move(in), std::move(out));
break;
case OpType::Transpose:
ans.push_operator(opt::OpType::Tanh, std::move(in), std::move(out));
break;
case OpType::Xor:
ans.push_operator(opt::OpType::Xor, std::move(in), std::move(out));
break;
default:
break;
}