forked from jiuyuan/InfiniTensor
add instancenorm, use layernorm replace instance, error
This commit is contained in:
parent
907239cf34
commit
0fcaf001c4
|
@ -1 +1 @@
|
|||
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
||||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
|
@ -37,6 +37,8 @@ class GraphHandlerObj {
|
|||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
Tensor instanceNormalization(Tensor input, Tensor output, Tensor scale,
|
||||
Tensor bias, float eps);
|
||||
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class InstanceNormObj : public OperatorObj {
|
||||
float eps;
|
||||
|
||||
public:
|
||||
InstanceNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor scale,
|
||||
Tensor bias , float eps = 1e-5);
|
||||
OP_CLONE(InstanceNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -323,6 +323,22 @@ class OnnxStub:
|
|||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "InstanceNormalization":
|
||||
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
|
||||
|
||||
output = tensors.get(node.output[0])
|
||||
|
||||
tensors[node.output[0]] = self.handler.instanceNormalization(
|
||||
input,
|
||||
output,
|
||||
scale,
|
||||
bias,
|
||||
next(
|
||||
(attr.f for attr in node.attribute if attr.name == "epsilon"),
|
||||
1e-5,
|
||||
),
|
||||
|
||||
)
|
||||
elif node.op_type == "RMSNorm":
|
||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/layer_norm.h"
|
||||
#include "operators/instance_norm.h"
|
||||
#include "operators/lrn.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
|
@ -124,6 +125,20 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
->getOutput();
|
||||
}
|
||||
}
|
||||
Tensor GraphHandlerObj::instanceNormalization(Tensor input,
|
||||
Tensor output, Tensor scale, Tensor bias,
|
||||
float eps) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<InstanceNormObj>(std::move(input), output, std::move(scale),
|
||||
std::move(bias), eps);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<InstanceNormObj>(std::move(input), output, std::move(scale),
|
||||
std::move(bias), eps)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
Tensor GraphHandlerObj::leakyrelu(Tensor input, Tensor output, float alpha) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<LeakyReluObj>(std::move(input), output, alpha);
|
||||
|
|
|
@ -529,6 +529,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("instanceNormalization", &Handler::instanceNormalization, policy::move)
|
||||
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
#include "operators/instance_norm.h"
|
||||
#include "aclnnop/level2/aclnn_layer_norm.h"
|
||||
#include "ascend/ascend_kernel_without_config.h"
|
||||
#include "ascend/ascend_runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class InstanceNormAclnn : public ASCENDKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<InstanceNormObj>(_op);
|
||||
auto context = dynamic_cast<const ASCENDRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const weightData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inputD = op->getInputs(0)->getDims();
|
||||
auto inputS = op->getInputs(0)->getStride();
|
||||
auto weightD = op->getInputs(1)->getDims();
|
||||
auto weightS = op->getInputs(1)->getStride();
|
||||
auto outD = op->getOutput()->getDims();
|
||||
auto outS = op->getOutput()->getStride();
|
||||
|
||||
double eps = static_cast<double>(op->getEps());
|
||||
|
||||
std::vector<int64_t> inputDim = castTo64(inputD);
|
||||
std::vector<int64_t> inputStride = castTo64(inputS);
|
||||
std::vector<int64_t> weightDim = castTo64(weightD);
|
||||
std::vector<int64_t> weightStride = castTo64(weightS);
|
||||
std::vector<int64_t> outputDim = castTo64(outD);
|
||||
std::vector<int64_t> outputStride = castTo64(outS);
|
||||
|
||||
auto axis = 3;
|
||||
|
||||
auto rank = static_cast<int>(inputDim.size());
|
||||
std::vector<int64_t> normalizedShape(rank - axis, 0);
|
||||
for (auto i = rank; i > axis; --i) {
|
||||
normalizedShape[i - 1 - axis] = inputDim[i - 1];
|
||||
}
|
||||
|
||||
auto inputTensor =
|
||||
aclCreateTensor(inputDim.data(), inputDim.size(), ACL_FLOAT,
|
||||
inputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
inputDim.data(), inputDim.size(), inputData);
|
||||
auto weightTensor =
|
||||
aclCreateTensor(weightDim.data(), weightDim.size(), ACL_FLOAT,
|
||||
weightStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
weightDim.data(), weightDim.size(), weightData);
|
||||
auto outputTensor =
|
||||
aclCreateTensor(outputDim.data(), outputDim.size(), ACL_FLOAT,
|
||||
outputStride.data(), 0, aclFormat::ACL_FORMAT_NCHW,
|
||||
outputDim.data(), outputDim.size(), outputData);
|
||||
|
||||
auto *normArray =
|
||||
aclCreateIntArray(normalizedShape.data(), normalizedShape.size());
|
||||
|
||||
aclTensor *biasTensor = NULL;
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
|
||||
auto biasD = op->getInputs(2)->getDims();
|
||||
auto biasS = op->getInputs(2)->getStride();
|
||||
std::vector<int64_t> biasDim = castTo64(biasD);
|
||||
std::vector<int64_t> biasStride = castTo64(biasS);
|
||||
|
||||
biasTensor = aclCreateTensor(
|
||||
biasDim.data(), biasDim.size(), ACL_FLOAT, biasStride.data(), 0,
|
||||
aclFormat::ACL_FORMAT_NCHW, biasDim.data(), biasDim.size(),
|
||||
biasData);
|
||||
}
|
||||
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
|
||||
auto ret = aclnnLayerNormGetWorkspaceSize(
|
||||
inputTensor, normArray, weightTensor, biasTensor, eps, outputTensor,
|
||||
NULL, NULL, &workspaceSize, &executor);
|
||||
|
||||
CHECK_RET(
|
||||
ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNormGetWorkspaceSize failed. ERROR: %d\n",
|
||||
ret));
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0) {
|
||||
workspaceAddr = context->getWorkspace(workspaceSize);
|
||||
}
|
||||
auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
if (tmp_err_msg != NULL) {
|
||||
printf(" ERROR Message : %s \n ", tmp_err_msg);
|
||||
}
|
||||
ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor,
|
||||
context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclnnLayerNorm failed. ERROR: %d\n", ret));
|
||||
|
||||
ret = aclrtSynchronizeStream(context->ASCENDHandle());
|
||||
CHECK_RET(ret == ACL_SUCCESS,
|
||||
LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret));
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::ASCEND, OpType::InstanceNormalization, InstanceNormAclnn,
|
||||
"InstanceNorm_ASCEND");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,50 @@
|
|||
#include "operators/instance_norm.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
InstanceNormObj::InstanceNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor scale,
|
||||
Tensor bias,
|
||||
float eps)
|
||||
: OperatorObj(OpType::InstanceNormalization,
|
||||
TensorVec{input, scale, bias},
|
||||
{output}),
|
||||
eps(eps) {
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> InstanceNormObj::inferShape(const TensorVec &inputs) {
|
||||
return {{inputs[0]->getDims()}};
|
||||
}
|
||||
|
||||
vector<DataType> InstanceNormObj::inferDataType(const TensorVec &inputs) const {
|
||||
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
std::string InstanceNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "InstanceNormalization[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "eps=" << eps << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "scale=" << inputs[1]->getGuid() << ",";
|
||||
os << "bias=" << inputs[2]->getGuid() << ",";
|
||||
os << "output=";
|
||||
for (auto output : outputs)
|
||||
os << output->getGuid() << ",";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> InstanceNormObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> InstanceNormObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,75 @@
|
|||
#include "ascend/ascend_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/instance_norm.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_instancenormFp32(
|
||||
const Shape &inputShape, const vector<float> &inputData,
|
||||
const Shape &scaleShape, const vector<float> &scaleData, float eps,
|
||||
const vector<float> &ExpectData,
|
||||
const Shape &biasShape,
|
||||
const vector<float> &biasData) {
|
||||
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
|
||||
|
||||
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
bias->copyin(biasData); //
|
||||
// bias->printData();
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto ascendRuntime = make_ref<ASCENDRuntimeObj>();
|
||||
Graph gAscend = make_ref<GraphObj>(ascendRuntime);
|
||||
auto biasNpu = gAscend->cloneTensor(bias);
|
||||
auto inputNpu = gAscend->cloneTensor(input);
|
||||
auto scaleNpu = gAscend->cloneTensor(scale);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
auto op =
|
||||
gAscend->addOp<InstanceNormObj>(inputNpu, nullptr, scaleNpu, biasNpu,
|
||||
eps); // InstancenormObj
|
||||
gAscend->dataMalloc();
|
||||
biasNpu->copyin(biasData);
|
||||
// gCpu->cloneTensor(biasNpu)->printData();
|
||||
inputNpu->copyin(inputData);
|
||||
scaleNpu->copyin(scaleData);
|
||||
ascendRuntime->run(gAscend);
|
||||
|
||||
auto oCpu =
|
||||
gCpu->cloneTensor(op->getOutput()); // move Data from npu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_InstancenormFp32, run) {
|
||||
aclInit(nullptr);
|
||||
test_instancenormFp32(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5,
|
||||
vector<float>{
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
||||
Shape{3}, vector<float>{0, 0, 0});
|
||||
|
||||
|
||||
aclFinalize();
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue