2022-09-29 11:01:30 +08:00
|
|
|
#include "core/graph.h"
|
|
|
|
#include "core/runtime.h"
|
|
|
|
#include "operators/split.h"
|
|
|
|
|
|
|
|
#include "test.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
TEST(Split, ShapeInfer) {
|
|
|
|
{
|
2023-03-27 21:28:49 +08:00
|
|
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
2022-09-29 11:01:30 +08:00
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
|
|
|
|
|
|
|
auto op = g->addOp<SplitObj>(input, std::nullopt, 3, 4);
|
|
|
|
EXPECT_EQ(op->numOutputs(), 4);
|
|
|
|
EXPECT_EQ(op->getOutputs().size(), (size_t)4);
|
|
|
|
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
|
|
|
}
|
|
|
|
|
2023-08-18 12:17:47 +08:00
|
|
|
{
|
|
|
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
|
|
|
|
|
|
|
auto op = g->addOp<SplitObj>(input, std::nullopt, -1, 4);
|
|
|
|
EXPECT_EQ(op->numOutputs(), 4);
|
|
|
|
EXPECT_EQ(op->getOutputs().size(), (size_t)4);
|
|
|
|
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
|
|
|
}
|
|
|
|
|
2022-09-29 11:01:30 +08:00
|
|
|
{
|
2023-03-27 21:28:49 +08:00
|
|
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
2022-09-29 11:01:30 +08:00
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
|
|
|
|
|
|
|
auto op =
|
|
|
|
g->addOp<SplitObj>(input, std::nullopt, 3, vector<int>{1, 2, 2});
|
|
|
|
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
|
|
|
|
EXPECT_EQ(op->numOutputs(), 3);
|
|
|
|
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
|
|
|
|
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 6}));
|
|
|
|
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 6}));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-15 16:29:28 +08:00
|
|
|
} // namespace infini
|