From 9303ddda8eac88962edc3817c740e7327cbe2a92 Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Wed, 17 Aug 2022 14:16:01 +0800 Subject: [PATCH] Add Conv operator and naive CPU implemenation (#5) * Add: Conv definition * Add: tensor copy data from vector * Add: CPU conv kernel * Fix: replace Int32 with UInt32 in DataType Co-authored-by: Liyan Zheng --- CMakeLists.txt | 1 + include/core/common.h | 12 +++++ include/core/graph.h | 2 +- include/core/tensor.h | 5 ++ include/core/tensor_base.h | 2 +- include/operators/conv.h | 68 ++++++++++++++++++++++++ include/test.h | 32 ++++++++++++ src/core/run_engine.cc | 4 +- src/core/tensor.cc | 4 ++ src/kerels/cpu/conv.cc | 60 +++++++++++++++++++++ src/kerels/cpu/matmul.cc | 2 +- src/operators/conv.cc | 101 ++++++++++++++++++++++++++++++++++++ test/core/test_graph.cc | 26 +++++----- test/core/test_hash.cc | 10 ++-- test/operators/test_conv.cc | 63 ++++++++++++++++++++++ 15 files changed, 369 insertions(+), 23 deletions(-) create mode 100644 include/operators/conv.h create mode 100644 src/kerels/cpu/conv.cc create mode 100644 src/operators/conv.cc create mode 100644 test/operators/test_conv.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index f1079f65..70935137 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,6 +81,7 @@ if(BUILD_TEST) enable_testing() if(BUILD_TEST_CORE) build_test(test/core/*.cc) + build_test(test/operators/*.cc) endif() if(BUILD_TEST_PET) build_test(test/pet/*.cc) diff --git a/include/core/common.h b/include/core/common.h index 46eb6922..54ad8065 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -57,6 +57,18 @@ template auto enum_to_underlying(T e) { return static_cast>(e); } +template std::string vecToString(const std::vector &vec) { + std::string ret; + ret.append("["); + for (auto d : vec) { + ret.append(std::to_string(d)); + ret.append(", "); + } + ret.pop_back(); + ret.append("]"); + return ret; +} + double timeit(const std::function &func); } // namespace infini diff --git a/include/core/graph.h b/include/core/graph.h index 8cb8cea1..7abe5f0e 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -16,7 +16,7 @@ class GraphObj : public Object { // Graph(OpVec oplist); string toString() const override; - Tensor addTensor(Shape dim, DataType dtype = DataType::Int32); + Tensor addTensor(Shape dim, DataType dtype = DataType::UInt32); /** * @brief Add an operator and create its outputs. Output tensor arguments diff --git a/include/core/tensor.h b/include/core/tensor.h index e4bbc6c2..dcdd4aef 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -24,8 +24,13 @@ class TensorObj : public TensorBaseObj { using TensorBaseObj::getData; VType getData(const Shape &pos) const; void copyData(VType *dptr); + void copyData(vector dataVector); void printData() const; bool equalData(const Tensor &rhs) const; + void + setData(std::function generator) const { + generator((void *)(data.get()), size(), dtype); + } // void setDims(const Dim &dms) { dims = dms; } // bool dataRand(int seed = 0) { diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index eefd300f..d746f273 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -22,7 +22,7 @@ using VType = uint32_t; enum class DataType { Float32, - Int32, + UInt32, }; class TensorBaseObj : public Object { diff --git a/include/operators/conv.h b/include/operators/conv.h new file mode 100644 index 00000000..a2b80815 --- /dev/null +++ b/include/operators/conv.h @@ -0,0 +1,68 @@ +#pragma once +#include "core/operator.h" + +namespace infini { + +class ConvObj : public OperatorObj { + public: + // When PaddingMode is Other, ConvObj will use padding size (ph, pw) + // Otherwise, padding size (ph, pw) will be computed by padding mode + enum class PaddingMode { + Other, + Same, + Valid, + }; + + private: + int ph, pw; + int sh, sw; + int dh, dw; + ActType act; + PaddingMode padding; + // auxiliary attributes + int n, c, h, w, f, r, s; + + public: + // Constructors for explicitly setting padding size + ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph, + int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1, + Tensor bias = nullptr, ActType act = ActType::None); + // Constructors for setting padding mode + ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1, + int dh = 1, int dw = 1, Tensor bias = nullptr, + ActType act = ActType::None); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 3; } + int numOutputs() const override { return 1; } + + Tensor getBias() const { return inputs[2]; } + ActType getAct() const { return act; } + PaddingMode getPaddingMode() const { return padding; } + pair inferPaddingSize() const; + + int getDh() const { return dh; } + int getDw() const { return dw; } + int getPh() const { return ph; } + int getPw() const { return pw; } + int getSh() const { return sh; } + int getSw() const { return sw; } + auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); } + auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } + int getChannelPerGroup() const { return inputs[1]->getDims()[1]; } + int getNumGroups() const { return c / getChannelPerGroup(); } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + /** + * @brief Set the Auxilary Attributes: nchwrfs and padding (ph, pw) if + * padding mode is set. This function should be called in constructor. + */ + void setAuxilaryAttributes(PaddingMode mode); +}; + +} // namespace infini diff --git a/include/test.h b/include/test.h index 05bcdcb1..ad8fc8a5 100644 --- a/include/test.h +++ b/include/test.h @@ -1,3 +1,35 @@ #pragma once #include "core/common.h" +#include "core/tensor_base.h" #include "gtest/gtest.h" + +namespace infini { + +class DataGenerator { + private: + virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }; + virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }; + + public: + void operator()(void *data, size_t size, DataType dataType) { + switch (dataType) { + case DataType::UInt32: + fill(reinterpret_cast(data), size); + break; + case DataType::Float32: + fill(reinterpret_cast(data), size); + break; + default: + IT_TODO_HALT(); + } + } +}; + +class IncrementalGenerator : public DataGenerator { + void fill(uint32_t *data, size_t size) override { + for (size_t i = 0; i < size; i++) { + data[i] = i; + } + } +}; +} // namespace infini \ No newline at end of file diff --git a/src/core/run_engine.cc b/src/core/run_engine.cc index ba6878bc..eace9218 100644 --- a/src/core/run_engine.cc +++ b/src/core/run_engine.cc @@ -17,7 +17,7 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const { for (auto &op : graph->getOperators()) { // HACK: set correct data type auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::Int32}; + KernelAttrs{device, op->getOpType(), DataType::UInt32}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); @@ -64,7 +64,7 @@ double RunEngine::getPerfTime(const Graph &graph, bool profiling) const { for (auto &op : graph->getOperators()) { // HACK: set correct data type auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::Int32}; + KernelAttrs{device, op->getOpType(), DataType::UInt32}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 41aa0aac..c6aac89d 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -45,6 +45,10 @@ void TensorObj::copyData(VType *dptr) { data[i] = dptr[i]; } } +void TensorObj::copyData(vector dataVector) { + IT_ASSERT(dataVector.size() >= size()); + copyData(dataVector.data()); +} void TensorObj::printData() const { IT_ASSERT(data != nullptr); diff --git a/src/kerels/cpu/conv.cc b/src/kerels/cpu/conv.cc new file mode 100644 index 00000000..f68479d4 --- /dev/null +++ b/src/kerels/cpu/conv.cc @@ -0,0 +1,60 @@ +#include "operators/conv.h" +#include "core/kernel.h" + +namespace infini { + +template class NaiveConv : public Kernel { + void compute(const Operator &_op, const PerfRecord &record) const override { + auto op = as(_op); + T *iptr = reinterpret_cast(op->getInputs(0)->getDataPtr().get()); + T *wptr = reinterpret_cast(op->getInputs(1)->getDataPtr().get()); + T *optr = reinterpret_cast(op->getOutput()->getDataPtr().get()); + auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + int cpg = op->getChannelPerGroup(); + int g = op->getNumGroups(); + IT_ASSERT(f % g == 0, "Illegal number of channel"); + auto outDim = op->getOutput()->getDims(); + int oh = outDim[2], ow = outDim[3]; + for (int nn = 0; nn < n; nn++) { +#pragma omp parallel for + for (int ff = 0; ff < f; ff++) { + for (int hh = 0; hh < oh; hh++) + for (int ww = 0; ww < ow; ww++) { + int gidx = ff / (f / g); + VType val = 0; + for (int cc = 0; cc < cpg; cc++) + for (int rr = 0; rr < r; rr++) + for (int ss = 0; ss < s; ss++) { + // clang-format off + int posH = hh * sh + rr * dh - ph; + int posW = ww * sw + ss * dw - pw; + if (posH < 0 || posH >= h || posW < 0 || posW >= w) + continue; + auto iOffset = posW + w * (posH + h * ((cc + gidx * cpg) + c * nn)), + wOffset = ss + s * (rr + r * (cc + cpg * ff)); + auto inputVal = iptr[iOffset], weightVal = wptr[wOffset]; + val += weightVal * inputVal; + // clang-format on + } + // TODO: check correctness, oh & ow or h & w? + auto oOffset = ww + ow * (hh + oh * (ff + f * nn)); + optr[oOffset] = val; + } + } + } + } + + void compute(const Operator &op) const override { compute(op, {}); } + + PerfRecord tune(const Operator &op) const override { + return PerfRecord{.time = timeit([this, &op]() { compute(op); })}; + } +}; + +REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32, + NaiveConv, "ConvNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv, + "ConvNaive_CPU_float32"); + +} // namespace infini \ No newline at end of file diff --git a/src/kerels/cpu/matmul.cc b/src/kerels/cpu/matmul.cc index 527fd66c..a1e83103 100644 --- a/src/kerels/cpu/matmul.cc +++ b/src/kerels/cpu/matmul.cc @@ -30,7 +30,7 @@ template class NaiveMatmul : public Kernel { } }; -REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32, +REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::UInt32, NaiveMatmul, "MatmulNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32, NaiveMatmul, "MatmulNaive_CPU_float32"); diff --git a/src/operators/conv.cc b/src/operators/conv.cc new file mode 100644 index 00000000..e2f9dc27 --- /dev/null +++ b/src/operators/conv.cc @@ -0,0 +1,101 @@ +#include "operators/conv.h" + +namespace infini { + +ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : OperatorObj(OpType::Conv, {input, weight, bias}, {output}), ph(ph), + pw(pw), sh(sh), sw(sw), dh(dh), dw(dw), act(act), + padding(PaddingMode::Other) { + setAuxilaryAttributes(PaddingMode::Other); + IT_ASSERT(checkValid(graph)); +} + +ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : OperatorObj(OpType::Conv, {input, weight, bias}, {output}), ph(-1), + pw(-1), sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(mode) { + IT_ASSERT(mode != PaddingMode::Other); + setAuxilaryAttributes(mode); + IT_ASSERT(checkValid(graph)); +} + +string ConvObj::toString() const { + std::ostringstream os; + os << "Conv[" << getGuid() << "]"; + os << "("; + if (inputs.size() == 2) { + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + } + os << "p=[" << ph << "," << pw << "],"; + os << "s=[" << sh << "," << sw << "],"; + os << "d=[" << dh << "," << dw << "],"; + os << "act=" << enum_to_underlying(act) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "weight=" << inputs[1]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +optional> ConvObj::inferShape(const TensorVec &inputs) const { + const auto &input = inputs[0], &weight = inputs[1]; + auto n = input->getDims()[0]; + auto h = input->getDims()[2]; + auto w = input->getDims()[3]; + auto f = weight->getDims()[0]; + auto r = weight->getDims()[2]; + auto s = weight->getDims()[3]; + int on = n, oc = f; + int oh = 0, ow = 0; + // For NCHW+FCRS layout, C of input is divisable by C of weight + if (input->getDims()[1] % weight->getDims()[1] != 0) + return {}; + // Set padding size + if (padding == PaddingMode::Other) { + oh = (h - (r - sh) * dh + ph * 2) / sh; + ow = (w - (s - sw) * dw + pw * 2) / sw; + } else if (padding == PaddingMode::Same) { + oh = h / sh; + ow = w / sw; + // ph = (h - oh * sh + (r - sh) * dh) / 2; + // pw = (w - ow * sw + (s - sw) * dw) / 2; + } else if (padding == PaddingMode::Valid) { + int ph = 0; + int pw = 0; + oh = (h - (r - sh) * dh + ph * 2) / sh; + ow = (w - (s - sw) * dw + pw * 2) / sw; + } + return {{{on, oc, oh, ow}}}; +} + +vector ConvObj::getWorkloadVector() const { + return { + enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, + enum_to_underlying(act)}; +} + +vector ConvObj::getOpAttrVector() const { + IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? + return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw, + enum_to_underlying(act)}; +} + +void ConvObj::setAuxilaryAttributes(PaddingMode mode) { + n = inputs[0]->getDims()[0], c = inputs[0]->getDims()[1], + h = inputs[0]->getDims()[2], w = inputs[0]->getDims()[3], + f = inputs[1]->getDims()[0], r = inputs[1]->getDims()[2], + s = inputs[1]->getDims()[3]; + if (mode == PaddingMode::Same) { + int oh = h / sh; + int ow = w / sw; + ph = (h - oh * sh + (r - sh) * dh) / 2; + pw = (w - ow * sw + (s - sw) * dw) / 2; + } else if (mode == PaddingMode::Valid) { + ph = pw = 0; + } +} + +} // namespace infini \ No newline at end of file diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 21db982b..7c764141 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -7,39 +7,39 @@ namespace infini { TEST(Graph, build_and_run) { Graph g = make_ref(); - Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); - Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); - Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); g->dataMalloc(); - i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); - w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); + i0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); g->addOpWithOutputs(i0, w0, o0); RunEngine(Device::CPU).run(g); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32); ans->dataMalloc(); - ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}.data()); + ans->copyData({38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(o0->equalData(ans)); } TEST(Graph, perf_engine) { Graph g = make_ref(); - Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); - Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); auto matmul = g->addOp(i0, w0, nullptr); g->dataMalloc(); - i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); - w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); + i0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); RunEngine(Device::CPU).run(g, true, true); double perfTime = RunEngine(Device::CPU).getPerfTime(g); // The example matmul takes 0.0036ms with one core EXPECT_GT(perfTime, 0); EXPECT_LT(perfTime, 0.01); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32); ans->dataMalloc(); - ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}.data()); + ans->copyData({38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(matmul->getOutput()->equalData(ans)); } diff --git a/test/core/test_hash.cc b/test/core/test_hash.cc index 22955ec0..4f3f29d4 100644 --- a/test/core/test_hash.cc +++ b/test/core/test_hash.cc @@ -9,9 +9,9 @@ TEST(Hash, OperatorHash) { OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown); { // build with addOpWithOutputs Graph g = make_ref(); - Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); - Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); - Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); auto matmul = g->addOpWithOutputs(i0, w0, o0); key1 = matmul->getOpPerfKey(); EXPECT_NE(key1.hash, 0); @@ -19,8 +19,8 @@ TEST(Hash, OperatorHash) { } { // build with addOp Graph g = make_ref(); - Tensor i0 = g->addTensor({2, 2, 3}, DataType::Int32); - Tensor w0 = g->addTensor({2, 3, 4}, DataType::Int32); + Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32); auto matmul = g->addOp(i0, w0, nullptr); key2 = matmul->getOpPerfKey(); EXPECT_NE(key2.hash, 0); diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc new file mode 100644 index 00000000..4584c660 --- /dev/null +++ b/test/operators/test_conv.cc @@ -0,0 +1,63 @@ +#include "core/graph.h" +#include "core/run_enigne.h" +#include "operators/conv.h" +#include "test.h" + +namespace infini { + +TEST(Conv, ShapeInference) { + // Padding modes + { + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); + auto conv = g->addOp(i0, w0, nullptr, 1, 1); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4})); + } + { + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); + auto conv = + g->addOp(i0, w0, nullptr, ConvObj::PaddingMode::Same); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4})); + } + { + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); + auto conv = + g->addOp(i0, w0, nullptr, ConvObj::PaddingMode::Valid); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 2, 2})); + } + { // dilation & stride + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); + auto conv = g->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 2, 2})); + } +} + +TEST(Conv, NaiveCPU) { + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); + Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); + auto conv = g->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); + + g->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(IncrementalGenerator()); + RunEngine(Device::CPU).run(g, true, true); + double perfTime = RunEngine(Device::CPU).getPerfTime(g); + // The example matmul takes 0.0036ms with one core + EXPECT_GT(perfTime, 0); + EXPECT_LT(perfTime, 5); + // check answer + auto ans = make_ref(Shape{1, 2, 2, 2}, DataType::UInt32); + ans->dataMalloc(); + ans->copyData({4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); + EXPECT_TRUE(conv->getOutput()->equalData(ans)); +} + +} // namespace infini \ No newline at end of file