@ -0,0 +1,536 @@
#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 64;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;
//strideS是大于等于stride的最小的二的幂次方
__mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {
// 0<axis<dim -1
__nram__ float src[maxNum];
if(stride >= maxNum){
__nram__ float tmpSum[maxNum];
__nram__ float tmpNewMax[maxNum];
__nram__ float tmpOldMax[maxNum];
int remain = stride % maxNum;
int repeat = (stride - remain) / maxNum;
int taskRemain = frontsize % taskDim;
int stepEasy = (frontsize - taskRemain) / taskDim;
int stepHard = stepEasy + 1;
int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy);
source = source + indStart * dimsize * stride;
destination = destination + indStart * dimsize * stride;
for(int ind = taskId; ind < frontsize; ind += taskDim){
int frontIdx = ind * dimsize * stride;
for(int j = 0; j < repeat; j++){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
__bang_write_zero(src, maxNum);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
}
}
if(remain){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
__bang_write_value(src, maxNum, -INFINITY);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
//-------------------
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
}
//---------------------
}
}
}
else if(stride < maxNum && dimsize * stride >= maxNum){
const int strideS = 1024;
__nram__ float tmp[strideS];
__nram__ float tmpOldMax[strideS];
__nram__ float tmpNewMax[strideS];
__nram__ float tmpSum[strideS];
int multiple = maxNum / stride;
int size = multiple * stride;//一个src最多可以放的数据量
int remain = dimsize % multiple;//如果不能整除,这部分数据需要特殊处理
int repeat = (dimsize - remain) / multiple;//为了加载整个dimsize需要的循环总数
int taskRemain = frontsize % taskDim;
int stepEasy = (frontsize - taskRemain) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < taskRemain ? stepHard : stepEasy);//每个taskId处理frontsize的数目
int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy);
source = source + indStart * dimsize * stride;
destination = destination + indStart * dimsize * stride;
//printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain);
for(int ind = 0; ind < step; ind++){
int frontIdx = ind * dimsize * stride;
__bang_write_value(tmpNewMax, strideS, -INFINITY);//必须初始化为负无穷
__bang_write_value(tmp, strideS, -INFINITY);//必须初始化为负无穷
__bang_write_zero(tmpSum, strideS);//必须初始化为0
for(int j = 0; j < repeat; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//虽然tmpNewMax后面strideS-stride部分是0, 但是不用写回GDRAM, 不影响结果
__bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0
__bang_active_exp_less_0(tmp, tmp, strideS);
if(j != 0 || m != 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
//if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
}
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]);
if(remain){
__memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < remain; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0
__bang_active_exp_less_0(tmp, tmp, strideS);
if(repeat != 0 || m != 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
}
//此时tmpNewMax存储的是对应于固定frontIdx, behindsize对应数据的最大值, 而tmpSum存储的就是对应数值和
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
__bang_active_recip(tmpSum, tmpSum, strideS);
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
if(remain){
for(int m = 0; m < remain; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);
__bang_active_exp_less_0(tmp, tmp, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
}
}
for(int j = 0 ; j < repeat; j++){
__memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);
__bang_active_exp_less_0(tmp, tmp, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);
__memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
}
}
}
}
else if(dimsize * stride < maxNum){
const int strideS = 32;
__nram__ float tmp[strideS];
__nram__ float tmpOldMax[strideS];
__nram__ float tmpNewMax[strideS];
__nram__ float tmpSum[strideS];
int behindsize = dimsize * stride;
int multiple = maxNum / behindsize;//表示一个maxNum能够在frontsize中分担的量
int remainF = frontsize % (taskDim * multiple);
int remainT = remainF % taskDim;
int stepEasy = (remainF - remainT) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remainT ? stepHard : stepEasy);
int taskRepeat = (frontsize - remainF) / (taskDim * multiple);
//此时对应于frontsize, 每个taskId处理的数据量是taskRepeat * multiple + step
int startHard = taskId * (taskRepeat * multiple + stepHard);
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
int indStart = (taskId < remainT ? startHard: startEasy);
source = source + indStart * behindsize;//indStart * behindsize表示不同taskId对应的偏移量
destination = destination + indStart * behindsize;
int tid;
for(int s = 0; s < taskRepeat; s++){
tid = s * multiple * behindsize;
__memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < multiple; m++){
__bang_write_zero(tmpSum, strideS);
__bang_write_value(tmp, strideS, -INFINITY);
__bang_write_value(tmpNewMax, strideS, -INFINITY);
for(int i = 0; i < dimsize; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
__bang_mul(tmp, tmp, tmpSum, strideS);
//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
}
}
__memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM);
}
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
if(step){
tid = taskRepeat * multiple * behindsize;
__memcpy(src, source + tid, step * behindsize * sizeof(float), GDRAM2NRAM);
for(int m = 0; m < step; m++){
__bang_write_zero(tmpSum, strideS);
__bang_write_value(tmp, strideS, -INFINITY);
__bang_write_value(tmpNewMax, strideS, -INFINITY);
for(int i = 0; i < dimsize; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
//__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
__bang_active_recip(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
__bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
__bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
__bang_mul(tmp, tmp, tmpSum, strideS);
//__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
}
}
__memcpy(destination + tid, src, step * behindsize * sizeof(float), NRAM2GDRAM);
}
}
}
const int dimS = 32;
__mlu_device__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize) {// axis = -1
int multiple = maxNum / dimsize;
int size = taskDim * multiple;
int remainS = othersize % size;
int taskRepeat = (othersize - remainS) / size;
int remainT = remainS % taskDim;
int stepEasy = (remainS - remainT) / taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remainT ? stepHard : stepEasy);
//每个taskId处理othersize分配的量就是taskRepeat * multiple + step
//整体来看, 每个taskId处理的数据量就是(taskRepeat * multiple + step) * dimsize
int startHard = taskId * (taskRepeat * multiple + stepHard);
int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
int indStart = (taskId < remainT ? startHard: startEasy);
source = source + indStart * dimsize;
destination = destination + indStart * dimsize;
__nram__ float src[maxNum];
__nram__ float tmp[dimS];
__nram__ float destSum[dimS];//后面数值求和
int remainDim = dimsize % dimS;//dimsize可能不是2的幂次方
int repeatDim = (dimsize - remainDim) / dimS;
__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
__nram__ float srcMax[2];
__nram__ float destOldMax;
__nram__ float destNewMax;
//printf("taskId:%d, taskRepeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, taskRepeat, step, repeatDim, indStart, indStart * dimsize);
int tid;
for(int s = 0; s < taskRepeat; s++){
tid = s * multiple * dimsize;
__memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
for(int j = 0; j < multiple; j++){
__bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
if(remainDim){
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(tmp, dimS, destNewMax);//必须重新初始化为NewMax
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
int segNum = dimS / warpSize;//开始数值求和
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和
if(remainDim){
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//下面开始写回数据
float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
}
}//必须马上写回GDRAM, 如果先写回src, 然后src写回GDRAM, 可能出现src写回GDRAM没有结束就修改src数据的情况
}
if(step){//step针对的是othersize不能整除multiple*dimsize的部分
tid = taskRepeat * multiple * dimsize;
__memcpy(src, source + tid, step * dimsize * sizeof(float), GDRAM2NRAM);
for(int j = 0; j < step; j++){
__bang_write_zero(destSum, dimS);
__bang_write_zero(destSumFinal, warpSize);
destNewMax = -INFINITY;
for(int i = 0; i < repeatDim; i++){//repeatDim针对的是固定otherIdx后, 靠dimS读取当前dimsize数据一共需要的循环次数
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(i > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
if(remainDim){//remainDim针对的是固定otherIdx后, dimsize不能整除dimS的部分
__bang_write_value(tmp, dimS, -INFINITY);
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_argmax(srcMax, tmp, dimS);
if(destNewMax < srcMax[0]){
destNewMax = srcMax[0];
}
__bang_write_value(tmp, dimS, destNewMax);//必须重新初始化为NewMax
__memcpy(tmp, src + j * dimsize + repeatDim * dimS, remainDim * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
if(repeatDim > 0){
__bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), dimS);
}
__bang_add(destSum, destSum, tmp, dimS);
destOldMax = destNewMax;
}
int segNum = dimS / warpSize;//开始数值求和
for(int strip = segNum/2; strip > 0; strip = strip / 2){
for(int i = 0; i < strip ; i++){
__bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
}
}
__bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和
if(remainDim){
destSumFinal[0] = destSumFinal[0] - (dimS - remainDim);
}
//__bang_printf("taskId:%d, max:%.2f, sum:%.2f\n", taskId, destNewMax, destSumFinal[0]);
float globalSumInv = 1.0/destSumFinal[0];
if(remainDim){
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + repeatDim * dimS, tmp, remainDim * sizeof(float), NRAM2GDRAM);
}
for(int i = 0; i < repeatDim; i++){
__memcpy(tmp, src + j * dimsize + i * dimS, dimS * sizeof(float), NRAM2NRAM);
__bang_sub_scalar(tmp, tmp, destNewMax, dimS);
__bang_active_exp_less_0(tmp, tmp, dimS);
__bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
__memcpy(destination + tid + j * dimsize + i * dimS, tmp, dimS * sizeof(float), NRAM2GDRAM);
}
}
}
}
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
__nram__ float src[maxNum];//每次搬运maxNum数据到NRAM
__nram__ float tmpSum[maxNum];
__nram__ float tmpNewMax[maxNum];
__nram__ float tmpOldMax[maxNum];
int remain = othersize % taskDim;
int stepEasy = (othersize - remain)/taskDim;
int stepHard = stepEasy + 1;
int step = (taskId < remain ? stepHard : stepEasy);//前部分taskId多处理一个元素
int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
int remainNram = step%maxNum;
int repeat = (step - remainNram)/maxNum;
//__bang_printf("taskId:%d, repeat:%d, step:%d, indStart:%d, remainNram:%d\n", taskId, repeat, step, indStart, remainNram);
for(int j = 0; j < repeat; j++){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
}
}
if(remainNram){
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
__bang_write_zero(tmpSum, maxNum);
__bang_write_zero(src, maxNum);
for(int i = 0; i < dimsize; i++){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
if(i > 0){
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
}
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
for(int i = 0; i < dimsize - 1; i++){
__memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
__bang_sub(src, src, tmpNewMax, maxNum);//x - M
__bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
__bang_mul(src, src, tmpSum, maxNum);
__memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
}
}
}
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, int axis, int othersize, int frontsize, int dimsize, int stride){
if(axis == 3){
softmaxKernelAxis_e(mlu_destination, mlu_src, othersize, dimsize);
}
else if (axis > 0 && axis < 3){
softmaxKernelAxis_m(mlu_destination, mlu_src, frontsize, dimsize, stride);
}
else if(axis == 0){
softmaxKernelAxis_s(mlu_destination, mlu_src, othersize, dimsize, stride);
}
}