2022-09-01 21:06:55 +08:00
|
|
|
|
|
|
|
#include "core/graph.h"
|
|
|
|
#include "core/kernel.h"
|
|
|
|
#include "core/runtime.h"
|
|
|
|
#include "operators/matmul.h"
|
|
|
|
|
|
|
|
#include "test.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
using ExpectOutput = vector<float>;
|
|
|
|
|
|
|
|
TEST(Matmul, ShapeInference) {
|
2023-03-27 21:28:49 +08:00
|
|
|
auto runtime = NativeCpuRuntimeObj::getInstance();
|
2022-09-01 21:06:55 +08:00
|
|
|
{
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto A = g->addTensor(Shape{1, 3, 5});
|
|
|
|
auto B = g->addTensor(Shape{1, 5, 2});
|
|
|
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr);
|
|
|
|
auto C = matmul->getOutputs()[0];
|
|
|
|
EXPECT_EQ(C->getDims(), (Shape{1, 3, 2}));
|
|
|
|
}
|
|
|
|
{
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto A = g->addTensor(Shape{3, 5, 4});
|
|
|
|
auto B = g->addTensor(Shape{3, 5, 2});
|
|
|
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, false);
|
|
|
|
auto C = matmul->getOutputs()[0];
|
|
|
|
EXPECT_EQ(C->getDims(), (Shape{3, 4, 2}));
|
|
|
|
}
|
2023-08-16 21:49:43 +08:00
|
|
|
{
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto A = g->addTensor(Shape{1, 2, 3, 5});
|
|
|
|
auto B = g->addTensor(Shape{1, 1, 5, 2});
|
|
|
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr);
|
|
|
|
auto C = matmul->getOutputs()[0];
|
|
|
|
EXPECT_EQ(C->getDims(), (Shape{1, 2, 3, 2}));
|
|
|
|
}
|
|
|
|
{
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto A = g->addTensor(Shape{2, 3, 5, 4});
|
|
|
|
auto B = g->addTensor(Shape{1, 3, 5, 2});
|
|
|
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, false);
|
|
|
|
auto C = matmul->getOutputs()[0];
|
|
|
|
EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2}));
|
|
|
|
}
|
|
|
|
{
|
|
|
|
Graph g = make_ref<GraphObj>(runtime);
|
|
|
|
auto A = g->addTensor(Shape{2, 3, 5, 4});
|
|
|
|
auto B = g->addTensor(Shape{1, 3, 2, 5});
|
|
|
|
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, true);
|
|
|
|
auto C = matmul->getOutputs()[0];
|
|
|
|
EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2}));
|
|
|
|
}
|
2022-09-01 21:06:55 +08:00
|
|
|
}
|
|
|
|
|
2022-10-15 16:29:28 +08:00
|
|
|
}; // namespace infini
|