#pragma once #include "core/common.h" #include "core/tensor_base.h" #include "utils/data_convert.h" #include namespace infini { // TODO: isolate these class class DataGenerator { private: virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); } virtual void fill(float *data, size_t size) { IT_TODO_HALT(); } virtual void fill_fp16(uint16_t *data, size_t size) { IT_TODO_HALT(); } public: virtual ~DataGenerator() {} void operator()(void *data, size_t size, DataType dataType) { if (dataType == DataType::UInt32) fill(reinterpret_cast(data), size); else if (dataType == DataType::Float32) fill(reinterpret_cast(data), size); else if (dataType == DataType::Float16) fill_fp16(reinterpret_cast(data), size); else IT_TODO_HALT(); } }; class IncrementalGenerator : public DataGenerator { public: virtual ~IncrementalGenerator() {} private: template void fill(T *data, size_t size) { for (size_t i = 0; i < size; i++) { data[i] = i; } } void fill(uint32_t *data, size_t size) override { fill(data, size); } void fill(float *data, size_t size) override { fill(data, size); } // FIXME: fix the accuracy standards when dtype is float16 void fill_fp16(uint16_t *data, size_t size) { for (size_t i = 0; i < size; i++) { float x = 2.0f; data[i] = float_to_fp16(x); } } }; class RandomGenerator : public DataGenerator { private: double l, r; std::mt19937 e; std::uniform_int_distribution di; std::uniform_real_distribution dr; public: RandomGenerator(double l = 0, double r = 1, unsigned int seed = 0) : l(l), r(r), e(seed), di(l, r), dr(l, r) {} virtual ~RandomGenerator() {} private: void fill(uint32_t *data, size_t size) override { for (size_t i = 0; i < size; i++) { data[i] = di(e); } } void fill(float *data, size_t size) override { for (size_t i = 0; i < size; i++) { data[i] = dr(e); } } }; template class ValGenerator : public DataGenerator { public: virtual ~ValGenerator() {} private: template void fill(T *data, size_t size) { for (size_t i = 0; i < size; i++) { data[i] = val; } } void fill(uint32_t *data, size_t size) override { fill(data, size); } void fill(float *data, size_t size) override { fill(data, size); } }; typedef ValGenerator<1> OneGenerator; typedef ValGenerator<0> ZeroGenerator; } // namespace infini