forked from jiuyuan/InfiniTensor
Add: import ONNX with membound Op
This commit is contained in:
parent
2a343e240e
commit
15d0eb79cd
|
@ -98,6 +98,9 @@ class GraphHandlerObj {
|
||||||
const optional<vector<int>> &steps);
|
const optional<vector<int>> &steps);
|
||||||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||||
const optional<vector<int>> &axes);
|
const optional<vector<int>> &axes);
|
||||||
|
/// @brief Import memBound operator from a json
|
||||||
|
TensorVec memBound(const TensorVec &inputs, const Tensor &outputs,
|
||||||
|
const string &jsonString);
|
||||||
|
|
||||||
//------ modifiers
|
//------ modifiers
|
||||||
|
|
||||||
|
@ -110,6 +113,7 @@ class GraphHandlerObj {
|
||||||
void data_malloc() { g->dataMalloc(); }
|
void data_malloc() { g->dataMalloc(); }
|
||||||
|
|
||||||
void run() { g->getRuntime()->run(g); }
|
void run() { g->getRuntime()->run(g); }
|
||||||
|
Graph getGraph() const { return g; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -79,6 +79,11 @@ class Serializer : public Functor<string()> {
|
||||||
|
|
||||||
tuple<Expr, vector<Tensor>, double, string>
|
tuple<Expr, vector<Tensor>, double, string>
|
||||||
deserializeAsMemobundOp(const string &filePath);
|
deserializeAsMemobundOp(const string &filePath);
|
||||||
|
|
||||||
|
// FIXME: the order of elements in tuple is not consistent with memboundObj
|
||||||
|
// constructor
|
||||||
|
tuple<Expr, vector<Tensor>, double, string>
|
||||||
|
membundOpFromString(const string &data);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace nnet
|
} // namespace nnet
|
||||||
|
|
|
@ -32,8 +32,6 @@ class NMutator : public Mutator {
|
||||||
long long cntStates = 0;
|
long long cntStates = 0;
|
||||||
long long cntCandidates = 0;
|
long long cntCandidates = 0;
|
||||||
|
|
||||||
static void memboundToJson(const Graph &g, const string path);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int maxDepth = 8;
|
int maxDepth = 8;
|
||||||
nnet::Expr opToExpression(Operator op);
|
nnet::Expr opToExpression(Operator op);
|
||||||
|
|
|
@ -36,7 +36,7 @@ class MemBoundObj : public OperatorObj {
|
||||||
return {expr, hash};
|
return {expr, hash};
|
||||||
}
|
}
|
||||||
double getEstimatedTime() const { return exec_time; }
|
double getEstimatedTime() const { return exec_time; }
|
||||||
void saveAsJson(string path) const;
|
string toJson() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import backend
|
import backend
|
||||||
from onnx import (
|
from onnx import (
|
||||||
ModelProto,
|
ModelProto,
|
||||||
TensorProto,
|
TensorProto,
|
||||||
|
@ -40,8 +40,9 @@ class OnnxStub:
|
||||||
disable_check: bool
|
disable_check: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_onnx(cls, model: ModelProto, runtime):
|
def from_onnx(cls, model: ModelProto, runtime, enable_onnx_shape_infernce = True):
|
||||||
model = infer_shapes(model)
|
if enable_onnx_shape_infernce:
|
||||||
|
model = infer_shapes(model)
|
||||||
ans = OnnxStub()
|
ans = OnnxStub()
|
||||||
ans.handler = backend.GraphHandler(runtime)
|
ans.handler = backend.GraphHandler(runtime)
|
||||||
|
|
||||||
|
@ -516,7 +517,18 @@ class OnnxStub:
|
||||||
):
|
):
|
||||||
tensors[name] = tensor
|
tensors[name] = tensor
|
||||||
elif node.op_type == "MemBound":
|
elif node.op_type == "MemBound":
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
attributes = _parse_attribute(node, {'expr': None})
|
||||||
|
expr: str = attributes['expr']
|
||||||
|
assert expr is not None
|
||||||
|
assert len(node.output) == 1, """MemBound with multiple
|
||||||
|
outputs requires rewrite the logic of tensor creation"""
|
||||||
|
outputs = ans.handler.memBound(
|
||||||
|
[tensors[name] for name in node.input],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
expr,
|
||||||
|
)
|
||||||
|
for name, tensor in zip(node.output, outputs):
|
||||||
|
tensors[name] = tensor
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
@ -572,10 +584,10 @@ class OnnxStub:
|
||||||
value_info: List[ValueInfoProto] = []
|
value_info: List[ValueInfoProto] = []
|
||||||
|
|
||||||
enable_check = False
|
enable_check = False
|
||||||
|
|
||||||
def __init__(self, enable_check):
|
def __init__(self, enable_check):
|
||||||
self.enable_check = enable_check
|
self.enable_check = enable_check
|
||||||
|
|
||||||
|
|
||||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||||
ty = op.op_type()
|
ty = op.op_type()
|
||||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "nnet/Visitor/Serializer.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "operators/membound.h"
|
||||||
#include "operators/pad.h"
|
#include "operators/pad.h"
|
||||||
#include "operators/pooling.h"
|
#include "operators/pooling.h"
|
||||||
#include "operators/reduce_mean.h"
|
#include "operators/reduce_mean.h"
|
||||||
|
@ -291,6 +293,22 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TensorVec GraphHandlerObj::memBound(const TensorVec &inputs,
|
||||||
|
const Tensor &output,
|
||||||
|
const string &jsonString) {
|
||||||
|
const auto &[expr, nnetInputs, execTime, hint] =
|
||||||
|
nnet::Serializer().membundOpFromString(jsonString);
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<MemBoundObj>(std::move(inputs), TensorVec{output},
|
||||||
|
nnetInputs, expr, execTime, hint);
|
||||||
|
return {output};
|
||||||
|
} else
|
||||||
|
return g
|
||||||
|
->addOp<MemBoundObj>(std::move(inputs), TensorVec{nullptr},
|
||||||
|
nnetInputs, expr, execTime, hint)
|
||||||
|
->getOutputs();
|
||||||
|
}
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "core/mutator.h"
|
#include "core/mutator.h"
|
||||||
#include "core/search_engine.h"
|
#include "core/search_engine.h"
|
||||||
#include "nnet/Visitor/Serializer.h"
|
|
||||||
#include "nnet/nmutator.h"
|
#include "nnet/nmutator.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
@ -232,9 +231,7 @@ static vector<int> transpose_permute_of(Operator op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static string membound_expr_of(Operator op) {
|
static string membound_expr_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::MemBound);
|
return as<MemBoundObj>(op)->toJson();
|
||||||
return *nnet::Serializer().toString(
|
|
||||||
dynamic_cast<const MemBoundObj *>(op.get())->getNnetExpr());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void export_functions(py::module &m) {
|
void export_functions(py::module &m) {
|
||||||
|
@ -276,12 +273,16 @@ void init_graph_builder(py::module &m) {
|
||||||
py::class_<Object, Ref<Object>>(m, "_Object")
|
py::class_<Object, Ref<Object>>(m, "_Object")
|
||||||
.def("__str__", &Object::toString)
|
.def("__str__", &Object::toString)
|
||||||
.def("guid", &Object::getGuid);
|
.def("guid", &Object::getGuid);
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
py::class_<RuntimeObj, Ref<RuntimeObj>>(m, "Runtime")
|
||||||
|
.def("run", &RuntimeObj::run, "graph"_a, "tune"_a = false,
|
||||||
|
"profiling"_a = false)
|
||||||
|
.def("getPerfTime", &RuntimeObj::getPerfTime)
|
||||||
|
.def("timeNonCtcOperators", &RuntimeObj::timeNonCtcOperators);
|
||||||
py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>,
|
py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>,
|
||||||
RuntimeObj>(m, "CpuRuntime");
|
RuntimeObj>(m, "CpuRuntime");
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
|
||||||
m, "CudaRuntime");
|
"CudaRuntime");
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_BANG
|
#ifdef USE_BANG
|
||||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||||
|
@ -342,19 +343,20 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||||
.def("slice", &Handler::slice, policy::move)
|
.def("slice", &Handler::slice, policy::move)
|
||||||
.def("pad", &Handler::pad, policy::move)
|
.def("pad", &Handler::pad, policy::move)
|
||||||
|
.def("memBound", &Handler::memBound, policy::move)
|
||||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
.def("optimize", &Handler::optimize, policy::automatic)
|
.def("optimize", &Handler::optimize, policy::automatic)
|
||||||
.def("operators", &Handler::operators, policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("run", &Handler::run, policy::automatic)
|
||||||
|
.def("getGraph", &Handler::getGraph);
|
||||||
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
|
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
|
||||||
py::enum_<NMutator::Mode>(m, "NMutatorMode")
|
py::enum_<NMutator::Mode>(m, "NMutatorMode")
|
||||||
.value("RuleBased", NMutator::Mode::RuleBased);
|
.value("RuleBased", NMutator::Mode::RuleBased);
|
||||||
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
|
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
|
||||||
.def(py::init<NMutator::Mode>())
|
.def(py::init<NMutator::Mode>())
|
||||||
.def(py::init<NMutator::Mode, vector<int>>())
|
.def(py::init<NMutator::Mode, vector<int>>())
|
||||||
.def("run", &NMutator::run)
|
.def("run", &NMutator::run);
|
||||||
.def_static("memboundToJson", &NMutator::memboundToJson);
|
|
||||||
py::class_<SearchEngine>(m, "SearchEngine")
|
py::class_<SearchEngine>(m, "SearchEngine")
|
||||||
.def(py::init<Runtime, Ref<Mutator>>())
|
.def(py::init<Runtime, Ref<Mutator>>())
|
||||||
.def("run", &SearchEngine::run);
|
.def("run", &SearchEngine::run);
|
||||||
|
|
|
@ -111,7 +111,7 @@ std::optional<std::string> Serializer::toString(const Expr &expr,
|
||||||
bool Serializer::toFile(const Expr &expr, const string &filePath,
|
bool Serializer::toFile(const Expr &expr, const string &filePath,
|
||||||
const string &msg, vector<Tensor> inputs,
|
const string &msg, vector<Tensor> inputs,
|
||||||
double exec_time, string hint) {
|
double exec_time, string hint) {
|
||||||
if (auto s = toString(expr, msg, inputs, exec_time, hint); s) {
|
if (auto s = toString(expr, msg, inputs, exec_time, hint)) {
|
||||||
// Write to file
|
// Write to file
|
||||||
std::ofstream fout(filePath);
|
std::ofstream fout(filePath);
|
||||||
fout << *s;
|
fout << *s;
|
||||||
|
@ -300,4 +300,14 @@ Serializer::deserializeAsMemobundOp(const string &filePath) {
|
||||||
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
|
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tuple<Expr, vector<Tensor>, double, string>
|
||||||
|
Serializer::membundOpFromString(const string &data) {
|
||||||
|
j = json::parse(data);
|
||||||
|
assert(j["Version"] == VERSION);
|
||||||
|
vector<Tensor> inputs;
|
||||||
|
for (const auto &input : j["nnetInputs"])
|
||||||
|
inputs.emplace_back(as<TensorNode>(buildExprTree(input)));
|
||||||
|
return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace nnet
|
} // namespace nnet
|
||||||
|
|
|
@ -754,15 +754,6 @@ NMutator::generateUnaryExpr(const Operator &op) {
|
||||||
NameNToTensorT{{"T", op->getInputs()[0]}}};
|
NameNToTensorT{{"T", op->getInputs()[0]}}};
|
||||||
}
|
}
|
||||||
|
|
||||||
void NMutator::memboundToJson(const Graph &g, const string path) {
|
|
||||||
for (auto &_op : g->getOperators()) {
|
|
||||||
if (auto op = as<MemBoundObj>(_op)) {
|
|
||||||
op->saveAsJson(path + "/" + "membound_" +
|
|
||||||
std::to_string(op->getGuid()) + ".json");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
|
pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
|
||||||
using namespace nnet;
|
using namespace nnet;
|
||||||
using infini::make_ref;
|
using infini::make_ref;
|
||||||
|
|
|
@ -91,10 +91,9 @@ bool MemBoundObj::checkOOB(nnet::Expr expr) {
|
||||||
nnet::as<nnet::RangeOpNode>(expr));
|
nnet::as<nnet::RangeOpNode>(expr));
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemBoundObj::saveAsJson(string path) const {
|
string MemBoundObj::toJson() const {
|
||||||
bool status = nnet::Serializer().toFile(
|
return *nnet::Serializer().toString(expr, "MemBoundObj::toJson", nnetInputs,
|
||||||
expr, path, "MemBoundObj::saveAsJson", nnetInputs, exec_time, hint);
|
exec_time, hint);
|
||||||
IT_ASSERT(status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -88,9 +88,18 @@ def run_InfoGAN_without_tuning(tuning: bool):
|
||||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
|
def load_onnx_and_run():
|
||||||
|
runtime = ft.cuda_runtime()
|
||||||
|
stub = OnnxStub.from_onnx(onnx.load("optimized.onnx"), runtime, False)
|
||||||
|
g = stub.handler.getGraph()
|
||||||
|
runtime.run(g, True)
|
||||||
|
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# run_e2e_InfoGAN()
|
# run_e2e_InfoGAN()
|
||||||
run_InfoGAN_without_tuning(True)
|
|
||||||
# runSingleConvT()
|
# runSingleConvT()
|
||||||
# read_and_check()
|
# read_and_check()
|
||||||
|
|
||||||
|
# run_InfoGAN_without_tuning(False)
|
||||||
|
load_onnx_and_run()
|
||||||
|
|
|
@ -103,9 +103,9 @@ TEST(Serializer, Serialization_memboundOp) {
|
||||||
string hint = "test";
|
string hint = "test";
|
||||||
infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr,
|
infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr,
|
||||||
execTime, hint);
|
execTime, hint);
|
||||||
memboundOp.saveAsJson("./test_serializer.json");
|
auto str = memboundOp.toJson();
|
||||||
auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] =
|
auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] =
|
||||||
Serializer().deserializeAsMemobundOp("./test_serializer.json");
|
Serializer().membundOpFromString(str);
|
||||||
EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable());
|
EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable());
|
||||||
EXPECT_EQ(execTime, execTimeLoaded);
|
EXPECT_EQ(execTime, execTimeLoaded);
|
||||||
EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size());
|
EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size());
|
||||||
|
|
Loading…
Reference in New Issue