feat: 导出到 python

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-06-25 17:17:14 +08:00
parent 89b3cf7c08
commit 2f64997142
5 changed files with 13 additions and 3 deletions

View File

@ -4,6 +4,6 @@
namespace infini {
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG);
Graph convertNCHWtoNHWCModel(Graph inG);
} // namespace infini

View File

@ -1,5 +1,6 @@
#pragma once
#include "app.h"
#include "core/graph.h"
#include "core/runtime.h"
#include <cstdint>
@ -102,6 +103,8 @@ class GraphHandlerObj {
inline void optimize() { g->optimize(); }
inline void convert_nhwc() { g = std::move(convertNCHWtoNHWCModel(g)); }
//------ runtime
inline void data_malloc() { g->dataMalloc(); }

View File

@ -820,6 +820,9 @@ class OnnxStub:
def init(self) -> None:
self.handler.data_malloc()
def convert_nhwc(self) -> None:
self.handler.convert_nhwc()
def optimize(self) -> None:
self.handler.optimize()

View File

@ -29,13 +29,13 @@ Tensor runWeightComputation(const Tensor &weight) {
#endif // BANG
}
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
Graph convertNCHWtoNHWCModel(Graph inG) {
// Construct new graph
// IT_ASSERT(inG->getInputs().size() == 1);
IT_ASSERT(inG->getOutputs().size() == 1);
bool status = inG->topo_sort();
IT_ASSERT(status);
auto g = make_ref<GraphObj>(runtime);
auto g = make_ref<GraphObj>(inG->getRuntime());
map<UidBaseType, Tensor> tensors;
// modelStatus: if currently processing Conv-related subgraph
// 0: before processcing Conv-related subgraph
@ -71,6 +71,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
auto bias =
cOp->getBias() ? g->cloneTensor(cOp->getBias()) : nullptr;
auto weight = runWeightComputation(inputs[1]);
g->addTensor(weight);
g->addOpWithOutputs<ConvNHWCObj>(inputs[0], weight, outputs[0],
ph, pw, sh, sw, dh, dw, bias,
cOp->getAct());
@ -82,6 +83,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
auto bias =
cOp->getBias() ? g->cloneTensor(cOp->getBias()) : nullptr;
auto weight = runWeightComputation(inputs[1]);
g->addTensor(weight);
g->addOpWithOutputs<ConvTransposed2dNHWCObj>(
inputs[0], weight, outputs[0], ph, pw, sh, sw, dh, dw, oph,
opw, group, bias, cOp->getAct());
@ -139,6 +141,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
vector<int>{0, 2, 3, 1})
->getOutput();
auto weight = runWeightComputation(inputs[1]);
g->addTensor(weight);
g->addOpWithOutputs<ConvNHWCObj>(t, weight, outputs[0], ph,
pw, sh, sw, dh, dw, bias,
cOp->getAct());

View File

@ -344,6 +344,7 @@ void init_graph_builder(py::module &m) {
.def("pad", &Handler::pad, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic)
.def("convert_nhwc", &Handler::convert_nhwc, policy::automatic)
.def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic);