forked from jiuyuan/InfiniTensor
stream kernel
This commit is contained in:
parent
d4721cb40c
commit
a6c919b61d
|
@ -4,7 +4,7 @@
|
||||||
const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768
|
const int NRAM_MAX_SIZE = 1024 * 512;//the maximum NRAM memory is 1024 * 768
|
||||||
const int nramNum = NRAM_MAX_SIZE/sizeof(float);
|
const int nramNum = NRAM_MAX_SIZE/sizeof(float);
|
||||||
__nram__ float nram_buffer[nramNum];
|
__nram__ float nram_buffer[nramNum];
|
||||||
const int SRC_MAX_SIZE = 1024 * 128;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
|
const int SRC_MAX_SIZE = 1024 * 64;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
|
||||||
//4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE
|
//4 * SRC_MAX_SIZE must <= NRAM_MAX_SIZE
|
||||||
const int maxNum = SRC_MAX_SIZE/sizeof(float);
|
const int maxNum = SRC_MAX_SIZE/sizeof(float);
|
||||||
const int warpSize = 32;
|
const int warpSize = 32;
|
||||||
|
@ -98,7 +98,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
|
|
||||||
//-----------------------------------------allocate memory
|
//-----------------------------------------allocate memory
|
||||||
float* src = nram_buffer;
|
float* src = nram_buffer;
|
||||||
float* tmp = src + maxNum;
|
float* tmp = src + 3 * maxNum;
|
||||||
float* tmpOldMax = tmp + strideS;
|
float* tmpOldMax = tmp + strideS;
|
||||||
float* tmpNewMax = tmpOldMax + strideS;
|
float* tmpNewMax = tmpOldMax + strideS;
|
||||||
float* tmpSum = tmpNewMax + strideS;
|
float* tmpSum = tmpNewMax + strideS;
|
||||||
|
@ -123,26 +123,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
__bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity
|
__bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity
|
||||||
__bang_write_zero(tmpSum, strideS);//Must be initialized to zero
|
__bang_write_zero(tmpSum, strideS);//Must be initialized to zero
|
||||||
|
|
||||||
for(int j = 0; j < repeat; j++){
|
for(int j = 0; j < repeat + 1; j++){
|
||||||
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
|
if(j < repeat){
|
||||||
for(int m = 0; m < multiple; m++){
|
__memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
|
||||||
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
|
|
||||||
|
|
||||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
|
|
||||||
|
|
||||||
__bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, strideS);
|
|
||||||
if(j != 0 || m != 0){
|
|
||||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
|
|
||||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
|
|
||||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
|
|
||||||
}
|
|
||||||
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
|
|
||||||
//if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);
|
|
||||||
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
|
|
||||||
}
|
}
|
||||||
|
if(j > 0){
|
||||||
|
for(int m = 0; m < multiple; m++){
|
||||||
|
__memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
|
||||||
|
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
|
||||||
|
|
||||||
|
__bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
|
||||||
|
__bang_active_exp_less_0(tmp, tmp, strideS);
|
||||||
|
if(j != 1 || m != 0){
|
||||||
|
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
|
||||||
|
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
|
||||||
|
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
|
||||||
|
}
|
||||||
|
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
|
||||||
|
|
||||||
|
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]);
|
|
||||||
if(remain){
|
if(remain){
|
||||||
__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
|
__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
|
||||||
for(int m = 0; m < remain; m++){
|
for(int m = 0; m < remain; m++){
|
||||||
|
@ -161,9 +166,9 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
}
|
}
|
||||||
|
|
||||||
//At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum
|
//At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum
|
||||||
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
|
|
||||||
__bang_active_reciphp(tmpSum, tmpSum, strideS);
|
__bang_active_reciphp(tmpSum, tmpSum, strideS);
|
||||||
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
|
|
||||||
if(remain){
|
if(remain){
|
||||||
for(int m = 0; m < remain; m++){
|
for(int m = 0; m < remain; m++){
|
||||||
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
|
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
@ -174,23 +179,31 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
for(int j = 0 ; j < repeat; j++){
|
for(int j = 0 ; j < repeat + 2; j++){
|
||||||
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
|
if(j < repeat){
|
||||||
for(int m = 0; m < multiple; m++){
|
__memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
|
||||||
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
|
|
||||||
|
|
||||||
__bang_sub(tmp, tmp, tmpNewMax, strideS);
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, strideS);
|
|
||||||
__bang_mul(tmp, tmp, tmpSum, strideS);
|
|
||||||
__memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
|
|
||||||
}
|
}
|
||||||
|
if(j > 0 && j < repeat + 1){
|
||||||
|
for(int m = 0; m < multiple; m++){
|
||||||
|
__memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
|
||||||
|
__bang_sub(tmp, tmp, tmpNewMax, strideS);
|
||||||
|
__bang_active_exp_less_0(tmp, tmp, strideS);
|
||||||
|
__bang_mul(tmp, tmp, tmpSum, strideS);
|
||||||
|
__memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(j > 1){
|
||||||
|
__memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(dimsize * stride < maxNum){
|
else if(dimsize * stride < maxNum){
|
||||||
//-----------------------------------------allocate memory
|
//-----------------------------------------allocate memory
|
||||||
float* src = nram_buffer;
|
float* src = nram_buffer;
|
||||||
float* tmp = src + maxNum;
|
float* tmp = src + 3 * maxNum;
|
||||||
float* tmpOldMax = tmp + strideS;
|
float* tmpOldMax = tmp + strideS;
|
||||||
float* tmpNewMax = tmpOldMax + strideS;
|
float* tmpNewMax = tmpOldMax + strideS;
|
||||||
float* tmpSum = tmpNewMax + strideS;
|
float* tmpSum = tmpNewMax + strideS;
|
||||||
|
@ -211,40 +224,48 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds
|
source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds
|
||||||
destination = destination + indStart * behindsize;
|
destination = destination + indStart * behindsize;
|
||||||
int tid;
|
int tid;
|
||||||
for(int s = 0; s < taskRepeat; s++){
|
for(int s = 0; s < taskRepeat + 2; s++){
|
||||||
tid = s * multiple * behindsize;
|
if(s < taskRepeat){
|
||||||
__memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
|
tid = s * multiple * behindsize;
|
||||||
for(int m = 0; m < multiple; m++){
|
__memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
|
||||||
__bang_write_zero(tmpSum, strideS);
|
}
|
||||||
__bang_write_value(tmp, strideS, -INFINITY);
|
if(s > 0 && s < taskRepeat + 1){
|
||||||
__bang_write_value(tmpNewMax, strideS, -INFINITY);
|
for(int m = 0; m < multiple; m++){
|
||||||
for(int i = 0; i < dimsize; i++){
|
__bang_write_zero(tmpSum, strideS);
|
||||||
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
|
__bang_write_value(tmp, strideS, -INFINITY);
|
||||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
|
__bang_write_value(tmpNewMax, strideS, -INFINITY);
|
||||||
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
|
for(int i = 0; i < dimsize; i++){
|
||||||
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
|
__memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
|
||||||
if(i > 0){
|
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
|
||||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
|
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
|
||||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
|
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
|
||||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
|
if(i > 0){
|
||||||
|
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
|
||||||
|
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
|
||||||
|
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
|
||||||
|
}
|
||||||
|
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
|
||||||
|
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||||
|
}
|
||||||
|
__bang_active_reciphp(tmpSum, tmpSum, strideS);
|
||||||
|
__bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
|
||||||
|
|
||||||
|
__memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
for(int i = 0; i < dimsize - 1; i++){
|
||||||
|
__memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
|
||||||
|
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
|
||||||
|
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
|
||||||
|
__bang_mul(tmp, tmp, tmpSum, strideS);
|
||||||
|
|
||||||
|
__memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
|
||||||
}
|
}
|
||||||
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
|
|
||||||
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
|
|
||||||
}
|
|
||||||
__bang_active_reciphp(tmpSum, tmpSum, strideS);
|
|
||||||
__bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
|
|
||||||
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
|
|
||||||
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
|
|
||||||
for(int i = 0; i < dimsize - 1; i++){
|
|
||||||
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
|
|
||||||
__bang_mul(tmp, tmp, tmpSum, strideS);
|
|
||||||
//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
|
|
||||||
__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM);
|
if(s > 1){
|
||||||
|
tid = (s - 2) * multiple * behindsize;
|
||||||
|
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
|
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
|
||||||
if(step){
|
if(step){
|
||||||
|
@ -288,175 +309,213 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
|
||||||
}
|
}
|
||||||
|
|
||||||
__mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1
|
__mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize, int dimS) {// axis = -1
|
||||||
int multiple = maxNum / dimsize;
|
__nram__ float destSumFinal[warpSize];
|
||||||
int size = taskDim * multiple;
|
|
||||||
int remainS = othersize % size;
|
|
||||||
int taskRepeat = (othersize - remainS) / size;
|
|
||||||
int remainT = remainS % taskDim;
|
|
||||||
int stepEasy = (remainS - remainT) / taskDim;
|
|
||||||
int stepHard = stepEasy + 1;
|
|
||||||
int step = (taskId < remainT ? stepHard : stepEasy);
|
|
||||||
//The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step
|
|
||||||
//Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize
|
|
||||||
int startHard = taskId * (taskRepeat * multiple + stepHard);
|
|
||||||
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
|
|
||||||
int indStart = (taskId < remainT ? startHard: startEasy);
|
|
||||||
source = source + indStart * dimsize;
|
|
||||||
destination = destination + indStart * dimsize;
|
|
||||||
|
|
||||||
//-----------------------------------------allocate memory
|
|
||||||
float* src = nram_buffer;
|
|
||||||
float* tmp = src + maxNum;
|
|
||||||
float* destSum = tmp + dimS;
|
|
||||||
int remainDim = dimsize % dimS;//Dimsize may not be a power of 2
|
|
||||||
int repeatDim = (dimsize - remainDim) / dimS;
|
|
||||||
|
|
||||||
__nram__ float destSumFinal[warpSize];//Reduce destSum to destFinal [0]
|
|
||||||
__nram__ float srcMax[2];
|
__nram__ float srcMax[2];
|
||||||
__nram__ float destOldMax;
|
__nram__ float destOldMax;
|
||||||
__nram__ float destNewMax;
|
__nram__ float destNewMax;
|
||||||
//-----------------------------------------
|
if(dimsize >= maxNum){
|
||||||
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
|
float *src = nram_buffer;
|
||||||
int tid;
|
float *destSum = src + 3 * maxNum;
|
||||||
for(int s = 0; s < taskRepeat; s++){
|
|
||||||
tid = s * multiple * dimsize;
|
|
||||||
__memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
|
|
||||||
for(int j = 0; j < multiple; j++){
|
|
||||||
__bang_write_zero(destSum, dimS);
|
|
||||||
__bang_write_zero(destSumFinal, warpSize);
|
|
||||||
destNewMax = -INFINITY;
|
|
||||||
|
|
||||||
for(int i = 0; i < repeatDim; i++){
|
int remain = dimsize % maxNum;
|
||||||
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
|
int repeat = (dimsize - remain)/maxNum;
|
||||||
__bang_argmax(srcMax, tmp, dimS);
|
|
||||||
|
int otherRemain = othersize % taskDim;
|
||||||
|
int stepEasy = (othersize - otherRemain) / taskDim;
|
||||||
|
int stepHard = stepEasy + 1;
|
||||||
|
|
||||||
|
int startHard = taskId * stepHard;
|
||||||
|
int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy;
|
||||||
|
int indStart = (taskId < otherRemain ? startHard : startEasy);
|
||||||
|
source = source + indStart * dimsize;
|
||||||
|
destination = destination + indStart * dimsize;
|
||||||
|
|
||||||
|
|
||||||
|
destOldMax = -INFINITY;
|
||||||
|
destNewMax = -INFINITY;
|
||||||
|
__bang_write_zero(destSum, maxNum);
|
||||||
|
for(int i = 0; i < repeat + 1; i++){
|
||||||
|
if(i < repeat){
|
||||||
|
__memcpy_async(src + i % 2 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||||
|
}
|
||||||
|
if(i > 0){
|
||||||
|
__bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum);
|
||||||
if(destNewMax < srcMax[0]){
|
if(destNewMax < srcMax[0]){
|
||||||
destNewMax = srcMax[0];
|
destNewMax = srcMax[0];
|
||||||
}
|
}
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
__bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum);
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);
|
||||||
if(i > 0){
|
if(i > 1){
|
||||||
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
|
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
|
||||||
}
|
}
|
||||||
__bang_add(destSum, destSum, tmp, dimS);
|
__bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum);
|
||||||
destOldMax = destNewMax;
|
|
||||||
}
|
|
||||||
if(remainDim){
|
|
||||||
__bang_write_value(tmp, dimS, -INFINITY);
|
|
||||||
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_argmax(srcMax, tmp, dimS);
|
|
||||||
if(destNewMax < srcMax[0]){
|
|
||||||
destNewMax = srcMax[0];
|
|
||||||
}
|
|
||||||
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax
|
|
||||||
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
|
||||||
if(repeatDim > 0){
|
|
||||||
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
|
|
||||||
}
|
|
||||||
__bang_add(destSum, destSum, tmp, dimS);
|
|
||||||
destOldMax = destNewMax;
|
destOldMax = destNewMax;
|
||||||
}
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
|
}
|
||||||
|
//------------
|
||||||
|
if(remain){
|
||||||
|
__bang_write_value(src, maxNum, -INFINITY);
|
||||||
|
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
|
||||||
|
|
||||||
int segNum = dimS / warpSize;//Starting numerical summation
|
__bang_argmax(srcMax, src, maxNum);
|
||||||
for(int strip = segNum/2; strip > 0; strip = strip / 2){
|
if(destNewMax < srcMax[0]){
|
||||||
for(int i = 0; i < strip ; i++){
|
destNewMax = srcMax[0];
|
||||||
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
__bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
|
__bang_write_value(src, maxNum, destNewMax);
|
||||||
if(remainDim){
|
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
|
||||||
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
|
__bang_sub_scalar(src, src, destNewMax, maxNum);
|
||||||
|
__bang_active_exp_less_0(src, src, maxNum);
|
||||||
|
if(repeat > 0){
|
||||||
|
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
|
||||||
}
|
}
|
||||||
//Now let's start writing back the data
|
__bang_add(destSum, destSum, src, maxNum);
|
||||||
float globalSumInv = 1.0/destSumFinal[0];
|
destOldMax = destNewMax;
|
||||||
if(remainDim){
|
}
|
||||||
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
|
//--------------
|
||||||
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
|
//--------------------------------
|
||||||
}
|
__bang_write_zero(destSumFinal, warpSize);
|
||||||
for(int i = 0; i < repeatDim; i++){
|
int segNum = maxNum / warpSize;
|
||||||
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
|
for(int strip = segNum/2; strip > 0; strip = strip / 2){
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
for(int i = 0; i < strip ; i++){
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
|
||||||
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
|
|
||||||
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM,
|
__bang_reduce_sum(destSumFinal, destSum, warpSize);
|
||||||
//there may be a situation where src writes back to GDRAM before modifying the src data
|
|
||||||
|
if(remain){
|
||||||
|
destSumFinal[0] = destSumFinal[0] - (maxNum - remain);
|
||||||
|
}
|
||||||
|
//-----------
|
||||||
|
float globalSumInv = 1.0/destSumFinal[0];
|
||||||
|
for(int i = 0; i < repeat + 2; i++){
|
||||||
|
if(i < repeat){
|
||||||
|
__memcpy_async(src + i % 3 * maxNum, source + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||||
|
}
|
||||||
|
if(i > 0 && i < repeat){
|
||||||
|
__bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum);
|
||||||
|
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);
|
||||||
|
__bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum);
|
||||||
|
}
|
||||||
|
if(i > 1){
|
||||||
|
__memcpy_async(destination + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
|
|
||||||
|
}
|
||||||
|
if(remain){
|
||||||
|
__bang_write_value(src, maxNum, destNewMax);
|
||||||
|
__memcpy(src, source + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
|
||||||
|
__bang_sub_scalar(src, src, destNewMax, maxNum);
|
||||||
|
__bang_active_exp_less_0(src, src, maxNum);
|
||||||
|
__bang_mul_scalar(src, src, globalSumInv, maxNum);
|
||||||
|
__memcpy(destination + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if(step){//Step targets parts of othersize that cannot be divided by multiple * dimsize
|
else{
|
||||||
tid = taskRepeat * multiple * dimsize;
|
int multiple = maxNum / dimsize;
|
||||||
__memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM);
|
int size = taskDim * multiple;
|
||||||
for(int j = 0; j < step; j++){
|
int remainS = othersize % size;
|
||||||
|
int taskRepeat = (othersize - remainS) / size;
|
||||||
|
int remainT = remainS % taskDim;
|
||||||
|
int stepEasy = (remainS - remainT) / taskDim;
|
||||||
|
int stepHard = stepEasy + 1;
|
||||||
|
int step = (taskId < remainT ? stepHard : stepEasy);
|
||||||
|
//The amount allocated for processing othersize for each taskId is taskRepeat * multiple+step
|
||||||
|
//Overall, the amount of data processed by each taskId is (taskRepeat * multiple+step) * dimsize
|
||||||
|
int startHard = taskId * (taskRepeat * multiple + stepHard);
|
||||||
|
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
|
||||||
|
int indStart = (taskId < remainT ? startHard: startEasy);
|
||||||
|
source = source + indStart * dimsize;
|
||||||
|
destination = destination + indStart * dimsize;
|
||||||
|
|
||||||
|
//-----------------------------------------allocate memory
|
||||||
|
float* src = nram_buffer;//src[maxNum]
|
||||||
|
float* tmp = src + 3 * maxNum;//tmp[dimS]
|
||||||
|
float* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, warpSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize
|
||||||
|
|
||||||
|
//-----------------------------------------
|
||||||
|
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
|
||||||
|
int tid;
|
||||||
|
for(int s = 0; s < taskRepeat + 2; s++){
|
||||||
|
if(s < taskRepeat){
|
||||||
|
tid = s * multiple * dimsize;
|
||||||
|
__memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
|
||||||
|
}
|
||||||
|
if(s > 0 && s < taskRepeat + 1){
|
||||||
|
for(int j = 0; j < multiple; j++){
|
||||||
|
__bang_write_zero(destSum, dimS);
|
||||||
|
__bang_write_zero(destSumFinal, warpSize);
|
||||||
|
__bang_write_value(tmp, dimS, -INFINITY);
|
||||||
|
|
||||||
|
__memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
|
||||||
|
__bang_argmax(srcMax, tmp, dimS);
|
||||||
|
__bang_write_value(tmp, dimS, srcMax[0]);
|
||||||
|
__memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
|
||||||
|
__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
|
||||||
|
__bang_active_exp_less_0(tmp, tmp, dimS);//tmp[dimsize:dimS] = exp(0)
|
||||||
|
__bang_add(destSum, destSum, tmp, dimS);
|
||||||
|
|
||||||
|
int segNum = dimS / warpSize;//Starting numerical summation
|
||||||
|
for(int strip = segNum/2; strip > 0; strip = strip / 2){
|
||||||
|
for(int i = 0; i < strip ; i++){
|
||||||
|
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__bang_reduce_sum(destSumFinal, destSum, warpSize);//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
|
||||||
|
destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
|
||||||
|
//Now let's start writing back the data
|
||||||
|
float globalSumInv = 1.0/destSumFinal[0];
|
||||||
|
__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
|
||||||
|
__memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(s > 1){
|
||||||
|
tid = (s - 2) * multiple * dimsize;
|
||||||
|
__memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
|
//it is necessary to write back to GDRAM immediately. If you first write back to src and then write back to GDRAM,
|
||||||
|
//there may be a situation where src writes back to GDRAM before modifying the src data
|
||||||
|
}
|
||||||
|
for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize
|
||||||
|
tid = taskRepeat * multiple * dimsize + s * dimsize;
|
||||||
__bang_write_zero(destSum, dimS);
|
__bang_write_zero(destSum, dimS);
|
||||||
__bang_write_zero(destSumFinal, warpSize);
|
__bang_write_zero(destSumFinal, warpSize);
|
||||||
destNewMax = -INFINITY;
|
__bang_write_value(tmp, dimS, -INFINITY);
|
||||||
for(int i = 0; i < repeatDim; i++){//RepeatDim refers to the total number of cycles required to read the current dimsize data using dimS after fixing otherIdx
|
__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
|
||||||
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_argmax(srcMax, tmp, dimS);
|
|
||||||
if(destNewMax < srcMax[0]){
|
|
||||||
destNewMax = srcMax[0];
|
|
||||||
}
|
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
|
||||||
if(i > 0){
|
|
||||||
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
|
|
||||||
}
|
|
||||||
__bang_add(destSum, destSum, tmp, dimS);
|
|
||||||
destOldMax = destNewMax;
|
|
||||||
}
|
|
||||||
if(remainDim){//RemainDim refers to the part of dimsize that cannot be divided by dimS after fixing otherIdx
|
|
||||||
__bang_write_value(tmp, dimS, -INFINITY);
|
|
||||||
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_argmax(srcMax, tmp, dimS);
|
|
||||||
if(destNewMax < srcMax[0]){
|
|
||||||
destNewMax = srcMax[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
__bang_write_value(tmp, dimS, destNewMax);//Must be reinitialized to NewMax
|
__bang_argmax(srcMax, tmp, dimS);
|
||||||
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
|
__bang_write_value(tmp, dimS, srcMax[0]);
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
__memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
__bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
|
||||||
if(repeatDim > 0){
|
|
||||||
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
|
__bang_active_exp_less_0(tmp, tmp, dimS);
|
||||||
}
|
__bang_add(destSum, destSum, tmp, dimS);
|
||||||
__bang_add(destSum, destSum, tmp, dimS);
|
|
||||||
destOldMax = destNewMax;
|
int segNum = dimS / warpSize;
|
||||||
}
|
|
||||||
int segNum = dimS / warpSize;//Starting numerical summation
|
|
||||||
for(int strip = segNum/2; strip > 0; strip = strip / 2){
|
for(int strip = segNum/2; strip > 0; strip = strip / 2){
|
||||||
for(int i = 0; i < strip ; i++){
|
for(int i = 0; i < strip ; i++){
|
||||||
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
|
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__bang_reduce_sum(destSumFinal, destSum, warpSize);
|
__bang_reduce_sum(destSumFinal, destSum, warpSize);
|
||||||
//At this point, destSumFinal [0] saves the numerical value of the current dimsize length data sum
|
destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
|
||||||
if(remainDim){
|
//__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]);
|
||||||
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
|
|
||||||
}
|
|
||||||
//__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]);
|
|
||||||
float globalSumInv = 1.0/destSumFinal[0];
|
float globalSumInv = 1.0/destSumFinal[0];
|
||||||
if(remainDim){
|
__bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
|
||||||
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
|
__memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM);
|
||||||
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
|
|
||||||
}
|
|
||||||
for(int i = 0; i < repeatDim; i++){
|
|
||||||
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
|
|
||||||
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
|
|
||||||
__bang_active_exp_less_0(tmp, tmp, dimS);
|
|
||||||
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
|
|
||||||
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
|
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
|
||||||
//-----------------------------------------allocate memory
|
//-----------------------------------------allocate memory
|
||||||
float* src = nram_buffer;
|
float* src = nram_buffer;// src[3 * maxNum]
|
||||||
float* tmpSum = src + maxNum;
|
float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
|
||||||
float* tmpNewMax = src + 2 * maxNum;
|
float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
|
||||||
float* tmpOldMax = src + 3 * maxNum;
|
float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
|
||||||
//-----------------------------------------
|
//-----------------------------------------
|
||||||
int remain = othersize % taskDim;
|
int remain = othersize % taskDim;
|
||||||
int stepEasy = (othersize - remain)/taskDim;
|
int stepEasy = (othersize - remain)/taskDim;
|
||||||
|
@ -470,67 +529,89 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o
|
||||||
for(int j = 0; j < repeat; j++){
|
for(int j = 0; j < repeat; j++){
|
||||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||||
__bang_write_zero(tmpSum, maxNum);
|
__bang_write_zero(tmpSum, maxNum);
|
||||||
for(int i = 0; i < dimsize; i++){
|
for(int i = 0; i < dimsize + 1; i++){
|
||||||
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
if(i < dimsize){
|
||||||
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value
|
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||||
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
|
|
||||||
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
|
|
||||||
if(i > 0){
|
|
||||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
|
||||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
|
||||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
|
|
||||||
}
|
}
|
||||||
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
|
if(i > 0){
|
||||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value
|
||||||
|
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
|
||||||
|
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
|
||||||
|
if(i > 1){
|
||||||
|
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
||||||
|
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
||||||
|
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
|
||||||
|
}
|
||||||
|
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
|
||||||
|
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
||||||
//Start exponential transformation and write back to GDRAM
|
//Start exponential transformation and write back to GDRAM
|
||||||
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
|
for(int i = 0; i < dimsize + 2; i++){
|
||||||
__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
|
if(i < dimsize){
|
||||||
for(int i = 0; i < dimsize - 1; i++){
|
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||||
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
}
|
||||||
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
|
if(i > 0 && i < dimsize + 1){
|
||||||
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
|
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
|
||||||
__bang_mul(src, src, tmpSum, maxNum);
|
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
|
||||||
__memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
|
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||||
|
}
|
||||||
|
if(i > 1){
|
||||||
|
__memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(remainNram){
|
if(remainNram){
|
||||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||||
__bang_write_zero(tmpSum, maxNum);
|
__bang_write_zero(tmpSum, maxNum);
|
||||||
__bang_write_zero(src, maxNum);
|
__bang_write_zero(src, 3 * maxNum);
|
||||||
|
|
||||||
|
|
||||||
for(int i = 0; i < dimsize; i++){
|
for(int i = 0; i < dimsize + 1; i++){
|
||||||
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
if(i < dimsize){
|
||||||
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
|
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
||||||
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
|
|
||||||
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
|
|
||||||
if(i > 0){
|
|
||||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
|
||||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
|
||||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
|
|
||||||
}
|
}
|
||||||
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
|
if(i > 0){
|
||||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
|
||||||
|
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
|
||||||
|
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
|
||||||
|
if(i > 1){
|
||||||
|
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
||||||
|
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
||||||
|
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
|
||||||
|
}
|
||||||
|
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
|
||||||
|
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
|
|
||||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
||||||
//Start exponential transformation and write back to GDRAM
|
//Start exponential transformation and write back to GDRAM
|
||||||
__bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
|
|
||||||
__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
|
for(int i = 0; i < dimsize + 2; i++){
|
||||||
for(int i = 0; i < dimsize - 1; i++){
|
if(i < dimsize){
|
||||||
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
||||||
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
|
}
|
||||||
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
|
if(i > 0 && i < dimsize + 1){
|
||||||
__bang_mul(src, src, tmpSum, maxNum);
|
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
|
||||||
__memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
|
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
|
||||||
|
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||||
|
}
|
||||||
|
if(i > 1){
|
||||||
|
__memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(float), NRAM2GDRAM);
|
||||||
|
}
|
||||||
|
__sync_all_ipu();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){
|
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int nDim, int axis, int othersize, int frontsize, int dimsize, int stride){
|
||||||
if(axis == nDim - 1){
|
if(axis == nDim - 1){
|
||||||
int dimS;
|
int dimS;
|
||||||
|
|
Loading…
Reference in New Issue