2022-07-31 21:43:26 +08:00
|
|
|
#pragma once
|
2022-08-05 12:50:34 +08:00
|
|
|
#include "core/tensor_base.h"
|
2023-09-18 14:17:02 +08:00
|
|
|
#include "core/tensor_type.h"
|
2023-08-02 16:38:16 +08:00
|
|
|
#include "utils/data_convert.h"
|
2022-08-22 15:01:03 +08:00
|
|
|
#include <cmath>
|
2023-02-24 11:19:37 +08:00
|
|
|
#include <cstring>
|
|
|
|
|
|
|
|
#if USE_CUDA
|
|
|
|
#include "cuda/cuda_runtime.h"
|
|
|
|
#endif
|
2022-07-31 21:43:26 +08:00
|
|
|
|
2022-08-07 21:12:17 +08:00
|
|
|
namespace infini {
|
2022-07-31 21:43:26 +08:00
|
|
|
|
2022-08-05 12:50:34 +08:00
|
|
|
// TODO: how to deal with this
|
|
|
|
using ShapeElem = int;
|
|
|
|
using Shape = vector<ShapeElem>;
|
2022-08-15 15:08:56 +08:00
|
|
|
class TensorObj : public TensorBaseObj {
|
2022-07-31 21:43:26 +08:00
|
|
|
private:
|
2022-08-05 12:50:34 +08:00
|
|
|
Shape shape;
|
2023-03-21 14:00:04 +08:00
|
|
|
size_t _size; // Cache of Π(shape).
|
|
|
|
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
|
|
|
// scratch have a new id.
|
2023-09-18 14:17:02 +08:00
|
|
|
TensorType tensorType = TensorType::others;
|
|
|
|
|
2022-07-31 21:43:26 +08:00
|
|
|
public:
|
2023-03-21 14:00:04 +08:00
|
|
|
TensorObj(Shape shape, DataType dtype, Runtime runtime);
|
2022-08-15 15:08:56 +08:00
|
|
|
virtual ~TensorObj() {}
|
2022-07-31 21:43:26 +08:00
|
|
|
string toString() const override;
|
|
|
|
|
2023-03-25 12:04:24 +08:00
|
|
|
size_t size() const { return _size; }
|
|
|
|
size_t getBytes() const { return _size * dtype.getSize(); }
|
2022-07-31 21:43:26 +08:00
|
|
|
|
2022-08-05 12:50:34 +08:00
|
|
|
Shape getDims() const { return shape; }
|
2023-08-16 21:49:43 +08:00
|
|
|
size_t getRank() const { return shape.size(); }
|
2023-08-22 14:21:59 +08:00
|
|
|
Shape getStride() const;
|
2023-03-21 14:00:04 +08:00
|
|
|
size_t getOffset(const vector<int> &ds) const;
|
2022-08-25 11:29:16 +08:00
|
|
|
void dataMalloc();
|
2023-03-25 12:04:24 +08:00
|
|
|
UidBaseType getFuid() const { return fuid; }
|
2023-09-18 14:17:02 +08:00
|
|
|
bool isWeight() const { return tensorType == TensorType::weight; }
|
|
|
|
bool isInput() const { return tensorType == TensorType::input; }
|
|
|
|
bool isOutput() const { return tensorType == TensorType::output; }
|
|
|
|
bool isOthers() const { return tensorType == TensorType::others; }
|
|
|
|
void setWeight() { tensorType = TensorType::weight; }
|
|
|
|
void setInput() { tensorType = TensorType::input; }
|
|
|
|
void setOutput() { tensorType = TensorType::output; }
|
|
|
|
string tensorTypeToString() const {
|
|
|
|
switch (tensorType) {
|
|
|
|
case TensorType::weight:
|
|
|
|
return "weight";
|
|
|
|
break;
|
|
|
|
case TensorType::input:
|
|
|
|
return "input";
|
|
|
|
break;
|
|
|
|
case TensorType::output:
|
|
|
|
return "output";
|
|
|
|
break;
|
|
|
|
case TensorType::others:
|
|
|
|
return "others";
|
|
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
return "unknown tensor type";
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
2022-08-23 16:55:59 +08:00
|
|
|
|
2022-09-13 11:27:41 +08:00
|
|
|
void load(std::string file_path);
|
|
|
|
void save(std::string file_path);
|
|
|
|
|
2023-09-01 11:20:26 +08:00
|
|
|
void copyin(const void *ptr, size_t size) {
|
|
|
|
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
|
|
|
|
}
|
|
|
|
void copyout(void *ptr, size_t size) const {
|
|
|
|
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
|
|
|
|
}
|
|
|
|
|
2023-03-21 14:00:04 +08:00
|
|
|
// Copy elements from `data`.
|
2023-03-25 12:04:24 +08:00
|
|
|
template <typename T> void copyin(const vector<T> &data) {
|
2023-08-02 16:38:16 +08:00
|
|
|
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
2023-09-01 11:20:26 +08:00
|
|
|
IT_ASSERT(data.size() == _size);
|
2023-03-21 14:00:04 +08:00
|
|
|
copyin(data.data(), getBytes());
|
2022-08-23 16:55:59 +08:00
|
|
|
}
|
2023-03-21 14:00:04 +08:00
|
|
|
// Copy all the elements to a vector.
|
2023-03-25 12:04:24 +08:00
|
|
|
template <typename T> auto copyout() const {
|
2023-08-02 16:38:16 +08:00
|
|
|
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
2023-03-21 14:00:04 +08:00
|
|
|
std::vector<T> ans(_size);
|
|
|
|
copyout(ans.data(), getBytes());
|
|
|
|
return ans;
|
|
|
|
}
|
|
|
|
// Copy the element at `pos`.
|
2023-03-25 12:04:24 +08:00
|
|
|
template <typename T> auto copyOne(const vector<int> &pos) const {
|
2023-08-02 16:38:16 +08:00
|
|
|
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
2023-03-21 14:00:04 +08:00
|
|
|
auto offset = getOffset(pos);
|
|
|
|
auto bytes = dtype.getSize();
|
|
|
|
T ans;
|
2023-03-22 09:46:40 +08:00
|
|
|
runtime->copyBlobToCPU(
|
|
|
|
&ans, getRawDataPtr<uint8_t *>() + offset * bytes, bytes);
|
2023-03-21 14:00:04 +08:00
|
|
|
return ans;
|
2022-08-17 14:16:01 +08:00
|
|
|
}
|
2022-08-23 16:55:59 +08:00
|
|
|
|
2022-08-25 11:29:16 +08:00
|
|
|
void copyData(const TensorObj *src);
|
|
|
|
void copyData(const Tensor &src) { copyData(src.get()); }
|
2023-04-18 00:26:36 +08:00
|
|
|
|
2023-08-13 13:39:35 +08:00
|
|
|
// TODO: Rename this function later, because it is confused that it will
|
|
|
|
// change the field data, but actually it generates data and maybe copy to
|
|
|
|
// device.
|
2023-04-18 00:26:36 +08:00
|
|
|
// FIXME: std::fucntion copies the generator instead of passing it by ref.
|
|
|
|
// Thus the internal state of generator cannot be updated.
|
2022-08-22 15:01:03 +08:00
|
|
|
void setData(
|
2023-04-18 15:10:33 +08:00
|
|
|
std::function<void(void *, size_t, DataType)> const &generator) const;
|
2023-08-13 13:39:35 +08:00
|
|
|
|
|
|
|
void setDataBlob(const Blob &blob);
|
|
|
|
|
2023-02-12 18:27:52 +08:00
|
|
|
Tensor clone() const {
|
|
|
|
auto obj = make_ref<TensorObj>(*this);
|
|
|
|
obj->freeData();
|
2023-02-24 15:02:52 +08:00
|
|
|
obj->targets.clear();
|
|
|
|
obj->source.reset();
|
2023-02-12 18:27:52 +08:00
|
|
|
return obj;
|
|
|
|
}
|
|
|
|
Tensor clone(Runtime runtime) const {
|
|
|
|
auto obj = make_ref<TensorObj>(*this);
|
|
|
|
obj->runtime = runtime;
|
|
|
|
obj->freeData();
|
2023-02-24 15:02:52 +08:00
|
|
|
obj->targets.clear();
|
|
|
|
obj->source.reset();
|
2023-02-12 18:27:52 +08:00
|
|
|
if (hasData()) {
|
|
|
|
obj->dataMalloc();
|
|
|
|
obj->copyData(this);
|
|
|
|
}
|
2022-08-31 14:44:53 +08:00
|
|
|
return obj;
|
|
|
|
}
|
2022-08-22 15:01:03 +08:00
|
|
|
|
2022-08-23 16:55:59 +08:00
|
|
|
void printData() const;
|
2023-04-18 00:26:36 +08:00
|
|
|
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
2022-08-23 16:55:59 +08:00
|
|
|
|
2022-08-25 11:29:16 +08:00
|
|
|
template <typename T> bool equalData(const vector<T> &dataVector) {
|
|
|
|
IT_ASSERT(size() == dataVector.size());
|
2023-08-02 16:38:16 +08:00
|
|
|
if (dtype == DataType::Float16) {
|
|
|
|
return equalDataImpl_fp16(getRawDataPtr<uint16_t *>(),
|
|
|
|
(float *)dataVector.data(), size());
|
|
|
|
}
|
|
|
|
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
2023-04-18 15:10:33 +08:00
|
|
|
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
|
2022-08-25 11:29:16 +08:00
|
|
|
}
|
|
|
|
|
2022-09-09 13:43:59 +08:00
|
|
|
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
private:
|
2023-04-18 15:10:33 +08:00
|
|
|
template <class T> string dataToString() const {
|
|
|
|
std::stringstream builder;
|
|
|
|
builder << "Tensor: " << guid << std::endl;
|
|
|
|
|
|
|
|
auto numDims = shape.size();
|
|
|
|
auto dimSzVec = vector<int>(numDims, 1);
|
|
|
|
auto ptr = data->getPtr<T *>();
|
|
|
|
dimSzVec[numDims - 1] = shape[numDims - 1];
|
|
|
|
|
|
|
|
for (int i = numDims - 1; i != 0; --i)
|
|
|
|
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
|
|
|
|
|
|
|
|
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
|
|
|
for (size_t j = 0; j < numDims; ++j)
|
|
|
|
if (i % dimSzVec[j] == 0)
|
|
|
|
builder << "[";
|
|
|
|
|
|
|
|
builder << ptr[i];
|
|
|
|
for (size_t j = 0; j < numDims; ++j)
|
|
|
|
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1)
|
|
|
|
builder << "]";
|
|
|
|
|
|
|
|
if (i != size() - 1)
|
|
|
|
builder << ", ";
|
|
|
|
|
|
|
|
auto column = (size_t)dimSzVec[numDims - 1];
|
|
|
|
if (i % column == column - 1)
|
|
|
|
builder << std::endl;
|
|
|
|
}
|
|
|
|
return builder.str();
|
|
|
|
}
|
2022-08-25 11:29:16 +08:00
|
|
|
|
|
|
|
template <typename T>
|
2023-10-16 10:57:08 +08:00
|
|
|
bool equalDataImpl(const T *a, const T *b, size_t size,
|
|
|
|
double relativeError = 1e-6) const {
|
2022-08-25 11:29:16 +08:00
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
|
if constexpr (std::is_integral_v<T>) {
|
|
|
|
if (a[i] != b[i])
|
|
|
|
return false;
|
|
|
|
} else if constexpr (std::is_floating_point_v<T>) {
|
|
|
|
if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) >
|
2023-10-16 10:57:08 +08:00
|
|
|
relativeError) {
|
2022-08-25 11:29:16 +08:00
|
|
|
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
static_assert(!sizeof(T), "Unsupported data type");
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
2022-09-09 13:43:59 +08:00
|
|
|
|
2023-08-02 16:38:16 +08:00
|
|
|
bool equalDataImpl_fp16(const uint16_t *a, const float *b,
|
|
|
|
size_t size) const {
|
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
|
auto a_fp32 = fp16_to_float(a[i]);
|
|
|
|
auto b_fp32 = b[i];
|
|
|
|
if (fabs(a_fp32 - b_fp32) / std::max(fabs(a_fp32), fabs(b_fp32)) >
|
|
|
|
1e-6) {
|
|
|
|
printf("Error on %lu: %f %f\n", i, a_fp32, b_fp32);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2022-09-09 13:43:59 +08:00
|
|
|
Shape getPosByOffset(size_t offset, Shape dim) const;
|
|
|
|
size_t getOffsetByPos(Shape pos, Shape dim) const;
|
|
|
|
|
2022-08-05 12:50:34 +08:00
|
|
|
// void setDims(const Dim &dms) { dims = dms; }
|
2022-07-31 21:43:26 +08:00
|
|
|
|
|
|
|
// bool dataRand(int seed = 0) {
|
|
|
|
// if (data == nullptr)
|
|
|
|
// data = new VType[size()];
|
|
|
|
// if (!random_inited)
|
|
|
|
// initFastrand();
|
|
|
|
// // srand(seed);
|
|
|
|
// // faster rand generator; parallel
|
|
|
|
// size_t iEnd = size();
|
|
|
|
// // std::cerr << "Init beginned " << std::endl;
|
|
|
|
// #pragma omp parallel for
|
|
|
|
// for (size_t i = 0; i < iEnd; ++i)
|
|
|
|
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
|
|
|
// 10000;
|
|
|
|
// // std::cerr << "Init finished" << std::endl;
|
|
|
|
// computed = ComputedFull;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setScalar(VType val) {
|
|
|
|
// if (data == nullptr || !dims.empty())
|
|
|
|
// return false;
|
|
|
|
// data[0] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setData(const Dim &ds, VType val) {
|
|
|
|
// if (data == nullptr || ds.size() != dims.size())
|
|
|
|
// return false;
|
|
|
|
// data[getOffset(ds)] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setData(size_t pos, VType val) {
|
|
|
|
// if (data == nullptr || pos >= size())
|
|
|
|
// return false;
|
|
|
|
// data[pos] = val;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
|
|
|
|
|
|
|
// VType getBroadcastData(const Dim &ds) {
|
|
|
|
// assert(data != nullptr);
|
|
|
|
// auto offset = getBroadcastOffset(ds);
|
|
|
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// VType getBroadcastData(size_t pos) {
|
|
|
|
// assert(data != nullptr);
|
|
|
|
// return data[pos % size()];
|
|
|
|
// }
|
|
|
|
|
|
|
|
// size_t getBroadcastOffset(const Dim &ds) {
|
|
|
|
// assert(ds.size() >= dims.size());
|
|
|
|
// auto nDim = dims.size();
|
|
|
|
// auto nBroadcastDim = ds.size() - nDim;
|
|
|
|
// for (size_t i = 0; i < nDim; ++i)
|
|
|
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
|
|
|
// dims[i])
|
|
|
|
// return (size_t)-1;
|
|
|
|
// size_t idx = 0;
|
|
|
|
// for (size_t i = 0; i < nDim; ++i)
|
|
|
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
|
|
|
// return idx;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// void itInit() { it = Dim(dims.size(), 0); }
|
|
|
|
|
|
|
|
// void itReset() {
|
|
|
|
// itInit();
|
|
|
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
|
|
|
// it[i] = 0;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool itValid() {
|
|
|
|
// if (it.size() != dims.size())
|
|
|
|
// return false;
|
|
|
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
|
|
|
// if (it[i] >= dims[i])
|
|
|
|
// return false;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// const Dim &itGet() { return it; }
|
|
|
|
|
|
|
|
// void itNext() {
|
|
|
|
// auto p = it.size() - 1;
|
|
|
|
// it[p] += 1;
|
|
|
|
// while (p >= 1) {
|
|
|
|
// if (it[p] == dims[p]) {
|
|
|
|
// it[p] = 0;
|
|
|
|
// it[--p] += 1;
|
|
|
|
// } else
|
|
|
|
// break;
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
// TensorType getType() const { return type; }
|
|
|
|
// void setType(TensorType ty) { type = ty; }
|
|
|
|
|
|
|
|
// static inline void initFastrand() {
|
|
|
|
// assert(omp_get_max_threads() <= 256);
|
|
|
|
// // srand(0); // constant seed for test
|
|
|
|
// // align random_seed to avoid false sharing
|
|
|
|
// for (int i = 0; i < 256 * 16; ++i) {
|
|
|
|
// // random_seed[i] = rand();
|
|
|
|
// // constant random seed for test
|
|
|
|
// random_seed[i] = i;
|
|
|
|
// }
|
|
|
|
// random_inited = true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// static inline int fastrand(int &g_seed) {
|
|
|
|
// g_seed = (214013 * g_seed + 2531011);
|
|
|
|
// return (g_seed >> 16) & 0x7FFF;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
|
|
|
// assert(!splittingPoints.empty());
|
|
|
|
// return &splittingPoints;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
|
|
|
// assert(!value.empty());
|
|
|
|
// splittingPoints = value;
|
|
|
|
// return true;
|
|
|
|
// }
|
|
|
|
|
|
|
|
// void printSplittingPoints() {
|
|
|
|
// if (splittingPoints.empty())
|
|
|
|
// printf("Empty SplittingPoints");
|
|
|
|
// else {
|
|
|
|
// printf("[");
|
|
|
|
// for (auto &vs : splittingPoints) {
|
|
|
|
// printf("[");
|
|
|
|
// for (auto v : vs)
|
|
|
|
// printf("%2d,", v);
|
|
|
|
// printf("],");
|
|
|
|
// }
|
|
|
|
// printf("]");
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
// void initSplittingPoints() {
|
2023-08-16 21:49:43 +08:00
|
|
|
// splittingPoints.resize(getRank()); }
|
2022-07-31 21:43:26 +08:00
|
|
|
|
|
|
|
// void printShape();
|
|
|
|
};
|
|
|
|
|
2022-09-13 11:27:41 +08:00
|
|
|
} // namespace infini
|