InfiniTensor/include/operators/dropout.h

53 lines
2.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Copy a tensor along a centain dimension for multiple times.
*/
class DropoutObj : public OperatorObj {
float ratio;
// bool training_mode; // TODO must be false.
public:
/**
* @brief Dropout takes an input floating-point tensor, an input ratio
* (floating-point scalar) and an input training_mode (boolean scalar). It
* produces two tensor outputs, output (floating-point tensor) and mask
* (bool tensor). If training_mode is true then the output Y will be a
* random dropout; Note that this Dropout scales the masked input data by
* the following equation, so to convert the trained model into inference
* mode, the user can simply not pass training_mode input or set it to
* false.
*
* @param graph The computation graph that this operator belongs to.
* @param data The input tensor.
* @param output The output tensor.
* @param mask The mask tensor.
* @param ratio The ratio of random dropout, with value in [0, 1). If this
* input was not set, or if it was set to 0, the output would be a simple
* copy of the input. If its non-zero, output will be a random dropout of
* the scaled input, which is typically the case during training.
* @param training_mode If set to true then it indicates dropout is being
* used for training. It is an optional value hence unless specified
* explicitly, it is false. If it is false, ratio is ignored and the
* operation mimics inference mode where nothing will be dropped from the
* input data and if mask is requested as output it will contain all ones.
*/
DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
float ratio, bool training_mode);
OP_CLONE(DropoutObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 2; }
float getRatio() const { return ratio; }
bool getTrainingMode() const { return false; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini