forked from jiuyuan/InfiniTensor
Add: tensor fuid
This commit is contained in:
parent
c5966f8d81
commit
e549f21867
|
@ -17,6 +17,7 @@ class GraphObj : public Object {
|
|||
string toString() const override;
|
||||
|
||||
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
|
||||
Tensor addTensor(const Tensor &tensor);
|
||||
Tensor cloneTensor(const Tensor &tensor) {
|
||||
auto ret = addTensor(tensor->getDims(), tensor->getDType());
|
||||
ret->dataMalloc();
|
||||
|
|
|
@ -27,6 +27,27 @@ class Guid {
|
|||
operator GuidBaseType() const { return guid; }
|
||||
};
|
||||
|
||||
class Fuid {
|
||||
private:
|
||||
GuidBaseType fuid;
|
||||
|
||||
private:
|
||||
GuidBaseType generateFuid() {
|
||||
static GuidBaseType guidCnt = 0;
|
||||
return ++guidCnt;
|
||||
}
|
||||
|
||||
public:
|
||||
Fuid() { fuid = generateFuid(); }
|
||||
Fuid(const Guid &rhs) { fuid = generateFuid(); }
|
||||
Fuid &operator=(const Guid &rhs) {
|
||||
fuid = generateFuid();
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator GuidBaseType() const { return fuid; }
|
||||
};
|
||||
|
||||
class Object {
|
||||
protected:
|
||||
Guid guid;
|
||||
|
|
|
@ -10,6 +10,8 @@ using Shape = vector<ShapeElem>;
|
|||
class TensorObj : public TensorBaseObj {
|
||||
private:
|
||||
Shape shape;
|
||||
Fuid fuid; // Tensor cloned from a common tensor share the same id. Tensors
|
||||
// constructed from common constructor has a new id.
|
||||
|
||||
public:
|
||||
TensorObj(const Shape &shape, DataType dtype, Runtime runtime);
|
||||
|
@ -25,6 +27,12 @@ class TensorObj : public TensorBaseObj {
|
|||
using TensorBaseObj::getData;
|
||||
VType getData(const Shape &pos) const;
|
||||
void dataMalloc();
|
||||
GuidBaseType getFuid() const { return fuid; }
|
||||
Tensor clone() const {
|
||||
auto ret = make_ref<TensorObj>(*this);
|
||||
ret->freeData();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void load(std::string file_path);
|
||||
void save(std::string file_path);
|
||||
|
|
|
@ -33,6 +33,7 @@ class TensorBaseObj : public Object {
|
|||
data = blob;
|
||||
}
|
||||
Blob getDataBlob() const { return data; }
|
||||
void freeData() { data = nullptr; }
|
||||
template <typename T> T getRawDataPtr() const {
|
||||
static_assert(std::is_pointer_v<T>,
|
||||
"Raw data pointer has a type of pointer");
|
||||
|
|
|
@ -53,6 +53,11 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
||||
tensors.emplace_back(tensor);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
OpVec GraphObj::getComputeOps() const {
|
||||
OpVec opList;
|
||||
for (auto op : ops)
|
||||
|
|
|
@ -14,8 +14,9 @@ VType TensorObj::getData(const Shape &pos) const {
|
|||
}
|
||||
|
||||
string TensorObj::toString() const {
|
||||
string ret = "Tensor " + std::to_string(guid) + ", shape " +
|
||||
vecToString(shape) + ", dtype " + dtype.toString();
|
||||
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
|
||||
std::to_string(fuid) + ", shape " + vecToString(shape) +
|
||||
", dtype " + dtype.toString();
|
||||
vector<GuidBaseType> inputOfGuid;
|
||||
for (const auto &op : inputOf)
|
||||
inputOfGuid.emplace_back(op.lock()->getGuid());
|
||||
|
|
|
@ -57,4 +57,22 @@ TEST(Graph, perf_engine) {
|
|||
EXPECT_TRUE(matmul->getOutput()->equalData(ans));
|
||||
}
|
||||
|
||||
TEST(Graph, test_tensor_id) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto i1 = g->addTensor(i0->clone());
|
||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
g->print();
|
||||
EXPECT_NE(i0->getGuid(), i1->getGuid());
|
||||
EXPECT_EQ(i0->getFuid(), i1->getFuid());
|
||||
EXPECT_NE(i0->getDataBlob(), nullptr);
|
||||
EXPECT_EQ(i1->getDataBlob(), nullptr);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue