2023-04-18 00:26:36 +08:00
|
|
|
#pragma once
|
2022-09-27 10:41:12 +08:00
|
|
|
#include "core/common.h"
|
|
|
|
#include "core/tensor_base.h"
|
2023-08-02 16:38:16 +08:00
|
|
|
#include "utils/data_convert.h"
|
2023-04-18 00:26:36 +08:00
|
|
|
#include <random>
|
2022-09-27 10:41:12 +08:00
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
|
|
|
// TODO: isolate these class
|
|
|
|
class DataGenerator {
|
|
|
|
private:
|
|
|
|
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
|
|
|
|
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
|
2023-08-02 16:38:16 +08:00
|
|
|
virtual void fill_fp16(uint16_t *data, size_t size) { IT_TODO_HALT(); }
|
2022-09-27 10:41:12 +08:00
|
|
|
|
|
|
|
public:
|
|
|
|
virtual ~DataGenerator() {}
|
|
|
|
void operator()(void *data, size_t size, DataType dataType) {
|
|
|
|
if (dataType == DataType::UInt32)
|
|
|
|
fill(reinterpret_cast<uint32_t *>(data), size);
|
|
|
|
else if (dataType == DataType::Float32)
|
|
|
|
fill(reinterpret_cast<float *>(data), size);
|
2023-08-02 16:38:16 +08:00
|
|
|
else if (dataType == DataType::Float16)
|
|
|
|
fill_fp16(reinterpret_cast<uint16_t *>(data), size);
|
2022-09-27 10:41:12 +08:00
|
|
|
else
|
|
|
|
IT_TODO_HALT();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class IncrementalGenerator : public DataGenerator {
|
|
|
|
public:
|
|
|
|
virtual ~IncrementalGenerator() {}
|
|
|
|
|
|
|
|
private:
|
|
|
|
template <typename T> void fill(T *data, size_t size) {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
data[i] = i;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void fill(uint32_t *data, size_t size) override {
|
|
|
|
fill<uint32_t>(data, size);
|
|
|
|
}
|
|
|
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
2023-08-02 16:38:16 +08:00
|
|
|
// FIXME: fix the accuracy standards when dtype is float16
|
|
|
|
void fill_fp16(uint16_t *data, size_t size) {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
float x = 2.0f;
|
|
|
|
data[i] = float_to_fp16(x);
|
|
|
|
}
|
|
|
|
}
|
2022-09-27 10:41:12 +08:00
|
|
|
};
|
|
|
|
|
2023-04-18 00:26:36 +08:00
|
|
|
class RandomGenerator : public DataGenerator {
|
|
|
|
private:
|
|
|
|
double l, r;
|
|
|
|
std::mt19937 e;
|
|
|
|
std::uniform_int_distribution<int> di;
|
|
|
|
std::uniform_real_distribution<float> dr;
|
|
|
|
|
|
|
|
public:
|
|
|
|
RandomGenerator(double l = 0, double r = 1, unsigned int seed = 0)
|
|
|
|
: l(l), r(r), e(seed), di(l, r), dr(l, r) {}
|
|
|
|
virtual ~RandomGenerator() {}
|
|
|
|
|
|
|
|
private:
|
|
|
|
void fill(uint32_t *data, size_t size) override {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
data[i] = di(e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void fill(float *data, size_t size) override {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
data[i] = dr(e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-10-15 16:29:28 +08:00
|
|
|
template <int val> class ValGenerator : public DataGenerator {
|
2022-09-27 10:41:12 +08:00
|
|
|
public:
|
2022-10-15 16:29:28 +08:00
|
|
|
virtual ~ValGenerator() {}
|
2022-09-27 10:41:12 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
template <typename T> void fill(T *data, size_t size) {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
2022-10-15 16:29:28 +08:00
|
|
|
data[i] = val;
|
2022-09-27 10:41:12 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void fill(uint32_t *data, size_t size) override {
|
|
|
|
fill<uint32_t>(data, size);
|
|
|
|
}
|
|
|
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
2024-01-15 11:02:13 +08:00
|
|
|
void fill_fp16(uint16_t *data, size_t size) {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
float x = 1.0f * val;
|
|
|
|
data[i] = float_to_fp16(x);
|
|
|
|
}
|
|
|
|
}
|
2022-09-27 10:41:12 +08:00
|
|
|
};
|
2022-10-15 16:29:28 +08:00
|
|
|
typedef ValGenerator<1> OneGenerator;
|
|
|
|
typedef ValGenerator<0> ZeroGenerator;
|
|
|
|
} // namespace infini
|