- Remove dataType from the kernel registration.

This commit is contained in:
kilinchange 2023-11-30 13:51:24 +08:00
parent 67974aee8a
commit 4db6699e09
6 changed files with 57 additions and 21 deletions

View File

@ -29,7 +29,6 @@ class Kernel {
public:
Kernel() {}
virtual ~Kernel() {}
/**
* @param op The operator to be executed.
* @param record The parameters for kernel execution. If extra parameters
@ -105,8 +104,7 @@ class KernelRegistry {
IT_ASSERT(it != kernels.end(),
"Kernel not found for key {" +
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
std::get<2>(kernelAttrs).toString() + "}");
", " + std::to_string(std::get<1>(kernelAttrs)) + "}");
return std::get<0>(it->second);
}
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
@ -131,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
} // namespace infini
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
namespace infini { \
static const bool _CAT(_register_kernel_, cnt) = \
KernelRegistry::getInstance().registerKernel( \
KernelAttrs{device, opType, dataType}, new kernel(), name); \
KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
opType}, \
new kernel(), name); \
}
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
#define REGISTER_KERNEL(device, opType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
namespace infini { \

View File

@ -4,7 +4,7 @@
#include "core/tensor.h"
namespace infini {
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
struct OpPerfKey {
HashType hash;
@ -90,6 +90,8 @@ class OperatorObj : public Object {
OpType getOpType() const { return type; }
// HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); }
DataType getInDType() const { return getInputs(0)->getDType(); }
DataType getOutDType() const { return getOutput()->getDType(); }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;

View File

@ -1,11 +1,13 @@
#pragma once
#include "core/tensor.h"
#include "cuda/cuda_common.h"
namespace infini {
void cudaPrintFloat(float *x, int len);
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
void cudaPrintTensor(const Tensor &tensor);
} // namespace infini
cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
} // namespace infini

View File

@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -1,4 +1,6 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include <cstdio>
__global__ void cudaPrintFloatImpl(float *x, int len) {
@ -18,4 +20,39 @@ void cudaPrintFloat(float *x, int len) {
cudaDeviceSynchronize();
}
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return CUDNN_DATA_FLOAT;
}
if (dataType == DataType::Double) {
return CUDNN_DATA_DOUBLE;
}
if (dataType == DataType::Float16) {
return CUDNN_DATA_HALF;
}
if (dataType == DataType::Int8) {
return CUDNN_DATA_INT8;
}
if (dataType == DataType::Int32) {
return CUDNN_DATA_INT32;
}
if (dataType == DataType::UInt8) {
return CUDNN_DATA_UINT8;
}
if (dataType == DataType::BFloat16) {
return CUDNN_DATA_BFLOAT16;
}
if (dataType == DataType::Int64) {
return CUDNN_DATA_INT64;
}
if (dataType == DataType::Bool) {
return CUDNN_DATA_BOOLEAN;
}
IT_ASSERT(false, "Unsupported data type");
}
} // namespace infini