forked from jiuyuan/InfiniTensor
[Intermediate state] Add: Graph ctor for OpVec
This commit is contained in:
parent
e549f21867
commit
f133f00478
|
@ -14,6 +14,7 @@ class GraphObj : public Object {
|
|||
|
||||
public:
|
||||
GraphObj(Runtime runtime) : runtime(runtime){};
|
||||
GraphObj(Runtime runtime, OpVec ops_in);
|
||||
string toString() const override;
|
||||
|
||||
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
|
||||
|
@ -46,8 +47,14 @@ class GraphObj : public Object {
|
|||
}
|
||||
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const TensorVec &getInputs() const { return inputs; }
|
||||
const TensorVec &getOutputs() const { return outputs; }
|
||||
const TensorVec &getInputs() const {
|
||||
IT_TODO_HALT();
|
||||
return inputs;
|
||||
}
|
||||
const TensorVec &getOutputs() const {
|
||||
IT_TODO_HALT();
|
||||
return outputs;
|
||||
}
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
// TensorVec &getInputs();
|
||||
|
|
|
@ -197,6 +197,9 @@ class OperatorObj : public Object {
|
|||
virtual int numInputs() const = 0;
|
||||
virtual int numOutputs() const = 0;
|
||||
|
||||
Operator cloneAndResetConnections(const TensorVec &newInputs,
|
||||
const TensorVec &newOutputs);
|
||||
|
||||
protected:
|
||||
optional<vector<Shape>> inferShape() const;
|
||||
vector<DataType> inferDataType() const;
|
||||
|
@ -213,8 +216,17 @@ class OperatorObj : public Object {
|
|||
* and output shapes.
|
||||
*/
|
||||
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
|
||||
virtual Operator clone() const {
|
||||
IT_TODO_HALT();
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
#define OP_CLONE(OpObj) \
|
||||
virtual Operator clone() const override { \
|
||||
return infini::make_ref<OpObj>(*this); \
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
||||
namespace std {
|
||||
|
|
|
@ -28,11 +28,6 @@ class TensorObj : public TensorBaseObj {
|
|||
VType getData(const Shape &pos) const;
|
||||
void dataMalloc();
|
||||
GuidBaseType getFuid() const { return fuid; }
|
||||
Tensor clone() const {
|
||||
auto ret = make_ref<TensorObj>(*this);
|
||||
ret->freeData();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void load(std::string file_path);
|
||||
void save(std::string file_path);
|
||||
|
@ -59,7 +54,15 @@ class TensorObj : public TensorBaseObj {
|
|||
}
|
||||
generator(data->getPtr<void *>(), size(), dtype);
|
||||
}
|
||||
Tensor clone(Runtime runtime) {
|
||||
Tensor clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
obj->inputOf.clear();
|
||||
obj->outputOf.reset();
|
||||
return obj;
|
||||
}
|
||||
Tensor clone(Runtime runtime) const {
|
||||
// TODO: use copy constructor
|
||||
auto obj = make_ref<TensorObj>(shape, dtype, runtime);
|
||||
obj->dataMalloc();
|
||||
obj->copyData(this);
|
||||
|
|
|
@ -1,7 +1,32 @@
|
|||
#include "core/graph.h"
|
||||
#include <queue>
|
||||
|
||||
namespace infini {
|
||||
|
||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
||||
map<GuidBaseType, Tensor> tensorPool;
|
||||
// Clone tensors
|
||||
for (const auto &op : ops_in) {
|
||||
for (const auto &t : op->getInputs())
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = t->clone();
|
||||
for (const auto &t : op->getOutputs())
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = t->clone();
|
||||
}
|
||||
for (const auto &[_, t] : tensorPool)
|
||||
addTensor(t);
|
||||
// Clone operators and add connections
|
||||
for (const auto &op : ops_in) {
|
||||
TensorVec inputs, outputs;
|
||||
for (const auto &t : op->getInputs())
|
||||
inputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
for (const auto &t : op->getOutputs())
|
||||
outputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
addOperatorAndConnect(op->cloneAndResetConnections(inputs, outputs));
|
||||
}
|
||||
}
|
||||
|
||||
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||
ops.push_back(op);
|
||||
for (auto &input : op->getInputs()) {
|
||||
|
|
|
@ -93,4 +93,15 @@ vector<DataType> OperatorObj::inferDataType() const {
|
|||
return inferDataType(inputs);
|
||||
}
|
||||
|
||||
Operator OperatorObj::cloneAndResetConnections(const TensorVec &newInputs,
|
||||
const TensorVec &newOutputs) {
|
||||
Operator op = clone();
|
||||
op->inputs = newInputs;
|
||||
op->outputs = newOutputs;
|
||||
op->predecessors.clear();
|
||||
op->successors.clear();
|
||||
IT_ASSERT(op->checkValid(nullptr));
|
||||
return op;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -75,4 +76,37 @@ TEST(Graph, test_tensor_id) {
|
|||
EXPECT_EQ(i1->getDataBlob(), nullptr);
|
||||
}
|
||||
|
||||
TEST(Graph, test_OpVec_ctor) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto o1 = g->addTensor(o0->clone());
|
||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
g->addOp<ReluObj>(o1, nullptr);
|
||||
g->print();
|
||||
puts("=========");
|
||||
OpVec ops = g->getOperators();
|
||||
Graph g2 = make_ref<GraphObj>(runtime, ops);
|
||||
g2->print();
|
||||
// Check if the two tensors with the same FUID (o0,o1) remain only one in g2
|
||||
EXPECT_EQ(g2->getTensors().size(), 4u);
|
||||
EXPECT_EQ(g2->getOperators().size(), 2u);
|
||||
map<pair<int, int>, int> inputOutput2Cnt = {
|
||||
{{1, 0}, 2}, {{1, 1}, 1}, {{0, 1}, 1}};
|
||||
for (auto t : g2->getTensors()) {
|
||||
pair<int, int> key = {t->getInputOf().size(),
|
||||
t->getOutputOf() != nullptr};
|
||||
EXPECT_GE(inputOutput2Cnt[key], 0);
|
||||
inputOutput2Cnt[key]--;
|
||||
}
|
||||
for (auto [u, v] : inputOutput2Cnt) {
|
||||
EXPECT_EQ(v, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue