Add: nnet::Serializer supports FuncNode

This commit is contained in:
Liyan Zheng 2023-04-17 20:15:40 +08:00
parent 9d50b30af8
commit 99b5c95455
2 changed files with 15 additions and 3 deletions

View File

@ -73,6 +73,14 @@ string Serializer::visit_(const Tensor &c) {
return key;
}
string Serializer::visit_(const Func &c) {
const string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["funcType"] = c->getFuncType();
j[key]["object"] = dispatch(c->getObject());
return key;
}
bool Serializer::serialize(const Expr &expr, const string &filePath,
const string &msg) {
// Metadata
@ -180,6 +188,10 @@ Expr Serializer::buildExprTree(string key) {
return make_ref<TensorNode>(j[key]["name"], j[key]["shape"],
j[key]["paddings"], source);
}
case NodeType::FuncNodeType: {
auto object = buildExprTree(j[key]["object"]);
return make_ref<FuncNode>(object, j[key]["funcType"]);
}
default: {
nnet_unimplemented_halt();
break;

View File

@ -76,12 +76,14 @@ TEST(Serializer, CompareTwoExprs) {
auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0});
auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0});
auto subA = makeSubscript(A, {b, (i3 + (2500 * i4)), k});
auto funcA = make_ref<FuncNode>(subA, FuncType::Relu);
auto subB = makeSubscript(B, {b, ((i3 + (2500 * i4)) + w), k});
auto range = makeRangeOperator(
{{i3, {0, 2500}}, {i4, {0, 4}}, {b, {0, 8}}, {w, {0, 65}}},
{{k, {0, 512}}}, subA * subB);
{{k, {0, 512}}}, funcA * subB);
Serializer().serialize(range, "./test_serializer.json");
auto expr = Serializer().deserialize("./test_serializer.json");
dbg(expr);
EXPECT_EQ(range->toReadable(), expr->toReadable());
}
@ -90,11 +92,9 @@ TEST(Serializer, Serialization_NestedTensor) {
FullPrinterVisitor printer;
auto range = buildNestedExpr();
auto ans = printer.print(range);
dbg(ans);
auto isSuccessful = Serializer().serialize(range, "./test_serializer.json");
EXPECT_TRUE(isSuccessful);
auto exprDeserialized = Serializer().deserialize("./test_serializer.json");
auto output = printer.print(exprDeserialized);
dbg(output);
EXPECT_EQ(output, ans);
}