forked from jiuyuan/InfiniTensor
add cuda cast & support half-precision for gather
This commit is contained in:
parent
5ed7db1506
commit
c29dcf1e6d
|
@ -15,6 +15,8 @@ template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
|||
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||
|
||||
template<typename INPUT, typename OUTPUT> void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||
|
||||
void unary_kernel(const Operator &_op);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData,
|
|||
metaData.inStride[i] = in->getStride()[i];
|
||||
}
|
||||
}
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||
template<typename T>
|
||||
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
|
||||
|
||||
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
||||
size_t num);
|
||||
|
|
|
@ -10,14 +10,23 @@ class GatherCuda : public CudaKernelWithoutConfig {
|
|||
const RuntimeObj *_context) const override {
|
||||
|
||||
auto input = op->getInputs(0);
|
||||
auto index = op->getInputs(1);
|
||||
auto index = op->getInputs(1);
|
||||
|
||||
GatherMetaData metaData;
|
||||
initGatherMetaData(metaData, op);
|
||||
|
||||
auto inData = input->getRawDataPtr<float *>();
|
||||
auto outData = op->getOutput()->getRawDataPtr<float *>();
|
||||
gather_kernel(inData, outData, metaData, op->getOutput()->size());
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
|
||||
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
gather_kernel<float>((float*)inputData, (float*)outputData, metaData, op->getOutput()->size());
|
||||
}
|
||||
else if(op->getDType() == DataType::Float32){
|
||||
gather_kernel<half>((half*)inputData, (half*)outputData, metaData, op->getOutput()->size());
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -28,27 +28,29 @@ __device__ T gatheredOffset2Offset(int gOffset,
|
|||
return offset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _gather_kernel(float *in, float *out,
|
||||
template <typename dataT, typename T>
|
||||
__global__ void _gather_kernel(dataT *in, dataT *out,
|
||||
infini::GatherMetaData metaData, size_t num) {
|
||||
T tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
while (tid < num) {
|
||||
if (tid < num) {
|
||||
T offset = gatheredOffset2Offset<T>(tid, metaData);
|
||||
out[tid] = in[offset];
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
|
||||
template<typename T>
|
||||
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<int64_t>
|
||||
_gather_kernel<T, int64_t>
|
||||
<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
_gather_kernel<T,int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
template void gather_kernel<float>(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||
template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData, size_t num);
|
||||
template void gather_kernel<int8_t>(int8_t *in, int8_t *out, GatherMetaData metaData, size_t num);
|
||||
} // namespace infini
|
||||
|
|
|
@ -13,6 +13,42 @@ class UnaryCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class CastCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<CastObj>(_op);
|
||||
|
||||
size_t num = op->getOutput()->size();
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
|
||||
if (op->getType() == CastType::Float162Float) {
|
||||
IT_ASSERT(op->getInDType() == DataType::Float16 && op->getOutDType() == DataType::Float32);
|
||||
cast_kernel<half, float>((half *)inputData, (float*)outputData, num);
|
||||
}
|
||||
else if(op->getType() == CastType::Float2Float16){
|
||||
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Float16);
|
||||
cast_kernel<float, half>((float*)inputData, (half *)outputData, num);
|
||||
}
|
||||
else if(op->getType() == CastType::Float2Int32){
|
||||
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Int32);
|
||||
cast_kernel<float, int32_t>((float*)inputData, (int32_t *)outputData, num);
|
||||
}
|
||||
else if(op->getType() == CastType::Float2Int8){
|
||||
IT_ASSERT(op->getInDType() == DataType::Float32 && op->getOutDType() == DataType::Int8);
|
||||
cast_kernel<float, int8_t>((float*)inputData, (int8_t *)outputData, num);
|
||||
}
|
||||
else if(op->getType() == CastType::Int82Float){
|
||||
IT_ASSERT(op->getInDType() == DataType::Int8 && op->getOutDType() == DataType::Float32);
|
||||
cast_kernel<int8_t, float>((int8_t*)inputData, (float *)outputData, num);
|
||||
}
|
||||
else{
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ActivationCudnn : public CudaKernelWithoutConfig {
|
||||
virtual cudnnActivationMode_t getOpType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
|
@ -145,6 +181,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Cast, CastCuda, "Cast_CUDA");
|
||||
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda,
|
||||
// "Relu_CUDA");
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_unary.h"
|
||||
#include <math.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
using infini::E_CONSTANT;
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
|
@ -120,6 +121,17 @@ __global__ void _neg_kernel(T *input, T *output, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename INPUT, typename OUTPUT>
|
||||
__global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) {
|
||||
|
||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
if(index < n){
|
||||
cub::CastOp<OUTPUT> _CastOp;
|
||||
output[index] = _CastOp(input[index]);
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
|
@ -275,4 +287,18 @@ void unary_kernel(const Operator &_op) {
|
|||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
template<typename INPUT, typename OUTPUT>
|
||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_cast_kernel<INPUT, OUTPUT><<<gridsize, blocksize>>>(input, output, num);
|
||||
}
|
||||
|
||||
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
||||
template void cast_kernel<half, float>(half *input, float *output, size_t num);
|
||||
template void cast_kernel<float, int32_t>(float *input, int32_t *output, size_t num);
|
||||
template void cast_kernel<float, int8_t>(float *input, int8_t *output, size_t num);
|
||||
template void cast_kernel<int8_t, float>(int8_t *input, float *output, size_t num);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -40,24 +40,67 @@ void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<HardSigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// more shapes
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
|
||||
template <class T>
|
||||
void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, vector<float> ansVec) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1});
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2});
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Float16);
|
||||
cudaGraph->dataMalloc();
|
||||
inputGpu->setData(generator);
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
|
||||
inputCpu->printData();
|
||||
outputGpu2Cpu->printData();
|
||||
EXPECT_TRUE(outputGpu2Cpu->equalData(ansVec));
|
||||
|
||||
// GPU
|
||||
//Graph cudaGraph2 = make_ref<GraphObj>(cudaRuntime);
|
||||
////auto inputGpu2 = cudaGraph2->cloneTensor(outputGpu2Cpu);
|
||||
//auto gpuOp2 = cudaGraph2->addOp<T>(outputGpu, nullptr, CastType::Float162Float);
|
||||
//cudaGraph2->dataMalloc();
|
||||
////inputGpu2->setData(generator);
|
||||
//cudaRuntime->run(cudaGraph2);
|
||||
//auto outputGpu2 = gpuOp2->getOutput();
|
||||
//auto outputGpu2Cpu2 = outputGpu2->clone(cpuRuntime);
|
||||
|
||||
//outputGpu2Cpu->printData();
|
||||
//outputGpu2Cpu2->printData();
|
||||
//EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(cuDNN_Unary, run) {
|
||||
//testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<HardSigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
//// more shapes
|
||||
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
||||
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
||||
//testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
|
||||
|
||||
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1});
|
||||
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2});
|
||||
//testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
|
||||
testCast<CastObj>(IncrementalGenerator(), Shape{8, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7});
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -13,8 +13,9 @@ TEST(Unary, ShapeInference) {
|
|||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({2}, DataType::Float32);
|
||||
auto op = g->addOp<GeluObj>(i0, nullptr);
|
||||
auto op = g->addOp<CastObj>(i0, nullptr, CastType::Float2Float16);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2}));
|
||||
EXPECT_EQ(op->getOutDType(), (DataType::Float16));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue