Add: reduce in Any

This commit is contained in:
Liyan Zheng 2023-04-23 21:36:12 +08:00
parent 131a679340
commit 1ba78d7f89
12 changed files with 383 additions and 50 deletions

View File

@ -87,9 +87,10 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
int repeat = 1000) const;
protected:
void printProfilingData(double totTime,
void printProfilingData(double totalTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const;
const std::map<OpType, int> &opCnt,
const std::map<OpType, int> &opNonCtcCnt) const;
virtual void copyBlobInsideRuntime(void *dst, const void *src,
size_t bytes) const = 0;
};

View File

@ -11,4 +11,9 @@ void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
int act, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh,
int sw, int dh, int dw);
} // namespace infini
void reduceConvRxSToNCHW(float *input, float *bias, float *output, int act,
int n, int h, int w, int f, int r, int s, int oh,
int ow, int ph, int pw, int sh, int sw, int dh,
int dw);
} // namespace infini

View File

@ -31,11 +31,13 @@ class CudaRuntimeObj : public RuntimeObj {
void sync() const;
CudaPtr alloc(size_t size) override {
void *ptr;
// printf("Try to cudaMalloc: %lu bytes\n", size);
checkCudaError(cudaMalloc(&ptr, size));
allocatedGPUMemorySize += size;
allocationMap[ptr] = size;
// printf("cuda malloc: %p %lu bytes, total %lu bytes (%.2lf GB)\n",
// ptr, size, allocatedGPUMemorySize,
// ptr,
// size, allocatedGPUMemorySize,
// double(allocatedGPUMemorySize) / 1024 / 1024 / 1024);
return ptr;
}

View File

@ -10,7 +10,7 @@ class AnyObj : public OperatorObj {
public:
AnyObj(GraphObj *graph, const TensorVec &inputs, const TensorVec &outputs,
string &kernelName, const vector<int> &attr);
const string &kernelName, const vector<int> &attr);
OP_CLONE(AnyObj);

View File

@ -885,7 +885,7 @@ class OnnxStub:
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty in [backend.OpType.ConvTransNHWC, backend.OpType.GBMM,
backend.OpType.G2BMM]:
backend.OpType.G2BMM, backend.OpType.Any]:
ctx.push_node(
make_node(
ty.name,

View File

@ -58,8 +58,8 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
opCnt[op->getOpType()]++;
}
}
if (profiling)
printProfilingData(totalTime, opTime, opCnt);
// if (profiling)
// printProfilingData(totalTime, opTime, opCnt);
}
map<UidBaseType, bool>
@ -89,7 +89,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling,
// Statistics
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
std::map<OpType, int> opCnt, opNonCtcCnt;
// compile-time computable
map<UidBaseType, bool> ctcMap = getCompileTimeComputableAttribute(graph);
@ -144,21 +144,26 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling,
printf(" op_time %lf\n", time);
opTime[op->getOpType()] += time;
opCnt[op->getOpType()]++;
if (!ctcMap[op->getGuid()])
opNonCtcCnt[op->getOpType()]++;
}
}
if (profiling)
printProfilingData(totalTime, opTime, opCnt);
printProfilingData(totalTime, opTime, opCnt, opNonCtcCnt);
return totalTime;
}
void RuntimeObj::printProfilingData(double totalTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const {
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
void RuntimeObj::printProfilingData(
double totalTime, const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt,
const std::map<OpType, int> &opNonCtcCnt) const {
printf("%11s %3s %5s %7s %7s %7s\n", "Op", "Cnt", "#NCtc", "T_tot",
"Percent", "T_mean");
for (const auto &[type, t] : opTime) {
printf("%11s %3d %7.3f %7.1f %7.3f\n",
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
t / totalTime * 100, t / opCnt.at(type));
printf("%11s %3d %5d %7.3f %7.1f %7.3f\n",
OpRegistry::getOpName(type).data(), opCnt.at(type),
opNonCtcCnt.at(type), t, t / totalTime * 100,
t / opCnt.at(type));
}
}

View File

@ -346,10 +346,10 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
if (node.type == 1) { // If it has computing OPs
auto mutatedGraphs = mutator->run(node.graph);
// // HACK: only try the first one for debug
if (mutatedGraphs.size() >= 2) {
mutatedGraphs.resize(2);
// if (mutatedGraphs.size() > 2)
// mutatedGraphs.resize(2);
if (mutatedGraphs.size() >= 2)
mutatedGraphs = {mutatedGraphs[1]};
}
for (auto graph : graphs) {
for (auto mutatedGraph : mutatedGraphs) {
std::vector<Operator> ops;

View File

@ -101,6 +101,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Resize)
.VALUE(OpType, Dropout)
.VALUE(OpType, MemBound)
.VALUE(OpType, Any)
.export_values();
py::enum_<TensorType>(m, "TensorType")

View File

@ -39,6 +39,19 @@ void any_kernel_mapping(vector<float *> inputs, vector<float *> outputs,
attr[4], attr[5], attr[6], attr[7], attr[8],
attr[9], attr[10], attr[11], attr[12], attr[13],
attr[14]);
} else if (kernelName == "reduceConvRxSToNCHW") {
IT_ASSERT(attr.size() == 15);
IT_ASSERT(inputs.size() == 1 || inputs.size() == 2)
IT_ASSERT(outputs.size() == 1);
// float *input, float *bias, float *output, int act,
// int n, int h, int w, int f, int r, int s,
// int oh, int ow, int ph, int pw, int sh, int
// sw, int dh, int dw
reduceConvRxSToNCHW(inputs[0], inputs.size() > 1 ? inputs[1] : nullptr,
outputs[0], attr[0], attr[1], attr[2], attr[3],
attr[4], attr[5], attr[6], attr[7], attr[8],
attr[9], attr[10], attr[11], attr[12], attr[13],
attr[14]);
} else {
std::cout << "Unimplemented AnyOp cuda kernel: " << kernelName
<< std::endl;

View File

@ -0,0 +1,280 @@
#include "core/common.h"
#include <vector>
using namespace std;
template <class T>
__global__ void reduce_merge_conv_3x3_1x1(
T *__restrict__ input, T *__restrict__ output, T *__restrict__ bias,
const int N, const int H, const int W, const int F, const int N_offset,
const int H_offset, const int W_offset, const int F_offset,
const int out_N_offset, const int out_F_offset, const int out_H_offset,
const int out_W_offset, const int num) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int tmptid = tid;
const int n = (tmptid / out_N_offset);
tmptid -= n * out_N_offset;
const int f = tmptid / out_F_offset;
tmptid -= f * out_F_offset;
const int h = tmptid / out_H_offset;
tmptid -= h * out_H_offset;
const int w = tmptid / out_W_offset;
const int noff = n * N_offset;
const int hoff = h * H_offset;
const int woff = w * W_offset;
const int foff = f * F_offset;
input += noff + foff + woff + hoff;
T res = 0;
res += input[4];
res += input[9];
if (h < H - 1) {
res += input[H_offset + 7];
if (w < W - 1)
res += input[H_offset + W_offset + 8];
if (w > 0)
res += input[H_offset - W_offset + 6];
}
if (h > 0) {
res += input[1 - H_offset];
if (w < W - 1)
res += input[W_offset - H_offset + 2];
if (w > 0)
res += input[-1 * H_offset - W_offset];
}
if (w < W - 1)
res += input[5 + W_offset];
if (w > 0)
res += input[3 - W_offset];
output[tid] = max(res + bias[f], 0.f);
}
}
template <class T>
__global__ void reduce_merge_conv_3x3(
T *__restrict__ input, T *__restrict__ output, T *__restrict__ bias,
const int N, const int H, const int W, const int F, const int N_offset,
const int H_offset, const int W_offset, const int F_offset,
const int out_N_offset, const int out_F_offset, const int out_H_offset,
const int out_W_offset, const int num, const int act) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int tmptid = tid;
const int n = (tmptid / out_N_offset);
tmptid -= n * out_N_offset;
const int f = tmptid / out_F_offset;
tmptid -= f * out_F_offset;
const int h = tmptid / out_H_offset;
tmptid -= h * out_H_offset;
const int w = tmptid / out_W_offset;
const int noff = n * N_offset;
const int hoff = h * H_offset;
const int woff = w * W_offset;
const int foff = f * F_offset;
input += noff + foff + woff + hoff;
T res = 0;
res += input[4];
if (h < H - 1) {
res += input[H_offset + 7];
if (w < W - 1)
res += input[H_offset + W_offset + 8];
if (w > 0)
res += input[H_offset - W_offset + 6];
}
if (h > 0) {
res += input[1 - H_offset];
if (w < W - 1)
res += input[W_offset - H_offset + 2];
if (w > 0)
res += input[-1 * H_offset - W_offset];
}
if (w < W - 1)
res += input[5 + W_offset];
if (w > 0)
res += input[3 - W_offset];
if (act)
output[tid] = max(res + bias[f], 0.f);
else
output[tid] = res + bias[f];
}
}
template <class T>
__global__ void
reduce_2(T *__restrict__ input, T *__restrict__ output, T *__restrict__ bias,
const int N, const int F, const int H, const int W, const int N_offset,
const int F_offset, const int H_offset, const int W_offset,
const int out_N_offset, const int out_F_offset, const int out_H_offset,
const int out_W_offset, const int num) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int tmptid = tid;
const int n = tmptid / out_N_offset;
tmptid -= n * out_N_offset;
const int f = tmptid / out_F_offset;
tmptid -= f * out_F_offset;
const int h = tmptid / out_H_offset;
tmptid -= h * out_H_offset;
const int w = tmptid / out_W_offset;
const int noff = n * N_offset;
const int foff = f * F_offset * 4;
const int hoff = h * H_offset;
const int woff = w * W_offset;
input += noff + foff + woff + hoff;
T res = input[0];
if (w != W - 1)
res += input[F_offset * 2 + 3];
if (h != H - 1) {
res += input[F_offset + 3 * H_offset];
if (w != W - 1)
res += input[F_offset * 3 + 3 * H_offset + 3];
}
output[tid] = max(res + bias[f], 0.f);
}
}
__global__ void reduceConvRxSToNCHWKernel(
float *__restrict__ input, float *__restrict__ bias,
float *__restrict__ output, const int act, const int n, const int f,
const int h, const int w, const int oh, const int ow, const int r,
const int s, const int ph, const int pw, const int dh, const int dw) {
// input shape: (n, h, w, f, r, s)
// output shape: (n, f, oh, ow)
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int out_N_offset = f * oh * ow, out_F_offset = oh * ow,
out_H_offset = ow, out_W_offset = 1;
const int num = out_N_offset * n;
if (tid < num) {
// output index
int tmptid = tid;
const int nid = (tmptid / out_N_offset);
tmptid -= nid * out_N_offset;
const int fid = tmptid / out_F_offset;
tmptid -= fid * out_F_offset;
const int hid = tmptid / out_H_offset;
tmptid -= hid * out_H_offset;
const int wid = tmptid / out_W_offset;
// Input index
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
nchunck = h * hchunk;
float *__restrict__ nfinput = input + nid * nchunck + fid * fchunck;
float imm = 0.0;
const int ihst = hid, iwst = wid;
for (int ri = 0; ri < r; ++ri) {
for (int si = 0; si < s; ++si) {
int ihid = ihst + (ri - r / 2) * dh;
int iwid = iwst + (si - s / 2) * dw;
if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) {
imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s +
si);
}
}
}
if (bias) {
imm += bias[fid];
}
if (act) {
imm = imm > 0.0 ? imm : 0;
}
output[tid] = imm;
}
}
namespace infini {
void hetConvToMMReduce(int n, int h, int w, int f, float *input, float *output,
float *bias) {
const int kBlockSize = 128;
vector<int> in_params = {n, h, w, f}; // NHWF
vector<int> out_params = {n, f, h, w};
int in_base = 10;
int out_base = 1;
vector<int> in_offsets;
vector<int> out_offsets;
for (int i = 0; i < 4; ++i) {
in_offsets.push_back(in_base);
in_base *= in_params[3 - i];
out_offsets.push_back(out_base);
out_base *= out_params[3 - i];
}
reduce_merge_conv_3x3_1x1<float>
<<<(out_base + kBlockSize - 1) / kBlockSize, kBlockSize>>>(
input, output, bias, in_params[0], in_params[1], in_params[2],
in_params[3], in_offsets[3], in_offsets[2], in_offsets[1],
in_offsets[0], out_offsets[3], out_offsets[2], out_offsets[1],
out_offsets[0], out_base);
}
void conv5x5ToConv3x3Reduce(int n, int f, int h, int w, float *input,
float *output, float *bias) {
const int kBlockSize = 128;
vector<int> params{n, f, h, w}; // NFHW
vector<int> ranges(4);
ranges[3] = params[3] + 2;
ranges[2] = params[2] + 2;
ranges[1] = params[1] * 4;
ranges[0] = params[0];
int in_base = 1;
int out_base = 1;
vector<int> in_offsets;
vector<int> out_offsets;
for (int i = 0; i < 4; ++i) {
in_offsets.push_back(in_base);
in_base *= ranges[3 - i];
out_offsets.push_back(out_base);
out_base *= params[3 - i];
}
reduce_2<float><<<(out_base + kBlockSize - 1) / kBlockSize, kBlockSize>>>(
input, output, bias, params[0], params[1], params[2], params[3],
in_offsets[3], in_offsets[2], in_offsets[1], in_offsets[0],
out_offsets[3], out_offsets[2], out_offsets[1], out_offsets[0],
out_base);
}
// [NHW,FRS] -> [NFHW]
void conv3x3ToReduce(int n, int h, int w, int f, float *input, float *output,
float *bias) {
const int kBlockSize = 128;
vector<int> in_params = {n, h, w, f}; // NHWF
vector<int> out_params = {n, f, h, w};
int in_base = 9;
int out_base = 1;
vector<int> in_offsets;
vector<int> out_offsets;
for (int i = 0; i < 4; ++i) {
in_offsets.push_back(in_base);
in_base *= in_params[3 - i];
out_offsets.push_back(out_base);
out_base *= out_params[3 - i];
}
reduce_merge_conv_3x3<float>
<<<(out_base + kBlockSize - 1) / kBlockSize, kBlockSize>>>(
input, output, bias, in_params[0], in_params[1], in_params[2],
in_params[3], in_offsets[3], in_offsets[2], in_offsets[1],
in_offsets[0], out_offsets[3], out_offsets[2], out_offsets[1],
out_offsets[0], out_base, 0);
}
void reduceConvRxSToNCHW(float *input, float *bias, float *output, int act,
int n, int h, int w, int f, int r, int s, int oh,
int ow, int ph, int pw, int sh, int sw, int dh,
int dw) {
IT_ASSERT(sh == 1 && sw == 1,
"reduceConvRxSToNCHWKernel_kernel only support sh=sw=1");
IT_ASSERT(dh == 1 && dw == 1,
"reduceConvRxSToNCHWKernel_kernel only support dh=dw=1");
const int blocksize = 512;
const int gridsize = (n * f * oh * ow + blocksize - 1) / blocksize;
cudaStream_t stream(cudaStreamPerThread);
reduceConvRxSToNCHWKernel<<<gridsize, blocksize, 0, stream>>>(
input, bias, output, act, n, f, h, w, oh, ow, r, s, ph, pw, dh, dw);
}
} // namespace infini

View File

@ -9,6 +9,7 @@
#include "nnet/derivator.h"
#include "operators/G2BMM.h"
#include "operators/GBMM.h"
#include "operators/any.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/membound.h"
@ -249,6 +250,11 @@ void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
// }
nnet::Expr NMutator::opToExpression(Operator opT) {
if (auto op = as<ConvObj>(opT)) {
if (op->getSh() != 1 || op->getSw() != 1 || op->getDh() != 1 ||
op->getDw() != 1)
return nullptr;
}
auto [expr, mapNameNToTensorT] = extractOp(opT);
IT_ASSERT(expr,
"Cannot convert " + opT->toString() + " to an NNet expression");
@ -625,16 +631,14 @@ Graph NMutator::transformConv1x1(Operator _op) {
if (!op)
return nullptr;
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (sh != 1 || sw != 1 || dh != 1 || dw != 1)
return nullptr;
Shape shapeA = op->getInputs(0)->getDims();
Shape shapeW = op->getInputs(1)->getDims();
// TODO: support batch size > 1
if (shapeA[0] != 1)
Shape shapeO = op->getOutput()->getDims();
if (sh != 1 || sw != 1 || dh != 1 || dw != 1 || shapeW[2] != 1 ||
shapeW[3] != 1)
return nullptr;
if (op->getPh() == 0 && op->getSh() == 1 && shapeW[2] == 1 &&
shapeW[3] == 1) {
auto g = make_ref<GraphObj>(runtime);
auto g = make_ref<GraphObj>(runtime);
if (shapeA[0] == 1) {
auto A =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[1], shapeA[0] * shapeA[2] *
@ -647,9 +651,27 @@ Graph NMutator::transformConv1x1(Operator _op) {
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
return g;
} else {
auto A = g->addOp<TransposeObj>(g->cloneTensor(op->getInputs(0)),
nullptr, vector{1, 0, 2, 3})
->getOutput(); // [C,N,H,W]
A = g->addOp<ReshapeObj>(A, nullptr,
vector{shapeA[1], shapeA[0] * shapeA[2] *
shapeA[3]}) // [C, N*H*W]
->getOutput();
auto B = g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
vector{shapeW[0], shapeW[1]}) // [F, C]
->getOutput();
auto O =
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, NHW]
O = g->addOp<ReshapeObj>(
O, nullptr, Shape{shapeO[1], shapeO[0], shapeO[2], shapeO[3]})
->getOutput(); // [F, NHW]
O = g->addOpWithOutputs<TransposeObj>(
O, g->cloneTensor(op->getOutput()), vector{1, 0, 2, 3})
->getOutput(); // [F, N*H*W]
}
return nullptr;
return g;
}
Graph NMutator::transformConv1xk(Operator _op) {
@ -659,6 +681,7 @@ Graph NMutator::transformConv1xk(Operator _op) {
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (sh != 1 || sw != 1 || dh != 1 || dw != 1)
return nullptr;
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
op->print();
const auto &A = op->getInputs(0);
const auto &W = op->getInputs(1);
@ -668,7 +691,7 @@ Graph NMutator::transformConv1xk(Operator _op) {
if (shapeW[2] == 1 || shapeW[3] == 1) {
auto g = make_ref<GraphObj>(runtime);
auto A0 = g->cloneTensor(A);
auto W0 = g->cloneTensor(W);
auto W0 = g->cloneTensor(W); // [F, C, R, S]
auto A1 = g->addOp<TransposeObj>(A0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [N, H, W, C]
auto A2 =
@ -677,29 +700,32 @@ Graph NMutator::transformConv1xk(Operator _op) {
vector<int>{shapeA[0] * shapeA[2] * shapeA[3], shapeA[1]})
->getOutput(); // [N*H*W, C]
auto W1 = g->addOp<TransposeObj>(W0, nullptr, vector<int>{0, 2, 3, 1})
->getOutput(); // [R, S, F, C]
->getOutput(); // [F,R,S,C]
auto W2 =
g->addOp<ReshapeObj>(
W1, nullptr,
vector<int>{shapeW[2] * shapeW[3] * shapeW[0], shapeW[1]})
->getOutput(); // [R*S*F, C]
auto O0 = g->addOp<MatmulObj>(W2, A2, nullptr, 0, 1)
->getOutput(); // [R*S*F, N*H*W]
auto O1 =
g->addOp<ReshapeObj>(O0, nullptr,
vector<int>{shapeW[2] * shapeW[3], shapeW[0],
shapeA[0] * shapeA[2] * shapeA[3]})
->getOutput(); // [R*S, F, N*H*W]
auto O2 = g->addOp<ReduceMeanObj>(O1, nullptr, optional{vector<int>{0}},
false)
->getOutput(); // [F, N*H*W]
auto O3 = g->addOp<ReshapeObj>(
O2, nullptr,
vector<int>{shapeW[0], shapeA[0], shapeA[2], shapeA[3]})
->getOutput(); // [F, N, H, W]
g->addOpWithOutputs<TransposeObj>(O3, g->cloneTensor(O),
vector<int>{1, 0, 2, 3});
std::cout << "Replace 1xk/kx1 conv successfully" << std::endl;
->getOutput(); // [F*R*S, C]
auto O0 = g->addOp<MatmulObj>(A2, W2, nullptr, 0, 1)
->getOutput(); // [N*H*W, F*R*S]
vector<int> args{op->getAct() != ActType::None,
n,
h,
w,
f,
r,
s,
O->getDims()[2],
O->getDims()[3],
ph,
pw,
sh,
sw,
dh,
dw};
const string kernelName = "reduceConvRxSToNCHW";
auto O3 = g->addOpWithOutputs<AnyObj>(
vector{O0}, vector{g->cloneTensor(O)}, kernelName, args);
return g;
}
return nullptr;

View File

@ -3,7 +3,7 @@
namespace infini {
AnyObj::AnyObj(GraphObj *graph, const TensorVec &inputs,
const TensorVec &outputs, string &kernelName,
const TensorVec &outputs, const string &kernelName,
const vector<int> &attr)
: OperatorObj(OpType::Any, inputs, outputs), kernelName(kernelName),
attr(attr) {