forked from jiuyuan/InfiniTensor
style: 修改 graph.h/graph.cc
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
0f52d04882
commit
f20e791cf5
|
@ -8,13 +8,10 @@ class GraphObj : public Object {
|
||||||
protected:
|
protected:
|
||||||
Runtime runtime;
|
Runtime runtime;
|
||||||
TensorVec tensors;
|
TensorVec tensors;
|
||||||
// TODO: whether to record input and output tensors
|
|
||||||
// TensorVec inputs;
|
|
||||||
// TensorVec outputs;
|
|
||||||
OpVec ops;
|
OpVec ops;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GraphObj(Runtime runtime) : runtime(runtime){};
|
explicit GraphObj(Runtime runtime) : runtime(runtime){};
|
||||||
GraphObj(Runtime runtime, OpVec ops_in);
|
GraphObj(Runtime runtime, OpVec ops_in);
|
||||||
string toString() const override;
|
string toString() const override;
|
||||||
Runtime getRuntime() const { return runtime; }
|
Runtime getRuntime() const { return runtime; }
|
||||||
|
@ -23,10 +20,15 @@ class GraphObj : public Object {
|
||||||
Tensor addTensor(const Tensor &tensor);
|
Tensor addTensor(const Tensor &tensor);
|
||||||
TensorVec addTensor(const TensorVec &tensors);
|
TensorVec addTensor(const TensorVec &tensors);
|
||||||
Tensor cloneTensor(const Tensor &tensor) {
|
Tensor cloneTensor(const Tensor &tensor) {
|
||||||
auto ret = addTensor(tensor->clone(runtime));
|
return addTensor(tensor->clone(runtime));
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TensorVec &getTensors() const { return tensors; }
|
||||||
|
const OpVec &getOperators() const { return ops; }
|
||||||
|
OpVec getComputeOps() const;
|
||||||
|
|
||||||
|
void dataMalloc();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||||
* should be empty Refs (e.g., nullptr).
|
* should be empty Refs (e.g., nullptr).
|
||||||
|
@ -47,35 +49,33 @@ class GraphObj : public Object {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
const TensorVec &getTensors() const { return tensors; }
|
/**
|
||||||
const TensorVec getInputs() const {
|
* @brief Gets input tensors of this graph.
|
||||||
|
*/
|
||||||
|
inline TensorVec getInputs() const {
|
||||||
TensorVec ret;
|
TensorVec ret;
|
||||||
for (auto t : tensors)
|
for (const auto &t : tensors)
|
||||||
if (!t->getOutputOf())
|
if (!t->getOutputOf())
|
||||||
ret.emplace_back(t);
|
ret.emplace_back(t);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
const TensorVec getOutputs() const {
|
|
||||||
|
/**
|
||||||
|
* @brief Gets output tensors of this graph.
|
||||||
|
*/
|
||||||
|
inline TensorVec getOutputs() const {
|
||||||
TensorVec ret;
|
TensorVec ret;
|
||||||
for (auto t : tensors)
|
for (const auto &t : tensors)
|
||||||
if (t->getInputOf().empty())
|
if (t->getInputOf().empty())
|
||||||
ret.emplace_back(t);
|
ret.emplace_back(t);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
const OpVec &getOperators() const { return ops; }
|
|
||||||
OpVec getComputeOps() const;
|
|
||||||
|
|
||||||
void dataMalloc();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Add reverse connections and Op relationship in ctor.
|
* @brief Add reverse connections and Op relationship in ctor.
|
||||||
*/
|
*/
|
||||||
void addOperatorAndConnect(const Operator &op);
|
void addOperatorAndConnect(const Operator &op);
|
||||||
|
|
||||||
// TODO: move to another class
|
|
||||||
// bool exportOnnx(const char *path);
|
|
||||||
// bool importOnnx(const char *net);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -73,15 +73,12 @@ void GraphObj::dataMalloc() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||||
Tensor tensor = make_ref<TensorObj>(dim, dtype, runtime);
|
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
|
||||||
tensors.emplace_back(tensor);
|
|
||||||
return tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
||||||
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
||||||
tensors.emplace_back(tensor);
|
return tensors.emplace_back(tensor);
|
||||||
return tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
||||||
|
@ -98,4 +95,4 @@ OpVec GraphObj::getComputeOps() const {
|
||||||
return opList;
|
return opList;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue