- cpu kernel: adapt the new registration mechanism

This commit is contained in:
kilinchange 2023-12-05 10:49:28 +08:00
parent c19256bca6
commit c587901586
9 changed files with 516 additions and 314 deletions

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveConcat : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConcatObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim();
@ -41,11 +41,25 @@ template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::UInt32,
NaiveConcat<uint32_t>, "ConcatNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::Float32,
NaiveConcat<float>, "ConcatNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Concat, NaiveConcat, "ConcatNaive_CPU");
} // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveConv : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConvObj>(_op);
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
T *wptr = op->getInputs(1)->getRawDataPtr<T *>();
@ -50,11 +50,25 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32,
NaiveConv<uint32_t>, "ConvNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv<float>,
"ConvNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Conv, NaiveConv, "ConvNaive_CPU");
} // namespace infini

View File

@ -2,10 +2,45 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
virtual T doCompute(T val0, T val1) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NativeElementWise : public CpuKernelWithoutConfig {
template <typename T> static T addCompute(T val0, T val1) {
return val0 + val1;
}
template <typename T> static T subCompute(T val0, T val1) {
return val0 - val1;
}
template <typename T> static T mulCompute(T val0, T val1) {
return val0 * val1;
}
template <typename T> static T divCompute(T val0, T val1) {
return (T)(val0 / val1);
}
template <typename T> static T equalCompute(T val0, T val1) {
return (T)(val0 == val1);
}
template <typename T> static T greaterOrEqualCompute(T val0, T val1) {
return (T)(val0 >= val1);
}
template <typename T> static T greaterCompute(T val0, T val1) {
return (T)(val0 > val1);
}
template <typename T> static T lessOrEqualCompute(T val0, T val1) {
return (T)(val0 <= val1);
}
template <typename T> static T lessCompute(T val0, T val1) {
return (T)(val0 < val1);
}
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ElementWiseObj>(_op);
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
@ -22,6 +57,39 @@ template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
std::copy(c_output.begin(), c_output.end(), c + (4 - c_output.size()));
auto n = op->getOutput()->size();
T (*_doCompute)(T val0, T val1);
switch (op->getOpType().underlying()) {
case OpType::Add:
_doCompute = addCompute<T>;
break;
case OpType::Sub:
_doCompute = subCompute<T>;
break;
case OpType::Mul:
_doCompute = mulCompute<T>;
break;
case OpType::Div:
_doCompute = divCompute<T>;
break;
case OpType::Equal:
_doCompute = equalCompute<T>;
break;
case OpType::GreaterOrEqual:
_doCompute = greaterOrEqualCompute<T>;
break;
case OpType::Greater:
_doCompute = greaterCompute<T>;
break;
case OpType::LessOrEqual:
_doCompute = lessOrEqualCompute<T>;
break;
case OpType::Less:
_doCompute = lessCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t i = 0; i < n; ++i) {
int c0_index = i / (c[1] * c[2] * c[3]);
int c1_index = (i % (c[1] * c[2] * c[3])) / (c[2] * c[3]);
@ -37,77 +105,44 @@ template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
int b1_index = c1_index % b[1];
int b2_index = c2_index % b[2];
int b3_index = c3_index % b[3];
outptr[i] = doCompute(
outptr[i] = _doCompute(
inptr0[a0_index * a[1] * a[2] * a[3] + a1_index * a[2] * a[3] +
a2_index * a[3] + a3_index],
inptr1[b0_index * b[1] * b[2] * b[3] + b1_index * b[2] * b[3] +
b2_index * b[3] + b3_index]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveAdd : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 + val1; }
};
template <typename T> class NaiveSub : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 - val1; }
};
template <typename T> class NaiveMul : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 * val1; }
};
template <typename T> class NaiveDiv : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
};
template <typename T> class NaiveEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 == val1); }
};
template <typename T> class NaiveGreaterEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); }
};
template <typename T> class NaiveGreaterThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 > val1); }
};
template <typename T> class NaiveLessEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); }
};
template <typename T> class NaiveLessThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 < val1); }
};
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
"addNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
"addNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
"subNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
"subNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
"mulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
"mulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
"divNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
"divNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32,
NaiveEqual<uint32_t>, "equalNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32,
NaiveEqual<float>, "equalNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32,
NaiveGreaterEqual<uint32_t>, "greaterEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32,
NaiveGreaterEqual<float>, "greaterEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32,
NaiveGreaterThan<uint32_t>, "greaterThanNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32,
NaiveGreaterThan<float>, "greaterThanNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32,
NaiveLessEqual<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32,
NaiveLessEqual<float>, "lessEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32,
NaiveLessThan<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32,
NaiveLessThan<float>, "lessEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Add, NativeElementWise, "addNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sub, NativeElementWise, "subNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Mul, NativeElementWise, "mulNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Div, NativeElementWise, "divNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Equal, NativeElementWise,
"equalNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, NativeElementWise,
"greaterEqualNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Greater, NativeElementWise,
"greaterThanNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, NativeElementWise,
"lessEqualNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Less, NativeElementWise,
"lessEqualNaive_CPU");
}; // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveMatmul : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
T *A = op->getInputs(0)->getRawDataPtr<T *>();
@ -23,11 +23,25 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32,
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::MatMul, NaiveMatmul, "MatmulNaive_CPU");
} // namespace infini

View File

@ -80,8 +80,8 @@ class MemboundInterpreter : public Kernel {
}
};
REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32,
MemboundInterpreter, "MemboundInterpreter_CPU");
REGISTER_KERNEL(Device::CPU, OpType::MemBound, MemboundInterpreter,
"MemboundInterpreter_CPU");
} // namespace infini

View File

@ -2,42 +2,10 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativePooling : public CpuKernelWithoutConfig {
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val =
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
ih, iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
};
template <typename T> class NaiveMaxPool : public NativePooling<T> {
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
T *inptr) const override {
class NativePooling : public CpuKernelWithoutConfig {
template <typename T>
static T getMaxPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) {
T maxval = 0;
for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) {
@ -53,11 +21,10 @@ template <typename T> class NaiveMaxPool : public NativePooling<T> {
}
return maxval;
}
};
template <typename T> class NaiveAvgPool : public NativePooling<T> {
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
T *inptr) const override {
template <typename T>
static T getAvgPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) {
T sum = 0;
for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) {
@ -71,12 +38,70 @@ template <typename T> class NaiveAvgPool : public NativePooling<T> {
}
return T(sum / (kh * kw));
}
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
T(*_doCompute)
(int kh, int kw, int posh, int posw, int ih, int iw, T *inptr);
switch (op->getOpType().underlying()) {
case OpType::MaxPool:
_doCompute = getMaxPoolingValue<T>;
break;
case OpType::AveragePool:
_doCompute = getAvgPoolingValue<T>;
break;
default:
IT_TODO_HALT();
}
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val = _doCompute(kh, kw, h * sh - ph, w * sw - pw, ih,
iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, NativePooling,
"maxPoolNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, NativePooling,
"avgPoolNaive_CPU");
} // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveSplit : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SplitObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim();
@ -40,11 +40,24 @@ template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::UInt32,
NaiveSplit<uint32_t>, "SplitNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::Float32,
NaiveSplit<float>, "SplitNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Split, NaiveSplit, "SplitNaive_CPU");
} // namespace infini

View File

@ -14,9 +14,9 @@ inline Shape idx2Pos(const Shape &shape, size_t idx) {
return pos;
}
template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveTranspose : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<TransposeObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
const auto &inDim = inputs[0]->getDims();
@ -35,11 +35,26 @@ template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
outPtr[outIdx] = inPtr[inIdx];
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::UInt32,
NaiveTranspose<uint32_t>, "TransposeNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::Float32,
NaiveTranspose<float>, "TransposeNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Transpose, NaiveTranspose,
"TransposeNaive_CPU");
} // namespace infini

View File

@ -4,25 +4,170 @@
#include "operators/softmax.h"
namespace infini {
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
virtual T doCompute(T val) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NativeUnary : public CpuKernelWithoutConfig {
template <typename T> static T reluCompute(T val) {
return std::max(T(0), val);
}
template <typename T> static T sigmoidCompute(T val) {
return 1 / (1 + pow(E_CONSTANT, -val));
}
template <typename T> static T hardSigmoidCompute(T val) {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
template <typename T> static T hardSwishCompute(T val) {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
template <typename T> static T tanhCompute(T val) {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
template <typename T> static T absCompute(T val) {
return val < 0 ? -val : val;
}
template <typename T> static T sqrtCompute(T val) { return std::sqrt(val); }
template <typename T> static T cosCompute(T val) { return std::cos(val); }
template <typename T> static T sinCompute(T val) { return std::sin(val); }
template <typename T> static T tanCompute(T val) { return std::tan(val); }
template <typename T> static T sinhCompute(T val) { return std::sinh(val); }
template <typename T> static T coshCompute(T val) { return std::cosh(val); }
template <typename T> static T geluCompute(T val) {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
template <typename T> static T erfCompute(T val) { return std::erf(val); }
template <typename T> static T aCosCompute(T val) { return std::acos(val); }
template <typename T> static T aCoshCompute(T val) {
return std::acosh(val);
}
template <typename T> static T aSinCompute(T val) { return std::asin(val); }
template <typename T> static T aSinhCompute(T val) {
return std::asinh(val);
}
template <typename T> static T aTanCompute(T val) { return std::atan(val); }
template <typename T> static T aTanhCompute(T val) {
return std::atanh(val);
}
template <typename T> static T negCompute(T val) { return -val; }
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size();
T (*_doCompute)(T val);
switch (op->getOpType().underlying()) {
case OpType::Relu:
_doCompute = reluCompute<T>;
break;
case OpType::Gelu:
_doCompute = geluCompute<T>;
break;
case OpType::Sigmoid:
_doCompute = sigmoidCompute<T>;
break;
case OpType::HardSigmoid:
_doCompute = hardSigmoidCompute<T>;
break;
case OpType::HardSwish:
_doCompute = hardSwishCompute<T>;
break;
case OpType::Tanh:
_doCompute = tanhCompute<T>;
break;
case OpType::Abs:
_doCompute = absCompute<T>;
break;
case OpType::Sqrt:
_doCompute = sqrtCompute<T>;
break;
case OpType::Erf:
_doCompute = erfCompute<T>;
break;
case OpType::Neg:
_doCompute = negCompute<T>;
break;
case OpType::Cos:
_doCompute = cosCompute<T>;
break;
case OpType::Sin:
_doCompute = sinCompute<T>;
break;
case OpType::Tan:
_doCompute = tanCompute<T>;
break;
case OpType::Sinh:
_doCompute = sinhCompute<T>;
break;
case OpType::Cosh:
_doCompute = coshCompute<T>;
break;
case OpType::Acos:
_doCompute = aCosCompute<T>;
break;
case OpType::Asin:
_doCompute = aSinCompute<T>;
break;
case OpType::Asinh:
_doCompute = aSinhCompute<T>;
break;
case OpType::Atan:
_doCompute = aTanCompute<T>;
break;
case OpType::Atanh:
_doCompute = aTanhCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t offset = 0; offset < n; offset++) {
outptr[offset] = doCompute(inptr[offset]);
outptr[offset] = _doCompute(inptr[offset]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveSoftmax : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SoftmaxObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -37,98 +182,28 @@ template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
}
}
};
template <typename T> class NaiveRelu : public NativeUnary<T> {
T doCompute(T val) const override { return std::max(T(0), val); }
};
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return 1 / (1 + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveHardSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
};
template <typename T> class NaiveHardSwish : public NativeUnary<T> {
T doCompute(T val) const override {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
};
template <typename T> class NaiveTanh : public NativeUnary<T> {
T doCompute(T val) const override {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveAbs : public NativeUnary<T> {
T doCompute(T val) const override { return val < 0 ? -val : val; }
};
template <typename T> class NaiveSqrt : public NativeUnary<T> {
T doCompute(T val) const override { return std::sqrt(val); }
};
template <typename T> class NaiveCos : public NativeUnary<T> {
T doCompute(T val) const override { return std::cos(val); }
};
template <typename T> class NaiveSin : public NativeUnary<T> {
T doCompute(T val) const override { return std::sin(val); }
};
template <typename T> class NaiveTan : public NativeUnary<T> {
T doCompute(T val) const override { return std::tan(val); }
};
template <typename T> class NaiveSinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::sinh(val); }
};
template <typename T> class NaiveCosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::cosh(val); }
};
template <typename T> class NaiveGelu : public NativeUnary<T> {
T doCompute(T val) const override {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
};
template <typename T> class NaiveErf : public NativeUnary<T> {
T doCompute(T val) const override { return std::erf(val); }
};
template <typename T> class NaiveACos : public NativeUnary<T> {
T doCompute(T val) const override { return std::acos(val); }
};
template <typename T> class NaiveACosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::acosh(val); }
};
template <typename T> class NaiveASin : public NativeUnary<T> {
T doCompute(T val) const override { return std::asin(val); }
};
template <typename T> class NaiveASinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::asinh(val); }
};
template <typename T> class NaiveATanh : public NativeUnary<T> {
T doCompute(T val) const override { return std::atanh(val); }
};
template <typename T> class NaiveNeg : public NativeUnary<T> {
T doCompute(T val) const override { return -val; }
};
template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Clip : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ClipObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -143,11 +218,28 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
: val;
}
}
};
template <typename T> class Log : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Log : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<LogObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -176,70 +268,50 @@ template <typename T> class Log : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveATan : public NativeUnary<T> {
T doCompute(T val) const override { return std::atan(val); }
};
REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
"hardSigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, NativeUnary,
"hardSwishNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, NativeUnary, "tanhNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Abs, NativeUnary, "absNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, NativeUnary, "sqrtNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Erf, NativeUnary, "erfNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Neg, NativeUnary, "negNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cos, NativeUnary, "Cos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sin, NativeUnary, "Sin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tan, NativeUnary, "Tan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, NativeUnary, "Sinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, NativeUnary, "Cosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acos, NativeUnary, "ACos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, NativeUnary, "ACosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asin, NativeUnary, "ASin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, NativeUnary, "ASinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atan, NativeUnary, "Atan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, NativeUnary, "ATanh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
"reluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32,
NaiveHardSigmoid<float>, "hardSigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32,
NaiveHardSwish<float>, "hardSwishNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
"tanhNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
"absNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
"erfNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg<float>,
"negNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip<float>,
"Clip_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan<float>,
"Atan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log<float>,
"Log_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos<float>,
"Cos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin<float>,
"Sin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan<float>,
"Tan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh<float>,
"Sinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh<float>,
"Cosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos<float>,
"ACos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32,
NaiveACosh<float>, "ACosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin<float>,
"ASin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32,
NaiveASinh<float>, "ASinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32,
NaiveATanh<float>, "ATanh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, NaiveSoftmax, "softmaxNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Clip, Clip, "Clip_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Log, Log, "Log_CPU");
}; // namespace infini