forked from jiuyuan/InfiniTensor
Pooling ceil mode (#155)
* add ceil mode for pooling * do not print debug info for allocator by default * fix test bugs after introducing pooling ceil mode * fix onnx import bug
This commit is contained in:
parent
785853b0a3
commit
7a9fcd93b2
|
@ -32,9 +32,9 @@ class GraphHandlerObj {
|
||||||
float momentum, float eps, bool training);
|
float momentum, float eps, bool training);
|
||||||
|
|
||||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||||
int ph, int pw, int sh, int sw);
|
int ph, int pw, int sh, int sw, int ceilMode);
|
||||||
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||||
int ph, int pw, int sh, int sw);
|
int ph, int pw, int sh, int sw, int ceilMode);
|
||||||
|
|
||||||
Tensor add(Tensor a, Tensor b, Tensor c);
|
Tensor add(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor sub(Tensor a, Tensor b, Tensor c);
|
Tensor sub(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
|
@ -12,6 +12,7 @@ class PoolingObj : public OperatorObj {
|
||||||
int dh, dw;
|
int dh, dw;
|
||||||
int ph, pw;
|
int ph, pw;
|
||||||
int sh, sw;
|
int sh, sw;
|
||||||
|
int ceilMode;
|
||||||
int n, c, h, w;
|
int n, c, h, w;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -32,9 +33,12 @@ class PoolingObj : public OperatorObj {
|
||||||
* @param pw Padding at the width dimension.
|
* @param pw Padding at the width dimension.
|
||||||
* @param sh Stride at the height dimension.
|
* @param sh Stride at the height dimension.
|
||||||
* @param sw Stride at the width dimension.
|
* @param sw Stride at the width dimension.
|
||||||
|
* @param ceilMode Whether to use ceil(1) or floor(0) to compute the output
|
||||||
|
* shape.
|
||||||
*/
|
*/
|
||||||
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
|
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
|
||||||
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw);
|
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw,
|
||||||
|
int ceilMode);
|
||||||
OP_CLONE(PoolingObj);
|
OP_CLONE(PoolingObj);
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
@ -50,6 +54,7 @@ class PoolingObj : public OperatorObj {
|
||||||
int getPw() const { return pw; }
|
int getPw() const { return pw; }
|
||||||
int getSh() const { return sh; }
|
int getSh() const { return sh; }
|
||||||
int getSw() const { return sw; }
|
int getSw() const { return sw; }
|
||||||
|
int getCeilMode() const { return ceilMode; }
|
||||||
|
|
||||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||||
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
||||||
|
@ -62,15 +67,15 @@ class PoolingObj : public OperatorObj {
|
||||||
class MaxPoolObj : public PoolingObj {
|
class MaxPoolObj : public PoolingObj {
|
||||||
public:
|
public:
|
||||||
MaxPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
MaxPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh, int sw)
|
int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode)
|
||||||
: PoolingObj(graph, OpType::MaxPool, input, output, kh, kw, dh, dw, ph,
|
: PoolingObj(graph, OpType::MaxPool, input, output, kh, kw, dh, dw, ph,
|
||||||
pw, sh, sw) {}
|
pw, sh, sw, ceilMode) {}
|
||||||
};
|
};
|
||||||
class AvgPoolObj : public PoolingObj {
|
class AvgPoolObj : public PoolingObj {
|
||||||
public:
|
public:
|
||||||
AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh, int sw)
|
int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode)
|
||||||
: PoolingObj(graph, OpType::AveragePool, input, output, kh, kw, dh, dw,
|
: PoolingObj(graph, OpType::AveragePool, input, output, kh, kw, dh, dw,
|
||||||
ph, pw, sh, sw) {}
|
ph, pw, sh, sw, ceilMode) {}
|
||||||
};
|
};
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -228,11 +228,12 @@ class OnnxStub:
|
||||||
"dilations": [1, 1],
|
"dilations": [1, 1],
|
||||||
"pads": [0, 0, 0, 0],
|
"pads": [0, 0, 0, 0],
|
||||||
"strides": [1, 1],
|
"strides": [1, 1],
|
||||||
|
"ceil_mode": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
(k, d, p, s) = (
|
(k, d, p, s, ceil_mode) = (
|
||||||
attributes[name]
|
attributes[name]
|
||||||
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
for name in ["kernel_shape", "dilations", "pads", "strides", "ceil_mode"]
|
||||||
)
|
)
|
||||||
if p[0] != p[2] or p[1] != p[3]:
|
if p[0] != p[2] or p[1] != p[3]:
|
||||||
adapt = "{}-adapt".format(node.output[0])
|
adapt = "{}-adapt".format(node.output[0])
|
||||||
|
@ -250,6 +251,7 @@ class OnnxStub:
|
||||||
0,
|
0,
|
||||||
s[0],
|
s[0],
|
||||||
s[1],
|
s[1],
|
||||||
|
ceil_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tensors[node.output[0]] = self.handler.maxPool(
|
tensors[node.output[0]] = self.handler.maxPool(
|
||||||
|
@ -263,6 +265,7 @@ class OnnxStub:
|
||||||
p[1],
|
p[1],
|
||||||
s[0],
|
s[0],
|
||||||
s[1],
|
s[1],
|
||||||
|
ceil_mode,
|
||||||
)
|
)
|
||||||
elif node.op_type == "AveragePool":
|
elif node.op_type == "AveragePool":
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
|
@ -271,10 +274,11 @@ class OnnxStub:
|
||||||
"kernel_shape": None,
|
"kernel_shape": None,
|
||||||
"pads": [0, 0, 0, 0],
|
"pads": [0, 0, 0, 0],
|
||||||
"strides": [1, 1],
|
"strides": [1, 1],
|
||||||
|
"ceil_mode": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
(k, p, s) = (
|
(k, p, s, ceil_mode) = (
|
||||||
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
attributes[name] for name in ["kernel_shape", "pads", "strides", "ceil_mode"]
|
||||||
)
|
)
|
||||||
if p[0] != p[2] or p[1] != p[3]:
|
if p[0] != p[2] or p[1] != p[3]:
|
||||||
adapt = "{}-adapt".format(node.output[0])
|
adapt = "{}-adapt".format(node.output[0])
|
||||||
|
@ -292,6 +296,7 @@ class OnnxStub:
|
||||||
0,
|
0,
|
||||||
s[0],
|
s[0],
|
||||||
s[1],
|
s[1],
|
||||||
|
ceil_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tensors[node.output[0]] = self.handler.avgPool(
|
tensors[node.output[0]] = self.handler.avgPool(
|
||||||
|
@ -305,6 +310,7 @@ class OnnxStub:
|
||||||
p[1],
|
p[1],
|
||||||
s[0],
|
s[0],
|
||||||
s[1],
|
s[1],
|
||||||
|
ceil_mode,
|
||||||
)
|
)
|
||||||
elif node.op_type == "GlobalAveragePool":
|
elif node.op_type == "GlobalAveragePool":
|
||||||
[_, _, h, w] = _search_shape(model, node.input[0])
|
[_, _, h, w] = _search_shape(model, node.input[0])
|
||||||
|
@ -319,6 +325,7 @@ class OnnxStub:
|
||||||
0,
|
0,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Add":
|
elif node.op_type == "Add":
|
||||||
tensors[node.output[0]] = self.handler.add(
|
tensors[node.output[0]] = self.handler.add(
|
||||||
|
@ -866,7 +873,7 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif ty == backend.OpTypeId.MaxPool:
|
elif ty == backend.OpTypeId.MaxPool:
|
||||||
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
kh, kw, dh, dw, ph, pw, sh, sw, ceil_mode = backend.pool_attrs_of(op)
|
||||||
ctx.push_node(
|
ctx.push_node(
|
||||||
make_node(
|
make_node(
|
||||||
ty.name,
|
ty.name,
|
||||||
|
@ -877,10 +884,11 @@ class OnnxStub:
|
||||||
pads=[ph, pw, ph, pw],
|
pads=[ph, pw, ph, pw],
|
||||||
dilations=[dh, dw],
|
dilations=[dh, dw],
|
||||||
strides=[sh, sw],
|
strides=[sh, sw],
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif ty == backend.OpTypeId.AveragePool:
|
elif ty == backend.OpTypeId.AveragePool:
|
||||||
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
kh, kw, dh, dw, ph, pw, sh, sw, ceil_mode = backend.pool_attrs_of(op)
|
||||||
ctx.push_node(
|
ctx.push_node(
|
||||||
make_node(
|
make_node(
|
||||||
"AveragePool",
|
"AveragePool",
|
||||||
|
@ -890,6 +898,7 @@ class OnnxStub:
|
||||||
kernel_shape=[kh, kw],
|
kernel_shape=[kh, kw],
|
||||||
pads=[ph, pw, ph, pw],
|
pads=[ph, pw, ph, pw],
|
||||||
strides=[sh, sw],
|
strides=[sh, sw],
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif ty in [
|
elif ty in [
|
||||||
|
|
|
@ -210,10 +210,6 @@ void GraphObj::dataMalloc() {
|
||||||
tensorToOffset[tensor.get()]));
|
tensorToOffset[tensor.get()]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef DEBUG_MODE
|
|
||||||
allocator.info();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||||
|
|
|
@ -95,30 +95,30 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh,
|
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||||
int sw) {
|
int ceilMode) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
|
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
|
||||||
dw, ph, pw, sh, sw);
|
dw, ph, pw, sh, sw, ceilMode);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||||
pw, sh, sw)
|
pw, sh, sw, ceilMode)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
|
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh,
|
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||||
int sw) {
|
int ceilMode) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
|
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
|
||||||
dw, ph, pw, sh, sw);
|
dw, ph, pw, sh, sw, ceilMode);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||||
pw, sh, sw)
|
pw, sh, sw, ceilMode)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,10 +145,10 @@ void LazyAllocator::free(size_t addr, size_t size) {
|
||||||
void *LazyAllocator::getPtr() {
|
void *LazyAllocator::getPtr() {
|
||||||
if (this->ptr == nullptr) {
|
if (this->ptr == nullptr) {
|
||||||
this->ptr = runtime->alloc(this->peak);
|
this->ptr = runtime->alloc(this->peak);
|
||||||
#ifdef DEBUG_MODE
|
// #ifdef DEBUG_MODE
|
||||||
printf("LazyAllocator really alloc non-weight: %p %lu bytes\n",
|
// printf("LazyAllocator really alloc non-weight: %p %lu
|
||||||
this->ptr, peak);
|
// bytes\n", this->ptr, peak);
|
||||||
#endif
|
// #endif
|
||||||
}
|
}
|
||||||
return this->ptr;
|
return this->ptr;
|
||||||
}
|
}
|
||||||
|
@ -156,10 +156,10 @@ void *LazyAllocator::getPtr() {
|
||||||
void *LazyAllocator::getWeightPtr() {
|
void *LazyAllocator::getWeightPtr() {
|
||||||
if (this->weightPtr == nullptr) {
|
if (this->weightPtr == nullptr) {
|
||||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||||
#ifdef DEBUG_MODE
|
// #ifdef DEBUG_MODE
|
||||||
printf("LazyAllocator really alloc weight: %p %lu bytes\n",
|
// printf("LazyAllocator really alloc weight: %p %lu bytes\n",
|
||||||
this->weightPtr, weightPeak);
|
// this->weightPtr, weightPeak);
|
||||||
#endif
|
// #endif
|
||||||
}
|
}
|
||||||
return this->weightPtr;
|
return this->weightPtr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,14 +187,14 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||||
batchnorm->getTrainingMode());
|
batchnorm->getTrainingMode());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<int, int, int, int, int, int, int, int>
|
static std::tuple<int, int, int, int, int, int, int, int, int>
|
||||||
pool_attrs_of(Operator op) {
|
pool_attrs_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
|
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
|
||||||
op->getOpType() == OpType::AveragePool);
|
op->getOpType() == OpType::AveragePool);
|
||||||
auto pool = dynamic_cast<const PoolingObj *>(op.get());
|
auto pool = dynamic_cast<const PoolingObj *>(op.get());
|
||||||
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
|
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
|
||||||
pool->getDw(), pool->getPh(), pool->getPw(),
|
pool->getDw(), pool->getPh(), pool->getPw(),
|
||||||
pool->getSh(), pool->getSw());
|
pool->getSh(), pool->getSw(), pool->getCeilMode());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<std::optional<float>, std::optional<float>>
|
static std::tuple<std::optional<float>, std::optional<float>>
|
||||||
|
|
|
@ -30,6 +30,7 @@ class PoolingCnnl : public BangKernelWithoutConfig {
|
||||||
ph, pw, pw, sh, sw, dh, dw, false));
|
ph, pw, pw, sh, sw, dh, dw, false));
|
||||||
|
|
||||||
// get outputs
|
// get outputs
|
||||||
|
// TODO: verify ceiling mode
|
||||||
auto outVec = op->getOutput()->getDims();
|
auto outVec = op->getOutput()->getDims();
|
||||||
int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]};
|
int outArray[4] = {outVec[0], outVec[1], outVec[2], outVec[3]};
|
||||||
cnnlTensorDescriptor_t outDesc;
|
cnnlTensorDescriptor_t outDesc;
|
||||||
|
|
|
@ -21,6 +21,7 @@ template <typename T> class NativePooling : public CpuKernelWithoutConfig {
|
||||||
auto inoffset = i * (c * ih * iw) + j * ih * iw;
|
auto inoffset = i * (c * ih * iw) + j * ih * iw;
|
||||||
for (auto h = 0; h < oh; h++) {
|
for (auto h = 0; h < oh; h++) {
|
||||||
for (auto w = 0; w < ow; w++) {
|
for (auto w = 0; w < ow; w++) {
|
||||||
|
// TODO: verify ceil mode
|
||||||
T val =
|
T val =
|
||||||
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
|
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
|
||||||
ih, iw, inptr + inoffset);
|
ih, iw, inptr + inoffset);
|
||||||
|
|
|
@ -29,17 +29,27 @@ class poolingCudnn : public CudaKernelWithoutConfig {
|
||||||
pw, sh, sw));
|
pw, sh, sw));
|
||||||
|
|
||||||
// get outputs
|
// get outputs
|
||||||
int outn, outc, outh, outw;
|
auto outDims = op->getOutput()->getDims();
|
||||||
checkCudnnError(cudnnGetPooling2dForwardOutputDim(
|
int outn = outDims[0], outc = outDims[1], outh = outDims[2],
|
||||||
poolingDesc, inDesc, &outn, &outc, &outh, &outw));
|
outw = outDims[3];
|
||||||
|
// NOTICE: cudnn pooling does not support ceil mode, so the shape
|
||||||
|
// inference of cudnn pooling is not consistant with our framework. Ceil
|
||||||
|
// mode is also supported in Pytorch and ONNX. See
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||||
|
// and https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
|
||||||
|
// for reference.
|
||||||
|
// TODO: Make sure the result after considering ceil mode is correct.
|
||||||
|
// int outn, outc, outh, outw;
|
||||||
|
// checkCudnnError(cudnnGetPooling2dForwardOutputDim(poolingDesc,
|
||||||
|
// inDesc, &outn, &outc, &outh, &outw));
|
||||||
cudnnTensorDescriptor_t outDesc;
|
cudnnTensorDescriptor_t outDesc;
|
||||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||||
CUDNN_DATA_FLOAT, outn, outc,
|
CUDNN_DATA_FLOAT, outn, outc,
|
||||||
outh, outw));
|
outh, outw));
|
||||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
// IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||||
op->getOutput()->getDims(),
|
// op->getOutput()->getDims(),
|
||||||
"cuDNN output shape mismatches with OP output shape");
|
// "cuDNN output shape mismatches with OP output shape");
|
||||||
|
|
||||||
float alpha = 1.f, beta = 0.f;
|
float alpha = 1.f, beta = 0.f;
|
||||||
checkCudnnError(cudnnPoolingForward(context->cudnnHandle(), poolingDesc,
|
checkCudnnError(cudnnPoolingForward(context->cudnnHandle(), poolingDesc,
|
||||||
|
|
|
@ -4,11 +4,9 @@ namespace infini {
|
||||||
|
|
||||||
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
||||||
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
||||||
int pw, int sh, int sw)
|
int pw, int sh, int sw, int ceilMode)
|
||||||
: OperatorObj(optype, {input}, {output}),
|
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw),
|
||||||
|
ph(ph), pw(pw), sh(sh), sw(sw), ceilMode(ceilMode),
|
||||||
kh(kh), kw(kw), dh(dh), dw(dw), ph(ph), pw(pw), sh(sh), sw(sw),
|
|
||||||
|
|
||||||
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
|
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
|
||||||
w(input->getDims()[3]) {
|
w(input->getDims()[3]) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
|
@ -18,8 +16,14 @@ optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
|
||||||
const auto &input = inputs[0];
|
const auto &input = inputs[0];
|
||||||
auto h = input->getDims()[input->getRank() - 2],
|
auto h = input->getDims()[input->getRank() - 2],
|
||||||
w = input->getDims()[input->getRank() - 1];
|
w = input->getDims()[input->getRank() - 1];
|
||||||
int oh = (h - (kh - sh) + ph * 2) / sh;
|
int oh, ow;
|
||||||
int ow = (w - (kw - sw) + pw * 2) / sw;
|
if (ceilMode) {
|
||||||
|
oh = ceil(((float)(h + 2 * ph - dh * (kh - 1) - 1)) / sh + 1);
|
||||||
|
ow = ceil(((float)(w + 2 * pw - dw * (kw - 1) - 1)) / sw + 1);
|
||||||
|
} else {
|
||||||
|
oh = floor(((float)(h + 2 * ph - dh * (kh - 1) - 1)) / sh + 1);
|
||||||
|
ow = floor(((float)(w + 2 * pw - dw * (kw - 1) - 1)) / sw + 1);
|
||||||
|
}
|
||||||
auto ret = input->getDims();
|
auto ret = input->getDims();
|
||||||
ret[input->getRank() - 2] = oh;
|
ret[input->getRank() - 2] = oh;
|
||||||
ret[input->getRank() - 1] = ow;
|
ret[input->getRank() - 1] = ow;
|
||||||
|
@ -34,17 +38,19 @@ std::string PoolingObj::toString() const {
|
||||||
os << "p=[" << ph << "," << pw << "],";
|
os << "p=[" << ph << "," << pw << "],";
|
||||||
os << "s=[" << sh << "," << sw << "],";
|
os << "s=[" << sh << "," << sw << "],";
|
||||||
os << "d=[" << dh << "," << dw << "],";
|
os << "d=[" << dh << "," << dw << "],";
|
||||||
|
os << "ceil mode=" << ceilMode << ",";
|
||||||
os << "input=" << inputs[0]->getGuid() << ",";
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
os << "output=" << outputs[0]->getGuid() << ")";
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> PoolingObj::getWorkloadVector() const {
|
vector<int> PoolingObj::getWorkloadVector() const {
|
||||||
return {type.underlying(), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
|
return {type.underlying(), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw,
|
||||||
|
ceilMode};
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> PoolingObj::getOpAttrVector() const {
|
vector<int> PoolingObj::getOpAttrVector() const {
|
||||||
return {type.underlying(), kh, kw, ph, pw, sh, sw, dh, dw};
|
return {type.underlying(), kh, kw, ph, pw, sh, sw, dh, dw, ceilMode};
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -208,16 +208,13 @@ TEST(MatchGraph, multi_output) {
|
||||||
SubGraph subg0 = make_ref<SubGraphObj>(runtime, TensorVec{i});
|
SubGraph subg0 = make_ref<SubGraphObj>(runtime, TensorVec{i});
|
||||||
{
|
{
|
||||||
auto maxpool =
|
auto maxpool =
|
||||||
subg0->addOp<MaxPoolObj>(i, nullptr, 3, 3, 0, 0, 0, 0, 2, 2);
|
subg0->addOp<MaxPoolObj>(i, nullptr, 3, 3, 1, 1, 0, 0, 2, 2, 0);
|
||||||
Tensor w0 = subg0->addTensor(Shape{64, 192, 1, 1}, DataType::UInt32);
|
Tensor w0 = subg0->addTensor(Shape{64, 192, 1, 1}, DataType::UInt32);
|
||||||
auto conv0 = subg0->addOp<ConvObj>(maxpool->getOutput(0), w0, nullptr);
|
auto conv0 = subg0->addOp<ConvObj>(maxpool->getOutput(0), w0, nullptr);
|
||||||
auto relu0 = subg0->addOp<ReluObj>(conv0->getOutput(0), nullptr);
|
auto relu0 = subg0->addOp<ReluObj>(conv0->getOutput(0), nullptr);
|
||||||
|
|
||||||
auto pad = subg0->addOp<PadObj>(maxpool->getOutput(0), nullptr,
|
auto avgpool = subg0->addOp<AvgPoolObj>(maxpool->getOutput(0), nullptr,
|
||||||
vector<int>{0, 0, 1, 1, 0, 0, 1, 1},
|
3, 3, 0, 0, 0, 0, 1, 1, 0);
|
||||||
std::nullopt);
|
|
||||||
auto avgpool = subg0->addOp<AvgPoolObj>(pad->getOutput(0), nullptr, 3,
|
|
||||||
3, 0, 0, 0, 0, 1, 1);
|
|
||||||
subg0->setOutputs(
|
subg0->setOutputs(
|
||||||
TensorVec{relu0->getOutput(0), avgpool->getOutput(0)});
|
TensorVec{relu0->getOutput(0), avgpool->getOutput(0)});
|
||||||
}
|
}
|
||||||
|
@ -225,8 +222,9 @@ TEST(MatchGraph, multi_output) {
|
||||||
SubGraph subg1 =
|
SubGraph subg1 =
|
||||||
make_ref<SubGraphObj>(runtime, TensorVec{i->clone(runtime)});
|
make_ref<SubGraphObj>(runtime, TensorVec{i->clone(runtime)});
|
||||||
{
|
{
|
||||||
auto avgpool = subg1->addOp<AvgPoolObj>(
|
auto avgpool =
|
||||||
subg1->getInputsFromOutside()[0], nullptr, 3, 3, 0, 0, 0, 0, 2, 2);
|
subg1->addOp<AvgPoolObj>(subg1->getInputsFromOutside()[0], nullptr,
|
||||||
|
3, 3, 1, 1, 0, 0, 2, 2, 0);
|
||||||
|
|
||||||
auto relu0 = subg1->addOp<ReluObj>(avgpool->getOutput(0), nullptr);
|
auto relu0 = subg1->addOp<ReluObj>(avgpool->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
@ -295,7 +293,7 @@ TEST(MatchGraph, multi_input_output) {
|
||||||
Tensor w2 = subg0->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
Tensor w2 = subg0->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
auto conv2 = subg0->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
auto conv2 = subg0->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
auto maxpool = subg0->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
auto maxpool = subg0->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
3, 0, 0, 0, 0, 2, 2);
|
3, 1, 1, 0, 0, 2, 2, 0);
|
||||||
subg0->setOutputs(
|
subg0->setOutputs(
|
||||||
TensorVec{conv2->getOutput(0), maxpool->getOutput(0)});
|
TensorVec{conv2->getOutput(0), maxpool->getOutput(0)});
|
||||||
}
|
}
|
||||||
|
@ -317,7 +315,7 @@ TEST(MatchGraph, multi_input_output) {
|
||||||
Tensor w2 = subg1->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
Tensor w2 = subg1->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
auto conv2 = subg1->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
auto conv2 = subg1->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
auto maxpool = subg1->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
auto maxpool = subg1->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
3, 0, 0, 0, 0, 2, 2);
|
3, 1, 1, 0, 0, 2, 2, 0);
|
||||||
subg1->setOutputs(
|
subg1->setOutputs(
|
||||||
TensorVec{maxpool->getOutput(0), conv2->getOutput(0)});
|
TensorVec{maxpool->getOutput(0), conv2->getOutput(0)});
|
||||||
}
|
}
|
||||||
|
@ -338,7 +336,7 @@ TEST(MatchGraph, multi_input_output) {
|
||||||
Tensor w2 = subg2->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
Tensor w2 = subg2->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
auto conv2 = subg2->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
auto conv2 = subg2->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
auto avgpool = subg2->addOp<AvgPoolObj>(relu1->getOutput(0), nullptr, 3,
|
auto avgpool = subg2->addOp<AvgPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
3, 0, 0, 0, 0, 2, 2);
|
3, 1, 1, 0, 0, 2, 2, 0);
|
||||||
subg2->setOutputs(
|
subg2->setOutputs(
|
||||||
TensorVec{conv2->getOutput(0), avgpool->getOutput(0)});
|
TensorVec{conv2->getOutput(0), avgpool->getOutput(0)});
|
||||||
}
|
}
|
||||||
|
@ -349,7 +347,7 @@ TEST(MatchGraph, multi_input_output) {
|
||||||
auto i = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32);
|
auto i = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32);
|
||||||
auto relu = g->addOp<ReluObj>(i, nullptr);
|
auto relu = g->addOp<ReluObj>(i, nullptr);
|
||||||
auto maxPool = g->addOp<MaxPoolObj>(relu->getOutput(0), nullptr, 3, 3,
|
auto maxPool = g->addOp<MaxPoolObj>(relu->getOutput(0), nullptr, 3, 3,
|
||||||
0, 0, 1, 1, 2, 2);
|
1, 1, 1, 1, 2, 2, 0);
|
||||||
auto out0 =
|
auto out0 =
|
||||||
v.addSubGraph(subg0, {relu->getOutput(0), maxPool->getOutput(0)});
|
v.addSubGraph(subg0, {relu->getOutput(0), maxPool->getOutput(0)});
|
||||||
auto out1 =
|
auto out1 =
|
||||||
|
|
|
@ -8,7 +8,8 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
template <class T>
|
template <class T, typename std::enable_if<std::is_base_of<PoolingObj, T>{},
|
||||||
|
int>::type = 0>
|
||||||
void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
const Shape &shape) {
|
const Shape &shape) {
|
||||||
// Runtime
|
// Runtime
|
||||||
|
@ -23,7 +24,8 @@ void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
// GPU
|
// GPU
|
||||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
auto gpuOp =
|
||||||
|
bangGraph->addOp<T>(inputGpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||||
bangGraph->dataMalloc();
|
bangGraph->dataMalloc();
|
||||||
bangRuntime->run(bangGraph);
|
bangRuntime->run(bangGraph);
|
||||||
auto outputGpu = gpuOp->getOutput();
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
|
|
@ -29,7 +29,7 @@ TEST(CUDA_Inception_v3_block, run) {
|
||||||
TensorVec outputs;
|
TensorVec outputs;
|
||||||
vector<OpVec> ops;
|
vector<OpVec> ops;
|
||||||
auto maxpool =
|
auto maxpool =
|
||||||
g->addOp<MaxPoolObj>(blockInput, nullptr, 3, 3, 1, 1, 1, 1, 1, 1);
|
g->addOp<MaxPoolObj>(blockInput, nullptr, 3, 3, 1, 1, 1, 1, 1, 1, 0);
|
||||||
auto chainInput = maxpool->getOutput();
|
auto chainInput = maxpool->getOutput();
|
||||||
for (auto &pathConfig : configs) {
|
for (auto &pathConfig : configs) {
|
||||||
int inputChannels = initialChannels;
|
int inputChannels = initialChannels;
|
||||||
|
@ -52,7 +52,7 @@ TEST(CUDA_Inception_v3_block, run) {
|
||||||
inputChannels = f;
|
inputChannels = f;
|
||||||
} else { // Add AveragePool
|
} else { // Add AveragePool
|
||||||
auto pool = g->addOp<AvgPoolObj>(input, nullptr, r, r, 1, 1,
|
auto pool = g->addOp<AvgPoolObj>(input, nullptr, r, r, 1, 1,
|
||||||
r / 2, r / 2, 1, 1);
|
r / 2, r / 2, 1, 1, 0);
|
||||||
input = pool->getOutput();
|
input = pool->getOutput();
|
||||||
ops.back().emplace_back(pool);
|
ops.back().emplace_back(pool);
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,8 @@ namespace infini {
|
||||||
using KDPS = vector<int>;
|
using KDPS = vector<int>;
|
||||||
using ExpectOutput = vector<float>;
|
using ExpectOutput = vector<float>;
|
||||||
|
|
||||||
template <class T>
|
template <class T, typename std::enable_if<std::is_base_of<PoolingObj, T>{},
|
||||||
|
int>::type = 0>
|
||||||
void testPoolCudnn(
|
void testPoolCudnn(
|
||||||
const std::function<void(void *, size_t, DataType)> &generator,
|
const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
||||||
|
@ -24,7 +25,7 @@ void testPoolCudnn(
|
||||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
auto i0 = g->cloneTensor(i0cpu);
|
auto i0 = g->cloneTensor(i0cpu);
|
||||||
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
||||||
kdps[4], kdps[5], kdps[6], kdps[7]);
|
kdps[4], kdps[5], kdps[6], kdps[7], 0);
|
||||||
|
|
||||||
// allocate CUDA memory
|
// allocate CUDA memory
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
|
|
|
@ -12,16 +12,16 @@ TEST(MaxPool, ShapeInference) {
|
||||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
||||||
sw = 2;
|
sw = 2, ceilMode = 0;
|
||||||
auto op =
|
auto op = g->addOp<MaxPoolObj>(i, nullptr, kh, kw, dh, dw, ph, pw, sh,
|
||||||
g->addOp<MaxPoolObj>(i, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
sw, ceilMode);
|
||||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
||||||
}
|
}
|
||||||
|
|
||||||
{ // dilation & stride
|
{ // dilation & stride
|
||||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
auto op = g->addOp<MaxPoolObj>(i, nullptr, 4, 3, 1, 1, 2, 1, 1, 2);
|
auto op = g->addOp<MaxPoolObj>(i, nullptr, 4, 3, 1, 1, 2, 1, 1, 2, 0);
|
||||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 163, 81}));
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 163, 81}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ TEST(MaxPool, NaiveCPU) {
|
||||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32);
|
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32);
|
||||||
auto op = g->addOp<MaxPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
auto op = g->addOp<MaxPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||||
|
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i->setData(IncrementalGenerator());
|
i->setData(IncrementalGenerator());
|
||||||
|
@ -49,7 +49,7 @@ TEST(AvgPool, NaiveCPU) {
|
||||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32);
|
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32);
|
||||||
auto op = g->addOp<AvgPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
auto op = g->addOp<AvgPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||||
|
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i->setData(IncrementalGenerator());
|
i->setData(IncrementalGenerator());
|
||||||
|
|
Loading…
Reference in New Issue