forked from jiuyuan/InfiniTensor
Add: Longformer models
This commit is contained in:
parent
4f02eeb08c
commit
84f9d6731a
|
@ -26,6 +26,7 @@ class NMutator : public Mutator {
|
|||
|
||||
vector<Graph> run(const Graph &in_graph) override;
|
||||
Graph fuseVertically(const Graph &in_graph) override;
|
||||
bool isMultiBranchMergable(const Graph &in_graph) override;
|
||||
|
||||
void setToNaiveMembound();
|
||||
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
|
||||
|
@ -61,7 +62,12 @@ class NMutator : public Mutator {
|
|||
// Graph transformConvtransposed(Operator op);
|
||||
// Graph transformDialtedConv(Operator op);
|
||||
Graph transformConv1x1(Operator op);
|
||||
Graph transformG2bmm(Operator op);
|
||||
Graph transformGbmm(Operator op);
|
||||
// Graph transformConv1xk(Operator op);
|
||||
|
||||
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output = nullptr);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
namespace infini {
|
||||
|
||||
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
||||
Graph getLongformer(Runtime runtime, int bs);
|
||||
vector<Tensor> runInfoGAN(int nLayers);
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
||||
|
|
|
@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj {
|
|||
* @param output The output tensor.
|
||||
* @param dims The shape of the output tensor.
|
||||
*/
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims = {});
|
||||
OP_CLONE(ReshapeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -884,7 +884,8 @@ class OnnxStub:
|
|||
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.ConvTransNHWC:
|
||||
elif ty in [backend.OpType.ConvTransNHWC, backend.OpType.GBMM,
|
||||
backend.OpType.G2BMM]:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
|
@ -983,3 +984,9 @@ def _parse_data(tensor: TensorProto) -> List[Any]:
|
|||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
||||
|
||||
def save_onnx(opt_g, filename: str):
|
||||
stub = OnnxStub.from_graph(opt_g)
|
||||
with open(filename, "wb") as f:
|
||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
|
|
|
@ -70,7 +70,6 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
}
|
||||
}
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, ops);
|
||||
tmp->dataMalloc();
|
||||
nextGraphs.emplace_back(tmp);
|
||||
}
|
||||
}
|
||||
|
@ -376,9 +375,6 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
|
|||
nextGraphs.emplace_back(make_ref<GraphObj>(runtimeExec, ops));
|
||||
}
|
||||
}
|
||||
for (auto g : nextGraphs) {
|
||||
g->dataMalloc();
|
||||
}
|
||||
dbg("===Num" + std::to_string(nextGraphs.size()));
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), graphTimeComparer);
|
||||
if (nextGraphs.size() > GRAPH_SIZE) {
|
||||
|
@ -441,7 +437,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
std::cout << op->toString() << std::endl;
|
||||
}
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
|
||||
tmp->dataMalloc();
|
||||
partitions.emplace_back(tmp);
|
||||
headOps.clear();
|
||||
}
|
||||
|
@ -449,7 +444,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
}
|
||||
if (!headOps.empty()) {
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
|
||||
tmp->dataMalloc();
|
||||
partitions.emplace_back(tmp);
|
||||
}
|
||||
std::reverse(partitions.begin(), partitions.end());
|
||||
|
|
|
@ -13,7 +13,7 @@ static std::function<void(const Graph &, string)> exportONNXImpl;
|
|||
void exportONNX(const Graph &graph, const string &path) {
|
||||
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
|
||||
static auto exportONNXImpl =
|
||||
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
|
||||
py::module_::import("pyinfinitensor.onnx").attr("save_onnx");
|
||||
exportONNXImpl(graph, path);
|
||||
}
|
||||
|
||||
|
|
|
@ -385,6 +385,7 @@ void export_test_model(py::module &m) {
|
|||
#ifdef USE_CUDA
|
||||
m.def("runInfoGAN", &runInfoGAN)
|
||||
.def("getGANGraph", &getGANGraph)
|
||||
.def("getLongformer", &getLongformer)
|
||||
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
||||
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||
"tuning"_a = false, "mode"_a = NMutator::Mode::Normal,
|
||||
|
|
|
@ -25,9 +25,14 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
|
|||
|
||||
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_ASSERT_TODO(dim.size() <= 4);
|
||||
int n, c, h, w;
|
||||
if (dim.size() == 4) {
|
||||
n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
} else if (dim.size() == 3) {
|
||||
n = 1, c = dim[0], h = dim[1], w = dim[2];
|
||||
} else
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
|
|
|
@ -7,7 +7,14 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/G2BMM.h"
|
||||
#include "operators/GBMM.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
@ -79,6 +86,113 @@ Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId) {
|
|||
return g;
|
||||
}
|
||||
|
||||
Graph getLongformer(Runtime runtime, int bs) {
|
||||
const int seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
assert(hidden % heads == 0);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto i0 = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
|
||||
TensorType::Input);
|
||||
auto w0 = g->addTensor({featlen, hidden}, DataType::Float32,
|
||||
TensorType::Initialized);
|
||||
auto w1 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto w2 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
// Feed forward
|
||||
auto w3 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto bias3 =
|
||||
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
|
||||
auto w4 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto bias4 =
|
||||
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
|
||||
|
||||
auto q0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto k0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto v0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
|
||||
auto q1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto q2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto q3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto prob = g->addTensor({bs * heads, seqlen, 2 * w + 1}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto probSoftmax = g->addTensor({bs * heads, seqlen, 2 * w + 1},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto attn = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto t00 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t01 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t02 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
// auto t10 = g->addTensor({bs, seqlen, hidden});
|
||||
auto t11 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t12 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto output = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w0, q0, false, true);
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w1, k0, false, true);
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w2, v0, false, true);
|
||||
g->addOpWithOutputs<ReshapeObj>(q0, q1);
|
||||
g->addOpWithOutputs<ReshapeObj>(k0, k1);
|
||||
g->addOpWithOutputs<ReshapeObj>(v0, v1);
|
||||
// For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2,
|
||||
// 3), the output shape will be (2, 1, 3).
|
||||
g->addOpWithOutputs<TransposeObj>(q1, q2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(k1, k2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(v1, v2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<ReshapeObj>(q2, q3);
|
||||
g->addOpWithOutputs<ReshapeObj>(k2, k3);
|
||||
g->addOpWithOutputs<ReshapeObj>(v2, v3);
|
||||
// Attention
|
||||
g->addOpWithOutputs<G2BMMObj>(q3, k3, prob, w, d);
|
||||
g->addOpWithOutputs<SoftmaxObj>(prob, probSoftmax, 2);
|
||||
g->addOpWithOutputs<GBMMObj>(probSoftmax, v3, attn, d);
|
||||
auto attn2 = g->addOp<ReshapeObj>(attn, nullptr,
|
||||
vector{bs, heads, seqlen, hiddenPerHead})
|
||||
->getOutput();
|
||||
auto t000 =
|
||||
g->addOp<TransposeObj>(attn2, nullptr, vector{0, 2, 1, 3})->getOutput();
|
||||
g->addOpWithOutputs<ReshapeObj>(t000, t00);
|
||||
|
||||
// Feed forward
|
||||
g->addOpWithOutputs<MatmulObj>(t00, w3, t01, false, true, bias3);
|
||||
g->addOpWithOutputs<ReluObj>(t01, t02);
|
||||
g->addOpWithOutputs<MatmulObj>(t02, w4, t11, false, true, bias4);
|
||||
g->addOpWithOutputs<ReluObj>(t11, t12);
|
||||
g->addOpWithOutputs<AddObj>(t12, i0, output);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId) {
|
||||
IT_ASSERT(0 <= layerId && layerId < 5);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
|
|
@ -6,10 +6,13 @@
|
|||
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/derivator.h"
|
||||
#include "operators/G2BMM.h"
|
||||
#include "operators/GBMM.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -46,6 +49,12 @@ vector<Graph> NMutator::run(const Graph &in_graph) {
|
|||
return out_graphs;
|
||||
}
|
||||
|
||||
bool NMutator::isMultiBranchMergable(const Graph &in_graph) {
|
||||
// TODO
|
||||
// dbg("Skip mergable Multi-Branch", in_graph);
|
||||
return false;
|
||||
}
|
||||
|
||||
void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
||||
std::vector<Graph> &out_graphs) {
|
||||
OpVec computeOps = in_graph->getComputeOps();
|
||||
|
@ -89,13 +98,22 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (Graph g = transformG2bmm(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (Graph g = transformGbmm(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
// // if (infini::Graph g = transformConv1xk(computeOps[0])) {
|
||||
// // Graph graph = new Graph(g->getOperators());
|
||||
// // out_graphs.emplace_back(graph);
|
||||
// // return;
|
||||
// // }
|
||||
|
||||
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC};
|
||||
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC, OpType::G2BMM,
|
||||
OpType::GBMM};
|
||||
if (opSet.count(computeOps[0]->getOpType()) == 0)
|
||||
return;
|
||||
|
||||
|
@ -273,26 +291,25 @@ pair<nnet::Expr, NMutator::NameNToTensorT> NMutator::extractOp(Operator opT) {
|
|||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return {nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s),
|
||||
{{"A", AT}, {"K", KT}}};
|
||||
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(opT)) {
|
||||
// const auto &AT = g2bmmOp->getInputs()[0];
|
||||
// const auto &BT = g2bmmOp->getInputs()[1];
|
||||
// const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
|
||||
} else if (auto g2bmmOp = as<G2BMMObj>(opT)) {
|
||||
const auto &AT = g2bmmOp->getInputs()[0];
|
||||
const auto &BT = g2bmmOp->getInputs()[1];
|
||||
const auto [b, m, k, width, dilation] = g2bmmOp->getBMKWD();
|
||||
|
||||
// const auto &[expr, inputsN] =
|
||||
// nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
|
||||
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
// return expr;
|
||||
// } else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(opT)) {
|
||||
// const auto &AT = gbmmlOp->getInputs()[0];
|
||||
// const auto &BT = gbmmlOp->getInputs()[1];
|
||||
// const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
|
||||
// const auto &[expr, inputsN] =
|
||||
// nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
|
||||
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
// dbg(b, m, w, k, dilation, expr);
|
||||
// return expr;
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
|
||||
return {
|
||||
expr,
|
||||
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
|
||||
} else if (auto gbmmlOp = as<GBMMObj>(opT)) {
|
||||
const auto &AT = gbmmlOp->getInputs()[0];
|
||||
const auto &BT = gbmmlOp->getInputs()[1];
|
||||
const auto [b, m, w, k, dilation] = gbmmlOp->getBMWND();
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
|
||||
return {
|
||||
expr,
|
||||
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
|
||||
} else if (auto matmulOp = as<MatmulObj>(opT)) {
|
||||
const auto &AT = matmulOp->getInputs()[0];
|
||||
const auto &BT = matmulOp->getInputs()[1];
|
||||
|
@ -574,6 +591,44 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
Graph NMutator::transformG2bmm(Operator _op) {
|
||||
auto op = as<G2BMMObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto [b, m, k, width, dilation] = op->getBMKWD();
|
||||
if (dilation == 1 || m % dilation != 0)
|
||||
return nullptr;
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A = g->cloneTensor(op->getInputs(0));
|
||||
auto B = g->cloneTensor(op->getInputs(1));
|
||||
auto O = g->cloneTensor(op->getOutput());
|
||||
auto A3 = splitTransposeMerge(g, A, 1, dilation),
|
||||
B3 = splitTransposeMerge(g, B, 1, dilation);
|
||||
auto O3 = g->addOp<G2BMMObj>(A3, B3, nullptr, width, 1)->getOutput();
|
||||
splitTransposeMerge(g, O3, 1, m / dilation, O);
|
||||
g->checkValid();
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformGbmm(Operator _op) {
|
||||
auto op = as<GBMMObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto [b, m, width, k, dilation] = op->getBMWND();
|
||||
if (dilation == 1 || m % dilation != 0)
|
||||
return nullptr;
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A = g->cloneTensor(op->getInputs(0)); // [b,m,2w+1]
|
||||
auto B = g->cloneTensor(op->getInputs(1)); // [b,m,n]
|
||||
auto O = g->cloneTensor(op->getOutput()); // [b,m,n]
|
||||
auto A3 = splitTransposeMerge(g, A, 1, dilation),
|
||||
B3 = splitTransposeMerge(g, B, 1, dilation);
|
||||
auto O3 = g->addOp<GBMMObj>(A3, B3, nullptr, 1)->getOutput();
|
||||
splitTransposeMerge(g, O3, 1, m / dilation, O);
|
||||
g->checkValid();
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConv1x1(Operator _op) {
|
||||
auto op = as<ConvObj>(_op);
|
||||
if (!op)
|
||||
|
@ -722,4 +777,28 @@ pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
|
|||
return {range, {tensor}};
|
||||
}
|
||||
|
||||
Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output) {
|
||||
IT_ASSERT(A->getDims().size() == 3);
|
||||
Shape shapeOrignial = A->getDims();
|
||||
Shape shapeNew;
|
||||
// Construct new shape
|
||||
for (int i = 0; i < dim; ++i)
|
||||
shapeNew.emplace_back(shapeOrignial[i]);
|
||||
shapeNew.emplace_back(shapeOrignial[dim] / chunkSize);
|
||||
shapeNew.emplace_back(chunkSize);
|
||||
for (size_t i = dim + 1; i < shapeOrignial.size(); ++i)
|
||||
shapeNew.emplace_back(shapeOrignial[i]);
|
||||
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
|
||||
auto A2 =
|
||||
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
|
||||
Tensor A3;
|
||||
if (output)
|
||||
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
|
||||
->getOutput();
|
||||
else
|
||||
A3 = g->addOp<ReshapeObj>(A2, nullptr, shapeOrignial)->getOutput();
|
||||
return A3;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
namespace infini {
|
||||
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
||||
: OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) {
|
||||
: OperatorObj(OpType::Reshape, {input}, {output}),
|
||||
dims(dims.size() == 0 ? output->getDims() : dims) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue