fix batchnorm bug

This commit is contained in:
wanghailu 2023-06-28 17:25:26 +08:00
parent 829c4bfe96
commit 4340522720
5 changed files with 161 additions and 35 deletions

View File

@ -34,6 +34,7 @@ enum class OpType {
Identity,
// element wise
BatchNorm = 200,
BatchNormNHWC,
Softmax,
Activation,
Relu,
@ -146,6 +147,7 @@ class OpRegistry {
FOP(Shape);
// element wise
FOP(BatchNorm);
FOP(BatchNormNHWC);
FOP(Softmax);
FOP(Activation);
FOP(Relu);

View File

@ -50,4 +50,48 @@ class BatchNormObj : public OperatorObj {
vector<DataType> inferDataType(const TensorVec &inputs) const override;
};
class BatchNormNHWCObj : public OperatorObj {
float momentum, eps;
bool trainingMode;
public:
/**
* @brief Construct a new BatchNorm object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor of BatchNorm. For image data, the input
* shape is usually [N, C, H, W].
* @param output The output tensor of BatchNorm, which should have the same
* shape as the input tensor.
* @param mean The mean tensor, which has a shape of [C].
* @param var The var tensor, which has a shape of [C].
* @param scale The scale tensor, which has a shape of [C].
* @param bias The bias tensor, which has a shape of [C].
* @param momentum Factor used in computing the running mean and variance.
* Default is 0.9.
* @param eps The epsilon value to use to avoid division by zero. Default is
* 1e-5.
* @param trainingMode Set to true when used for training.
*/
BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
float eps = 1e-5, bool trainingMode = false);
OP_CLONE(BatchNormNHWCObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
// output size will be 3 when training
int numInputs() const override { return 5; }
int numOutputs() const override { return outputs.size(); }
float getMomentum() const { return momentum; }
float getEps() const { return eps; }
bool getTrainingMode() const { return trainingMode; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;
};
} // namespace infini

View File

@ -10,6 +10,7 @@
#include "operators/reshape.h"
#include "operators/transpose.h"
#include "operators/unary.h"
#include "operators/batch_norm.h"
#ifdef USE_BANG
#include "bang/bang_runtime.h"
@ -132,8 +133,14 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
g->cloneOperator(uOp, inputs, outputs);
} else if (const auto &eOp = as<ElementWiseObj>(op)) {
g->cloneOperator(eOp, inputs, outputs);
} else if (const auto &eOp = as<BatchNormObj>(op)) {
float momentum = eOp->getMomentum();
float eps = eOp->getEps();
bool mode = eOp->getTrainingMode();
g->addOpWithOutputs<BatchNormNHWCObj>(inputs[0], outputs[0], inputs[1], inputs[2], inputs[3], inputs[4], momentum, eps, mode);
} else {
dbg(op);
std::cout << OpRegistry::getOpName(op->getOpType()) << std::endl;
for (auto &t : inputs) {
if (t->getDims().size() != 4)
IT_TODO_HALT();

View File

@ -10,10 +10,10 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
// void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
// void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
// void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
// void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
void *const output = (op->getOutput()->getRawDataPtr<void *>());
auto dims = op->getInputs(0)->getDims();
@ -21,51 +21,59 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
if (dims.size() != 4)
IT_TODO_HALT();
int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1];
for (size_t i = 0; i < dims.size(); ++i) {
dimArray[i] = dims[i];
strideArray[i] = op->getInputs(0)->getStride()[i];
}
int w = dimArray[3];
dimArray[3] = dimArray[1];
int h = dimArray[2];
dimArray[1] = h;
dimArray[2] = w;
dimPArray[0] = op->getInputs(1)->getDims()[0];
stridePArray[0] = op->getInputs(1)->getDims()[0];
// get inputs
cnnlTensorDescriptor_t inDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC,
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, dims.size(),
dimArray, strideArray));
dims.data()));
// get bnScaleBiasMeanVarDesc
cnnlTensorDescriptor_t paraDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&paraDesc));
checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 1, dimPArray,
stridePArray));
float alpha = 1.f, beta = 0.f;
// This mode is intended for use after convolutional layers
cnnlStatus_t stat = cnnlBatchNormForwardInference(
context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc,
scale, bias, mean, var, op->getEps(), inDesc, output);
cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
}
};
class BatchNormNHWCCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<BatchNormNHWCObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
// void *const mean = (op->getInputs(1)->getRawDataPtr<void *>());
// void *const var = (op->getInputs(2)->getRawDataPtr<void *>());
// void *const scale = (op->getInputs(3)->getRawDataPtr<void *>());
// void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
void *const output = (op->getOutput()->getRawDataPtr<void *>());
auto dims = op->getInputs(0)->getDims();
if (dims.size() != 4)
IT_TODO_HALT();
// get inputs
cnnlTensorDescriptor_t inDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC,
CNNL_DTYPE_FLOAT, dims.size(),
dims.data()));
cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32,
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BatchNormNHWC, DataType::Float32,
BatchNormNHWCCnnl, "BatchNormNHWC_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -66,4 +66,69 @@ vector<int> BatchNormObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
BatchNormNHWCObj::BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output,
Tensor mean, Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool trainingMode)
: OperatorObj(OpType::BatchNormNHWC, {input, mean, var, scale, bias}, {output}),
momentum(momentum), eps(eps), trainingMode(trainingMode) {
if (trainingMode)
IT_TODO_HALT();
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
BatchNormNHWCObj::inferShape(const TensorVec &inputs) const {
auto input = inputs[0];
auto mean = inputs[1];
auto var = inputs[2];
auto scale = inputs[3];
auto bias = inputs[4];
auto c = std::vector<int>{input->getDims()[3]};
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
bias->getDims() != c)
return {};
return {{input->getDims()}};
}
vector<DataType> BatchNormNHWCObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 5);
auto index = inputs[1];
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
IT_ASSERT(inputs[3]->getDType() == DataType::Float32);
IT_ASSERT(inputs[4]->getDType() == DataType::Float32);
return {inputs[0]->getDType()};
}
std::string BatchNormNHWCObj::toString() const {
std::ostringstream os;
os << "BatchNormNHWC[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "momentum=" << momentum << ",";
os << "eps=" << eps << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "mean=" << inputs[1]->getGuid() << ",";
os << "var=" << inputs[2]->getGuid() << ",";
os << "scale=" << inputs[3]->getGuid() << ",";
os << "bias=" << inputs[4]->getGuid() << ",";
os << "output=";
for (auto output : outputs)
os << output->getGuid() << ",";
return os.str();
}
// need eps and momentum?
vector<int> BatchNormNHWCObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
return ret;
}
// need eps and momentum?
vector<int> BatchNormNHWCObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
} // namespace infini