forked from jiuyuan/InfiniTensor
Add: nnet::Serializer supports FuncNode
This commit is contained in:
parent
9d50b30af8
commit
99b5c95455
|
@ -73,6 +73,14 @@ string Serializer::visit_(const Tensor &c) {
|
|||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const Func &c) {
|
||||
const string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["funcType"] = c->getFuncType();
|
||||
j[key]["object"] = dispatch(c->getObject());
|
||||
return key;
|
||||
}
|
||||
|
||||
bool Serializer::serialize(const Expr &expr, const string &filePath,
|
||||
const string &msg) {
|
||||
// Metadata
|
||||
|
@ -180,6 +188,10 @@ Expr Serializer::buildExprTree(string key) {
|
|||
return make_ref<TensorNode>(j[key]["name"], j[key]["shape"],
|
||||
j[key]["paddings"], source);
|
||||
}
|
||||
case NodeType::FuncNodeType: {
|
||||
auto object = buildExprTree(j[key]["object"]);
|
||||
return make_ref<FuncNode>(object, j[key]["funcType"]);
|
||||
}
|
||||
default: {
|
||||
nnet_unimplemented_halt();
|
||||
break;
|
||||
|
|
|
@ -76,12 +76,14 @@ TEST(Serializer, CompareTwoExprs) {
|
|||
auto A = makeTensor("A", {8, 10000, 512}, {0, 0, 0});
|
||||
auto B = makeTensor("B", {8, 10000, 512}, {0, 128, 0});
|
||||
auto subA = makeSubscript(A, {b, (i3 + (2500 * i4)), k});
|
||||
auto funcA = make_ref<FuncNode>(subA, FuncType::Relu);
|
||||
auto subB = makeSubscript(B, {b, ((i3 + (2500 * i4)) + w), k});
|
||||
auto range = makeRangeOperator(
|
||||
{{i3, {0, 2500}}, {i4, {0, 4}}, {b, {0, 8}}, {w, {0, 65}}},
|
||||
{{k, {0, 512}}}, subA * subB);
|
||||
{{k, {0, 512}}}, funcA * subB);
|
||||
Serializer().serialize(range, "./test_serializer.json");
|
||||
auto expr = Serializer().deserialize("./test_serializer.json");
|
||||
dbg(expr);
|
||||
|
||||
EXPECT_EQ(range->toReadable(), expr->toReadable());
|
||||
}
|
||||
|
@ -90,11 +92,9 @@ TEST(Serializer, Serialization_NestedTensor) {
|
|||
FullPrinterVisitor printer;
|
||||
auto range = buildNestedExpr();
|
||||
auto ans = printer.print(range);
|
||||
dbg(ans);
|
||||
auto isSuccessful = Serializer().serialize(range, "./test_serializer.json");
|
||||
EXPECT_TRUE(isSuccessful);
|
||||
auto exprDeserialized = Serializer().deserialize("./test_serializer.json");
|
||||
auto output = printer.print(exprDeserialized);
|
||||
dbg(output);
|
||||
EXPECT_EQ(output, ans);
|
||||
}
|
Loading…
Reference in New Issue