forked from jiuyuan/InfiniTensor
Add: import ONNX with membound Op
This commit is contained in:
parent
2a343e240e
commit
15d0eb79cd
|
@ -98,6 +98,9 @@ class GraphHandlerObj {
|
|||
const optional<vector<int>> &steps);
|
||||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||
const optional<vector<int>> &axes);
|
||||
/// @brief Import memBound operator from a json
|
||||
TensorVec memBound(const TensorVec &inputs, const Tensor &outputs,
|
||||
const string &jsonString);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
|
@ -110,6 +113,7 @@ class GraphHandlerObj {
|
|||
void data_malloc() { g->dataMalloc(); }
|
||||
|
||||
void run() { g->getRuntime()->run(g); }
|
||||
Graph getGraph() const { return g; }
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -79,6 +79,11 @@ class Serializer : public Functor<string()> {
|
|||
|
||||
tuple<Expr, vector<Tensor>, double, string>
|
||||
deserializeAsMemobundOp(const string &filePath);
|
||||
|
||||
// FIXME: the order of elements in tuple is not consistent with memboundObj
|
||||
// constructor
|
||||
tuple<Expr, vector<Tensor>, double, string>
|
||||
membundOpFromString(const string &data);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
||||
|
|
|
@ -32,8 +32,6 @@ class NMutator : public Mutator {
|
|||
long long cntStates = 0;
|
||||
long long cntCandidates = 0;
|
||||
|
||||
static void memboundToJson(const Graph &g, const string path);
|
||||
|
||||
private:
|
||||
int maxDepth = 8;
|
||||
nnet::Expr opToExpression(Operator op);
|
||||
|
|
|
@ -36,7 +36,7 @@ class MemBoundObj : public OperatorObj {
|
|||
return {expr, hash};
|
||||
}
|
||||
double getEstimatedTime() const { return exec_time; }
|
||||
void saveAsJson(string path) const;
|
||||
string toJson() const;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import backend
|
||||
import backend
|
||||
from onnx import (
|
||||
ModelProto,
|
||||
TensorProto,
|
||||
|
@ -40,8 +40,9 @@ class OnnxStub:
|
|||
disable_check: bool
|
||||
|
||||
@classmethod
|
||||
def from_onnx(cls, model: ModelProto, runtime):
|
||||
model = infer_shapes(model)
|
||||
def from_onnx(cls, model: ModelProto, runtime, enable_onnx_shape_infernce = True):
|
||||
if enable_onnx_shape_infernce:
|
||||
model = infer_shapes(model)
|
||||
ans = OnnxStub()
|
||||
ans.handler = backend.GraphHandler(runtime)
|
||||
|
||||
|
@ -516,7 +517,18 @@ class OnnxStub:
|
|||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "MemBound":
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
attributes = _parse_attribute(node, {'expr': None})
|
||||
expr: str = attributes['expr']
|
||||
assert expr is not None
|
||||
assert len(node.output) == 1, """MemBound with multiple
|
||||
outputs requires rewrite the logic of tensor creation"""
|
||||
outputs = ans.handler.memBound(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
expr,
|
||||
)
|
||||
for name, tensor in zip(node.output, outputs):
|
||||
tensors[name] = tensor
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
@ -572,10 +584,10 @@ class OnnxStub:
|
|||
value_info: List[ValueInfoProto] = []
|
||||
|
||||
enable_check = False
|
||||
|
||||
def __init__(self, enable_check):
|
||||
self.enable_check = enable_check
|
||||
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
@ -291,6 +293,22 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
|
|||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::memBound(const TensorVec &inputs,
|
||||
const Tensor &output,
|
||||
const string &jsonString) {
|
||||
const auto &[expr, nnetInputs, execTime, hint] =
|
||||
nnet::Serializer().membundOpFromString(jsonString);
|
||||
if (output) {
|
||||
g->addOpWithOutputs<MemBoundObj>(std::move(inputs), TensorVec{output},
|
||||
nnetInputs, expr, execTime, hint);
|
||||
return {output};
|
||||
} else
|
||||
return g
|
||||
->addOp<MemBoundObj>(std::move(inputs), TensorVec{nullptr},
|
||||
nnetInputs, expr, execTime, hint)
|
||||
->getOutputs();
|
||||
}
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "core/mutator.h"
|
||||
#include "core/search_engine.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
|
@ -232,9 +231,7 @@ static vector<int> transpose_permute_of(Operator op) {
|
|||
}
|
||||
|
||||
static string membound_expr_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::MemBound);
|
||||
return *nnet::Serializer().toString(
|
||||
dynamic_cast<const MemBoundObj *>(op.get())->getNnetExpr());
|
||||
return as<MemBoundObj>(op)->toJson();
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
|
@ -276,12 +273,16 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<Object, Ref<Object>>(m, "_Object")
|
||||
.def("__str__", &Object::toString)
|
||||
.def("guid", &Object::getGuid);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<RuntimeObj, Ref<RuntimeObj>>(m, "Runtime")
|
||||
.def("run", &RuntimeObj::run, "graph"_a, "tune"_a = false,
|
||||
"profiling"_a = false)
|
||||
.def("getPerfTime", &RuntimeObj::getPerfTime)
|
||||
.def("timeNonCtcOperators", &RuntimeObj::timeNonCtcOperators);
|
||||
py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>,
|
||||
RuntimeObj>(m, "CpuRuntime");
|
||||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
||||
m, "CudaRuntime");
|
||||
py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
|
||||
"CudaRuntime");
|
||||
#endif
|
||||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
|
@ -342,19 +343,20 @@ void init_graph_builder(py::module &m) {
|
|||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("memBound", &Handler::memBound, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic);
|
||||
.def("run", &Handler::run, policy::automatic)
|
||||
.def("getGraph", &Handler::getGraph);
|
||||
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
|
||||
py::enum_<NMutator::Mode>(m, "NMutatorMode")
|
||||
.value("RuleBased", NMutator::Mode::RuleBased);
|
||||
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
|
||||
.def(py::init<NMutator::Mode>())
|
||||
.def(py::init<NMutator::Mode, vector<int>>())
|
||||
.def("run", &NMutator::run)
|
||||
.def_static("memboundToJson", &NMutator::memboundToJson);
|
||||
.def("run", &NMutator::run);
|
||||
py::class_<SearchEngine>(m, "SearchEngine")
|
||||
.def(py::init<Runtime, Ref<Mutator>>())
|
||||
.def("run", &SearchEngine::run);
|
||||
|
|
|
@ -111,7 +111,7 @@ std::optional<std::string> Serializer::toString(const Expr &expr,
|
|||
bool Serializer::toFile(const Expr &expr, const string &filePath,
|
||||
const string &msg, vector<Tensor> inputs,
|
||||
double exec_time, string hint) {
|
||||
if (auto s = toString(expr, msg, inputs, exec_time, hint); s) {
|
||||
if (auto s = toString(expr, msg, inputs, exec_time, hint)) {
|
||||
// Write to file
|
||||
std::ofstream fout(filePath);
|
||||
fout << *s;
|
||||
|
@ -300,4 +300,14 @@ Serializer::deserializeAsMemobundOp(const string &filePath) {
|
|||
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
|
||||
}
|
||||
|
||||
tuple<Expr, vector<Tensor>, double, string>
|
||||
Serializer::membundOpFromString(const string &data) {
|
||||
j = json::parse(data);
|
||||
assert(j["Version"] == VERSION);
|
||||
vector<Tensor> inputs;
|
||||
for (const auto &input : j["nnetInputs"])
|
||||
inputs.emplace_back(as<TensorNode>(buildExprTree(input)));
|
||||
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
|
||||
}
|
||||
|
||||
} // namespace nnet
|
||||
|
|
|
@ -754,15 +754,6 @@ NMutator::generateUnaryExpr(const Operator &op) {
|
|||
NameNToTensorT{{"T", op->getInputs()[0]}}};
|
||||
}
|
||||
|
||||
void NMutator::memboundToJson(const Graph &g, const string path) {
|
||||
for (auto &_op : g->getOperators()) {
|
||||
if (auto op = as<MemBoundObj>(_op)) {
|
||||
op->saveAsJson(path + "/" + "membound_" +
|
||||
std::to_string(op->getGuid()) + ".json");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
|
||||
using namespace nnet;
|
||||
using infini::make_ref;
|
||||
|
|
|
@ -91,10 +91,9 @@ bool MemBoundObj::checkOOB(nnet::Expr expr) {
|
|||
nnet::as<nnet::RangeOpNode>(expr));
|
||||
}
|
||||
|
||||
void MemBoundObj::saveAsJson(string path) const {
|
||||
bool status = nnet::Serializer().toFile(
|
||||
expr, path, "MemBoundObj::saveAsJson", nnetInputs, exec_time, hint);
|
||||
IT_ASSERT(status);
|
||||
string MemBoundObj::toJson() const {
|
||||
return *nnet::Serializer().toString(expr, "MemBoundObj::toJson", nnetInputs,
|
||||
exec_time, hint);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -88,9 +88,18 @@ def run_InfoGAN_without_tuning(tuning: bool):
|
|||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
|
||||
|
||||
def load_onnx_and_run():
|
||||
runtime = ft.cuda_runtime()
|
||||
stub = OnnxStub.from_onnx(onnx.load("optimized.onnx"), runtime, False)
|
||||
g = stub.handler.getGraph()
|
||||
runtime.run(g, True)
|
||||
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# run_e2e_InfoGAN()
|
||||
run_InfoGAN_without_tuning(True)
|
||||
# runSingleConvT()
|
||||
# read_and_check()
|
||||
|
||||
# run_InfoGAN_without_tuning(False)
|
||||
load_onnx_and_run()
|
||||
|
|
|
@ -103,9 +103,9 @@ TEST(Serializer, Serialization_memboundOp) {
|
|||
string hint = "test";
|
||||
infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr,
|
||||
execTime, hint);
|
||||
memboundOp.saveAsJson("./test_serializer.json");
|
||||
auto str = memboundOp.toJson();
|
||||
auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] =
|
||||
Serializer().deserializeAsMemobundOp("./test_serializer.json");
|
||||
Serializer().membundOpFromString(str);
|
||||
EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable());
|
||||
EXPECT_EQ(execTime, execTimeLoaded);
|
||||
EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size());
|
||||
|
|
Loading…
Reference in New Issue