forked from jiuyuan/InfiniTensor
feat: 导出到 python
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
89b3cf7c08
commit
2f64997142
|
@ -4,6 +4,6 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG);
|
||||
Graph convertNCHWtoNHWCModel(Graph inG);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "app.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include <cstdint>
|
||||
|
@ -102,6 +103,8 @@ class GraphHandlerObj {
|
|||
|
||||
inline void optimize() { g->optimize(); }
|
||||
|
||||
inline void convert_nhwc() { g = std::move(convertNCHWtoNHWCModel(g)); }
|
||||
|
||||
//------ runtime
|
||||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
|
|
|
@ -820,6 +820,9 @@ class OnnxStub:
|
|||
def init(self) -> None:
|
||||
self.handler.data_malloc()
|
||||
|
||||
def convert_nhwc(self) -> None:
|
||||
self.handler.convert_nhwc()
|
||||
|
||||
def optimize(self) -> None:
|
||||
self.handler.optimize()
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ Tensor runWeightComputation(const Tensor &weight) {
|
|||
#endif // BANG
|
||||
}
|
||||
|
||||
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
|
||||
Graph convertNCHWtoNHWCModel(Graph inG) {
|
||||
// Construct new graph
|
||||
// IT_ASSERT(inG->getInputs().size() == 1);
|
||||
IT_ASSERT(inG->getOutputs().size() == 1);
|
||||
bool status = inG->topo_sort();
|
||||
IT_ASSERT(status);
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto g = make_ref<GraphObj>(inG->getRuntime());
|
||||
map<UidBaseType, Tensor> tensors;
|
||||
// modelStatus: if currently processing Conv-related subgraph
|
||||
// 0: before processcing Conv-related subgraph
|
||||
|
@ -71,6 +71,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
|
|||
auto bias =
|
||||
cOp->getBias() ? g->cloneTensor(cOp->getBias()) : nullptr;
|
||||
auto weight = runWeightComputation(inputs[1]);
|
||||
g->addTensor(weight);
|
||||
g->addOpWithOutputs<ConvNHWCObj>(inputs[0], weight, outputs[0],
|
||||
ph, pw, sh, sw, dh, dw, bias,
|
||||
cOp->getAct());
|
||||
|
@ -82,6 +83,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
|
|||
auto bias =
|
||||
cOp->getBias() ? g->cloneTensor(cOp->getBias()) : nullptr;
|
||||
auto weight = runWeightComputation(inputs[1]);
|
||||
g->addTensor(weight);
|
||||
g->addOpWithOutputs<ConvTransposed2dNHWCObj>(
|
||||
inputs[0], weight, outputs[0], ph, pw, sh, sw, dh, dw, oph,
|
||||
opw, group, bias, cOp->getAct());
|
||||
|
@ -139,6 +141,7 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
|
|||
vector<int>{0, 2, 3, 1})
|
||||
->getOutput();
|
||||
auto weight = runWeightComputation(inputs[1]);
|
||||
g->addTensor(weight);
|
||||
g->addOpWithOutputs<ConvNHWCObj>(t, weight, outputs[0], ph,
|
||||
pw, sh, sw, dh, dw, bias,
|
||||
cOp->getAct());
|
||||
|
|
|
@ -344,6 +344,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("convert_nhwc", &Handler::convert_nhwc, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic);
|
||||
|
|
Loading…
Reference in New Issue