fix concat and split operation

This commit is contained in:
wanghailu 2023-01-10 02:02:19 +00:00
parent 04d0e1a560
commit 6f1c7d0e82
2 changed files with 31 additions and 12 deletions

View File

@ -17,13 +17,22 @@ class ConcatCnnl : public BangKernelWithoutConfig {
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t desc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
int dim_array[num][4];
for (int i = 0; i < num; ++i) {
auto dim = op->getInputs(i)->getDims();
if (dim.size() != 4) {
IT_TODO_HALT();
}
dim_array[i][0] = dim[0];
dim_array[i][1] = dim[1];
dim_array[i][2] = dim[2];
dim_array[i][3] = dim[3];
}
auto dim = op->getOutput()->getDims();
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
dimout_array[axis] *= num;
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimout_array));
@ -32,7 +41,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
CNNL_DTYPE_FLOAT, 4, dim_array[i]));
}
size_t wsSize;

View File

@ -17,13 +17,23 @@ class SplitCnnl : public BangKernelWithoutConfig {
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
cnnlTensorDescriptor_t desc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dimout_array[num][4];
for (int i = 0; i < num; ++i) {
auto dim = op->getOutput(i)->getDims();
if (dim.size() != 4) {
IT_TODO_HALT();
}
dimout_array[i][0] = dim[0];
dimout_array[i][1] = dim[1];
dimout_array[i][2] = dim[2];
dimout_array[i][3] = dim[3];
}
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4) {
IT_TODO_HALT();
}
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
dimout_array[axis] /= num;
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
@ -32,7 +42,7 @@ class SplitCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dimout_array));
CNNL_DTYPE_FLOAT, 4, dimout_array[i]));
}
size_t wsSize;