diff --git a/include/core/graph.h b/include/core/graph.h index 4ce8697a..eaa6f4a5 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -8,13 +8,10 @@ class GraphObj : public Object { protected: Runtime runtime; TensorVec tensors; - // TODO: whether to record input and output tensors - // TensorVec inputs; - // TensorVec outputs; OpVec ops; public: - GraphObj(Runtime runtime) : runtime(runtime){}; + explicit GraphObj(Runtime runtime) : runtime(runtime){}; GraphObj(Runtime runtime, OpVec ops_in); string toString() const override; Runtime getRuntime() const { return runtime; } @@ -23,10 +20,15 @@ class GraphObj : public Object { Tensor addTensor(const Tensor &tensor); TensorVec addTensor(const TensorVec &tensors); Tensor cloneTensor(const Tensor &tensor) { - auto ret = addTensor(tensor->clone(runtime)); - return ret; + return addTensor(tensor->clone(runtime)); } + const TensorVec &getTensors() const { return tensors; } + const OpVec &getOperators() const { return ops; } + OpVec getComputeOps() const; + + void dataMalloc(); + /** * @brief Add an operator and create its outputs. Output tensor arguments * should be empty Refs (e.g., nullptr). @@ -47,35 +49,33 @@ class GraphObj : public Object { return op; } - const TensorVec &getTensors() const { return tensors; } - const TensorVec getInputs() const { + /** + * @brief Gets input tensors of this graph. + */ + inline TensorVec getInputs() const { TensorVec ret; - for (auto t : tensors) + for (const auto &t : tensors) if (!t->getOutputOf()) ret.emplace_back(t); return ret; } - const TensorVec getOutputs() const { + + /** + * @brief Gets output tensors of this graph. + */ + inline TensorVec getOutputs() const { TensorVec ret; - for (auto t : tensors) + for (const auto &t : tensors) if (t->getInputOf().empty()) ret.emplace_back(t); return ret; } - const OpVec &getOperators() const { return ops; } - OpVec getComputeOps() const; - - void dataMalloc(); private: /** * @brief Add reverse connections and Op relationship in ctor. */ void addOperatorAndConnect(const Operator &op); - - // TODO: move to another class - // bool exportOnnx(const char *path); - // bool importOnnx(const char *net); }; } // namespace infini diff --git a/src/core/graph.cc b/src/core/graph.cc index a9edda64..8b2a6bbc 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -73,15 +73,12 @@ void GraphObj::dataMalloc() { } Tensor GraphObj::addTensor(Shape dim, DataType dtype) { - Tensor tensor = make_ref(dim, dtype, runtime); - tensors.emplace_back(tensor); - return tensor; + return tensors.emplace_back(make_ref(dim, dtype, runtime)); } Tensor GraphObj::addTensor(const Tensor &tensor) { IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch"); - tensors.emplace_back(tensor); - return tensor; + return tensors.emplace_back(tensor); } TensorVec GraphObj::addTensor(const TensorVec &tensors) { @@ -98,4 +95,4 @@ OpVec GraphObj::getComputeOps() const { return opList; }; -} // namespace infini \ No newline at end of file +} // namespace infini