forked from jiuyuan/InfiniTensor
- Remove dataType from the kernel registration.
This commit is contained in:
parent
67974aee8a
commit
4db6699e09
|
@ -29,7 +29,6 @@ class Kernel {
|
|||
public:
|
||||
Kernel() {}
|
||||
virtual ~Kernel() {}
|
||||
|
||||
/**
|
||||
* @param op The operator to be executed.
|
||||
* @param record The parameters for kernel execution. If extra parameters
|
||||
|
@ -105,8 +104,7 @@ class KernelRegistry {
|
|||
IT_ASSERT(it != kernels.end(),
|
||||
"Kernel not found for key {" +
|
||||
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
|
||||
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
|
||||
std::get<2>(kernelAttrs).toString() + "}");
|
||||
", " + std::to_string(std::get<1>(kernelAttrs)) + "}");
|
||||
return std::get<0>(it->second);
|
||||
}
|
||||
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||
|
@ -131,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
|
|||
|
||||
} // namespace infini
|
||||
|
||||
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
|
||||
#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
|
||||
namespace infini { \
|
||||
static const bool _CAT(_register_kernel_, cnt) = \
|
||||
KernelRegistry::getInstance().registerKernel( \
|
||||
KernelAttrs{device, opType, dataType}, new kernel(), name); \
|
||||
KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
|
||||
opType}, \
|
||||
new kernel(), name); \
|
||||
}
|
||||
|
||||
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
|
||||
#define REGISTER_KERNEL(device, opType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
|
||||
|
||||
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
|
||||
namespace infini { \
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "core/tensor.h"
|
||||
|
||||
namespace infini {
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
|
||||
|
||||
struct OpPerfKey {
|
||||
HashType hash;
|
||||
|
@ -90,6 +90,8 @@ class OperatorObj : public Object {
|
|||
OpType getOpType() const { return type; }
|
||||
// HACK: set correct data type
|
||||
DataType getDType() const { return getInputs(0)->getDType(); }
|
||||
DataType getInDType() const { return getInputs(0)->getDType(); }
|
||||
DataType getOutDType() const { return getOutput()->getDType(); }
|
||||
virtual int numInputs() const = 0;
|
||||
virtual int numOutputs() const = 0;
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
#pragma once
|
||||
#include "core/tensor.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void cudaPrintFloat(float *x, int len);
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
void cudaPrintTensor(const Tensor &tensor);
|
||||
|
||||
} // namespace infini
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
auto &perfEngine = PerfEngine::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
|
||||
DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include <cstdio>
|
||||
|
||||
__global__ void cudaPrintFloatImpl(float *x, int len) {
|
||||
|
@ -18,4 +20,39 @@ void cudaPrintFloat(float *x, int len) {
|
|||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
|
||||
if (dataType == DataType::Float32) {
|
||||
return CUDNN_DATA_FLOAT;
|
||||
}
|
||||
if (dataType == DataType::Double) {
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
}
|
||||
if (dataType == DataType::Float16) {
|
||||
return CUDNN_DATA_HALF;
|
||||
}
|
||||
if (dataType == DataType::Int8) {
|
||||
return CUDNN_DATA_INT8;
|
||||
}
|
||||
if (dataType == DataType::Int32) {
|
||||
return CUDNN_DATA_INT32;
|
||||
}
|
||||
if (dataType == DataType::UInt8) {
|
||||
return CUDNN_DATA_UINT8;
|
||||
}
|
||||
if (dataType == DataType::BFloat16) {
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
}
|
||||
if (dataType == DataType::Int64) {
|
||||
return CUDNN_DATA_INT64;
|
||||
}
|
||||
if (dataType == DataType::Bool) {
|
||||
return CUDNN_DATA_BOOLEAN;
|
||||
}
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue