style: fix style

This commit is contained in:
OdinaryWord 2023-12-14 11:32:07 +08:00
parent c29dcf1e6d
commit db8c3eec15
7 changed files with 86 additions and 75 deletions

View File

@ -15,7 +15,8 @@ template <typename T> void erf_kernel(T *input, T *output, size_t num);
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num); template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num); template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
template<typename INPUT, typename OUTPUT> void cast_kernel(INPUT *input, OUTPUT *output, size_t num); template <typename INPUT, typename OUTPUT>
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
void unary_kernel(const Operator &_op); void unary_kernel(const Operator &_op);

View File

@ -53,7 +53,7 @@ inline void initGatherMetaData(GatherMetaData &metaData,
metaData.inStride[i] = in->getStride()[i]; metaData.inStride[i] = in->getStride()[i];
} }
} }
template<typename T> template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num); void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,

View File

@ -18,15 +18,13 @@ class GatherCuda : public CudaKernelWithoutConfig {
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
if (op->getDType() == DataType::Float32) { if (op->getDType() == DataType::Float32) {
gather_kernel<float>((float*)inputData, (float*)outputData, metaData, op->getOutput()->size()); gather_kernel<float>((float *)inputData, (float *)outputData,
metaData, op->getOutput()->size());
} else if (op->getDType() == DataType::Float32) {
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
op->getOutput()->size());
} }
else if(op->getDType() == DataType::Float32){
gather_kernel<half>((half*)inputData, (half*)outputData, metaData, op->getOutput()->size());
}
} }
}; };

View File

@ -39,7 +39,7 @@ __global__ void _gather_kernel(dataT *in, dataT *out,
} }
namespace infini { namespace infini {
template<typename T> template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) { void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
int blockSize = 32 * 16; int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
@ -47,10 +47,13 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
_gather_kernel<T, int64_t> _gather_kernel<T, int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num); <<<gridSize, blockSize>>>(in, out, metaData, num);
} else { } else {
_gather_kernel<T,int><<<gridSize, blockSize>>>(in, out, metaData, num); _gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num);
} }
} }
template void gather_kernel<float>(float *in, float *out, GatherMetaData metaData, size_t num); template void gather_kernel<float>(float *in, float *out,
template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData, size_t num); GatherMetaData metaData, size_t num);
template void gather_kernel<int8_t>(int8_t *in, int8_t *out, GatherMetaData metaData, size_t num); template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData,
size_t num);
template void gather_kernel<int8_t>(int8_t *in, int8_t *out,
GatherMetaData metaData, size_t num);
} // namespace infini } // namespace infini

View File

@ -22,28 +22,32 @@ class CastCuda : public CudaKernelWithoutConfig {
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
if (op->getType() == CastType::Float162Float) { if (op->getType() == CastType::Float162Float) {
IT_ASSERT(op->getInDType() == DataType::Float16 && op->getOutDType() == DataType::Float32); IT_ASSERT(op->getInDType() == DataType::Float16 &&
cast_kernel<half, float>((half *)inputData, (float*)outputData, num); op->getOutDType() == DataType::Float32);
} cast_kernel<half, float>((half *)inputData, (float *)outputData,
else if(op->getType() == CastType::Float2Float16){ num);
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Float16); } else if (op->getType() == CastType::Float2Float16) {
cast_kernel<float, half>((float*)inputData, (half *)outputData, num); IT_ASSERT(op->getInDType() == DataType::Float32 &&
} op->getOutDType() == DataType::Float16);
else if(op->getType() == CastType::Float2Int32){ cast_kernel<float, half>((float *)inputData, (half *)outputData,
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Int32); num);
cast_kernel<float, int32_t>((float*)inputData, (int32_t *)outputData, num); } else if (op->getType() == CastType::Float2Int32) {
} IT_ASSERT(op->getInDType() == DataType::Float32 &&
else if(op->getType() == CastType::Float2Int8){ op->getOutDType() == DataType::Int32);
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Int8); cast_kernel<float, int32_t>((float *)inputData,
cast_kernel<float, int8_t>((float*)inputData, (int8_t *)outputData, num); (int32_t *)outputData, num);
} } else if (op->getType() == CastType::Float2Int8) {
else if(op->getType() == CastType::Int82Float){ IT_ASSERT(op->getInDType() == DataType::Float32 &&
IT_ASSERT(op->getInDType() == DataType::Int8 && op->getOutDType() == DataType::Float32); op->getOutDType() == DataType::Int8);
cast_kernel<int8_t, float>((int8_t*)inputData, (float *)outputData, num); cast_kernel<float, int8_t>((float *)inputData, (int8_t *)outputData,
} num);
else{ } else if (op->getType() == CastType::Int82Float) {
IT_ASSERT(op->getInDType() == DataType::Int8 &&
op->getOutDType() == DataType::Float32);
cast_kernel<int8_t, float>((int8_t *)inputData, (float *)outputData,
num);
} else {
IT_ASSERT(false); IT_ASSERT(false);
} }
} }

View File

@ -2,8 +2,8 @@
#include "core/constants.h" #include "core/constants.h"
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_unary.h" #include "cuda/cuda_unary.h"
#include <math.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <math.h>
using infini::E_CONSTANT; using infini::E_CONSTANT;
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
@ -121,12 +121,12 @@ __global__ void _neg_kernel(T *input, T *output, size_t n) {
} }
} }
template<typename INPUT, typename OUTPUT> template <typename INPUT, typename OUTPUT>
__global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) { __global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if(index < n){ if (index < n) {
cub::CastOp<OUTPUT> _CastOp; cub::CastOp<OUTPUT> _CastOp;
output[index] = _CastOp(input[index]); output[index] = _CastOp(input[index]);
} }
@ -287,7 +287,7 @@ void unary_kernel(const Operator &_op) {
IT_TODO_HALT(); IT_TODO_HALT();
} }
template<typename INPUT, typename OUTPUT> template <typename INPUT, typename OUTPUT>
void cast_kernel(INPUT *input, OUTPUT *output, size_t num) { void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
@ -297,8 +297,11 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
template void cast_kernel<float, half>(float *input, half *output, size_t num); template void cast_kernel<float, half>(float *input, half *output, size_t num);
template void cast_kernel<half, float>(half *input, float *output, size_t num); template void cast_kernel<half, float>(half *input, float *output, size_t num);
template void cast_kernel<float, int32_t>(float *input, int32_t *output, size_t num); template void cast_kernel<float, int32_t>(float *input, int32_t *output,
template void cast_kernel<float, int8_t>(float *input, int8_t *output, size_t num); size_t num);
template void cast_kernel<int8_t, float>(int8_t *input, float *output, size_t num); template void cast_kernel<float, int8_t>(float *input, int8_t *output,
size_t num);
template void cast_kernel<int8_t, float>(int8_t *input, float *output,
size_t num);
}; // namespace infini }; // namespace infini

View File

@ -42,7 +42,7 @@ void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
template <class T> template <class T>
void testCast(const std::function<void(void *, size_t, DataType)> &generator, void testCast(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape, vector<float> ansVec) { const Shape &shape, vector<float> ansVec) {
// Runtime // Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
@ -55,7 +55,8 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
// GPU // GPU
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime); Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = cudaGraph->cloneTensor(inputCpu); auto inputGpu = cudaGraph->cloneTensor(inputCpu);
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Float16); auto gpuOp =
cudaGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Float16);
cudaGraph->dataMalloc(); cudaGraph->dataMalloc();
inputGpu->setData(generator); inputGpu->setData(generator);
cudaRuntime->run(cudaGraph); cudaRuntime->run(cudaGraph);
@ -67,40 +68,41 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
EXPECT_TRUE(outputGpu2Cpu->equalData(ansVec)); EXPECT_TRUE(outputGpu2Cpu->equalData(ansVec));
// GPU // GPU
//Graph cudaGraph2 = make_ref<GraphObj>(cudaRuntime); // Graph cudaGraph2 = make_ref<GraphObj>(cudaRuntime);
////auto inputGpu2 = cudaGraph2->cloneTensor(outputGpu2Cpu); ////auto inputGpu2 = cudaGraph2->cloneTensor(outputGpu2Cpu);
//auto gpuOp2 = cudaGraph2->addOp<T>(outputGpu, nullptr, CastType::Float162Float); // auto gpuOp2 = cudaGraph2->addOp<T>(outputGpu, nullptr,
//cudaGraph2->dataMalloc(); // CastType::Float162Float); cudaGraph2->dataMalloc();
////inputGpu2->setData(generator); ////inputGpu2->setData(generator);
//cudaRuntime->run(cudaGraph2); // cudaRuntime->run(cudaGraph2);
//auto outputGpu2 = gpuOp2->getOutput(); // auto outputGpu2 = gpuOp2->getOutput();
//auto outputGpu2Cpu2 = outputGpu2->clone(cpuRuntime); // auto outputGpu2Cpu2 = outputGpu2->clone(cpuRuntime);
//outputGpu2Cpu->printData(); // outputGpu2Cpu->printData();
//outputGpu2Cpu2->printData(); // outputGpu2Cpu2->printData();
//EXPECT_TRUE(1); // EXPECT_TRUE(1);
} }
TEST(cuDNN_Unary, run) { TEST(cuDNN_Unary, run) {
//testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<HardSigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<HardSigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
//// more shapes //// more shapes
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{13}); // testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3}); // testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6}); // testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1}); // testUnary<GeluObj>(IncrementalGenerator(), Shape{1});
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2}); // testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2});
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); // testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testCast<CastObj>(IncrementalGenerator(), Shape{8, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7}); testCast<CastObj>(IncrementalGenerator(), Shape{8, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7});
} }
} // namespace infini } // namespace infini