Add: comments for Matmul

This commit is contained in:
Liyan Zheng 2023-04-12 11:19:09 +08:00
parent 31b03ef91a
commit f939ebf8bb
6 changed files with 20 additions and 38 deletions

View File

@ -27,4 +27,4 @@ class MergeMemboundMutator : public Mutator {
Expr merge(bool allowEmptyMembound = false, bool allowFailure = false);
};
} // namespace nnet
} // namespace nnet

View File

@ -19,9 +19,15 @@ class MatmulObj : public OperatorObj {
public:
/**
* @brief Construct a new Matmul object. This comments show how operators is
* defined in InfiniTensor. The constructor can create output tensors for
* the operator or not, which depends on `graph`.
* @brief Matmul operator with batch broadcast and tensor transpose
* supports. Only one tensor with singe batch can be broadcasted due to the
* BLAS interface restriction. Tranpose indicates whether the last two
* dimensions should be transposed before Matmul and does not affect other
* leading dimensions.
*
* Matmul show how operators are defined in InfiniTensor. The constructor of
* an operator can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph The computation graph that this operator belongs to.
* @param A The input tensor.

View File

@ -1,4 +1,4 @@
print("__init__")
from .gen_ansor_op import gen_ansor_op
print("__init2__")
from .gen_ansor_so import gen_ansor_so
from .gen_ansor_so import gen_ansor_so

View File

@ -5,32 +5,13 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
// auto shape_a = A->getDims();
// auto shape_b = B->getDims();
// IT_ASSERT(shape_a.size() == shape_b.size());
// switch (shape_a.size()) {
// case 0:
// case 1:
// IT_ASSERT(false);
// case 2:
// break;
// default:
// for (size_t i = 0; i < shape_a.size() - 2; ++i) {
// IT_ASSERT(shape_a[i] == shape_b[i]);
// b *= shape_a[i];
// }
// break;
// }
// m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
// n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
// k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
act(act), b(1)
{
act(act) {
auto shape_a = A->getDims();
auto shape_b = B->getDims();
int dimA = shape_a.size(), dimB = shape_b.size();
IT_ASSERT(dimA >= 2 && dimB >= 2);
b = 1;
if (dimA <= 3 && dimB <= 3) {
int b1 = dimA == 2 ? 1 : A->getDims()[0];
int b2 = dimB == 2 ? 1 : B->getDims()[0];
@ -46,11 +27,6 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
// std::cout << A->toString() << "\n"
// << B->toString() << "\n";
// if (C) {
// std::cout << C->toString() << std::endl;
// }
IT_ASSERT(checkValid(graph));
}
@ -80,10 +56,10 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
int b2 = dimB == 2 ? 1 : B->getDims()[0];
int b = std::max(b1, b2);
int m = transA ? A->getDims()[dimA-1] : A->getDims()[dimA-2];
int n = transB ? B->getDims()[dimB-2] : B->getDims()[dimB-1];
int kA = transA ? A->getDims()[dimA-2] : A->getDims()[dimA-1];
int kB = transB ? B->getDims()[dimB-1] : B->getDims()[dimB-2];
int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2];
int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1];
int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1];
int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2];
if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) {
printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB);

View File

@ -8,7 +8,7 @@ namespace infini {
TEST(Hash, OperatorHash) {
OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown);
{ // build with addOpWithOutputs
Graph g = make_ref<GraphObj>(CpuRuntimeObj::getInstance());
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
@ -18,7 +18,7 @@ TEST(Hash, OperatorHash) {
EXPECT_GT(key1.attrs.size(), (size_t)5);
}
{ // build with addOp
Graph g = make_ref<GraphObj>(CpuRuntimeObj::getInstance());
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);

View File

@ -71,4 +71,4 @@ TEST(FuseMembound, mergeNestedStagesInRangeOp) {
makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}));
dbg(merged, ans);
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}
}