forked from jiuyuan/InfiniTensor
fix bug of unary_list_kernel
This commit is contained in:
parent
e3da455aba
commit
55c69978ef
|
@ -23,34 +23,36 @@ __mlu_device__ void UnaryFunction(T *output, T *input, size_t num,
|
|||
|
||||
int my_repeat = my / deal_align;
|
||||
int my_rem = my % deal_align;
|
||||
for (int i = 0; i < my_repeat; ++i) {
|
||||
size_t my_op_list = op_list;
|
||||
for(int i = 0; i < my_repeat; ++i) {
|
||||
__memcpy(left, input_start, use_nram_size, GDRAM2NRAM);
|
||||
while (op_list) {
|
||||
int op = op_list % 10;
|
||||
switch (op) {
|
||||
while(my_op_list){
|
||||
int op = my_op_list % 10;
|
||||
switch(op){
|
||||
case 1:
|
||||
__bang_active_abs((T *)left, (T *)left, num);
|
||||
break;
|
||||
__bang_active_abs((T*)left, (T*)left, deal_align);
|
||||
break;
|
||||
case 2:
|
||||
__bang_active_relu((T *)left, (T *)left, num);
|
||||
break;
|
||||
__bang_active_relu((T*)left, (T*)left, deal_align);
|
||||
break;
|
||||
case 3:
|
||||
__bang_active_sigmoid((T *)left, (T *)left, num);
|
||||
break;
|
||||
__bang_active_sigmoid((T*)left, (T*)left, deal_align);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
op_list /= 10;
|
||||
break;
|
||||
}
|
||||
my_op_list /= 10;
|
||||
}
|
||||
__memcpy(output_start, left, use_nram_size, NRAM2GDRAM);
|
||||
input_start += use_nram_size;
|
||||
output_start += use_nram_size;
|
||||
}
|
||||
if (my_rem) {
|
||||
if(my_rem) {
|
||||
my_op_list = op_list;
|
||||
__memcpy(left, input_start, my_rem * sizeof(T), GDRAM2NRAM);
|
||||
while (op_list) {
|
||||
int op = op_list % 10;
|
||||
switch (op) {
|
||||
while(my_op_list){
|
||||
int op = my_op_list % 10;
|
||||
switch(op){
|
||||
case 1:
|
||||
__bang_active_abs((T *)left, (T *)left, my_rem);
|
||||
break;
|
||||
|
@ -61,9 +63,9 @@ __mlu_device__ void UnaryFunction(T *output, T *input, size_t num,
|
|||
__bang_active_sigmoid((T *)left, (T *)left, my_rem);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
op_list /= 10;
|
||||
break;
|
||||
}
|
||||
my_op_list /= 10;
|
||||
}
|
||||
__memcpy(output_start, left, my_rem * sizeof(T), NRAM2GDRAM);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue