forked from jiuyuan/InfiniTensor
parent
cc62a3216d
commit
28287f3782
|
@ -2,6 +2,10 @@
|
|||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/unary.h"
|
||||
#include "optimization/common.h"
|
||||
#include <algorithm>
|
||||
|
@ -137,65 +141,6 @@ optimization::DataType cast(DataType ty) {
|
|||
#undef IT
|
||||
}
|
||||
|
||||
optimization::OpType cast(OpType ty) {
|
||||
#define IT(A, B) \
|
||||
case OpType::A: \
|
||||
return optimization::OpType::B
|
||||
|
||||
switch (ty) {
|
||||
IT(Abs, Abs);
|
||||
IT(Add, Add);
|
||||
IT(And, And);
|
||||
IT(AvgPool, AveragePool);
|
||||
IT(BatchNorm, BatchNormalization);
|
||||
IT(Cast, Cast);
|
||||
IT(Clip, Clip);
|
||||
IT(Concat, Concat);
|
||||
IT(Conv, Conv);
|
||||
IT(ConvTrans, ConvTranspose);
|
||||
IT(Cos, Cos);
|
||||
IT(Div, Div);
|
||||
IT(Dropout, Dropout);
|
||||
IT(Erf, Erf);
|
||||
IT(Exp, Exp);
|
||||
IT(Flatten, Flatten);
|
||||
IT(Gather, Gather);
|
||||
IT(Identity, Identity);
|
||||
IT(Log, Log);
|
||||
IT(Matmul, MatMul);
|
||||
IT(MaxPool, MaxPool);
|
||||
IT(Mul, Mul);
|
||||
IT(Neg, Neg);
|
||||
IT(Not, Not);
|
||||
IT(Or, Or);
|
||||
IT(PRelu, PRelu);
|
||||
IT(Pad, Pad);
|
||||
IT(Pow, Pow);
|
||||
IT(ReduceMean, ReduceMean);
|
||||
IT(Relu, Relu);
|
||||
IT(Reshape, Reshape);
|
||||
IT(Resize, Resize);
|
||||
IT(Shape, Shape);
|
||||
IT(Sigmoid, Sigmoid);
|
||||
IT(Sin, Sin);
|
||||
IT(SinH, Sinh);
|
||||
IT(Slice, Slice);
|
||||
IT(Softmax, Softmax);
|
||||
IT(Split, Split);
|
||||
IT(Sqrt, Sqrt);
|
||||
IT(Sub, Sub);
|
||||
IT(Tan, Tan);
|
||||
IT(TanH, Tanh);
|
||||
IT(Transpose, Transpose);
|
||||
IT(Xor, Xor);
|
||||
default:
|
||||
IT_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
|
||||
#undef IT
|
||||
}
|
||||
|
||||
void GraphObj::optimize() {
|
||||
namespace opt = optimization;
|
||||
|
||||
|
@ -261,10 +206,19 @@ void GraphObj::optimize() {
|
|||
ans.push_operator(opt::OpType::Atanh, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::AvgPool:
|
||||
case OpType::AvgPool: {
|
||||
auto obj = as<AvgPoolObj>(op);
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getDh(), obj->getDw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getKh(), obj->getKw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getPh(), obj->getPw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getSh(), obj->getSw()}));
|
||||
ans.push_operator(opt::OpType::AveragePool, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
} break;
|
||||
case OpType::BatchNorm:
|
||||
ans.push_operator(opt::OpType::BatchNormalization, std::move(in),
|
||||
std::move(out));
|
||||
|
@ -369,6 +323,130 @@ void GraphObj::optimize() {
|
|||
ans.push_operator(opt::OpType::Identity, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Log:
|
||||
ans.push_operator(opt::OpType::Log, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Matmul: {
|
||||
auto obj = as<MatmulObj>(op);
|
||||
IT_ASSERT(obj->getAct() == ActType::None);
|
||||
in.push_back(opt::Tensor::share_single<float>(1.0f));
|
||||
in.push_back(opt::Tensor::share_single<float>(1.0f));
|
||||
in.push_back(
|
||||
opt::Tensor::share_single<int>(obj->getTransA() ? 1 : 0));
|
||||
in.push_back(
|
||||
opt::Tensor::share_single<int>(obj->getTransB() ? 1 : 0));
|
||||
ans.push_operator(opt::OpType::Gemm, std::move(in), std::move(out));
|
||||
} break;
|
||||
case OpType::Maximum:
|
||||
ans.push_operator(opt::OpType::Max, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::MaxPool: {
|
||||
auto obj = as<MaxPoolObj>(op);
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getDh(), obj->getDw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getKh(), obj->getKw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getPh(), obj->getPw()}));
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>({obj->getSh(), obj->getSw()}));
|
||||
ans.push_operator(opt::OpType::AveragePool, std::move(in),
|
||||
std::move(out));
|
||||
} break;
|
||||
case OpType::Minimum:
|
||||
ans.push_operator(opt::OpType::Min, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Mul:
|
||||
ans.push_operator(opt::OpType::Mul, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Neg:
|
||||
ans.push_operator(opt::OpType::Neg, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Not:
|
||||
ans.push_operator(opt::OpType::Not, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Or:
|
||||
ans.push_operator(opt::OpType::Or, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Pad:
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>(as<PadObj>(op)->getPads()));
|
||||
ans.push_operator(opt::OpType::Pad, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Reciprocal:
|
||||
ans.push_operator(opt::OpType::Reciprocal, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::ReduceMean: {
|
||||
const auto obj = as<ReduceMeanObj>(op);
|
||||
const auto axes = obj->getAxes();
|
||||
in.push_back(
|
||||
opt::Tensor::share_vec<int>(vector(axes.begin(), axes.end())));
|
||||
in.push_back(
|
||||
opt::Tensor::share_single<int>(obj->getKeepDims() ? 1 : 0));
|
||||
ans.push_operator(opt::OpType::ReduceMean, std::move(in),
|
||||
std::move(out));
|
||||
} break;
|
||||
case OpType::Relu:
|
||||
ans.push_operator(opt::OpType::Relu, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Reshape:
|
||||
ans.push_operator(opt::OpType::Reshape, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Resize:
|
||||
ans.push_operator(opt::OpType::Resize, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Round:
|
||||
ans.push_operator(opt::OpType::Round, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Shape:
|
||||
ans.push_operator(opt::OpType::Shape, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Sigmoid:
|
||||
ans.push_operator(opt::OpType::Sigmoid, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Sin:
|
||||
ans.push_operator(opt::OpType::Sin, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::SinH:
|
||||
ans.push_operator(opt::OpType::Sinh, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Slice:
|
||||
IT_TODO_HALT();
|
||||
ans.push_operator(opt::OpType::Slice, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Softmax:
|
||||
ans.push_operator(opt::OpType::Softmax, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Split:
|
||||
ans.push_operator(opt::OpType::Split, std::move(in),
|
||||
std::move(out));
|
||||
break;
|
||||
case OpType::Sqrt:
|
||||
ans.push_operator(opt::OpType::Sqrt, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Sub:
|
||||
ans.push_operator(opt::OpType::Sub, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Tan:
|
||||
ans.push_operator(opt::OpType::Tan, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::TanH:
|
||||
ans.push_operator(opt::OpType::Tanh, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Transpose:
|
||||
ans.push_operator(opt::OpType::Tanh, std::move(in), std::move(out));
|
||||
break;
|
||||
case OpType::Xor:
|
||||
ans.push_operator(opt::OpType::Xor, std::move(in), std::move(out));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue