2022-08-05 12:50:34 +08:00
|
|
|
#pragma once
|
2022-08-22 15:01:03 +08:00
|
|
|
#include "core/blob.h"
|
2022-08-23 16:55:59 +08:00
|
|
|
#include "core/data_type.h"
|
2022-08-05 12:50:34 +08:00
|
|
|
#include "core/object.h"
|
2022-08-23 16:55:59 +08:00
|
|
|
#include "core/runtime.h"
|
2022-08-07 21:12:17 +08:00
|
|
|
namespace infini {
|
2023-04-17 13:09:07 +08:00
|
|
|
class GraphObj;
|
2022-08-15 15:08:56 +08:00
|
|
|
class TensorBaseObj : public Object {
|
2023-04-17 13:09:07 +08:00
|
|
|
friend class GraphObj;
|
|
|
|
|
2022-08-05 12:50:34 +08:00
|
|
|
public:
|
|
|
|
// enum TensorType {
|
|
|
|
// Input,
|
|
|
|
// Weight,
|
|
|
|
// Invalid,
|
|
|
|
// NotCounted,
|
|
|
|
// };
|
|
|
|
|
|
|
|
protected:
|
|
|
|
int dim;
|
|
|
|
|
|
|
|
DataType dtype;
|
2023-02-24 15:02:52 +08:00
|
|
|
vector<WRef<OperatorObj>> targets;
|
|
|
|
WRef<OperatorObj> source;
|
2022-08-22 15:01:03 +08:00
|
|
|
Blob data;
|
2022-08-23 16:55:59 +08:00
|
|
|
Runtime runtime;
|
2022-08-05 12:50:34 +08:00
|
|
|
|
|
|
|
public:
|
2022-08-23 16:55:59 +08:00
|
|
|
TensorBaseObj(int dim, DataType dtype, Runtime runtime);
|
2022-08-15 15:08:56 +08:00
|
|
|
virtual ~TensorBaseObj() {}
|
2022-08-05 12:50:34 +08:00
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
void dataMalloc(const Blob &blob) {
|
|
|
|
IT_ASSERT(data == nullptr);
|
|
|
|
data = blob;
|
|
|
|
}
|
2022-08-25 11:29:16 +08:00
|
|
|
Blob getDataBlob() const { return data; }
|
2023-02-12 18:27:52 +08:00
|
|
|
bool hasData() const { return data != nullptr; }
|
|
|
|
void freeData() { data = nullptr; }
|
2022-08-25 11:29:16 +08:00
|
|
|
template <typename T> T getRawDataPtr() const {
|
|
|
|
static_assert(std::is_pointer_v<T>,
|
|
|
|
"Raw data pointer has a type of pointer");
|
2022-08-22 15:01:03 +08:00
|
|
|
IT_ASSERT(data != nullptr);
|
|
|
|
return data->getPtr<T>();
|
|
|
|
}
|
2022-08-05 12:50:34 +08:00
|
|
|
|
|
|
|
DataType getDType() const { return dtype; }
|
2023-09-15 16:40:44 +08:00
|
|
|
int getDTypeIndex() const { return dtype.getIndex(); }
|
2022-08-23 16:55:59 +08:00
|
|
|
Runtime getRuntime() const { return runtime; }
|
2022-08-05 12:50:34 +08:00
|
|
|
|
2023-04-17 13:09:07 +08:00
|
|
|
// std::pair<Operator *, int> getOutputOfWithIndex();
|
2023-02-24 15:02:52 +08:00
|
|
|
|
|
|
|
bool hasTarget() const { return !targets.empty(); }
|
|
|
|
|
2023-03-25 12:04:24 +08:00
|
|
|
OpVec getTargets() const { return wrefs_to_refs(targets); }
|
|
|
|
Operator getSource() const { return source.lock(); }
|
2023-04-17 13:09:07 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
void addTarget(const Operator &op) { targets.emplace_back(op); }
|
|
|
|
void setSource(const Operator &op) { source = op; }
|
|
|
|
void removeTarget(const Operator &op) {
|
|
|
|
for (auto itr = targets.begin(); itr != targets.end();) {
|
|
|
|
if (itr->lock() == op)
|
|
|
|
itr = targets.erase(itr);
|
|
|
|
else
|
|
|
|
++itr;
|
|
|
|
}
|
|
|
|
}
|
2023-03-25 12:04:24 +08:00
|
|
|
// std::pair<Operator *, int> getSourceWithIndex();
|
2022-08-05 12:50:34 +08:00
|
|
|
|
|
|
|
// bool setScalar(VType val) {
|
|
|
|
// if (data == nullptr || !dims.empty())
|
|
|
|
// return false;
|
|
|
|
// data[0] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setData(const Dim &ds, VType val) {
|
|
|
|
// if (data == nullptr || ds.size() != dims.size())
|
|
|
|
// return false;
|
|
|
|
// data[getOffset(ds)] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setData(size_t pos, VType val) {
|
|
|
|
// if (data == nullptr || pos >= size())
|
|
|
|
// return false;
|
|
|
|
// data[pos] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
|
|
|
|
|
|
|
// VType getBroadcastData(const Dim &ds) {
|
|
|
|
// assert(data != nullptr);
|
|
|
|
// auto offset = getBroadcastOffset(ds);
|
|
|
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// VType getBroadcastData(size_t pos) {
|
|
|
|
// assert(data != nullptr);
|
|
|
|
// return data[pos % size()];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// size_t getBroadcastOffset(const Dim &ds) {
|
|
|
|
// assert(ds.size() >= dims.size());
|
|
|
|
// auto nDim = dims.size();
|
|
|
|
// auto nBroadcastDim = ds.size() - nDim;
|
|
|
|
// for (size_t i = 0; i < nDim; ++i)
|
|
|
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
|
|
|
// dims[i])
|
|
|
|
// return (size_t)-1;
|
|
|
|
// size_t idx = 0;
|
|
|
|
// for (size_t i = 0; i < nDim; ++i)
|
|
|
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
|
|
|
// return idx;
|
|
|
|
// }
|
|
|
|
};
|
|
|
|
|
2022-08-09 14:58:45 +08:00
|
|
|
} // namespace infini
|