Add: import ONNX with membound Op

This commit is contained in:
Liyan Zheng 2023-04-20 10:39:42 +08:00
parent 2a343e240e
commit 15d0eb79cd
12 changed files with 83 additions and 35 deletions

View File

@ -98,6 +98,9 @@ class GraphHandlerObj {
const optional<vector<int>> &steps); const optional<vector<int>> &steps);
Tensor pad(Tensor input, Tensor output, const vector<int> &pads, Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes); const optional<vector<int>> &axes);
/// @brief Import memBound operator from a json
TensorVec memBound(const TensorVec &inputs, const Tensor &outputs,
const string &jsonString);
//------ modifiers //------ modifiers
@ -110,6 +113,7 @@ class GraphHandlerObj {
void data_malloc() { g->dataMalloc(); } void data_malloc() { g->dataMalloc(); }
void run() { g->getRuntime()->run(g); } void run() { g->getRuntime()->run(g); }
Graph getGraph() const { return g; }
}; };
} // namespace infini } // namespace infini

View File

@ -79,6 +79,11 @@ class Serializer : public Functor<string()> {
tuple<Expr, vector<Tensor>, double, string> tuple<Expr, vector<Tensor>, double, string>
deserializeAsMemobundOp(const string &filePath); deserializeAsMemobundOp(const string &filePath);
// FIXME: the order of elements in tuple is not consistent with memboundObj
// constructor
tuple<Expr, vector<Tensor>, double, string>
membundOpFromString(const string &data);
}; };
} // namespace nnet } // namespace nnet

View File

@ -32,8 +32,6 @@ class NMutator : public Mutator {
long long cntStates = 0; long long cntStates = 0;
long long cntCandidates = 0; long long cntCandidates = 0;
static void memboundToJson(const Graph &g, const string path);
private: private:
int maxDepth = 8; int maxDepth = 8;
nnet::Expr opToExpression(Operator op); nnet::Expr opToExpression(Operator op);

View File

@ -36,7 +36,7 @@ class MemBoundObj : public OperatorObj {
return {expr, hash}; return {expr, hash};
} }
double getEstimatedTime() const { return exec_time; } double getEstimatedTime() const { return exec_time; }
void saveAsJson(string path) const; string toJson() const;
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;

View File

@ -1,4 +1,4 @@
import backend import backend
from onnx import ( from onnx import (
ModelProto, ModelProto,
TensorProto, TensorProto,
@ -40,7 +40,8 @@ class OnnxStub:
disable_check: bool disable_check: bool
@classmethod @classmethod
def from_onnx(cls, model: ModelProto, runtime): def from_onnx(cls, model: ModelProto, runtime, enable_onnx_shape_infernce = True):
if enable_onnx_shape_infernce:
model = infer_shapes(model) model = infer_shapes(model)
ans = OnnxStub() ans = OnnxStub()
ans.handler = backend.GraphHandler(runtime) ans.handler = backend.GraphHandler(runtime)
@ -516,7 +517,18 @@ class OnnxStub:
): ):
tensors[name] = tensor tensors[name] = tensor
elif node.op_type == "MemBound": elif node.op_type == "MemBound":
raise Exception('Unsupported operator "{}"'.format(node.op_type)) attributes = _parse_attribute(node, {'expr': None})
expr: str = attributes['expr']
assert expr is not None
assert len(node.output) == 1, """MemBound with multiple
outputs requires rewrite the logic of tensor creation"""
outputs = ans.handler.memBound(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
expr,
)
for name, tensor in zip(node.output, outputs):
tensors[name] = tensor
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
@ -572,10 +584,10 @@ class OnnxStub:
value_info: List[ValueInfoProto] = [] value_info: List[ValueInfoProto] = []
enable_check = False enable_check = False
def __init__(self, enable_check): def __init__(self, enable_check):
self.enable_check = enable_check self.enable_check = enable_check
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type() ty = op.op_type()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)

View File

@ -1,10 +1,12 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "nnet/Visitor/Serializer.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/pad.h" #include "operators/pad.h"
#include "operators/pooling.h" #include "operators/pooling.h"
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
@ -291,6 +293,22 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
} }
} }
TensorVec GraphHandlerObj::memBound(const TensorVec &inputs,
const Tensor &output,
const string &jsonString) {
const auto &[expr, nnetInputs, execTime, hint] =
nnet::Serializer().membundOpFromString(jsonString);
if (output) {
g->addOpWithOutputs<MemBoundObj>(std::move(inputs), TensorVec{output},
nnetInputs, expr, execTime, hint);
return {output};
} else
return g
->addOp<MemBoundObj>(std::move(inputs), TensorVec{nullptr},
nnetInputs, expr, execTime, hint)
->getOutputs();
}
static DataType dtype_repr_convert(int dtype) { static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) { switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT: case OnnxDType::FLOAT:

View File

@ -1,7 +1,6 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "core/mutator.h" #include "core/mutator.h"
#include "core/search_engine.h" #include "core/search_engine.h"
#include "nnet/Visitor/Serializer.h"
#include "nnet/nmutator.h" #include "nnet/nmutator.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -232,9 +231,7 @@ static vector<int> transpose_permute_of(Operator op) {
} }
static string membound_expr_of(Operator op) { static string membound_expr_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::MemBound); return as<MemBoundObj>(op)->toJson();
return *nnet::Serializer().toString(
dynamic_cast<const MemBoundObj *>(op.get())->getNnetExpr());
} }
void export_functions(py::module &m) { void export_functions(py::module &m) {
@ -276,12 +273,16 @@ void init_graph_builder(py::module &m) {
py::class_<Object, Ref<Object>>(m, "_Object") py::class_<Object, Ref<Object>>(m, "_Object")
.def("__str__", &Object::toString) .def("__str__", &Object::toString)
.def("guid", &Object::getGuid); .def("guid", &Object::getGuid);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime"); py::class_<RuntimeObj, Ref<RuntimeObj>>(m, "Runtime")
.def("run", &RuntimeObj::run, "graph"_a, "tune"_a = false,
"profiling"_a = false)
.def("getPerfTime", &RuntimeObj::getPerfTime)
.def("timeNonCtcOperators", &RuntimeObj::timeNonCtcOperators);
py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>, py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>,
RuntimeObj>(m, "CpuRuntime"); RuntimeObj>(m, "CpuRuntime");
#ifdef USE_CUDA #ifdef USE_CUDA
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>( py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
m, "CudaRuntime"); "CudaRuntime");
#endif #endif
#ifdef USE_BANG #ifdef USE_BANG
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>( py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
@ -342,19 +343,20 @@ void init_graph_builder(py::module &m) {
.def("reduce_mean", &Handler::reduceMean, policy::move) .def("reduce_mean", &Handler::reduceMean, policy::move)
.def("slice", &Handler::slice, policy::move) .def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move) .def("pad", &Handler::pad, policy::move)
.def("memBound", &Handler::memBound, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic)
.def("getGraph", &Handler::getGraph);
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run); py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
py::enum_<NMutator::Mode>(m, "NMutatorMode") py::enum_<NMutator::Mode>(m, "NMutatorMode")
.value("RuleBased", NMutator::Mode::RuleBased); .value("RuleBased", NMutator::Mode::RuleBased);
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator") py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
.def(py::init<NMutator::Mode>()) .def(py::init<NMutator::Mode>())
.def(py::init<NMutator::Mode, vector<int>>()) .def(py::init<NMutator::Mode, vector<int>>())
.def("run", &NMutator::run) .def("run", &NMutator::run);
.def_static("memboundToJson", &NMutator::memboundToJson);
py::class_<SearchEngine>(m, "SearchEngine") py::class_<SearchEngine>(m, "SearchEngine")
.def(py::init<Runtime, Ref<Mutator>>()) .def(py::init<Runtime, Ref<Mutator>>())
.def("run", &SearchEngine::run); .def("run", &SearchEngine::run);

View File

@ -111,7 +111,7 @@ std::optional<std::string> Serializer::toString(const Expr &expr,
bool Serializer::toFile(const Expr &expr, const string &filePath, bool Serializer::toFile(const Expr &expr, const string &filePath,
const string &msg, vector<Tensor> inputs, const string &msg, vector<Tensor> inputs,
double exec_time, string hint) { double exec_time, string hint) {
if (auto s = toString(expr, msg, inputs, exec_time, hint); s) { if (auto s = toString(expr, msg, inputs, exec_time, hint)) {
// Write to file // Write to file
std::ofstream fout(filePath); std::ofstream fout(filePath);
fout << *s; fout << *s;
@ -300,4 +300,14 @@ Serializer::deserializeAsMemobundOp(const string &filePath) {
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]}; return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
} }
tuple<Expr, vector<Tensor>, double, string>
Serializer::membundOpFromString(const string &data) {
j = json::parse(data);
assert(j["Version"] == VERSION);
vector<Tensor> inputs;
for (const auto &input : j["nnetInputs"])
inputs.emplace_back(as<TensorNode>(buildExprTree(input)));
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
}
} // namespace nnet } // namespace nnet

View File

@ -754,15 +754,6 @@ NMutator::generateUnaryExpr(const Operator &op) {
NameNToTensorT{{"T", op->getInputs()[0]}}}; NameNToTensorT{{"T", op->getInputs()[0]}}};
} }
void NMutator::memboundToJson(const Graph &g, const string path) {
for (auto &_op : g->getOperators()) {
if (auto op = as<MemBoundObj>(_op)) {
op->saveAsJson(path + "/" + "membound_" +
std::to_string(op->getGuid()) + ".json");
}
}
}
pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) { pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
using namespace nnet; using namespace nnet;
using infini::make_ref; using infini::make_ref;

View File

@ -91,10 +91,9 @@ bool MemBoundObj::checkOOB(nnet::Expr expr) {
nnet::as<nnet::RangeOpNode>(expr)); nnet::as<nnet::RangeOpNode>(expr));
} }
void MemBoundObj::saveAsJson(string path) const { string MemBoundObj::toJson() const {
bool status = nnet::Serializer().toFile( return *nnet::Serializer().toString(expr, "MemBoundObj::toJson", nnetInputs,
expr, path, "MemBoundObj::saveAsJson", nnetInputs, exec_time, hint); exec_time, hint);
IT_ASSERT(status);
} }
} // namespace infini } // namespace infini

View File

@ -88,9 +88,18 @@ def run_InfoGAN_without_tuning(tuning: bool):
f.write(stub.to_onnx("optimized").SerializeToString()) f.write(stub.to_onnx("optimized").SerializeToString())
def load_onnx_and_run():
runtime = ft.cuda_runtime()
stub = OnnxStub.from_onnx(onnx.load("optimized.onnx"), runtime, False)
g = stub.handler.getGraph()
runtime.run(g, True)
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
if __name__ == "__main__": if __name__ == "__main__":
# run_e2e_InfoGAN() # run_e2e_InfoGAN()
run_InfoGAN_without_tuning(True)
# runSingleConvT() # runSingleConvT()
# read_and_check() # read_and_check()
# run_InfoGAN_without_tuning(False)
load_onnx_and_run()

View File

@ -103,9 +103,9 @@ TEST(Serializer, Serialization_memboundOp) {
string hint = "test"; string hint = "test";
infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr, infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr,
execTime, hint); execTime, hint);
memboundOp.saveAsJson("./test_serializer.json"); auto str = memboundOp.toJson();
auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] = auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] =
Serializer().deserializeAsMemobundOp("./test_serializer.json"); Serializer().membundOpFromString(str);
EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable()); EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable());
EXPECT_EQ(execTime, execTimeLoaded); EXPECT_EQ(execTime, execTimeLoaded);
EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size()); EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size());