diff --git a/src/kernels/bang/dropout.cc b/src/kernels/bang/dropout.cc index 5b52f5ce..bedf4d11 100644 --- a/src/kernels/bang/dropout.cc +++ b/src/kernels/bang/dropout.cc @@ -7,6 +7,7 @@ class DropoutCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); void *const iData = (op->getInputs(0)->getRawDataPtr()); @@ -46,7 +47,6 @@ class DropoutCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Dropout, DataType::Float32, DropoutCnnl, - "Dropout_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/slice.cc b/src/kernels/bang/slice.cc index 5cc772aa..b149e0a6 100644 --- a/src/kernels/bang/slice.cc +++ b/src/kernels/bang/slice.cc @@ -7,6 +7,7 @@ class SliceCnnl : public BangKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); auto starts = op->getStarts(); @@ -59,6 +60,6 @@ class SliceCnnl : public BangKernelWithoutConfig { } }; -REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, +REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl, "Slice_cnnl_BANG_Float32"); }; // namespace infini