forked from jiuyuan/InfiniTensor
format
This commit is contained in:
parent
cd703e5679
commit
d216b529e7
|
@ -1,6 +1,6 @@
|
|||
#include "operators/concat.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "operators/concat.h"
|
||||
|
||||
namespace infini {
|
||||
class ConcatCnnl : public BangKernelWithoutConfig {
|
||||
|
@ -23,10 +23,10 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
int dimout_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
dimout_array[axis] *= num;
|
||||
dimout_array[axis] *= num;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dimout_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
desc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimout_array));
|
||||
cnnlTensorDescriptor_t descArray[num];
|
||||
for (int i = 0; i < num; ++i) {
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
|
||||
|
@ -40,7 +40,8 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
cnnlStatus_t stat =
|
||||
cnnlConcat(context->cnnlHandle(), num, axis, descArray, argv, wsData, wsSize, desc, cData);
|
||||
cnnlConcat(context->cnnlHandle(), num, axis, descArray, argv,
|
||||
wsData, wsSize, desc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace infini {
|
|||
|
||||
template <class T>
|
||||
void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
@ -29,7 +29,8 @@ void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
|
||||
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
|
||||
auto gpuOp = bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
|
||||
auto gpuOp =
|
||||
bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
|
||||
bangGraph->dataMalloc();
|
||||
bangRuntime->run(bangGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
|
|
Loading…
Reference in New Issue