forked from jiuyuan/InfiniTensor
ADD:pad/slice operator and cuda kernel. (#39)
fix compile error refector clang format split test. fix compile error. ADD slice cuda kernel. ADD slice operator. ADD:pad operator and cuda kernel.
This commit is contained in:
parent
1aefc1b27e
commit
5560d0f2fb
|
@ -20,7 +20,7 @@ class TensorObj : public TensorBaseObj {
|
||||||
size_t getBytes() const;
|
size_t getBytes() const;
|
||||||
|
|
||||||
Shape getDims() const { return shape; }
|
Shape getDims() const { return shape; }
|
||||||
|
vector<size_t> getStride() const;
|
||||||
size_t getOffset(const Shape &ds) const;
|
size_t getOffset(const Shape &ds) const;
|
||||||
using TensorBaseObj::getData;
|
using TensorBaseObj::getData;
|
||||||
VType getData(const Shape &pos) const;
|
VType getData(const Shape &pos) const;
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
const int MAX_DIM = 4;
|
||||||
|
|
||||||
|
// Pad operator acts like padding small(part) tensor into a big(whole) tensor.
|
||||||
|
// Slice operator acts like spling a big(whole) tensor into a small(part)
|
||||||
|
// tensor.
|
||||||
|
typedef struct {
|
||||||
|
int begNum[MAX_DIM]; // pad or slice number at beginning
|
||||||
|
int wholeNDim[MAX_DIM]; // dim size after padding or before slicing
|
||||||
|
int partNDim[MAX_DIM]; // dim size before padding or after slicing
|
||||||
|
int partStride[MAX_DIM]; // stride before padding or after slicing
|
||||||
|
} TransMetaData;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void pad_slice_kernel(float *partData, float *wholeData,
|
||||||
|
const TransMetaData &metadata, int nDims, int num,
|
||||||
|
bool isPad);
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,24 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class PadObj : public OperatorObj {
|
||||||
|
// the number of start and end pad values for all dims.
|
||||||
|
vector<int> pads;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// pad for appointed axises,if axis is empty,then pad for all axises.
|
||||||
|
PadObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const vector<int> &pads, const optional<const vector<int>> &axis);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
Shape PadObj::getPads() const { return pads; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,24 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class SliceObj : public OperatorObj {
|
||||||
|
vector<int> starts, ends; // the start no. and end no. for all dims.
|
||||||
|
|
||||||
|
public:
|
||||||
|
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const vector<int> &starts, const vector<int> &ends,
|
||||||
|
const optional<vector<int>> &axis,
|
||||||
|
const optional<vector<int>> &steps);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
Shape getStart() const { return starts; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -30,6 +30,17 @@ size_t TensorObj::getOffset(const Shape &pos) const {
|
||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vector<size_t> TensorObj::getStride() const {
|
||||||
|
vector<size_t> ret;
|
||||||
|
size_t stride = 1;
|
||||||
|
for (int i = shape.size() - 1; i >= 1; i--) {
|
||||||
|
ret.emplace(ret.begin(), stride);
|
||||||
|
stride *= shape.at(i);
|
||||||
|
}
|
||||||
|
ret.emplace(ret.begin(), stride);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
size_t TensorObj::size() const {
|
size_t TensorObj::size() const {
|
||||||
size_t ret = 1;
|
size_t ret = 1;
|
||||||
for (const auto &d : shape)
|
for (const auto &d : shape)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_pad_slice.h"
|
||||||
|
#include "operators/pad.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
namespace infini {
|
||||||
|
class PadSliceCudaCompute {
|
||||||
|
public:
|
||||||
|
void do_compute(Tensor partTensor, Tensor wholeTensor, const Shape &begNos,
|
||||||
|
bool isPad) const {
|
||||||
|
int nDims = partTensor->getDims().size();
|
||||||
|
IT_ASSERT(MAX_DIM >= nDims);
|
||||||
|
TransMetaData metadata;
|
||||||
|
for (int i = 0; i < nDims; i++) {
|
||||||
|
metadata.begNum[i] = begNos[i];
|
||||||
|
metadata.wholeNDim[i] = wholeTensor->getDims()[i];
|
||||||
|
metadata.partNDim[i] = partTensor->getDims()[i];
|
||||||
|
metadata.partStride[i] = partTensor->getStride()[i];
|
||||||
|
}
|
||||||
|
pad_slice_kernel(partTensor->getRawDataPtr<float *>(),
|
||||||
|
wholeTensor->getRawDataPtr<float *>(), metadata, nDims,
|
||||||
|
wholeTensor->size(), isPad);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class PadCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
do_compute(op->getInputs(0), op->getOutput(), as<PadObj>(op)->getPads(),
|
||||||
|
true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
do_compute(op->getOutput(), op->getInputs(0),
|
||||||
|
as<SliceObj>(op)->getStart(), false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
|
||||||
|
"Slice__CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
|
||||||
|
"Pad__CUDA_Float32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,52 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/cuda_pad_slice.h"
|
||||||
|
|
||||||
|
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||||
|
TransMetaData metaData,
|
||||||
|
int nDims) {
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = nDims - 1; i >= 0; --i) {
|
||||||
|
auto wholePos = wholeOffset % metaData.wholeNDim[i];
|
||||||
|
auto pos = wholePos - metaData.begNum[i];
|
||||||
|
// if pos belongs to pad range, then return -1
|
||||||
|
if (pos < 0 || pos >= metaData.partNDim[i])
|
||||||
|
return -1;
|
||||||
|
wholeOffset = wholeOffset / metaData.wholeNDim[i];
|
||||||
|
|
||||||
|
offset += pos * metaData.partStride[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void _pad_slice_kernel(float *part, float *whole,
|
||||||
|
TransMetaData metaData, int nDims, int num,
|
||||||
|
bool isPad) {
|
||||||
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (tid >= num)
|
||||||
|
return;
|
||||||
|
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
while (tid < num) {
|
||||||
|
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
|
||||||
|
if (isPad)
|
||||||
|
if (offset < 0)
|
||||||
|
whole[tid] = 0;
|
||||||
|
else
|
||||||
|
whole[tid] = part[offset];
|
||||||
|
else
|
||||||
|
part[offset] = whole[tid];
|
||||||
|
tid += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void pad_slice_kernel(float *partData, float *wholeData,
|
||||||
|
const TransMetaData &metadata, int nDims, int num,
|
||||||
|
bool isPad) {
|
||||||
|
int blockSize = 32 * 16;
|
||||||
|
int gridSize = (num + blockSize - 1) / blockSize;
|
||||||
|
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata,
|
||||||
|
nDims, num, isPad);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,63 @@
|
||||||
|
#include "operators/pad.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const vector<int> &_pads,
|
||||||
|
const optional<const vector<int>> &axis)
|
||||||
|
: OperatorObj(OpType::Pad, {input}, {output}) {
|
||||||
|
if (axis == std::nullopt)
|
||||||
|
pads = _pads;
|
||||||
|
else {
|
||||||
|
int nAxis = (*axis).size();
|
||||||
|
IT_ASSERT((int)_pads.size() == nAxis * 2);
|
||||||
|
int nDims = input->getDims().size();
|
||||||
|
vector<int> tmp(nDims * 2, 0);
|
||||||
|
|
||||||
|
for (int i = 0; i < nAxis; ++i) {
|
||||||
|
tmp[(*axis)[i]] = _pads[i];
|
||||||
|
tmp[(*axis)[i] + nDims] = _pads[i + nAxis];
|
||||||
|
}
|
||||||
|
pads = tmp;
|
||||||
|
}
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto dims = inputs[0]->getDims();
|
||||||
|
int nDims = dims.size();
|
||||||
|
if (nDims * 2 != (int)pads.size())
|
||||||
|
return {};
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
if (pads[i] < 0 || pads[i + nDims] < 0)
|
||||||
|
return {};
|
||||||
|
dims[i] += pads[i] + pads[i + nDims];
|
||||||
|
}
|
||||||
|
|
||||||
|
return {{dims}};
|
||||||
|
}
|
||||||
|
std::string PadObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Pad"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "pads=" << vecToString(pads) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> PadObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), pads.begin(), pads.end());
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> PadObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = pads;
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,80 @@
|
||||||
|
#include "operators/slice.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const vector<int> &starts, const vector<int> &ends,
|
||||||
|
const optional<vector<int>> &axis,
|
||||||
|
const optional<vector<int>> &steps)
|
||||||
|
: OperatorObj(OpType::Slice, {input}, {output}) {
|
||||||
|
if (steps != std::nullopt)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
IT_ASSERT(starts.size() == ends.size());
|
||||||
|
|
||||||
|
if (axis == std::nullopt) {
|
||||||
|
this->starts = starts;
|
||||||
|
this->ends = ends;
|
||||||
|
} else {
|
||||||
|
int nAxis = (*axis).size();
|
||||||
|
IT_ASSERT((int)starts.size() == nAxis);
|
||||||
|
|
||||||
|
int nDims = input->getDims().size();
|
||||||
|
vector<int> tmpS(nDims, 0), tmpE;
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
tmpE.emplace_back(input->getDims()[i] - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nAxis; ++i) {
|
||||||
|
if ((*axis)[i] < 0)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
tmpS[(*axis)[i]] = starts[i];
|
||||||
|
tmpE[(*axis)[i]] = ends[i];
|
||||||
|
}
|
||||||
|
this->starts = tmpS;
|
||||||
|
this->ends = tmpE;
|
||||||
|
}
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> SliceObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto dims = inputs[0]->getDims();
|
||||||
|
int nDims = dims.size();
|
||||||
|
if (nDims != (int)starts.size())
|
||||||
|
return {};
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
if (starts[i] < 0 || ends[i] >= dims[i] || starts[i] > ends[i])
|
||||||
|
return {};
|
||||||
|
dims[i] = ends[i] - starts[i] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {{dims}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SliceObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Slice"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "starts=" << vecToString(starts) << ",";
|
||||||
|
os << "ends=" << vecToString(ends) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> SliceObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), starts.begin(), starts.end());
|
||||||
|
ret.insert(ret.end(), ends.begin(), ends.end());
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> SliceObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = starts;
|
||||||
|
ret.insert(ret.end(), ends.begin(), ends.end());
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,41 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/pad.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Pad, Cuda) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu =
|
||||||
|
make_ref<TensorObj>(Shape{1, 2, 3, 2}, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Build CUDA graph;
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op = g->addOp<PadObj>(i, nullptr, vector<int>{1, 0, 1, 1},
|
||||||
|
vector<int>{0, 3});
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto cpuo = o->clone(cpuRuntime);
|
||||||
|
// cudaPrintTensor(o);
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(cpuo->equalData(
|
||||||
|
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,39 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(CUDA_Slice, run) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu =
|
||||||
|
make_ref<TensorObj>(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Build CUDA graph;
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op =
|
||||||
|
g->addOp<SliceObj>(i, nullptr, vector<int>{1, 1}, vector<int>{1, 4},
|
||||||
|
vector<int>{0, 3}, std::nullopt);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto cpuo = o->clone(cpuRuntime);
|
||||||
|
// cudaPrintTensor(o);
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(cpuo->equalData(vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,25 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/pad.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Pad, ShapeInference) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<PadObj>(
|
||||||
|
i, nullptr, vector<int>{2, 10, 1, 5, 0, 10, 1, 5}, std::nullopt);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<PadObj>(i, nullptr, vector<int>{2, 10, 1, 5},
|
||||||
|
vector<int>{0, 3});
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{4, 64, 162, 177}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,27 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Slice, ShapeInference) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<SliceObj>(i, nullptr, vector<int>{2, 10, 1, 5},
|
||||||
|
vector<int>{3, 10, 100, 100}, std::nullopt,
|
||||||
|
std::nullopt);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<SliceObj>(i, nullptr, vector<int>{2, 5},
|
||||||
|
vector<int>{3, 100}, vector<int>{1, 3},
|
||||||
|
std::nullopt);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{10, 2, 162, 96}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue