Modify kernel registration & support fp16 (#205)

* - Remove dataType from the kernel registration.

* - support fp16 for conv

* - cpu kernel: adapt the new registration mechanism

* modified all register kernel

* add where fp16

* add layernorm fp16

* add split_concat fp16

* - element_wise support fp16

* feat: support transpose fp16

* feat: support sliceOp fp16

* - unary support fp16

* - feat: support reduceOp fp16

* feat: support matmulOp/expandOp fp16

* feat: support powOp int8

* add cuda cast & support half-precision for gather

* style: fix style

* feat:support int8 for gather

* style:fix style

* modified test_cuda_conv_transposed

* fix: fix dist code to support fp16

* fix(graph.cc): fix topo_sort

* fix: fix recv and send kernel registration

* feat: add field tensors for stub

* refactor(frontend): 先排序后构图

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 为中间结果提供tensor到node的mapping

* fix (slice): add guard for area out of range

* fix: fix matmul fp16

* fix: fix re-dataMalloc for weight tensor and use of naive allocator

* feat: add dataType filter for cuda kernel

* feat: bang kernel adapt the new registration mechanism

* fix: fix some error on mlu

* feat: intelcpu kernel adapt the new registration mechanism

* feat: modify kernel registration on kunlun

* fix intelcpu compiler bug

* feat: bang reshape support all dataType

* fix: fix bang reduce

* fix(all_reduce.cc): fix as reviewer suggessted

* fix: fix style and restore unary test codes

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
Co-authored-by: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com>
Co-authored-by: zhangyunze <z13785159769@163.com>
Co-authored-by: OdinaryWord <sx-hz@163.com>
Co-authored-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
Chenjie Duan 2024-01-15 11:02:13 +08:00 committed by GitHub
parent 58993d4339
commit 51086d2b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
157 changed files with 3627 additions and 2575 deletions

View File

@ -137,7 +137,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.output[0]] = Shard(list(perm).index(plc.dim)) place[node.output[0]] = Shard(list(perm).index(plc.dim))
def shard_node(node: NodeProto): def shard_node(node: NodeProto):
if node.op_type in ["Relu", "Tanh", "Softmax"]: if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
place[node.output[0]] = place[node.input[0]] place[node.output[0]] = place[node.input[0]]
elif node.op_type in ["Where"]: elif node.op_type in ["Where"]:
place[node.output[0]] = place[node.input[1]] place[node.output[0]] = place[node.input[1]]
@ -177,7 +177,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
input in data for input in node.input input in data for input in node.input
): ):
# FIXME(constroy): the last MatMul should not be sharded as TP. # FIXME(constroy): the last MatMul should not be sharded as TP.
if node.output[0] in output: if (
node.output[0] in output
or (
index + 1 < len(model.graph.node)
and model.graph.node[index + 1].output[0]
)
in output
):
continue continue
groups = 1 groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups

View File

@ -30,7 +30,6 @@ class Kernel {
public: public:
Kernel() {} Kernel() {}
virtual ~Kernel() {} virtual ~Kernel() {}
/** /**
* @param op The operator to be executed. * @param op The operator to be executed.
* @param record The parameters for kernel execution. If extra parameters * @param record The parameters for kernel execution. If extra parameters
@ -130,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
} // namespace infini } // namespace infini
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \ #define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
namespace infini { \ namespace infini { \
static const bool _CAT(_register_kernel_, cnt) = \ static const bool _CAT(_register_kernel_, cnt) = \
KernelRegistry::getInstance().registerKernel( \ KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
KernelAttrs{device, opType, dataType}, new kernel(), name); \ opType}, \
new kernel(), name); \
} }
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \ #define REGISTER_KERNEL(device, opType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__) _REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \ #define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
namespace infini { \ namespace infini { \

View File

@ -4,7 +4,7 @@
#include "core/tensor.h" #include "core/tensor.h"
namespace infini { namespace infini {
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>; using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
struct OpPerfKey { struct OpPerfKey {
HashType hash; HashType hash;
@ -90,6 +90,7 @@ class OperatorObj : public Object {
OpType getOpType() const { return type; } OpType getOpType() const { return type; }
// HACK: set correct data type // HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); } DataType getDType() const { return getInputs(0)->getDType(); }
DataType getOutDType() const { return getOutput()->getDType(); }
virtual int numInputs() const = 0; virtual int numInputs() const = 0;
virtual int numOutputs() const = 0; virtual int numOutputs() const = 0;

View File

@ -44,8 +44,16 @@ class TensorObj : public TensorBaseObj {
bool isOutput() const { return tensorType == TensorType::output; } bool isOutput() const { return tensorType == TensorType::output; }
bool isOthers() const { return tensorType == TensorType::others; } bool isOthers() const { return tensorType == TensorType::others; }
void setWeight() { tensorType = TensorType::weight; } void setWeight() { tensorType = TensorType::weight; }
void setInput() { tensorType = TensorType::input; } void setInput() {
void setOutput() { tensorType = TensorType::output; } if (!this->isWeight()) {
tensorType = TensorType::input;
}
}
void setOutput() {
if (!this->isWeight()) {
tensorType = TensorType::output;
}
}
string tensorTypeToString() const { string tensorTypeToString() const {
switch (tensorType) { switch (tensorType) {
case TensorType::weight: case TensorType::weight:

View File

@ -1,13 +1,16 @@
#pragma once #pragma once
namespace infini { namespace infini {
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void div_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int c2, int c3);
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); void add_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); int c2, int c3);
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c3); int c2, int c3);
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
}; // namespace infini }; // namespace infini

View File

@ -3,7 +3,8 @@
#include "operators/unary.h" #include "operators/unary.h"
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void expandKernel(float *input, float *output, int nDims, int outputsize, void expandKernel(int dType, void *input, void *output, int nDims,
SmallArray inputShape, SmallArray outputShape); int outputsize, SmallArray inputShape,
SmallArray outputShape);
}; // namespace infini }; // namespace infini

View File

@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
void LaynormKernel(const float *input, const float *scale, const float eps, void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride, int size, int scaleSize, const int dimsize, const int stride,
float *output); float *output);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output);
}; // namespace infini }; // namespace infini

View File

@ -3,4 +3,6 @@
namespace infini { namespace infini {
void softmax_kernel(int num_blocks, float *input, float *output, int size, void softmax_kernel(int num_blocks, float *input, float *output, int size,
int dimsize, int stride); int dimsize, int stride);
} void softmax_kernel(int num_blocks, half *input, half *output, int size,
int dimsize, int stride);
} // namespace infini

View File

@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8;
// Concat operator acts like element tensors composing to one big tensor,and // Concat operator acts like element tensors composing to one big tensor,and
// split operator acts like one big tensor being composed by element // split operator acts like one big tensor being composed by element
// tensors. // tensors.
struct ElementTensorMetadata { template <typename T> struct ElementTensorMetadata {
float *data[BATCH_SIZE]; T *data[BATCH_SIZE];
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
// the composed tensor. // the composed tensor.
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor. int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
data[i], dimBgNo[i], dimSize[i], nElements[i]); data[i], dimBgNo[i], dimSize[i], nElements[i]);
} }
}; };
template <typename T> struct ComposedTensorMetadata {
struct ComposedTensorMetadata {
int dimSize[DIM_MAX_SIZE]; int dimSize[DIM_MAX_SIZE];
int stride[DIM_MAX_SIZE]; int stride[DIM_MAX_SIZE];
float *data; T *data;
}; };
namespace infini { namespace infini {
void split_concat_kernel(const ElementTensorMetadata &eleMeta, void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
const ComposedTensorMetadata &compMeta, int dim, const ComposedTensorMetadata<float> &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
const ComposedTensorMetadata<half> &compMeta, int dim,
int batchSize, int nDims, bool isSplit); int batchSize, int nDims, bool isSplit);
} // namespace infini } // namespace infini

View File

@ -5,7 +5,7 @@
namespace infini { namespace infini {
void transpose_kernel(float *input, float *output, int nDims, int size, void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape); SmallArray strides, SmallArray outputShape);
}; // namespace infini }; // namespace infini

View File

@ -3,48 +3,21 @@
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
void softmax_kernel(float *input, float *output, size_t num); template <typename T> void softmax_kernel(T *input, T *output, size_t num);
void relu_kernel(float *input, float *output, size_t num); template <typename T> void relu_kernel(T *input, T *output, size_t num);
void sigmoid_kernel(float *input, float *output, size_t num); template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
void tanh_kernel(float *input, float *output, size_t num); template <typename T> void tanh_kernel(T *input, T *output, size_t num);
void abs_kernel(float *input, float *output, size_t num); template <typename T> void abs_kernel(T *input, T *output, size_t num);
void sqrt_kernel(float *input, float *output, size_t num); template <typename T> void sqrt_kernel(T *input, T *output, size_t num);
void neg_kernel(float *input, float *output, size_t num); template <typename T> void neg_kernel(T *input, T *output, size_t num);
void gelu_kernel(float *input, float *output, size_t num); template <typename T> void gelu_kernel(T *input, T *output, size_t num);
void erf_kernel(float *input, float *output, size_t num); template <typename T> void erf_kernel(T *input, T *output, size_t num);
void hard_sigmoid_kernel(float *input, float *output, size_t num); template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
void hard_swish_kernel(float *input, float *output, size_t num); template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
void unary_kernel(const Operator &_op) { template <typename INPUT, typename OUTPUT>
auto op = as<UnaryObj>(_op); void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
size_t num = op->getOutput()->size(); void unary_kernel(const Operator &_op);
if (op->getOpType() == OpType::Softmax)
softmax_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Relu)
relu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sigmoid)
sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSigmoid)
hard_sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSwish)
hard_swish_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Tanh)
tanh_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Abs)
abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Gelu)
gelu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Neg)
neg_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Erf)
erf_kernel(inputData, outputData, num);
else
IT_TODO_HALT();
}
}; // namespace infini }; // namespace infini

View File

@ -1,11 +1,29 @@
#pragma once
#include "core/tensor.h" #include "core/tensor.h"
#include "cuda/cuda_common.h"
namespace infini { namespace infini {
void cudaPrintFloat(float *x, int len); void cudaPrintFloat(float *x, int len);
void cudaPrintTensor(const Tensor &tensor) { void cudaPrintTensor(const Tensor &tensor);
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
} // namespace infini cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
cudaDataType cublasDataTypeConvert(DataType);
template <int index> struct DT_CUDA {};
template <> struct DT_CUDA<0> { using t = bool; };
template <> struct DT_CUDA<1> { using t = float; };
template <> struct DT_CUDA<2> { using t = unsigned char; };
template <> struct DT_CUDA<3> { using t = char; };
template <> struct DT_CUDA<4> { using t = unsigned short; };
template <> struct DT_CUDA<5> { using t = short; };
template <> struct DT_CUDA<6> { using t = int; };
template <> struct DT_CUDA<7> { using t = long long; };
template <> struct DT_CUDA<9> { using t = bool; };
template <> struct DT_CUDA<10> { using t = half; };
template <> struct DT_CUDA<11> { using t = double; };
template <> struct DT_CUDA<12> { using t = unsigned int; };
template <> struct DT_CUDA<13> { using t = unsigned long long; };
template <> struct DT_CUDA<16> { using t = nv_bfloat16; };
} // namespace infini

View File

@ -3,10 +3,15 @@
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
void whereKernel(const float *inputX, const float *inputY, void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims, const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape, int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize, SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize); int ySize, int cSize);
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
}; // namespace infini }; // namespace infini

View File

@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData,
metaData.inStride[i] = in->getStride()[i]; metaData.inStride[i] = in->getStride()[i];
} }
} }
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num); template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
size_t num); size_t num);

View File

@ -91,6 +91,12 @@ template <int val> class ValGenerator : public DataGenerator {
fill<uint32_t>(data, size); fill<uint32_t>(data, size);
} }
void fill(float *data, size_t size) override { fill<float>(data, size); } void fill(float *data, size_t size) override { fill<float>(data, size); }
void fill_fp16(uint16_t *data, size_t size) {
for (size_t i = 0; i < size; i++) {
float x = 1.0f * val;
data[i] = float_to_fp16(x);
}
}
}; };
typedef ValGenerator<1> OneGenerator; typedef ValGenerator<1> OneGenerator;
typedef ValGenerator<0> ZeroGenerator; typedef ValGenerator<0> ZeroGenerator;

File diff suppressed because it is too large Load Diff

View File

@ -16,8 +16,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -87,48 +87,33 @@ string GraphObj::toString() const {
} }
bool GraphObj::topo_sort() { bool GraphObj::topo_sort() {
if (this->sorted) if (this->sorted) {
return true; return true;
}
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
std::vector<Operator> sorted; std::vector<Operator> sorted;
std::unordered_set<OperatorObj *> flags;
while (!waiting.empty()) { sorted.reserve(ops.size());
flags.reserve(ops.size());
while (sorted.size() < ops.size()) {
// Any node is move to sorted in this loop. // Any node is move to sorted in this loop.
auto modified = false; auto modified = false;
// Find head nodes. for (auto const &op : ops) {
for (auto it = waiting.begin(); it != waiting.end();) { if (auto const &inputs = op->getInputs();
const auto &this_inputs = (*it)->getInputs(); flags.find(op.get()) == flags.end() &&
// If none of the input tensors is in waiting list, std::all_of(inputs.begin(), inputs.end(),
// this node is a head node. [&flags](auto const &input) {
const auto is_head = std::all_of( auto ptr = input->getSource().get();
this_inputs.begin(), this_inputs.end(), [&](const auto &input) { return !ptr || flags.find(ptr) != flags.end();
auto src = input->getSource(); })) {
return src // If the source node is in the waiting
// list, means that this node is not the
// head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
modified = true; modified = true;
sorted.emplace_back(std::move(*it)); sorted.emplace_back(op);
it = waiting.erase(it); flags.insert(op.get());
} else {
++it;
} }
} }
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) { if (!modified) {
return false; return false;
} }
} }
// Done.
this->ops = std::move(sorted); this->ops = std::move(sorted);
return this->sorted = true; return this->sorted = true;
} }
@ -182,7 +167,10 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
// note: behavior may not match running in non-naive mode, and it may // note: behavior may not match running in non-naive mode, and it may
// not reproduce the bug // not reproduce the bug
for (auto &tensor : tensors) { for (auto &tensor : tensors) {
tensor->dataMalloc(); if (!tensor->isWeight() ||
(tensor->isWeight() && !weightAllocated)) {
tensor->dataMalloc();
}
} }
return; return;
} }

View File

@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
auto kernelAttrs = auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
auto kernelAttrs = auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance(); auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt; std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) { for (auto &op : graph->getOperators()) {
// HACK: set correct data type // HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(), auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey); auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -1,4 +1,6 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include <cstdio> #include <cstdio>
__global__ void cudaPrintFloatImpl(float *x, int len) { __global__ void cudaPrintFloatImpl(float *x, int len) {
@ -18,4 +20,55 @@ void cudaPrintFloat(float *x, int len) {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
} }
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return CUDNN_DATA_FLOAT;
}
if (dataType == DataType::Double) {
return CUDNN_DATA_DOUBLE;
}
if (dataType == DataType::Float16) {
return CUDNN_DATA_HALF;
}
if (dataType == DataType::Int8) {
return CUDNN_DATA_INT8;
}
if (dataType == DataType::Int32) {
return CUDNN_DATA_INT32;
}
if (dataType == DataType::UInt8) {
return CUDNN_DATA_UINT8;
}
if (dataType == DataType::BFloat16) {
return CUDNN_DATA_BFLOAT16;
}
if (dataType == DataType::Int64) {
return CUDNN_DATA_INT64;
}
if (dataType == DataType::Bool) {
return CUDNN_DATA_BOOLEAN;
}
IT_ASSERT(false, "Unsupported data type");
}
cudaDataType cublasDataTypeConvert(DataType dataType) {
switch (dataType.getIndex()) {
case 1:
return CUDA_R_32F;
// case 3:
// return CUDA_R_8I;
case 10:
return CUDA_R_16F;
case 11:
return CUDA_R_64F;
// case 16:
// return CUDA_R_16BF;
default:
IT_ASSERT(false, "MatMul Unsupported data type");
}
}
} // namespace infini } // namespace infini

View File

@ -11,6 +11,7 @@ class UnaryCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -50,6 +51,7 @@ class RoundCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -80,6 +82,7 @@ class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PReluObj>(_op); auto op = as<PReluObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -119,6 +122,7 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op); auto op = as<SoftmaxObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -215,15 +219,12 @@ class SigmoidCnnl : public UnaryCnnl {
float getCoef() const override { return 0.0; } float getCoef() const override { return 0.0; }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl, REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
"Relu_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
"PRelu_cnnl_BANG_Float32"); "Sigmoid_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl, REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
"Sigmoid_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl, "Softmax_cnnl_BANG");
"Round_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl,
"Softmax_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -10,6 +10,7 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ActivationBackwardObj>(_op); auto op = as<ActivationBackwardObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -81,11 +82,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl {
float getCoef() const override { return 0.0; } float getCoef() const override { return 0.0; }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, ReluBackwardCnnl,
ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32"); "ReluBackward_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, SigmoidBackwardCnnl,
SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32"); "SigmoidBackward_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, TanhBackwardCnnl,
TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32"); "TanhBackward_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<BatchNormObj>(_op); auto op = as<BatchNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>()); void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
@ -101,7 +102,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, BatchNormCnnl,
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32"); "BatchNorm_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -212,7 +212,6 @@ class CastCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl, REGISTER_KERNEL(Device::BANG, OpType::Cast, CastCnnl, "Cast_cnnl_BANG");
"Cast_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class CeilCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Ceil, DataType::Float32, CeilCnnl, REGISTER_KERNEL(Device::BANG, OpType::Ceil, CeilCnnl, "Ceil_cnnl_BANG");
"Ceil_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ClipCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op); auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -30,7 +31,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Clip, DataType::Float32, ClipCnnl, REGISTER_KERNEL(Device::BANG, OpType::Clip, ClipCnnl, "Clip_cnnl_BANG");
"Clip_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConcatObj>(_op); auto op = as<ConcatObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs(); int num = op->numInputs();
int axis = op->getDim(); int axis = op->getDim();
@ -50,6 +51,5 @@ class ConcatCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Concat, DataType::Float32, ConcatCnnl, REGISTER_KERNEL(Device::BANG, OpType::Concat, ConcatCnnl, "Concat_cnnl_BANG");
"Concat_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op); auto op = as<ConvObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -151,6 +152,5 @@ class ConvCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Conv, DataType::Float32, ConvCnnl, REGISTER_KERNEL(Device::BANG, OpType::Conv, ConvCnnl, "Conv_cnnl_BANG");
"Conv_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvBaseObj>(_op); auto op = as<ConvBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -76,6 +77,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, ConvTransCnnl,
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32"); "ConvTrans_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ConvBackwardFilterObj>(_op); auto op = as<ConvBackwardFilterObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -154,6 +155,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter,
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32"); ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class DetCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<DetObj>(_op); auto op = as<DetObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -42,6 +43,5 @@ class DetCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl, REGISTER_KERNEL(Device::BANG, OpType::Det, DetCnnl, "Det_cnnl_BANG");
"Det_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -11,6 +11,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -74,6 +75,7 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -127,6 +129,7 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -179,6 +182,7 @@ class DivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -231,6 +235,7 @@ class MaximumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -282,6 +287,7 @@ class MinimumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -333,6 +339,7 @@ class MSELossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<MSELossObj>(_op); auto op = as<MSELossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -389,6 +396,7 @@ class PowerCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -442,6 +450,7 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -494,6 +503,7 @@ class FloorModCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -546,6 +556,7 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -658,62 +669,48 @@ class BitNotCnnl : public BitComputeCnnl {
// CNNL_BLEFT_SHIFT_OP_V2; } // CNNL_BLEFT_SHIFT_OP_V2; }
// }; // };
REGISTER_KERNEL(Device::BANG, OpType::Add, DataType::Float32, AddCnnl, REGISTER_KERNEL(Device::BANG, OpType::Add, AddCnnl, "Add_cnnl_BANG");
"Add_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Sub, SubCnnl, "Sub_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl, REGISTER_KERNEL(Device::BANG, OpType::Mul, MulCnnl, "Mul_cnnl_BANG");
"Sub_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
"Mul_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl, REGISTER_KERNEL(Device::BANG, OpType::Div, DivCnnl, "Div_cnnl");
"Div_cnnl_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Max, MaximumCnnl, "Maximum_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl, REGISTER_KERNEL(Device::BANG, OpType::Min, MinimumCnnl, "Minimum_cnnl_BANG");
"Maximum_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::MSELoss, MSELossCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl, "MSELoss_cnnl_BANG");
"Minimum_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Pow, PowerCnnl, "Power_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, FloorDivCnnl,
"MSELoss_cnnl_BANG_Float32"); "FloorDiv_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorMod, FloorModCnnl,
"Power_cnnl_BANG_Float32"); "FloorMod_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl, REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, SquaredDifferenceCnnl,
"FloorDiv_cnnl_BANG_Float32"); "SquaredDifference_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl, REGISTER_KERNEL(Device::BANG, OpType::Equal, EqualCnnl, "Equal_cnnl_BANG");
"FloorMod_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Greater, GreaterThanCnnl,
REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32, "GreaterThan_cnnl_BANG");
SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, GreaterEqualCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl, "GreaterEqual_cnnl_BANG");
"Equal_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Less, LessThanCnnl, "LessThan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, LessEqualCnnl,
GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32"); "LessEqual_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::And, AndCnnl, "And_cnnl_BANG");
GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Or, OrCnnl, "Or_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Xor, XorCnnl, "Xor_cnnl_BANG");
"LessThan_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Not, NotCnnl, "Not_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, BitAndCnnl,
LessEqualCnnl, "LessEqual_cnnl_BANG_Float32"); "BitAnd_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl, REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, BitOrCnnl, "BitOr_cnnl_BANG");
"And_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, BitXorCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Or, DataType::Float32, OrCnnl, "BitXor_cnnl_BANG");
"Or_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, BitNotCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl, "BitNot_cnnl_BANG");
"Xor_cnnl_BANG_Float32"); // REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift,
REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl,
"Not_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl,
"BitAnd_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl,
"BitOr_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl,
"BitXor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl,
"BitNot_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32,
// BitLeftShiftCnnl, // BitLeftShiftCnnl,
// "BitLeftShift_cnnl_BANG_Float32"); // "BitLeftShift_cnnl_BANG");
// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift, DataType::Float32, // REGISTER_KERNEL(Device::BANG, OpType::BitRightShift,
// BitRightShiftCnnl, // BitRightShiftCnnl,
// "BitRightShift_cnnl_BANG_Float32"); // "BitRightShift_cnnl_BANG");
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, // REGISTER_KERNEL(Device::BANG, OpType::Pow,
// ElementWiseBang, // ElementWiseBang,
// "Pow_Bang_Float32"); // "Pow_Bang");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ErfCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class ErfCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Erf, DataType::Float32, ErfCnnl, REGISTER_KERNEL(Device::BANG, OpType::Erf, ErfCnnl, "Erf_cnnl_BANG");
"Erf_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ExpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class ExpCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Exp, DataType::Float32, ExpCnnl, REGISTER_KERNEL(Device::BANG, OpType::Exp, ExpCnnl, "Exp_cnnl_BANG");
"Exp_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class FillCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<FillObj>(_op); auto op = as<FillObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -29,7 +30,6 @@ class FillCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Fill, DataType::Float32, FillCnnl, REGISTER_KERNEL(Device::BANG, OpType::Fill, FillCnnl, "Fill_cnnl_BANG");
"Fill_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Floor, DataType::Float32, FloorCnnl, REGISTER_KERNEL(Device::BANG, OpType::Floor, FloorCnnl,
"Floor_cnnl_BANG_Float32"); "Floor_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class GatherCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op); auto op = as<GatherObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -49,7 +50,6 @@ class GatherCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl, REGISTER_KERNEL(Device::BANG, OpType::Gather, GatherCnnl, "Gather_cnnl_BANG");
"Gather_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<HardtanhObj>(_op); auto op = as<HardtanhObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -30,7 +31,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, DataType::Float32, HardtanhCnnl, REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, HardtanhCnnl,
"Hardtanh_cnnl_BANG_Float32"); "Hardtanh_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class L2LossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<L2LossObj>(_op); auto op = as<L2LossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -28,7 +29,6 @@ class L2LossCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::L2Loss, DataType::Float32, L2LossCnnl, REGISTER_KERNEL(Device::BANG, OpType::L2Loss, L2LossCnnl, "L2Loss_cnnl_BANG");
"L2Loss_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -8,6 +8,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op); auto op = as<LayerNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -58,7 +59,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, LayerNormCnnl,
LayerNormCnnl, "LayerNorm_BANG_Float32"); "LayerNorm_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class LogCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LogObj>(_op); auto op = as<LogObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -51,7 +52,6 @@ class LogCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Log, DataType::Float32, LogCnnl, REGISTER_KERNEL(Device::BANG, OpType::Log, LogCnnl, "Log_cnnl_BANG");
"Log_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class LRNCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<LRNObj>(_op); auto op = as<LRNObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -56,7 +57,6 @@ class LRNCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl, REGISTER_KERNEL(Device::BANG, OpType::LRN, LRNCnnl, "LRN_cnnl_BANG");
"LRN_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -8,6 +8,7 @@ class MatmulCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<MatmulObj>(_op); auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto input_num = op->numInputs(); auto input_num = op->numInputs();
@ -107,6 +108,5 @@ class MatmulCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl, REGISTER_KERNEL(Device::BANG, OpType::MatMul, MatmulCnnl, "Matmul_cnnl_BANG");
"Matmul_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Neg, DataType::Float32, NegTensorCnnl, REGISTER_KERNEL(Device::BANG, OpType::Neg, NegTensorCnnl, "Neg_cnnl_BANG");
"Neg_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class PadCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PadObj>(_op); auto op = as<PadObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -57,7 +58,6 @@ class PadCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Pad, DataType::Float32, PadCnnl, REGISTER_KERNEL(Device::BANG, OpType::Pad, PadCnnl, "Pad_cnnl_BANG");
"Pad_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -8,6 +8,7 @@ class PoolingCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op); auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -68,8 +69,8 @@ class avgPoolCnnl : public PoolingCnnl {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl, REGISTER_KERNEL(Device::BANG, OpType::MaxPool, maxPoolCnnl,
"MaxPool_cnnl_BANG_Float32"); "MaxPool_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::AveragePool, avgPoolCnnl,
avgPoolCnnl, "AvgPool_cnnl_BANG_Float32"); "AvgPool_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, ReciprocalCnnl,
ReciprocalCnnl, "Reciprocal_cnnl_BANG_Float32"); "Reciprocal_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -9,6 +9,7 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ReduceBaseObj>(_op); auto op = as<ReduceBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -73,9 +74,9 @@ class ReduceSumCnnl : public ReduceCnnlBase {
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; } cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, ReduceMeanCnnl,
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32"); "ReduceMean_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, ReduceSumCnnl,
ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32"); "ReduceSum_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -13,9 +13,9 @@ class CopyBang : public BangKernelWithoutConfig {
auto dim = op->getInputs(0)->getDims(); auto dim = op->getInputs(0)->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dim.size(), aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
dim.data())); dim.size() * op->getDType().getSize(), dim.data()));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData); cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
@ -25,13 +25,8 @@ class CopyBang : public BangKernelWithoutConfig {
} }
}; };
// reshape/flatten/identity all act as copying from input to output. // reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang, REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG");
"Reshape_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang, REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG");
"Reshape_BANG_Int64");
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
"Flatten_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,
"Identity_BANG_Float32");
} // namespace infini } // namespace infini

View File

@ -7,6 +7,7 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, DataType::Float32, RsqrtCnnl, REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, RsqrtCnnl, "Rsqrt_cnnl_BANG");
"Rsqrt_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class SplitCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SplitObj>(_op); auto op = as<SplitObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numOutputs(); int num = op->numOutputs();
int axis = op->getDim(); int axis = op->getDim();
@ -49,6 +50,5 @@ class SplitCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Split, DataType::Float32, SplitCnnl, REGISTER_KERNEL(Device::BANG, OpType::Split, SplitCnnl, "Split_cnnl_BANG");
"Split_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class SqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class SqrtCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Sqrt, DataType::Float32, SqrtCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sqrt, SqrtCnnl, "Sqrt_cnnl_BANG");
"Sqrt_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class TransposeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op); auto op = as<TransposeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -52,6 +53,7 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<DepthToSpaceObj>(_op); auto op = as<DepthToSpaceObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -101,9 +103,9 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::Transpose, TransposeCnnl,
TransposeCnnl, "Transpose_cnnl_BANG_Float32"); "Transpose_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DataType::Float32, REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DepthToSpaceCnnl,
DepthToSpaceCnnl, "DepthToSpace_cnnl_BANG_Float32"); "DepthToSpace_cnnl_BANG");
}; // namespace infini }; // namespace infini

View File

@ -9,6 +9,7 @@ class TrigonCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -150,29 +151,17 @@ class ATanHCnnl : public TrigonCnnl {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Sin, DataType::Float32, SinCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sin, SinCnnl, "Sin_cnnl_BANG");
"Sin_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Cos, CosCnnl, "Cos_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl, REGISTER_KERNEL(Device::BANG, OpType::Tan, TanCnnl, "Tan_cnnl_BANG");
"Cos_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Asin, ASinCnnl, "ASin_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Acos, ACosCnnl, "ACos_cnnl_BANG");
"Tan_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Atan, ATanCnnl, "ATan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl, REGISTER_KERNEL(Device::BANG, OpType::Sinh, SinHCnnl, "SinH_cnnl_BANG");
"ASin_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Cosh, CosHCnnl, "CosH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl, REGISTER_KERNEL(Device::BANG, OpType::Tanh, TanHCnnl, "TanH_cnnl_BANG");
"ACos_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Asinh, ASinHCnnl, "ASinH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl, REGISTER_KERNEL(Device::BANG, OpType::Acosh, ACosHCnnl, "ACosH_cnnl_BANG");
"ATan_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Atanh, ATanHCnnl, "ATanH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl,
"SinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl,
"CosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl,
"TanH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl,
"ASinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl,
"ACosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl,
"ATanH_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class WhereCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op); auto op = as<WhereObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -67,7 +68,6 @@ class WhereCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl, REGISTER_KERNEL(Device::BANG, OpType::Where, WhereCnnl, "Where_cnnl_BANG");
"Where_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini { namespace infini {
template <typename T> class NaiveConcat : public CpuKernelWithoutConfig { class NaiveConcat : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConcatObj>(_op); auto op = as<ConcatObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs(); auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim(); auto dim = op->getDim();
@ -41,11 +41,25 @@ template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
} }
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Concat, NaiveConcat, "ConcatNaive_CPU");
NaiveConcat<uint32_t>, "ConcatNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::Float32,
NaiveConcat<float>, "ConcatNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini { namespace infini {
template <typename T> class NaiveConv : public CpuKernelWithoutConfig { class NaiveConv : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConvObj>(_op); auto op = as<ConvObj>(_op);
T *iptr = op->getInputs(0)->getRawDataPtr<T *>(); T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
T *wptr = op->getInputs(1)->getRawDataPtr<T *>(); T *wptr = op->getInputs(1)->getRawDataPtr<T *>();
@ -50,11 +50,25 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
} }
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Conv, NaiveConv, "ConvNaive_CPU");
NaiveConv<uint32_t>, "ConvNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv<float>,
"ConvNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -3,10 +3,45 @@
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
namespace infini { namespace infini {
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig { class NativeElementWise : public CpuKernelWithoutConfig {
virtual T doCompute(T val0, T val1) const = 0; template <typename T> static T addCompute(T val0, T val1) {
void compute(const Operator &_op, return val0 + val1;
const RuntimeObj *context) const override { }
template <typename T> static T subCompute(T val0, T val1) {
return val0 - val1;
}
template <typename T> static T mulCompute(T val0, T val1) {
return val0 * val1;
}
template <typename T> static T divCompute(T val0, T val1) {
return (T)(val0 / val1);
}
template <typename T> static T equalCompute(T val0, T val1) {
return (T)(val0 == val1);
}
template <typename T> static T greaterOrEqualCompute(T val0, T val1) {
return (T)(val0 >= val1);
}
template <typename T> static T greaterCompute(T val0, T val1) {
return (T)(val0 > val1);
}
template <typename T> static T lessOrEqualCompute(T val0, T val1) {
return (T)(val0 <= val1);
}
template <typename T> static T lessCompute(T val0, T val1) {
return (T)(val0 < val1);
}
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>(); T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
@ -35,77 +70,77 @@ template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
Shape strideB = getStride(b); Shape strideB = getStride(b);
auto n = op->getOutput()->size(); auto n = op->getOutput()->size();
T (*_doCompute)(T val0, T val1);
switch (op->getOpType().underlying()) {
case OpType::Add:
_doCompute = addCompute<T>;
break;
case OpType::Sub:
_doCompute = subCompute<T>;
break;
case OpType::Mul:
_doCompute = mulCompute<T>;
break;
case OpType::Div:
_doCompute = divCompute<T>;
break;
case OpType::Equal:
_doCompute = equalCompute<T>;
break;
case OpType::GreaterOrEqual:
_doCompute = greaterOrEqualCompute<T>;
break;
case OpType::Greater:
_doCompute = greaterCompute<T>;
break;
case OpType::LessOrEqual:
_doCompute = lessOrEqualCompute<T>;
break;
case OpType::Less:
_doCompute = lessCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto shapeIndexC = locate_index(i, shapeC); auto shapeIndexC = locate_index(i, shapeC);
auto indexA = delocate_index(shapeIndexC, a, strideA); auto indexA = delocate_index(shapeIndexC, a, strideA);
auto indexB = delocate_index(shapeIndexC, b, strideB); auto indexB = delocate_index(shapeIndexC, b, strideB);
outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]); outptr[i] = _doCompute(inptr0[indexA], inptr1[indexB]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
} }
} }
}; };
template <typename T> class NaiveAdd : public NativeElementWise<T> { REGISTER_KERNEL(Device::CPU, OpType::Add, NativeElementWise, "addNaive_CPU");
T doCompute(T val0, T val1) const override { return val0 + val1; } REGISTER_KERNEL(Device::CPU, OpType::Sub, NativeElementWise, "subNaive_CPU");
}; REGISTER_KERNEL(Device::CPU, OpType::Mul, NativeElementWise, "mulNaive_CPU");
template <typename T> class NaiveSub : public NativeElementWise<T> { REGISTER_KERNEL(Device::CPU, OpType::Div, NativeElementWise, "divNaive_CPU");
T doCompute(T val0, T val1) const override { return val0 - val1; } REGISTER_KERNEL(Device::CPU, OpType::Equal, NativeElementWise,
}; "equalNaive_CPU");
template <typename T> class NaiveMul : public NativeElementWise<T> { REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, NativeElementWise,
T doCompute(T val0, T val1) const override { return val0 * val1; } "greaterEqualNaive_CPU");
}; REGISTER_KERNEL(Device::CPU, OpType::Greater, NativeElementWise,
template <typename T> class NaiveDiv : public NativeElementWise<T> { "greaterThanNaive_CPU");
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); } REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, NativeElementWise,
}; "lessEqualNaive_CPU");
template <typename T> class NaiveEqual : public NativeElementWise<T> { REGISTER_KERNEL(Device::CPU, OpType::Less, NativeElementWise,
T doCompute(T val0, T val1) const override { return (T)(val0 == val1); } "lessEqualNaive_CPU");
};
template <typename T> class NaiveGreaterEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); }
};
template <typename T> class NaiveGreaterThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 > val1); }
};
template <typename T> class NaiveLessEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); }
};
template <typename T> class NaiveLessThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 < val1); }
};
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
"addNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
"addNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
"subNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
"subNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
"mulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
"mulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
"divNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
"divNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32,
NaiveEqual<uint32_t>, "equalNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32,
NaiveEqual<float>, "equalNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32,
NaiveGreaterEqual<uint32_t>, "greaterEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32,
NaiveGreaterEqual<float>, "greaterEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32,
NaiveGreaterThan<uint32_t>, "greaterThanNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32,
NaiveGreaterThan<float>, "greaterThanNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32,
NaiveLessEqual<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32,
NaiveLessEqual<float>, "lessEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32,
NaiveLessThan<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32,
NaiveLessThan<float>, "lessEqualNaive_CPU_float32");
}; // namespace infini }; // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini { namespace infini {
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig { class NaiveMatmul : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<MatmulObj>(_op); auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
T *A = op->getInputs(0)->getRawDataPtr<T *>(); T *A = op->getInputs(0)->getRawDataPtr<T *>();
@ -23,11 +23,25 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
} }
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::MatMul, NaiveMatmul, "MatmulNaive_CPU");
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -80,8 +80,8 @@ class MemboundInterpreter : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::MemBound, MemboundInterpreter,
MemboundInterpreter, "MemboundInterpreter_CPU"); "MemboundInterpreter_CPU");
} // namespace infini } // namespace infini

View File

@ -2,42 +2,10 @@
#include "core/kernel.h" #include "core/kernel.h"
namespace infini { namespace infini {
template <typename T> class NativePooling : public CpuKernelWithoutConfig { class NativePooling : public CpuKernelWithoutConfig {
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih, template <typename T>
int iw, T *inptr) const = 0; static T getMaxPoolingValue(int kh, int kw, int posh, int posw, int ih,
void compute(const Operator &_op, int iw, T *inptr) {
const RuntimeObj *context) const override {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val =
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
ih, iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
};
template <typename T> class NaiveMaxPool : public NativePooling<T> {
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
T *inptr) const override {
T maxval = 0; T maxval = 0;
for (auto k = 0; k < kh; k++) { for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) { for (auto l = 0; l < kw; l++) {
@ -53,11 +21,10 @@ template <typename T> class NaiveMaxPool : public NativePooling<T> {
} }
return maxval; return maxval;
} }
};
template <typename T> class NaiveAvgPool : public NativePooling<T> { template <typename T>
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw, static T getAvgPoolingValue(int kh, int kw, int posh, int posw, int ih,
T *inptr) const override { int iw, T *inptr) {
T sum = 0; T sum = 0;
for (auto k = 0; k < kh; k++) { for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) { for (auto l = 0; l < kw; l++) {
@ -71,12 +38,70 @@ template <typename T> class NaiveAvgPool : public NativePooling<T> {
} }
return T(sum / (kh * kw)); return T(sum / (kh * kw));
} }
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
T(*_doCompute)
(int kh, int kw, int posh, int posw, int ih, int iw, T *inptr);
switch (op->getOpType().underlying()) {
case OpType::MaxPool:
_doCompute = getMaxPoolingValue<T>;
break;
case OpType::AveragePool:
_doCompute = getAvgPoolingValue<T>;
break;
default:
IT_TODO_HALT();
}
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val = _doCompute(kh, kw, h * sh - ph, w * sw - pw, ih,
iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::MaxPool, NativePooling,
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32"); "maxPoolNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::AveragePool, NativePooling,
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32"); "avgPoolNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini { namespace infini {
template <typename T> class NaiveSplit : public CpuKernelWithoutConfig { class NaiveSplit : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SplitObj>(_op); auto op = as<SplitObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs(); auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim(); auto dim = op->getDim();
@ -40,11 +40,24 @@ template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
} }
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Split, NaiveSplit, "SplitNaive_CPU");
NaiveSplit<uint32_t>, "SplitNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::Float32,
NaiveSplit<float>, "SplitNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -14,9 +14,9 @@ inline Shape idx2Pos(const Shape &shape, size_t idx) {
return pos; return pos;
} }
template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig { class NaiveTranspose : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<TransposeObj>(_op); auto op = as<TransposeObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs(); auto inputs = op->getInputs(), outputs = op->getOutputs();
const auto &inDim = inputs[0]->getDims(); const auto &inDim = inputs[0]->getDims();
@ -35,11 +35,26 @@ template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
outPtr[outIdx] = inPtr[inIdx]; outPtr[outIdx] = inPtr[inIdx];
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Transpose, NaiveTranspose,
NaiveTranspose<uint32_t>, "TransposeNaive_CPU_uint32"); "TransposeNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::Float32,
NaiveTranspose<float>, "TransposeNaive_CPU_float32");
} // namespace infini } // namespace infini

View File

@ -4,25 +4,170 @@
#include "operators/softmax.h" #include "operators/softmax.h"
namespace infini { namespace infini {
template <typename T> class NativeUnary : public CpuKernelWithoutConfig { class NativeUnary : public CpuKernelWithoutConfig {
virtual T doCompute(T val) const = 0; template <typename T> static T reluCompute(T val) {
void compute(const Operator &_op, return std::max(T(0), val);
const RuntimeObj *context) const override { }
template <typename T> static T sigmoidCompute(T val) {
return 1 / (1 + pow(E_CONSTANT, -val));
}
template <typename T> static T hardSigmoidCompute(T val) {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
template <typename T> static T hardSwishCompute(T val) {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
template <typename T> static T tanhCompute(T val) {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
template <typename T> static T absCompute(T val) {
return val < 0 ? -val : val;
}
template <typename T> static T sqrtCompute(T val) { return std::sqrt(val); }
template <typename T> static T cosCompute(T val) { return std::cos(val); }
template <typename T> static T sinCompute(T val) { return std::sin(val); }
template <typename T> static T tanCompute(T val) { return std::tan(val); }
template <typename T> static T sinhCompute(T val) { return std::sinh(val); }
template <typename T> static T coshCompute(T val) { return std::cosh(val); }
template <typename T> static T geluCompute(T val) {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
template <typename T> static T erfCompute(T val) { return std::erf(val); }
template <typename T> static T aCosCompute(T val) { return std::acos(val); }
template <typename T> static T aCoshCompute(T val) {
return std::acosh(val);
}
template <typename T> static T aSinCompute(T val) { return std::asin(val); }
template <typename T> static T aSinhCompute(T val) {
return std::asinh(val);
}
template <typename T> static T aTanCompute(T val) { return std::atan(val); }
template <typename T> static T aTanhCompute(T val) {
return std::atanh(val);
}
template <typename T> static T negCompute(T val) { return -val; }
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>(); T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims(); auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size(); auto n = op->getOutput()->size();
T (*_doCompute)(T val);
switch (op->getOpType().underlying()) {
case OpType::Relu:
_doCompute = reluCompute<T>;
break;
case OpType::Gelu:
_doCompute = geluCompute<T>;
break;
case OpType::Sigmoid:
_doCompute = sigmoidCompute<T>;
break;
case OpType::HardSigmoid:
_doCompute = hardSigmoidCompute<T>;
break;
case OpType::HardSwish:
_doCompute = hardSwishCompute<T>;
break;
case OpType::Tanh:
_doCompute = tanhCompute<T>;
break;
case OpType::Abs:
_doCompute = absCompute<T>;
break;
case OpType::Sqrt:
_doCompute = sqrtCompute<T>;
break;
case OpType::Erf:
_doCompute = erfCompute<T>;
break;
case OpType::Neg:
_doCompute = negCompute<T>;
break;
case OpType::Cos:
_doCompute = cosCompute<T>;
break;
case OpType::Sin:
_doCompute = sinCompute<T>;
break;
case OpType::Tan:
_doCompute = tanCompute<T>;
break;
case OpType::Sinh:
_doCompute = sinhCompute<T>;
break;
case OpType::Cosh:
_doCompute = coshCompute<T>;
break;
case OpType::Acos:
_doCompute = aCosCompute<T>;
break;
case OpType::Asin:
_doCompute = aSinCompute<T>;
break;
case OpType::Asinh:
_doCompute = aSinhCompute<T>;
break;
case OpType::Atan:
_doCompute = aTanCompute<T>;
break;
case OpType::Atanh:
_doCompute = aTanhCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t offset = 0; offset < n; offset++) { for (size_t offset = 0; offset < n; offset++) {
outptr[offset] = doCompute(inptr[offset]); outptr[offset] = _doCompute(inptr[offset]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
} }
} }
}; };
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig { class NaiveSoftmax : public CpuKernelWithoutConfig {
void compute(const Operator &_op, template <typename T>
const RuntimeObj *context) const override { void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SoftmaxObj>(_op); auto op = as<SoftmaxObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>(); T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -37,98 +182,28 @@ template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum; outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
} }
} }
};
template <typename T> class NaiveRelu : public NativeUnary<T> {
T doCompute(T val) const override { return std::max(T(0), val); }
};
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return 1 / (1 + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveHardSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
};
template <typename T> class NaiveHardSwish : public NativeUnary<T> {
T doCompute(T val) const override {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
};
template <typename T> class NaiveTanh : public NativeUnary<T> {
T doCompute(T val) const override {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveAbs : public NativeUnary<T> {
T doCompute(T val) const override { return val < 0 ? -val : val; }
};
template <typename T> class NaiveSqrt : public NativeUnary<T> {
T doCompute(T val) const override { return std::sqrt(val); }
};
template <typename T> class NaiveCos : public NativeUnary<T> {
T doCompute(T val) const override { return std::cos(val); }
};
template <typename T> class NaiveSin : public NativeUnary<T> {
T doCompute(T val) const override { return std::sin(val); }
};
template <typename T> class NaiveTan : public NativeUnary<T> {
T doCompute(T val) const override { return std::tan(val); }
};
template <typename T> class NaiveSinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::sinh(val); }
};
template <typename T> class NaiveCosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::cosh(val); }
};
template <typename T> class NaiveGelu : public NativeUnary<T> {
T doCompute(T val) const override {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
};
template <typename T> class NaiveErf : public NativeUnary<T> {
T doCompute(T val) const override { return std::erf(val); }
};
template <typename T> class NaiveACos : public NativeUnary<T> {
T doCompute(T val) const override { return std::acos(val); }
};
template <typename T> class NaiveACosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::acosh(val); }
};
template <typename T> class NaiveASin : public NativeUnary<T> {
T doCompute(T val) const override { return std::asin(val); }
};
template <typename T> class NaiveASinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::asinh(val); }
};
template <typename T> class NaiveATanh : public NativeUnary<T> {
T doCompute(T val) const override { return std::atanh(val); }
};
template <typename T> class NaiveNeg : public NativeUnary<T> {
T doCompute(T val) const override { return -val; }
};
template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Clip : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ClipObj>(_op); auto op = as<ClipObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>(); T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -143,11 +218,28 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
: val; : val;
} }
} }
};
template <typename T> class Log : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Log : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<LogObj>(_op); auto op = as<LogObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>(); T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -176,70 +268,50 @@ template <typename T> class Log : public CpuKernelWithoutConfig {
} }
} }
} }
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
}; };
template <typename T> class NaiveATan : public NativeUnary<T> { REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
T doCompute(T val) const override { return std::atan(val); } REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
}; REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
"hardSigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, NativeUnary,
"hardSwishNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, NativeUnary, "tanhNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Abs, NativeUnary, "absNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, NativeUnary, "sqrtNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Erf, NativeUnary, "erfNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Neg, NativeUnary, "negNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cos, NativeUnary, "Cos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sin, NativeUnary, "Sin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tan, NativeUnary, "Tan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, NativeUnary, "Sinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, NativeUnary, "Cosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acos, NativeUnary, "ACos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, NativeUnary, "ACosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asin, NativeUnary, "ASin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, NativeUnary, "ASinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atan, NativeUnary, "Atan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, NativeUnary, "ATanh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, NaiveSoftmax, "softmaxNaive_CPU");
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Clip, Clip, "Clip_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>, REGISTER_KERNEL(Device::CPU, OpType::Log, Log, "Log_CPU");
"reluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32,
NaiveHardSigmoid<float>, "hardSigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32,
NaiveHardSwish<float>, "hardSwishNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
"tanhNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
"absNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
"erfNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg<float>,
"negNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip<float>,
"Clip_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan<float>,
"Atan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log<float>,
"Log_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos<float>,
"Cos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin<float>,
"Sin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan<float>,
"Tan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh<float>,
"Sinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh<float>,
"Cosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos<float>,
"ACos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32,
NaiveACosh<float>, "ACosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin<float>,
"ASin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32,
NaiveASinh<float>, "ASinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32,
NaiveATanh<float>, "ATanh_CPU_float32");
}; // namespace infini }; // namespace infini

View File

@ -48,13 +48,13 @@ class G2BMMCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<G2BMMObj>(_op); auto op = as<G2BMMObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = g2bmmKernel(op, context); bool success = g2bmmKernel(op, context);
IT_ASSERT(success); IT_ASSERT(success);
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn, REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, G2BMMCudnn, "G2BMM_cuDNN_CUDA");
"G2BMM_cuDNN_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -49,13 +49,13 @@ class GBMMCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<GBMMObj>(_op); auto op = as<GBMMObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = gbmmKernel(op, context); bool success = gbmmKernel(op, context);
IT_ASSERT(success); IT_ASSERT(success);
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn, REGISTER_KERNEL(Device::CUDA, OpType::GBMM, GBMMCudnn, "GBMM_cuDNN_CUDA");
"GBMM_cuDNN_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -39,8 +39,8 @@ class AllGatherNCCL : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllGather, AllGatherNCCL,
AllGatherNCCL, "AllGather_NCCL_CUDA_Float32"); "AllGather_NCCL_CUDA");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -13,15 +13,24 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>(); void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>(); void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32); ncclDataType_t ncclType = ncclFloat;
if (op->getDType() == DataType::Float16) {
ncclType = ncclFloat16;
} else if (op->getDType() == DataType::Int8) {
ncclType = ncclInt8;
} else if (op->getDType() == DataType::Float32) {
ncclType = ncclFloat;
} else {
IT_TODO_HALT();
}
size_t count = op->getInputs(0)->size(); size_t count = op->getInputs(0)->size();
ncclComm_t comm = ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm(); .getNcclComm();
// TODO: Using default stream 0 for now. // TODO: Using default stream 0 for now.
checkNcclError(ncclAllReduce(input, output, count, ncclFloat, checkNcclError(
getRedOp(), comm, 0)); ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
} }
virtual ncclRedOp_t getRedOp() const = 0; virtual ncclRedOp_t getRedOp() const = 0;
@ -43,16 +52,16 @@ class AllReduceAvgNCCL : public AllReduceNCCL {
ncclRedOp_t getRedOp() const override { return ncclAvg; } ncclRedOp_t getRedOp() const override { return ncclAvg; }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, AllReduceSumNCCL,
AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32"); "AllReduce_Sum_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, AllReduceProdNCCL,
AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32"); "AllReduce_Prod_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, AllReduceMinNCCL,
AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32"); "AllReduce_Min_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, AllReduceMaxNCCL,
AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32"); "AllReduce_Max_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, AllReduceAvgNCCL,
AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32"); "AllReduce_Avg_NCCL_CUDA");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -40,6 +40,7 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig { public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32);
do_compute(_op->getInputs()[0], _op->getInputs()[1], do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3], _op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5], _op->getInputs()[4], _op->getInputs()[5],
@ -47,6 +48,6 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, AttentionKVCacheCuda,
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32"); "AttentionKVCache_CUDA");
} // namespace infini } // namespace infini

View File

@ -10,6 +10,7 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
auto op = as<BatchNormObj>(_op); auto op = as<BatchNormObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
cudnnStatus_t stat; cudnnStatus_t stat;
IT_ASSERT(op->getDType() == DataType::Float32);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
void *const meanData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const meanData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -59,6 +60,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, BatchNormCudnn,
BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32"); "BatchNorm_cuDNN_CUDA");
} // namespace infini } // namespace infini

View File

@ -25,8 +25,8 @@ class BroadcastNCCL : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, BroadcastNCCL,
BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32"); "Broadcast_NCCL_CUDA");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -9,7 +9,7 @@ class ClipCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op); auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
auto min = op->getMin(); auto min = op->getMin();
@ -21,7 +21,6 @@ class ClipCuda : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Clip, DataType::Float32, ClipCuda, REGISTER_KERNEL(Device::CUDA, OpType::Clip, ClipCuda, "Clip_CUDA");
"Clip_CUDA_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,10 +1,12 @@
#include "operators/conv.h" #include "operators/conv.h"
#include "core/kernel.h" #include "core/kernel.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include <chrono> #include <chrono>
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <tuple> #include <tuple>
namespace infini { namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj { struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
@ -56,8 +58,11 @@ class convCudnn : public Kernel {
const ConvCuDnnPerfRecord &record) const { const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
if (op->getInputs().size() > 2) // Bias is not supported yet // Bias is not supported yet
if (op->getInputs().size() > 2) {
IT_TODO_HALT(); IT_TODO_HALT();
}
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>()); // void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -72,27 +77,26 @@ class convCudnn : public Kernel {
cudnnTensorDescriptor_t inDesc; cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w)); inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, n, channels, h, w));
// get kernels // get kernels
cudnnFilterDescriptor_t knDesc; cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT, checkCudnnError(cudnnSetFilter4dDescriptor(
CUDNN_TENSOR_NCHW, f, knDesc, cudnnDataType, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
channelsPerGrp, r, s));
// get bias // get bias
cudnnTensorDescriptor_t biasDesc; cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1)); cudnnDataType, 1, f, 1, 1));
// get convlution descriptor // get convolution descriptor
cudnnConvolutionDescriptor_t convDesc; cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument // TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor( checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode], convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_FLOAT)); cudnnDataType));
if (g > 1) { if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
} }
@ -120,14 +124,14 @@ class convCudnn : public Kernel {
assert(false); assert(false);
} }
// get output descriptor
int outn, outc, outh, outw; int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim( checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc; cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, checkCudnnError(cudnnSetTensor4dDescriptor(
CUDNN_DATA_FLOAT, outn, outc, outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, outn, outc, outh, outw));
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) == IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(), op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape"); "cuDNN output shape mismatches with OP output shape");
@ -151,55 +155,9 @@ class convCudnn : public Kernel {
inData, knDesc, knData, convDesc, inData, knDesc, knData, convDesc,
ALGOS[record->algo], wsData, wsSize, ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData); &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) if (stat != CUDNN_STATUS_SUCCESS) {
return false; return false;
// TODO: }
// // bias
// if (bias != nullptr) {
// auto sz = op.getOutputs()[0]->size();
// // TODO: element wise
// t += sz * 2 / 400;
// }
// // act
// if (act != None) {
// stat = cudnnActivationForward(cudnnHandle(), actDesc,
// &alpha, inDesc, inData,
// &beta, outDesc, outData);
// checkCudaError(cudaDeviceSynchronize());
// end = ch::high_resolution_clock::now();
// if (stat != CUDNN_STATUS_SUCCESS) {
// durtime = INFINITY;
// break;
// }
// t +=
// ch::duration_cast<ch::duration<double>>(end -
// beg).count() * 1000; // ms
// }
// best = ConvResult{durtime, ALGOS[i], wsSize, false};
// // w/ bias & act
// for (int j = 0; j < rounds + warmupRounds; ++j) {
// cudnnStatus_t stat;
// if (j == warmupRounds) {
// checkCudaError(cudaDeviceSynchronize());
// beg = ch::high_resolution_clock::now();
// }
// stat = cudnnConvolutionBiasActivationForward(
// cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
// convDesc, ALGOS[i], wsData, wsSize, &beta, outDesc,
// outData, biasDesc, biasData, actDesc, outDesc, outData);
// if (stat != CUDNN_STATUS_SUCCESS) {
// // checkCudnnError(stat);
// // Do not checkCudnnError since not all algorithms are
// // supported
// durtime_fuse = INFINITY;
// break;
// }
// }
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
@ -238,10 +196,12 @@ class convCudnn : public Kernel {
stat = cudnnGetConvolutionForwardWorkspaceSize( stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc, context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize); ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS) if (stat != CUDNN_STATUS_SUCCESS) {
continue; continue;
if (record.workspaceSize > context->getWorkspaceSize()) }
if (record.workspaceSize > context->getWorkspaceSize()) {
continue; continue;
}
CudaPtr wsData = context->getWorkspace(record.workspaceSize); CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f; float alpha = 1.f, beta = 0.f;
@ -249,8 +209,9 @@ class convCudnn : public Kernel {
context->cudnnHandle(), &alpha, inDesc, inData, knDesc, context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData, knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData); record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) if (stat != CUDNN_STATUS_SUCCESS) {
continue; continue;
}
record.time = timeit( record.time = timeit(
[&]() { [&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha, cudnnConvolutionForward(context->cudnnHandle(), &alpha,
@ -263,8 +224,9 @@ class convCudnn : public Kernel {
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time); // printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result // Update the tune result
if (ret.time > record.time) if (ret.time > record.time) {
ret = record; ret = record;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
@ -291,8 +253,7 @@ class convCudnn : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Conv, convCudnn, "Conv_cuDNN_CUDA");
"Conv_cuDNN_CUDA_Float32");
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json); REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
} // namespace infini } // namespace infini

View File

@ -1,261 +0,0 @@
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "operators/conv.h"
#include <chrono>
#include <functional>
#include <limits>
#include <tuple>
namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
bool fuseAct = false;
void to_json(json &j) override {
j["type"] = 1;
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
}
static PerfRecord from_json(const json &j) {
ConvCuDnnPerfRecordObj tmp;
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
j["data"].get<tuple<int, int, bool, double, size_t>>();
tmp.algo = Algo;
tmp.mode = Mode;
tmp.fuseAct = FuseAct;
tmp.time = Time;
tmp.workspaceSize = WorkspaceSize;
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
}
};
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
class convCudnnFP16 : public Kernel {
static constexpr int N_ALGO = 8;
static constexpr int N_MODE = 2;
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
static constexpr cudnnConvolutionMode_t MODES[2] = {
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
cudnnTensorDescriptor_t>
createCuDNNDescriptor(const Ref<ConvObj> &op,
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
// Bias is not supported yet
if (op->getInputs().size() > 2) {
IT_TODO_HALT();
}
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
int channelsPerGrp = cpg, channels = c;
// get inputs
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, n, channels,
h, w)); /*fp16 type*/
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(
knDesc, CUDNN_DATA_HALF, /*fp16 type*/
CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, 1, f, 1,
1)); /*fp16 type*/
// get convolution descriptor
cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_HALF)); /*fp16 type*/
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
}
// get activation descriptor
cudnnActivationDescriptor_t actDesc;
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
// NOT_PROPAGATE_NAN is requierd by
// cudnnConvolotionBiasActivationForward
switch (op->getAct()) {
case ActType::Relu:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::Sigmoid:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::None:
checkCudnnError(
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
CUDNN_NOT_PROPAGATE_NAN, 0));
break;
default:
assert(false);
}
// get output descriptor
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);
}
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record->workspaceSize;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
inData, knDesc, knData, convDesc,
ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
return false;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
// default ctor
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
auto &record = *recordRef;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, recordRef);
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
if (record.workspaceSize > context->getWorkspaceSize()) {
continue;
}
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
inDesc, inData, knDesc, knData,
convDesc, ALGOS[record.algo],
wsData, record.workspaceSize,
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time) {
ret = record;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
// ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return make_ref<ConvCuDnnPerfRecordObj>(ret);
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
auto record = as<ConvCuDnnPerfRecordObj>(_record);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = cuDNNUnfused(op, record, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16,
"Conv_cuDNN_CUDA_Float16");
} // namespace infini

View File

@ -219,6 +219,7 @@ class convBackwardDataCudnn : public Kernel {
void compute(const Operator &op, const RuntimeObj *context) const override { void compute(const Operator &op, const RuntimeObj *context) const override {
// with paramters in default ctor // with paramters in default ctor
auto record = make_ref<ConvTransposedCuDnnPerfRecordObj>(); auto record = make_ref<ConvTransposedCuDnnPerfRecordObj>();
IT_ASSERT(op->getDType() == DataType::Float32);
compute(op, record, context); compute(op, record, context);
} }
@ -300,8 +301,9 @@ class convBackwardDataCudnn : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, convBackwardDataCudnn,
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32"); "ConvTranposed_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, convBackwardDataCudnn,
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32"); "ConvTranposedNHWC_cuDNN_CUDA");
} // namespace infini } // namespace infini

View File

@ -2,6 +2,7 @@
#include "cuda/cuda_element_wise.h" #include "cuda/cuda_element_wise.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
namespace infini { namespace infini {
class ElementWiseCudnn : public CudaKernelWithoutConfig { class ElementWiseCudnn : public CudaKernelWithoutConfig {
@ -44,22 +45,21 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size())); std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// get inputs // get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW, checkCudnnError(cudnnSetTensor4dDescriptor(
CUDNN_DATA_FLOAT, a[0], a[1], aDesc, CUDNN_TENSOR_NCHW, cudnnDataType, a[0], a[1], a[2], a[3]));
a[2], a[3]));
checkCudnnError(cudnnCreateTensorDescriptor(&bDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&bDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(bDesc, CUDNN_TENSOR_NCHW, checkCudnnError(cudnnSetTensor4dDescriptor(
CUDNN_DATA_FLOAT, b[0], b[1], bDesc, CUDNN_TENSOR_NCHW, cudnnDataType, b[0], b[1], b[2], b[3]));
b[2], b[3]));
// get outputs // get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&cDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&cDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(cDesc, CUDNN_TENSOR_NCHW, checkCudnnError(cudnnSetTensor4dDescriptor(
CUDNN_DATA_FLOAT, c[0], c[1], cDesc, CUDNN_TENSOR_NCHW, cudnnDataType, c[0], c[1], c[2], c[3]));
c[2], c[3]));
// get op descriptor // get op descriptor
cudnnOpTensorDescriptor_t opDesc; cudnnOpTensorDescriptor_t opDesc;
@ -127,40 +127,33 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
if (op->getOpType() == OpType::Div) const int dType = _op->getDType().getIndex();
div_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], if (op->getOpType() == OpType::Div) {
b[2], b[3], c[0], c[1], c[2], c[3]); div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
else if (op->getOpType() == OpType::Pow) b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], } else if (op->getOpType() == OpType::Add) {
b[2], b[3], c[0], c[1], c[2], c[3]); add_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
else if (op->getOpType() == OpType::Add) { b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], } else if (op->getOpType() == OpType::Pow) {
b[2], b[3], c[0], c[1], c[2], c[3]); pow_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else if (op->getOpType() == OpType::Less) { } else if (op->getOpType() == OpType::Less) {
less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], less_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3],
b[2], b[3], c[0], c[1], c[2], c[3]); b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else } else {
IT_TODO_HALT(); IT_TODO_HALT();
}
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Add, AddCudnn, "Add_cuDNN_CUDA");
"Add_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sub, SubCudnn, "Sub_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Mul, MulCudnn, "Mul_cuDNN_CUDA");
"Sub_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Min, MinCudnn, "Min_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Max, MaxCudnn, "Max_cuDNN_CUDA");
"Mul_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Div, ElementWiseCuda, "Div_CUDA");
"Min_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Pow, ElementWiseCuda, "Pow_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Less, ElementWiseCuda, "Less_CUDA");
"Max_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
"Div_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
"Add_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
"Pow__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda,
"Less__CUDA_Int64");
}; // namespace infini }; // namespace infini

View File

@ -1,4 +1,5 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include <math.h> #include <math.h>
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
@ -129,44 +130,113 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
} }
} }
#define CASE(OP, T) \
_##OP##_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
#define SWITCH_DTYPE(OP, DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(OP, 1) \
break; \
case 2: \
CASE(OP, 2) \
break; \
case 3: \
CASE(OP, 3) \
break; \
case 4: \
CASE(OP, 4) \
break; \
case 5: \
CASE(OP, 5) \
break; \
case 6: \
CASE(OP, 6) \
break; \
case 7: \
CASE(OP, 7) \
break; \
case 10: \
CASE(OP, 10) \
break; \
case 11: \
CASE(OP, 11) \
break; \
case 12: \
CASE(OP, 12) \
break; \
case 13: \
CASE(OP, 13) \
break; \
case 16: \
CASE(OP, 16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini { namespace infini {
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_div_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, SWITCH_DTYPE(div, dType)
b2, b3, c0, c1, c2, c3);
} }
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, SWITCH_DTYPE(add, dType)
b1, b2, b3, c0, c1, c2, c3);
} }
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, if (dType == 1) {
b2, b3, c0, c1, c2, c3); _pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 3) {
_pow_kernel<int8_t><<<gridsize, blocksize>>>(
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 10) {
int a_size = a0 * a1 * a2 * a3;
int b_size = b0 * b1 * b2 * b3;
int c_size = c0 * c1 * c2 * c3;
vector<float> a_float(a_size);
vector<float> b_float(b_size);
vector<float> c_float(c_size);
for (int i = 0; i < a_size; ++i) {
a_float[i] = __half2float(((half *)a)[i]);
}
for (int i = 0; i < b_size; ++i) {
b_float[i] = __half2float(((half *)b)[i]);
}
_pow_kernel<float><<<gridsize, blocksize>>>(
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
for (int i = 0; i < c_size; ++i) {
((half *)c)[i] = __float2half(c_float[i]);
}
} else {
IT_TODO_HALT();
}
} }
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_less_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, SWITCH_DTYPE(less, dType)
b1, b2, b3, c0, c1, c2, c3);
} }
}; // namespace infini }; // namespace infini

View File

@ -25,12 +25,12 @@ class ExpandCuda : public CudaKernelWithoutConfig {
inputShape.data[i] = in_Shape[i]; inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i]; outputsize *= out_Shape[i];
} }
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize, const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape); inputShape, outputShape);
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda, REGISTER_KERNEL(Device::CUDA, OpType::Expand, ExpandCuda, "Expand_CUDA");
"Expand_CUDA_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,12 +1,14 @@
#include "core/common.h" #include "core/common.h"
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h" #include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; } constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _expandKernel(float *input, float *output, int nDims, template <class T>
__global__ void _expandKernel(void *input, void *output, int nDims,
int outputsize, infini::SmallArray inputShape, int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) { infini::SmallArray outputShape) {
@ -33,17 +35,64 @@ __global__ void _expandKernel(float *input, float *output, int nDims,
temp *= inputShape.data[i]; temp *= inputShape.data[i];
v = v / outputShape.data[i]; v = v / outputShape.data[i];
} }
output[outputIdx] = input[inputIdx]; ((T *)output)[outputIdx] = ((T *)input)[inputIdx];
} }
} }
namespace infini { namespace infini {
void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) { #define CASE(T) \
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize, SWITCH_DTYPE(dType)
inputShape, outputShape);
} }
} // namespace infini } // namespace infini

View File

@ -8,6 +8,7 @@ class ExtendCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ExtendObj>(_op); auto op = as<ExtendObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto inData = op->getInputs(0)->getRawDataPtr<float *>(); auto inData = op->getInputs(0)->getRawDataPtr<float *>();
auto outData = op->getOutputs()[0]->getRawDataPtr<float *>(); auto outData = op->getOutputs()[0]->getRawDataPtr<float *>();
int blockSize = 1; int blockSize = 1;
@ -22,6 +23,5 @@ class ExtendCuda : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda, REGISTER_KERNEL(Device::CUDA, OpType::Extend, ExtendCuda, "Extend_CUDA");
"Extend_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -15,12 +15,23 @@ class GatherCuda : public CudaKernelWithoutConfig {
GatherMetaData metaData; GatherMetaData metaData;
initGatherMetaData(metaData, op); initGatherMetaData(metaData, op);
auto inData = input->getRawDataPtr<float *>(); void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
auto outData = op->getOutput()->getRawDataPtr<float *>(); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
gather_kernel(inData, outData, metaData, op->getOutput()->size());
if (op->getDType() == DataType::Float32) {
gather_kernel<float>((float *)inputData, (float *)outputData,
metaData, op->getOutput()->size());
} else if (op->getDType() == DataType::Float16) {
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
op->getOutput()->size());
} else if (op->getDType() == DataType::Int8) {
gather_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData,
metaData, op->getOutput()->size());
} else {
IT_ASSERT(false);
}
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda, REGISTER_KERNEL(Device::CUDA, OpType::Gather, GatherCuda, "Gather_CUDA");
"Gather_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -28,27 +28,32 @@ __device__ T gatheredOffset2Offset(int gOffset,
return offset; return offset;
} }
template <typename T> template <typename dataT, typename T>
__global__ void _gather_kernel(float *in, float *out, __global__ void _gather_kernel(dataT *in, dataT *out,
infini::GatherMetaData metaData, size_t num) { infini::GatherMetaData metaData, size_t num) {
T tid = threadIdx.x + blockIdx.x * blockDim.x; T tid = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; if (tid < num) {
while (tid < num) {
T offset = gatheredOffset2Offset<T>(tid, metaData); T offset = gatheredOffset2Offset<T>(tid, metaData);
out[tid] = in[offset]; out[tid] = in[offset];
tid += stride;
} }
} }
namespace infini { namespace infini {
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) { template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
int blockSize = 32 * 16; int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.indexType == DataType::Int64) { if (metaData.indexType == DataType::Int64) {
_gather_kernel<int64_t> _gather_kernel<T, int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num); <<<gridSize, blockSize>>>(in, out, metaData, num);
} else { } else {
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num); _gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num);
} }
} }
template void gather_kernel<float>(float *in, float *out,
GatherMetaData metaData, size_t num);
template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData,
size_t num);
template void gather_kernel<int8_t>(int8_t *in, int8_t *out,
GatherMetaData metaData, size_t num);
} // namespace infini } // namespace infini

View File

@ -21,8 +21,7 @@ class GatherElementsCuda : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, GatherElementsCuda,
GatherElementsCuda, "GatherELements_CUDA_Float32"); "GatherELements_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32,
GatherElementsCuda, "GatherElements_CUDA_Int32");
} // namespace infini } // namespace infini

View File

@ -24,22 +24,41 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
int dimsize = dims[op->getAxis()]; int dimsize = dims[op->getAxis()];
int size = op->getOutput(0)->size(); int size = op->getOutput(0)->size();
int scaleSize = op->getInputs(1)->size(); int scaleSize = op->getInputs(1)->size();
if (op->numInputs() == 3) { if (op->getDType() == DataType::Float32) {
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>()); if (op->numInputs() == 3) {
int biasSize = op->getInputs(2)->size(); void *const biasData =
// printf("kernel bias:true:%d\n", 1); (op->getInputs(2)->getRawDataPtr<void *>());
LaynormKernel((float *)inputData, (float *)scaleData, eps, size, int biasSize = op->getInputs(2)->size();
scaleSize, dimsize, stride, (float *)outputData, // printf("kernel bias:true:%d\n", 1);
(float *)biasData, biasSize); LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData,
(float *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData);
}
} else if (op->getDType() == DataType::Float16) {
if (op->numInputs() == 3) {
void *const biasData =
(op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData,
(half *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData);
}
} else { } else {
// printf("kernel bias:false:%d\n", 0); IT_ASSERT(false);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData);
} }
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, LayerNormCuda,
LayerNormCuda, "LayerNorm_CUDA_Float32"); "LayerNorm_CUDA");
}; // namespace infini }; // namespace infini

View File

@ -1,43 +1,41 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include <cub/cub.cuh> #include <cub/cub.cuh>
template <int BLOCK_DIM> template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ __launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale, void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int dimsize, const int stride, float *output, const int stride, T *output, const T eps,
const float eps, int scaleSize, const float *bias, int scaleSize, const T *bias, int biasSize) {
int biasSize) {
// len(scale) = len(bias) = dimsize // len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride; int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp; int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu; __shared__ T mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize; mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2; __shared__ T sigma2;
float sigma2Block = T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize; sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
if (biasSize == dimsize) { if (biasSize == dimsize) {
@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM]; bias[threadIdx.x + ph * BLOCK_DIM];
} }
} else { } else {
@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM]; bias[threadIdx.x + ph * BLOCK_DIM];
} }
} }
@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0]; bias[0];
} }
} else { } else {
@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0]; bias[0];
} }
} }
} }
} }
//----------------- //-----------------
template <int BLOCK_DIM> template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ __launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale, void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int dimsize, const int stride, float *output, const int stride, T *output, const T eps,
const float eps, int scaleSize) { int scaleSize) {
// len(scale) = len(bias) = dimsize // len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride; int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp; int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu; __shared__ T mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize; mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2; __shared__ T sigma2;
float sigma2Block = T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize; sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
if (scaleSize == dimsize) { if (scaleSize == dimsize) {
@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
sqrt(sigma2 + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
} }
} else { } else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
sqrt(sigma2 + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
} }
} }
} }
@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) {
} }
return val; return val;
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y> template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale, __global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride, const int dimsize, const int stride,
float *output, const float eps, int scaleSize, T *output, const T eps, int scaleSize,
int otherSize, const float *bias, int otherSize, const T *bias, int biasSize) {
int biasSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) { if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y]; __shared__ T muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y]; __shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
} }
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial); muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize; muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]); muTotal[threadIdx.y]);
} }
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial); sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize; sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
if (biasSize == dimsize) { if (biasSize == dimsize) {
@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x]; bias[threadIdx.x + ph * BLOCK_DIM_x];
} }
} else { } else {
@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] * scale[0] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x]; bias[threadIdx.x + ph * BLOCK_DIM_x];
} }
} }
@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0]; bias[0];
} }
} else { } else {
@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] * scale[0] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0]; bias[0];
} }
} }
} }
} }
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y> template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale, __global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride, const int dimsize, const int stride,
float *output, const float eps, int scaleSize, T *output, const T eps, int scaleSize,
int otherSize) { int otherSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) { if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y]; __shared__ T muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y]; __shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
} }
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial); muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize; muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]); muTotal[threadIdx.y]);
} }
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial); sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize; sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
if (scaleSize == dimsize) { if (scaleSize == dimsize) {
@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
} }
} else { } else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
} }
} }
} }
@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<1024> blockLaynormKernel<float, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output, <<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize); eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 15) { } else if (dimsize > 15) {
@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 7) { } else if (dimsize > 7) {
@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else { } else {
@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} }
@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>( blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize); input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}
//-----------------
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
}
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} }
} }

View File

@ -2,6 +2,7 @@
#include "core/kernel.h" #include "core/kernel.h"
#include "cuda/cuda_expand.h" #include "cuda/cuda_expand.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h" #include "utils/small_array.h"
namespace infini { namespace infini {
@ -48,11 +49,12 @@ class matmulCublas : public Kernel {
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N; auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n, const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
ldc = n; ldc = n;
float alpha = 1.f, beta = 0.f; float alpha_naive = 1.f, beta_naive = 0.f;
if (op->numInputs() == 2) { // no bias auto dataType = op->getDType();
beta = 0.f; auto cuDataType = cublasDataTypeConvert(dataType);
} else { // broadcast bias to output IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
beta = 1.f; if (op->numInputs() == 3) { // have bias
beta_naive = 1.f;
auto inC = op->getInputs(2); auto inC = op->getInputs(2);
auto out = op->getOutput(); auto out = op->getOutput();
SmallArray inputShape, outputShape; SmallArray inputShape, outputShape;
@ -69,8 +71,9 @@ class matmulCublas : public Kernel {
if (i >= offset) if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset]; inputShape.data[i] = inC->getDims()[i - offset];
} }
expandKernel(inC->getRawDataPtr<float *>(), const int dType = dataType.getIndex();
out->getRawDataPtr<float *>(), nDims, outputsize, expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape); inputShape, outputShape);
} }
// TODO:use compute type // TODO:use compute type
@ -89,16 +92,38 @@ class matmulCublas : public Kernel {
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1)) (dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1 ? 0 // Broadcast the batch dimension if batch size is 1
: n * k; : n * k;
stat = cublasGemmStridedBatchedEx( if (dataType == DataType::Float16) {
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, half alpha_half = static_cast<half>(alpha_naive);
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA, half beta_half = static_cast<half>(beta_naive);
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, stat = cublasGemmStridedBatchedEx(
(cublasGemmAlgo_t)record->algo); context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
}
} else { } else {
stat = cublasGemmEx( if (dataType == DataType::Float16) {
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, half alpha_half = static_cast<half>(alpha_naive);
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData, half beta_half = static_cast<half>(beta_naive);
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo); stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_half, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_half,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_naive, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_naive,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
}
} }
// if (stat != CUBLAS_STATUS_SUCCESS) // if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat); // cout << cublasGetErrorString(stat);
@ -140,8 +165,9 @@ class matmulCublas : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas, REGISTER_KERNEL(Device::CUDA, OpType::MatMul, matmulCublas,
"Matmul_cuBLAS_CUDA_Float32"); "Matmul_cuBLAS_CUDA");
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json); REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);
}; // namespace infini }; // namespace infini

View File

@ -229,9 +229,8 @@ class MemboundTVMExtractSource : public Kernel {
} }
}; };
// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMExtractSource,
// MemboundTVMExtractSource, "Memobund_TVM_Ansor_extract_source");
// "Memobund_TVM_Ansor_extract_source");
}; // namespace infini }; // namespace infini
#endif #endif

View File

@ -216,9 +216,9 @@ class MemboundTVMPackedFunction : public Kernel {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMPackedFunction,
MemboundTVMPackedFunction,
"Memobund_TVM_Ansor_packed_funciton"); "Memobund_TVM_Ansor_packed_funciton");
}; // namespace infini }; // namespace infini
#endif #endif

View File

@ -39,10 +39,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda, REGISTER_KERNEL(Device::CUDA, OpType::Slice, SliceCuda, "Slice__CUDA");
"Slice__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda, REGISTER_KERNEL(Device::CUDA, OpType::Pad, PadCuda, "Pad__CUDA");
"Slice__CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
"Pad__CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -1,6 +1,7 @@
#include "core/data_type.h" #include "core/data_type.h"
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_pad_slice.h" #include "cuda/cuda_pad_slice.h"
#include "cuda/cuda_utility.h"
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset, __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
TransMetaData metaData, TransMetaData metaData,
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
} }
template <typename T> template <typename T>
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData, __global__ void _pad_slice_kernel(void *part, void *whole,
int nDims, int num, bool isPad) { TransMetaData metaData, int nDims, int num,
bool isPad) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num) if (tid >= num) {
return; return;
}
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
while (tid < num) { while (tid < num) {
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims); int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
if (isPad) if (isPad) {
if (offset < 0) if (offset < 0) {
whole[tid] = 0; ((T *)whole)[tid] = static_cast<T>(0.f);
else } else {
whole[tid] = part[offset]; ((T *)whole)[tid] = ((T *)part)[offset];
else if (offset >= 0) }
part[offset] = whole[tid]; } else if (offset >= 0) {
((T *)part)[offset] = ((T *)whole)[tid];
}
tid += stride; tid += stride;
} }
} }
namespace infini { namespace infini {
#define CASE(T) \
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
partData, wholeData, metadata, nDims, num, isPad);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
void pad_slice_kernel(void *partData, void *wholeData, void pad_slice_kernel(void *partData, void *wholeData,
const TransMetaData &metadata, int nDims, int num, const TransMetaData &metadata, int nDims, int num,
bool isPad) { bool isPad) {
int blockSize = 32 * 16; int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
if (metadata.DType == DataType::Int64.getIndex()) { int dType = metadata.DType;
_pad_slice_kernel<int64_t> SWITCH_DTYPE(dType)
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
metadata, nDims, num, isPad);
} else if (metadata.DType == DataType::Float32.getIndex()) {
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
}
} }
} // namespace infini } // namespace infini

View File

@ -8,6 +8,7 @@ class poolingCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op); auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -76,8 +77,9 @@ class avgPoolCudnn : public poolingCudnn {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn, REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, maxPoolCudnn,
"MaxPool_cuDNN_CUDA_Float32"); "MaxPool_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, avgPoolCudnn,
avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32"); "AvgPool_cuDNN_CUDA");
}; // namespace infini }; // namespace infini

View File

@ -40,8 +40,7 @@ class RecvNCCL : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL, REGISTER_KERNEL(Device::CUDA, OpType::Recv, RecvNCCL, "Recv_NCCL_CUDA");
"Recv_NCCL_CUDA_Float32");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -1,6 +1,7 @@
#include "operators/reduce.h" #include "operators/reduce.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
namespace infini { namespace infini {
class ReduceCudnnBase : public CudaKernelWithoutConfig { class ReduceCudnnBase : public CudaKernelWithoutConfig {
@ -46,12 +47,12 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
cudnnTensorDescriptor_t outDesc; cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
if (nInDims > 3) { if (nInDims > 3) {
checkCudnnError(cudnnSetTensorNdDescriptor( checkCudnnError(cudnnSetTensorNdDescriptor(
inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray)); inDesc, cudnnDataType, nInDims, inDimArray, inStrideArray));
checkCudnnError( checkCudnnError(cudnnSetTensorNdDescriptor(
cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims, outDesc, cudnnDataType, nInDims, outDimArray, outStrideArray));
outDimArray, outStrideArray));
} else { } else {
int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1}; int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1};
for (int i = 0; i < nInDims; ++i) { for (int i = 0; i < nInDims; ++i) {
@ -62,20 +63,19 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
} }
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1], inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, idims[0], idims[1],
idims[2], idims[3])); idims[2], idims[3]));
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0], outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, odims[0], odims[1],
odims[1], odims[2], odims[3])); odims[2], odims[3]));
} }
// get reduce descriptor // get reduce descriptor
cudnnReduceTensorDescriptor_t reduceDesc; cudnnReduceTensorDescriptor_t reduceDesc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc)); checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
checkCudnnError(cudnnSetReduceTensorDescriptor( checkCudnnError(cudnnSetReduceTensorDescriptor(
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT, reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
CUDNN_32BIT_INDICES));
// get workspace // get workspace
size_t workspaceSize = 0; size_t workspaceSize = 0;
@ -120,8 +120,9 @@ class ReduceSumCudnn : public ReduceCudnnBase {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, ReduceMeanCudnn,
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32"); "ReduceMean_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32, REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, ReduceSumCudnn,
ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32"); "ReduceSum_cuDNN_CUDA");
}; // namespace infini }; // namespace infini

View File

@ -11,19 +11,12 @@ class CopyCuda : public CudaKernelWithoutConfig {
} }
}; };
// reshape/flatten/identity all act as copying from input to output. // reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
"Reshape_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Reshape, CopyCuda, "Reshape_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
"Reshape_CUDA_Int64"); REGISTER_KERNEL(Device::CUDA, OpType::Flatten, CopyCuda, "Flatten_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Identity, CopyCuda, "Identity_CUDA");
"Reshape_CUDA_Int32"); REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, CopyCuda, "Squeeze_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, CopyCuda, "Unsqueeze_CUDA");
"Flatten_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
"Squeeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
"Unsqueeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
"Identity_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -6,6 +6,7 @@ class ResizeCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ResizeObj>(_op); auto op = as<ResizeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto in = op->getInputs(0); auto in = op->getInputs(0);
auto out = op->getOutputs()[0]; auto out = op->getOutputs()[0];
@ -48,7 +49,6 @@ class ResizeCuda : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda, REGISTER_KERNEL(Device::CUDA, OpType::Resize, ResizeCuda, "Resize_CUDA");
"Resize_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -36,8 +36,7 @@ class SendNCCL : public CudaKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL, REGISTER_KERNEL(Device::CUDA, OpType::Send, SendNCCL, "Send_NCCL_CUDA");
"Send_NCCL_CUDA_Float32");
} // namespace infini } // namespace infini
#endif #endif

View File

@ -20,11 +20,17 @@ class SoftmaxCuda : public CudaKernelWithoutConfig {
int stride = op->getInputs(0)->getStride().at(op->getAxis()); int stride = op->getInputs(0)->getStride().at(op->getAxis());
int num_blocks = size / dimsize; int num_blocks = size / dimsize;
softmax_kernel(num_blocks, (float *)input, (float *)output, size, if (op->getDType() == DataType::Float32) {
dimsize, stride); softmax_kernel(num_blocks, (float *)input, (float *)output, size,
dimsize, stride);
} else if (op->getDType() == DataType::Float16) {
softmax_kernel(num_blocks, (half *)input, (half *)output, size,
dimsize, stride);
} else {
IT_ASSERT(false);
}
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda, REGISTER_KERNEL(Device::CUDA, OpType::Softmax, SoftmaxCuda, "Softmax_CUDA");
"Softmax_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -1,6 +1,5 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include <cub/cub.cuh> #include <cub/cub.cuh>
struct __align__(8) DataMaxSum { // update the global max and sum, store the struct __align__(8) DataMaxSum { // update the global max and sum, store the
// output at max_tmp and sum_tmp // output at max_tmp and sum_tmp
float max_tmp; // store max float max_tmp; // store max
@ -16,9 +15,9 @@ __device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
return bigger; return bigger;
} }
template <int BLOCK_DIM> template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
float *__restrict input, float *__restrict output, int size, int dimsize, T *__restrict input, T *__restrict output, int size, int dimsize,
int stride) { // if set axis = 1, inputShape=[I,J,K,S] int stride) { // if set axis = 1, inputShape=[I,J,K,S]
// tid = i(JKS) + j(KS) + k(S) + s // tid = i(JKS) + j(KS) + k(S) + s
@ -33,15 +32,33 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
dms_partial.max_tmp = -__FLT_MAX__; dms_partial.max_tmp = -__FLT_MAX__;
dms_partial.sum_tmp = 0.0f; dms_partial.sum_tmp = 0.0f;
DataMaxSum dms_input; DataMaxSum dms_input;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
dms_input.max_tmp = if (threadIdx.x < remain) {
input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; for (int ind = 0; ind < step; ind++) {
dms_input.max_tmp =
input[tid + (threadIdx.x * step + ind) * stride];
dms_input.sum_tmp = 1.0f; dms_input.sum_tmp = 1.0f;
dms_partial = reduce_dms_op(dms_partial, dms_partial =
dms_input); // reduce the data to one block reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
dms_input.max_tmp =
input[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
} }
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ DataMaxSum dms_total; __shared__ DataMaxSum dms_total;
@ -53,12 +70,102 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
} }
__syncthreads(); __syncthreads();
//----------------- //-----------------
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { output[tid + (threadIdx.x * step + ind) * stride] =
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = __expf(static_cast<float>(
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - input[tid + (threadIdx.x * step + ind) * stride]) -
dms_total.max_tmp) * dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp); __fdividef(1.0F, dms_total.sum_tmp);
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
__expf(static_cast<float>(
input[tid +
(remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride]) -
dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
}
}
template <typename T, int BLOCK_DIM, int numPerThread>
__global__ void
_blockSoftmaxKernel(T *__restrict input, T *__restrict output, int size,
int dimsize,
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
// tid = i(JKS) + j(KS) + k(S) + s
// blockDim.x = size/dimsize = IKS
// blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s
int tid =
blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
dimsize; // now, tid = i(JKS) + k(S) + s;
int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
float dataPerThread[numPerThread];
DataMaxSum dms_partial;
dms_partial.max_tmp = -__FLT_MAX__;
dms_partial.sum_tmp = 0.0f;
DataMaxSum dms_input;
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
dataPerThread[ind] =
input[tid + (threadIdx.x * step + ind) * stride];
dms_input.max_tmp = dataPerThread[ind];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
dataPerThread[ind] =
input[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride];
dms_input.max_tmp = dataPerThread[ind];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
}
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ DataMaxSum dms_total;
DataMaxSum dms_block =
BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
dms_total = dms_block;
}
__syncthreads();
//-----------------
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
__expf(dataPerThread[ind] - dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
__expf(dataPerThread[ind] - dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
} }
} }
@ -81,14 +188,14 @@ __inline__ __device__ T WarpAllReduce(T val) {
} }
return val; return val;
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void _warpSoftmaxKernel(float *__restrict input, template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y, int numPerThreadx>
float *__restrict output, int size, __global__ void _warpSoftmaxKernel(T *__restrict input, T *__restrict output,
int dimsize, int stride) { int size, int dimsize, int stride) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int otherSize = size / dimsize; int otherSize = size / dimsize;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
float dataPerThreadx[numPerThreadx];
if (otherIdx < otherSize) { if (otherIdx < otherSize) {
__shared__ float max_total[BLOCK_DIM_y]; __shared__ float max_total[BLOCK_DIM_y];
@ -96,9 +203,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
float max_data = -__FLT_MAX__; float max_data = -__FLT_MAX__;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
max_data = dataPerThreadx[ph] =
max(max_data, input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]); max_data = max(max_data, dataPerThreadx[ph]);
} }
max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data); max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);
@ -110,9 +217,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
float sum_data = 0.0f; float sum_data = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sum_data += dataPerThreadx[ph] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - __expf(dataPerThreadx[ph] - max_total[threadIdx.y]);
max_total[threadIdx.y]); sum_data += dataPerThreadx[ph];
} }
sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data); sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);
@ -124,9 +231,7 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y]);
max_total[threadIdx.y]) *
__fdividef(1.0F, sum_total[threadIdx.y]);
} }
} }
} }
@ -137,10 +242,35 @@ namespace infini {
void softmax_kernel(int num_blocks, float *input, float *output, int size, void softmax_kernel(int num_blocks, float *input, float *output, int size,
int dimsize, int stride) { int dimsize, int stride) {
if (dimsize > 1024) { if (dimsize > 1024 * 128) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<1024> _blockSoftmaxKernel<float, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
@ -149,7 +279,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<32, 32> _warpSoftmaxKernel<float, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
@ -158,7 +288,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<16, 64> _warpSoftmaxKernel<float, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
@ -167,7 +297,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<8, 128> _warpSoftmaxKernel<float, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
@ -176,7 +306,79 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<4, 256> _warpSoftmaxKernel<float, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
}
}
//------------------
void softmax_kernel(int num_blocks, half *input, half *output, int size,
int dimsize, int stride) {
if (dimsize > 1024 * 128) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} }
} }

View File

@ -7,7 +7,8 @@
namespace infini { namespace infini {
class CudaCompute { class CudaCompute {
void initComposedTensorMetadata(ComposedTensorMetadata &metadata, template <typename T>
void initComposedTensorMetadata(ComposedTensorMetadata<T> &metadata,
Tensor tensor) const { Tensor tensor) const {
int nDims = tensor->getRank(); int nDims = tensor->getRank();
auto strides = tensor->getStride(); auto strides = tensor->getStride();
@ -16,10 +17,10 @@ class CudaCompute {
metadata.dimSize[i] = tensor->getDims().at(i); metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i); metadata.stride[i] = strides.at(i);
} }
metadata.data = tensor->getRawDataPtr<float *>(); metadata.data = tensor->getRawDataPtr<T *>();
} }
template <typename T>
void initElementTensorMetadata(ElementTensorMetadata &metadata, void initElementTensorMetadata(ElementTensorMetadata<T> &metadata,
TensorVec tensors, int idx, int dim, TensorVec tensors, int idx, int dim,
int &dimBgIdx, int &batchCounter) const { int &dimBgIdx, int &batchCounter) const {
int nTensors = tensors.size(); int nTensors = tensors.size();
@ -27,7 +28,7 @@ class CudaCompute {
++batchCounter) { ++batchCounter) {
auto tensor = tensors.at(idx + batchCounter); auto tensor = tensors.at(idx + batchCounter);
auto dimSize = tensor->getDims()[dim]; auto dimSize = tensor->getDims()[dim];
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>(); metadata.data[batchCounter] = tensor->getRawDataPtr<T *>();
metadata.dimBgNo[batchCounter] = dimBgIdx; metadata.dimBgNo[batchCounter] = dimBgIdx;
metadata.dimSize[batchCounter] = dimSize; metadata.dimSize[batchCounter] = dimSize;
metadata.nElements[batchCounter] = tensor->size(); metadata.nElements[batchCounter] = tensor->size();
@ -36,17 +37,17 @@ class CudaCompute {
} }
public: public:
template <typename T>
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim, void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
int nDims, bool isSplit) const { int nDims, bool isSplit) const {
IT_ASSERT(nDims <= DIM_MAX_SIZE); IT_ASSERT(nDims <= DIM_MAX_SIZE);
ComposedTensorMetadata<T> composedMetadata;
ComposedTensorMetadata composedMetadata; initComposedTensorMetadata<T>(composedMetadata, composedTensor);
initComposedTensorMetadata(composedMetadata, composedTensor);
int dimBgNo = 0; int dimBgNo = 0;
int nElemets = elementsTensor.size(); int nElemets = elementsTensor.size();
for (int i = 0; i < nElemets; i += BATCH_SIZE) { for (int i = 0; i < nElemets; i += BATCH_SIZE) {
ElementTensorMetadata elemMetadata; ElementTensorMetadata<T> elemMetadata;
int batchCounter = 0; int batchCounter = 0;
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim, initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
dimBgNo, batchCounter); dimBgNo, batchCounter);
@ -74,23 +75,38 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
} }
} }
} }
do_compute(_op->getOutput(), _op->getInputs(), if (_op->getDType() == DataType::Float32) {
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(), do_compute<float>(_op->getOutput(), _op->getInputs(),
false); as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
} else {
IT_ASSERT(false);
}
} }
}; };
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig { class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
do_compute(_op->getInputs(0), _op->getOutputs(), if (_op->getDType() == DataType::Float32) {
as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(), do_compute<float>(_op->getInputs(0), _op->getOutputs(),
true); as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
} else {
IT_ASSERT(false);
}
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda, REGISTER_KERNEL(Device::CUDA, OpType::Concat, ConcatCuda, "Concat_CUDA");
"Concat_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Split, SplitCuda, "Split_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda,
"Split_CUDA_Float32");
} // namespace infini } // namespace infini

Some files were not shown because too many files have changed in this diff Show More