refactor(core): 添加新的 `OpType` 定义 (#99)

* feat: 添加新的 OpType 定义

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* refactor: 使用新的 OpType 替换原来的,修改整个项目

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: onnx 导入

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 修正 cuda 和 bang kernel 的问题

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 过滤 bang test

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 过滤 bang test

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix bang code.

* fix code on bang

* fmt

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 删除指定文件

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 删两个没用的文件,去掉一个不知道为什么的注释

Signed-off-by: YdrMaster <ydrml@hotmail.com>

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
This commit is contained in:
Derui Yang 2023-08-07 11:17:05 +08:00 committed by GitHub
parent 9b10a74788
commit 57ac94d893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 776 additions and 907 deletions

View File

@ -27,9 +27,9 @@ class GraphHandlerObj {
int opw); int opw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act); Tensor bias, ActType act);
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor scale, Tensor bias, float momentum, float eps, Tensor var, Tensor scale, Tensor bias,
bool training); float momentum, float eps, bool training);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw); int ph, int pw, int sh, int sw);

View File

@ -105,8 +105,8 @@ class KernelRegistry {
IT_ASSERT(it != kernels.end(), IT_ASSERT(it != kernels.end(),
"Kernel not found for key {" + "Kernel not found for key {" +
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
", " + OpRegistry::getOpName(std::get<1>(kernelAttrs)) + ", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
", " + std::get<2>(kernelAttrs).toString() + "}"); std::get<2>(kernelAttrs).toString() + "}");
return std::get<0>(it->second); return std::get<0>(it->second);
} }
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {

253
include/core/op_type.h Normal file
View File

@ -0,0 +1,253 @@
#pragma once
#ifndef OP_TYPE_H
#define OP_TYPE_H
#include <string>
#include <unordered_set>
namespace infini {
struct OpType {
using underlying_t = uint16_t;
// Clang-format is ambiguous in formating of comment alignment.
// In order to disambiguate, it is necessary to comment all enum
// elements.
enum : underlying_t {
Unknown,
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //
BitShift, // Binary
BitwiseAnd, // Binary
BitwiseNot, // Binary
BitwiseOr, // Binary
BitwiseXor, // Binary
BlackmanWindow, //
Cast, // Unary
CastLike, //
Ceil, // Unary
Celu, //
CenterCropPad, //
Clip, // Unary
Col2lm,
Compress,
Concat,
ConcatFromSequence,
ConstantOfShape,
Conv, // ComputationIntensive
ConvInteger, // ComputationIntensive
ConvTranspose, // ComputationIntensive
Cos, // Unary
Cosh, // Unary
CumSum,
DFT,
DeformConv, // ComputationIntensive
DepthToSpace,
DequantizeLinear,
Det,
Div, // Binary
Dropout,
DynamicQuantizeLinear,
Einsum,
Elu,
Equal, // Compair
Erf, // Unary
Exp, // Unary
Expand,
EyeLike,
Flatten,
Floor, // Unary
GRU,
Gather,
GatherElements,
GatherND,
Gemm,
GlobalAveragePool, // GlobalPool
GlobalLpPool, // GlobalPool
GlobalMaxPool, // GlobalPool
Greater, // Compair
GreaterOrEqual, // Compair
GridSample,
GroupNormalization,
HammingWindow,
HannWindow,
HardSigmoid,
HardSwish,
Hardmax,
Identity,
If,
InstanceNormalization,
IsInf,
IsNaN,
LRN,
LSTM,
LayerNormalization,
LeakyRelu,
Less, // Compair
LessOrEqual, // Compair
Log, // Unary
LogSoftmax,
Loop,
LpNormalization,
LpPool,
MatMul, // ComputationIntensive
MatMulInteger, // ComputationIntensive
Max,
MaxPool,
MaxRoiPool,
MaxUnpool,
Mean,
MeanVarianceNormalization,
MelWeightMatrix,
Min,
Mish,
Mod, // Binary
Mul, // Binary
Multinomial, //
Neg, // Unary
NegativeLogLikelihoodLoss,
NonMaxSuppression,
NonZero,
Not, // Unary
OneHot,
Optional,
OptionalGetElement,
OptionalHasElement,
Or, // Binary
PRelu, //
Pad, //
Pow, // Binary
QLinearConv, // ComputationIntensive
QLinearMatMul, // ComputationIntensive
QuantizeLinear,
RNN,
RandomNormal,
RandomNormalLike,
RandomUniform,
RandomUniformLike,
Range,
Reciprocal,
ReduceL1, // Reduce
ReduceL2, // Reduce
ReduceLogSum, // Reduce
ReduceLogSumExp, // Reduce
ReduceMax, // Reduce
ReduceMean, // Reduce
ReduceMin, // Reduce
ReduceProd, // Reduce
ReduceSum, // Reduce
ReduceSumSquare, // Reduce
Relu, // Unary
Reshape,
Resize,
ReverseSequence,
RoiAlign,
Round, // Unary
STFT,
Scan,
Scatter,
ScatterElements,
ScatterND,
Selu,
SequenceAt,
SequenceConstruct,
SequenceEmpty,
SequenceErase,
SequenceInsert,
SequenceLength,
SequenceMap,
Shape,
Shrink,
Sigmoid,
Sign,
Sin, // Unary
Sinh, // Unary
Size,
Slice,
Softmax,
SoftmaxCrossEntropyLoss,
Softplus,
Softsign,
SpaceToDepth,
Split,
SplitToSequence,
Sqrt,
Squeeze,
StringNormalizer,
Sub, // Binary
Sum, //
Tan, // Unary
Tanh, // unary
TfIdfVectorizer,
ThresholdedRelu,
Tile,
TopK,
Transpose,
Trilu,
Unique,
Unsqueeze,
Upsample,
Where,
Xor, // Binary
// CUSTOM DEFINED
G2BMM,
GBMM,
MemBound,
// TODO
ConvTransNHWC,
ConvBackwardFilter,
ReluBackward,
SigmoidBackward,
TanhBackward,
Fill,
Extend,
MSELoss,
Hardtanh,
L2Loss,
Rsqrt,
FloorDiv,
FloorMod,
Square,
SquaredDifference,
} type;
constexpr OpType(decltype(type) t) : type(t) {}
constexpr explicit OpType(underlying_t val) : type((decltype(type))val) {}
constexpr underlying_t underlying() const { return type; }
bool operator==(OpType others) const { return type == others.type; }
bool operator!=(OpType others) const { return type != others.type; }
bool operator<(OpType others) const { return type < others.type; }
const char *toString() const;
bool isUnary() const;
bool isBinary() const;
bool isElementWise() const;
bool isCompair() const;
bool isPool() const;
bool isGlobalPool() const;
bool isMatMulOrConv() const;
};
enum class ActType {
None,
Relu,
Sigmoid,
Tanh,
};
} // namespace infini
#endif // OP_TYPE_H

View File

@ -1,231 +1,14 @@
#pragma once #pragma once
#include "core/op_type.h"
#include "core/tensor.h" #include "core/tensor.h"
namespace infini { namespace infini {
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
enum class OpType {
Unknown = 0,
// linear
Conv = 100,
ConvBackwardFilter,
ConvBackwardData,
Matmul,
ConvTrans,
ConvTransNHWC,
G2BMM,
GBMM,
Pad,
Slice,
Concat,
Split,
Transpose,
Extend,
MaxPool,
AvgPool,
Add,
Sub,
Mul,
Div,
Pow,
Gather,
ReduceMean,
Reshape,
Flatten,
Identity,
// element wise
BatchNorm = 200,
Softmax,
Activation,
Relu,
ReluBackward,
PRelu,
Sigmoid,
SigmoidBackward,
Tanh,
TanhBackward,
Abs,
Sin,
Cos,
Tan,
ASin,
ACos,
ATan,
SinH,
CosH,
TanH,
ASinH,
ACosH,
ATanH,
Resize,
Arange,
Shape,
Copy,
Ceil,
Floor,
Clip,
Erf,
Exp,
Fill,
Log,
L2Loss,
Maximum,
Minimum,
MSELoss,
Neg,
Power,
Reciprocal,
Sqrt,
Rsqrt,
Cast,
FloorDiv,
FloorMod,
Det,
Round,
Square,
SquaredDifference,
Hardtanh,
Equal,
NotEqual,
GreaterThan,
GreaterEqual,
LessThan,
LessEqual,
And,
Or,
Xor,
Not,
BitAnd,
BitOr,
BitXor,
BitNot,
BitLeftShift,
BitRightShift,
Dropout,
//
MemBound = 300,
};
using KernelAttrs = std::tuple<Device, OpType, DataType>;
class OpRegistry {
public:
static std::string getOpName(OpType opType) {
#define FOP(op) \
case OpType::op: \
return #op
switch (opType) {
FOP(Unknown);
// linear
FOP(Conv);
FOP(ConvBackwardFilter);
FOP(ConvBackwardData);
FOP(Matmul);
FOP(ConvTrans);
FOP(G2BMM);
FOP(GBMM);
FOP(Pad);
FOP(Slice);
FOP(Concat);
FOP(Split);
FOP(Transpose);
FOP(Extend);
FOP(MaxPool);
FOP(AvgPool);
FOP(Add);
FOP(Sub);
FOP(Mul);
FOP(Div);
FOP(Pow);
FOP(Gather);
FOP(ReduceMean);
FOP(Reshape);
FOP(Identity);
FOP(Shape);
// element wise
FOP(BatchNorm);
FOP(Softmax);
FOP(Activation);
FOP(Relu);
FOP(ReluBackward);
FOP(PRelu);
FOP(Sigmoid);
FOP(SigmoidBackward);
FOP(Tanh);
FOP(TanhBackward);
FOP(Abs);
FOP(Sin);
FOP(Cos);
FOP(Tan);
FOP(ASin);
FOP(ACos);
FOP(ATan);
FOP(SinH);
FOP(CosH);
FOP(TanH);
FOP(ASinH);
FOP(ACosH);
FOP(ATanH);
FOP(Copy);
FOP(Ceil);
FOP(Floor);
FOP(Clip);
FOP(Erf);
FOP(Exp);
FOP(Fill);
FOP(Log);
FOP(L2Loss);
FOP(Maximum);
FOP(Minimum);
FOP(MSELoss);
FOP(Neg);
FOP(Power);
FOP(Reciprocal);
FOP(Sqrt);
FOP(Rsqrt);
FOP(Cast);
FOP(FloorDiv);
FOP(FloorMod);
FOP(Det);
FOP(Round);
FOP(Square);
FOP(SquaredDifference);
FOP(Hardtanh);
FOP(Equal);
FOP(NotEqual);
FOP(GreaterThan);
FOP(GreaterEqual);
FOP(LessThan);
FOP(LessEqual);
FOP(And);
FOP(Or);
FOP(Xor);
FOP(Not);
FOP(BitAnd);
FOP(BitOr);
FOP(BitXor);
FOP(BitNot);
FOP(BitLeftShift);
FOP(BitRightShift);
//
FOP(MemBound);
default:
IT_ASSERT(false);
break;
}
#undef FOP
}
};
enum class ActType {
None,
Relu,
Sigmoid,
Tanh,
};
struct OpPerfKey { struct OpPerfKey {
HashType hash; HashType hash;
OpType opType; OpType::underlying_t opType;
vector<int> attrs; vector<int> attrs;
public: public:
@ -233,7 +16,7 @@ struct OpPerfKey {
// https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types // https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types
OpPerfKey() = default; OpPerfKey() = default;
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {}) OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
: hash(hash), opType(opType), attrs(attrs) {} : hash(hash), opType(opType.underlying()), attrs(attrs) {}
bool operator==(const OpPerfKey &rhs) const { bool operator==(const OpPerfKey &rhs) const {
if (hash != rhs.hash) if (hash != rhs.hash)
return false; return false;
@ -290,16 +73,7 @@ class OperatorObj : public Object {
*/ */
HashType hash() const; HashType hash() const;
public: // check Op type public:
bool isLinearOp() const;
bool isElementWiseOp() const;
bool isSplitOp() const;
bool isConcatOp() const;
bool isComputeOp() const;
bool isTransposeOp() const;
bool isReshapeOp() const;
bool isMemBoundOp() const;
public: // getter and setter public: // getter and setter
const TensorVec &getInputs() const { return inputs; } const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; } const TensorVec &getOutputs() const { return outputs; }

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "core/common.h" #include "core/common.h"
#include "core/op_type.h"
#include "core/ref.h" #include "core/ref.h"
#include <memory> #include <memory>
@ -21,7 +22,6 @@ using Graph = Ref<GraphObj>;
using GraphHandler = Ref<GraphHandlerObj>; using GraphHandler = Ref<GraphHandlerObj>;
using Runtime = Ref<RuntimeObj>; using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>; using Blob = Ref<BlobObj>;
enum class OpType;
using TensorVec = vector<Tensor>; using TensorVec = vector<Tensor>;
using OpVec = vector<Operator>; using OpVec = vector<Operator>;

View File

@ -65,26 +65,24 @@ DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub)
DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul) DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul)
DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div) DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div)
DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow) DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow)
DEFINE_ELEMENT_WISE_OBJ(Maximum, OpType::Maximum) DEFINE_ELEMENT_WISE_OBJ(Maximum, OpType::Max)
DEFINE_ELEMENT_WISE_OBJ(Minimum, OpType::Minimum) DEFINE_ELEMENT_WISE_OBJ(Minimum, OpType::Min)
DEFINE_ELEMENT_WISE_OBJ(Power, OpType::Power) DEFINE_ELEMENT_WISE_OBJ(Power, OpType::Pow)
DEFINE_ELEMENT_WISE_OBJ(FloorDiv, OpType::FloorDiv) DEFINE_ELEMENT_WISE_OBJ(FloorDiv, OpType::FloorDiv)
DEFINE_ELEMENT_WISE_OBJ(FloorMod, OpType::FloorMod) DEFINE_ELEMENT_WISE_OBJ(FloorMod, OpType::FloorMod)
DEFINE_ELEMENT_WISE_OBJ(SquaredDifference, OpType::SquaredDifference) DEFINE_ELEMENT_WISE_OBJ(SquaredDifference, OpType::SquaredDifference)
DEFINE_ELEMENT_WISE_OBJ(Equal, OpType::Equal) DEFINE_ELEMENT_WISE_OBJ(Equal, OpType::Equal)
DEFINE_ELEMENT_WISE_OBJ(NotEqual, OpType::NotEqual) DEFINE_ELEMENT_WISE_OBJ(GreaterThan, OpType::Greater)
DEFINE_ELEMENT_WISE_OBJ(GreaterThan, OpType::GreaterThan) DEFINE_ELEMENT_WISE_OBJ(GreaterEqual, OpType::GreaterOrEqual)
DEFINE_ELEMENT_WISE_OBJ(GreaterEqual, OpType::GreaterEqual) DEFINE_ELEMENT_WISE_OBJ(LessThan, OpType::Less)
DEFINE_ELEMENT_WISE_OBJ(LessThan, OpType::LessThan) DEFINE_ELEMENT_WISE_OBJ(LessEqual, OpType::LessOrEqual)
DEFINE_ELEMENT_WISE_OBJ(LessEqual, OpType::LessEqual)
DEFINE_ELEMENT_WISE_OBJ(And, OpType::And) DEFINE_ELEMENT_WISE_OBJ(And, OpType::And)
DEFINE_ELEMENT_WISE_OBJ(Or, OpType::Or) DEFINE_ELEMENT_WISE_OBJ(Or, OpType::Or)
DEFINE_ELEMENT_WISE_OBJ(Xor, OpType::Xor) DEFINE_ELEMENT_WISE_OBJ(Xor, OpType::Xor)
DEFINE_ELEMENT_WISE_OBJ(Not, OpType::Not) DEFINE_ELEMENT_WISE_OBJ(Not, OpType::Not)
DEFINE_ELEMENT_WISE_OBJ(BitAnd, OpType::BitAnd) DEFINE_ELEMENT_WISE_OBJ(BitAnd, OpType::BitwiseAnd)
DEFINE_ELEMENT_WISE_OBJ(BitOr, OpType::BitOr) DEFINE_ELEMENT_WISE_OBJ(BitOr, OpType::BitwiseOr)
DEFINE_ELEMENT_WISE_OBJ(BitXor, OpType::BitXor) DEFINE_ELEMENT_WISE_OBJ(BitXor, OpType::BitwiseXor)
DEFINE_ELEMENT_WISE_OBJ(BitNot, OpType::BitNot) DEFINE_ELEMENT_WISE_OBJ(BitNot, OpType::BitwiseNot)
DEFINE_ELEMENT_WISE_OBJ(BitLeftShift, OpType::BitLeftShift) DEFINE_ELEMENT_WISE_OBJ(BitLeftShift, OpType::BitShift)
DEFINE_ELEMENT_WISE_OBJ(BitRightShift, OpType::BitRightShift)
}; // namespace infini }; // namespace infini

View File

@ -70,7 +70,7 @@ class AvgPoolObj : public PoolingObj {
public: public:
AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw, AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw) int dh, int dw, int ph, int pw, int sh, int sw)
: PoolingObj(graph, OpType::AvgPool, input, output, kh, kw, dh, dw, ph, : PoolingObj(graph, OpType::AveragePool, input, output, kh, kw, dh, dw,
pw, sh, sw) {} ph, pw, sh, sw) {}
}; };
}; // namespace infini }; // namespace infini

View File

@ -197,27 +197,6 @@ class CumsumObj : public OperatorObj {
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };
class ArangeObj : public OperatorObj {
public:
ArangeObj(GraphObj *graph, float start, float step, int length,
Tensor output);
OP_CLONE(ArangeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 0; }
int numOutputs() const override { return 1; }
float getStartValue() { return startValue; }
float getStepValue() { return stepValue; }
int getLength() { return lengthValue; }
private:
float startValue, stepValue;
int lengthValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
class ShapeObj : public OperatorObj { class ShapeObj : public OperatorObj {
public: public:
ShapeObj(GraphObj *graph, Tensor input, Tensor output); ShapeObj(GraphObj *graph, Tensor input, Tensor output);
@ -283,17 +262,16 @@ DEFINE_UNARY_OBJ(Abs, OpType::Abs)
DEFINE_UNARY_OBJ(Sin, OpType::Sin) DEFINE_UNARY_OBJ(Sin, OpType::Sin)
DEFINE_UNARY_OBJ(Cos, OpType::Cos) DEFINE_UNARY_OBJ(Cos, OpType::Cos)
DEFINE_UNARY_OBJ(Tan, OpType::Tan) DEFINE_UNARY_OBJ(Tan, OpType::Tan)
DEFINE_UNARY_OBJ(ASin, OpType::ASin) DEFINE_UNARY_OBJ(ASin, OpType::Asin)
DEFINE_UNARY_OBJ(ACos, OpType::ACos) DEFINE_UNARY_OBJ(ACos, OpType::Acos)
DEFINE_UNARY_OBJ(ATan, OpType::ATan) DEFINE_UNARY_OBJ(ATan, OpType::Atan)
DEFINE_UNARY_OBJ(SinH, OpType::SinH) DEFINE_UNARY_OBJ(SinH, OpType::Sinh)
DEFINE_UNARY_OBJ(CosH, OpType::CosH) DEFINE_UNARY_OBJ(CosH, OpType::Cosh)
DEFINE_UNARY_OBJ(TanH, OpType::TanH) DEFINE_UNARY_OBJ(TanH, OpType::Tanh)
DEFINE_UNARY_OBJ(ASinH, OpType::ASinH) DEFINE_UNARY_OBJ(ASinH, OpType::Asinh)
DEFINE_UNARY_OBJ(ACosH, OpType::ACosH) DEFINE_UNARY_OBJ(ACosH, OpType::Acosh)
DEFINE_UNARY_OBJ(ATanH, OpType::ATanH) DEFINE_UNARY_OBJ(ATanH, OpType::Atanh)
DEFINE_UNARY_OBJ(Copy, OpType::Copy)
DEFINE_UNARY_OBJ(Ceil, OpType::Ceil) DEFINE_UNARY_OBJ(Ceil, OpType::Ceil)
DEFINE_UNARY_OBJ(Floor, OpType::Floor) DEFINE_UNARY_OBJ(Floor, OpType::Floor)
DEFINE_UNARY_OBJ(Erf, OpType::Erf) DEFINE_UNARY_OBJ(Erf, OpType::Erf)
@ -301,7 +279,5 @@ DEFINE_UNARY_OBJ(Exp, OpType::Exp)
DEFINE_UNARY_OBJ(Neg, OpType::Neg) DEFINE_UNARY_OBJ(Neg, OpType::Neg)
DEFINE_UNARY_OBJ(Reciprocal, OpType::Reciprocal) DEFINE_UNARY_OBJ(Reciprocal, OpType::Reciprocal)
DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt) DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt)
DEFINE_UNARY_OBJ(Rsqrt, OpType::Rsqrt)
DEFINE_UNARY_OBJ(Round, OpType::Round) DEFINE_UNARY_OBJ(Round, OpType::Round)
DEFINE_UNARY_OBJ(Square, OpType::Square)
}; // namespace infini }; // namespace infini

View File

@ -196,7 +196,7 @@ class OnnxStub:
attributes[name] attributes[name]
for name in ["momentum", "epsilon", "training_mode"] for name in ["momentum", "epsilon", "training_mode"]
) )
tensors[node.output[0]] = self.handler.batchNorm( tensors[node.output[0]] = self.handler.batchNormalization(
input, output, mean, var, scale, bias, momentum, eps, training != 0 input, output, mean, var, scale, bias, momentum, eps, training != 0
) )
elif node.op_type == "MaxPool": elif node.op_type == "MaxPool":
@ -551,7 +551,7 @@ class OnnxStub:
# saves object names, including tensors and operators # saves object names, including tensors and operators
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict() names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
# counts the occurrence times of each operator for naming # counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict() count_op: Dict[backend.OpTypeId, int] = dict()
# counts input and output tensors for naming # counts input and output tensors for naming
count_in, count_out = 0, 0 count_in, count_out = 0, 0
# saves nodes (operators) # saves nodes (operators)
@ -563,8 +563,8 @@ class OnnxStub:
# saves global input tensors # saves global input tensors
initializers: List[TensorProto] = [] initializers: List[TensorProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: def name_op(self, op: backend.Operator) -> Tuple[backend.OpTypeId, str]:
ty = op.op_type() ty = op.op_type().id()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
self.names[op] = name self.names[op] = name
self.count_op[ty] += 1 self.count_op[ty] += 1
@ -647,7 +647,7 @@ class OnnxStub:
ctx.push_output("{}_{}".format(name, i), it) ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs()) for (i, it) in enumerate(op.outputs())
] ]
if ty == backend.OpType.Conv: if ty == backend.OpTypeId.Conv:
ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op) ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op)
ctx.push_node( ctx.push_node(
make_node( make_node(
@ -661,11 +661,11 @@ class OnnxStub:
group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1], group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1],
) )
) )
elif ty == backend.OpType.ConvTrans: elif ty == backend.OpTypeId.ConvTranspose:
ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op) ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op)
ctx.push_node( ctx.push_node(
make_node( make_node(
"ConvTranspose", ty.name,
inputs, inputs,
outputs, outputs,
name, name,
@ -675,14 +675,14 @@ class OnnxStub:
output_padding=[oph, opw], output_padding=[oph, opw],
) )
) )
elif ty == backend.OpType.Matmul: elif ty == backend.OpTypeId.MatMul:
transA, transB = backend.matmul_attrs_of(op) transA, transB = backend.matmul_attrs_of(op)
ctx.push_node( ctx.push_node(
make_node( make_node(
"Gemm", inputs, outputs, name, transA=transA, transB=transB "Gemm", inputs, outputs, name, transA=transA, transB=transB
) )
) )
elif ty == backend.OpType.BatchNorm: elif ty == backend.OpTypeId.BatchNormalization:
inputs = [inputs[i] for i in [0, 3, 4, 1, 2]] inputs = [inputs[i] for i in [0, 3, 4, 1, 2]]
momentum, eps, training = backend.batch_norm_attrs_of(op) momentum, eps, training = backend.batch_norm_attrs_of(op)
ctx.push_node( ctx.push_node(
@ -696,7 +696,7 @@ class OnnxStub:
training_mode=training, training_mode=training,
) )
) )
elif ty == backend.OpType.MaxPool: elif ty == backend.OpTypeId.MaxPool:
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op) kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node( ctx.push_node(
make_node( make_node(
@ -710,7 +710,7 @@ class OnnxStub:
strides=[sh, sw], strides=[sh, sw],
) )
) )
elif ty == backend.OpType.AvgPool: elif ty == backend.OpTypeId.AveragePool:
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op) kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node( ctx.push_node(
make_node( make_node(
@ -724,27 +724,27 @@ class OnnxStub:
) )
) )
elif ty in [ elif ty in [
backend.OpType.Add, backend.OpTypeId.Add,
backend.OpType.Sub, backend.OpTypeId.Sub,
backend.OpType.Mul, backend.OpTypeId.Mul,
backend.OpType.Div, backend.OpTypeId.Div,
backend.OpType.Pow, backend.OpTypeId.Pow,
backend.OpType.Relu, backend.OpTypeId.Relu,
backend.OpType.Sigmoid, backend.OpTypeId.Sigmoid,
backend.OpType.Tanh, backend.OpTypeId.Tanh,
backend.OpType.Softmax, backend.OpTypeId.Softmax,
backend.OpType.Abs, backend.OpTypeId.Abs,
backend.OpType.Identity, backend.OpTypeId.Identity,
backend.OpType.PRelu, backend.OpTypeId.PRelu,
]: ]:
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten: elif ty == backend.OpTypeId.Flatten:
axis = backend.flatten_axis_of(op) axis = backend.flatten_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Transpose: elif ty == backend.OpTypeId.Transpose:
perm = backend.transpose_permute_of(op) perm = backend.transpose_permute_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm)) ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm))
elif ty == backend.OpType.Reshape: elif ty == backend.OpTypeId.Reshape:
shape = backend.reshape_shape_of(op) shape = backend.reshape_shape_of(op)
inputs.append( inputs.append(
ctx.push_data_input( ctx.push_data_input(
@ -756,10 +756,10 @@ class OnnxStub:
) )
) )
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat: elif ty == backend.OpTypeId.Concat:
axis = backend.concat_axis_of(op) axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Split: elif ty == backend.OpTypeId.Split:
axis = backend.split_axis_of(op) axis = backend.split_axis_of(op)
num_outputs = len(outputs) num_outputs = len(outputs)
split = op.inputs()[0].shape()[axis] // num_outputs split = op.inputs()[0].shape()[axis] // num_outputs
@ -781,10 +781,10 @@ class OnnxStub:
axis=axis, axis=axis,
) )
) )
elif ty == backend.OpType.Gather: elif ty == backend.OpTypeId.Gather:
axis = backend.gather_axis_of(op) axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean: elif ty == backend.OpTypeId.ReduceMean:
axes, keepdims = backend.reduce_mean_attrs_of(op) axes, keepdims = backend.reduce_mean_attrs_of(op)
inputs.append( inputs.append(
ctx.push_data_input( ctx.push_data_input(
@ -794,9 +794,9 @@ class OnnxStub:
ctx.push_node( ctx.push_node(
make_node(ty.name, inputs, outputs, name, keepdims=keepdims) make_node(ty.name, inputs, outputs, name, keepdims=keepdims)
) )
elif ty == backend.OpType.Slice: elif ty == backend.OpTypeId.Slice:
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.Pad: elif ty == backend.OpTypeId.Pad:
pads = backend.pad_pads_of(op) pads = backend.pad_pads_of(op)
inputs.append( inputs.append(
ctx.push_data_input( ctx.push_data_input(
@ -804,7 +804,7 @@ class OnnxStub:
) )
) )
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Clip: elif ty == backend.OpTypeId.Clip:
min, max = backend.clip_attrs_of(op) min, max = backend.clip_attrs_of(op)
if min != None: if min != None:
inputs.append( inputs.append(

View File

@ -108,7 +108,7 @@ class TestStringMethods(unittest.TestCase):
name="batchNormalization", name="batchNormalization",
) )
make_and_import_model( make_and_import_model(
make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y]) make_graph([batch_norm], "batchNormalzation", [x, scale, b, mean, var], [y])
) )
def test_max_pool(self): def test_max_pool(self):

View File

@ -13,7 +13,8 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -48,7 +48,7 @@ bool DummyMutator::isMultiBranchMergable(const Graph &inGraph) {
if (inGraph->getOperators().size() != 2) if (inGraph->getOperators().size() != 2)
return false; return false;
for (auto op : inGraph->getOperators()) { for (auto op : inGraph->getOperators()) {
if (op->getOpType() != OpType::Matmul) if (op->getOpType() != OpType::MatMul)
return false; return false;
if (op->getPredecessors().size() > 0) if (op->getPredecessors().size() > 0)
return false; return false;

View File

@ -116,7 +116,7 @@ bool GraphObj::topo_sort() {
void GraphObj::optimize() { void GraphObj::optimize() {
for (auto &op : ops) { for (auto &op : ops) {
switch (op->getOpType()) { switch (op->getOpType().underlying()) {
default: default:
break; break;
} }
@ -151,7 +151,7 @@ TensorVec GraphObj::addTensor(const TensorVec &tensors) {
OpVec GraphObj::getComputeOps() const { OpVec GraphObj::getComputeOps() const {
OpVec opList; OpVec opList;
for (auto op : ops) for (auto op : ops)
if (op->isComputeOp()) if (op->getOpType().isMatMulOrConv())
opList.emplace_back(op); opList.emplace_back(op);
return opList; return opList;
} }

View File

@ -69,9 +69,11 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
} }
} }
Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean, Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
Tensor var, Tensor scale, Tensor bias, Tensor mean, Tensor var,
float momentum, float eps, bool training) { Tensor scale, Tensor bias,
float momentum, float eps,
bool training) {
if (output) { if (output) {
g->addOpWithOutputs<BatchNormObj>( g->addOpWithOutputs<BatchNormObj>(
std::move(input), output, std::move(mean), std::move(var), std::move(input), output, std::move(mean), std::move(var),

278
src/core/op_type.cc Normal file
View File

@ -0,0 +1,278 @@
#include "core/op_type.h"
namespace infini {
const char *OpType::toString() const {
#define CASE(NAME) \
case OpType::NAME: \
return #NAME
switch (type) {
CASE(Unknown);
CASE(Abs);
CASE(Acos);
CASE(Acosh);
CASE(Add);
CASE(And);
CASE(ArgMax);
CASE(Asin);
CASE(Asinh);
CASE(Atan);
CASE(Atanh);
CASE(AveragePool);
CASE(BatchNormalization);
CASE(Bernoulli);
CASE(BitShift);
CASE(BitwiseAnd);
CASE(BitwiseNot);
CASE(BitwiseOr);
CASE(BitwiseXor);
CASE(BlackmanWindow);
CASE(Cast);
CASE(CastLike);
CASE(Ceil);
CASE(Celu);
CASE(CenterCropPad);
CASE(Clip);
CASE(Col2lm);
CASE(Compress);
CASE(Concat);
CASE(ConcatFromSequence);
CASE(ConstantOfShape);
CASE(Conv);
CASE(ConvInteger);
CASE(ConvTranspose);
CASE(Cos);
CASE(Cosh);
CASE(CumSum);
CASE(DFT);
CASE(DeformConv);
CASE(DepthToSpace);
CASE(DequantizeLinear);
CASE(Det);
CASE(Div);
CASE(Dropout);
CASE(DynamicQuantizeLinear);
CASE(Einsum);
CASE(Elu);
CASE(Equal);
CASE(Erf);
CASE(Exp);
CASE(Expand);
CASE(EyeLike);
CASE(Flatten);
CASE(Floor);
CASE(GRU);
CASE(Gather);
CASE(GatherElements);
CASE(GatherND);
CASE(Gemm);
CASE(GlobalAveragePool);
CASE(GlobalLpPool);
CASE(GlobalMaxPool);
CASE(Greater);
CASE(GreaterOrEqual);
CASE(GridSample);
CASE(GroupNormalization);
CASE(HammingWindow);
CASE(HannWindow);
CASE(HardSigmoid);
CASE(HardSwish);
CASE(Hardmax);
CASE(Identity);
CASE(If);
CASE(InstanceNormalization);
CASE(IsInf);
CASE(IsNaN);
CASE(LRN);
CASE(LSTM);
CASE(LayerNormalization);
CASE(LeakyRelu);
CASE(Less);
CASE(LessOrEqual);
CASE(Log);
CASE(LogSoftmax);
CASE(Loop);
CASE(LpNormalization);
CASE(LpPool);
CASE(MatMul);
CASE(MatMulInteger);
CASE(Max);
CASE(MaxPool);
CASE(MaxRoiPool);
CASE(MaxUnpool);
CASE(Mean);
CASE(MeanVarianceNormalization);
CASE(MelWeightMatrix);
CASE(Min);
CASE(Mish);
CASE(Mod);
CASE(Mul);
CASE(Multinomial);
CASE(Neg);
CASE(NegativeLogLikelihoodLoss);
CASE(NonMaxSuppression);
CASE(NonZero);
CASE(Not);
CASE(OneHot);
CASE(Optional);
CASE(OptionalGetElement);
CASE(OptionalHasElement);
CASE(Or);
CASE(PRelu);
CASE(Pad);
CASE(Pow);
CASE(QLinearConv);
CASE(QLinearMatMul);
CASE(QuantizeLinear);
CASE(RNN);
CASE(RandomNormal);
CASE(RandomNormalLike);
CASE(RandomUniform);
CASE(RandomUniformLike);
CASE(Range);
CASE(Reciprocal);
CASE(ReduceL1);
CASE(ReduceL2);
CASE(ReduceLogSum);
CASE(ReduceLogSumExp);
CASE(ReduceMax);
CASE(ReduceMean);
CASE(ReduceMin);
CASE(ReduceProd);
CASE(ReduceSum);
CASE(ReduceSumSquare);
CASE(Relu);
CASE(Reshape);
CASE(Resize);
CASE(ReverseSequence);
CASE(RoiAlign);
CASE(Round);
CASE(STFT);
CASE(Scan);
CASE(Scatter);
CASE(ScatterElements);
CASE(ScatterND);
CASE(Selu);
CASE(SequenceAt);
CASE(SequenceConstruct);
CASE(SequenceEmpty);
CASE(SequenceErase);
CASE(SequenceInsert);
CASE(SequenceLength);
CASE(SequenceMap);
CASE(Shape);
CASE(Shrink);
CASE(Sigmoid);
CASE(Sign);
CASE(Sin);
CASE(Sinh);
CASE(Size);
CASE(Slice);
CASE(Softmax);
CASE(SoftmaxCrossEntropyLoss);
CASE(Softplus);
CASE(Softsign);
CASE(SpaceToDepth);
CASE(Split);
CASE(SplitToSequence);
CASE(Sqrt);
CASE(Squeeze);
CASE(StringNormalizer);
CASE(Sub);
CASE(Sum);
CASE(Tan);
CASE(Tanh);
CASE(TfIdfVectorizer);
CASE(ThresholdedRelu);
CASE(Tile);
CASE(TopK);
CASE(Transpose);
CASE(Trilu);
CASE(Unique);
CASE(Unsqueeze);
CASE(Upsample);
CASE(Where);
CASE(Xor);
// CUSTOM DEFINED
CASE(G2BMM);
CASE(GBMM);
CASE(MemBound);
// TODO
CASE(ConvTransNHWC);
CASE(ConvBackwardFilter);
CASE(ReluBackward);
CASE(SigmoidBackward);
CASE(TanhBackward);
CASE(Fill);
CASE(Extend);
CASE(MSELoss);
CASE(Hardtanh);
CASE(L2Loss);
CASE(Rsqrt);
CASE(FloorDiv);
CASE(FloorMod);
CASE(Square);
CASE(SquaredDifference);
default:
return "Unknown";
}
#undef CASE
}
bool OpType::isUnary() const {
static const std::unordered_set<decltype(type)> set{
Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cast, Ceil,
Clip, Cos, Cosh, Erf, Exp, Floor, Log, Neg, Not,
Relu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh,
};
return set.find(type) != set.end();
}
bool OpType::isBinary() const {
static const std::unordered_set<decltype(type)> set{
Add, And, BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor,
Div, Mod, Mul, Or, Pow, Sub, Xor,
};
return set.find(type) != set.end() || isCompair();
}
bool OpType::isElementWise() const { return isUnary() || isBinary(); }
bool OpType::isCompair() const {
static const std::unordered_set<decltype(type)> set{
Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
};
return set.find(type) != set.end();
}
bool OpType::isPool() const {
static const std::unordered_set<decltype(type)> set{};
return set.find(type) != set.end();
}
bool OpType::isGlobalPool() const {
static const std::unordered_set<decltype(type)> set{
GlobalAveragePool,
GlobalLpPool,
GlobalMaxPool,
};
return set.find(type) != set.end();
}
bool OpType::isMatMulOrConv() const {
static const std::unordered_set<decltype(type)> set{
Conv, ConvInteger, ConvTranspose, DeformConv,
QLinearConv, MatMul, MatMulInteger, QLinearMatMul,
};
return set.find(type) != set.end();
}
} // namespace infini

View File

@ -10,33 +10,6 @@ OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
IT_ASSERT(t); IT_ASSERT(t);
} }
bool OperatorObj::isLinearOp() const {
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
}
bool OperatorObj::isElementWiseOp() const {
return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300;
}
bool OperatorObj::isSplitOp() const { return type == OpType::Split; }
bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
bool OperatorObj::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::ConvTransNHWC ||
type == OpType::G2BMM || type == OpType::GBMM;
}
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose;
}
void OperatorObj::removePredecessors(const Operator &op) { void OperatorObj::removePredecessors(const Operator &op) {
for (auto it = predecessors.begin(); it != predecessors.end();) { for (auto it = predecessors.begin(); it != predecessors.end();) {
if (it->lock() == op) if (it->lock() == op)
@ -69,14 +42,14 @@ OpPerfKey OperatorObj::getOpPerfKey() const {
// Operator::hash, which hashes operator attributes and ignores tensor // Operator::hash, which hashes operator attributes and ignores tensor
// shapes. // shapes.
HashType hash = 0; HashType hash = 0;
hash = hashAppend(hash, enum_to_underlying(type)); hash = hashAppend(hash, type.underlying());
hash = hashAppend(hash, hashVector(workloadVector)); hash = hashAppend(hash, hashVector(workloadVector));
return OpPerfKey(hash, type, workloadVector); return OpPerfKey(hash, type, workloadVector);
} }
HashType OperatorObj::hash() const { HashType OperatorObj::hash() const {
HashType hash = 0; HashType hash = 0;
hash = hashAppend(hash, enum_to_underlying(type)); hash = hashAppend(hash, type.underlying());
hash = hashAppend(hash, hashVector(getOpAttrVector())); hash = hashAppend(hash, hashVector(getOpAttrVector()));
return hash; return hash;
} }

View File

@ -17,7 +17,8 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);
@ -65,7 +66,8 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);
@ -116,9 +118,8 @@ void RuntimeObj::printProfilingData(double totalTime,
const std::map<OpType, int> &opCnt) const { const std::map<OpType, int> &opCnt) const {
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean"); printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
for (const auto &[type, t] : opTime) { for (const auto &[type, t] : opTime) {
printf("%11s %3d %7.3f %7.1f %7.3f\n", printf("%11s %3d %7.3f %7.1f %7.3f\n", type.toString(), opCnt.at(type),
OpRegistry::getOpName(type).data(), opCnt.at(type), t, t, t / totalTime * 100, t / opCnt.at(type));
t / totalTime * 100, t / opCnt.at(type));
} }
} }

View File

@ -127,7 +127,7 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
std::vector<Operator> ops; std::vector<Operator> ops;
ops.emplace_back(op); ops.emplace_back(op);
node.graph = make_ref<GraphObj>(runtimeExec, ops); node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.type = op->isComputeOp(); node.type = op->getOpType().isMatMulOrConv();
node.cnt = op->getPredecessors().size(); node.cnt = op->getPredecessors().size();
opMap.emplace(op->getGuid(), i); opMap.emplace(op->getGuid(), i);
metaGraph->nodes.emplace_back(node); metaGraph->nodes.emplace_back(node);
@ -196,7 +196,7 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
} }
node.graph = make_ref<GraphObj>(runtimeExec, ops); node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.cnt = node.pre.size(); node.cnt = node.pre.size();
node.type = ops[0]->isComputeOp(); node.type = ops[0]->getOpType().isMatMulOrConv();
resultMetaGraph->nodes.emplace_back(node); resultMetaGraph->nodes.emplace_back(node);
} }
} }
@ -404,7 +404,7 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
headOps.emplace_back(op); headOps.emplace_back(op);
if (op->getPredecessors().size() + op->getSuccessors().size() >= if (op->getPredecessors().size() + op->getSuccessors().size() >=
(size_t)partitionThreshold && (size_t)partitionThreshold &&
!op->isComputeOp()) { !op->getOpType().isMatMulOrConv()) {
auto preOrderI = preOrder[op->getGuid()]; auto preOrderI = preOrder[op->getGuid()];
auto postOrderI = postOrder[op->getGuid()]; auto postOrderI = postOrder[op->getGuid()];
for (size_t j = 0; j < i; j++) { for (size_t j = 0; j < i; j++) {

View File

@ -11,7 +11,8 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance(); auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);
@ -32,7 +33,8 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -48,6 +48,8 @@ void register_operator_timer(py::module &m) {
#endif #endif
} }
decltype(OpType::type) getId(OpType const *const ptr) { return ptr->type; }
void export_values(py::module &m) { void export_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME) #define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
@ -58,13 +60,13 @@ void export_values(py::module &m) {
.VALUE(ActType, Tanh) .VALUE(ActType, Tanh)
.export_values(); .export_values();
py::enum_<OpType>(m, "OpType") py::class_<OpType>(m, "OpType")
.VALUE(OpType, Unknown) .def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Conv) .VALUE(OpType, Conv)
.VALUE(OpType, Matmul) .VALUE(OpType, MatMul)
.VALUE(OpType, ConvTrans) .VALUE(OpType, ConvTranspose)
.VALUE(OpType, G2BMM)
.VALUE(OpType, GBMM)
.VALUE(OpType, Pad) .VALUE(OpType, Pad)
.VALUE(OpType, Clip) .VALUE(OpType, Clip)
.VALUE(OpType, Slice) .VALUE(OpType, Slice)
@ -73,7 +75,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Transpose) .VALUE(OpType, Transpose)
.VALUE(OpType, Extend) .VALUE(OpType, Extend)
.VALUE(OpType, MaxPool) .VALUE(OpType, MaxPool)
.VALUE(OpType, AvgPool) .VALUE(OpType, AveragePool)
.VALUE(OpType, Add) .VALUE(OpType, Add)
.VALUE(OpType, Sub) .VALUE(OpType, Sub)
.VALUE(OpType, Mul) .VALUE(OpType, Mul)
@ -84,9 +86,8 @@ void export_values(py::module &m) {
.VALUE(OpType, Reshape) .VALUE(OpType, Reshape)
.VALUE(OpType, Flatten) .VALUE(OpType, Flatten)
.VALUE(OpType, Identity) .VALUE(OpType, Identity)
.VALUE(OpType, BatchNorm) .VALUE(OpType, BatchNormalization)
.VALUE(OpType, Softmax) .VALUE(OpType, Softmax)
.VALUE(OpType, Activation)
.VALUE(OpType, Relu) .VALUE(OpType, Relu)
.VALUE(OpType, PRelu) .VALUE(OpType, PRelu)
.VALUE(OpType, Sigmoid) .VALUE(OpType, Sigmoid)
@ -152,7 +153,7 @@ static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
static std::tuple<int, int, int, int, int, int, int, int> static std::tuple<int, int, int, int, int, int, int, int>
conv_trans_attrs_of(Operator op) { conv_trans_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ConvTrans); IT_ASSERT(op->getOpType() == OpType::ConvTranspose);
auto conv = dynamic_cast<const ConvTransposed2dObj *>(op.get()); auto conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
auto [oph, opw] = conv->getOutputPadding(); auto [oph, opw] = conv->getOutputPadding();
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(), return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
@ -161,13 +162,13 @@ conv_trans_attrs_of(Operator op) {
} }
static std::tuple<bool, bool> matmul_attrs_of(Operator op) { static std::tuple<bool, bool> matmul_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Matmul); IT_ASSERT(op->getOpType() == OpType::MatMul);
auto matmul = dynamic_cast<const MatmulObj *>(op.get()); auto matmul = dynamic_cast<const MatmulObj *>(op.get());
return std::make_tuple(matmul->getTransA(), matmul->getTransB()); return std::make_tuple(matmul->getTransA(), matmul->getTransB());
} }
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) { static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::BatchNorm); IT_ASSERT(op->getOpType() == OpType::BatchNormalization);
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get()); auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(), return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
batchnorm->getTrainingMode()); batchnorm->getTrainingMode());
@ -176,7 +177,7 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
static std::tuple<int, int, int, int, int, int, int, int> static std::tuple<int, int, int, int, int, int, int, int>
pool_attrs_of(Operator op) { pool_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::MaxPool || IT_ASSERT(op->getOpType() == OpType::MaxPool ||
op->getOpType() == OpType::AvgPool); op->getOpType() == OpType::AveragePool);
auto pool = dynamic_cast<const PoolingObj *>(op.get()); auto pool = dynamic_cast<const PoolingObj *>(op.get());
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(), return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
pool->getDw(), pool->getPh(), pool->getPw(), pool->getDw(), pool->getPh(), pool->getPw(),
@ -319,7 +320,7 @@ void init_graph_builder(py::module &m) {
.def("conv", &Handler::conv, policy::move) .def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move)
.def("matmul", &Handler::matmul, policy::move) .def("matmul", &Handler::matmul, policy::move)
.def("batchNorm", &Handler::batchNorm, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("maxPool", &Handler::maxPool, policy::move) .def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move) .def("add", &Handler::add, policy::move)

View File

@ -92,43 +92,6 @@ class RoundCnnl : public BangKernelWithoutConfig {
} }
}; };
class SquareCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat =
cnnlSquare(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
class PReluCnnl : public BangKernelWithoutConfig { class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
@ -185,24 +148,13 @@ class SigmoidCnnl : public UnaryCnnl {
float getCoef() const override { return 0.0; } float getCoef() const override { return 0.0; }
}; };
class TanhCnnl : public UnaryCnnl {
cnnlActivationMode_t getOpType() const override {
return CNNL_ACTIVATION_TANH;
}
float getCoef() const override { return 0.0; }
};
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl, REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl,
"Relu_cnnl_BANG_Float32"); "Relu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl, REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl,
"PRelu_cnnl_BANG_Float32"); "PRelu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
"Sigmoid_cnnl_BANG_Float32"); "Sigmoid_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanhCnnl,
"Tanh_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl, REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
"Round_cnnl_BANG_Float32"); "Round_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Square, DataType::Float32, SquareCnnl,
"Square_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -65,7 +65,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32,
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32"); BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -83,6 +83,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ConvTrans, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32,
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32"); ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,46 +0,0 @@
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/unary.h"
namespace infini {
class CopyCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat =
cnnlCopy(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Copy, DataType::Float32, CopyCnnl,
"Copy_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -593,9 +593,6 @@ class MulCnnl : public ElementWiseCnnl {
class EqualCnnl : public LogicOpCnnl { class EqualCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_EQ; } cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_EQ; }
}; };
class NotEqualCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_NE; }
};
class GreaterThanCnnl : public LogicOpCnnl { class GreaterThanCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_GT; } cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_GT; }
}; };
@ -651,13 +648,13 @@ REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl, REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl,
"Div_cnnl_Float32"); "Div_cnnl_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Maximum, DataType::Float32, MaximumCnnl, REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl,
"Maximum_cnnl_BANG_Float32"); "Maximum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Minimum, DataType::Float32, MinimumCnnl, REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl,
"Minimum_cnnl_BANG_Float32"); "Minimum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl, REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl,
"MSELoss_cnnl_BANG_Float32"); "MSELoss_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Power, DataType::Float32, PowerCnnl, REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl,
"Power_cnnl_BANG_Float32"); "Power_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl,
"FloorDiv_cnnl_BANG_Float32"); "FloorDiv_cnnl_BANG_Float32");
@ -667,15 +664,13 @@ REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32,
SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32"); SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl, REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl,
"Equal_cnnl_BANG_Float32"); "Equal_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::NotEqual, DataType::Float32, NotEqualCnnl, REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32,
"NotEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::GreaterThan, DataType::Float32,
GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32"); GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::GreaterEqual, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32,
GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32"); GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LessThan, DataType::Float32, LessThanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl,
"LessThan_cnnl_BANG_Float32"); "LessThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LessEqual, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32,
LessEqualCnnl, "LessEqual_cnnl_BANG_Float32"); LessEqualCnnl, "LessEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl, REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl,
"And_cnnl_BANG_Float32"); "And_cnnl_BANG_Float32");
@ -685,13 +680,13 @@ REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl,
"Xor_cnnl_BANG_Float32"); "Xor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl, REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl,
"Not_cnnl_BANG_Float32"); "Not_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitAnd, DataType::Float32, BitAndCnnl, REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl,
"BitAnd_cnnl_BANG_Float32"); "BitAnd_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitOr, DataType::Float32, BitOrCnnl, REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl,
"BitOr_cnnl_BANG_Float32"); "BitOr_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitXor, DataType::Float32, BitXorCnnl, REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl,
"BitXor_cnnl_BANG_Float32"); "BitXor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitNot, DataType::Float32, BitNotCnnl, REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl,
"BitNot_cnnl_BANG_Float32"); "BitNot_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32, // REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32,
// BitLeftShiftCnnl, // BitLeftShiftCnnl,

View File

@ -79,6 +79,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Matmul, DataType::Float32, MatmulCnnl, REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl,
"Matmul_cnnl_BANG_Float32"); "Matmul_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -68,6 +68,6 @@ class avgPoolCnnl : public PoolingCnnl {
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl, REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl,
"MaxPool_cnnl_BANG_Float32"); "MaxPool_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AvgPool, DataType::Float32, avgPoolCnnl, REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32,
"AvgPool_cnnl_BANG_Float32"); avgPoolCnnl, "AvgPool_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -162,23 +162,23 @@ REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl,
"Cos_cnnl_BANG_Float32"); "Cos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl,
"Tan_cnnl_BANG_Float32"); "Tan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ASin, DataType::Float32, ASinCnnl, REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl,
"ASin_cnnl_BANG_Float32"); "ASin_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ACos, DataType::Float32, ACosCnnl, REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl,
"ACos_cnnl_BANG_Float32"); "ACos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ATan, DataType::Float32, ATanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl,
"ATan_cnnl_BANG_Float32"); "ATan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::SinH, DataType::Float32, SinHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl,
"SinH_cnnl_BANG_Float32"); "SinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::CosH, DataType::Float32, CosHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl,
"CosH_cnnl_BANG_Float32"); "CosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::TanH, DataType::Float32, TanHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl,
"TanH_cnnl_BANG_Float32"); "TanH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ASinH, DataType::Float32, ASinHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl,
"ASinH_cnnl_BANG_Float32"); "ASinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ACosH, DataType::Float32, ACosHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl,
"ACosH_cnnl_BANG_Float32"); "ACosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ATanH, DataType::Float32, ATanHCnnl, REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl,
"ATanH_cnnl_BANG_Float32"); "ATanH_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -26,9 +26,9 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32,
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32"); NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32"); NaiveMatmul<float>, "MatmulNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -76,6 +76,6 @@ REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32"); NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32"); NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::AvgPool, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32"); NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -59,6 +59,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::BatchNorm, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32,
BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32"); BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -300,7 +300,7 @@ class convBackwardDataCudnn : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32,
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32"); convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32,
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32"); convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32");

View File

@ -114,7 +114,7 @@ class matmulCublas : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas, REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas,
"Matmul_cuBLAS_CUDA_Float32"); "Matmul_cuBLAS_CUDA_Float32");
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json); REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);

View File

@ -68,6 +68,6 @@ class avgPoolCudnn : public poolingCudnn {
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn, REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn,
"MaxPool_cuDNN_CUDA_Float32"); "MaxPool_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AvgPool, DataType::Float32, avgPoolCudnn, REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32,
"AvgPool_cuDNN_CUDA_Float32"); avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32");
}; // namespace infini }; // namespace infini

View File

@ -63,6 +63,6 @@ class MklBatchNorm : public MklKernelWithoutConfig {
{DNNL_ARG_SHIFT, baisMemory}}); {DNNL_ARG_SHIFT, baisMemory}});
} }
}; };
REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNorm, DataType::Float32, REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNormalization, DataType::Float32,
MklBatchNorm, "BatchNorm_Mkl_Float32"); MklBatchNorm, "BatchNorm_Mkl_Float32");
}; // namespace infini }; // namespace infini

View File

@ -244,7 +244,7 @@ class MklConvTranspose : public Kernel {
return make_ref<ConvTransposeMklPerfRecordObj>(ret); return make_ref<ConvTransposeMklPerfRecordObj>(ret);
} }
}; };
REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTrans, DataType::Float32, REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTranspose, DataType::Float32,
MklConvTranspose, "MklConvTrans_CPU_float32"); MklConvTranspose, "MklConvTrans_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -38,12 +38,12 @@ optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
} }
vector<int> G2BMMObj::getWorkloadVector() const { vector<int> G2BMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, k, width, dilation, return {type.underlying(), b, m, k, width, dilation,
enum_to_underlying(act)}; enum_to_underlying(act)};
} }
vector<int> G2BMMObj::getOpAttrVector() const { vector<int> G2BMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), width, dilation, enum_to_underlying(act)}; return {type.underlying(), width, dilation, enum_to_underlying(act)};
} }
} // namespace infini } // namespace infini

View File

@ -37,11 +37,10 @@ optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
} }
vector<int> GBMMObj::getWorkloadVector() const { vector<int> GBMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, w, n, dilation, return {type.underlying(), b, m, w, n, dilation, enum_to_underlying(act)};
enum_to_underlying(act)};
} }
vector<int> GBMMObj::getOpAttrVector() const { vector<int> GBMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), dilation, enum_to_underlying(act)}; return {type.underlying(), dilation, enum_to_underlying(act)};
} }
} // namespace infini } // namespace infini

View File

@ -15,7 +15,7 @@ ActivationBackwardObj::inferShape(const TensorVec &inputs) const {
std::string ActivationBackwardObj::toString() const { std::string ActivationBackwardObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -24,14 +24,14 @@ std::string ActivationBackwardObj::toString() const {
} }
vector<int> ActivationBackwardObj::getWorkloadVector() const { vector<int> ActivationBackwardObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> ActivationBackwardObj::getOpAttrVector() const { vector<int> ActivationBackwardObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {type.underlying()};
} }
}; // namespace infini }; // namespace infini

View File

@ -4,7 +4,8 @@ namespace infini {
BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output, BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output,
Tensor mean, Tensor var, Tensor scale, Tensor bias, Tensor mean, Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool trainingMode) float momentum, float eps, bool trainingMode)
: OperatorObj(OpType::BatchNorm, {input, mean, var, scale, bias}, {output}), : OperatorObj(OpType::BatchNormalization, {input, mean, var, scale, bias},
{output}),
momentum(momentum), eps(eps), trainingMode(trainingMode) { momentum(momentum), eps(eps), trainingMode(trainingMode) {
if (trainingMode) if (trainingMode)
IT_TODO_HALT(); IT_TODO_HALT();
@ -38,7 +39,7 @@ vector<DataType> BatchNormObj::inferDataType(const TensorVec &inputs) const {
std::string BatchNormObj::toString() const { std::string BatchNormObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << "BatchNorm[" << getGuid() << "]"; os << "batchNormalization[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "momentum=" << momentum << ","; os << "momentum=" << momentum << ",";
@ -57,13 +58,13 @@ std::string BatchNormObj::toString() const {
// need eps and momentum? // need eps and momentum?
vector<int> BatchNormObj::getWorkloadVector() const { vector<int> BatchNormObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
// need eps and momentum? // need eps and momentum?
vector<int> BatchNormObj::getOpAttrVector() const { vector<int> BatchNormObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {type.underlying()};
} }
} // namespace infini } // namespace infini

View File

@ -47,12 +47,12 @@ vector<int> ConcatObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims(); vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), (int)inputs.size()); ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim); ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> ConcatObj::getOpAttrVector() const { vector<int> ConcatObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim}; return {type.underlying(), dim};
} }
} // namespace infini } // namespace infini

View File

@ -19,7 +19,7 @@ ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
string ConvBaseObj::toString() const { string ConvBaseObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
if (inputs.size() == 2) { if (inputs.size() == 2) {
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
@ -36,13 +36,12 @@ string ConvBaseObj::toString() const {
} }
vector<int> ConvBaseObj::getWorkloadVector() const { vector<int> ConvBaseObj::getWorkloadVector() const {
return { return {type.underlying(), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
} }
vector<int> ConvBaseObj::getOpAttrVector() const { vector<int> ConvBaseObj::getOpAttrVector() const {
// IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? // IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw}; return {type.underlying(), c, f, r, s, ph, pw, sh, sw, dh, dw};
} }
void ConvObj::setAuxilaryAttributes(PaddingMode mode) { void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
@ -119,8 +118,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
int pw, int sh, int sw, int dh, int dw, int pw, int sh, int sw, int dh, int dw,
int oph, int opw, int group, int oph, int opw, int group,
Tensor bias, ActType act) Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw, : ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, ph, pw, sh,
dh, dw, output, weight, act), sw, dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) { oph(oph), opw(opw), group(group) {
if (bias) if (bias)
IT_TODO_HALT(); IT_TODO_HALT();
@ -133,8 +132,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
PaddingMode mode, int sh, int sw, PaddingMode mode, int sh, int sw,
int dh, int dw, int oph, int opw, int dh, int dw, int oph, int opw,
int group, Tensor bias, ActType act) int group, Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh, : ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, mode, sh, sw,
dw, output, weight, act), dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) { oph(oph), opw(opw), group(group) {
if (bias) if (bias)
IT_TODO_HALT(); IT_TODO_HALT();
@ -274,8 +273,8 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
int sw, int dh, int dw, int sw, int dh, int dw,
int oph, int opw, int group, int oph, int opw, int group,
Tensor bias, ActType act) Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh, : ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, mode, sh, sw,
dw, output, weight, act), dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) { oph(oph), opw(opw), group(group) {
if (bias) if (bias)
IT_TODO_HALT(); IT_TODO_HALT();

View File

@ -21,7 +21,7 @@ optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
std::string DetObj::toString() const { std::string DetObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -30,14 +30,12 @@ std::string DetObj::toString() const {
} }
vector<int> DetObj::getWorkloadVector() const { vector<int> DetObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> DetObj::getOpAttrVector() const { vector<int> DetObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
}; // namespace infini }; // namespace infini

View File

@ -29,12 +29,12 @@ std::string DropoutObj::toString() const {
vector<int> DropoutObj::getWorkloadVector() const { vector<int> DropoutObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace_back(static_cast<int>(ratio)); ret.emplace_back(static_cast<int>(ratio));
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> DropoutObj::getOpAttrVector() const { vector<int> DropoutObj::getOpAttrVector() const {
return {enum_to_underlying(type), static_cast<int>(ratio), false}; return {type.underlying(), static_cast<int>(ratio), false};
} }
} // namespace infini } // namespace infini

View File

@ -39,7 +39,7 @@ ElementWiseObj::inferShape(const TensorVec &inputs) const {
std::string ElementWiseObj::toString() const { std::string ElementWiseObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ","; os << vecToString(inputs[1]->getDims()) << ",";
@ -52,12 +52,12 @@ std::string ElementWiseObj::toString() const {
// use output dim or inputs dim? // use output dim or inputs dim?
vector<int> ElementWiseObj::getWorkloadVector() const { vector<int> ElementWiseObj::getWorkloadVector() const {
vector<int> ret = outputs[0]->getDims(); vector<int> ret = outputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> ElementWiseObj::getOpAttrVector() const { vector<int> ElementWiseObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {type.underlying()};
} }
MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
@ -83,7 +83,7 @@ optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const {
std::string MSELossObj::toString() const { std::string MSELossObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ","; os << vecToString(inputs[1]->getDims()) << ",";
@ -96,12 +96,10 @@ std::string MSELossObj::toString() const {
// use output dim or inputs dim? // use output dim or inputs dim?
vector<int> MSELossObj::getWorkloadVector() const { vector<int> MSELossObj::getWorkloadVector() const {
vector<int> ret = outputs[0]->getDims(); vector<int> ret = outputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> MSELossObj::getOpAttrVector() const { vector<int> MSELossObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
}; // namespace infini }; // namespace infini

View File

@ -30,12 +30,12 @@ vector<int> ExtendObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace_back(dim); ret.emplace_back(dim);
ret.emplace_back(num); ret.emplace_back(num);
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> ExtendObj::getOpAttrVector() const { vector<int> ExtendObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim, num}; return {type.underlying(), dim, num};
} }
} // namespace infini } // namespace infini

View File

@ -72,7 +72,7 @@ std::string GatherObj::toString() const {
vector<int> GatherObj::getWorkloadVector() const { vector<int> GatherObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
for (auto it : inputs[1]->getDims()) for (auto it : inputs[1]->getDims())
ret.emplace_back(it); ret.emplace_back(it);
ret.emplace_back(axis); ret.emplace_back(axis);
@ -80,7 +80,7 @@ vector<int> GatherObj::getWorkloadVector() const {
} }
vector<int> GatherObj::getOpAttrVector() const { vector<int> GatherObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis}; return {type.underlying(), axis};
} }
} // namespace infini } // namespace infini

View File

@ -4,7 +4,7 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act) bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul, : OperatorObj(OpType::MatMul,
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}), bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
transA(transA), transB(transB), act(act), b(1) { transA(transA), transB(transB), act(act), b(1) {
auto shape_a = A->getDims(); auto shape_a = A->getDims();
@ -82,12 +82,12 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
} }
vector<int> MatmulObj::getWorkloadVector() const { vector<int> MatmulObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, n, k, transA, transB, return {type.underlying(), b, m, n, k, transA, transB,
enum_to_underlying(act)}; enum_to_underlying(act)};
} }
vector<int> MatmulObj::getOpAttrVector() const { vector<int> MatmulObj::getOpAttrVector() const {
return {enum_to_underlying(type), transA, transB, enum_to_underlying(act)}; return {type.underlying(), transA, transB, enum_to_underlying(act)};
} }
} // namespace infini } // namespace infini

View File

@ -69,7 +69,7 @@ optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
} }
vector<int> MemBoundObj::getWorkloadVector() const { vector<int> MemBoundObj::getWorkloadVector() const {
return {enum_to_underlying(type), (int)simplifiedHash}; return {type.underlying(), (int)simplifiedHash};
} }
vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); } vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }

View File

@ -50,13 +50,13 @@ std::string PadObj::toString() const {
vector<int> PadObj::getWorkloadVector() const { vector<int> PadObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), pads.begin(), pads.end()); ret.insert(ret.end(), pads.begin(), pads.end());
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> PadObj::getOpAttrVector() const { vector<int> PadObj::getOpAttrVector() const {
vector<int> ret = pads; vector<int> ret = pads;
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }

View File

@ -28,7 +28,7 @@ optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
std::string PoolingObj::toString() const { std::string PoolingObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << "k=[" << kh << "," << kw << "],"; os << "k=[" << kh << "," << kw << "],";
os << "p=[" << ph << "," << pw << "],"; os << "p=[" << ph << "," << pw << "],";
@ -40,12 +40,11 @@ std::string PoolingObj::toString() const {
} }
vector<int> PoolingObj::getWorkloadVector() const { vector<int> PoolingObj::getWorkloadVector() const {
return { return {type.underlying(), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
enum_to_underlying(type), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
} }
vector<int> PoolingObj::getOpAttrVector() const { vector<int> PoolingObj::getOpAttrVector() const {
return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw}; return {type.underlying(), kh, kw, ph, pw, sh, sw, dh, dw};
} }
}; // namespace infini }; // namespace infini

View File

@ -69,14 +69,14 @@ std::string ReduceMeanObj::toString() const {
vector<int> ReduceMeanObj::getWorkloadVector() const { vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
ret.emplace_back((int)keepDims); ret.emplace_back((int)keepDims);
ret.insert(ret.end(), axes.begin(), axes.end()); ret.insert(ret.end(), axes.begin(), axes.end());
return ret; return ret;
} }
vector<int> ReduceMeanObj::getOpAttrVector() const { vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ret = {enum_to_underlying(type), (int)keepDims}; vector<int> ret = {type.underlying(), (int)keepDims};
ret.insert(ret.end(), axes.begin(), axes.end()); ret.insert(ret.end(), axes.begin(), axes.end());
return ret; return ret;
} }

View File

@ -30,12 +30,12 @@ std::string ReshapeObj::toString() const {
vector<int> ReshapeObj::getWorkloadVector() const { vector<int> ReshapeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), dims.begin(), dims.end()); ret.insert(ret.end(), dims.begin(), dims.end());
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> ReshapeObj::getOpAttrVector() const { vector<int> ReshapeObj::getOpAttrVector() const {
vector<int> ret = dims; vector<int> ret = dims;
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
@ -74,12 +74,12 @@ std::string FlattenObj::toString() const {
vector<int> FlattenObj::getWorkloadVector() const { vector<int> FlattenObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), axis); ret.emplace(ret.begin(), axis);
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> FlattenObj::getOpAttrVector() const { vector<int> FlattenObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis}; return {type.underlying(), axis};
} }
IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output) IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output)
@ -103,10 +103,8 @@ std::string IdentityObj::toString() const {
vector<int> IdentityObj::getWorkloadVector() const { vector<int> IdentityObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> IdentityObj::getOpAttrVector() const { vector<int> IdentityObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
} // namespace infini } // namespace infini

View File

@ -244,7 +244,7 @@ vector<int> ResizeObj::getWorkloadVector() const {
// here. // here.
ret.emplace_back(enum_to_underlying(coMode)); ret.emplace_back(enum_to_underlying(coMode));
ret.emplace_back(enum_to_underlying(nearestMode)); ret.emplace_back(enum_to_underlying(nearestMode));
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
@ -253,7 +253,7 @@ vector<int> ResizeObj::getOpAttrVector() const {
ret.emplace_back(enum_to_underlying(coMode)); ret.emplace_back(enum_to_underlying(coMode));
ret.emplace_back(enum_to_underlying(nearestMode)); ret.emplace_back(enum_to_underlying(nearestMode));
ret.emplace_back(enum_to_underlying(ratioPolicy)); ret.emplace_back(enum_to_underlying(ratioPolicy));
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }

View File

@ -93,7 +93,7 @@ vector<int> SliceObj::getWorkloadVector() const {
} }
vector<int> SliceObj::getOpAttrVector() const { vector<int> SliceObj::getOpAttrVector() const {
vector<int> ans{enum_to_underlying(type)}; vector<int> ans{type.underlying()};
for (const auto &range : axes) { for (const auto &range : axes) {
ans.push_back(range.start); ans.push_back(range.start);
ans.push_back(range.end); ans.push_back(range.end);

View File

@ -15,7 +15,7 @@ SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
std::string SoftmaxObj::toString() const { std::string SoftmaxObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -25,13 +25,13 @@ std::string SoftmaxObj::toString() const {
} }
vector<int> SoftmaxObj::getWorkloadVector() const { vector<int> SoftmaxObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type), axis}; vector<int> ret{type.underlying(), axis};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> SoftmaxObj::getOpAttrVector() const { vector<int> SoftmaxObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis}; return {type.underlying(), axis};
} }
} // namespace infini } // namespace infini

View File

@ -56,14 +56,14 @@ optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const {
vector<int> SplitObj::getWorkloadVector() const { vector<int> SplitObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(dim); ret.emplace_back(dim);
ret.emplace_back(num); ret.emplace_back(num);
return ret; return ret;
} }
vector<int> SplitObj::getOpAttrVector() const { vector<int> SplitObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim, num}; return {type.underlying(), dim, num};
} }
string SplitObj::toString() const { string SplitObj::toString() const {

View File

@ -28,7 +28,7 @@ TransposeObj::inferShape(const TensorVec &inputs) const {
std::string TransposeObj::toString() const { std::string TransposeObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -37,14 +37,14 @@ std::string TransposeObj::toString() const {
} }
vector<int> TransposeObj::getWorkloadVector() const { vector<int> TransposeObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> TransposeObj::getOpAttrVector() const { vector<int> TransposeObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {type.underlying()};
} }
}; // namespace infini }; // namespace infini

View File

@ -13,7 +13,7 @@ optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const {
std::string UnaryObj::toString() const { std::string UnaryObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -22,15 +22,13 @@ std::string UnaryObj::toString() const {
} }
vector<int> UnaryObj::getWorkloadVector() const { vector<int> UnaryObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> UnaryObj::getOpAttrVector() const { vector<int> UnaryObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output, ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
std::optional<float> min, std::optional<float> max) std::optional<float> min, std::optional<float> max)
@ -46,7 +44,7 @@ optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) const {
std::string ClipObj::toString() const { std::string ClipObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -55,15 +53,13 @@ std::string ClipObj::toString() const {
} }
vector<int> ClipObj::getWorkloadVector() const { vector<int> ClipObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> ClipObj::getOpAttrVector() const { vector<int> ClipObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output, HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output,
float min, float max) float min, float max)
@ -79,7 +75,7 @@ optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) const {
std::string HardtanhObj::toString() const { std::string HardtanhObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -88,15 +84,13 @@ std::string HardtanhObj::toString() const {
} }
vector<int> HardtanhObj::getWorkloadVector() const { vector<int> HardtanhObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> HardtanhObj::getOpAttrVector() const { vector<int> HardtanhObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value) FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
: OperatorObj(OpType::Fill, {input}, {output}), setValue(value) { : OperatorObj(OpType::Fill, {input}, {output}), setValue(value) {
@ -110,22 +104,20 @@ optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) const {
std::string FillObj::toString() const { std::string FillObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
vector<int> FillObj::getWorkloadVector() const { vector<int> FillObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> FillObj::getOpAttrVector() const { vector<int> FillObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output) L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output)
: OperatorObj(OpType::L2Loss, {input}, {output}) { : OperatorObj(OpType::L2Loss, {input}, {output}) {
@ -139,22 +131,20 @@ optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const {
std::string L2LossObj::toString() const { std::string L2LossObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
vector<int> L2LossObj::getWorkloadVector() const { vector<int> L2LossObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> L2LossObj::getOpAttrVector() const { vector<int> L2LossObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type) CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type)
: OperatorObj(OpType::Cast, {input}, {output}), castType(type) { : OperatorObj(OpType::Cast, {input}, {output}), castType(type) {
@ -176,22 +166,20 @@ optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
std::string CastObj::toString() const { std::string CastObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
vector<int> CastObj::getWorkloadVector() const { vector<int> CastObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> CastObj::getOpAttrVector() const { vector<int> CastObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
DataType CastObj::getOutputDataType() const { DataType CastObj::getOutputDataType() const {
switch (castType) { switch (castType) {
@ -251,7 +239,7 @@ optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const {
std::string ShapeObj::toString() const { std::string ShapeObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "](" os << type.toString() << "[" << getGuid() << "]("
<< "output=" << outputs[0]->getGuid() << ")"; << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
@ -268,7 +256,7 @@ optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) const {
std::string PReluObj::toString() const { std::string PReluObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
@ -277,15 +265,13 @@ std::string PReluObj::toString() const {
} }
vector<int> PReluObj::getWorkloadVector() const { vector<int> PReluObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> PReluObj::getOpAttrVector() const { vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type) LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
: OperatorObj(OpType::Log, {input}, {output}), logType(type) { : OperatorObj(OpType::Log, {input}, {output}), logType(type) {
@ -299,21 +285,19 @@ optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) const {
std::string LogObj::toString() const { std::string LogObj::toString() const {
std::ostringstream os; std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << type.toString() << "[" << getGuid() << "]";
os << "("; os << "(";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
vector<int> LogObj::getWorkloadVector() const { vector<int> LogObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)}; vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims(); const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end()); ret.insert(ret.end(), shape.begin(), shape.end());
return ret; return ret;
} }
vector<int> LogObj::getOpAttrVector() const { vector<int> LogObj::getOpAttrVector() const { return {type.underlying()}; }
return {enum_to_underlying(type)};
}
}; // namespace infini }; // namespace infini

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testCopy(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(outputGpu2Cpu->equalData(inputCpu));
}
TEST(cnnl_Copy, run) {
testCopy<CopyObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,49 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testFloorDiv(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
inputCpu1->printData();
inputCpu2->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_FloorDiv, run) {
testFloorDiv<FloorDivObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,49 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testFloorMod(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
inputCpu1->printData();
inputCpu2->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_FloorMod, run) {
testFloorMod<FloorModObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -42,7 +42,6 @@ void testLogicOp(const std::function<void(void *, size_t, DataType)> &generator,
TEST(cnnl_LogicOp, run) { TEST(cnnl_LogicOp, run) {
testLogicOp<EqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testLogicOp<EqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<NotEqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<GreaterThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testLogicOp<GreaterThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<GreaterEqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testLogicOp<GreaterEqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<LessThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testLogicOp<LessThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testRsqrt(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Rsqrt, run) {
testRsqrt<RsqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testSquare(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Square, run) {
testSquare<SquareObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,48 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testSquaredDifference(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_SquaredDifference, run) {
testSquaredDifference<SquaredDifferenceObj>(IncrementalGenerator(),
Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -152,8 +152,8 @@ TEST(cuDNN_ConvTransposed, tune) {
bool tune = true; bool tune = true;
cuda->run(gCuda, tune); cuda->run(gCuda, tune);
// check record // check record
auto kernelAttrs = auto kernelAttrs = KernelAttrs{Device::CUDA, conv->getOpType().underlying(),
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32}; DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData = std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey); PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -53,8 +53,8 @@ TEST(mkl_Conv, tune) {
mklRuntime->run(gMkl, tune); mklRuntime->run(gMkl, tune);
// check record // check record
auto kernelAttrs = auto kernelAttrs = KernelAttrs{
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32}; Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData = std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey); PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -73,8 +73,8 @@ TEST(mkl_ConvTransposed, tune) {
bool tune = true; bool tune = true;
runtime->run(gMkl, tune); runtime->run(gMkl, tune);
// check record // check record
auto kernelAttrs = auto kernelAttrs = KernelAttrs{
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32}; Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData = std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey); PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -4,7 +4,7 @@
#include "test.h" #include "test.h"
namespace infini { namespace infini {
TEST(BatchNorm, ShapeInference) { TEST(BatchNormalization, ShapeInference) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
{ {
Graph g = make_ref<GraphObj>(cpuRuntime); Graph g = make_ref<GraphObj>(cpuRuntime);