fix bug of unary_list_kernel

This commit is contained in:
wanghailu 2023-08-03 15:38:21 +08:00
parent e3da455aba
commit 55c69978ef
1 changed files with 22 additions and 20 deletions

View File

@ -23,34 +23,36 @@ __mlu_device__ void UnaryFunction(T *output, T *input, size_t num,
int my_repeat = my / deal_align;
int my_rem = my % deal_align;
for (int i = 0; i < my_repeat; ++i) {
size_t my_op_list = op_list;
for(int i = 0; i < my_repeat; ++i) {
__memcpy(left, input_start, use_nram_size, GDRAM2NRAM);
while (op_list) {
int op = op_list % 10;
switch (op) {
while(my_op_list){
int op = my_op_list % 10;
switch(op){
case 1:
__bang_active_abs((T *)left, (T *)left, num);
break;
__bang_active_abs((T*)left, (T*)left, deal_align);
break;
case 2:
__bang_active_relu((T *)left, (T *)left, num);
break;
__bang_active_relu((T*)left, (T*)left, deal_align);
break;
case 3:
__bang_active_sigmoid((T *)left, (T *)left, num);
break;
__bang_active_sigmoid((T*)left, (T*)left, deal_align);
break;
default:
break;
}
op_list /= 10;
break;
}
my_op_list /= 10;
}
__memcpy(output_start, left, use_nram_size, NRAM2GDRAM);
input_start += use_nram_size;
output_start += use_nram_size;
}
if (my_rem) {
if(my_rem) {
my_op_list = op_list;
__memcpy(left, input_start, my_rem * sizeof(T), GDRAM2NRAM);
while (op_list) {
int op = op_list % 10;
switch (op) {
while(my_op_list){
int op = my_op_list % 10;
switch(op){
case 1:
__bang_active_abs((T *)left, (T *)left, my_rem);
break;
@ -61,9 +63,9 @@ __mlu_device__ void UnaryFunction(T *output, T *input, size_t num,
__bang_active_sigmoid((T *)left, (T *)left, my_rem);
break;
default:
break;
}
op_list /= 10;
break;
}
my_op_list /= 10;
}
__memcpy(output_start, left, my_rem * sizeof(T), NRAM2GDRAM);
}