From a6c919b61df6ba82b0dba123e0ce00289dfaa382 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 7 Mar 2024 09:01:00 +0000 Subject: [PATCH] stream kernel --- src/kernels/mlu/src/bangSoftmax_device.mlu | 579 ++++++++++++--------- 1 file changed, 330 insertions(+), 249 deletions(-) diff --git a/src/kernels/mlu/src/bangSoftmax_device.mlu b/src/kernels/mlu/src/bangSoftmax_device.mlu index 597ea464..f0bdd30b 100644 --- a/src/kernels/mlu/src/bangSoftmax_device.mlu +++ b/src/kernels/mlu/src/bangSoftmax_device.mlu @@ -4,7 +4,7 @@ const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768 const int nramNum = NRAM_MAX_SIZE/sizeof(float); __nram__ float nram_buffer[nramNum]; -const int SRC_MAX_SIZE = 1024 * 128;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2 +const int SRC_MAX_SIZE = 1024 * 64;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2 //4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE const int maxNum = SRC_MAX_SIZE/sizeof(float); const int warpSize = 32; @@ -98,7 +98,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f //-----------------------------------------allocate memory float* src = nram_buffer; - float* tmp = src + maxNum; + float* tmp = src + 3 * maxNum; float* tmpOldMax = tmp + strideS; float* tmpNewMax = tmpOldMax + strideS; float* tmpSum = tmpNewMax + strideS; @@ -123,26 +123,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f __bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity __bang_write_zero(tmpSum, strideS);//Must be initialized to zero - for(int j = 0; j < repeat; j++){ - __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); - for(int m = 0; m < multiple; m++){ - __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); - - __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result - - __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 - __bang_active_exp_less_0(tmp, tmp, strideS); - if(j != 0 || m != 0){ - __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM - __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) - __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM) - } - __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) - //if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]); - __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); } + if(j > 0){ + for(int m = 0; m < multiple; m++){ + __memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM); + + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result + + __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 + __bang_active_exp_less_0(tmp, tmp, strideS); + if(j != 1 || m != 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + } + __sync_all_ipu(); } - //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]); + if(remain){ __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM); for(int m = 0; m < remain; m++){ @@ -161,9 +166,9 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f } //At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum - //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]); + __bang_active_reciphp(tmpSum, tmpSum, strideS); - //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]); + if(remain){ for(int m = 0; m < remain; m++){ __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); @@ -174,23 +179,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f } } - for(int j = 0 ; j < repeat; j++){ - __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); - for(int m = 0; m < multiple; m++){ - __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); - - __bang_sub(tmp, tmp, tmpNewMax, strideS); - __bang_active_exp_less_0(tmp, tmp, strideS); - __bang_mul(tmp, tmp, tmpSum, strideS); - __memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM); + for(int j = 0 ; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); } + if(j > 0 && j < repeat + 1){ + for(int m = 0; m < multiple; m++){ + __memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM); + + __bang_sub(tmp, tmp, tmpNewMax, strideS); + __bang_active_exp_less_0(tmp, tmp, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS); + __memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(float), NRAM2NRAM); + } + } + if(j > 1){ + __memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); } } } else if(dimsize * stride < maxNum){ //-----------------------------------------allocate memory float* src = nram_buffer; - float* tmp = src + maxNum; + float* tmp = src + 3 * maxNum; float* tmpOldMax = tmp + strideS; float* tmpNewMax = tmpOldMax + strideS; float* tmpSum = tmpNewMax + strideS; @@ -211,40 +224,48 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds destination = destination + indStart * behindsize; int tid; - for(int s = 0; s < taskRepeat; s++){ - tid = s * multiple * behindsize; - __memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM); - for(int m = 0; m < multiple; m++){ - __bang_write_zero(tmpSum, strideS); - __bang_write_value(tmp, strideS, -INFINITY); - __bang_write_value(tmpNewMax, strideS, -INFINITY); - for(int i = 0; i < dimsize; i++){ - __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); - __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); - __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M - __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) - if(i > 0){ - __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM - __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) - __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + for(int s = 0; s < taskRepeat + 2; s++){ + if(s < taskRepeat){ + tid = s * multiple * behindsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM); + } + if(s > 0 && s < taskRepeat + 1){ + for(int m = 0; m < multiple; m++){ + __bang_write_zero(tmpSum, strideS); + __bang_write_value(tmp, strideS, -INFINITY); + __bang_write_value(tmpNewMax, strideS, -INFINITY); + for(int i = 0; i < dimsize; i++){ + __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + if(i > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM + } + __bang_active_reciphp(tmpSum, tmpSum, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized + + __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM); + for(int i = 0; i < dimsize - 1; i++){ + __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + __bang_mul(tmp, tmp, tmpSum, strideS); + + __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM); } - __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) - __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM - } - __bang_active_reciphp(tmpSum, tmpSum, strideS); - __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized - //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM); - __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM); - for(int i = 0; i < dimsize - 1; i++){ - __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); - __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M - __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) - __bang_mul(tmp, tmp, tmpSum, strideS); - //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM); - __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM); } } - __memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM); + if(s > 1){ + tid = (s - 2) * multiple * behindsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); } //__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize); if(step){ @@ -288,175 +309,213 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f } __mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1 - int multiple = maxNum / dimsize; - int size = taskDim * multiple; - int remainS = othersize % size; - int taskRepeat = (othersize - remainS) / size; - int remainT = remainS % taskDim; - int stepEasy = (remainS - remainT) / taskDim; - int stepHard = stepEasy + 1; - int step = (taskId < remainT ? stepHard : stepEasy); - //The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step - //Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize - int startHard = taskId * (taskRepeat * multiple + stepHard); - int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy); - int indStart = (taskId < remainT ? startHard: startEasy); - source = source + indStart * dimsize; - destination = destination + indStart * dimsize; - - //-----------------------------------------allocate memory - float* src = nram_buffer; - float* tmp = src + maxNum; - float* destSum = tmp + dimS; - int remainDim = dimsize % dimS;//Dimsize may not be a power of 2 - int repeatDim = (dimsize - remainDim) / dimS; - - __nram__ float destSumFinal[warpSize];//Reduce destSum to destFinal [0] + __nram__ float destSumFinal[warpSize]; __nram__ float srcMax[2]; __nram__ float destOldMax; __nram__ float destNewMax; - //----------------------------------------- - //printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize); - int tid; - for(int s = 0; s < taskRepeat; s++){ - tid = s * multiple * dimsize; - __memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM); - for(int j = 0; j < multiple; j++){ - __bang_write_zero(destSum, dimS); - __bang_write_zero(destSumFinal, warpSize); - destNewMax = -INFINITY; - - for(int i = 0; i < repeatDim; i++){ - __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); - __bang_argmax(srcMax, tmp, dimS); + if(dimsize >= maxNum){ + float *src = nram_buffer; + float *destSum = src + 3 * maxNum; + + int remain = dimsize % maxNum; + int repeat = (dimsize - remain)/maxNum; + + int otherRemain = othersize % taskDim; + int stepEasy = (othersize - otherRemain) / taskDim; + int stepHard = stepEasy + 1; + + int startHard = taskId * stepHard; + int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy; + int indStart = (taskId < otherRemain ? startHard : startEasy); + source = source + indStart * dimsize; + destination = destination + indStart * dimsize; + + + destOldMax = -INFINITY; + destNewMax = -INFINITY; + __bang_write_zero(destSum, maxNum); + for(int i = 0; i < repeat + 1; i++){ + if(i < repeat){ + __memcpy_async(src + i % 2 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + } + if(i > 0){ + __bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum); if(destNewMax < srcMax[0]){ destNewMax = srcMax[0]; } - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - if(i > 0){ - __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); + __bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum); + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum); + if(i > 1){ + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum); } - __bang_add(destSum, destSum, tmp, dimS); + __bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum); destOldMax = destNewMax; } - if(remainDim){ - __bang_write_value(tmp, dimS, -INFINITY); - __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); - __bang_argmax(srcMax, tmp, dimS); - if(destNewMax < srcMax[0]){ - destNewMax = srcMax[0]; - } - __bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax - __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - if(repeatDim > 0){ - __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); - } - __bang_add(destSum, destSum, tmp, dimS); - destOldMax = destNewMax; - } - - int segNum = dimS / warpSize;//Starting numerical summation - for(int strip = segNum/2; strip > 0; strip = strip / 2){ - for(int i = 0; i < strip ; i++){ - __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); - } - } - __bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum - if(remainDim){ - destSumFinal[0] = destSumFinal[0] - (dimS - remainDim); - } - //Now let's start writing back the data - float globalSumInv = 1.0/destSumFinal[0]; - if(remainDim){ - __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); - __memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM); - } - for(int i = 0; i < repeatDim; i++){ - __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); - __memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM); - } + __sync_all_ipu(); } - //it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM, - //there may be a situation where src writes back to GDRAM before modifying the src data + //------------ + if(remain){ + __bang_write_value(src, maxNum, -INFINITY); + __memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM); + + __bang_argmax(srcMax, src, maxNum); + if(destNewMax < srcMax[0]){ + destNewMax = srcMax[0]; + } + __bang_write_value(src, maxNum, destNewMax); + __memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM); + __bang_sub_scalar(src, src, destNewMax, maxNum); + __bang_active_exp_less_0(src, src, maxNum); + if(repeat > 0){ + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum); + } + __bang_add(destSum, destSum, src, maxNum); + destOldMax = destNewMax; + } + //-------------- + //-------------------------------- + __bang_write_zero(destSumFinal, warpSize); + int segNum = maxNum / warpSize; + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, warpSize); + + if(remain){ + destSumFinal[0] = destSumFinal[0] - (maxNum - remain); + } + //----------- + float globalSumInv = 1.0/destSumFinal[0]; + for(int i = 0; i < repeat + 2; i++){ + if(i < repeat){ + __memcpy_async(src + i % 3 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + } + if(i > 0 && i < repeat){ + __bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum); + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum); + __bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); + + } + if(remain){ + __bang_write_value(src, maxNum, destNewMax); + __memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM); + __bang_sub_scalar(src, src, destNewMax, maxNum); + __bang_active_exp_less_0(src, src, maxNum); + __bang_mul_scalar(src, src, globalSumInv, maxNum); + __memcpy(destination + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM); + } + } - if(step){//Step targets parts of othersize that cannot be divided by multiple * dimsize - tid = taskRepeat * multiple * dimsize; - __memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM); - for(int j = 0; j < step; j++){ + else{ + int multiple = maxNum / dimsize; + int size = taskDim * multiple; + int remainS = othersize % size; + int taskRepeat = (othersize - remainS) / size; + int remainT = remainS % taskDim; + int stepEasy = (remainS - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + //The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step + //Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize + int startHard = taskId * (taskRepeat * multiple + stepHard); + int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy); + int indStart = (taskId < remainT ? startHard: startEasy); + source = source + indStart * dimsize; + destination = destination + indStart * dimsize; + + //-----------------------------------------allocate memory + float* src = nram_buffer;//src[maxNum] + float* tmp = src + 3 * maxNum;//tmp[dimS] + float* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, warpSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize + + //----------------------------------------- + //printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize); + int tid; + for(int s = 0; s < taskRepeat + 2; s++){ + if(s < taskRepeat){ + tid = s * multiple * dimsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM); + } + if(s > 0 && s < taskRepeat + 1){ + for(int j = 0; j < multiple; j++){ + __bang_write_zero(destSum, dimS); + __bang_write_zero(destSumFinal, warpSize); + __bang_write_value(tmp, dimS, -INFINITY); + + __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + __bang_write_value(tmp, dimS, srcMax[0]); + __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM); + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + __bang_active_exp_less_0(tmp, tmp, dimS);//tmp[dimsize:dimS] = exp(0) + __bang_add(destSum, destSum, tmp, dimS); + + int segNum = dimS / warpSize;//Starting numerical summation + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum + destSumFinal[0] = destSumFinal[0] - (dimS - dimsize); + //Now let's start writing back the data + float globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum); + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM); + } + } + if(s > 1){ + tid = (s - 2) * multiple * dimsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); + //it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM, + //there may be a situation where src writes back to GDRAM before modifying the src data + } + for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize + tid = taskRepeat * multiple * dimsize + s * dimsize; __bang_write_zero(destSum, dimS); __bang_write_zero(destSumFinal, warpSize); - destNewMax = -INFINITY; - for(int i = 0; i < repeatDim; i++){//RepeatDim refers to the total number of cycles required to read the current dimsize data using dimS after fixing otherIdx - __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); - __bang_argmax(srcMax, tmp, dimS); - if(destNewMax < srcMax[0]){ - destNewMax = srcMax[0]; - } - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - if(i > 0){ - __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); - } - __bang_add(destSum, destSum, tmp, dimS); - destOldMax = destNewMax; - } - if(remainDim){//RemainDim refers to the part of dimsize that cannot be divided by dimS after fixing otherIdx - __bang_write_value(tmp, dimS, -INFINITY); - __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); - __bang_argmax(srcMax, tmp, dimS); - if(destNewMax < srcMax[0]){ - destNewMax = srcMax[0]; - } - - __bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax - __memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - if(repeatDim > 0){ - __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS); - } - __bang_add(destSum, destSum, tmp, dimS); - destOldMax = destNewMax; - } - int segNum = dimS / warpSize;//Starting numerical summation + __bang_write_value(tmp, dimS, -INFINITY); + __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM); + + __bang_argmax(srcMax, tmp, dimS); + __bang_write_value(tmp, dimS, srcMax[0]); + __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM); + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + + __bang_active_exp_less_0(tmp, tmp, dimS); + __bang_add(destSum, destSum, tmp, dimS); + + int segNum = dimS / warpSize; for(int strip = segNum/2; strip > 0; strip = strip / 2){ for(int i = 0; i < strip ; i++){ __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); - } + } } __bang_reduce_sum(destSumFinal, destSum, warpSize); - //At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum - if(remainDim){ - destSumFinal[0] = destSumFinal[0] - (dimS - remainDim); - } - //__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]); + destSumFinal[0] = destSumFinal[0] - (dimS - dimsize); + //__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]); float globalSumInv = 1.0/destSumFinal[0]; - if(remainDim){ - __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); - __memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM); - } - for(int i = 0; i < repeatDim; i++){ - __memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM); - __bang_sub_scalar(tmp, tmp, destNewMax, dimS); - __bang_active_exp_less_0(tmp, tmp, dimS); - __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); - __memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM); - } - } - } + __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum); + __memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM); + + } + } } __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0 //-----------------------------------------allocate memory - float* src = nram_buffer; - float* tmpSum = src + maxNum; - float* tmpNewMax = src + 2 * maxNum; - float* tmpOldMax = src + 3 * maxNum; + float* src = nram_buffer;// src[3 * maxNum] + float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum] + float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum] + float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum] //----------------------------------------- int remain = othersize % taskDim; int stepEasy = (othersize - remain)/taskDim; @@ -470,67 +529,89 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o for(int j = 0; j < repeat; j++){ __bang_write_value(tmpNewMax, maxNum, -INFINITY); __bang_write_zero(tmpSum, maxNum); - for(int i = 0; i < dimsize; i++){ - __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); - __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value - __bang_sub(src, src, tmpNewMax, maxNum);//x - M - __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) - if(i > 0){ - __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM - __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) - __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM) + for(int i = 0; i < dimsize + 1; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); } - __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) - __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + if(i > 0){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value + __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + __sync_all_ipu(); } __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum //Start exponential transformation and write back to GDRAM - __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized - __memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); - for(int i = 0; i < dimsize - 1; i++){ - __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); - __bang_sub(src, src, tmpNewMax, maxNum);//x - M - __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) - __bang_mul(src, src, tmpSum, maxNum); - __memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); + for(int i = 0; i < dimsize + 2; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); + } + if(i > 0 && i < dimsize + 1){ + __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M) + __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); } } if(remainNram){ __bang_write_value(tmpNewMax, maxNum, -INFINITY); __bang_write_zero(tmpSum, maxNum); - __bang_write_zero(src, maxNum); + __bang_write_zero(src, 3 * maxNum); - for(int i = 0; i < dimsize; i++){ - __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); - __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum); - __bang_sub(src, src, tmpNewMax, maxNum);//x - M - __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) - if(i > 0){ - __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM - __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) - __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + for(int i = 0; i < dimsize + 1; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); } - __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) - __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + if(i > 0){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum); + __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM + } + __sync_all_ipu(); } __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum //Start exponential transformation and write back to GDRAM - __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized - __memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); - for(int i = 0; i < dimsize - 1; i++){ - __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); - __bang_sub(src, src, tmpNewMax, maxNum);//x - M - __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) - __bang_mul(src, src, tmpSum, maxNum); - __memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); + + for(int i = 0; i < dimsize + 2; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); + } + if(i > 0 && i < dimsize + 1){ + __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M) + __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(float), NRAM2GDRAM); + } + __sync_all_ipu(); } } } + __mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){ if(axis == nDim - 1){ int dimS;