modified format
This commit is contained in:
parent
4ef734eb77
commit
7861d38048
|
@ -62,7 +62,7 @@ class GraphHandlerObj {
|
|||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor erf(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y, int axis);
|
||||
|
||||
|
||||
Tensor abs(Tensor x, Tensor y);
|
||||
Tensor sqrt(Tensor x, Tensor y);
|
||||
Tensor neg(Tensor x, Tensor y);
|
||||
|
|
|
@ -526,7 +526,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("hardSigmoid", &Handler::hardSigmoid, policy::move)
|
||||
.def("hardSwish", &Handler::hardSwish, policy::move)
|
||||
.def("softmax", &Handler::softmax, policy::move)
|
||||
|
||||
|
||||
.def("abs", &Handler::abs, policy::move)
|
||||
.def("sqrt", &Handler::sqrt, policy::move)
|
||||
.def("neg", &Handler::neg, policy::move)
|
||||
|
|
|
@ -139,7 +139,7 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
|||
auto aDim = op->getInputs(0)->getDims();
|
||||
int axis = op->getAxis();
|
||||
int nDim = aDim.size();
|
||||
if(axis == 0 || axis == nDim - 1){
|
||||
if (axis == 0 || axis == nDim - 1) {
|
||||
int stride = 1;
|
||||
int dimsize = aDim[axis];
|
||||
int num = 1;
|
||||
|
@ -158,13 +158,11 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
|||
othersize *= aDim[s];
|
||||
}
|
||||
}
|
||||
softmaxKernel(context->cnnlHandle(), (float *)cData,
|
||||
(float *)aData, othersize, dimsize, frontsize,
|
||||
stride, axis, nDim);
|
||||
}
|
||||
else{
|
||||
softmaxKernel(context->cnnlHandle(), (float *)cData, (float *)aData,
|
||||
othersize, dimsize, frontsize, stride, axis, nDim);
|
||||
} else {
|
||||
cnnlSoftmaxMode_t mode;
|
||||
|
||||
|
||||
std::vector<int> inDim = {1, 1, 1};
|
||||
std::vector<int> outDim = inDim;
|
||||
|
||||
|
@ -226,16 +224,15 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
|||
outDim.size(), outDim.data()));
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
cnnlStatus_t stat =
|
||||
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
|
||||
mode, CNNL_COMPUTATION_ULTRAHIGH_PRECISION,
|
||||
&alpha, aDesc, aData, &beta, cDesc, cData);
|
||||
cnnlStatus_t stat = cnnlSoftmaxForward_v2(
|
||||
context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE, mode,
|
||||
CNNL_COMPUTATION_ULTRAHIGH_PRECISION, &alpha, aDesc, aData,
|
||||
&beta, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -389,287 +389,403 @@ __mlu_device__ void softmaxKernelAxis_e(float *destination, float *source,
|
|||
}
|
||||
}
|
||||
}
|
||||
__mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, float* tmpGdram, int othersize, int dimsize, int stride) {// axis = 0
|
||||
|
||||
const int SRC_MAX_SIZE = 1024 * 64;//The subsequent tree summation must ensure that SRC-MAX-SIZE is a power of 2
|
||||
const int maxNum = SRC_MAX_SIZE/sizeof(float);
|
||||
if(othersize > taskDim * maxNum){
|
||||
//-----------------------------------------allocate memory
|
||||
float* src = nram_buffer;// src[3 * maxNum]
|
||||
float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
|
||||
float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
|
||||
float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
|
||||
//-----------------------------------------
|
||||
int remain = othersize % taskDim;
|
||||
int stepEasy = (othersize - remain)/taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element
|
||||
int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
|
||||
int remainNram = step%maxNum;
|
||||
int repeat = (step - remainNram)/maxNum;
|
||||
|
||||
for(int j = 0; j < repeat; j++){
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
for(int i = 0; i < dimsize + 1; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0){
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
|
||||
if(i > 1){
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
||||
|
||||
for(int i = 0; i < dimsize + 2; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0 && i < dimsize + 1){
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if(i > 1){
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
}
|
||||
if(remainNram){
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
|
||||
for(int i = 0; i < dimsize + 1; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0){
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
|
||||
if(i > 1){
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
||||
//Start exponential transformation and write back to GDRAM
|
||||
|
||||
for(int i = 0; i < dimsize + 2; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0 && i < dimsize + 1){
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if(i > 1){
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (othersize > maxNum && othersize <= taskDim * maxNum){
|
||||
float* src = nram_buffer;// src[3 * maxNum]
|
||||
float* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
|
||||
float* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
|
||||
float* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
|
||||
//-----------------------------------------
|
||||
int remain = othersize % taskDim;
|
||||
int stepEasy = (othersize - remain)/taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element
|
||||
int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
|
||||
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
|
||||
for(int i = 0; i < dimsize + 1; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart, step * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0){
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
|
||||
if(i > 1){
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
|
||||
//Start exponential transformation and write back to GDRAM
|
||||
|
||||
for(int i = 0; i < dimsize + 2; i++){
|
||||
if(i < dimsize){
|
||||
__memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart, step * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if(i > 0 && i < dimsize + 1){
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if(i > 1){
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart, src + (i - 2) % 3 * maxNum, step * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
}
|
||||
else{
|
||||
|
||||
int multiple = maxNum / othersize;
|
||||
int size = taskDim * multiple;
|
||||
int remain = dimsize % size;
|
||||
int repeat = (dimsize - remain) / size;
|
||||
__mlu_device__ void softmaxKernelAxis_s(float *destination, float *source,
|
||||
float *tmpGdram, int othersize,
|
||||
int dimsize, int stride) { // axis = 0
|
||||
|
||||
int remainT = remain % taskDim;
|
||||
int stepEasy = (remain - remainT) / taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step = (taskId < remainT ? stepHard : stepEasy);
|
||||
int indStart = (taskId < remainT ? taskId * stepHard : remainT * stepHard + (taskId - remainT) * stepEasy);
|
||||
|
||||
float* src = nram_buffer;// src[3 * maxNum]
|
||||
float* tmpSum = src + 3 * maxNum;//tmpSum[othersize]
|
||||
float* tmpNewMax = tmpSum + othersize;//tmpNewMax[othersize]
|
||||
float* tmpOldMax = tmpNewMax + othersize;//tmpOldMax[othersize]
|
||||
float* tmpGlobal = tmpOldMax + othersize;
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
for(int i = 0; i < repeat; i++){
|
||||
__memcpy(src, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(float), GDRAM2NRAM);//stride=othersize
|
||||
for(int m = 0; m < multiple; m++){
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize, othersize);
|
||||
}
|
||||
for(int m = 0; m < multiple; m++){
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);//x - M
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, multiple * othersize);//exp(x - M)
|
||||
if(i > 0){
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); //sum = sum * exp(oldM - newM)
|
||||
}
|
||||
for(int m = 0; m < multiple; m++){
|
||||
__bang_add(tmpSum, tmpSum, src + m * othersize, othersize);
|
||||
}
|
||||
__memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(float), NRAM2NRAM);
|
||||
const int SRC_MAX_SIZE =
|
||||
1024 * 64; // The subsequent tree summation must ensure that
|
||||
// SRC-MAX-SIZE is a power of 2
|
||||
const int maxNum = SRC_MAX_SIZE / sizeof(float);
|
||||
if (othersize > taskDim * maxNum) {
|
||||
//-----------------------------------------allocate memory
|
||||
float *src = nram_buffer; // src[3 * maxNum]
|
||||
float *tmpSum = src + 3 * maxNum; // tmpSum[maxNum]
|
||||
float *tmpNewMax = src + 4 * maxNum; // tmpNewMax[maxNum]
|
||||
float *tmpOldMax = src + 5 * maxNum; // tmpOldMax[maxNum]
|
||||
//-----------------------------------------
|
||||
int remain = othersize % taskDim;
|
||||
int stepEasy = (othersize - remain) / taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step =
|
||||
(taskId < remain ? stepHard
|
||||
: stepEasy); // The first part of taskId handles an
|
||||
// additional element
|
||||
int indStart = (taskId < remain
|
||||
? taskId * stepHard
|
||||
: remain * stepHard + (taskId - remain) * stepEasy);
|
||||
int remainNram = step % maxNum;
|
||||
int repeat = (step - remainNram) / maxNum;
|
||||
|
||||
for (int j = 0; j < repeat; j++) {
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
for (int i = 0; i < dimsize + 1; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 2 * maxNum,
|
||||
source + i * stride + indStart + j * maxNum,
|
||||
maxNum * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0) {
|
||||
__bang_maxequal(
|
||||
tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // Continuously updating the maximum value
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
if (i > 1) {
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax,
|
||||
maxNum); // oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax,
|
||||
maxNum); // exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax,
|
||||
maxNum); // sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float),
|
||||
NRAM2NRAM); // oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum); // compute 1/sum
|
||||
|
||||
for (int i = 0; i < dimsize + 2; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 3 * maxNum,
|
||||
source + i * stride + indStart + j * maxNum,
|
||||
maxNum * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0 && i < dimsize + 1) {
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if (i > 1) {
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart +
|
||||
j * maxNum,
|
||||
src + (i - 2) % 3 * maxNum,
|
||||
maxNum * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
}
|
||||
if (remainNram) {
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
|
||||
for (int i = 0; i < dimsize + 1; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 2 * maxNum,
|
||||
source + i * stride + indStart +
|
||||
repeat * maxNum,
|
||||
remainNram * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0) {
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax,
|
||||
src + (i - 1) % 2 * maxNum, maxNum);
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
if (i > 1) {
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax,
|
||||
maxNum); // oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax,
|
||||
maxNum); // exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax,
|
||||
maxNum); // sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float),
|
||||
NRAM2NRAM); // oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum); // compute 1/sum
|
||||
// Start exponential transformation and write back to GDRAM
|
||||
|
||||
for (int i = 0; i < dimsize + 2; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 3 * maxNum,
|
||||
source + i * stride + indStart +
|
||||
repeat * maxNum,
|
||||
remainNram * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0 && i < dimsize + 1) {
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if (i > 1) {
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart +
|
||||
repeat * maxNum,
|
||||
src + (i - 2) % 3 * maxNum,
|
||||
remainNram * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
}
|
||||
} else if (othersize > maxNum && othersize <= taskDim * maxNum) {
|
||||
float *src = nram_buffer; // src[3 * maxNum]
|
||||
float *tmpSum = src + 3 * maxNum; // tmpSum[maxNum]
|
||||
float *tmpNewMax = src + 4 * maxNum; // tmpNewMax[maxNum]
|
||||
float *tmpOldMax = src + 5 * maxNum; // tmpOldMax[maxNum]
|
||||
//-----------------------------------------
|
||||
int remain = othersize % taskDim;
|
||||
int stepEasy = (othersize - remain) / taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step =
|
||||
(taskId < remain ? stepHard
|
||||
: stepEasy); // The first part of taskId handles an
|
||||
// additional element
|
||||
int indStart = (taskId < remain
|
||||
? taskId * stepHard
|
||||
: remain * stepHard + (taskId - remain) * stepEasy);
|
||||
|
||||
__bang_write_value(tmpNewMax, maxNum, -INFINITY);
|
||||
__bang_write_zero(tmpSum, maxNum);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
|
||||
for (int i = 0; i < dimsize + 1; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 2 * maxNum,
|
||||
source + i * stride + indStart,
|
||||
step * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0) {
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax,
|
||||
src + (i - 1) % 2 * maxNum, maxNum);
|
||||
__bang_sub(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 2 * maxNum,
|
||||
src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
if (i > 1) {
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax,
|
||||
maxNum); // oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax,
|
||||
maxNum); // exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax,
|
||||
maxNum); // sum = sum * exp(oldM - newM)
|
||||
}
|
||||
__bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum,
|
||||
maxNum); // sum += exp(x - M)
|
||||
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float),
|
||||
NRAM2NRAM); // oldM = newM
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
|
||||
__bang_active_reciphp(tmpSum, tmpSum, maxNum); // compute 1/sum
|
||||
// Start exponential transformation and write back to GDRAM
|
||||
|
||||
for (int i = 0; i < dimsize + 2; i++) {
|
||||
if (i < dimsize) {
|
||||
__memcpy_async(src + i % 3 * maxNum,
|
||||
source + i * stride + indStart,
|
||||
step * sizeof(float), GDRAM2NRAM);
|
||||
}
|
||||
if (i > 0 && i < dimsize + 1) {
|
||||
__bang_sub(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpNewMax,
|
||||
maxNum); // x - M
|
||||
__bang_active_exp_less_0(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum,
|
||||
maxNum); // exp(x - M)
|
||||
__bang_mul(src + (i - 1) % 3 * maxNum,
|
||||
src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
|
||||
}
|
||||
if (i > 1) {
|
||||
__memcpy_async(destination + (i - 2) * stride + indStart,
|
||||
src + (i - 2) % 3 * maxNum, step * sizeof(float),
|
||||
NRAM2GDRAM);
|
||||
}
|
||||
__sync_all_ipu();
|
||||
}
|
||||
} else {
|
||||
|
||||
int multiple = maxNum / othersize;
|
||||
int size = taskDim * multiple;
|
||||
int remain = dimsize % size;
|
||||
int repeat = (dimsize - remain) / size;
|
||||
|
||||
int remainT = remain % taskDim;
|
||||
int stepEasy = (remain - remainT) / taskDim;
|
||||
int stepHard = stepEasy + 1;
|
||||
int step = (taskId < remainT ? stepHard : stepEasy);
|
||||
int indStart = (taskId < remainT ? taskId * stepHard
|
||||
: remainT * stepHard +
|
||||
(taskId - remainT) * stepEasy);
|
||||
|
||||
float *src = nram_buffer; // src[3 * maxNum]
|
||||
float *tmpSum = src + 3 * maxNum; // tmpSum[othersize]
|
||||
float *tmpNewMax = tmpSum + othersize; // tmpNewMax[othersize]
|
||||
float *tmpOldMax = tmpNewMax + othersize; // tmpOldMax[othersize]
|
||||
float *tmpGlobal = tmpOldMax + othersize;
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
__bang_write_zero(src, 3 * maxNum);
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
__memcpy(src, source + (i * size + taskId * multiple) * stride,
|
||||
multiple * othersize * sizeof(float),
|
||||
GDRAM2NRAM); // stride=othersize
|
||||
for (int m = 0; m < multiple; m++) {
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize,
|
||||
othersize);
|
||||
}
|
||||
for (int m = 0; m < multiple; m++) {
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax,
|
||||
othersize); // x - M
|
||||
}
|
||||
__bang_active_exp_less_0(src, src,
|
||||
multiple * othersize); // exp(x - M)
|
||||
if (i > 0) {
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax,
|
||||
othersize); // oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax,
|
||||
othersize); // exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax,
|
||||
othersize); // sum = sum * exp(oldM - newM)
|
||||
}
|
||||
for (int m = 0; m < multiple; m++) {
|
||||
__bang_add(tmpSum, tmpSum, src + m * othersize, othersize);
|
||||
}
|
||||
__memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(float),
|
||||
NRAM2NRAM);
|
||||
}
|
||||
|
||||
if (step) {
|
||||
__memcpy(src, source + repeat * size * stride + indStart * stride,
|
||||
step * othersize * sizeof(float),
|
||||
GDRAM2NRAM); // stride=othersize
|
||||
|
||||
for (int m = 0; m < step; m++) {
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize,
|
||||
othersize);
|
||||
}
|
||||
for (int m = 0; m < step; m++) {
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax,
|
||||
othersize); // x - M
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, step * othersize); // exp(x - M)
|
||||
if (repeat > 0) {
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax,
|
||||
othersize); // oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax,
|
||||
othersize); // exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax,
|
||||
othersize); // sum = sum * exp(oldM - newM)
|
||||
}
|
||||
for (int m = 0; m < step; m++) {
|
||||
__bang_add(tmpSum, tmpSum, src + m * othersize, othersize);
|
||||
}
|
||||
__memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(float),
|
||||
NRAM2NRAM);
|
||||
}
|
||||
//----------------
|
||||
if (repeat > 0 || dimsize >= taskDim) {
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpNewMax,
|
||||
othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
for (int id = 0; id < taskDim; id++) {
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize,
|
||||
othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpSum,
|
||||
othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
for (int id = 0; id < taskDim; id++) {
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize,
|
||||
othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, othersize);
|
||||
} else {
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpNewMax,
|
||||
othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
for (int id = 0; id < dimsize; id++) {
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize,
|
||||
othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpSum,
|
||||
othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
for (int id = 0; id < dimsize; id++) {
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize,
|
||||
othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, othersize);
|
||||
}
|
||||
|
||||
//-------------------
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
__memcpy(src, source + (i * size + taskId * multiple) * stride,
|
||||
multiple * othersize * sizeof(float),
|
||||
GDRAM2NRAM); // stride=othersize
|
||||
for (int m = 0; m < multiple; m++) {
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax,
|
||||
othersize);
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, multiple * othersize);
|
||||
for (int m = 0; m < multiple; m++) {
|
||||
__bang_mul(src + m * othersize, src + m * othersize, tmpSum,
|
||||
othersize);
|
||||
}
|
||||
__memcpy(destination + (i * size + taskId * multiple) * stride, src,
|
||||
multiple * othersize * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
if (step) {
|
||||
__memcpy(src, source + repeat * size * stride + indStart * stride,
|
||||
step * othersize * sizeof(float),
|
||||
GDRAM2NRAM); // stride=othersize
|
||||
for (int m = 0; m < step; m++) {
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax,
|
||||
othersize);
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, step * othersize);
|
||||
for (int m = 0; m < step; m++) {
|
||||
__bang_mul(src + m * othersize, src + m * othersize, tmpSum,
|
||||
othersize);
|
||||
}
|
||||
__memcpy(destination + repeat * size * stride + indStart * stride,
|
||||
src, step * othersize * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
|
||||
if(step) {
|
||||
__memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(float), GDRAM2NRAM);//stride=othersize
|
||||
|
||||
for(int m = 0; m < step; m++){
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize, othersize);
|
||||
}
|
||||
for(int m = 0; m < step; m++){
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);//x - M
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, step * othersize);//exp(x - M)
|
||||
if(repeat > 0){
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM)
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); //sum = sum * exp(oldM - newM)
|
||||
}
|
||||
for(int m = 0; m < step; m++){
|
||||
__bang_add(tmpSum, tmpSum, src + m * othersize, othersize);
|
||||
}
|
||||
__memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(float), NRAM2NRAM);
|
||||
}
|
||||
//----------------
|
||||
if(repeat > 0 || dimsize >= taskDim){
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
for(int id = 0; id < taskDim; id++){
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
for(int id = 0; id < taskDim; id++){
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, othersize);
|
||||
}
|
||||
else{
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_value(tmpNewMax, othersize, -INFINITY);
|
||||
for(int id = 0; id < dimsize; id++){
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
|
||||
__bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
|
||||
__bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
|
||||
__memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(float), NRAM2GDRAM);
|
||||
__sync_all();
|
||||
__bang_write_zero(tmpSum, othersize);
|
||||
for(int id = 0; id < dimsize; id++){
|
||||
__memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(float), GDRAM2NRAM);
|
||||
__bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
|
||||
}
|
||||
__bang_active_reciphp(tmpSum, tmpSum, othersize);
|
||||
}
|
||||
|
||||
//-------------------
|
||||
for(int i = 0; i < repeat; i++){
|
||||
__memcpy(src, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(float), GDRAM2NRAM);//stride=othersize
|
||||
for(int m = 0; m < multiple; m++){
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, multiple * othersize);
|
||||
for(int m = 0; m < multiple; m++){
|
||||
__bang_mul(src + m * othersize, src + m * othersize, tmpSum, othersize);
|
||||
}
|
||||
__memcpy(destination + (i * size + taskId * multiple) * stride, src, multiple * othersize * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
if(step) {
|
||||
__memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(float), GDRAM2NRAM);//stride=othersize
|
||||
for(int m = 0; m < step; m++){
|
||||
__bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);
|
||||
}
|
||||
__bang_active_exp_less_0(src, src, step * othersize);
|
||||
for(int m = 0; m < step; m++){
|
||||
__bang_mul(src + m * othersize, src + m * othersize, tmpSum, othersize);
|
||||
}
|
||||
__memcpy(destination + repeat * size * stride + indStart * stride, src, step * othersize * sizeof(float), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
__mlu_device__ void softmaxKernelAxis_m(float *destination, float *source,
|
||||
int frontsize, int dimsize, int stride,
|
||||
|
@ -1036,8 +1152,8 @@ __mlu_device__ void softmaxKernelAxis_m(float *destination, float *source,
|
|||
__sync_all_ipu();
|
||||
}
|
||||
//__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d,
|
||||
//indStart:%d\n",taskId, multiple, taskRepeat, step, indStart *
|
||||
//behindsize);
|
||||
// indStart:%d\n",taskId, multiple, taskRepeat, step, indStart *
|
||||
// behindsize);
|
||||
if (step) {
|
||||
tid = taskRepeat * multiple * behindsize;
|
||||
__memcpy(src, source + tid, step * behindsize * sizeof(float),
|
||||
|
@ -1066,13 +1182,13 @@ __mlu_device__ void softmaxKernelAxis_m(float *destination, float *source,
|
|||
NRAM2NRAM); // oldM = newM
|
||||
}
|
||||
//__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n",
|
||||
//tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
|
||||
// tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
|
||||
__bang_active_reciphp(tmpSum, tmpSum, strideS);
|
||||
__bang_mul(tmp, tmp, tmpSum,
|
||||
strideS); // The data stored in tmp at the end of the
|
||||
// loop above can be utilized
|
||||
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) *
|
||||
//stride, tmp, stride * sizeof(float), NRAM2GDRAM);
|
||||
// stride, tmp, stride * sizeof(float), NRAM2GDRAM);
|
||||
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp,
|
||||
stride * sizeof(float), NRAM2NRAM);
|
||||
for (int i = 0; i < dimsize - 1; i++) {
|
||||
|
@ -1082,7 +1198,7 @@ __mlu_device__ void softmaxKernelAxis_m(float *destination, float *source,
|
|||
__bang_active_exp_less_0(tmp, tmp, strideS); // exp(x - M)
|
||||
__bang_mul(tmp, tmp, tmpSum, strideS);
|
||||
//__memcpy(destination + tid + m * behindsize + i * stride,
|
||||
//tmp, stride * sizeof(float), NRAM2GDRAM);
|
||||
// tmp, stride * sizeof(float), NRAM2GDRAM);
|
||||
__memcpy(src + m * behindsize + i * stride, tmp,
|
||||
stride * sizeof(float), NRAM2NRAM);
|
||||
}
|
||||
|
@ -1092,9 +1208,10 @@ __mlu_device__ void softmaxKernelAxis_m(float *destination, float *source,
|
|||
}
|
||||
}
|
||||
}
|
||||
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, float *tmpGdram,
|
||||
int othersize, int dimsize, int frontsize,
|
||||
int stride, int axis, int nDim) {
|
||||
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src,
|
||||
float *tmpGdram, int othersize, int dimsize,
|
||||
int frontsize, int stride, int axis,
|
||||
int nDim) {
|
||||
if (axis == nDim - 1) {
|
||||
int dimS;
|
||||
float mi = log2(dimsize);
|
||||
|
@ -1108,8 +1225,8 @@ __mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src, float
|
|||
}
|
||||
softmaxKernelAxis_e(mlu_destination, mlu_src, othersize, dimsize, dimS);
|
||||
} else if (axis == 0) {
|
||||
softmaxKernelAxis_s(mlu_destination, mlu_src, tmpGdram, othersize, dimsize,
|
||||
stride);
|
||||
softmaxKernelAxis_s(mlu_destination, mlu_src, tmpGdram, othersize,
|
||||
dimsize, stride);
|
||||
} else {
|
||||
float mi = log2(stride);
|
||||
int strideS;
|
||||
|
@ -1136,10 +1253,11 @@ void softmaxKernel(cnnlHandle_t handle, float *mlu_destination, float *mlu_src,
|
|||
k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
// launch kernel
|
||||
float *tmpGdram;
|
||||
CNRT_CHECK(cnrtMalloc((void **)&tmpGdram, k_dim.x * k_dim.y * k_dim.z * othersize * sizeof(float)));
|
||||
softmaxUnion1<<<k_dim, k_type, queue>>>(mlu_destination, mlu_src, tmpGdram, othersize,
|
||||
dimsize, frontsize, stride, axis,
|
||||
nDim);
|
||||
cnrtFree(tmpGdram);
|
||||
CNRT_CHECK(cnrtMalloc((void **)&tmpGdram, k_dim.x * k_dim.y * k_dim.z *
|
||||
othersize * sizeof(float)));
|
||||
softmaxUnion1<<<k_dim, k_type, queue>>>(mlu_destination, mlu_src, tmpGdram,
|
||||
othersize, dimsize, frontsize,
|
||||
stride, axis, nDim);
|
||||
cnrtFree(tmpGdram);
|
||||
}
|
||||
}; // namespace infini
|
||||
|
|
|
@ -30,7 +30,7 @@ void cnnlSoftmaxFp32(const Shape &inputShape, const vector<float> &inputData,
|
|||
EXPECT_TRUE(outputGpu2Cpu->equalData(expectData));
|
||||
}
|
||||
TEST(cnnlSoftmaxFp32, run) {
|
||||
|
||||
|
||||
cnnlSoftmaxFp32(Shape{2, 2, 2, 2},
|
||||
vector<float>{
|
||||
0.,
|
||||
|
@ -79,7 +79,6 @@ TEST(cnnlSoftmaxFp32, run) {
|
|||
0.1192029, 0.1192029, 0.8807971, 0.8807971,
|
||||
0.1192029, 0.1192029, 0.8807971, 0.8807971,
|
||||
0.1192029, 0.1192029, 0.8807971, 0.8807971});
|
||||
|
||||
}
|
||||
void bangSoftmaxFp32(const Shape &inputShape, const vector<float> &inputData,
|
||||
int axis, const vector<float> &expectData) {
|
||||
|
@ -118,7 +117,7 @@ TEST(bangSoftmaxFp32, run) {
|
|||
9.99993801e-01, 9.99993801e-01, 9.99993801e-01,
|
||||
9.99993801e-01, 9.99993801e-01, 9.99993801e-01,
|
||||
9.99993801e-01, 9.99993801e-01, 9.99993801e-01});
|
||||
|
||||
|
||||
bangSoftmaxFp32(Shape{2, 2, 2, 2},
|
||||
vector<float>{
|
||||
0.,
|
||||
|
|
Loading…
Reference in New Issue