From 537b3b4ea444cba4e680fd167d2ad1c246c35c3d Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Tue, 18 Apr 2023 21:53:48 +0800 Subject: [PATCH] Add: Membound operator serialization --- include/nnet/Visitor/Serializer.h | 9 +++++- include/operators/membound.h | 9 ++++-- src/nnet/Visitor/Serializer.cc | 26 ++++++++++++++- src/operators/membound.cc | 18 ++++++++--- test/nnet/test_serializer.cc | 54 ++++++++++++++++++++----------- 5 files changed, 87 insertions(+), 29 deletions(-) diff --git a/include/nnet/Visitor/Serializer.h b/include/nnet/Visitor/Serializer.h index b2c3ff8a..dc0efe17 100644 --- a/include/nnet/Visitor/Serializer.h +++ b/include/nnet/Visitor/Serializer.h @@ -36,10 +36,14 @@ class Serializer : public Functor { * @param expr The expression to be serialized * @param filePath The path of json file to be output * @param msg Message of derivation + * @param inputs membound operator attributes + * @param exec_time membound operator attributes + * @param hint membound operator attributes * @return bool Whether the serialization succeed */ bool serialize(const Expr &expr, const string &filePath, - const string &msg = ""); + const string &msg = "", vector inputs = {}, + double exec_time = -1e9, string hint = ""); /** * @brief Deserialize the given json file to expression @@ -48,6 +52,9 @@ class Serializer : public Functor { * @return Expression deserialized from the given json file */ Expr deserialize(const string &filePath); + + tuple, double, string> + deserializeAsMemobundOp(const string &filePath); }; } // namespace nnet \ No newline at end of file diff --git a/include/operators/membound.h b/include/operators/membound.h index df42e5b2..fc289bdc 100644 --- a/include/operators/membound.h +++ b/include/operators/membound.h @@ -6,14 +6,17 @@ namespace infini { class MemBoundObj : public OperatorObj { private: + nnet::Expr expr; std::vector nnetInputs; // The order of inputs in nnetInputs should be consistant // with inputs in infinitensor - nnet::Expr expr, simplifiedExpr; double exec_time; std::string hint; - HashType hash, simplifiedHash; - int n, f, h, w; + + // Generated attributes + HashType hash; + nnet::Expr simplifiedExpr; + HashType simplifiedHash; public: MemBoundObj(GraphObj *graph, const TensorVec &input, diff --git a/src/nnet/Visitor/Serializer.cc b/src/nnet/Visitor/Serializer.cc index 423487e7..670f5a8c 100644 --- a/src/nnet/Visitor/Serializer.cc +++ b/src/nnet/Visitor/Serializer.cc @@ -82,13 +82,26 @@ string Serializer::visit_(const Func &c) { } bool Serializer::serialize(const Expr &expr, const string &filePath, - const string &msg) { + const string &msg, vector inputs, + double exec_time, string hint) { // Metadata j["Version"] = VERSION; j["Msg"] = msg; + j["exec_time"] = exec_time; + j["hint"] = hint; // Expressions and routines id = 0; dispatch(expr); + + // Input tensors + vector inputsIndices; + for (const auto &tensor : inputs) { + inputsIndices.emplace_back(std::to_string(id)); + dispatch(tensor); + } + j["nnetInputs"] = inputsIndices; + + // Write to file std::ofstream fout(filePath); fout << std::setw(4) << j << std::endl; return true; @@ -254,4 +267,15 @@ Routine Serializer::buildRoutine(string key) { return nullptr; } +tuple, double, string> +Serializer::deserializeAsMemobundOp(const string &filePath) { + std::ifstream fin(filePath); + fin >> j; + assert(j["Version"] == VERSION); + vector inputs; + for (const auto &input : j["nnetInputs"]) + inputs.emplace_back(as(buildExprTree(input))); + return {buildExprTree("0"), inputs, j["exec_time"], j["hint"]}; +} + } // namespace nnet \ No newline at end of file diff --git a/src/operators/membound.cc b/src/operators/membound.cc index 3afa9134..6a348879 100644 --- a/src/operators/membound.cc +++ b/src/operators/membound.cc @@ -10,8 +10,8 @@ MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input, const TensorVec &output, const std::vector &nnetInputs, nnet::Expr expr, double exec_time, std::string hint) - : OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs), - expr(expr), exec_time(exec_time), hint(hint) { + : OperatorObj(OpType::MemBound, input, output), expr(expr), + nnetInputs(nnetInputs), exec_time(exec_time), hint(hint) { IT_ASSERT(checkValid(graph)); IT_ASSERT(!checkOOB(expr)); hash = calcHash(expr); @@ -61,11 +61,18 @@ string MemBoundObj::toString() const { optional> MemBoundObj::inferShape(const TensorVec &inputs) const { // inputs have to match nnetInputs excatly - if (inputs.size() != nnetInputs.size()) + if (inputs.size() != nnetInputs.size()) { + std::cout << "Num mismatch" << inputs.size() << " " + << nnetInputs.size(); return {}; + } for (size_t i = 0; i < inputs.size(); ++i) - if (inputs[i]->getDims() != nnetInputs[i]->getShape()) + if (inputs[i]->getDims() != nnetInputs[i]->getShape()) { + std::cout << "Shape mismatch " << inputs[i] + << vecToString(inputs[i]->getDims()) << " " + << vecToString(nnetInputs[i]->getShape()); return {}; + } return {{nnet::as(expr)->getOutputShape()}}; } @@ -85,7 +92,8 @@ bool MemBoundObj::checkOOB(nnet::Expr expr) { } void MemBoundObj::saveAsJson(string path) const { - bool status = nnet::Serializer().serialize(expr, path); + bool status = nnet::Serializer().serialize( + expr, path, "MemBoundObj::saveAsJson", nnetInputs, exec_time, hint); IT_ASSERT(status); } diff --git a/test/nnet/test_serializer.cc b/test/nnet/test_serializer.cc index 0ec3982a..7363778b 100644 --- a/test/nnet/test_serializer.cc +++ b/test/nnet/test_serializer.cc @@ -1,9 +1,11 @@ +#include "core/graph.h" #include "nnet/Visitor/FullPrinterVisitor.h" #include "nnet/Visitor/Serializer.h" +#include "nnet/test.h" +#include "operators/membound.h" #include "gtest/gtest.h" using namespace nnet; using namespace std; -#define DEFINE_VAR(name) auto name = make_ref(#name); //{LSum //{({A}[b, (i3 + (2500 * i4)), k] * {B}[b, ((i3 + (2500 * i4)) + @@ -12,11 +14,7 @@ using namespace std; // ==> B : Input Tensor shape=[8,10000,512] pad=[0,128,0] Expr buildSimpleExpr() { - DEFINE_VAR(b); - DEFINE_VAR(w); - DEFINE_VAR(k); - DEFINE_VAR(i3); - DEFINE_VAR(i4); + DEFINE_VAR(b, w, k, i3, i4); auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0}); auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0}); auto subA = makeSubscript(A, {b, (i3 + (2500 * i4)), k}); @@ -28,9 +26,7 @@ Expr buildSimpleExpr() { } Expr buildNestedExpr() { - DEFINE_VAR(j1); - DEFINE_VAR(j2); - DEFINE_VAR(j3); + DEFINE_VAR(j1, j2, j3); // Build a Matmul to verify. const int M = 10000, N = 512, K = 3; auto C = make_ref("C", vector({M, K})); @@ -46,11 +42,7 @@ Expr buildNestedExpr() { auto E = make_ref("E", shapeE, shapeE, ele2); auto ele1 = make_ref(expr, vector{E}, shapeE); - DEFINE_VAR(b); - DEFINE_VAR(w); - DEFINE_VAR(k); - DEFINE_VAR(i3); - DEFINE_VAR(i4); + DEFINE_VAR(b, w, k, i3, i4); auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0}, matmul); auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0}, ele1); auto subA = makeSubscript(A, {b, (i3 + (2500 * i4)), k}); @@ -68,11 +60,7 @@ TEST(Serializer, Serialization) { } TEST(Serializer, CompareTwoExprs) { - DEFINE_VAR(b); - DEFINE_VAR(w); - DEFINE_VAR(k); - DEFINE_VAR(i3); - DEFINE_VAR(i4); + DEFINE_VAR(b, w, k, i3, i4); auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0}); auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0}); auto subA = makeSubscript(A, {b, (i3 + (2500 * i4)), k}); @@ -97,4 +85,32 @@ TEST(Serializer, Serialization_NestedTensor) { auto exprDeserialized = Serializer().deserialize("./test_serializer.json"); auto output = printer.print(exprDeserialized); EXPECT_EQ(output, ans); +} + +TEST(Serializer, Serialization_memboundOp) { + auto expr = buildSimpleExpr(); + auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0}); + auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0}); + // using namespace infini; + auto runtime = infini::NativeCpuRuntimeObj::getInstance(); + auto g = infini::make_ref(runtime); + auto AT = g->addTensor({8, 10000, 512}); + auto BT = g->addTensor({8, 10000, 512}); + auto CT = g->addTensor({2500, 4, 8, 65}); + + vector nnetInputs{A, B}; + double execTime = 1; + string hint = "test"; + infini::MemBoundObj memboundOp(nullptr, {AT, BT}, {CT}, nnetInputs, expr, + execTime, hint); + memboundOp.saveAsJson("./test_serializer.json"); + auto [exprLoaded, nnetInputsLoaded, execTimeLoaded, hintLoaded] = + Serializer().deserializeAsMemobundOp("./test_serializer.json"); + EXPECT_EQ(expr->toReadable(), exprLoaded->toReadable()); + EXPECT_EQ(execTime, execTimeLoaded); + EXPECT_EQ(nnetInputs.size(), nnetInputsLoaded.size()); + for (size_t i = 0; i < nnetInputs.size(); ++i) + EXPECT_EQ(nnetInputs[i]->toReadable(), + nnetInputsLoaded[i]->toReadable()); + EXPECT_EQ(hint, hintLoaded); } \ No newline at end of file