From 4813204a3622b04c3cbcc6d53b0824197b5dfe45 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:29:59 +0800 Subject: [PATCH] feat: add reshape/identity/squeeze/flatten/unsqueeze op cpu kernel (#213) --- src/kernels/cpu/reshape.cc | 29 +++++++++++++++++++ .../nativecpu/test_nativecpu_identity.cc | 21 ++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 src/kernels/cpu/reshape.cc create mode 100644 test/kernels/nativecpu/test_nativecpu_identity.cc diff --git a/src/kernels/cpu/reshape.cc b/src/kernels/cpu/reshape.cc new file mode 100644 index 00000000..4eea18c4 --- /dev/null +++ b/src/kernels/cpu/reshape.cc @@ -0,0 +1,29 @@ +#include "operators/reshape.h" +#include "core/kernel.h" +#include "operators/squeeze.h" +#include "operators/unsqueeze.h" + +namespace infini { +class NaiveIdentity : public CpuKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *context) const override { + auto size = _op->getInputs()[0]->getBytes(); + void *inptr = _op->getInputs(0)->getRawDataPtr(); + void *outptr = _op->getOutput()->getRawDataPtr(); + + std::memcpy(outptr, inptr, size); + } +}; + +REGISTER_KERNEL(Device::CPU, OpType::Reshape, NaiveIdentity, + "ReshapeNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Identity, NaiveIdentity, + "IdentityNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Unsqueeze, NaiveIdentity, + "UnsqueezeNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Squeeze, NaiveIdentity, + "SqueezeNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Flatten, NaiveIdentity, + "FlattenNaive_CPU"); + +} // namespace infini diff --git a/test/kernels/nativecpu/test_nativecpu_identity.cc b/test/kernels/nativecpu/test_nativecpu_identity.cc new file mode 100644 index 00000000..82ddf2e0 --- /dev/null +++ b/test/kernels/nativecpu/test_nativecpu_identity.cc @@ -0,0 +1,21 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/reshape.h" + +#include "test.h" + +namespace infini { +TEST(Identity, NativeCpu) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + + auto t1 = g->addTensor({2, 2, 3, 1}, DataType::Float32); + auto op = g->addOp(t1, nullptr); + g->dataMalloc(); + t1->setData(IncrementalGenerator()); + + runtime->run(g); + EXPECT_TRUE(op->getOutput()->equalData( + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); +} +} // namespace infini