2022-08-08 15:52:07 +08:00
|
|
|
#pragma once
|
|
|
|
#include "core/operator.h"
|
|
|
|
|
|
|
|
namespace infini {
|
2023-02-13 22:48:20 +08:00
|
|
|
/**
|
|
|
|
* @brief Matrix multiplication.
|
|
|
|
*
|
|
|
|
*/
|
2022-08-15 15:08:56 +08:00
|
|
|
class MatmulObj : public OperatorObj {
|
2022-08-08 15:52:07 +08:00
|
|
|
private:
|
2022-08-15 15:08:56 +08:00
|
|
|
// InfiniTensor assumes a row-major tensor layout. `transA`=false means
|
|
|
|
// default dims, true means A should be transposed before matmul. This is in
|
|
|
|
// oppsite to the column-major BLAS.
|
2022-08-09 14:58:45 +08:00
|
|
|
bool transA, transB;
|
|
|
|
ActType act;
|
|
|
|
|
2022-08-15 15:08:56 +08:00
|
|
|
// Auxiliary attributes which are not a part of operator attributes.
|
2022-08-09 14:58:45 +08:00
|
|
|
int b, m, n, k;
|
2022-08-08 15:52:07 +08:00
|
|
|
|
2024-03-26 09:00:45 +08:00
|
|
|
// Specifies the data precision for the matrix multiply.
|
|
|
|
std::string computeType = "default";
|
|
|
|
|
2022-08-08 15:52:07 +08:00
|
|
|
public:
|
2022-08-15 15:08:56 +08:00
|
|
|
/**
|
2023-04-18 00:26:36 +08:00
|
|
|
* @brief Matmul operator with batch broadcast and tensor transpose
|
|
|
|
* supports. Only one tensor with singe batch can be broadcasted due to the
|
|
|
|
* BLAS interface restriction. Tranpose indicates whether the last two
|
|
|
|
* dimensions should be transposed before Matmul and does not affect other
|
|
|
|
* leading dimensions.
|
|
|
|
*
|
|
|
|
* Matmul show how operators are defined in InfiniTensor. The constructor of
|
|
|
|
* an operator can create output tensors for the operator or not, which
|
|
|
|
* depends on `graph`.
|
2022-08-15 15:08:56 +08:00
|
|
|
*
|
2023-02-13 22:48:20 +08:00
|
|
|
* @param graph The computation graph that this operator belongs to.
|
|
|
|
* @param A The input tensor.
|
|
|
|
* @param B The input tensor.
|
2022-08-15 15:08:56 +08:00
|
|
|
* @param C C is the output of Matmul. If outputs are going to be created in
|
|
|
|
* the constructor, C should be an empty Ref.
|
2023-02-13 22:48:20 +08:00
|
|
|
* @param transA If matrix A should be transposed when computing.
|
|
|
|
* @param transB If matrix B should be transposed when computing.
|
|
|
|
* @param bias The bias tensor.
|
|
|
|
* @param act The activation function.
|
2024-03-26 09:00:45 +08:00
|
|
|
* @param computeType Specifies the data precision for the matrix multiply.
|
2022-08-15 15:08:56 +08:00
|
|
|
*/
|
|
|
|
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
|
|
|
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
2024-03-26 09:00:45 +08:00
|
|
|
ActType act = ActType::None, std::string computeType = "default");
|
2023-02-12 18:27:52 +08:00
|
|
|
OP_CLONE(MatmulObj);
|
2022-08-08 15:52:07 +08:00
|
|
|
|
|
|
|
std::string toString() const override;
|
2023-11-23 13:11:50 +08:00
|
|
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
2022-08-08 15:52:07 +08:00
|
|
|
|
2023-04-18 15:10:33 +08:00
|
|
|
int numInputs() const override { return inputs.size(); }
|
2022-08-08 15:52:07 +08:00
|
|
|
int numOutputs() const override { return 1; }
|
|
|
|
|
2023-04-18 15:10:33 +08:00
|
|
|
Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; }
|
2022-08-09 14:58:45 +08:00
|
|
|
ActType getAct() const { return act; }
|
2022-09-13 15:17:22 +08:00
|
|
|
auto getBMNKTransAB() const { return tuple(b, m, n, k, transA, transB); }
|
2022-08-09 14:58:45 +08:00
|
|
|
bool getTransA() const { return transA; }
|
|
|
|
bool getTransB() const { return transB; }
|
|
|
|
int getB() const { return b; }
|
|
|
|
int getM() const { return m; }
|
|
|
|
int getN() const { return n; }
|
|
|
|
int getK() const { return k; }
|
2022-09-01 21:06:55 +08:00
|
|
|
auto getBMNK() const { return tuple{b, m, n, k}; }
|
2024-03-26 09:00:45 +08:00
|
|
|
std::string getComputeType() const { return computeType; }
|
2022-08-08 15:52:07 +08:00
|
|
|
|
|
|
|
private:
|
2022-08-15 15:08:56 +08:00
|
|
|
vector<int> getWorkloadVector() const override;
|
|
|
|
vector<int> getOpAttrVector() const override;
|
2022-08-08 15:52:07 +08:00
|
|
|
};
|
|
|
|
|
2022-08-09 14:58:45 +08:00
|
|
|
} // namespace infini
|