forked from jiuyuan/InfiniTensor
Fix bang (#198)
* fix bang batchnorm * fix pooling test bang * add test batchnorm * HIGH PRECISION ACTIVATION * fix pooling * fix matmul * fix test * add layernorm * fix softmax * fix * better code * fix * fix worlflow * fix workflow * fix * fix * fxi matmul * add LRN * fix lrn * fix lrn --------- Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: Baoming Li <1508269885@qq.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
3f34372012
commit
5ac0ab442f
|
@ -99,6 +99,8 @@ class GraphHandlerObj {
|
||||||
int outputType, Tensor input);
|
int outputType, Tensor input);
|
||||||
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
|
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
|
||||||
std::string mode);
|
std::string mode);
|
||||||
|
Tensor lrn(Tensor input, Tensor output, float alpha, float beta, float bias,
|
||||||
|
int size);
|
||||||
|
|
||||||
//------ modifiers
|
//------ modifiers
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class LRNObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
LRNObj(GraphObj *graph, Tensor inputX, Tensor inputY, float alpha,
|
||||||
|
float beta, float bias, int size);
|
||||||
|
OP_CLONE(LRNObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return inputs.size(); }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
auto getAlphaBetaBias() const {
|
||||||
|
return tuple(alpha_value, beta_value, bias_value);
|
||||||
|
}
|
||||||
|
auto getSize() const { return size_value; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
float alpha_value, beta_value, bias_value;
|
||||||
|
int size_value;
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -857,6 +857,22 @@ class OnnxStub:
|
||||||
tensors[output_name] = self.handler.tensor(dims, tensor.data_type)
|
tensors[output_name] = self.handler.tensor(dims, tensor.data_type)
|
||||||
data[output_name] = tensor
|
data[output_name] = tensor
|
||||||
tensors[output_name].set_weight()
|
tensors[output_name].set_weight()
|
||||||
|
elif node.op_type == "LRN":
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||||
|
)
|
||||||
|
(alpha, beta, bias, size) = (
|
||||||
|
attributes[name]
|
||||||
|
for name in ["alpha", "beta", "bias", "size"]
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = self.handler.lrn(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
bias,
|
||||||
|
size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
new_node_name.append(node.name)
|
new_node_name.append(node.name)
|
||||||
|
@ -1195,6 +1211,20 @@ class OnnxStub:
|
||||||
elif ty == backend.OpTypeId.Expand:
|
elif ty == backend.OpTypeId.Expand:
|
||||||
shape = backend.expand_shape_of(op)
|
shape = backend.expand_shape_of(op)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
||||||
|
elif ty == backend.OpTypeId.LRN:
|
||||||
|
alpha, beta, bias, size = backend.lrn_attrs_of(op)
|
||||||
|
ctx.push_node(
|
||||||
|
make_node(
|
||||||
|
ty.name,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
name,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
bias,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported OpType", ty)
|
raise Exception("Unsupported OpType", ty)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "operators/expand.h"
|
#include "operators/expand.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/layer_norm.h"
|
#include "operators/layer_norm.h"
|
||||||
|
#include "operators/lrn.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/pad.h"
|
#include "operators/pad.h"
|
||||||
#include "operators/pooling.h"
|
#include "operators/pooling.h"
|
||||||
|
@ -519,6 +520,19 @@ Tensor GraphHandlerObj::depthToSpace(Tensor input, Tensor output, int blocksize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
||||||
|
float beta, float bias, int size) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<LRNObj>(std::move(input), output, alpha, beta, bias,
|
||||||
|
size);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<LRNObj>(std::move(input), output, alpha, beta, bias, size)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static CastType inferCastType(Tensor input, int to) {
|
static CastType inferCastType(Tensor input, int to) {
|
||||||
auto iType = input->getDType();
|
auto iType = input->getDType();
|
||||||
auto oType = DataType(to);
|
auto oType = DataType(to);
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/expand.h"
|
#include "operators/expand.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
|
#include "operators/lrn.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/pad.h"
|
#include "operators/pad.h"
|
||||||
#include "operators/pooling.h"
|
#include "operators/pooling.h"
|
||||||
|
@ -113,6 +114,7 @@ void export_values(py::module &m) {
|
||||||
.VALUE(OpType, Erf)
|
.VALUE(OpType, Erf)
|
||||||
.VALUE(OpType, Where)
|
.VALUE(OpType, Where)
|
||||||
.VALUE(OpType, DepthToSpace)
|
.VALUE(OpType, DepthToSpace)
|
||||||
|
.VALUE(OpType, LRN)
|
||||||
.export_values();
|
.export_values();
|
||||||
|
|
||||||
#undef VALUE
|
#undef VALUE
|
||||||
|
@ -296,6 +298,14 @@ static std::tuple<int, std::string> depth_to_space_attrs_of(Operator op) {
|
||||||
depth_to_space->getModeString());
|
depth_to_space->getModeString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::tuple<float, float, float, int> lrn_attrs_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::LRN);
|
||||||
|
auto lrn = dynamic_cast<const LRNObj *>(op.get());
|
||||||
|
auto [alpha, beta, bias] = lrn->getAlphaBetaBias();
|
||||||
|
auto size = lrn->getSize();
|
||||||
|
return std::make_tuple(alpha, beta, bias, size);
|
||||||
|
}
|
||||||
|
|
||||||
void export_functions(py::module &m) {
|
void export_functions(py::module &m) {
|
||||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||||
|
@ -332,7 +342,8 @@ void export_functions(py::module &m) {
|
||||||
.FUNCTION(gather_axis_of)
|
.FUNCTION(gather_axis_of)
|
||||||
.FUNCTION(flatten_axis_of)
|
.FUNCTION(flatten_axis_of)
|
||||||
.FUNCTION(cast_to_of)
|
.FUNCTION(cast_to_of)
|
||||||
.FUNCTION(depth_to_space_attrs_of);
|
.FUNCTION(depth_to_space_attrs_of)
|
||||||
|
.FUNCTION(lrn_attrs_of);
|
||||||
#undef FUNCTION
|
#undef FUNCTION
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -517,6 +528,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("expand", &Handler::expand, policy::move)
|
.def("expand", &Handler::expand, policy::move)
|
||||||
.def("erf", &Handler::erf, policy::move)
|
.def("erf", &Handler::erf, policy::move)
|
||||||
.def("where", &Handler::where, policy::move)
|
.def("where", &Handler::where, policy::move)
|
||||||
|
.def("lrn", &Handler::lrn, policy::move)
|
||||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
.def("optimize", &Handler::optimize, policy::automatic)
|
.def("optimize", &Handler::optimize, policy::automatic)
|
||||||
.def("operators", &Handler::operators, policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
|
|
|
@ -30,8 +30,9 @@ class UnaryCnnl : public BangKernelWithoutConfig {
|
||||||
cDim.data()));
|
cDim.data()));
|
||||||
cnnlActivationDescriptor_t opDesc;
|
cnnlActivationDescriptor_t opDesc;
|
||||||
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||||
checkCnnlError(cnnlSetActivationDescriptor(
|
checkCnnlError(cnnlSetActivationDescriptor_v2(
|
||||||
opDesc, getOpType(), CNNL_NOT_PROPAGATE_NAN, getCoef()));
|
opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION,
|
||||||
|
CNNL_NOT_PROPAGATE_NAN, getCoef()));
|
||||||
|
|
||||||
auto [alpha, beta] = getAlphBeta();
|
auto [alpha, beta] = getAlphBeta();
|
||||||
cnnlStatus_t stat =
|
cnnlStatus_t stat =
|
||||||
|
@ -131,31 +132,51 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
||||||
std::vector<int> inDim = {1, 1, 1};
|
std::vector<int> inDim = {1, 1, 1};
|
||||||
std::vector<int> outDim = inDim;
|
std::vector<int> outDim = inDim;
|
||||||
|
|
||||||
if (axis == 0) {
|
if (aDim.size() >= 3) {
|
||||||
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
if (axis == 0) {
|
||||||
inDim[0] = aDim[0];
|
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||||
inDim[1] = aDim[1];
|
inDim[0] = aDim[0];
|
||||||
for (size_t i = 2; i < aDim.size(); ++i) {
|
inDim[1] = aDim[1];
|
||||||
inDim[2] *= aDim[i];
|
for (size_t i = 2; i < aDim.size(); ++i) {
|
||||||
|
inDim[2] *= aDim[i];
|
||||||
|
}
|
||||||
|
outDim = inDim;
|
||||||
|
} else if (axis == aDim.size() - 1) {
|
||||||
|
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
||||||
|
inDim[0] = aDim[0];
|
||||||
|
for (size_t i = 1; i < axis; ++i) {
|
||||||
|
inDim[1] *= aDim[i];
|
||||||
|
}
|
||||||
|
inDim[2] = aDim[axis];
|
||||||
|
outDim = inDim;
|
||||||
|
} else {
|
||||||
|
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
|
||||||
|
for (size_t i = 0; i < axis; ++i) {
|
||||||
|
inDim[0] *= aDim[i];
|
||||||
|
}
|
||||||
|
inDim[1] = aDim[axis];
|
||||||
|
for (size_t i = axis + 1; i < aDim.size(); ++i) {
|
||||||
|
inDim[2] *= aDim[i];
|
||||||
|
}
|
||||||
|
outDim = inDim;
|
||||||
}
|
}
|
||||||
outDim = inDim;
|
} else if (aDim.size() == 2) {
|
||||||
} else if (axis == aDim.size() - 1) {
|
if (axis == 0) {
|
||||||
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||||
inDim[0] = aDim[0];
|
inDim = aDim;
|
||||||
for (size_t i = 1; i < axis; ++i) {
|
inDim.push_back(1);
|
||||||
inDim[1] *= aDim[i];
|
outDim = inDim;
|
||||||
|
} else {
|
||||||
|
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
||||||
|
inDim = aDim;
|
||||||
|
inDim.insert(inDim.begin(), 1);
|
||||||
|
outDim = inDim;
|
||||||
}
|
}
|
||||||
inDim[2] = aDim[axis];
|
|
||||||
outDim = inDim;
|
|
||||||
} else {
|
} else {
|
||||||
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
|
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||||
for (size_t i = 0; i < axis; ++i) {
|
inDim = aDim;
|
||||||
inDim[0] *= aDim[i];
|
inDim.push_back(1);
|
||||||
}
|
inDim.push_back(1);
|
||||||
inDim[1] = aDim[axis];
|
|
||||||
for (size_t i = axis + 1; i < aDim.size(); ++i) {
|
|
||||||
inDim[2] *= aDim[i];
|
|
||||||
}
|
|
||||||
outDim = inDim;
|
outDim = inDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,8 +192,8 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
cnnlStatus_t stat =
|
cnnlStatus_t stat =
|
||||||
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
|
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
|
||||||
mode, CNNL_COMPUTATION_HIGH_PRECISION, &alpha,
|
mode, CNNL_COMPUTATION_ULTRAHIGH_PRECISION,
|
||||||
aDesc, aData, &beta, cDesc, cData);
|
&alpha, aDesc, aData, &beta, cDesc, cData);
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
|
|
@ -17,51 +17,87 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
||||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto dims = op->getInputs(0)->getDims();
|
auto dims = op->getInputs(0)->getDims();
|
||||||
|
auto outDims = op->getOutput()->getDims();
|
||||||
if (dims.size() != 4)
|
if (dims.size() != 4)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
|
||||||
int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1];
|
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
||||||
|
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
||||||
|
int permute[4] = {0, 2, 3, 1};
|
||||||
|
int permuteOut[4] = {0, 3, 1, 2};
|
||||||
|
|
||||||
for (size_t i = 0; i < dims.size(); ++i) {
|
|
||||||
dimArray[i] = dims[i];
|
|
||||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
|
||||||
}
|
|
||||||
int w = dimArray[3];
|
|
||||||
dimArray[3] = dimArray[1];
|
|
||||||
int h = dimArray[2];
|
|
||||||
dimArray[1] = h;
|
|
||||||
dimArray[2] = w;
|
|
||||||
|
|
||||||
dimPArray[0] = op->getInputs(1)->getDims()[0];
|
|
||||||
stridePArray[0] = op->getInputs(1)->getDims()[0];
|
|
||||||
// get inputs
|
// get inputs
|
||||||
cnnlTensorDescriptor_t inDesc;
|
cnnlTensorDescriptor_t inDesc, intransDesc, outDesc, outtransDesc;
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC,
|
checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc));
|
||||||
CNNL_DTYPE_FLOAT, dims.size(),
|
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
|
||||||
dimArray, strideArray));
|
checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, dims.size(),
|
||||||
|
dims.data()));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC,
|
||||||
|
CNNL_DTYPE_FLOAT, dims.size(),
|
||||||
|
dimsTrans));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, outDims.size(),
|
||||||
|
outDims.data()));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC,
|
||||||
|
CNNL_DTYPE_FLOAT, outDims.size(),
|
||||||
|
dimsOutTrans));
|
||||||
|
cnnlTransposeDescriptor_t opDesc;
|
||||||
|
checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc));
|
||||||
|
checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute));
|
||||||
|
size_t wsSize;
|
||||||
|
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), inDesc, opDesc,
|
||||||
|
&wsSize);
|
||||||
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
BangPtr inputTrans = context->getWorkspace(
|
||||||
|
cnnlGetTensorElementNum(inDesc) * sizeof(float));
|
||||||
|
BangPtr outputTrans = context->getWorkspace(
|
||||||
|
cnnlGetTensorElementNum(inDesc) * sizeof(float));
|
||||||
|
cnnlStatus_t stat =
|
||||||
|
cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input,
|
||||||
|
intransDesc, inputTrans, wsData, wsSize);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
// get bnScaleBiasMeanVarDesc
|
// get bnScaleBiasMeanVarDesc
|
||||||
|
auto dimsScaleBiasMeanVar = op->getInputs(1)->getDims();
|
||||||
cnnlTensorDescriptor_t paraDesc;
|
cnnlTensorDescriptor_t paraDesc;
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(¶Desc));
|
checkCnnlError(cnnlCreateTensorDescriptor(¶Desc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY,
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
CNNL_DTYPE_FLOAT, 1, dimPArray,
|
paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||||
stridePArray));
|
dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data()));
|
||||||
|
|
||||||
float alpha = 1.f, beta = 0.f;
|
float alpha = 1.f, beta = 0.f;
|
||||||
// This mode is intended for use after convolutional layers
|
// This mode is intended for use after convolutional layers
|
||||||
cnnlStatus_t stat = cnnlBatchNormForwardInference(
|
stat = cnnlBatchNormForwardInference(
|
||||||
context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc,
|
context->cnnlHandle(), &alpha, &beta, intransDesc, inputTrans,
|
||||||
scale, bias, mean, var, op->getEps(), inDesc, output);
|
paraDesc, scale, bias, mean, var, op->getEps(), outtransDesc,
|
||||||
|
outputTrans);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
cnnlTransposeDescriptor_t op2Desc;
|
||||||
|
checkCnnlError(cnnlCreateTransposeDescriptor(&op2Desc));
|
||||||
|
checkCnnlError(cnnlSetTransposeDescriptor(op2Desc, 4, permuteOut));
|
||||||
|
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), intransDesc,
|
||||||
|
op2Desc, &wsSize);
|
||||||
|
BangPtr ws2Data = context->getWorkspace(wsSize);
|
||||||
|
stat = cnnlTranspose_v2(context->cnnlHandle(), op2Desc, outtransDesc,
|
||||||
|
outputTrans, outDesc, output, ws2Data, wsSize);
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Destories in BANG does not require sync. But cnnl does not state
|
// Destories in BANG does not require sync. But cnnl does not state
|
||||||
// whether sync is required before destories.
|
// whether sync is required before destories.
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(outDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(intransDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(outtransDesc));
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTransposeDescriptor(opDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTransposeDescriptor(op2Desc));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
#include "operators/layer_norm.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class LayerNormCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LayerNormObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
void *biasData = NULL;
|
||||||
|
if (op->numInputs() == 3) {
|
||||||
|
biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
auto inDims = op->getInputs(0)->getDims();
|
||||||
|
auto outDims = op->getOutput()->getDims();
|
||||||
|
auto fiterDims = op->getOutput(1)->getDims();
|
||||||
|
|
||||||
|
float eps = op->getEps();
|
||||||
|
const int axis = op->getAxis();
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
|
||||||
|
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY,
|
||||||
|
CNNL_DTYPE_FLOAT, inDims.size(),
|
||||||
|
inDims.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(),
|
||||||
|
fiterDims.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY,
|
||||||
|
CNNL_DTYPE_FLOAT, outDims.size(),
|
||||||
|
outDims.data()));
|
||||||
|
size_t wsSize;
|
||||||
|
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
||||||
|
&wsSize);
|
||||||
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
|
||||||
|
cnnlStatus_t stat = cnnlLayerNormForward(
|
||||||
|
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
||||||
|
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
||||||
|
inDesc, NULL, NULL);
|
||||||
|
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(fiterDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(outDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32,
|
||||||
|
LayerNormCnnl, "LayerNorm_BANG_Float32");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,62 @@
|
||||||
|
#include "operators/lrn.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class LRNCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LRNObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||||
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
|
auto cDim = op->getOutput()->getDims();
|
||||||
|
auto [alpha, beta, bias] = op->getAlphaBetaBias();
|
||||||
|
auto size = op->getSize();
|
||||||
|
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, aDim.size(),
|
||||||
|
aDim.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, cDim.size(),
|
||||||
|
cDim.data()));
|
||||||
|
|
||||||
|
size_t extra_size;
|
||||||
|
cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc,
|
||||||
|
CNNL_LRN_LOCAL_SIZE, size, &extra_size);
|
||||||
|
void *extra_cpu = NULL;
|
||||||
|
extra_cpu = malloc(extra_size);
|
||||||
|
BangPtr extra_mlu = context->getWorkspace(extra_size);
|
||||||
|
cnnlInitLrnExtraInput(context->cnnlHandle(), CNNL_LRN_LOCAL_SIZE, size,
|
||||||
|
(double)alpha, (double)beta, (double)bias, aDesc,
|
||||||
|
cDesc, extra_cpu);
|
||||||
|
cnrtMemcpy(extra_mlu, extra_cpu, extra_size,
|
||||||
|
CNRT_MEM_TRANS_DIR_HOST2DEV);
|
||||||
|
|
||||||
|
size_t wsSize;
|
||||||
|
cnnlGetLrnWorkspaceSize_v2(context->cnnlHandle(), aDesc, cDesc,
|
||||||
|
CNNL_LRN_LOCAL_SIZE, size, &wsSize);
|
||||||
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
|
||||||
|
cnnlStatus_t stat = cnnlLrn_v2(
|
||||||
|
context->cnnlHandle(), CNNL_LRN_LOCAL_SIZE, size, (double)alpha,
|
||||||
|
(double)beta, (double)bias, wsData, wsSize, aDesc, aData, extra_mlu,
|
||||||
|
extra_size, cDesc, cData);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl,
|
||||||
|
"LRN_cnnl_BANG_Float32");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -10,15 +10,29 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
auto op = as<MatmulObj>(_op);
|
auto op = as<MatmulObj>(_op);
|
||||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
auto input_num = op->numInputs();
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
void *biasData = NULL;
|
||||||
|
if (input_num > 2) {
|
||||||
|
biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
|
cnnlTensorDescriptor_t aDesc, bDesc, cDesc, biasDesc;
|
||||||
auto dimInputs0 = op->getInputs(0)->getDims();
|
auto dimInputs0 = op->getInputs(0)->getDims();
|
||||||
auto dimInputs1 = op->getInputs(1)->getDims();
|
auto dimInputs1 = op->getInputs(1)->getDims();
|
||||||
|
std::vector<int> dimBias;
|
||||||
|
if (input_num > 2) {
|
||||||
|
dimBias = op->getInputs(2)->getDims();
|
||||||
|
}
|
||||||
|
|
||||||
auto dimOutput = op->getOutput()->getDims();
|
auto dimOutput = op->getOutput()->getDims();
|
||||||
|
|
||||||
|
float alpha = 1.0;
|
||||||
|
float beta = 0.0;
|
||||||
|
|
||||||
int32_t transA = op->getTransA();
|
int32_t transA = op->getTransA();
|
||||||
int32_t transB = op->getTransB();
|
int32_t transB = op->getTransB();
|
||||||
|
|
||||||
|
@ -37,6 +51,13 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||||
dimOutput.size(), dimOutput.data()));
|
dimOutput.size(), dimOutput.data()));
|
||||||
|
|
||||||
|
if (input_num > 2) {
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(),
|
||||||
|
dimBias.data()));
|
||||||
|
}
|
||||||
|
|
||||||
cnnlMatMulDescriptor_t bmm_desc;
|
cnnlMatMulDescriptor_t bmm_desc;
|
||||||
cnnlMatMulDescCreate(&bmm_desc);
|
cnnlMatMulDescCreate(&bmm_desc);
|
||||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSA, &transA,
|
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSA, &transA,
|
||||||
|
@ -47,8 +68,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlMatMulAlgo_t bmm_algo;
|
cnnlMatMulAlgo_t bmm_algo;
|
||||||
cnnlMatMulAlgoCreate(&bmm_algo);
|
cnnlMatMulAlgoCreate(&bmm_algo);
|
||||||
|
|
||||||
float alpha = 1.0;
|
|
||||||
float beta = 0.0;
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
|
||||||
cnnlMatMulHeuristicResult_t desc;
|
cnnlMatMulHeuristicResult_t desc;
|
||||||
|
@ -66,9 +85,22 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
wsData = NULL;
|
||||||
|
if (input_num > 2) {
|
||||||
|
cnnlGetBiasAddWorkspaceSize(context->cnnlHandle(), biasDesc, cDesc,
|
||||||
|
&wsSize);
|
||||||
|
stat = cnnlBiasAdd(context->cnnlHandle(), &alpha, biasDesc,
|
||||||
|
biasData, wsData, wsSize, &alpha, cDesc, cData);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
||||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
if (input_num > 2) {
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(biasDesc));
|
||||||
|
}
|
||||||
checkCnnlError(cnnlMatMulDescDestroy(bmm_desc));
|
checkCnnlError(cnnlMatMulDescDestroy(bmm_desc));
|
||||||
checkCnnlError(cnnlMatMulAlgoDestroy(bmm_algo));
|
checkCnnlError(cnnlMatMulAlgoDestroy(bmm_algo));
|
||||||
checkCnnlError(cnnlDestroyMatMulHeuristicResult(desc));
|
checkCnnlError(cnnlDestroyMatMulHeuristicResult(desc));
|
||||||
|
|
|
@ -13,14 +13,14 @@ class PadCnnl : public BangKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||||
auto dim = op->getOutput()->getDims();
|
auto dimIn = op->getInputs(0)->getDims();
|
||||||
int dim_size = dim.size();
|
auto dimOut = op->getOutput()->getDims();
|
||||||
int dim_array[dim_size];
|
|
||||||
for (int i = 0; i < dim_size; ++i) {
|
int dim_size = dimIn.size();
|
||||||
dim_array[i] = dim[i];
|
|
||||||
}
|
|
||||||
int paddings[dim_size * 2];
|
int paddings[dim_size * 2];
|
||||||
|
|
||||||
std::vector<int> pads = op->getPads();
|
std::vector<int> pads = op->getPads();
|
||||||
|
|
||||||
if (pads.size() == 2 && dim_size != 1) {
|
if (pads.size() == 2 && dim_size != 1) {
|
||||||
for (int i = 0; i < dim_size * 2; i += 2) {
|
for (int i = 0; i < dim_size * 2; i += 2) {
|
||||||
paddings[i] = pads[0];
|
paddings[i] = pads[0];
|
||||||
|
@ -32,20 +32,18 @@ class PadCnnl : public BangKernelWithoutConfig {
|
||||||
paddings[i + 1] = pads[i / 2 + dim_size];
|
paddings[i + 1] = pads[i / 2 + dim_size];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int dimout_array[dim_size];
|
|
||||||
for (int i = 0; i < dim_size; ++i) {
|
|
||||||
dimout_array[i] = dim[i] + paddings[2 * i] + paddings[2 * i + 1];
|
|
||||||
}
|
|
||||||
float paddingValue = 0.0;
|
float paddingValue = 0.0;
|
||||||
// input
|
// input
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||||
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dim_size, dim_array));
|
CNNL_DTYPE_FLOAT, dimIn.size(),
|
||||||
|
dimIn.data()));
|
||||||
// output
|
// output
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
||||||
CNNL_DTYPE_FLOAT, dim_size,
|
CNNL_DTYPE_FLOAT, dimOut.size(),
|
||||||
dimout_array));
|
dimOut.data()));
|
||||||
|
|
||||||
cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData,
|
cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData,
|
||||||
paddings, &paddingValue, cDesc, cData);
|
paddings, &paddingValue, cDesc, cData);
|
||||||
|
|
|
@ -21,13 +21,14 @@ class PoolingCnnl : public BangKernelWithoutConfig {
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
|
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
|
||||||
CNNL_DTYPE_FLOAT, 4, inArray));
|
CNNL_DTYPE_FLOAT, 4, inArray));
|
||||||
|
bool mode = op->getCeilMode();
|
||||||
|
|
||||||
// get maxpool descriptor
|
// get maxpool descriptor
|
||||||
cnnlPoolingDescriptor_t poolingDesc;
|
cnnlPoolingDescriptor_t poolingDesc;
|
||||||
checkCnnlError(cnnlCreatePoolingDescriptor(&poolingDesc));
|
checkCnnlError(cnnlCreatePoolingDescriptor(&poolingDesc));
|
||||||
checkCnnlError(cnnlSetPooling2dDescriptor_v2(
|
checkCnnlError(cnnlSetPooling2dDescriptor_v2(
|
||||||
poolingDesc, getPoolingMode(), CNNL_NOT_PROPAGATE_NAN, kh, kw, ph,
|
poolingDesc, getPoolingMode(), CNNL_NOT_PROPAGATE_NAN, kh, kw, ph,
|
||||||
ph, pw, pw, sh, sw, dh, dw, false));
|
ph, pw, pw, sh, sw, dh, dw, mode));
|
||||||
|
|
||||||
// get outputs
|
// get outputs
|
||||||
// TODO: verify ceiling mode
|
// TODO: verify ceiling mode
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
#include "operators/lrn.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
LRNObj::LRNObj(GraphObj *graph, Tensor input, Tensor output, float alpha,
|
||||||
|
float beta, float bias, int size)
|
||||||
|
: OperatorObj(OpType::LRN, TensorVec{input}, {output}), alpha_value(alpha),
|
||||||
|
beta_value(beta), bias_value(bias), size_value(size) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> LRNObj::inferShape(const TensorVec &inputs) {
|
||||||
|
const auto A = inputs[0];
|
||||||
|
return {{A->getDims()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LRNObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "LRN[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> LRNObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = getOutput()->getDims();
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> LRNObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,57 @@
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/batch_norm.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(BANG_BatchNorm, run) {
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build cpu graph
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
|
||||||
|
auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||||
|
auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||||
|
auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||||
|
auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
iCpu->setData(IncrementalGenerator());
|
||||||
|
meanCpu->copyin(vector<float>{1, 6, 9});
|
||||||
|
varCpu->copyin(vector<float>{4, 1, 9});
|
||||||
|
scaleCpu->setData(OneGenerator());
|
||||||
|
biasCpu->setData(ZeroGenerator());
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
|
||||||
|
auto i = g->cloneTensor(iCpu);
|
||||||
|
auto mean = g->cloneTensor(meanCpu);
|
||||||
|
auto var = g->cloneTensor(varCpu);
|
||||||
|
auto scale = g->cloneTensor(scaleCpu);
|
||||||
|
auto bias = g->cloneTensor(biasCpu);
|
||||||
|
auto op =
|
||||||
|
g->addOp<BatchNormObj>(i, nullptr, mean, var, scale, bias, 0.9, 0);
|
||||||
|
|
||||||
|
g->dataMalloc();
|
||||||
|
i->setData(IncrementalGenerator());
|
||||||
|
mean->copyin(vector<float>{1, 6, 9});
|
||||||
|
var->copyin(vector<float>{4, 1, 9});
|
||||||
|
scale->setData(OneGenerator());
|
||||||
|
bias->setData(ZeroGenerator());
|
||||||
|
|
||||||
|
bangRuntime->run(g);
|
||||||
|
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto ocpu = o->clone(cpuRuntime);
|
||||||
|
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
||||||
|
EXPECT_TRUE(ocpu->equalData(vector<float>{
|
||||||
|
-0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.3333333, 0.6666667}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -32,6 +32,8 @@ void testConcat(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
auto gpuOp =
|
auto gpuOp =
|
||||||
bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
|
bangGraph->addOp<T>(TensorVec{inputGpu1, inputGpu2}, nullptr, 2);
|
||||||
bangGraph->dataMalloc();
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu1->setData(generator);
|
||||||
|
inputGpu2->setData(generator);
|
||||||
bangRuntime->run(bangGraph);
|
bangRuntime->run(bangGraph);
|
||||||
auto outputGpu = gpuOp->getOutput();
|
auto outputGpu = gpuOp->getOutput();
|
||||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
|
|
@ -18,8 +18,14 @@ void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
|
|
||||||
// Build input data on CPU
|
// Build input data on CPU
|
||||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||||
inputCpu->dataMalloc();
|
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto cpuOp =
|
||||||
|
cpuGraph->addOp<T>(inputCpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||||
|
cpuGraph->addTensor(inputCpu);
|
||||||
|
cpuGraph->dataMalloc();
|
||||||
inputCpu->setData(generator);
|
inputCpu->setData(generator);
|
||||||
|
cpuRuntime->run(cpuGraph);
|
||||||
|
auto outputCpu = cpuOp->getOutput();
|
||||||
|
|
||||||
// GPU
|
// GPU
|
||||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
@ -27,17 +33,16 @@ void testPooling(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
auto gpuOp =
|
auto gpuOp =
|
||||||
bangGraph->addOp<T>(inputGpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
bangGraph->addOp<T>(inputGpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0);
|
||||||
bangGraph->dataMalloc();
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->setData(generator);
|
||||||
bangRuntime->run(bangGraph);
|
bangRuntime->run(bangGraph);
|
||||||
auto outputGpu = gpuOp->getOutput();
|
auto outputGpu = gpuOp->getOutput();
|
||||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
inputCpu->printData();
|
|
||||||
outputGpu2Cpu->printData();
|
|
||||||
EXPECT_TRUE(1);
|
EXPECT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(cnnl_Pooling, run) {
|
TEST(cnnl_Pooling, run) {
|
||||||
testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 1, 5, 5});
|
testPooling<MaxPoolObj>(IncrementalGenerator(), Shape{1, 3, 5, 5});
|
||||||
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 1, 5, 5});
|
testPooling<AvgPoolObj>(IncrementalGenerator(), Shape{1, 3, 5, 5});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/softmax.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <cmath>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(cuDNN_Softmax, run_axis1) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 4}, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 1);
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->copyin(vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
// Check
|
||||||
|
EXPECT_TRUE(outputGpu2Cpu->equalData(
|
||||||
|
vector<float>{0.032058604, 0.08714432, 0.23688284, 0.6439143,
|
||||||
|
0.032058604, 0.08714432, 0.23688284, 0.6439143}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(cuDNN_Softmax, run_axis0) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 4}, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 0);
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->copyin(vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
// Check
|
||||||
|
EXPECT_TRUE(
|
||||||
|
outputGpu2Cpu->equalData(vector<float>{0., 0., 0., 0., 1, 1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(cuDNN_Softmax2, run_axis1) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 1);
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->setData(IncrementalGenerator());
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
// Check
|
||||||
|
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||||
|
0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138,
|
||||||
|
0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862,
|
||||||
|
0.9820138, 0.9820138, 0.9820138, 0.9820138}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(cuDNN_Softmax2, run_axis2) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 2);
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->setData(IncrementalGenerator());
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
// Check
|
||||||
|
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||||
|
0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029,
|
||||||
|
0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971,
|
||||||
|
0.1192029, 0.1192029, 0.8807971, 0.8807971}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(cuDNN_Softmax2, run_axis3) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 3);
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
inputGpu->setData(IncrementalGenerator());
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
// Check
|
||||||
|
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||||
|
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
|
||||||
|
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
|
||||||
|
0.2689414, 0.7310586, 0.2689414, 0.7310586}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue