This commit is contained in:
wanghailu 2024-01-15 05:21:38 +00:00
parent 19d3e831f9
commit 8baa34a1d2
2 changed files with 4 additions and 3 deletions

View File

@ -7,6 +7,7 @@ class DropoutCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<DropoutObj>(_op); auto op = as<DropoutObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const iData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const iData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -46,7 +47,6 @@ class DropoutCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DataType::Float32, DropoutCnnl, REGISTER_KERNEL(Device::BANG, OpType::Dropout, DropoutCnnl, "Dropout_cnnl_BANG_Float32");
"Dropout_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -7,6 +7,7 @@ class SliceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SliceObj>(_op); auto op = as<SliceObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto starts = op->getStarts(); auto starts = op->getStarts();
@ -59,6 +60,6 @@ class SliceCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, REGISTER_KERNEL(Device::BANG, OpType::Slice, SliceCnnl,
"Slice_cnnl_BANG_Float32"); "Slice_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini