[Intermediate state] Add: Graph ctor for OpVec

This commit is contained in:
Liyan Zheng 2022-11-15 21:09:03 +08:00
parent e549f21867
commit f133f00478
6 changed files with 100 additions and 8 deletions

View File

@ -14,6 +14,7 @@ class GraphObj : public Object {
public:
GraphObj(Runtime runtime) : runtime(runtime){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
@ -46,8 +47,14 @@ class GraphObj : public Object {
}
const TensorVec &getTensors() const { return tensors; }
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
const TensorVec &getInputs() const {
IT_TODO_HALT();
return inputs;
}
const TensorVec &getOutputs() const {
IT_TODO_HALT();
return outputs;
}
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
// TensorVec &getInputs();

View File

@ -197,6 +197,9 @@ class OperatorObj : public Object {
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
Operator cloneAndResetConnections(const TensorVec &newInputs,
const TensorVec &newOutputs);
protected:
optional<vector<Shape>> inferShape() const;
vector<DataType> inferDataType() const;
@ -213,8 +216,17 @@ class OperatorObj : public Object {
* and output shapes.
*/
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
virtual Operator clone() const {
IT_TODO_HALT();
return nullptr;
}
};
#define OP_CLONE(OpObj) \
virtual Operator clone() const override { \
return infini::make_ref<OpObj>(*this); \
}
} // namespace infini
namespace std {

View File

@ -28,11 +28,6 @@ class TensorObj : public TensorBaseObj {
VType getData(const Shape &pos) const;
void dataMalloc();
GuidBaseType getFuid() const { return fuid; }
Tensor clone() const {
auto ret = make_ref<TensorObj>(*this);
ret->freeData();
return ret;
}
void load(std::string file_path);
void save(std::string file_path);
@ -59,7 +54,15 @@ class TensorObj : public TensorBaseObj {
}
generator(data->getPtr<void *>(), size(), dtype);
}
Tensor clone(Runtime runtime) {
Tensor clone() const {
auto obj = make_ref<TensorObj>(*this);
obj->freeData();
obj->inputOf.clear();
obj->outputOf.reset();
return obj;
}
Tensor clone(Runtime runtime) const {
// TODO: use copy constructor
auto obj = make_ref<TensorObj>(shape, dtype, runtime);
obj->dataMalloc();
obj->copyData(this);

View File

@ -1,7 +1,32 @@
#include "core/graph.h"
#include <queue>
namespace infini {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
map<GuidBaseType, Tensor> tensorPool;
// Clone tensors
for (const auto &op : ops_in) {
for (const auto &t : op->getInputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
for (const auto &t : op->getOutputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
}
for (const auto &[_, t] : tensorPool)
addTensor(t);
// Clone operators and add connections
for (const auto &op : ops_in) {
TensorVec inputs, outputs;
for (const auto &t : op->getInputs())
inputs.emplace_back(tensorPool.at(t->getFuid()));
for (const auto &t : op->getOutputs())
outputs.emplace_back(tensorPool.at(t->getFuid()));
addOperatorAndConnect(op->cloneAndResetConnections(inputs, outputs));
}
}
void GraphObj::addOperatorAndConnect(const Operator &op) {
ops.push_back(op);
for (auto &input : op->getInputs()) {

View File

@ -93,4 +93,15 @@ vector<DataType> OperatorObj::inferDataType() const {
return inferDataType(inputs);
}
Operator OperatorObj::cloneAndResetConnections(const TensorVec &newInputs,
const TensorVec &newOutputs) {
Operator op = clone();
op->inputs = newInputs;
op->outputs = newOutputs;
op->predecessors.clear();
op->successors.clear();
IT_ASSERT(op->checkValid(nullptr));
return op;
}
} // namespace infini

View File

@ -2,6 +2,7 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/matmul.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
@ -75,4 +76,37 @@ TEST(Graph, test_tensor_id) {
EXPECT_EQ(i1->getDataBlob(), nullptr);
}
TEST(Graph, test_OpVec_ctor) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc();
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto o1 = g->addTensor(o0->clone());
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
g->addOp<ReluObj>(o1, nullptr);
g->print();
puts("=========");
OpVec ops = g->getOperators();
Graph g2 = make_ref<GraphObj>(runtime, ops);
g2->print();
// Check if the two tensors with the same FUID (o0,o1) remain only one in g2
EXPECT_EQ(g2->getTensors().size(), 4u);
EXPECT_EQ(g2->getOperators().size(), 2u);
map<pair<int, int>, int> inputOutput2Cnt = {
{{1, 0}, 2}, {{1, 1}, 1}, {{0, 1}, 1}};
for (auto t : g2->getTensors()) {
pair<int, int> key = {t->getInputOf().size(),
t->getOutputOf() != nullptr};
EXPECT_GE(inputOutput2Cnt[key], 0);
inputOutput2Cnt[key]--;
}
for (auto [u, v] : inputOutput2Cnt) {
EXPECT_EQ(v, 0);
}
}
} // namespace infini