This commit is contained in:
wanghailu 2023-01-09 15:16:43 +08:00
parent cd703e5679
commit d216b529e7
2 changed files with 9 additions and 7 deletions

View File

@ -1,6 +1,6 @@
#include "operators/concat.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/concat.h"
namespace infini {
class ConcatCnnl : public BangKernelWithoutConfig {
@ -23,10 +23,10 @@ class ConcatCnnl : public BangKernelWithoutConfig {
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
dimout_array[axis] *= num;
dimout_array[axis] *= num;
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dimout_array));
checkCnnlError(cnnlSetTensorDescriptor(
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimout_array));
cnnlTensorDescriptor_t descArray[num];
for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
@ -40,7 +40,8 @@ class ConcatCnnl : public BangKernelWithoutConfig {
BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat =
cnnlConcat(context->cnnlHandle(), num, axis, descArray, argv, wsData, wsSize, desc, cData);
cnnlConcat(context->cnnlHandle(), num, axis, descArray, argv,
wsData, wsSize, desc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T>
void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
@ -29,7 +29,8 @@ void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
auto gpuOp =
bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();