forked from jiuyuan/InfiniTensor
fix batchnorm bug
This commit is contained in:
parent
829c4bfe96
commit
4340522720
|
@ -34,6 +34,7 @@ enum class OpType {
|
|||
Identity,
|
||||
// element wise
|
||||
BatchNorm = 200,
|
||||
BatchNormNHWC,
|
||||
Softmax,
|
||||
Activation,
|
||||
Relu,
|
||||
|
@ -146,6 +147,7 @@ class OpRegistry {
|
|||
FOP(Shape);
|
||||
// element wise
|
||||
FOP(BatchNorm);
|
||||
FOP(BatchNormNHWC);
|
||||
FOP(Softmax);
|
||||
FOP(Activation);
|
||||
FOP(Relu);
|
||||
|
|
|
@ -50,4 +50,48 @@ class BatchNormObj : public OperatorObj {
|
|||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
|
||||
class BatchNormNHWCObj : public OperatorObj {
|
||||
float momentum, eps;
|
||||
bool trainingMode;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new BatchNorm object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor of BatchNorm. For image data, the input
|
||||
* shape is usually [N, C, H, W].
|
||||
* @param output The output tensor of BatchNorm, which should have the same
|
||||
* shape as the input tensor.
|
||||
* @param mean The mean tensor, which has a shape of [C].
|
||||
* @param var The var tensor, which has a shape of [C].
|
||||
* @param scale The scale tensor, which has a shape of [C].
|
||||
* @param bias The bias tensor, which has a shape of [C].
|
||||
* @param momentum Factor used in computing the running mean and variance.
|
||||
* Default is 0.9.
|
||||
* @param eps The epsilon value to use to avoid division by zero. Default is
|
||||
* 1e-5.
|
||||
* @param trainingMode Set to true when used for training.
|
||||
*/
|
||||
BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
|
||||
float eps = 1e-5, bool trainingMode = false);
|
||||
OP_CLONE(BatchNormNHWCObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
std::string toString() const override;
|
||||
|
||||
// output size will be 3 when training
|
||||
int numInputs() const override { return 5; }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getMomentum() const { return momentum; }
|
||||
float getEps() const { return eps; }
|
||||
bool getTrainingMode() const { return trainingMode; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/batch_norm.h"
|
||||
|
||||
#ifdef USE_BANG
|
||||
#include "bang/bang_runtime.h"
|
||||
|
@ -132,8 +133,14 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
|
|||
g->cloneOperator(uOp, inputs, outputs);
|
||||
} else if (const auto &eOp = as<ElementWiseObj>(op)) {
|
||||
g->cloneOperator(eOp, inputs, outputs);
|
||||
} else if (const auto &eOp = as<BatchNormObj>(op)) {
|
||||
float momentum = eOp->getMomentum();
|
||||
float eps = eOp->getEps();
|
||||
bool mode = eOp->getTrainingMode();
|
||||
g->addOpWithOutputs<BatchNormNHWCObj>(inputs[0], outputs[0], inputs[1], inputs[2], inputs[3], inputs[4], momentum, eps, mode);
|
||||
} else {
|
||||
dbg(op);
|
||||
std::cout << OpRegistry::getOpName(op->getOpType()) << std::endl;
|
||||
for (auto &t : inputs) {
|
||||
if (t->getDims().size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
|
|
@ -10,10 +10,10 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
// void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
// void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
// void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
// void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
@ -21,51 +21,59 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1];
|
||||
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimArray[i] = dims[i];
|
||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
||||
}
|
||||
int w = dimArray[3];
|
||||
dimArray[3] = dimArray[1];
|
||||
int h = dimArray[2];
|
||||
dimArray[1] = h;
|
||||
dimArray[2] = w;
|
||||
|
||||
dimPArray[0] = op->getInputs(1)->getDims()[0];
|
||||
stridePArray[0] = op->getInputs(1)->getDims()[0];
|
||||
// get inputs
|
||||
cnnlTensorDescriptor_t inDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC,
|
||||
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dimArray, strideArray));
|
||||
dims.data()));
|
||||
|
||||
// get bnScaleBiasMeanVarDesc
|
||||
cnnlTensorDescriptor_t paraDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(¶Desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 1, dimPArray,
|
||||
stridePArray));
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
// This mode is intended for use after convolutional layers
|
||||
cnnlStatus_t stat = cnnlBatchNormForwardInference(
|
||||
context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc,
|
||||
scale, bias, mean, var, op->getEps(), inDesc, output);
|
||||
cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
};
|
||||
|
||||
class BatchNormNHWCCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormNHWCObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
// void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
// void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
// void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
// void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
// get inputs
|
||||
cnnlTensorDescriptor_t inDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dims.data()));
|
||||
|
||||
cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in BANG does not require sync. But cnnl does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32,
|
||||
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNormNHWC, DataType::Float32,
|
||||
BatchNormNHWCCnnl, "BatchNormNHWC_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -66,4 +66,69 @@ vector<int> BatchNormObj::getOpAttrVector() const {
|
|||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
BatchNormNHWCObj::BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
Tensor mean, Tensor var, Tensor scale, Tensor bias,
|
||||
float momentum, float eps, bool trainingMode)
|
||||
: OperatorObj(OpType::BatchNormNHWC, {input, mean, var, scale, bias}, {output}),
|
||||
momentum(momentum), eps(eps), trainingMode(trainingMode) {
|
||||
if (trainingMode)
|
||||
IT_TODO_HALT();
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
BatchNormNHWCObj::inferShape(const TensorVec &inputs) const {
|
||||
auto input = inputs[0];
|
||||
auto mean = inputs[1];
|
||||
auto var = inputs[2];
|
||||
auto scale = inputs[3];
|
||||
auto bias = inputs[4];
|
||||
auto c = std::vector<int>{input->getDims()[3]};
|
||||
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
|
||||
bias->getDims() != c)
|
||||
return {};
|
||||
return {{input->getDims()}};
|
||||
}
|
||||
|
||||
vector<DataType> BatchNormNHWCObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 5);
|
||||
auto index = inputs[1];
|
||||
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
|
||||
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
|
||||
IT_ASSERT(inputs[3]->getDType() == DataType::Float32);
|
||||
IT_ASSERT(inputs[4]->getDType() == DataType::Float32);
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
std::string BatchNormNHWCObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "BatchNormNHWC[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "momentum=" << momentum << ",";
|
||||
os << "eps=" << eps << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "mean=" << inputs[1]->getGuid() << ",";
|
||||
os << "var=" << inputs[2]->getGuid() << ",";
|
||||
os << "scale=" << inputs[3]->getGuid() << ",";
|
||||
os << "bias=" << inputs[4]->getGuid() << ",";
|
||||
os << "output=";
|
||||
for (auto output : outputs)
|
||||
os << output->getGuid() << ",";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
// need eps and momentum?
|
||||
vector<int> BatchNormNHWCObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
// need eps and momentum?
|
||||
vector<int> BatchNormNHWCObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue