forked from jiuyuan/InfiniTensor
Add: convert expression to operator
This commit is contained in:
parent
29071ddcac
commit
1ee4a60af0
|
@ -387,13 +387,13 @@ class FuncNode : public ExprNode {
|
|||
|
||||
// Wrappers for type deduction
|
||||
|
||||
// make subscript
|
||||
// make a subscript operator
|
||||
Subscript mSub(const Expr &tensor, const VecExpr &subscripts);
|
||||
// make range operator
|
||||
// make a range operator
|
||||
RangeOp mL(const vector<VarRangePair> &_loopIters,
|
||||
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||
const vector<int> &paddings = {});
|
||||
// make tensor
|
||||
// make a tensor
|
||||
Tensor mT(const string &name, const vector<int> &shape,
|
||||
const vector<int> &paddings = {}, const Routine &source = nullptr);
|
||||
|
||||
|
|
|
@ -90,6 +90,10 @@ class NMutator : public Mutator {
|
|||
/// @param op The operator chain. It can have wrong input/output shapes.
|
||||
/// @return
|
||||
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
|
||||
|
||||
// Convert an nnet::Expr to an infini::Graph containing corresponding
|
||||
// tensors and operators
|
||||
Graph constructGraphFromExpression(Runtime runtime, nnet::Expr expr);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -45,6 +45,8 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) {
|
|||
curExpr = sub->getObjectPtr();
|
||||
else
|
||||
break;
|
||||
} else if (auto funcOp = as<BinaryOpNode>(summand)) {
|
||||
break;
|
||||
} else {
|
||||
if (allowFailure)
|
||||
return nullptr;
|
||||
|
@ -143,4 +145,4 @@ Expr MergeMemboundMutator::rule4StageMerging(Expr &rCur,
|
|||
return merged;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
||||
} // namespace nnet
|
||||
|
|
|
@ -1232,4 +1232,17 @@ vector<Graph> NMutator::transformConv3x3ONNX(Operator _op) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
Graph NMutator::constructGraphFromExpression(Runtime runtime, nnet::Expr expr) {
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto nTensors = nnet::GetTensorsVisitor().get(expr);
|
||||
TensorVec inputs, outputs;
|
||||
for (auto &[nName, nTensor] : nTensors) {
|
||||
inputs.emplace_back(g->addTensor(nTensor->getShape()));
|
||||
}
|
||||
outputs.emplace_back(
|
||||
g->addTensor(nnet::as<nnet::RangeOpNode>(expr)->getOutputShape()));
|
||||
g->addOpWithOutputs<MemBoundObj>(inputs, outputs, nTensors, expr, 0);
|
||||
return g;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -65,3 +65,31 @@ TEST(FuseMembound, mergeNestedStagesInRangeOp) {
|
|||
mSub(K, {f, i / 1024, (i / 256) % 4, i % 256}));
|
||||
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
|
||||
}
|
||||
|
||||
TEST(FuseMembound, mergeReductionBiasRelu) {
|
||||
DEFINE_VAR(f, i);
|
||||
const int F = 4, H = 16;
|
||||
auto A = mT("A", vector<int>({F, H}));
|
||||
auto B = mT("B", vector<int>({F, H}));
|
||||
auto AB = mT("AB", vector<int>({F, H}));
|
||||
auto C = mT("Bias", vector<int>({F}));
|
||||
auto l0 = // Reduction
|
||||
mL({{f, {0, F}}, {i, {0, H}}}, {}, mSub(A, {f, i}) * mSub(B, {f, i}));
|
||||
auto l1 = // Bias
|
||||
mL({{f, {0, F}}, {i, {0, H}}}, {}, mSub(l0, {f, i}) + mSub(C, {f}));
|
||||
// Relu
|
||||
auto l2 = mL({{f, {0, F}}, {i, {0, H}}}, {},
|
||||
make_ref<FuncNode>(mSub(AB, {f, i}), FuncType::Relu));
|
||||
dbg(l1, l2);
|
||||
|
||||
auto merged = MergeMemboundMutator({l1, l2}).merge();
|
||||
dbg(merged);
|
||||
|
||||
// TODO:
|
||||
// 1. 用NMutator::constructGraphFromExpression,把以上merged表达式变为算子,
|
||||
// 跑通TVM codegen。现在端到端运行模型用的是test/nnet/run_models_nnet.py的
|
||||
// model_e2e_exp(),可将以上函数作为整体暴露给python。
|
||||
// 2. 在NMutator::runSingleOp中,处理带有bias、relu的conv,使得输入一个conv
|
||||
// operator,能得到Gemm加一个Membound算子(Reduce+Bias+Relu)。现在的代码中,
|
||||
// NMutator::opToExpression返回的表达式会直接忽略bias和relu。
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue