forked from jiuyuan/InfiniTensor
fix concat and split operation
This commit is contained in:
parent
04d0e1a560
commit
6f1c7d0e82
|
@ -17,13 +17,22 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t desc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
int dim_array[num][4];
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto dim = op->getInputs(i)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
dim_array[i][0] = dim[0];
|
||||
dim_array[i][1] = dim[1];
|
||||
dim_array[i][2] = dim[2];
|
||||
dim_array[i][3] = dim[3];
|
||||
}
|
||||
|
||||
auto dim = op->getOutput()->getDims();
|
||||
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
dimout_array[axis] *= num;
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimout_array));
|
||||
|
@ -32,7 +41,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array[i]));
|
||||
}
|
||||
|
||||
size_t wsSize;
|
||||
|
|
|
@ -17,13 +17,23 @@ class SplitCnnl : public BangKernelWithoutConfig {
|
|||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t desc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dimout_array[num][4];
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto dim = op->getOutput(i)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
dimout_array[i][0] = dim[0];
|
||||
dimout_array[i][1] = dim[1];
|
||||
dimout_array[i][2] = dim[2];
|
||||
dimout_array[i][3] = dim[3];
|
||||
}
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
dimout_array[axis] /= num;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
|
@ -32,7 +42,7 @@ class SplitCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
|
||||
checkCnnlError(
|
||||
cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dimout_array));
|
||||
CNNL_DTYPE_FLOAT, 4, dimout_array[i]));
|
||||
}
|
||||
|
||||
size_t wsSize;
|
||||
|
|
Loading…
Reference in New Issue