forked from jiuyuan/InfiniTensor
add dropout
This commit is contained in:
parent
5ac0ab442f
commit
cada8ec6c8
|
@ -101,6 +101,8 @@ class GraphHandlerObj {
|
|||
std::string mode);
|
||||
Tensor lrn(Tensor input, Tensor output, float alpha, float beta, float bias,
|
||||
int size);
|
||||
Tensor dropout(Tensor input, Tensor output, Tensor mask, float ratio,
|
||||
bool training_mode);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
|
|
|
@ -673,23 +673,19 @@ class OnnxStub:
|
|||
else None,
|
||||
)
|
||||
elif node.op_type == "Dropout":
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.dropout(
|
||||
attributes = _parse_attribute(
|
||||
node, {"ratio": 0.5, "training_mode: 0"})
|
||||
(ratio, training_mode) = (
|
||||
attribute[name]
|
||||
for name in ["ratio", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.dropout(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1])
|
||||
if len(node.output) > 1
|
||||
else None,
|
||||
_parse_data(data[node.input[1]])[0]
|
||||
if len(node.input) > 1
|
||||
else 0.5,
|
||||
_parse_data(data[node.input[2]])[0]
|
||||
if len(node.input) > 2
|
||||
else False,
|
||||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
tensors.get(node.output[1]),
|
||||
ratio,
|
||||
(bool)training_mode,
|
||||
)
|
||||
elif node.op_type == "Cast":
|
||||
tensors[node.output[0]] = self.handler.cast(
|
||||
tensors[node.input[0]],
|
||||
|
@ -1225,6 +1221,18 @@ class OnnxStub:
|
|||
size,
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpTypeId.Dropout:
|
||||
ratio, training_mode = backend.dropout_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
ratio,
|
||||
training_mode,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported OpType", ty)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "operators/broadcast.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/dropout.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
|
@ -533,6 +534,20 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::dropout(Tensor input, Tensor output, Tensor mask,
|
||||
float ratio, bool training_mode) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<DropoutObj>(std::move(input), output, mask, ratio,
|
||||
training_mode);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<DropoutObj>(std::move(input), output, mask, ratio,
|
||||
training_mode)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
auto oType = DataType(to);
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/dropout.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/lrn.h"
|
||||
|
@ -306,6 +307,14 @@ static std::tuple<float, float, float, int> lrn_attrs_of(Operator op) {
|
|||
return std::make_tuple(alpha, beta, bias, size);
|
||||
}
|
||||
|
||||
static std::tuple<float, bool> dropout_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Dropout);
|
||||
auto dropout = dynamic_cast<const DropoutObj *>(op.get());
|
||||
auto ratio = dropout->getRatio();
|
||||
auto train = dropout->getTrainingMode();
|
||||
return std::make_tuple(ratio, train);
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
|
@ -343,7 +352,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(flatten_axis_of)
|
||||
.FUNCTION(cast_to_of)
|
||||
.FUNCTION(depth_to_space_attrs_of)
|
||||
.FUNCTION(lrn_attrs_of);
|
||||
.FUNCTION(lrn_attrs_of)
|
||||
.FUNCTION(dropout_attrs_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
@ -529,6 +539,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("erf", &Handler::erf, policy::move)
|
||||
.def("where", &Handler::where, policy::move)
|
||||
.def("lrn", &Handler::lrn, policy::move)
|
||||
.def("dropout", &Handler::dropout, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#include "operators/dropout.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class DropoutCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DropoutObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const iData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const oData = (op->getOutput(0)->getRawDataPtr<void *>());
|
||||
void *const mData = (op->getOutput(1)->getRawDataPtr<void *>());
|
||||
|
||||
cnnlRandGenerator_t generator;
|
||||
cnnlRandCreateGenerator(&generator, CNNL_RAND_RNG_FAST);
|
||||
cnnlRandSetPseudoRandomGeneratorSeed(generator, 233);
|
||||
cnnlRandSetMTGP32Period(generator, CNNL_RAND_MTGP32_P11213);
|
||||
|
||||
cnnlTensorDescriptor_t oDesc;
|
||||
auto oDim = op->getOutput(0)->getDims();
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&oDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(oDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, oDim.size(),
|
||||
oDim.data()));
|
||||
|
||||
auto ratio = op->getRatio();
|
||||
// auto train = op->getTrainingMode();
|
||||
|
||||
cnnlStatus_t stat =
|
||||
cnnlFusedDropout_v2(context->cnnlHandle(), generator, oDesc, iData,
|
||||
ratio, NULL, oDesc, mData, oDesc, oData);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(oDesc));
|
||||
checkCnnlError(cnnlRandDestroyGenerator(generator));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Dropout, DataType::Float32, DropoutCnnl,
|
||||
"Dropout_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
Loading…
Reference in New Issue