InfiniTensor/include/operators/matmul.h

54 lines
1.8 KiB
C
Raw Normal View History

2022-08-08 15:52:07 +08:00
#pragma once
#include "core/operator.h"
namespace infini {
class MatmulObj : public OperatorObj {
2022-08-08 15:52:07 +08:00
private:
// InfiniTensor assumes a row-major tensor layout. `transA`=false means
// default dims, true means A should be transposed before matmul. This is in
// oppsite to the column-major BLAS.
2022-08-09 14:58:45 +08:00
bool transA, transB;
ActType act;
// Auxiliary attributes which are not a part of operator attributes.
2022-08-09 14:58:45 +08:00
int b, m, n, k;
2022-08-08 15:52:07 +08:00
public:
/**
* @brief This comments show how operators is defined in InfiniTensor. The
* constructor can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph If graph is not empty, create outputs in the constructor.
* Otherwise, check the provided shape with the results of `inferShape` in
* `checkValid`.
* @param C C is the output of Matmul. If outputs are going to be created in
* the constructor, C should be an empty Ref.
*/
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
2022-08-08 15:52:07 +08:00
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
2022-08-08 15:52:07 +08:00
int numInputs() const override { return 3; }
2022-08-08 15:52:07 +08:00
int numOutputs() const override { return 1; }
Tensor getBias() const { return inputs[2]; }
2022-08-09 14:58:45 +08:00
ActType getAct() const { return act; }
bool getTransA() const { return transA; }
bool getTransB() const { return transB; }
int getB() const { return b; }
int getM() const { return m; }
int getN() const { return n; }
int getK() const { return k; }
2022-08-08 15:52:07 +08:00
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
2022-08-08 15:52:07 +08:00
};
2022-08-09 14:58:45 +08:00
} // namespace infini