forked from jiuyuan/InfiniTensor
Add: comments for Matmul
This commit is contained in:
parent
31b03ef91a
commit
f939ebf8bb
|
@ -27,4 +27,4 @@ class MergeMemboundMutator : public Mutator {
|
|||
Expr merge(bool allowEmptyMembound = false, bool allowFailure = false);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
||||
} // namespace nnet
|
||||
|
|
|
@ -19,9 +19,15 @@ class MatmulObj : public OperatorObj {
|
|||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Matmul object. This comments show how operators is
|
||||
* defined in InfiniTensor. The constructor can create output tensors for
|
||||
* the operator or not, which depends on `graph`.
|
||||
* @brief Matmul operator with batch broadcast and tensor transpose
|
||||
* supports. Only one tensor with singe batch can be broadcasted due to the
|
||||
* BLAS interface restriction. Tranpose indicates whether the last two
|
||||
* dimensions should be transposed before Matmul and does not affect other
|
||||
* leading dimensions.
|
||||
*
|
||||
* Matmul show how operators are defined in InfiniTensor. The constructor of
|
||||
* an operator can create output tensors for the operator or not, which
|
||||
* depends on `graph`.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param A The input tensor.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
print("__init__")
|
||||
from .gen_ansor_op import gen_ansor_op
|
||||
print("__init2__")
|
||||
from .gen_ansor_so import gen_ansor_so
|
||||
from .gen_ansor_so import gen_ansor_so
|
||||
|
|
|
@ -5,32 +5,13 @@ namespace infini {
|
|||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
|
||||
// auto shape_a = A->getDims();
|
||||
// auto shape_b = B->getDims();
|
||||
// IT_ASSERT(shape_a.size() == shape_b.size());
|
||||
// switch (shape_a.size()) {
|
||||
// case 0:
|
||||
// case 1:
|
||||
// IT_ASSERT(false);
|
||||
// case 2:
|
||||
// break;
|
||||
// default:
|
||||
// for (size_t i = 0; i < shape_a.size() - 2; ++i) {
|
||||
// IT_ASSERT(shape_a[i] == shape_b[i]);
|
||||
// b *= shape_a[i];
|
||||
// }
|
||||
// break;
|
||||
// }
|
||||
// m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||
// n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||
// k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
act(act), b(1)
|
||||
{
|
||||
act(act) {
|
||||
auto shape_a = A->getDims();
|
||||
auto shape_b = B->getDims();
|
||||
int dimA = shape_a.size(), dimB = shape_b.size();
|
||||
IT_ASSERT(dimA >= 2 && dimB >= 2);
|
||||
|
||||
b = 1;
|
||||
if (dimA <= 3 && dimB <= 3) {
|
||||
int b1 = dimA == 2 ? 1 : A->getDims()[0];
|
||||
int b2 = dimB == 2 ? 1 : B->getDims()[0];
|
||||
|
@ -46,11 +27,6 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
|||
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
// std::cout << A->toString() << "\n"
|
||||
// << B->toString() << "\n";
|
||||
// if (C) {
|
||||
// std::cout << C->toString() << std::endl;
|
||||
// }
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -80,10 +56,10 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
|||
int b2 = dimB == 2 ? 1 : B->getDims()[0];
|
||||
|
||||
int b = std::max(b1, b2);
|
||||
int m = transA ? A->getDims()[dimA-1] : A->getDims()[dimA-2];
|
||||
int n = transB ? B->getDims()[dimB-2] : B->getDims()[dimB-1];
|
||||
int kA = transA ? A->getDims()[dimA-2] : A->getDims()[dimA-1];
|
||||
int kB = transB ? B->getDims()[dimB-1] : B->getDims()[dimB-2];
|
||||
int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2];
|
||||
int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1];
|
||||
int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1];
|
||||
int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2];
|
||||
|
||||
if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) {
|
||||
printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB);
|
||||
|
|
|
@ -8,7 +8,7 @@ namespace infini {
|
|||
TEST(Hash, OperatorHash) {
|
||||
OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown);
|
||||
{ // build with addOpWithOutputs
|
||||
Graph g = make_ref<GraphObj>(CpuRuntimeObj::getInstance());
|
||||
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
|
@ -18,7 +18,7 @@ TEST(Hash, OperatorHash) {
|
|||
EXPECT_GT(key1.attrs.size(), (size_t)5);
|
||||
}
|
||||
{ // build with addOp
|
||||
Graph g = make_ref<GraphObj>(CpuRuntimeObj::getInstance());
|
||||
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
|
||||
Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
|
||||
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
||||
|
|
|
@ -71,4 +71,4 @@ TEST(FuseMembound, mergeNestedStagesInRangeOp) {
|
|||
makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}));
|
||||
dbg(merged, ans);
|
||||
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue