InfiniTensor/include/core/tensor_base.h

103 lines
3.0 KiB
C
Raw Normal View History

2022-08-05 12:50:34 +08:00
#pragma once
#include "core/blob.h"
#include "core/data_type.h"
2022-08-05 12:50:34 +08:00
#include "core/object.h"
#include "core/runtime.h"
2022-08-05 12:50:34 +08:00
2022-08-07 21:12:17 +08:00
namespace infini {
2022-08-05 12:50:34 +08:00
class TensorBaseObj : public Object {
2022-08-05 12:50:34 +08:00
public:
// enum TensorType {
// Input,
// Weight,
// Invalid,
// NotCounted,
// };
protected:
int dim;
DataType dtype;
vector<WRef<OperatorObj>> inputOf;
WRef<OperatorObj> outputOf;
Blob data;
Runtime runtime;
2022-08-05 12:50:34 +08:00
public:
TensorBaseObj(int dim, DataType dtype, Runtime runtime);
virtual ~TensorBaseObj() {}
2022-08-05 12:50:34 +08:00
void dataMalloc(const Blob &blob) {
IT_ASSERT(data == nullptr);
data = blob;
}
Blob getDataBlob() const { return data; }
template <typename T> T getRawDataPtr() const {
static_assert(std::is_pointer_v<T>,
"Raw data pointer has a type of pointer");
IT_ASSERT(data != nullptr);
return data->getPtr<T>();
}
2022-08-05 12:50:34 +08:00
VType getData(size_t offset) const;
DataType getDType() const { return dtype; }
Runtime getRuntime() const { return runtime; }
2022-08-05 12:50:34 +08:00
void addInputOf(const Operator &op) { inputOf.emplace_back(op); }
void setOutputOf(const Operator &op) { outputOf = op; }
OpVec getInputOf() { return wrefs_to_refs(inputOf); }
Operator getOutputOf() { return outputOf.lock(); }
2022-08-05 12:50:34 +08:00
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
// data[0] = val;
// return true;
// }
// bool setData(const Dim &ds, VType val) {
// if (data == nullptr || ds.size() != dims.size())
// return false;
// data[getOffset(ds)] = val;
// return true;
// }
// bool setData(size_t pos, VType val) {
// if (data == nullptr || pos >= size())
// return false;
// data[pos] = val;
// return true;
// }
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getBroadcastOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getBroadcastData(size_t pos) {
// assert(data != nullptr);
// return data[pos % size()];
// }
// size_t getBroadcastOffset(const Dim &ds) {
// assert(ds.size() >= dims.size());
// auto nDim = dims.size();
// auto nBroadcastDim = ds.size() - nDim;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
// dims[i])
// return (size_t)-1;
// size_t idx = 0;
// for (size_t i = 0; i < nDim; ++i)
// idx = idx * dims[i] + ds[nBroadcastDim + i];
// return idx;
// }
};
2022-08-09 14:58:45 +08:00
} // namespace infini