diff --git a/include/core/graph.h b/include/core/graph.h index d8126a84..1a3261f8 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -14,6 +14,7 @@ class GraphObj : public Object { public: GraphObj(Runtime runtime) : runtime(runtime){}; + GraphObj(Runtime runtime, OpVec ops_in); string toString() const override; Tensor addTensor(Shape dim, DataType dtype = DataType::Float32); @@ -46,8 +47,14 @@ class GraphObj : public Object { } const TensorVec &getTensors() const { return tensors; } - const TensorVec &getInputs() const { return inputs; } - const TensorVec &getOutputs() const { return outputs; } + const TensorVec &getInputs() const { + IT_TODO_HALT(); + return inputs; + } + const TensorVec &getOutputs() const { + IT_TODO_HALT(); + return outputs; + } const OpVec &getOperators() const { return ops; } OpVec getComputeOps() const; // TensorVec &getInputs(); diff --git a/include/core/operator.h b/include/core/operator.h index f5efc8e7..4e1885ca 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -197,6 +197,9 @@ class OperatorObj : public Object { virtual int numInputs() const = 0; virtual int numOutputs() const = 0; + Operator cloneAndResetConnections(const TensorVec &newInputs, + const TensorVec &newOutputs); + protected: optional> inferShape() const; vector inferDataType() const; @@ -213,8 +216,17 @@ class OperatorObj : public Object { * and output shapes. */ virtual vector getWorkloadVector() const { IT_TODO_HALT(); } + virtual Operator clone() const { + IT_TODO_HALT(); + return nullptr; + } }; +#define OP_CLONE(OpObj) \ + virtual Operator clone() const override { \ + return infini::make_ref(*this); \ + } + } // namespace infini namespace std { diff --git a/include/core/tensor.h b/include/core/tensor.h index 7b4eb826..a9edb38b 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -28,11 +28,6 @@ class TensorObj : public TensorBaseObj { VType getData(const Shape &pos) const; void dataMalloc(); GuidBaseType getFuid() const { return fuid; } - Tensor clone() const { - auto ret = make_ref(*this); - ret->freeData(); - return ret; - } void load(std::string file_path); void save(std::string file_path); @@ -59,7 +54,15 @@ class TensorObj : public TensorBaseObj { } generator(data->getPtr(), size(), dtype); } - Tensor clone(Runtime runtime) { + Tensor clone() const { + auto obj = make_ref(*this); + obj->freeData(); + obj->inputOf.clear(); + obj->outputOf.reset(); + return obj; + } + Tensor clone(Runtime runtime) const { + // TODO: use copy constructor auto obj = make_ref(shape, dtype, runtime); obj->dataMalloc(); obj->copyData(this); diff --git a/src/core/graph.cc b/src/core/graph.cc index b45c0839..42c4c20b 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,7 +1,32 @@ #include "core/graph.h" +#include namespace infini { +GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { + map tensorPool; + // Clone tensors + for (const auto &op : ops_in) { + for (const auto &t : op->getInputs()) + if (tensorPool.find(t->getFuid()) == tensorPool.end()) + tensorPool[t->getFuid()] = t->clone(); + for (const auto &t : op->getOutputs()) + if (tensorPool.find(t->getFuid()) == tensorPool.end()) + tensorPool[t->getFuid()] = t->clone(); + } + for (const auto &[_, t] : tensorPool) + addTensor(t); + // Clone operators and add connections + for (const auto &op : ops_in) { + TensorVec inputs, outputs; + for (const auto &t : op->getInputs()) + inputs.emplace_back(tensorPool.at(t->getFuid())); + for (const auto &t : op->getOutputs()) + outputs.emplace_back(tensorPool.at(t->getFuid())); + addOperatorAndConnect(op->cloneAndResetConnections(inputs, outputs)); + } +} + void GraphObj::addOperatorAndConnect(const Operator &op) { ops.push_back(op); for (auto &input : op->getInputs()) { diff --git a/src/core/operator.cc b/src/core/operator.cc index b8e69af8..d7ab78e2 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -93,4 +93,15 @@ vector OperatorObj::inferDataType() const { return inferDataType(inputs); } +Operator OperatorObj::cloneAndResetConnections(const TensorVec &newInputs, + const TensorVec &newOutputs) { + Operator op = clone(); + op->inputs = newInputs; + op->outputs = newOutputs; + op->predecessors.clear(); + op->successors.clear(); + IT_ASSERT(op->checkValid(nullptr)); + return op; +} + } // namespace infini diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 516adfda..d208c21f 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -2,6 +2,7 @@ #include "core/graph.h" #include "core/runtime.h" #include "operators/matmul.h" +#include "operators/unary.h" #include "test.h" namespace infini { @@ -75,4 +76,37 @@ TEST(Graph, test_tensor_id) { EXPECT_EQ(i1->getDataBlob(), nullptr); } +TEST(Graph, test_OpVec_ctor) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); + g->dataMalloc(); + i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto o1 = g->addTensor(o0->clone()); + auto matmul = g->addOpWithOutputs(i0, w0, o0); + g->addOp(o1, nullptr); + g->print(); + puts("========="); + OpVec ops = g->getOperators(); + Graph g2 = make_ref(runtime, ops); + g2->print(); + // Check if the two tensors with the same FUID (o0,o1) remain only one in g2 + EXPECT_EQ(g2->getTensors().size(), 4u); + EXPECT_EQ(g2->getOperators().size(), 2u); + map, int> inputOutput2Cnt = { + {{1, 0}, 2}, {{1, 1}, 1}, {{0, 1}, 1}}; + for (auto t : g2->getTensors()) { + pair key = {t->getInputOf().size(), + t->getOutputOf() != nullptr}; + EXPECT_GE(inputOutput2Cnt[key], 0); + inputOutput2Cnt[key]--; + } + for (auto [u, v] : inputOutput2Cnt) { + EXPECT_EQ(v, 0); + } +} + } // namespace infini