Add: Matmul Transpose plans

This commit is contained in:
Liyan Zheng 2023-04-28 19:13:01 +08:00
parent 95a8b90fa7
commit c875f3cbb8
6 changed files with 172 additions and 70 deletions

View File

@ -33,9 +33,9 @@ class NMutator : public Mutator {
bool isMultiBranchMergable(const Graph &in_graph) override;
void setToNaiveMembound();
void setMaxDepth(int _maxDepth) {
maxDepth = _maxDepth;
searchFilter = true;
void setMaxDepth(int _maxDepth) {
maxDepth = _maxDepth;
searchFilter = true;
}
long long cntStates = 0;
long long cntCandidates = 0;
@ -67,11 +67,11 @@ class NMutator : public Mutator {
// TODO: recover these rules
Graph transformConvtransposed1x1(Operator _op);
// Graph transformConvtransposed(Operator op);
Graph transformConv1x1(Operator op);
vector<Graph> transformConv1x1(Operator op);
Graph transformG2bmm(Operator op);
Graph transformGbmm(Operator op);
Graph transformDialtedConv(Operator _op);
Graph transformConv1xk(Operator op);
vector<Graph> transformConv1xk(Operator op);
// Graph transformConv1xk(Operator op);
Graph transformConvToGEMMReduce(Operator _op);
Graph transformConvTranposeToGEMMReduce(Operator _op);

View File

@ -146,6 +146,8 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling,
opCnt[op->getOpType()]++;
if (!ctcMap[op->getGuid()])
opNonCtcCnt[op->getOpType()]++;
else
opNonCtcCnt[op->getOpType()]; // Create a new entry
}
}
if (profiling)

View File

@ -100,6 +100,15 @@ Graph SearchEngine::run(const Graph graph) {
}
std::sort(bestGraphs.begin(), bestGraphs.end(), graphTimeComparer);
// Check optimized graphs are legal
for (auto g : bestGraphs) {
g->checkValid();
IT_ASSERT(graph->getInputs().size() == g->getInputs().size(),
graph->toString() + string("\n") + g->toString());
IT_ASSERT(graph->getOutputs().size() == g->getOutputs().size(),
graph->toString() + string("\n") + g->toString());
}
std::cout << "[INFO] best fused graph: " << std::endl;
std::cout << "[INFO] perf: " << getEstimatedGraphPerf(bestGraphs[0])
<< std::endl;

View File

@ -6,6 +6,7 @@
#include "nnet/dbg.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/any.h"
#ifdef INFINI_USE_TVM
#include "tvm/runtime/device_api.h"
#endif
@ -140,8 +141,12 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
kernel->compute(op, perfData, this);
else
kernel->compute(op, this);
// FIXME: transpose
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape)
bool isFakeOp = (as<AnyObj>(op) &&
as<AnyObj>(op)->getKernelName() == string("FakeOp"));
if (as<AnyObj>(op))
dbg(op, as<AnyObj>(op)->getKernelName() == string("FakeOp"));
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape &&
!isFakeOp)
kernels.emplace_back(op, kernel, perfData);
}
for (auto &[op, kernel, perfData] : kernels) {

View File

@ -74,6 +74,7 @@ void any_kernel_mapping(vector<float *> inputs, vector<float *> outputs,
IT_ASSERT(outputs.size() == 1);
conv3x3ToReduce(attr[0], attr[1], attr[2], attr[3], inputs[0],
outputs[0], inputs.size() > 1 ? inputs[1] : nullptr);
} else if (kernelName == "FakeOp") {
} else {
std::cout << "Unimplemented AnyOp cuda kernel: " << kernelName
<< std::endl;

View File

@ -97,12 +97,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (Graph g = transformConv1x1(computeOps[0])) {
for (auto g : transformConv1x1(computeOps[0]))
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformConv1xk(computeOps[0])) {
for (auto g : transformConv1xk(computeOps[0]))
out_graphs.emplace_back(g);
}
if (Graph g = transformG2bmm(computeOps[0])) {
out_graphs.emplace_back(g);
}
@ -711,31 +709,54 @@ Graph NMutator::transformGbmm(Operator _op) {
return g;
}
Graph NMutator::transformConv1x1(Operator _op) {
vector<Graph> NMutator::transformConv1x1(Operator _op) {
vector<Graph> ret;
auto op = as<ConvObj>(_op);
if (!op)
return nullptr;
return {};
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
Shape shapeA = op->getInputs(0)->getDims();
Shape shapeW = op->getInputs(1)->getDims();
Shape shapeO = op->getOutput()->getDims();
if (sh != 1 || sw != 1 || dh != 1 || dw != 1 || shapeW[2] != 1 ||
shapeW[3] != 1)
return nullptr;
auto g = make_ref<GraphObj>(runtime);
return {};
if (shapeA[0] == 1) {
auto A =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[1], shapeA[0] * shapeA[2] *
shapeA[3]}) // [C, N*H*W]
->getOutput();
auto B = g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
vector{shapeW[0], shapeW[1]}) // [F, C]
->getOutput();
auto O =
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
{
auto g = make_ref<GraphObj>(runtime);
auto A =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[1], shapeA[2] * shapeA[3]})
->getOutput(); // [C, H*W]
auto B =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
vector{shapeW[0], shapeW[1]}) // [F, C]
->getOutput();
auto O = g->addOp<MatmulObj>(B, A, nullptr, false, false)
->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
ret.emplace_back(g);
}
{
auto g = make_ref<GraphObj>(runtime);
auto A = g->addOp<ReshapeObj>(
g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[1], shapeA[2] * shapeA[3]}) // [C, HW]
->getOutput();
auto B = g->addOp<TransposeObj>(
g->cloneTensor(op->getInputs(1)), //[C,F,1,1]
nullptr, vector{1, 0, 2, 3})
->getOutput();
B = g->addOp<ReshapeObj>(B, nullptr,
vector{shapeW[1], shapeW[0]}) // [C, F]
->getOutput();
auto O = g->addOp<MatmulObj>(B, A, nullptr, true, false)
->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
ret.emplace_back(g);
}
// } else { // Tranpose + Matmul + Transpose
// auto A = g->addOp<TransposeObj>(g->cloneTensor(op->getInputs(0)),
// nullptr, vector{1, 0, 2, 3})
@ -762,6 +783,7 @@ Graph NMutator::transformConv1x1(Operator _op) {
// O, g->cloneTensor(op->getOutput()), vector{1, 0, 2, 3})
// ->getOutput(); // [F, N*H*W]
} else { // BGemm
auto g = make_ref<GraphObj>(runtime);
auto A =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[0], shapeA[1],
@ -775,17 +797,19 @@ Graph NMutator::transformConv1x1(Operator _op) {
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
ret.emplace_back(g);
}
return g;
return ret;
}
Graph NMutator::transformConv1xk(Operator _op) {
vector<Graph> NMutator::transformConv1xk(Operator _op) {
vector<Graph> ret;
auto op = as<ConvObj>(_op);
if (!op)
return nullptr;
return {};
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (sh != 1 || sw != 1 || dh != 1 || dw != 1)
return nullptr;
return {};
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
op->print();
const auto &A = op->getInputs(0);
@ -794,46 +818,107 @@ Graph NMutator::transformConv1xk(Operator _op) {
const Shape &shapeA = A->getDims();
const Shape &shapeW = W->getDims();
if (shapeW[2] == 1 || shapeW[3] == 1) {
auto g = make_ref<GraphObj>(runtime);
auto A0 = g->cloneTensor(A);
auto W0 = g->cloneTensor(W); // [F, C, R, S]
auto A1 = g->addOp<TransposeObj>(A0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [N, H, W, C]
auto A2 =
g->addOp<ReshapeObj>(
A1, nullptr,
vector<int>{shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]})
->getOutput(); // [N*H*W, C]
auto W1 = g->addOp<TransposeObj>(W0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [F,R,S,C]
auto W2 =
g->addOp<ReshapeObj>(
W1, nullptr,
vector<int>{shapeW[2] * shapeW[3] * shapeW[0], shapeW[1]})
->getOutput(); // [F*R*S, C]
auto O0 = g->addOp<MatmulObj>(A2, W2, nullptr, 0, 1)
->getOutput(); // [N*H*W, F*R*S]
vector<int> args{op->getAct() != ActType::None,
n,
h,
w,
f,
r,
s,
O->getDims()[2],
O->getDims()[3],
ph,
pw,
sh,
sw,
dh,
dw};
const string kernelName = "reduceConvRxSToNCHW";
auto O3 = g->addOpWithOutputs<AnyObj>(
vector{O0}, vector{g->cloneTensor(O)}, kernelName, args);
return g;
{
auto g = make_ref<GraphObj>(runtime);
auto A0 = g->cloneTensor(A);
auto W0 = g->cloneTensor(W); // [F, C, R, S]
auto A1 =
g->addOp<TransposeObj>(A0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [N, H, W, C]
auto A2 =
g->addOp<ReshapeObj>(
A1, nullptr,
vector<int>{shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]})
->getOutput(); // [N*H*W, C]
// auto A2 =
// g->addTensor({shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]});
// dbg(A0, A2);
// g->addOpWithOutputs<AnyObj>(vector{A0}, vector{A2},
// string("FakeOp"), vector<int>{});
auto W1 =
g->addOp<TransposeObj>(W0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [F,R,S,C]
auto W2 =
g->addOp<ReshapeObj>(
W1, nullptr,
vector<int>{shapeW[2] * shapeW[3] * shapeW[0], shapeW[1]})
->getOutput(); // [F*R*S, C]
auto O0 =
g->addOp<MatmulObj>(A2, W2, nullptr, 0, 1) // Original: W X A
->getOutput(); // [N*H*W, F*R*S]
vector<int> args{op->getAct() != ActType::None,
n,
h,
w,
f,
r,
s,
O->getDims()[2],
O->getDims()[3],
ph,
pw,
sh,
sw,
dh,
dw};
const string kernelName = "reduceConvRxSToNCHW";
// const string kernelName = "FakeOp";
auto O3 = g->addOpWithOutputs<AnyObj>(
vector{O0}, vector{g->cloneTensor(O)}, kernelName, args);
ret.emplace_back(g);
}
{
auto g = make_ref<GraphObj>(runtime);
auto A0 = g->cloneTensor(A);
auto W0 = g->cloneTensor(W); // [F, C, R, S]
auto A1 =
g->addOp<TransposeObj>(A0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [N, H, W, C]
auto A2 =
g->addOp<ReshapeObj>(
A1, nullptr,
vector<int>{shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]})
->getOutput(); // [N*H*W, C]
// auto A2 = // [N*H*W, C]
// g->addTensor({shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]});
// dbg(A0, A2);
// g->addOpWithOutputs<AnyObj>(vector{A0}, vector{A2},
// string("FakeOp"), vector<int>{});
auto W1 =
g->addOp<TransposeObj>(W0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [F,R,S,C]
auto W2 =
g->addOp<ReshapeObj>(
W1, nullptr,
vector<int>{shapeW[2] * shapeW[3] * shapeW[0], shapeW[1]})
->getOutput(); // [F*R*S, C]
auto O0 =
g->addOp<MatmulObj>(W2, A2, nullptr, 0, 1) // Original: W X A
->getOutput(); // [F*R*S, N*H*W]
vector<int> args{op->getAct() != ActType::None,
n,
h,
w,
f,
r,
s,
O->getDims()[2],
O->getDims()[3],
ph,
pw,
sh,
sw,
dh,
dw};
// FIXME: FRS,NHW->NHWF
const string kernelName = "reduceConvRxSToNCHW";
// const string kernelName = "FakeOp";
auto O3 = g->addOpWithOutputs<AnyObj>(
vector{O0}, vector{g->cloneTensor(O)}, kernelName, args);
ret.emplace_back(g);
}
}
return nullptr;
return ret;
}
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,