2022-07-31 21:43:26 +08:00
|
|
|
#pragma once
|
|
|
|
#include "core/common.h"
|
2022-08-17 14:16:01 +08:00
|
|
|
#include "core/tensor_base.h"
|
2022-07-31 21:43:26 +08:00
|
|
|
#include "gtest/gtest.h"
|
2022-08-17 14:16:01 +08:00
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
// TODO: isolate these class
|
2022-08-17 14:16:01 +08:00
|
|
|
class DataGenerator {
|
|
|
|
private:
|
2022-08-22 15:01:03 +08:00
|
|
|
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
|
|
|
|
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
|
2022-08-17 14:16:01 +08:00
|
|
|
|
|
|
|
public:
|
2022-08-22 15:01:03 +08:00
|
|
|
virtual ~DataGenerator() {}
|
2022-08-17 14:16:01 +08:00
|
|
|
void operator()(void *data, size_t size, DataType dataType) {
|
2022-08-23 16:55:59 +08:00
|
|
|
if (dataType == DataType::UInt32)
|
2022-08-17 14:16:01 +08:00
|
|
|
fill(reinterpret_cast<uint32_t *>(data), size);
|
2022-08-23 16:55:59 +08:00
|
|
|
else if (dataType == DataType::Float32)
|
2022-08-17 14:16:01 +08:00
|
|
|
fill(reinterpret_cast<float *>(data), size);
|
2022-08-23 16:55:59 +08:00
|
|
|
else
|
2022-08-17 14:16:01 +08:00
|
|
|
IT_TODO_HALT();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class IncrementalGenerator : public DataGenerator {
|
2022-08-22 15:01:03 +08:00
|
|
|
public:
|
|
|
|
virtual ~IncrementalGenerator() {}
|
|
|
|
|
|
|
|
private:
|
|
|
|
template <typename T> void fill(T *data, size_t size) {
|
2022-08-17 14:16:01 +08:00
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
data[i] = i;
|
|
|
|
}
|
|
|
|
}
|
2022-08-22 15:01:03 +08:00
|
|
|
|
|
|
|
void fill(uint32_t *data, size_t size) override {
|
|
|
|
fill<uint32_t>(data, size);
|
|
|
|
}
|
|
|
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
|
|
|
};
|
|
|
|
|
|
|
|
class OneGenerator : public DataGenerator {
|
|
|
|
public:
|
|
|
|
virtual ~OneGenerator() {}
|
|
|
|
|
|
|
|
private:
|
|
|
|
template <typename T> void fill(T *data, size_t size) {
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
data[i] = 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void fill(uint32_t *data, size_t size) override {
|
|
|
|
fill<uint32_t>(data, size);
|
|
|
|
}
|
|
|
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
2022-08-17 14:16:01 +08:00
|
|
|
};
|
|
|
|
} // namespace infini
|