stream kernel

This commit is contained in:
xgqdut2016 2024-03-07 09:01:00 +00:00
parent d4721cb40c
commit a6c919b61d
1 changed files with 330 additions and 249 deletions

View File

@ -4,7 +4,7 @@
const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768
const int nramNum = NRAM_MAX_SIZE/sizeof(float);
__nram__ float nram_buffer[nramNum];
const int SRC_MAX_SIZE = 1024 * 128;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
const int SRC_MAX_SIZE = 1024 * 64;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
//4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE
const int maxNum = SRC_MAX_SIZE/sizeof(float);
const int warpSize = 32;
@ -98,7 +98,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
//-----------------------------------------allocate memory
float* src = nram_buffer;
float* tmp = src + maxNum;
float* tmp = src + 3 * maxNum;
float* tmpOldMax = tmp + strideS;
float* tmpNewMax = tmpOldMax + strideS;
float* tmpSum = tmpNewMax + strideS;
@ -123,26 +123,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity
__bang_write_zero(tmpSum, strideS);//Must be initialized to zero
for(int j = 0; j < repeat; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
__bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
__bang_active_exp_less_0(tmp, tmp, strideS);
if(j != 0 || m != 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
//if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
for(int j = 0; j < repeat + 1; j++){
if(j < repeat){
__memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
}
if(j > 0){
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
__bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
__bang_active_exp_less_0(tmp, tmp, strideS);
if(j != 1 || m != 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
}
__sync_all_ipu();
}
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]);
if(remain){
__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < remain; m++){
@ -161,9 +166,9 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
}
//At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
__bang_active_reciphp(tmpSum, tmpSum, strideS);
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
if(remain){
for(int m = 0; m < remain; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
@ -174,23 +179,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
}
}
for(int j = 0 ; j < repeat; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);
__bang_active_exp_less_0(tmp, tmp, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
for(int j = 0 ; j < repeat + 2; j++){
if(j < repeat){
__memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
}
if(j > 0 && j < repeat + 1){
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);
__bang_active_exp_less_0(tmp, tmp, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(float), NRAM2NRAM);
}
}
if(j > 1){
__memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
}
}
else if(dimsize * stride < maxNum){
//-----------------------------------------allocate memory
float* src = nram_buffer;
float* tmp = src + maxNum;
float* tmp = src + 3 * maxNum;
float* tmpOldMax = tmp + strideS;
float* tmpNewMax = tmpOldMax + strideS;
float* tmpSum = tmpNewMax + strideS;
@ -211,40 +224,48 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds
destination = destination + indStart * behindsize;
int tid;
for(int s = 0; s < taskRepeat; s++){
tid = s * multiple * behindsize;
__memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__bang_write_zero(tmpSum, strideS);
__bang_write_value(tmp, strideS, -INFINITY);
__bang_write_value(tmpNewMax, strideS, -INFINITY);
for(int i = 0; i < dimsize; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
for(int s = 0; s < taskRepeat + 2; s++){
if(s < taskRepeat){
tid = s * multiple * behindsize;
__memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
}
if(s > 0 && s < taskRepeat + 1){
for(int m = 0; m < multiple; m++){
__bang_write_zero(tmpSum, strideS);
__bang_write_value(tmp, strideS, -INFINITY);
__bang_write_value(tmpNewMax, strideS, -INFINITY);
for(int i = 0; i < dimsize; i++){
__memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_reciphp(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
__memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
__bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_reciphp(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
__bang_mul(tmp, tmp, tmpSum, strideS);
//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
}
}
__memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM);
if(s > 1){
tid = (s - 2) * multiple * behindsize;
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
if(step){
@ -288,175 +309,213 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
}
__mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1
int multiple = maxNum / dimsize;
int size = taskDim * multiple;
int remainS = othersize % size;
int taskRepeat = (othersize - remainS) / size;
int remainT = remainS % taskDim;
int stepEasy = (remainS - remainT) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remainT ? stepHard : stepEasy);
//The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step
//Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize
int startHard = taskId * (taskRepeat * multiple + stepHard);
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
int indStart = (taskId < remainT ? startHard: startEasy);
source = source + indStart * dimsize;
destination = destination + indStart * dimsize;
//-----------------------------------------allocate memory
float* src = nram_buffer;
float* tmp = src + maxNum;
float* destSum = tmp + dimS;
int remainDim = dimsize % dimS;//Dimsize may not be a power of 2
int repeatDim = (dimsize - remainDim) / dimS;
__nram__ float destSumFinal[warpSize];//Reduce destSum to destFinal [0]
__nram__ float destSumFinal[warpSize];
__nram__ float srcMax[2];
__nram__ float destOldMax;
__nram__ float destNewMax;
//-----------------------------------------
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
int tid;
for(int s = 0; s < taskRepeat; s++){
tid = s * multiple * dimsize;
__memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
for(int j = 0; j < multiple; j++){
__bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(dimsize >= maxNum){
float *src = nram_buffer;
float *destSum = src + 3 * maxNum;
int remain = dimsize % maxNum;
int repeat = (dimsize - remain)/maxNum;
int otherRemain = othersize % taskDim;
int stepEasy = (othersize - otherRemain) / taskDim;
int stepHard = stepEasy + 1;
int startHard = taskId * stepHard;
int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy;
int indStart = (taskId < otherRemain ? startHard : startEasy);
source = source + indStart * dimsize;
destination = destination + indStart * dimsize;
destOldMax = -INFINITY;
destNewMax = -INFINITY;
__bang_write_zero(destSum, maxNum);
for(int i = 0; i < repeat + 1; i++){
if(i < repeat){
__memcpy_async(src + i % 2 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
if(i > 0){
__bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
__bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum);
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);
if(i > 1){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, tmp, dimS);
__bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum);
destOldMax = destNewMax;
}
if(remainDim){
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
int segNum = dimS / warpSize;//Starting numerical summation
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
if(remainDim){
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//Now let's start writing back the data
float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
//it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM,
//there may be a situation where src writes back to GDRAM before modifying the src data
//------------
if(remain){
__bang_write_value(src, maxNum, -INFINITY);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(src, maxNum, destNewMax);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
if(repeat > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, src, maxNum);
destOldMax = destNewMax;
}
//--------------
//--------------------------------
__bang_write_zero(destSumFinal, warpSize);
int segNum = maxNum / warpSize;
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);
if(remain){
destSumFinal[0] = destSumFinal[0] - (maxNum - remain);
}
//-----------
float globalSumInv = 1.0/destSumFinal[0];
for(int i = 0; i < repeat + 2; i++){
if(i < repeat){
__memcpy_async(src + i % 3 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
if(i > 0 && i < repeat){
__bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum);
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);
__bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
if(remain){
__bang_write_value(src, maxNum, destNewMax);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
__bang_mul_scalar(src, src, globalSumInv, maxNum);
__memcpy(destination + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
}
}
if(step){//Step targets parts of othersize that cannot be divided by multiple * dimsize
tid = taskRepeat * multiple * dimsize;
__memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM);
for(int j = 0; j < step; j++){
else{
int multiple = maxNum / dimsize;
int size = taskDim * multiple;
int remainS = othersize % size;
int taskRepeat = (othersize - remainS) / size;
int remainT = remainS % taskDim;
int stepEasy = (remainS - remainT) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remainT ? stepHard : stepEasy);
//The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step
//Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize
int startHard = taskId * (taskRepeat * multiple + stepHard);
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
int indStart = (taskId < remainT ? startHard: startEasy);
source = source + indStart * dimsize;
destination = destination + indStart * dimsize;
//-----------------------------------------allocate memory
float* src = nram_buffer;//src[maxNum]
float* tmp = src + 3 * maxNum;//tmp[dimS]
float* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, warpSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize
//-----------------------------------------
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
int tid;
for(int s = 0; s < taskRepeat + 2; s++){
if(s < taskRepeat){
tid = s * multiple * dimsize;
__memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
}
if(s > 0 && s < taskRepeat + 1){
for(int j = 0; j < multiple; j++){
__bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize);
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
__bang_write_value(tmp, dimS, srcMax[0]);
__memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);//tmp[dimsize:dimS] = exp(0)
__bang_add(destSum, destSum, tmp, dimS);
int segNum = dimS / warpSize;//Starting numerical summation
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
//Now let's start writing back the data
float globalSumInv = 1.0/destSumFinal[0];
__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
__memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM);
}
}
if(s > 1){
tid = (s - 2) * multiple * dimsize;
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
//it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM,
//there may be a situation where src writes back to GDRAM before modifying the src data
}
for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize
tid = taskRepeat * multiple * dimsize + s * dimsize;
__bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){//RepeatDim refers to the total number of cycles required to read the current dimsize data using dimS after fixing otherIdx
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
if(remainDim){//RemainDim refers to the part of dimsize that cannot be divided by dimS after fixing otherIdx
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
int segNum = dimS / warpSize;//Starting numerical summation
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
__bang_write_value(tmp, dimS, srcMax[0]);
__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_add(destSum, destSum, tmp, dimS);
int segNum = dimS / warpSize;
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);
//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
if(remainDim){
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]);
destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
//__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]);
float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
}
}
}
__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
__memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM);
}
}
}
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
//-----------------------------------------allocate memory
float* src = nram_buffer;
float* tmpSum = src + maxNum;
float* tmpNewMax = src + 2 * maxNum;
float* tmpOldMax = src + 3 * maxNum;
float* src = nram_buffer;// src[3 * maxNum]
float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
//-----------------------------------------
int remain = othersize % taskDim;
int stepEasy = (othersize - remain)/taskDim;
@ -470,67 +529,89 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o
for(int j = 0; j < repeat; j++){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
for(int i = 0; i < dimsize + 1; i++){
if(i < dimsize){
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
if(i > 0){
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
if(i > 1){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__sync_all_ipu();
}
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
//Start exponential transformation and write back to GDRAM
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize + 2; i++){
if(i < dimsize){
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
if(i > 0 && i < dimsize + 1){
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
}
if(remainNram){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
__bang_write_zero(src, maxNum);
__bang_write_zero(src, 3 * maxNum);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
for(int i = 0; i < dimsize + 1; i++){
if(i < dimsize){
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
if(i > 0){
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
if(i > 1){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__sync_all_ipu();
}
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
//Start exponential transformation and write back to GDRAM
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize + 2; i++){
if(i < dimsize){
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
}
if(i > 0 && i < dimsize + 1){
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
}
}
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){
if(axis == nDim - 1){
int dimS;