refactor(core): 添加新的 `OpType` 定义 (#99)

* feat: 添加新的 OpType 定义

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* refactor: 使用新的 OpType 替换原来的,修改整个项目

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: onnx 导入

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 修正 cuda 和 bang kernel 的问题

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 过滤 bang test

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 过滤 bang test

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix bang code.

* fix code on bang

* fmt

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 删除指定文件

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 删两个没用的文件,去掉一个不知道为什么的注释

Signed-off-by: YdrMaster <ydrml@hotmail.com>

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
This commit is contained in:
Derui Yang 2023-08-07 11:17:05 +08:00 committed by GitHub
parent 9b10a74788
commit 57ac94d893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 776 additions and 907 deletions

View File

@ -27,9 +27,9 @@ class GraphHandlerObj {
int opw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act);
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
Tensor scale, Tensor bias, float momentum, float eps,
bool training);
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw);

View File

@ -105,8 +105,8 @@ class KernelRegistry {
IT_ASSERT(it != kernels.end(),
"Kernel not found for key {" +
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
", " + OpRegistry::getOpName(std::get<1>(kernelAttrs)) +
", " + std::get<2>(kernelAttrs).toString() + "}");
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
std::get<2>(kernelAttrs).toString() + "}");
return std::get<0>(it->second);
}
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {

253
include/core/op_type.h Normal file
View File

@ -0,0 +1,253 @@
#pragma once
#ifndef OP_TYPE_H
#define OP_TYPE_H
#include <string>
#include <unordered_set>
namespace infini {
struct OpType {
using underlying_t = uint16_t;
// Clang-format is ambiguous in formating of comment alignment.
// In order to disambiguate, it is necessary to comment all enum
// elements.
enum : underlying_t {
Unknown,
Abs, // Unary
Acos, // Unary
Acosh, // Unary
Add, // Binary
And, // Binary
ArgMax, //
Asin, // Binary
Asinh, // Binary
Atan, // Binary
Atanh, // Binary
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //
BitShift, // Binary
BitwiseAnd, // Binary
BitwiseNot, // Binary
BitwiseOr, // Binary
BitwiseXor, // Binary
BlackmanWindow, //
Cast, // Unary
CastLike, //
Ceil, // Unary
Celu, //
CenterCropPad, //
Clip, // Unary
Col2lm,
Compress,
Concat,
ConcatFromSequence,
ConstantOfShape,
Conv, // ComputationIntensive
ConvInteger, // ComputationIntensive
ConvTranspose, // ComputationIntensive
Cos, // Unary
Cosh, // Unary
CumSum,
DFT,
DeformConv, // ComputationIntensive
DepthToSpace,
DequantizeLinear,
Det,
Div, // Binary
Dropout,
DynamicQuantizeLinear,
Einsum,
Elu,
Equal, // Compair
Erf, // Unary
Exp, // Unary
Expand,
EyeLike,
Flatten,
Floor, // Unary
GRU,
Gather,
GatherElements,
GatherND,
Gemm,
GlobalAveragePool, // GlobalPool
GlobalLpPool, // GlobalPool
GlobalMaxPool, // GlobalPool
Greater, // Compair
GreaterOrEqual, // Compair
GridSample,
GroupNormalization,
HammingWindow,
HannWindow,
HardSigmoid,
HardSwish,
Hardmax,
Identity,
If,
InstanceNormalization,
IsInf,
IsNaN,
LRN,
LSTM,
LayerNormalization,
LeakyRelu,
Less, // Compair
LessOrEqual, // Compair
Log, // Unary
LogSoftmax,
Loop,
LpNormalization,
LpPool,
MatMul, // ComputationIntensive
MatMulInteger, // ComputationIntensive
Max,
MaxPool,
MaxRoiPool,
MaxUnpool,
Mean,
MeanVarianceNormalization,
MelWeightMatrix,
Min,
Mish,
Mod, // Binary
Mul, // Binary
Multinomial, //
Neg, // Unary
NegativeLogLikelihoodLoss,
NonMaxSuppression,
NonZero,
Not, // Unary
OneHot,
Optional,
OptionalGetElement,
OptionalHasElement,
Or, // Binary
PRelu, //
Pad, //
Pow, // Binary
QLinearConv, // ComputationIntensive
QLinearMatMul, // ComputationIntensive
QuantizeLinear,
RNN,
RandomNormal,
RandomNormalLike,
RandomUniform,
RandomUniformLike,
Range,
Reciprocal,
ReduceL1, // Reduce
ReduceL2, // Reduce
ReduceLogSum, // Reduce
ReduceLogSumExp, // Reduce
ReduceMax, // Reduce
ReduceMean, // Reduce
ReduceMin, // Reduce
ReduceProd, // Reduce
ReduceSum, // Reduce
ReduceSumSquare, // Reduce
Relu, // Unary
Reshape,
Resize,
ReverseSequence,
RoiAlign,
Round, // Unary
STFT,
Scan,
Scatter,
ScatterElements,
ScatterND,
Selu,
SequenceAt,
SequenceConstruct,
SequenceEmpty,
SequenceErase,
SequenceInsert,
SequenceLength,
SequenceMap,
Shape,
Shrink,
Sigmoid,
Sign,
Sin, // Unary
Sinh, // Unary
Size,
Slice,
Softmax,
SoftmaxCrossEntropyLoss,
Softplus,
Softsign,
SpaceToDepth,
Split,
SplitToSequence,
Sqrt,
Squeeze,
StringNormalizer,
Sub, // Binary
Sum, //
Tan, // Unary
Tanh, // unary
TfIdfVectorizer,
ThresholdedRelu,
Tile,
TopK,
Transpose,
Trilu,
Unique,
Unsqueeze,
Upsample,
Where,
Xor, // Binary
// CUSTOM DEFINED
G2BMM,
GBMM,
MemBound,
// TODO
ConvTransNHWC,
ConvBackwardFilter,
ReluBackward,
SigmoidBackward,
TanhBackward,
Fill,
Extend,
MSELoss,
Hardtanh,
L2Loss,
Rsqrt,
FloorDiv,
FloorMod,
Square,
SquaredDifference,
} type;
constexpr OpType(decltype(type) t) : type(t) {}
constexpr explicit OpType(underlying_t val) : type((decltype(type))val) {}
constexpr underlying_t underlying() const { return type; }
bool operator==(OpType others) const { return type == others.type; }
bool operator!=(OpType others) const { return type != others.type; }
bool operator<(OpType others) const { return type < others.type; }
const char *toString() const;
bool isUnary() const;
bool isBinary() const;
bool isElementWise() const;
bool isCompair() const;
bool isPool() const;
bool isGlobalPool() const;
bool isMatMulOrConv() const;
};
enum class ActType {
None,
Relu,
Sigmoid,
Tanh,
};
} // namespace infini
#endif // OP_TYPE_H

View File

@ -1,231 +1,14 @@
#pragma once
#include "core/op_type.h"
#include "core/tensor.h"
namespace infini {
enum class OpType {
Unknown = 0,
// linear
Conv = 100,
ConvBackwardFilter,
ConvBackwardData,
Matmul,
ConvTrans,
ConvTransNHWC,
G2BMM,
GBMM,
Pad,
Slice,
Concat,
Split,
Transpose,
Extend,
MaxPool,
AvgPool,
Add,
Sub,
Mul,
Div,
Pow,
Gather,
ReduceMean,
Reshape,
Flatten,
Identity,
// element wise
BatchNorm = 200,
Softmax,
Activation,
Relu,
ReluBackward,
PRelu,
Sigmoid,
SigmoidBackward,
Tanh,
TanhBackward,
Abs,
Sin,
Cos,
Tan,
ASin,
ACos,
ATan,
SinH,
CosH,
TanH,
ASinH,
ACosH,
ATanH,
Resize,
Arange,
Shape,
Copy,
Ceil,
Floor,
Clip,
Erf,
Exp,
Fill,
Log,
L2Loss,
Maximum,
Minimum,
MSELoss,
Neg,
Power,
Reciprocal,
Sqrt,
Rsqrt,
Cast,
FloorDiv,
FloorMod,
Det,
Round,
Square,
SquaredDifference,
Hardtanh,
Equal,
NotEqual,
GreaterThan,
GreaterEqual,
LessThan,
LessEqual,
And,
Or,
Xor,
Not,
BitAnd,
BitOr,
BitXor,
BitNot,
BitLeftShift,
BitRightShift,
Dropout,
//
MemBound = 300,
};
using KernelAttrs = std::tuple<Device, OpType, DataType>;
class OpRegistry {
public:
static std::string getOpName(OpType opType) {
#define FOP(op) \
case OpType::op: \
return #op
switch (opType) {
FOP(Unknown);
// linear
FOP(Conv);
FOP(ConvBackwardFilter);
FOP(ConvBackwardData);
FOP(Matmul);
FOP(ConvTrans);
FOP(G2BMM);
FOP(GBMM);
FOP(Pad);
FOP(Slice);
FOP(Concat);
FOP(Split);
FOP(Transpose);
FOP(Extend);
FOP(MaxPool);
FOP(AvgPool);
FOP(Add);
FOP(Sub);
FOP(Mul);
FOP(Div);
FOP(Pow);
FOP(Gather);
FOP(ReduceMean);
FOP(Reshape);
FOP(Identity);
FOP(Shape);
// element wise
FOP(BatchNorm);
FOP(Softmax);
FOP(Activation);
FOP(Relu);
FOP(ReluBackward);
FOP(PRelu);
FOP(Sigmoid);
FOP(SigmoidBackward);
FOP(Tanh);
FOP(TanhBackward);
FOP(Abs);
FOP(Sin);
FOP(Cos);
FOP(Tan);
FOP(ASin);
FOP(ACos);
FOP(ATan);
FOP(SinH);
FOP(CosH);
FOP(TanH);
FOP(ASinH);
FOP(ACosH);
FOP(ATanH);
FOP(Copy);
FOP(Ceil);
FOP(Floor);
FOP(Clip);
FOP(Erf);
FOP(Exp);
FOP(Fill);
FOP(Log);
FOP(L2Loss);
FOP(Maximum);
FOP(Minimum);
FOP(MSELoss);
FOP(Neg);
FOP(Power);
FOP(Reciprocal);
FOP(Sqrt);
FOP(Rsqrt);
FOP(Cast);
FOP(FloorDiv);
FOP(FloorMod);
FOP(Det);
FOP(Round);
FOP(Square);
FOP(SquaredDifference);
FOP(Hardtanh);
FOP(Equal);
FOP(NotEqual);
FOP(GreaterThan);
FOP(GreaterEqual);
FOP(LessThan);
FOP(LessEqual);
FOP(And);
FOP(Or);
FOP(Xor);
FOP(Not);
FOP(BitAnd);
FOP(BitOr);
FOP(BitXor);
FOP(BitNot);
FOP(BitLeftShift);
FOP(BitRightShift);
//
FOP(MemBound);
default:
IT_ASSERT(false);
break;
}
#undef FOP
}
};
enum class ActType {
None,
Relu,
Sigmoid,
Tanh,
};
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
struct OpPerfKey {
HashType hash;
OpType opType;
OpType::underlying_t opType;
vector<int> attrs;
public:
@ -233,7 +16,7 @@ struct OpPerfKey {
// https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types
OpPerfKey() = default;
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
: hash(hash), opType(opType), attrs(attrs) {}
: hash(hash), opType(opType.underlying()), attrs(attrs) {}
bool operator==(const OpPerfKey &rhs) const {
if (hash != rhs.hash)
return false;
@ -290,16 +73,7 @@ class OperatorObj : public Object {
*/
HashType hash() const;
public: // check Op type
bool isLinearOp() const;
bool isElementWiseOp() const;
bool isSplitOp() const;
bool isConcatOp() const;
bool isComputeOp() const;
bool isTransposeOp() const;
bool isReshapeOp() const;
bool isMemBoundOp() const;
public:
public: // getter and setter
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }

View File

@ -1,5 +1,6 @@
#pragma once
#include "core/common.h"
#include "core/op_type.h"
#include "core/ref.h"
#include <memory>
@ -21,7 +22,6 @@ using Graph = Ref<GraphObj>;
using GraphHandler = Ref<GraphHandlerObj>;
using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>;
enum class OpType;
using TensorVec = vector<Tensor>;
using OpVec = vector<Operator>;

View File

@ -65,26 +65,24 @@ DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub)
DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul)
DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div)
DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow)
DEFINE_ELEMENT_WISE_OBJ(Maximum, OpType::Maximum)
DEFINE_ELEMENT_WISE_OBJ(Minimum, OpType::Minimum)
DEFINE_ELEMENT_WISE_OBJ(Power, OpType::Power)
DEFINE_ELEMENT_WISE_OBJ(Maximum, OpType::Max)
DEFINE_ELEMENT_WISE_OBJ(Minimum, OpType::Min)
DEFINE_ELEMENT_WISE_OBJ(Power, OpType::Pow)
DEFINE_ELEMENT_WISE_OBJ(FloorDiv, OpType::FloorDiv)
DEFINE_ELEMENT_WISE_OBJ(FloorMod, OpType::FloorMod)
DEFINE_ELEMENT_WISE_OBJ(SquaredDifference, OpType::SquaredDifference)
DEFINE_ELEMENT_WISE_OBJ(Equal, OpType::Equal)
DEFINE_ELEMENT_WISE_OBJ(NotEqual, OpType::NotEqual)
DEFINE_ELEMENT_WISE_OBJ(GreaterThan, OpType::GreaterThan)
DEFINE_ELEMENT_WISE_OBJ(GreaterEqual, OpType::GreaterEqual)
DEFINE_ELEMENT_WISE_OBJ(LessThan, OpType::LessThan)
DEFINE_ELEMENT_WISE_OBJ(LessEqual, OpType::LessEqual)
DEFINE_ELEMENT_WISE_OBJ(GreaterThan, OpType::Greater)
DEFINE_ELEMENT_WISE_OBJ(GreaterEqual, OpType::GreaterOrEqual)
DEFINE_ELEMENT_WISE_OBJ(LessThan, OpType::Less)
DEFINE_ELEMENT_WISE_OBJ(LessEqual, OpType::LessOrEqual)
DEFINE_ELEMENT_WISE_OBJ(And, OpType::And)
DEFINE_ELEMENT_WISE_OBJ(Or, OpType::Or)
DEFINE_ELEMENT_WISE_OBJ(Xor, OpType::Xor)
DEFINE_ELEMENT_WISE_OBJ(Not, OpType::Not)
DEFINE_ELEMENT_WISE_OBJ(BitAnd, OpType::BitAnd)
DEFINE_ELEMENT_WISE_OBJ(BitOr, OpType::BitOr)
DEFINE_ELEMENT_WISE_OBJ(BitXor, OpType::BitXor)
DEFINE_ELEMENT_WISE_OBJ(BitNot, OpType::BitNot)
DEFINE_ELEMENT_WISE_OBJ(BitLeftShift, OpType::BitLeftShift)
DEFINE_ELEMENT_WISE_OBJ(BitRightShift, OpType::BitRightShift)
DEFINE_ELEMENT_WISE_OBJ(BitAnd, OpType::BitwiseAnd)
DEFINE_ELEMENT_WISE_OBJ(BitOr, OpType::BitwiseOr)
DEFINE_ELEMENT_WISE_OBJ(BitXor, OpType::BitwiseXor)
DEFINE_ELEMENT_WISE_OBJ(BitNot, OpType::BitwiseNot)
DEFINE_ELEMENT_WISE_OBJ(BitLeftShift, OpType::BitShift)
}; // namespace infini

View File

@ -70,7 +70,7 @@ class AvgPoolObj : public PoolingObj {
public:
AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw)
: PoolingObj(graph, OpType::AvgPool, input, output, kh, kw, dh, dw, ph,
pw, sh, sw) {}
: PoolingObj(graph, OpType::AveragePool, input, output, kh, kw, dh, dw,
ph, pw, sh, sw) {}
};
}; // namespace infini

View File

@ -197,27 +197,6 @@ class CumsumObj : public OperatorObj {
vector<int> getOpAttrVector() const override;
};
class ArangeObj : public OperatorObj {
public:
ArangeObj(GraphObj *graph, float start, float step, int length,
Tensor output);
OP_CLONE(ArangeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 0; }
int numOutputs() const override { return 1; }
float getStartValue() { return startValue; }
float getStepValue() { return stepValue; }
int getLength() { return lengthValue; }
private:
float startValue, stepValue;
int lengthValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
class ShapeObj : public OperatorObj {
public:
ShapeObj(GraphObj *graph, Tensor input, Tensor output);
@ -283,17 +262,16 @@ DEFINE_UNARY_OBJ(Abs, OpType::Abs)
DEFINE_UNARY_OBJ(Sin, OpType::Sin)
DEFINE_UNARY_OBJ(Cos, OpType::Cos)
DEFINE_UNARY_OBJ(Tan, OpType::Tan)
DEFINE_UNARY_OBJ(ASin, OpType::ASin)
DEFINE_UNARY_OBJ(ACos, OpType::ACos)
DEFINE_UNARY_OBJ(ATan, OpType::ATan)
DEFINE_UNARY_OBJ(SinH, OpType::SinH)
DEFINE_UNARY_OBJ(CosH, OpType::CosH)
DEFINE_UNARY_OBJ(TanH, OpType::TanH)
DEFINE_UNARY_OBJ(ASinH, OpType::ASinH)
DEFINE_UNARY_OBJ(ACosH, OpType::ACosH)
DEFINE_UNARY_OBJ(ATanH, OpType::ATanH)
DEFINE_UNARY_OBJ(ASin, OpType::Asin)
DEFINE_UNARY_OBJ(ACos, OpType::Acos)
DEFINE_UNARY_OBJ(ATan, OpType::Atan)
DEFINE_UNARY_OBJ(SinH, OpType::Sinh)
DEFINE_UNARY_OBJ(CosH, OpType::Cosh)
DEFINE_UNARY_OBJ(TanH, OpType::Tanh)
DEFINE_UNARY_OBJ(ASinH, OpType::Asinh)
DEFINE_UNARY_OBJ(ACosH, OpType::Acosh)
DEFINE_UNARY_OBJ(ATanH, OpType::Atanh)
DEFINE_UNARY_OBJ(Copy, OpType::Copy)
DEFINE_UNARY_OBJ(Ceil, OpType::Ceil)
DEFINE_UNARY_OBJ(Floor, OpType::Floor)
DEFINE_UNARY_OBJ(Erf, OpType::Erf)
@ -301,7 +279,5 @@ DEFINE_UNARY_OBJ(Exp, OpType::Exp)
DEFINE_UNARY_OBJ(Neg, OpType::Neg)
DEFINE_UNARY_OBJ(Reciprocal, OpType::Reciprocal)
DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt)
DEFINE_UNARY_OBJ(Rsqrt, OpType::Rsqrt)
DEFINE_UNARY_OBJ(Round, OpType::Round)
DEFINE_UNARY_OBJ(Square, OpType::Square)
}; // namespace infini

View File

@ -196,7 +196,7 @@ class OnnxStub:
attributes[name]
for name in ["momentum", "epsilon", "training_mode"]
)
tensors[node.output[0]] = self.handler.batchNorm(
tensors[node.output[0]] = self.handler.batchNormalization(
input, output, mean, var, scale, bias, momentum, eps, training != 0
)
elif node.op_type == "MaxPool":
@ -551,7 +551,7 @@ class OnnxStub:
# saves object names, including tensors and operators
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
# counts the occurrence times of each operator for naming
count_op: Dict[backend.OpType, int] = dict()
count_op: Dict[backend.OpTypeId, int] = dict()
# counts input and output tensors for naming
count_in, count_out = 0, 0
# saves nodes (operators)
@ -563,8 +563,8 @@ class OnnxStub:
# saves global input tensors
initializers: List[TensorProto] = []
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
def name_op(self, op: backend.Operator) -> Tuple[backend.OpTypeId, str]:
ty = op.op_type().id()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
self.names[op] = name
self.count_op[ty] += 1
@ -647,7 +647,7 @@ class OnnxStub:
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Conv:
if ty == backend.OpTypeId.Conv:
ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op)
ctx.push_node(
make_node(
@ -661,11 +661,11 @@ class OnnxStub:
group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1],
)
)
elif ty == backend.OpType.ConvTrans:
elif ty == backend.OpTypeId.ConvTranspose:
ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op)
ctx.push_node(
make_node(
"ConvTranspose",
ty.name,
inputs,
outputs,
name,
@ -675,14 +675,14 @@ class OnnxStub:
output_padding=[oph, opw],
)
)
elif ty == backend.OpType.Matmul:
elif ty == backend.OpTypeId.MatMul:
transA, transB = backend.matmul_attrs_of(op)
ctx.push_node(
make_node(
"Gemm", inputs, outputs, name, transA=transA, transB=transB
)
)
elif ty == backend.OpType.BatchNorm:
elif ty == backend.OpTypeId.BatchNormalization:
inputs = [inputs[i] for i in [0, 3, 4, 1, 2]]
momentum, eps, training = backend.batch_norm_attrs_of(op)
ctx.push_node(
@ -696,7 +696,7 @@ class OnnxStub:
training_mode=training,
)
)
elif ty == backend.OpType.MaxPool:
elif ty == backend.OpTypeId.MaxPool:
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node(
make_node(
@ -710,7 +710,7 @@ class OnnxStub:
strides=[sh, sw],
)
)
elif ty == backend.OpType.AvgPool:
elif ty == backend.OpTypeId.AveragePool:
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node(
make_node(
@ -724,27 +724,27 @@ class OnnxStub:
)
)
elif ty in [
backend.OpType.Add,
backend.OpType.Sub,
backend.OpType.Mul,
backend.OpType.Div,
backend.OpType.Pow,
backend.OpType.Relu,
backend.OpType.Sigmoid,
backend.OpType.Tanh,
backend.OpType.Softmax,
backend.OpType.Abs,
backend.OpType.Identity,
backend.OpType.PRelu,
backend.OpTypeId.Add,
backend.OpTypeId.Sub,
backend.OpTypeId.Mul,
backend.OpTypeId.Div,
backend.OpTypeId.Pow,
backend.OpTypeId.Relu,
backend.OpTypeId.Sigmoid,
backend.OpTypeId.Tanh,
backend.OpTypeId.Softmax,
backend.OpTypeId.Abs,
backend.OpTypeId.Identity,
backend.OpTypeId.PRelu,
]:
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten:
elif ty == backend.OpTypeId.Flatten:
axis = backend.flatten_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Transpose:
elif ty == backend.OpTypeId.Transpose:
perm = backend.transpose_permute_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm))
elif ty == backend.OpType.Reshape:
elif ty == backend.OpTypeId.Reshape:
shape = backend.reshape_shape_of(op)
inputs.append(
ctx.push_data_input(
@ -756,10 +756,10 @@ class OnnxStub:
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat:
elif ty == backend.OpTypeId.Concat:
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Split:
elif ty == backend.OpTypeId.Split:
axis = backend.split_axis_of(op)
num_outputs = len(outputs)
split = op.inputs()[0].shape()[axis] // num_outputs
@ -781,10 +781,10 @@ class OnnxStub:
axis=axis,
)
)
elif ty == backend.OpType.Gather:
elif ty == backend.OpTypeId.Gather:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean:
elif ty == backend.OpTypeId.ReduceMean:
axes, keepdims = backend.reduce_mean_attrs_of(op)
inputs.append(
ctx.push_data_input(
@ -794,9 +794,9 @@ class OnnxStub:
ctx.push_node(
make_node(ty.name, inputs, outputs, name, keepdims=keepdims)
)
elif ty == backend.OpType.Slice:
elif ty == backend.OpTypeId.Slice:
raise Exception("TODO")
elif ty == backend.OpType.Pad:
elif ty == backend.OpTypeId.Pad:
pads = backend.pad_pads_of(op)
inputs.append(
ctx.push_data_input(
@ -804,7 +804,7 @@ class OnnxStub:
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Clip:
elif ty == backend.OpTypeId.Clip:
min, max = backend.clip_attrs_of(op)
if min != None:
inputs.append(

View File

@ -108,7 +108,7 @@ class TestStringMethods(unittest.TestCase):
name="batchNormalization",
)
make_and_import_model(
make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y])
make_graph([batch_norm], "batchNormalzation", [x, scale, b, mean, var], [y])
)
def test_max_pool(self):

View File

@ -13,7 +13,8 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -48,7 +48,7 @@ bool DummyMutator::isMultiBranchMergable(const Graph &inGraph) {
if (inGraph->getOperators().size() != 2)
return false;
for (auto op : inGraph->getOperators()) {
if (op->getOpType() != OpType::Matmul)
if (op->getOpType() != OpType::MatMul)
return false;
if (op->getPredecessors().size() > 0)
return false;

View File

@ -116,7 +116,7 @@ bool GraphObj::topo_sort() {
void GraphObj::optimize() {
for (auto &op : ops) {
switch (op->getOpType()) {
switch (op->getOpType().underlying()) {
default:
break;
}
@ -151,7 +151,7 @@ TensorVec GraphObj::addTensor(const TensorVec &tensors) {
OpVec GraphObj::getComputeOps() const {
OpVec opList;
for (auto op : ops)
if (op->isComputeOp())
if (op->getOpType().isMatMulOrConv())
opList.emplace_back(op);
return opList;
}

View File

@ -69,9 +69,11 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
}
}
Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training) {
Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
Tensor mean, Tensor var,
Tensor scale, Tensor bias,
float momentum, float eps,
bool training) {
if (output) {
g->addOpWithOutputs<BatchNormObj>(
std::move(input), output, std::move(mean), std::move(var),

278
src/core/op_type.cc Normal file
View File

@ -0,0 +1,278 @@
#include "core/op_type.h"
namespace infini {
const char *OpType::toString() const {
#define CASE(NAME) \
case OpType::NAME: \
return #NAME
switch (type) {
CASE(Unknown);
CASE(Abs);
CASE(Acos);
CASE(Acosh);
CASE(Add);
CASE(And);
CASE(ArgMax);
CASE(Asin);
CASE(Asinh);
CASE(Atan);
CASE(Atanh);
CASE(AveragePool);
CASE(BatchNormalization);
CASE(Bernoulli);
CASE(BitShift);
CASE(BitwiseAnd);
CASE(BitwiseNot);
CASE(BitwiseOr);
CASE(BitwiseXor);
CASE(BlackmanWindow);
CASE(Cast);
CASE(CastLike);
CASE(Ceil);
CASE(Celu);
CASE(CenterCropPad);
CASE(Clip);
CASE(Col2lm);
CASE(Compress);
CASE(Concat);
CASE(ConcatFromSequence);
CASE(ConstantOfShape);
CASE(Conv);
CASE(ConvInteger);
CASE(ConvTranspose);
CASE(Cos);
CASE(Cosh);
CASE(CumSum);
CASE(DFT);
CASE(DeformConv);
CASE(DepthToSpace);
CASE(DequantizeLinear);
CASE(Det);
CASE(Div);
CASE(Dropout);
CASE(DynamicQuantizeLinear);
CASE(Einsum);
CASE(Elu);
CASE(Equal);
CASE(Erf);
CASE(Exp);
CASE(Expand);
CASE(EyeLike);
CASE(Flatten);
CASE(Floor);
CASE(GRU);
CASE(Gather);
CASE(GatherElements);
CASE(GatherND);
CASE(Gemm);
CASE(GlobalAveragePool);
CASE(GlobalLpPool);
CASE(GlobalMaxPool);
CASE(Greater);
CASE(GreaterOrEqual);
CASE(GridSample);
CASE(GroupNormalization);
CASE(HammingWindow);
CASE(HannWindow);
CASE(HardSigmoid);
CASE(HardSwish);
CASE(Hardmax);
CASE(Identity);
CASE(If);
CASE(InstanceNormalization);
CASE(IsInf);
CASE(IsNaN);
CASE(LRN);
CASE(LSTM);
CASE(LayerNormalization);
CASE(LeakyRelu);
CASE(Less);
CASE(LessOrEqual);
CASE(Log);
CASE(LogSoftmax);
CASE(Loop);
CASE(LpNormalization);
CASE(LpPool);
CASE(MatMul);
CASE(MatMulInteger);
CASE(Max);
CASE(MaxPool);
CASE(MaxRoiPool);
CASE(MaxUnpool);
CASE(Mean);
CASE(MeanVarianceNormalization);
CASE(MelWeightMatrix);
CASE(Min);
CASE(Mish);
CASE(Mod);
CASE(Mul);
CASE(Multinomial);
CASE(Neg);
CASE(NegativeLogLikelihoodLoss);
CASE(NonMaxSuppression);
CASE(NonZero);
CASE(Not);
CASE(OneHot);
CASE(Optional);
CASE(OptionalGetElement);
CASE(OptionalHasElement);
CASE(Or);
CASE(PRelu);
CASE(Pad);
CASE(Pow);
CASE(QLinearConv);
CASE(QLinearMatMul);
CASE(QuantizeLinear);
CASE(RNN);
CASE(RandomNormal);
CASE(RandomNormalLike);
CASE(RandomUniform);
CASE(RandomUniformLike);
CASE(Range);
CASE(Reciprocal);
CASE(ReduceL1);
CASE(ReduceL2);
CASE(ReduceLogSum);
CASE(ReduceLogSumExp);
CASE(ReduceMax);
CASE(ReduceMean);
CASE(ReduceMin);
CASE(ReduceProd);
CASE(ReduceSum);
CASE(ReduceSumSquare);
CASE(Relu);
CASE(Reshape);
CASE(Resize);
CASE(ReverseSequence);
CASE(RoiAlign);
CASE(Round);
CASE(STFT);
CASE(Scan);
CASE(Scatter);
CASE(ScatterElements);
CASE(ScatterND);
CASE(Selu);
CASE(SequenceAt);
CASE(SequenceConstruct);
CASE(SequenceEmpty);
CASE(SequenceErase);
CASE(SequenceInsert);
CASE(SequenceLength);
CASE(SequenceMap);
CASE(Shape);
CASE(Shrink);
CASE(Sigmoid);
CASE(Sign);
CASE(Sin);
CASE(Sinh);
CASE(Size);
CASE(Slice);
CASE(Softmax);
CASE(SoftmaxCrossEntropyLoss);
CASE(Softplus);
CASE(Softsign);
CASE(SpaceToDepth);
CASE(Split);
CASE(SplitToSequence);
CASE(Sqrt);
CASE(Squeeze);
CASE(StringNormalizer);
CASE(Sub);
CASE(Sum);
CASE(Tan);
CASE(Tanh);
CASE(TfIdfVectorizer);
CASE(ThresholdedRelu);
CASE(Tile);
CASE(TopK);
CASE(Transpose);
CASE(Trilu);
CASE(Unique);
CASE(Unsqueeze);
CASE(Upsample);
CASE(Where);
CASE(Xor);
// CUSTOM DEFINED
CASE(G2BMM);
CASE(GBMM);
CASE(MemBound);
// TODO
CASE(ConvTransNHWC);
CASE(ConvBackwardFilter);
CASE(ReluBackward);
CASE(SigmoidBackward);
CASE(TanhBackward);
CASE(Fill);
CASE(Extend);
CASE(MSELoss);
CASE(Hardtanh);
CASE(L2Loss);
CASE(Rsqrt);
CASE(FloorDiv);
CASE(FloorMod);
CASE(Square);
CASE(SquaredDifference);
default:
return "Unknown";
}
#undef CASE
}
bool OpType::isUnary() const {
static const std::unordered_set<decltype(type)> set{
Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cast, Ceil,
Clip, Cos, Cosh, Erf, Exp, Floor, Log, Neg, Not,
Relu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh,
};
return set.find(type) != set.end();
}
bool OpType::isBinary() const {
static const std::unordered_set<decltype(type)> set{
Add, And, BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor,
Div, Mod, Mul, Or, Pow, Sub, Xor,
};
return set.find(type) != set.end() || isCompair();
}
bool OpType::isElementWise() const { return isUnary() || isBinary(); }
bool OpType::isCompair() const {
static const std::unordered_set<decltype(type)> set{
Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
};
return set.find(type) != set.end();
}
bool OpType::isPool() const {
static const std::unordered_set<decltype(type)> set{};
return set.find(type) != set.end();
}
bool OpType::isGlobalPool() const {
static const std::unordered_set<decltype(type)> set{
GlobalAveragePool,
GlobalLpPool,
GlobalMaxPool,
};
return set.find(type) != set.end();
}
bool OpType::isMatMulOrConv() const {
static const std::unordered_set<decltype(type)> set{
Conv, ConvInteger, ConvTranspose, DeformConv,
QLinearConv, MatMul, MatMulInteger, QLinearMatMul,
};
return set.find(type) != set.end();
}
} // namespace infini

View File

@ -10,33 +10,6 @@ OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
IT_ASSERT(t);
}
bool OperatorObj::isLinearOp() const {
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
}
bool OperatorObj::isElementWiseOp() const {
return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300;
}
bool OperatorObj::isSplitOp() const { return type == OpType::Split; }
bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
bool OperatorObj::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::ConvTransNHWC ||
type == OpType::G2BMM || type == OpType::GBMM;
}
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose;
}
void OperatorObj::removePredecessors(const Operator &op) {
for (auto it = predecessors.begin(); it != predecessors.end();) {
if (it->lock() == op)
@ -69,14 +42,14 @@ OpPerfKey OperatorObj::getOpPerfKey() const {
// Operator::hash, which hashes operator attributes and ignores tensor
// shapes.
HashType hash = 0;
hash = hashAppend(hash, enum_to_underlying(type));
hash = hashAppend(hash, type.underlying());
hash = hashAppend(hash, hashVector(workloadVector));
return OpPerfKey(hash, type, workloadVector);
}
HashType OperatorObj::hash() const {
HashType hash = 0;
hash = hashAppend(hash, enum_to_underlying(type));
hash = hashAppend(hash, type.underlying());
hash = hashAppend(hash, hashVector(getOpAttrVector()));
return hash;
}

View File

@ -17,7 +17,8 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -65,7 +66,8 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -116,9 +118,8 @@ void RuntimeObj::printProfilingData(double totalTime,
const std::map<OpType, int> &opCnt) const {
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
for (const auto &[type, t] : opTime) {
printf("%11s %3d %7.3f %7.1f %7.3f\n",
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
t / totalTime * 100, t / opCnt.at(type));
printf("%11s %3d %7.3f %7.1f %7.3f\n", type.toString(), opCnt.at(type),
t, t / totalTime * 100, t / opCnt.at(type));
}
}

View File

@ -127,7 +127,7 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
std::vector<Operator> ops;
ops.emplace_back(op);
node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.type = op->isComputeOp();
node.type = op->getOpType().isMatMulOrConv();
node.cnt = op->getPredecessors().size();
opMap.emplace(op->getGuid(), i);
metaGraph->nodes.emplace_back(node);
@ -196,7 +196,7 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
}
node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.cnt = node.pre.size();
node.type = ops[0]->isComputeOp();
node.type = ops[0]->getOpType().isMatMulOrConv();
resultMetaGraph->nodes.emplace_back(node);
}
}
@ -404,7 +404,7 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
headOps.emplace_back(op);
if (op->getPredecessors().size() + op->getSuccessors().size() >=
(size_t)partitionThreshold &&
!op->isComputeOp()) {
!op->getOpType().isMatMulOrConv()) {
auto preOrderI = preOrder[op->getGuid()];
auto postOrderI = postOrder[op->getGuid()];
for (size_t j = 0; j < i; j++) {

View File

@ -11,7 +11,8 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -32,7 +33,8 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -48,6 +48,8 @@ void register_operator_timer(py::module &m) {
#endif
}
decltype(OpType::type) getId(OpType const *const ptr) { return ptr->type; }
void export_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
@ -58,13 +60,13 @@ void export_values(py::module &m) {
.VALUE(ActType, Tanh)
.export_values();
py::enum_<OpType>(m, "OpType")
.VALUE(OpType, Unknown)
py::class_<OpType>(m, "OpType")
.def(py::init<decltype(OpType::type)>())
.def("id", getId, policy::automatic);
py::enum_<decltype(OpType::type)>(m, "OpTypeId")
.VALUE(OpType, Conv)
.VALUE(OpType, Matmul)
.VALUE(OpType, ConvTrans)
.VALUE(OpType, G2BMM)
.VALUE(OpType, GBMM)
.VALUE(OpType, MatMul)
.VALUE(OpType, ConvTranspose)
.VALUE(OpType, Pad)
.VALUE(OpType, Clip)
.VALUE(OpType, Slice)
@ -73,7 +75,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Transpose)
.VALUE(OpType, Extend)
.VALUE(OpType, MaxPool)
.VALUE(OpType, AvgPool)
.VALUE(OpType, AveragePool)
.VALUE(OpType, Add)
.VALUE(OpType, Sub)
.VALUE(OpType, Mul)
@ -84,9 +86,8 @@ void export_values(py::module &m) {
.VALUE(OpType, Reshape)
.VALUE(OpType, Flatten)
.VALUE(OpType, Identity)
.VALUE(OpType, BatchNorm)
.VALUE(OpType, BatchNormalization)
.VALUE(OpType, Softmax)
.VALUE(OpType, Activation)
.VALUE(OpType, Relu)
.VALUE(OpType, PRelu)
.VALUE(OpType, Sigmoid)
@ -152,7 +153,7 @@ static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
static std::tuple<int, int, int, int, int, int, int, int>
conv_trans_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ConvTrans);
IT_ASSERT(op->getOpType() == OpType::ConvTranspose);
auto conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
auto [oph, opw] = conv->getOutputPadding();
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
@ -161,13 +162,13 @@ conv_trans_attrs_of(Operator op) {
}
static std::tuple<bool, bool> matmul_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Matmul);
IT_ASSERT(op->getOpType() == OpType::MatMul);
auto matmul = dynamic_cast<const MatmulObj *>(op.get());
return std::make_tuple(matmul->getTransA(), matmul->getTransB());
}
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
IT_ASSERT(op->getOpType() == OpType::BatchNormalization);
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
batchnorm->getTrainingMode());
@ -176,7 +177,7 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
static std::tuple<int, int, int, int, int, int, int, int>
pool_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
op->getOpType() == OpType::AvgPool);
op->getOpType() == OpType::AveragePool);
auto pool = dynamic_cast<const PoolingObj *>(op.get());
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
pool->getDw(), pool->getPh(), pool->getPw(),
@ -319,7 +320,7 @@ void init_graph_builder(py::module &m) {
.def("conv", &Handler::conv, policy::move)
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
.def("matmul", &Handler::matmul, policy::move)
.def("batchNorm", &Handler::batchNorm, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)

View File

@ -92,43 +92,6 @@ class RoundCnnl : public BangKernelWithoutConfig {
}
};
class SquareCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat =
cnnlSquare(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
@ -185,24 +148,13 @@ class SigmoidCnnl : public UnaryCnnl {
float getCoef() const override { return 0.0; }
};
class TanhCnnl : public UnaryCnnl {
cnnlActivationMode_t getOpType() const override {
return CNNL_ACTIVATION_TANH;
}
float getCoef() const override { return 0.0; }
};
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl,
"Relu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl,
"PRelu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
"Sigmoid_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanhCnnl,
"Tanh_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
"Round_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Square, DataType::Float32, SquareCnnl,
"Square_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -65,7 +65,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32,
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -83,6 +83,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::ConvTrans, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32,
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -1,46 +0,0 @@
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/unary.h"
namespace infini {
class CopyCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat =
cnnlCopy(context->cnnlHandle(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Copy, DataType::Float32, CopyCnnl,
"Copy_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -593,9 +593,6 @@ class MulCnnl : public ElementWiseCnnl {
class EqualCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_EQ; }
};
class NotEqualCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_NE; }
};
class GreaterThanCnnl : public LogicOpCnnl {
cnnlLogicOp_t getOpType() const override { return CNNL_LOGIC_OP_GT; }
};
@ -651,13 +648,13 @@ REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl,
"Div_cnnl_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Maximum, DataType::Float32, MaximumCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl,
"Maximum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Minimum, DataType::Float32, MinimumCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl,
"Minimum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl,
"MSELoss_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Power, DataType::Float32, PowerCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl,
"Power_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl,
"FloorDiv_cnnl_BANG_Float32");
@ -667,15 +664,13 @@ REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32,
SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl,
"Equal_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::NotEqual, DataType::Float32, NotEqualCnnl,
"NotEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::GreaterThan, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32,
GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::GreaterEqual, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32,
GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LessThan, DataType::Float32, LessThanCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl,
"LessThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LessEqual, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32,
LessEqualCnnl, "LessEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl,
"And_cnnl_BANG_Float32");
@ -685,13 +680,13 @@ REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl,
"Xor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl,
"Not_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitAnd, DataType::Float32, BitAndCnnl,
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl,
"BitAnd_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitOr, DataType::Float32, BitOrCnnl,
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl,
"BitOr_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitXor, DataType::Float32, BitXorCnnl,
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl,
"BitXor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitNot, DataType::Float32, BitNotCnnl,
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl,
"BitNot_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32,
// BitLeftShiftCnnl,

View File

@ -79,6 +79,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Matmul, DataType::Float32, MatmulCnnl,
REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl,
"Matmul_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -68,6 +68,6 @@ class avgPoolCnnl : public PoolingCnnl {
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl,
"MaxPool_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AvgPool, DataType::Float32, avgPoolCnnl,
"AvgPool_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32,
avgPoolCnnl, "AvgPool_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -162,23 +162,23 @@ REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl,
"Cos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl,
"Tan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ASin, DataType::Float32, ASinCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl,
"ASin_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ACos, DataType::Float32, ACosCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl,
"ACos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ATan, DataType::Float32, ATanCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl,
"ATan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::SinH, DataType::Float32, SinHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl,
"SinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::CosH, DataType::Float32, CosHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl,
"CosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::TanH, DataType::Float32, TanHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl,
"TanH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ASinH, DataType::Float32, ASinHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl,
"ASinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ACosH, DataType::Float32, ACosHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl,
"ACosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ATanH, DataType::Float32, ATanHCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl,
"ATanH_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -26,9 +26,9 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::UInt32,
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32,
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32,
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
} // namespace infini

View File

@ -76,6 +76,6 @@ REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::AvgPool, DataType::Float32,
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
} // namespace infini

View File

@ -59,6 +59,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::BatchNorm, DataType::Float32,
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32,
BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32");
} // namespace infini

View File

@ -300,7 +300,7 @@ class convBackwardDataCudnn : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32,
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32,
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32,
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32");

View File

@ -114,7 +114,7 @@ class matmulCublas : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas,
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas,
"Matmul_cuBLAS_CUDA_Float32");
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);

View File

@ -68,6 +68,6 @@ class avgPoolCudnn : public poolingCudnn {
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn,
"MaxPool_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AvgPool, DataType::Float32, avgPoolCudnn,
"AvgPool_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32,
avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32");
}; // namespace infini

View File

@ -63,6 +63,6 @@ class MklBatchNorm : public MklKernelWithoutConfig {
{DNNL_ARG_SHIFT, baisMemory}});
}
};
REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNorm, DataType::Float32,
REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNormalization, DataType::Float32,
MklBatchNorm, "BatchNorm_Mkl_Float32");
}; // namespace infini

View File

@ -244,7 +244,7 @@ class MklConvTranspose : public Kernel {
return make_ref<ConvTransposeMklPerfRecordObj>(ret);
}
};
REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTrans, DataType::Float32,
REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTranspose, DataType::Float32,
MklConvTranspose, "MklConvTrans_CPU_float32");
} // namespace infini

View File

@ -38,12 +38,12 @@ optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
}
vector<int> G2BMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, k, width, dilation,
return {type.underlying(), b, m, k, width, dilation,
enum_to_underlying(act)};
}
vector<int> G2BMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), width, dilation, enum_to_underlying(act)};
return {type.underlying(), width, dilation, enum_to_underlying(act)};
}
} // namespace infini

View File

@ -37,11 +37,10 @@ optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
}
vector<int> GBMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, w, n, dilation,
enum_to_underlying(act)};
return {type.underlying(), b, m, w, n, dilation, enum_to_underlying(act)};
}
vector<int> GBMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), dilation, enum_to_underlying(act)};
return {type.underlying(), dilation, enum_to_underlying(act)};
}
} // namespace infini

View File

@ -15,7 +15,7 @@ ActivationBackwardObj::inferShape(const TensorVec &inputs) const {
std::string ActivationBackwardObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -24,14 +24,14 @@ std::string ActivationBackwardObj::toString() const {
}
vector<int> ActivationBackwardObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> ActivationBackwardObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
return {type.underlying()};
}
}; // namespace infini

View File

@ -4,7 +4,8 @@ namespace infini {
BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output,
Tensor mean, Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool trainingMode)
: OperatorObj(OpType::BatchNorm, {input, mean, var, scale, bias}, {output}),
: OperatorObj(OpType::BatchNormalization, {input, mean, var, scale, bias},
{output}),
momentum(momentum), eps(eps), trainingMode(trainingMode) {
if (trainingMode)
IT_TODO_HALT();
@ -38,7 +39,7 @@ vector<DataType> BatchNormObj::inferDataType(const TensorVec &inputs) const {
std::string BatchNormObj::toString() const {
std::ostringstream os;
os << "BatchNorm[" << getGuid() << "]";
os << "batchNormalization[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "momentum=" << momentum << ",";
@ -57,13 +58,13 @@ std::string BatchNormObj::toString() const {
// need eps and momentum?
vector<int> BatchNormObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
// need eps and momentum?
vector<int> BatchNormObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
return {type.underlying()};
}
} // namespace infini

View File

@ -47,12 +47,12 @@ vector<int> ConcatObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> ConcatObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim};
return {type.underlying(), dim};
}
} // namespace infini

View File

@ -19,7 +19,7 @@ ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
string ConvBaseObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
if (inputs.size() == 2) {
os << vecToString(inputs[0]->getDims()) << ",";
@ -36,13 +36,12 @@ string ConvBaseObj::toString() const {
}
vector<int> ConvBaseObj::getWorkloadVector() const {
return {
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
return {type.underlying(), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
}
vector<int> ConvBaseObj::getOpAttrVector() const {
// IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
return {type.underlying(), c, f, r, s, ph, pw, sh, sw, dh, dw};
}
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
@ -119,8 +118,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
int pw, int sh, int sw, int dh, int dw,
int oph, int opw, int group,
Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw,
dh, dw, output, weight, act),
: ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, ph, pw, sh,
sw, dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) {
if (bias)
IT_TODO_HALT();
@ -133,8 +132,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
PaddingMode mode, int sh, int sw,
int dh, int dw, int oph, int opw,
int group, Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
dw, output, weight, act),
: ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, mode, sh, sw,
dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) {
if (bias)
IT_TODO_HALT();
@ -274,8 +273,8 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
int sw, int dh, int dw,
int oph, int opw, int group,
Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
dw, output, weight, act),
: ConvBaseObj(OpType::ConvTranspose, {input, weight}, output, mode, sh, sw,
dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) {
if (bias)
IT_TODO_HALT();

View File

@ -21,7 +21,7 @@ optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
std::string DetObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -30,14 +30,12 @@ std::string DetObj::toString() const {
}
vector<int> DetObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> DetObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> DetObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini

View File

@ -29,12 +29,12 @@ std::string DropoutObj::toString() const {
vector<int> DropoutObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace_back(static_cast<int>(ratio));
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> DropoutObj::getOpAttrVector() const {
return {enum_to_underlying(type), static_cast<int>(ratio), false};
return {type.underlying(), static_cast<int>(ratio), false};
}
} // namespace infini

View File

@ -39,7 +39,7 @@ ElementWiseObj::inferShape(const TensorVec &inputs) const {
std::string ElementWiseObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
@ -52,12 +52,12 @@ std::string ElementWiseObj::toString() const {
// use output dim or inputs dim?
vector<int> ElementWiseObj::getWorkloadVector() const {
vector<int> ret = outputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> ElementWiseObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
return {type.underlying()};
}
MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
@ -83,7 +83,7 @@ optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const {
std::string MSELossObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
@ -96,12 +96,10 @@ std::string MSELossObj::toString() const {
// use output dim or inputs dim?
vector<int> MSELossObj::getWorkloadVector() const {
vector<int> ret = outputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> MSELossObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> MSELossObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini

View File

@ -30,12 +30,12 @@ vector<int> ExtendObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace_back(dim);
ret.emplace_back(num);
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> ExtendObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim, num};
return {type.underlying(), dim, num};
}
} // namespace infini

View File

@ -72,7 +72,7 @@ std::string GatherObj::toString() const {
vector<int> GatherObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
for (auto it : inputs[1]->getDims())
ret.emplace_back(it);
ret.emplace_back(axis);
@ -80,7 +80,7 @@ vector<int> GatherObj::getWorkloadVector() const {
}
vector<int> GatherObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis};
return {type.underlying(), axis};
}
} // namespace infini

View File

@ -4,7 +4,7 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul,
: OperatorObj(OpType::MatMul,
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
transA(transA), transB(transB), act(act), b(1) {
auto shape_a = A->getDims();
@ -82,12 +82,12 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
}
vector<int> MatmulObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, n, k, transA, transB,
return {type.underlying(), b, m, n, k, transA, transB,
enum_to_underlying(act)};
}
vector<int> MatmulObj::getOpAttrVector() const {
return {enum_to_underlying(type), transA, transB, enum_to_underlying(act)};
return {type.underlying(), transA, transB, enum_to_underlying(act)};
}
} // namespace infini

View File

@ -69,7 +69,7 @@ optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
}
vector<int> MemBoundObj::getWorkloadVector() const {
return {enum_to_underlying(type), (int)simplifiedHash};
return {type.underlying(), (int)simplifiedHash};
}
vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }

View File

@ -50,13 +50,13 @@ std::string PadObj::toString() const {
vector<int> PadObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), pads.begin(), pads.end());
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> PadObj::getOpAttrVector() const {
vector<int> ret = pads;
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}

View File

@ -28,7 +28,7 @@ optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
std::string PoolingObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << "k=[" << kh << "," << kw << "],";
os << "p=[" << ph << "," << pw << "],";
@ -40,12 +40,11 @@ std::string PoolingObj::toString() const {
}
vector<int> PoolingObj::getWorkloadVector() const {
return {
enum_to_underlying(type), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
return {type.underlying(), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
}
vector<int> PoolingObj::getOpAttrVector() const {
return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw};
return {type.underlying(), kh, kw, ph, pw, sh, sw, dh, dw};
}
}; // namespace infini

View File

@ -69,14 +69,14 @@ std::string ReduceMeanObj::toString() const {
vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back((int)keepDims);
ret.insert(ret.end(), axes.begin(), axes.end());
return ret;
}
vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
vector<int> ret = {type.underlying(), (int)keepDims};
ret.insert(ret.end(), axes.begin(), axes.end());
return ret;
}

View File

@ -30,12 +30,12 @@ std::string ReshapeObj::toString() const {
vector<int> ReshapeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), dims.begin(), dims.end());
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> ReshapeObj::getOpAttrVector() const {
vector<int> ret = dims;
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
@ -74,12 +74,12 @@ std::string FlattenObj::toString() const {
vector<int> FlattenObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), axis);
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> FlattenObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis};
return {type.underlying(), axis};
}
IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output)
@ -103,10 +103,8 @@ std::string IdentityObj::toString() const {
vector<int> IdentityObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> IdentityObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> IdentityObj::getOpAttrVector() const { return {type.underlying()}; }
} // namespace infini

View File

@ -244,7 +244,7 @@ vector<int> ResizeObj::getWorkloadVector() const {
// here.
ret.emplace_back(enum_to_underlying(coMode));
ret.emplace_back(enum_to_underlying(nearestMode));
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}
@ -253,7 +253,7 @@ vector<int> ResizeObj::getOpAttrVector() const {
ret.emplace_back(enum_to_underlying(coMode));
ret.emplace_back(enum_to_underlying(nearestMode));
ret.emplace_back(enum_to_underlying(ratioPolicy));
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
return ret;
}

View File

@ -93,7 +93,7 @@ vector<int> SliceObj::getWorkloadVector() const {
}
vector<int> SliceObj::getOpAttrVector() const {
vector<int> ans{enum_to_underlying(type)};
vector<int> ans{type.underlying()};
for (const auto &range : axes) {
ans.push_back(range.start);
ans.push_back(range.end);

View File

@ -15,7 +15,7 @@ SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
std::string SoftmaxObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -25,13 +25,13 @@ std::string SoftmaxObj::toString() const {
}
vector<int> SoftmaxObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type), axis};
vector<int> ret{type.underlying(), axis};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> SoftmaxObj::getOpAttrVector() const {
return {enum_to_underlying(type), axis};
return {type.underlying(), axis};
}
} // namespace infini

View File

@ -56,14 +56,14 @@ optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const {
vector<int> SplitObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(dim);
ret.emplace_back(num);
return ret;
}
vector<int> SplitObj::getOpAttrVector() const {
return {enum_to_underlying(type), dim, num};
return {type.underlying(), dim, num};
}
string SplitObj::toString() const {

View File

@ -28,7 +28,7 @@ TransposeObj::inferShape(const TensorVec &inputs) const {
std::string TransposeObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -37,14 +37,14 @@ std::string TransposeObj::toString() const {
}
vector<int> TransposeObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> TransposeObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
return {type.underlying()};
}
}; // namespace infini

View File

@ -13,7 +13,7 @@ optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const {
std::string UnaryObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -22,15 +22,13 @@ std::string UnaryObj::toString() const {
}
vector<int> UnaryObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> UnaryObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> UnaryObj::getOpAttrVector() const { return {type.underlying()}; }
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
std::optional<float> min, std::optional<float> max)
@ -46,7 +44,7 @@ optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) const {
std::string ClipObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -55,15 +53,13 @@ std::string ClipObj::toString() const {
}
vector<int> ClipObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> ClipObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> ClipObj::getOpAttrVector() const { return {type.underlying()}; }
HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output,
float min, float max)
@ -79,7 +75,7 @@ optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) const {
std::string HardtanhObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -88,15 +84,13 @@ std::string HardtanhObj::toString() const {
}
vector<int> HardtanhObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> HardtanhObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> HardtanhObj::getOpAttrVector() const { return {type.underlying()}; }
FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
: OperatorObj(OpType::Fill, {input}, {output}), setValue(value) {
@ -110,22 +104,20 @@ optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) const {
std::string FillObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> FillObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> FillObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> FillObj::getOpAttrVector() const { return {type.underlying()}; }
L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output)
: OperatorObj(OpType::L2Loss, {input}, {output}) {
@ -139,22 +131,20 @@ optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const {
std::string L2LossObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> L2LossObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> L2LossObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> L2LossObj::getOpAttrVector() const { return {type.underlying()}; }
CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type)
: OperatorObj(OpType::Cast, {input}, {output}), castType(type) {
@ -176,22 +166,20 @@ optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
std::string CastObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> CastObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> CastObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> CastObj::getOpAttrVector() const { return {type.underlying()}; }
DataType CastObj::getOutputDataType() const {
switch (castType) {
@ -251,7 +239,7 @@ optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const {
std::string ShapeObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]("
os << type.toString() << "[" << getGuid() << "]("
<< "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
@ -268,7 +256,7 @@ optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) const {
std::string PReluObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
@ -277,15 +265,13 @@ std::string PReluObj::toString() const {
}
vector<int> PReluObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> PReluObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
@ -299,21 +285,19 @@ optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) const {
std::string LogObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> LogObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> LogObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
vector<int> LogObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testCopy(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(outputGpu2Cpu->equalData(inputCpu));
}
TEST(cnnl_Copy, run) {
testCopy<CopyObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,49 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testFloorDiv(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
inputCpu1->printData();
inputCpu2->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_FloorDiv, run) {
testFloorDiv<FloorDivObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,49 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testFloorMod(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
inputCpu1->printData();
inputCpu2->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_FloorMod, run) {
testFloorMod<FloorModObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -42,7 +42,6 @@ void testLogicOp(const std::function<void(void *, size_t, DataType)> &generator,
TEST(cnnl_LogicOp, run) {
testLogicOp<EqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<NotEqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<GreaterThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<GreaterEqualObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testLogicOp<LessThanObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testRsqrt(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Rsqrt, run) {
testRsqrt<RsqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,40 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testSquare(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Square, run) {
testSquare<SquareObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -1,48 +0,0 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
template <class T>
void testSquaredDifference(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu1 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu1->dataMalloc();
inputCpu1->setData(generator);
Tensor inputCpu2 =
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu2->dataMalloc();
inputCpu2->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp = bangGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_SquaredDifference, run) {
testSquaredDifference<SquaredDifferenceObj>(IncrementalGenerator(),
Shape{1, 2, 2, 3});
}
} // namespace infini

View File

@ -152,8 +152,8 @@ TEST(cuDNN_ConvTransposed, tune) {
bool tune = true;
cuda->run(gCuda, tune);
// check record
auto kernelAttrs =
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{Device::CUDA, conv->getOpType().underlying(),
DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -53,8 +53,8 @@ TEST(mkl_Conv, tune) {
mklRuntime->run(gMkl, tune);
// check record
auto kernelAttrs =
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{
Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -73,8 +73,8 @@ TEST(mkl_ConvTransposed, tune) {
bool tune = true;
runtime->run(gMkl, tune);
// check record
auto kernelAttrs =
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{
Device::INTELCPU, conv->getOpType().underlying(), DataType::Float32};
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
std::optional<PerfRecord> perfData =
PerfEngine::getInstance().getPerfData(perfKey);

View File

@ -4,7 +4,7 @@
#include "test.h"
namespace infini {
TEST(BatchNorm, ShapeInference) {
TEST(BatchNormalization, ShapeInference) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(cpuRuntime);