forked from jiuyuan/InfiniTensor
feat: add reshape/identity/squeeze/flatten/unsqueeze op cpu kernel (#213)
This commit is contained in:
parent
d1a90ba3e2
commit
4813204a36
|
@ -0,0 +1,29 @@
|
||||||
|
#include "operators/reshape.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "operators/squeeze.h"
|
||||||
|
#include "operators/unsqueeze.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class NaiveIdentity : public CpuKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *context) const override {
|
||||||
|
auto size = _op->getInputs()[0]->getBytes();
|
||||||
|
void *inptr = _op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *outptr = _op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
|
||||||
|
std::memcpy(outptr, inptr, size);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Reshape, NaiveIdentity,
|
||||||
|
"ReshapeNaive_CPU");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Identity, NaiveIdentity,
|
||||||
|
"IdentityNaive_CPU");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Unsqueeze, NaiveIdentity,
|
||||||
|
"UnsqueezeNaive_CPU");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Squeeze, NaiveIdentity,
|
||||||
|
"SqueezeNaive_CPU");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Flatten, NaiveIdentity,
|
||||||
|
"FlattenNaive_CPU");
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,21 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/reshape.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Identity, NativeCpu) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto t1 = g->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||||
|
auto op = g->addOp<IdentityObj>(t1, nullptr);
|
||||||
|
g->dataMalloc();
|
||||||
|
t1->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
runtime->run(g);
|
||||||
|
EXPECT_TRUE(op->getOutput()->equalData(
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue