style: 修改 graph.h/graph.cc

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-16 14:58:04 +08:00
parent 0f52d04882
commit f20e791cf5
2 changed files with 22 additions and 25 deletions

View File

@ -8,13 +8,10 @@ class GraphObj : public Object {
protected: protected:
Runtime runtime; Runtime runtime;
TensorVec tensors; TensorVec tensors;
// TODO: whether to record input and output tensors
// TensorVec inputs;
// TensorVec outputs;
OpVec ops; OpVec ops;
public: public:
GraphObj(Runtime runtime) : runtime(runtime){}; explicit GraphObj(Runtime runtime) : runtime(runtime){};
GraphObj(Runtime runtime, OpVec ops_in); GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override; string toString() const override;
Runtime getRuntime() const { return runtime; } Runtime getRuntime() const { return runtime; }
@ -23,10 +20,15 @@ class GraphObj : public Object {
Tensor addTensor(const Tensor &tensor); Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors); TensorVec addTensor(const TensorVec &tensors);
Tensor cloneTensor(const Tensor &tensor) { Tensor cloneTensor(const Tensor &tensor) {
auto ret = addTensor(tensor->clone(runtime)); return addTensor(tensor->clone(runtime));
return ret;
} }
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
/** /**
* @brief Add an operator and create its outputs. Output tensor arguments * @brief Add an operator and create its outputs. Output tensor arguments
* should be empty Refs (e.g., nullptr). * should be empty Refs (e.g., nullptr).
@ -47,35 +49,33 @@ class GraphObj : public Object {
return op; return op;
} }
const TensorVec &getTensors() const { return tensors; } /**
const TensorVec getInputs() const { * @brief Gets input tensors of this graph.
*/
inline TensorVec getInputs() const {
TensorVec ret; TensorVec ret;
for (auto t : tensors) for (const auto &t : tensors)
if (!t->getOutputOf()) if (!t->getOutputOf())
ret.emplace_back(t); ret.emplace_back(t);
return ret; return ret;
} }
const TensorVec getOutputs() const {
/**
* @brief Gets output tensors of this graph.
*/
inline TensorVec getOutputs() const {
TensorVec ret; TensorVec ret;
for (auto t : tensors) for (const auto &t : tensors)
if (t->getInputOf().empty()) if (t->getInputOf().empty())
ret.emplace_back(t); ret.emplace_back(t);
return ret; return ret;
} }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
private: private:
/** /**
* @brief Add reverse connections and Op relationship in ctor. * @brief Add reverse connections and Op relationship in ctor.
*/ */
void addOperatorAndConnect(const Operator &op); void addOperatorAndConnect(const Operator &op);
// TODO: move to another class
// bool exportOnnx(const char *path);
// bool importOnnx(const char *net);
}; };
} // namespace infini } // namespace infini

View File

@ -73,15 +73,12 @@ void GraphObj::dataMalloc() {
} }
Tensor GraphObj::addTensor(Shape dim, DataType dtype) { Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime); return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
tensors.emplace_back(tensor);
return tensor;
} }
Tensor GraphObj::addTensor(const Tensor &tensor) { Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch"); IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
tensors.emplace_back(tensor); return tensors.emplace_back(tensor);
return tensor;
} }
TensorVec GraphObj::addTensor(const TensorVec &tensors) { TensorVec GraphObj::addTensor(const TensorVec &tensors) {
@ -98,4 +95,4 @@ OpVec GraphObj::getComputeOps() const {
return opList; return opList;
}; };
} // namespace infini } // namespace infini