stream kernel

This commit is contained in:
xgqdut2016 2024-03-07 09:01:00 +00:00
parent d4721cb40c
commit a6c919b61d
1 changed files with 330 additions and 249 deletions

View File

@ -4,7 +4,7 @@
const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768 const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768
const int nramNum = NRAM_MAX_SIZE/sizeof(float); const int nramNum = NRAM_MAX_SIZE/sizeof(float);
__nram__ float nram_buffer[nramNum]; __nram__ float nram_buffer[nramNum];
const int SRC_MAX_SIZE = 1024 * 128;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2 const int SRC_MAX_SIZE = 1024 * 64;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
//4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE //4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE
const int maxNum = SRC_MAX_SIZE/sizeof(float); const int maxNum = SRC_MAX_SIZE/sizeof(float);
const int warpSize = 32; const int warpSize = 32;
@ -98,7 +98,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
//-----------------------------------------allocate memory //-----------------------------------------allocate memory
float* src = nram_buffer; float* src = nram_buffer;
float* tmp = src + maxNum; float* tmp = src + 3 * maxNum;
float* tmpOldMax = tmp + strideS; float* tmpOldMax = tmp + strideS;
float* tmpNewMax = tmpOldMax + strideS; float* tmpNewMax = tmpOldMax + strideS;
float* tmpSum = tmpNewMax + strideS; float* tmpSum = tmpNewMax + strideS;
@ -123,26 +123,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity __bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity
__bang_write_zero(tmpSum, strideS);//Must be initialized to zero __bang_write_zero(tmpSum, strideS);//Must be initialized to zero
for(int j = 0; j < repeat; j++){ for(int j = 0; j < repeat + 1; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); if(j < repeat){
__memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
}
if(j > 0){
for(int m = 0; m < multiple; m++){ for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); __memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
__bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
__bang_active_exp_less_0(tmp, tmp, strideS); __bang_active_exp_less_0(tmp, tmp, strideS);
if(j != 0 || m != 0){ if(j != 1 || m != 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM) __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
} }
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
//if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
} }
} }
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]); __sync_all_ipu();
}
if(remain){ if(remain){
__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM); __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < remain; m++){ for(int m = 0; m < remain; m++){
@ -161,9 +166,9 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
} }
//At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum //At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
__bang_active_reciphp(tmpSum, tmpSum, strideS); __bang_active_reciphp(tmpSum, tmpSum, strideS);
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
if(remain){ if(remain){
for(int m = 0; m < remain; m++){ for(int m = 0; m < remain; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
@ -174,23 +179,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
} }
} }
for(int j = 0 ; j < repeat; j++){ for(int j = 0 ; j < repeat + 2; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM); if(j < repeat){
__memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
}
if(j > 0 && j < repeat + 1){
for(int m = 0; m < multiple; m++){ for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM); __memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS); __bang_sub(tmp, tmp, tmpNewMax, strideS);
__bang_active_exp_less_0(tmp, tmp, strideS); __bang_active_exp_less_0(tmp, tmp, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS); __bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM); __memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(float), NRAM2NRAM);
} }
} }
if(j > 1){
__memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
} }
} }
else if(dimsize * stride < maxNum){ else if(dimsize * stride < maxNum){
//-----------------------------------------allocate memory //-----------------------------------------allocate memory
float* src = nram_buffer; float* src = nram_buffer;
float* tmp = src + maxNum; float* tmp = src + 3 * maxNum;
float* tmpOldMax = tmp + strideS; float* tmpOldMax = tmp + strideS;
float* tmpNewMax = tmpOldMax + strideS; float* tmpNewMax = tmpOldMax + strideS;
float* tmpSum = tmpNewMax + strideS; float* tmpSum = tmpNewMax + strideS;
@ -211,15 +224,18 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds
destination = destination + indStart * behindsize; destination = destination + indStart * behindsize;
int tid; int tid;
for(int s = 0; s < taskRepeat; s++){ for(int s = 0; s < taskRepeat + 2; s++){
if(s < taskRepeat){
tid = s * multiple * behindsize; tid = s * multiple * behindsize;
__memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM); __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
}
if(s > 0 && s < taskRepeat + 1){
for(int m = 0; m < multiple; m++){ for(int m = 0; m < multiple; m++){
__bang_write_zero(tmpSum, strideS); __bang_write_zero(tmpSum, strideS);
__bang_write_value(tmp, strideS, -INFINITY); __bang_write_value(tmp, strideS, -INFINITY);
__bang_write_value(tmpNewMax, strideS, -INFINITY); __bang_write_value(tmpNewMax, strideS, -INFINITY);
for(int i = 0; i < dimsize; i++){ for(int i = 0; i < dimsize; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
@ -233,18 +249,23 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
} }
__bang_active_reciphp(tmpSum, tmpSum, strideS); __bang_active_reciphp(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM); __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
for(int i = 0; i < dimsize - 1; i++){ for(int i = 0; i < dimsize - 1; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM); __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
__bang_mul(tmp, tmp, tmpSum, strideS); __bang_mul(tmp, tmp, tmpSum, strideS);
//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM); __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
} }
} }
__memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM); }
if(s > 1){
tid = (s - 2) * multiple * behindsize;
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
} }
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize); //__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
if(step){ if(step){
@ -288,6 +309,111 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
} }
__mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1 __mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1
__nram__ float destSumFinal[warpSize];
__nram__ float srcMax[2];
__nram__ float destOldMax;
__nram__ float destNewMax;
if(dimsize >= maxNum){
float *src = nram_buffer;
float *destSum = src + 3 * maxNum;
int remain = dimsize % maxNum;
int repeat = (dimsize - remain)/maxNum;
int otherRemain = othersize % taskDim;
int stepEasy = (othersize - otherRemain) / taskDim;
int stepHard = stepEasy + 1;
int startHard = taskId * stepHard;
int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy;
int indStart = (taskId < otherRemain ? startHard : startEasy);
source = source + indStart * dimsize;
destination = destination + indStart * dimsize;
destOldMax = -INFINITY;
destNewMax = -INFINITY;
__bang_write_zero(destSum, maxNum);
for(int i = 0; i < repeat + 1; i++){
if(i < repeat){
__memcpy_async(src + i % 2 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
if(i > 0){
__bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum);
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);
if(i > 1){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum);
destOldMax = destNewMax;
}
__sync_all_ipu();
}
//------------
if(remain){
__bang_write_value(src, maxNum, -INFINITY);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(src, maxNum, destNewMax);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
if(repeat > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
}
__bang_add(destSum, destSum, src, maxNum);
destOldMax = destNewMax;
}
//--------------
//--------------------------------
__bang_write_zero(destSumFinal, warpSize);
int segNum = maxNum / warpSize;
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);
if(remain){
destSumFinal[0] = destSumFinal[0] - (maxNum - remain);
}
//-----------
float globalSumInv = 1.0/destSumFinal[0];
for(int i = 0; i < repeat + 2; i++){
if(i < repeat){
__memcpy_async(src + i % 3 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
}
if(i > 0 && i < repeat){
__bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum);
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);
__bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
}
if(remain){
__bang_write_value(src, maxNum, destNewMax);
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(src, src, destNewMax, maxNum);
__bang_active_exp_less_0(src, src, maxNum);
__bang_mul_scalar(src, src, globalSumInv, maxNum);
__memcpy(destination + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
}
}
else{
int multiple = maxNum / dimsize; int multiple = maxNum / dimsize;
int size = taskDim * multiple; int size = taskDim * multiple;
int remainS = othersize % size; int remainS = othersize % size;
@ -305,58 +431,31 @@ __mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int o
destination = destination + indStart * dimsize; destination = destination + indStart * dimsize;
//-----------------------------------------allocate memory //-----------------------------------------allocate memory
float* src = nram_buffer; float* src = nram_buffer;//src[maxNum]
float* tmp = src + maxNum; float* tmp = src + 3 * maxNum;//tmp[dimS]
float* destSum = tmp + dimS; float* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, warpSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize
int remainDim = dimsize % dimS;//Dimsize may not be a power of 2
int repeatDim = (dimsize - remainDim) / dimS;
__nram__ float destSumFinal[warpSize];//Reduce destSum to destFinal [0]
__nram__ float srcMax[2];
__nram__ float destOldMax;
__nram__ float destNewMax;
//----------------------------------------- //-----------------------------------------
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize); //printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
int tid; int tid;
for(int s = 0; s < taskRepeat; s++){ for(int s = 0; s < taskRepeat + 2; s++){
if(s < taskRepeat){
tid = s * multiple * dimsize; tid = s * multiple * dimsize;
__memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM); __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
}
if(s > 0 && s < taskRepeat + 1){
for(int j = 0; j < multiple; j++){ for(int j = 0; j < multiple; j++){
__bang_write_zero(destSum, dimS); __bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize); __bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
if(remainDim){
__bang_write_value(tmp, dimS, -INFINITY); __bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS); __bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){ __bang_write_value(tmp, dimS, srcMax[0]);
destNewMax = srcMax[0]; __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
} __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax __bang_active_exp_less_0(tmp, tmp, dimS);//tmp[dimsize:dimS] = exp(0)
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS); __bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
int segNum = dimS / warpSize;//Starting numerical summation int segNum = dimS / warpSize;//Starting numerical summation
for(int strip = segNum/2; strip > 0; strip = strip / 2){ for(int strip = segNum/2; strip > 0; strip = strip / 2){
@ -365,98 +464,58 @@ __mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int o
} }
} }
__bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum __bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
if(remainDim){ destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//Now let's start writing back the data //Now let's start writing back the data
float globalSumInv = 1.0/destSumFinal[0]; float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){ __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS); __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
} }
} }
if(s > 1){
tid = (s - 2) * multiple * dimsize;
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
//it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM, //it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM,
//there may be a situation where src writes back to GDRAM before modifying the src data //there may be a situation where src writes back to GDRAM before modifying the src data
} }
if(step){//Step targets parts of othersize that cannot be divided by multiple * dimsize for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize
tid = taskRepeat * multiple * dimsize; tid = taskRepeat * multiple * dimsize + s * dimsize;
__memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM);
for(int j = 0; j < step; j++){
__bang_write_zero(destSum, dimS); __bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize); __bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){//RepeatDim refers to the total number of cycles required to read the current dimsize data using dimS after fixing otherIdx
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
if(remainDim){//RemainDim refers to the part of dimsize that cannot be divided by dimS after fixing otherIdx
__bang_write_value(tmp, dimS, -INFINITY); __bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM); __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){ __bang_argmax(srcMax, tmp, dimS);
destNewMax = srcMax[0]; __bang_write_value(tmp, dimS, srcMax[0]);
} __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS); __bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS); __bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
} int segNum = dimS / warpSize;
int segNum = dimS / warpSize;//Starting numerical summation
for(int strip = segNum/2; strip > 0; strip = strip / 2){ for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){ for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize); __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
} }
} }
__bang_reduce_sum(destSumFinal, destSum, warpSize); __bang_reduce_sum(destSumFinal, destSum, warpSize);
//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
if(remainDim){ //__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]);
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]);
float globalSumInv = 1.0/destSumFinal[0]; float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){ __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS); __memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
}
} }
} }
} }
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0 __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
//-----------------------------------------allocate memory //-----------------------------------------allocate memory
float* src = nram_buffer; float* src = nram_buffer;// src[3 * maxNum]
float* tmpSum = src + maxNum; float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
float* tmpNewMax = src + 2 * maxNum; float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
float* tmpOldMax = src + 3 * maxNum; float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
//----------------------------------------- //-----------------------------------------
int remain = othersize % taskDim; int remain = othersize % taskDim;
int stepEasy = (othersize - remain)/taskDim; int stepEasy = (othersize - remain)/taskDim;
@ -470,67 +529,89 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o
for(int j = 0; j < repeat; j++){ for(int j = 0; j < repeat; j++){
__bang_write_value(tmpNewMax, maxNum, -INFINITY); __bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum); __bang_write_zero(tmpSum, maxNum);
for(int i = 0; i < dimsize; i++){ for(int i = 0; i < dimsize + 1; i++){
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); if(i < dimsize){
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M }
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){ if(i > 0){
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
if(i > 1){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM) __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
} }
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
} }
__sync_all_ipu();
}
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
//Start exponential transformation and write back to GDRAM //Start exponential transformation and write back to GDRAM
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized for(int i = 0; i < dimsize + 2; i++){
__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); if(i < dimsize){
for(int i = 0; i < dimsize - 1; i++){ __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM); }
__bang_sub(src, src, tmpNewMax, maxNum);//x - M if(i > 0 && i < dimsize + 1){
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M) __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
__bang_mul(src, src, tmpSum, maxNum); __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
__memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM); __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
} }
} }
if(remainNram){ if(remainNram){
__bang_write_value(tmpNewMax, maxNum, -INFINITY); __bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum); __bang_write_zero(tmpSum, maxNum);
__bang_write_zero(src, maxNum); __bang_write_zero(src, 3 * maxNum);
for(int i = 0; i < dimsize; i++){ for(int i = 0; i < dimsize + 1; i++){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); if(i < dimsize){
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum); __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M }
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){ if(i > 0){
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
if(i > 1){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
} }
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
} }
__sync_all_ipu();
}
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
//Start exponential transformation and write back to GDRAM //Start exponential transformation and write back to GDRAM
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); for(int i = 0; i < dimsize + 2; i++){
for(int i = 0; i < dimsize - 1; i++){ if(i < dimsize){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM); __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M }
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M) if(i > 0 && i < dimsize + 1){
__bang_mul(src, src, tmpSum, maxNum); __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
__memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM); __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
}
if(i > 1){
__memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(float), NRAM2GDRAM);
}
__sync_all_ipu();
} }
} }
} }
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){ __mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){
if(axis == nDim - 1){ if(axis == nDim - 1){
int dimS; int dimS;