finish fusing conv add bias on onnx

This commit is contained in:
xxcclong 2023-04-22 14:40:41 +08:00
parent 225a42f22d
commit 0657938139
10 changed files with 79 additions and 25 deletions

View File

@ -31,6 +31,7 @@ class Mutator {
IT_TODO_HALT();
}
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
virtual Graph fuseConvBiasAct(const Graph &inputGraph) { IT_TODO_HALT(); }
};
} // namespace infini

View File

@ -26,6 +26,7 @@ class NMutator : public Mutator {
vector<Graph> run(const Graph &in_graph) override;
Graph fuseVertically(const Graph &in_graph) override;
Graph fuseConvBiasAct(const Graph &in_graph) override;
void setToNaiveMembound();
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }

View File

@ -98,7 +98,10 @@ class ConvBaseObj : public OperatorObj {
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
Tensor getBias() const { return inputs[2]; }
Tensor getBias() const {
if (inputs.size() >= 3) return inputs[2];
else return nullptr;
}
PaddingMode getPaddingMode() const { return padding; }
pair<int, int> inferPaddingSize() const;

View File

@ -104,7 +104,7 @@ class OnnxStub:
adapt = node.input[0]
# HACK: ignore bias
if len(node.input) > 3:
if len(node.input) > 2: # 2023.04.22-01:11:24
bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0])
tensors[bias] = ans.handler.conv(
@ -701,6 +701,7 @@ class OnnxStub:
for op in ops:
ty, name = ctx.name_op(op)
print(f"hkz: onnx {name} {len(op.inputs())}")
inputs = [
ctx.push_input(it, self.initializer.get(it.fuid()))
for it in op.inputs()
@ -710,7 +711,13 @@ class OnnxStub:
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Conv:
ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op)
ph, pw, dh, dw, sh, sw, bias = backend.conv_attrs_of(op)
if bias is not None:
inputs.append(ctx.push_input(bias, self.initializer.get(bias.fuid())))
for item in inputs:
print(type(item), str(item))
else:
print("hkz: bias is none")
ctx.push_node(
make_node(
ty.name,

View File

@ -36,7 +36,7 @@ bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose || type == OpType::Relu ||
type == OpType::Tanh;
type == OpType::Tanh || type == OpType::Add; // TODO: is Add memory bound? 2023.04.22-14:35:32
}
void OperatorObj::removePredecessors(const Operator &op) {

View File

@ -471,9 +471,11 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
if (visitTime.find(op->getGuid()) != visitTime.end()) {
continue;
}
// if is conv, we can still vertical fuse it
bool conv_flag = op->isComputeOp() && op->getSuccessors().size() == 1;
// Skip compute OP and multi-input/output OP
if (!op->isMemBoundOp() || (op->getPredecessors().size() != 1 &&
op->getSuccessors().size() != 1)) {
if (!conv_flag && (!op->isMemBoundOp() || (op->getPredecessors().size() != 1 &&
op->getSuccessors().size() != 1))) {
visitTime.emplace(op->getGuid(), ++cnt);
ops.emplace_back(op);
continue;
@ -483,14 +485,17 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
vector<Operator> tmp;
auto cur = op;
while (cur->getPredecessors().size() == 1 &&
cur->getPredecessors()[0]->isMemBoundOp()) {
cur = cur->getPredecessors()[0];
tmp.emplace_back(cur);
visitTime.emplace(cur->getGuid(), cnt);
}
for (int i = tmp.size() - 1; i >= 0; i--) {
chainOps.emplace_back(tmp[i]);
if (!conv_flag) {
while (cur->getPredecessors().size() == 1 &&
cur->getPredecessors()[0]->isMemBoundOp()) {
cur = cur->getPredecessors()[0];
tmp.emplace_back(cur);
visitTime.emplace(cur->getGuid(), cnt);
}
for (int i = tmp.size() - 1; i >= 0; i--) {
chainOps.emplace_back(tmp[i]);
}
}
chainOps.emplace_back(op);
cur = op;
@ -502,10 +507,19 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
}
make_ref<GraphObj>(runtimeExec, chainOps)->print();
Graph optGraph =
mutator->fuseVertically(make_ref<GraphObj>(runtimeExec, chainOps));
for (auto op : optGraph->getOperators()) {
ops.emplace_back(op);
if (conv_flag && chainOps.size() > 1) {
Graph optGraph =
mutator->fuseConvBiasAct(make_ref<GraphObj>(runtimeExec, chainOps));
for (auto op : optGraph->getOperators()) {
ops.emplace_back(op);
}
}
else {
Graph optGraph =
mutator->fuseVertically(make_ref<GraphObj>(runtimeExec, chainOps));
for (auto op : optGraph->getOperators()) {
ops.emplace_back(op);
}
}
}

View File

@ -142,11 +142,11 @@ static Ref<BangRuntimeObj> bang_runtime() { return make_ref<BangRuntimeObj>(); }
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
#endif
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
static std::tuple<int, int, int, int, int, int, Tensor> conv_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Conv);
auto conv = dynamic_cast<const ConvObj *>(op.get());
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
conv->getDw(), conv->getSh(), conv->getSw());
conv->getDw(), conv->getSh(), conv->getSw(), conv->getBias());
}
static std::tuple<int, int, int, int, int, int, int, int>

View File

@ -29,6 +29,8 @@ void NMutator::setToNaiveMembound() { mode = Mode::ToNaiveMembound; }
vector<Graph> NMutator::run(const Graph &in_graph) {
vector<Graph> out_graphs{in_graph};
printf("directly return out_graph\n");// Test helper: naively transform one Op to Membound
return out_graphs;
// Test helper: naively transform one Op to Membound
if (mode == Mode::ToNaiveMembound) {
runSingleOpToNaiveMembound(in_graph, out_graphs);
@ -680,14 +682,38 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
// return graph;
// }
Graph NMutator::fuseConvBiasAct(const Graph &inputGraph) {
Graph optGraph = make_ref<GraphObj>(runtime);
auto chainOps = inputGraph->getOperators();
IT_ASSERT(!chainOps.empty());
if (chainOps.size() == 1) {
return make_ref<GraphObj>(runtime, chainOps);
}
auto conv_op = as<ConvObj>(chainOps[0]);
auto bias_tensor = conv_op->getInputs()[0]; // init bias tensor
auto add_op = chainOps[1];
bool fuse_bias = chainOps.size() >=2 && add_op->getInputs().size() == 2;
if (fuse_bias) { // conv add act
bias_tensor = add_op->getInputs()[1];
}
IT_ASSERT(conv_op != nullptr);
const auto &A = conv_op->getInputs()[0];
const auto &W = conv_op->getInputs()[1];
const auto &[ph, pw, sh, sw, dh, dw] = conv_op->getPadStrideDilation();
auto g = make_ref<GraphObj>(runtime);
g->addOpWithOutputs<ConvObj>(g->cloneTensor(A), g->cloneTensor(W), g->cloneTensor(chainOps.back()->getOutput()),
ph, pw, sh, sw, dh, dw, fuse_bias ? g->cloneTensor(bias_tensor) : nullptr, chainOps.size() == 3 ? ActType::Relu : ActType::None);
return g;
}
Graph NMutator::fuseVertically(const Graph &inputGraph) {
Graph optGraph = make_ref<GraphObj>(runtime);
auto chainOps = inputGraph->getOperators();
IT_ASSERT(!chainOps.empty());
for (auto &op : chainOps) {
IT_ASSERT(op->isMemBoundOp());
IT_ASSERT_TODO(op->getInputs().size() == 1);
IT_ASSERT(op->isMemBoundOp() || chainOps.size() == 1); // it is OK if a single non-mem-bound op is in
IT_ASSERT_TODO(op->getInputs().size() == 1 || chainOps.size() == 1); // it is OK if a single multi-input op is in
IT_ASSERT(op->getOutputs().size() == 1);
}
if (chainOps.size() == 1) {

View File

@ -66,8 +66,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
ActType act)
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
input, weight, act) {
// if (bias)
// IT_TODO_HALT();
if (bias)
IT_TODO_HALT();
inputs.emplace_back(bias);
setAuxilaryAttributes(PaddingMode::Other);
IT_ASSERT(checkValid(graph));
}

View File

@ -123,5 +123,5 @@ if __name__ == "__main__":
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
save_onnx(g, f"optimized_{name}.onnx")
verify_graphs(runtime, original_g, g)
run_and_evaluate(runtime, g)
# verify_graphs(runtime, original_g, g)
# run_and_evaluate(runtime, g)