style: 修改 graph.h/graph.cc

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-16 14:58:04 +08:00
parent 0f52d04882
commit f20e791cf5
2 changed files with 22 additions and 25 deletions

View File

@ -8,13 +8,10 @@ class GraphObj : public Object {
protected:
Runtime runtime;
TensorVec tensors;
// TODO: whether to record input and output tensors
// TensorVec inputs;
// TensorVec outputs;
OpVec ops;
public:
GraphObj(Runtime runtime) : runtime(runtime){};
explicit GraphObj(Runtime runtime) : runtime(runtime){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Runtime getRuntime() const { return runtime; }
@ -23,10 +20,15 @@ class GraphObj : public Object {
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
Tensor cloneTensor(const Tensor &tensor) {
auto ret = addTensor(tensor->clone(runtime));
return ret;
return addTensor(tensor->clone(runtime));
}
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
/**
* @brief Add an operator and create its outputs. Output tensor arguments
* should be empty Refs (e.g., nullptr).
@ -47,35 +49,33 @@ class GraphObj : public Object {
return op;
}
const TensorVec &getTensors() const { return tensors; }
const TensorVec getInputs() const {
/**
* @brief Gets input tensors of this graph.
*/
inline TensorVec getInputs() const {
TensorVec ret;
for (auto t : tensors)
for (const auto &t : tensors)
if (!t->getOutputOf())
ret.emplace_back(t);
return ret;
}
const TensorVec getOutputs() const {
/**
* @brief Gets output tensors of this graph.
*/
inline TensorVec getOutputs() const {
TensorVec ret;
for (auto t : tensors)
for (const auto &t : tensors)
if (t->getInputOf().empty())
ret.emplace_back(t);
return ret;
}
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
void dataMalloc();
private:
/**
* @brief Add reverse connections and Op relationship in ctor.
*/
void addOperatorAndConnect(const Operator &op);
// TODO: move to another class
// bool exportOnnx(const char *path);
// bool importOnnx(const char *net);
};
} // namespace infini

View File

@ -73,15 +73,12 @@ void GraphObj::dataMalloc() {
}
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime);
tensors.emplace_back(tensor);
return tensor;
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
}
Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
tensors.emplace_back(tensor);
return tensor;
return tensors.emplace_back(tensor);
}
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
@ -98,4 +95,4 @@ OpVec GraphObj::getComputeOps() const {
return opList;
};
} // namespace infini
} // namespace infini