Add: Longformer models

This commit is contained in:
Liyan Zheng 2023-04-22 16:00:29 +08:00
parent 4f02eeb08c
commit 84f9d6731a
11 changed files with 240 additions and 32 deletions

View File

@ -26,6 +26,7 @@ class NMutator : public Mutator {
vector<Graph> run(const Graph &in_graph) override;
Graph fuseVertically(const Graph &in_graph) override;
bool isMultiBranchMergable(const Graph &in_graph) override;
void setToNaiveMembound();
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
@ -61,7 +62,12 @@ class NMutator : public Mutator {
// Graph transformConvtransposed(Operator op);
// Graph transformDialtedConv(Operator op);
Graph transformConv1x1(Operator op);
Graph transformG2bmm(Operator op);
Graph transformGbmm(Operator op);
// Graph transformConv1xk(Operator op);
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
Tensor output = nullptr);
};
} // namespace infini

View File

@ -6,6 +6,7 @@
namespace infini {
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
Graph getLongformer(Runtime runtime, int bs);
vector<Tensor> runInfoGAN(int nLayers);
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,

View File

@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj {
* @param output The output tensor.
* @param dims The shape of the output tensor.
*/
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims = {});
OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -884,7 +884,8 @@ class OnnxStub:
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.ConvTransNHWC:
elif ty in [backend.OpType.ConvTransNHWC, backend.OpType.GBMM,
backend.OpType.G2BMM]:
ctx.push_node(
make_node(
ty.name,
@ -983,3 +984,9 @@ def _parse_data(tensor: TensorProto) -> List[Any]:
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
def save_onnx(opt_g, filename: str):
stub = OnnxStub.from_graph(opt_g)
with open(filename, "wb") as f:
f.write(stub.to_onnx("optimized").SerializeToString())

View File

@ -70,7 +70,6 @@ Graph SearchEngine::run(const Graph graph) {
}
}
auto tmp = make_ref<GraphObj>(runtimeExec, ops);
tmp->dataMalloc();
nextGraphs.emplace_back(tmp);
}
}
@ -376,9 +375,6 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
nextGraphs.emplace_back(make_ref<GraphObj>(runtimeExec, ops));
}
}
for (auto g : nextGraphs) {
g->dataMalloc();
}
dbg("===Num" + std::to_string(nextGraphs.size()));
std::sort(nextGraphs.begin(), nextGraphs.end(), graphTimeComparer);
if (nextGraphs.size() > GRAPH_SIZE) {
@ -441,7 +437,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
std::cout << op->toString() << std::endl;
}
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
tmp->dataMalloc();
partitions.emplace_back(tmp);
headOps.clear();
}
@ -449,7 +444,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
}
if (!headOps.empty()) {
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
tmp->dataMalloc();
partitions.emplace_back(tmp);
}
std::reverse(partitions.begin(), partitions.end());

View File

@ -13,7 +13,7 @@ static std::function<void(const Graph &, string)> exportONNXImpl;
void exportONNX(const Graph &graph, const string &path) {
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
static auto exportONNXImpl =
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
py::module_::import("pyinfinitensor.onnx").attr("save_onnx");
exportONNXImpl(graph, path);
}

View File

@ -385,6 +385,7 @@ void export_test_model(py::module &m) {
#ifdef USE_CUDA
m.def("runInfoGAN", &runInfoGAN)
.def("getGANGraph", &getGANGraph)
.def("getLongformer", &getLongformer)
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
"tuning"_a = false, "mode"_a = NMutator::Mode::Normal,

View File

@ -25,9 +25,14 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
cudnnTensorDescriptor_t inputDesc, outputDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_ASSERT_TODO(dim.size() <= 4);
int n, c, h, w;
if (dim.size() == 4) {
n = dim[0], c = dim[1], h = dim[2], w = dim[3];
} else if (dim.size() == 3) {
n = 1, c = dim[0], h = dim[1], w = dim[2];
} else
IT_TODO_HALT();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));

View File

@ -7,7 +7,14 @@
#include "cuda/cuda_runtime.h"
#include "ffi/ffi_callback.h"
#include "nnet/nmutator.h"
#include "operators/G2BMM.h"
#include "operators/GBMM.h"
#include "operators/conv.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "operators/unary.h"
#include "test.h"
#include <pybind11/stl.h>
@ -79,6 +86,113 @@ Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId) {
return g;
}
Graph getLongformer(Runtime runtime, int bs) {
const int seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
const int hidden = featlen, hiddenPerHead = hidden / heads;
assert(hidden % heads == 0);
Graph g = make_ref<GraphObj>(runtime);
auto i0 = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
TensorType::Input);
auto w0 = g->addTensor({featlen, hidden}, DataType::Float32,
TensorType::Initialized);
auto w1 =
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
auto w2 =
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
// Feed forward
auto w3 =
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
auto bias3 =
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
auto w4 =
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
auto bias4 =
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
auto q0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto k0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto v0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto q1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto k1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto v1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto q2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto k2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto v2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto q3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto k3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto v3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto prob = g->addTensor({bs * heads, seqlen, 2 * w + 1}, DataType::Float32,
TensorType::Other);
auto probSoftmax = g->addTensor({bs * heads, seqlen, 2 * w + 1},
DataType::Float32, TensorType::Other);
auto attn = g->addTensor({bs * heads, seqlen, hiddenPerHead},
DataType::Float32, TensorType::Other);
auto t00 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto t01 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto t02 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
// auto t10 = g->addTensor({bs, seqlen, hidden});
auto t11 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto t12 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
TensorType::Other);
auto output = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
TensorType::Other);
g->addOpWithOutputs<MatmulObj>(i0, w0, q0, false, true);
g->addOpWithOutputs<MatmulObj>(i0, w1, k0, false, true);
g->addOpWithOutputs<MatmulObj>(i0, w2, v0, false, true);
g->addOpWithOutputs<ReshapeObj>(q0, q1);
g->addOpWithOutputs<ReshapeObj>(k0, k1);
g->addOpWithOutputs<ReshapeObj>(v0, v1);
// For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2,
// 3), the output shape will be (2, 1, 3).
g->addOpWithOutputs<TransposeObj>(q1, q2, vector{0, 2, 1, 3});
g->addOpWithOutputs<TransposeObj>(k1, k2, vector{0, 2, 1, 3});
g->addOpWithOutputs<TransposeObj>(v1, v2, vector{0, 2, 1, 3});
g->addOpWithOutputs<ReshapeObj>(q2, q3);
g->addOpWithOutputs<ReshapeObj>(k2, k3);
g->addOpWithOutputs<ReshapeObj>(v2, v3);
// Attention
g->addOpWithOutputs<G2BMMObj>(q3, k3, prob, w, d);
g->addOpWithOutputs<SoftmaxObj>(prob, probSoftmax, 2);
g->addOpWithOutputs<GBMMObj>(probSoftmax, v3, attn, d);
auto attn2 = g->addOp<ReshapeObj>(attn, nullptr,
vector{bs, heads, seqlen, hiddenPerHead})
->getOutput();
auto t000 =
g->addOp<TransposeObj>(attn2, nullptr, vector{0, 2, 1, 3})->getOutput();
g->addOpWithOutputs<ReshapeObj>(t000, t00);
// Feed forward
g->addOpWithOutputs<MatmulObj>(t00, w3, t01, false, true, bias3);
g->addOpWithOutputs<ReluObj>(t01, t02);
g->addOpWithOutputs<MatmulObj>(t02, w4, t11, false, true, bias4);
g->addOpWithOutputs<ReluObj>(t11, t12);
g->addOpWithOutputs<AddObj>(t12, i0, output);
return g;
}
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId) {
IT_ASSERT(0 <= layerId && layerId < 5);
Graph g = make_ref<GraphObj>(runtime);

View File

@ -6,10 +6,13 @@
#include "nnet/Visitor/MatchReshapeVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/derivator.h"
#include "operators/G2BMM.h"
#include "operators/GBMM.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/reshape.h"
#include "operators/transpose.h"
#include "operators/unary.h"
namespace infini {
@ -46,6 +49,12 @@ vector<Graph> NMutator::run(const Graph &in_graph) {
return out_graphs;
}
bool NMutator::isMultiBranchMergable(const Graph &in_graph) {
// TODO
// dbg("Skip mergable Multi-Branch", in_graph);
return false;
}
void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
std::vector<Graph> &out_graphs) {
OpVec computeOps = in_graph->getComputeOps();
@ -89,13 +98,22 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
out_graphs.emplace_back(g);
return;
}
if (Graph g = transformG2bmm(computeOps[0])) {
out_graphs.emplace_back(g);
return;
}
if (Graph g = transformGbmm(computeOps[0])) {
out_graphs.emplace_back(g);
return;
}
// // if (infini::Graph g = transformConv1xk(computeOps[0])) {
// // Graph graph = new Graph(g->getOperators());
// // out_graphs.emplace_back(graph);
// // return;
// // }
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC};
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC, OpType::G2BMM,
OpType::GBMM};
if (opSet.count(computeOps[0]->getOpType()) == 0)
return;
@ -273,26 +291,25 @@ pair<nnet::Expr, NMutator::NameNToTensorT> NMutator::extractOp(Operator opT) {
const auto K = nnet::makeTensor("K", KT->getDims());
return {nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s),
{{"A", AT}, {"K", KT}}};
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(opT)) {
// const auto &AT = g2bmmOp->getInputs()[0];
// const auto &BT = g2bmmOp->getInputs()[1];
// const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
} else if (auto g2bmmOp = as<G2BMMObj>(opT)) {
const auto &AT = g2bmmOp->getInputs()[0];
const auto &BT = g2bmmOp->getInputs()[1];
const auto [b, m, k, width, dilation] = g2bmmOp->getBMKWD();
// const auto &[expr, inputsN] =
// nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// return expr;
// } else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(opT)) {
// const auto &AT = gbmmlOp->getInputs()[0];
// const auto &BT = gbmmlOp->getInputs()[1];
// const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
// const auto &[expr, inputsN] =
// nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// dbg(b, m, w, k, dilation, expr);
// return expr;
const auto &[expr, inputsN] =
nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
return {
expr,
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
} else if (auto gbmmlOp = as<GBMMObj>(opT)) {
const auto &AT = gbmmlOp->getInputs()[0];
const auto &BT = gbmmlOp->getInputs()[1];
const auto [b, m, w, k, dilation] = gbmmlOp->getBMWND();
const auto &[expr, inputsN] =
nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
return {
expr,
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
} else if (auto matmulOp = as<MatmulObj>(opT)) {
const auto &AT = matmulOp->getInputs()[0];
const auto &BT = matmulOp->getInputs()[1];
@ -574,6 +591,44 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
// return graph;
// }
Graph NMutator::transformG2bmm(Operator _op) {
auto op = as<G2BMMObj>(_op);
if (!op)
return nullptr;
const auto [b, m, k, width, dilation] = op->getBMKWD();
if (dilation == 1 || m % dilation != 0)
return nullptr;
auto g = make_ref<GraphObj>(runtime);
auto A = g->cloneTensor(op->getInputs(0));
auto B = g->cloneTensor(op->getInputs(1));
auto O = g->cloneTensor(op->getOutput());
auto A3 = splitTransposeMerge(g, A, 1, dilation),
B3 = splitTransposeMerge(g, B, 1, dilation);
auto O3 = g->addOp<G2BMMObj>(A3, B3, nullptr, width, 1)->getOutput();
splitTransposeMerge(g, O3, 1, m / dilation, O);
g->checkValid();
return g;
}
Graph NMutator::transformGbmm(Operator _op) {
auto op = as<GBMMObj>(_op);
if (!op)
return nullptr;
const auto [b, m, width, k, dilation] = op->getBMWND();
if (dilation == 1 || m % dilation != 0)
return nullptr;
auto g = make_ref<GraphObj>(runtime);
auto A = g->cloneTensor(op->getInputs(0)); // [b,m,2w+1]
auto B = g->cloneTensor(op->getInputs(1)); // [b,m,n]
auto O = g->cloneTensor(op->getOutput()); // [b,m,n]
auto A3 = splitTransposeMerge(g, A, 1, dilation),
B3 = splitTransposeMerge(g, B, 1, dilation);
auto O3 = g->addOp<GBMMObj>(A3, B3, nullptr, 1)->getOutput();
splitTransposeMerge(g, O3, 1, m / dilation, O);
g->checkValid();
return g;
}
Graph NMutator::transformConv1x1(Operator _op) {
auto op = as<ConvObj>(_op);
if (!op)
@ -722,4 +777,28 @@ pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
return {range, {tensor}};
}
Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
Tensor output) {
IT_ASSERT(A->getDims().size() == 3);
Shape shapeOrignial = A->getDims();
Shape shapeNew;
// Construct new shape
for (int i = 0; i < dim; ++i)
shapeNew.emplace_back(shapeOrignial[i]);
shapeNew.emplace_back(shapeOrignial[dim] / chunkSize);
shapeNew.emplace_back(chunkSize);
for (size_t i = dim + 1; i < shapeOrignial.size(); ++i)
shapeNew.emplace_back(shapeOrignial[i]);
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
auto A2 =
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
Tensor A3;
if (output)
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
->getOutput();
else
A3 = g->addOp<ReshapeObj>(A2, nullptr, shapeOrignial)->getOutput();
return A3;
};
} // namespace infini

View File

@ -2,7 +2,8 @@
namespace infini {
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
: OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) {
: OperatorObj(OpType::Reshape, {input}, {output}),
dims(dims.size() == 0 ? output->getDims() : dims) {
IT_ASSERT(checkValid(graph));
}