Add: convert expression to operator

This commit is contained in:
Liyan Zheng 2023-06-28 11:06:17 +08:00
parent 29071ddcac
commit 1ee4a60af0
5 changed files with 51 additions and 4 deletions

View File

@ -387,13 +387,13 @@ class FuncNode : public ExprNode {
// Wrappers for type deduction
// make subscript
// make a subscript operator
Subscript mSub(const Expr &tensor, const VecExpr &subscripts);
// make range operator
// make a range operator
RangeOp mL(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings = {});
// make tensor
// make a tensor
Tensor mT(const string &name, const vector<int> &shape,
const vector<int> &paddings = {}, const Routine &source = nullptr);

View File

@ -90,6 +90,10 @@ class NMutator : public Mutator {
/// @param op The operator chain. It can have wrong input/output shapes.
/// @return
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
// Convert an nnet::Expr to an infini::Graph containing corresponding
// tensors and operators
Graph constructGraphFromExpression(Runtime runtime, nnet::Expr expr);
};
} // namespace infini

View File

@ -45,6 +45,8 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) {
curExpr = sub->getObjectPtr();
else
break;
} else if (auto funcOp = as<BinaryOpNode>(summand)) {
break;
} else {
if (allowFailure)
return nullptr;
@ -143,4 +145,4 @@ Expr MergeMemboundMutator::rule4StageMerging(Expr &rCur,
return merged;
}
} // namespace nnet
} // namespace nnet

View File

@ -1232,4 +1232,17 @@ vector<Graph> NMutator::transformConv3x3ONNX(Operator _op) {
return ret;
}
Graph NMutator::constructGraphFromExpression(Runtime runtime, nnet::Expr expr) {
auto g = make_ref<GraphObj>(runtime);
auto nTensors = nnet::GetTensorsVisitor().get(expr);
TensorVec inputs, outputs;
for (auto &[nName, nTensor] : nTensors) {
inputs.emplace_back(g->addTensor(nTensor->getShape()));
}
outputs.emplace_back(
g->addTensor(nnet::as<nnet::RangeOpNode>(expr)->getOutputShape()));
g->addOpWithOutputs<MemBoundObj>(inputs, outputs, nTensors, expr, 0);
return g;
}
} // namespace infini

View File

@ -65,3 +65,31 @@ TEST(FuseMembound, mergeNestedStagesInRangeOp) {
mSub(K, {f, i / 1024, (i / 256) % 4, i % 256}));
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}
TEST(FuseMembound, mergeReductionBiasRelu) {
DEFINE_VAR(f, i);
const int F = 4, H = 16;
auto A = mT("A", vector<int>({F, H}));
auto B = mT("B", vector<int>({F, H}));
auto AB = mT("AB", vector<int>({F, H}));
auto C = mT("Bias", vector<int>({F}));
auto l0 = // Reduction
mL({{f, {0, F}}, {i, {0, H}}}, {}, mSub(A, {f, i}) * mSub(B, {f, i}));
auto l1 = // Bias
mL({{f, {0, F}}, {i, {0, H}}}, {}, mSub(l0, {f, i}) + mSub(C, {f}));
// Relu
auto l2 = mL({{f, {0, F}}, {i, {0, H}}}, {},
make_ref<FuncNode>(mSub(AB, {f, i}), FuncType::Relu));
dbg(l1, l2);
auto merged = MergeMemboundMutator({l1, l2}).merge();
dbg(merged);
// TODO:
// 1. 用NMutator::constructGraphFromExpression把以上merged表达式变为算子
// 跑通TVM codegen。现在端到端运行模型用的是test/nnet/run_models_nnet.py的
// model_e2e_exp()可将以上函数作为整体暴露给python。
// 2. 在NMutator::runSingleOp中处理带有bias、relu的conv使得输入一个conv
// operator能得到Gemm加一个Membound算子(Reduce+Bias+Relu)。现在的代码中,
// NMutator::opToExpression返回的表达式会直接忽略bias和relu。
}