forked from jiuyuan/InfiniTensor
Compare commits
9 Commits
master
...
dev-leakyr
Author | SHA1 | Date |
---|---|---|
zhangyue | a889527aa5 | |
Zhang Bolun | 2acb680c64 | |
Zhang Bolun | 5862671c0c | |
Zhang Bolun | 917e82e90c | |
zhangyunze | d1799b67a3 | |
weijie01 | 36baae7615 | |
Zhang Bolun | 23b1612192 | |
zhangyunze | 77fd137dcb | |
zhangyunze | c6de91ee82 |
|
@ -53,6 +53,7 @@ class GraphHandlerObj {
|
|||
Tensor max(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor leakyRelu(Tensor x, Tensor y, float alpha);
|
||||
Tensor silu(Tensor x, Tensor y);
|
||||
Tensor gelu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
|
|
|
@ -15,6 +15,8 @@ template <typename T> void gelu_kernel(T *input, T *output, size_t num);
|
|||
template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||
template <typename T>
|
||||
void leaky_relu_kernel(T *input, T *output, size_t num, float alpha);
|
||||
|
||||
template <typename INPUT, typename OUTPUT>
|
||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||
|
|
|
@ -228,6 +228,23 @@ class PReluObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class LeakyReluObj : public OperatorObj {
|
||||
public:
|
||||
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float alpha);
|
||||
OP_CLONE(LeakyReluObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getAlpha() const { return alphaValue; }
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
float alphaValue;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class LogObj : public OperatorObj {
|
||||
public:
|
||||
enum LogType {
|
||||
|
|
|
@ -85,7 +85,8 @@ class OnnxStub:
|
|||
while len(sorted_nodes) < len(model.graph.node):
|
||||
updated = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
if all(t in known_edge for t in node.input):
|
||||
# TODO:目前只考虑了resize算子输入为空的情况
|
||||
if all(t in known_edge or t == "" for t in node.input):
|
||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||
sorted_nodes.append(i)
|
||||
known_edge.update(node.output)
|
||||
|
@ -112,7 +113,6 @@ class OnnxStub:
|
|||
)
|
||||
tensors[input.name].set_input()
|
||||
|
||||
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
if node.op_type == "Conv":
|
||||
|
@ -209,8 +209,8 @@ class OnnxStub:
|
|||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]], # input
|
||||
tensors[node.input[1]], # weight
|
||||
tensors[node.input[0]], # input
|
||||
tensors[node.input[1]], # weight
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
|
@ -447,6 +447,15 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "LeakyRelu":
|
||||
tensors[node.output[0]] = self.handler.leakyRelu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.f for attr in node.attribute if attr.name == "alpha"),
|
||||
0.01,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Silu":
|
||||
tensors[node.output[0]] = self.handler.silu(
|
||||
tensors[node.input[0]],
|
||||
|
@ -530,12 +539,16 @@ class OnnxStub:
|
|||
tensors[node.output[0]] = self.handler.clip(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(_parse_data(data[node.input[1]]).__iter__(), None)
|
||||
if len(node.input) > 1
|
||||
else None,
|
||||
next(_parse_data(data[node.input[2]]).__iter__(), None)
|
||||
if len(node.input) > 2
|
||||
else None,
|
||||
(
|
||||
next(_parse_data(data[node.input[1]]).__iter__(), None)
|
||||
if len(node.input) > 1
|
||||
else None
|
||||
),
|
||||
(
|
||||
next(_parse_data(data[node.input[2]]).__iter__(), None)
|
||||
if len(node.input) > 2
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Transpose":
|
||||
perm = next(
|
||||
|
@ -601,15 +614,15 @@ class OnnxStub:
|
|||
"nearest_mode",
|
||||
]
|
||||
)
|
||||
if len(node.input) > 1:
|
||||
if len(node.input) > 1 and node.input[1] in data:
|
||||
roiVal = _parse_data(data[node.input[1]])
|
||||
else:
|
||||
roiVal = []
|
||||
if len(node.input) > 2:
|
||||
if len(node.input) > 2 and node.input[2] in data:
|
||||
scalesVal = _parse_data(data[node.input[2]])
|
||||
else:
|
||||
scalesVal = []
|
||||
if len(node.input) > 3:
|
||||
if len(node.input) > 3 and node.input[3] in data:
|
||||
sizesVal = _parse_data(data[node.input[3]])
|
||||
else:
|
||||
sizesVal = []
|
||||
|
@ -617,9 +630,21 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
output,
|
||||
axes,
|
||||
tensors[node.input[3]] if len(node.input) > 3 else None,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||
(
|
||||
tensors[node.input[3]]
|
||||
if len(node.input) > 3 and node.input[3] != ""
|
||||
else None
|
||||
),
|
||||
(
|
||||
tensors[node.input[2]]
|
||||
if len(node.input) > 2 and node.input[2] != ""
|
||||
else None
|
||||
),
|
||||
(
|
||||
tensors[node.input[1]]
|
||||
if len(node.input) > 1 and node.input[1] != ""
|
||||
else None
|
||||
),
|
||||
sizesVal,
|
||||
scalesVal,
|
||||
roiVal,
|
||||
|
@ -629,18 +654,10 @@ class OnnxStub:
|
|||
coordinate_transformation_mode,
|
||||
)
|
||||
elif node.op_type == "Squeeze":
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else None
|
||||
)
|
||||
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
[],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.squeeze(
|
||||
|
@ -649,18 +666,10 @@ class OnnxStub:
|
|||
axes,
|
||||
)
|
||||
elif node.op_type == "Unsqueeze":
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else None
|
||||
)
|
||||
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
)
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes")
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.unsqueeze(
|
||||
tensors[node.input[0]],
|
||||
|
@ -684,24 +693,18 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "RoPE":
|
||||
tensors[node.output[0]]= self.handler.RoPE(
|
||||
tensors[node.output[0]] = self.handler.RoPE(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
split = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if (len(node.input) > 1)
|
||||
else None
|
||||
_parse_data(data[node.input[1]]) if (len(node.input) > 1) else None
|
||||
)
|
||||
if split is None:
|
||||
split = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "split"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "split"),
|
||||
None,
|
||||
)
|
||||
for name, tensor in zip(
|
||||
|
@ -710,11 +713,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
None,
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "axis"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
0,
|
||||
),
|
||||
split if split is not None else len(node.output),
|
||||
|
@ -767,12 +766,16 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
clamp(_parse_data(data[node.input[1]])),
|
||||
clamp(_parse_data(data[node.input[2]])),
|
||||
clamp(_parse_data(data[node.input[3]]))
|
||||
if len(node.input) > 3
|
||||
else None,
|
||||
clamp(_parse_data(data[node.input[4]]))
|
||||
if len(node.input) > 4
|
||||
else None,
|
||||
(
|
||||
clamp(_parse_data(data[node.input[3]]))
|
||||
if len(node.input) > 3
|
||||
else None
|
||||
),
|
||||
(
|
||||
clamp(_parse_data(data[node.input[4]]))
|
||||
if len(node.input) > 4
|
||||
else None
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Pad":
|
||||
tensors[node.output[0]] = self.handler.pad(
|
||||
|
@ -788,12 +791,16 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
||||
_parse_data(data[node.input[1]])[0]
|
||||
if len(node.input) > 1
|
||||
else 0.5,
|
||||
_parse_data(data[node.input[2]])[0]
|
||||
if len(node.input) > 2
|
||||
else False,
|
||||
(
|
||||
_parse_data(data[node.input[1]])[0]
|
||||
if len(node.input) > 1
|
||||
else 0.5
|
||||
),
|
||||
(
|
||||
_parse_data(data[node.input[2]])[0]
|
||||
if len(node.input) > 2
|
||||
else False
|
||||
),
|
||||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
|
@ -942,18 +949,25 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Where":
|
||||
## If Y is single -inf, treat Where as Add
|
||||
## If Y is single -inf, treat Where as Add
|
||||
## TODO: deal with cases where Y is single inf or 0
|
||||
if node.input[0] in data and node.input[2] in data:
|
||||
where_condition = to_array(data[node.input[0]])
|
||||
where_alt = to_array(data[node.input[2]])
|
||||
where_alt = to_array(data[node.input[2]])
|
||||
if where_alt.size == 1:
|
||||
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
||||
node.input[0] = node.input[0] + "_alt"
|
||||
if node.input[0] not in data:
|
||||
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
|
||||
data[node.input[0]] = from_array(where_value, node.input[0])
|
||||
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
|
||||
where_value = np.where(
|
||||
where_condition, 0, -np.inf
|
||||
).astype(where_alt.dtype)
|
||||
data[node.input[0]] = from_array(
|
||||
where_value, node.input[0]
|
||||
)
|
||||
tensors[node.input[0]] = self.handler.tensor(
|
||||
list(where_value.shape),
|
||||
data[node.input[0]].data_type,
|
||||
)
|
||||
tensors[node.input[0]].set_weight()
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.input[1]],
|
||||
|
@ -980,8 +994,7 @@ class OnnxStub:
|
|||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||
)
|
||||
(alpha, beta, bias, size) = (
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "bias", "size"]
|
||||
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.lrn(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -222,6 +222,15 @@ Tensor GraphHandlerObj::pRelu(Tensor x, Tensor slope, Tensor y) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::leakyRelu(Tensor x, Tensor y, float alpha) {
|
||||
if (y) {
|
||||
g->addOpWithOutputs<LeakyReluObj>(std::move(x), y, alpha);
|
||||
return y;
|
||||
} else {
|
||||
return g->addOp<LeakyReluObj>(std::move(x), y, alpha)->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::clip(Tensor x, Tensor y, std::optional<float> min,
|
||||
std::optional<float> max) {
|
||||
if (y) {
|
||||
|
|
|
@ -562,6 +562,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("expand", &Handler::expand, policy::move)
|
||||
.def("erf", &Handler::erf, policy::move)
|
||||
.def("where", &Handler::where, policy::move)
|
||||
.def("leakyRelu", &Handler::leakyRelu, policy::move)
|
||||
.def("lrn", &Handler::lrn, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
|
|
|
@ -241,8 +241,50 @@ class HardSigmoidCnnl : public UnaryCnnl {
|
|||
float getScale() const override { return 0.5f; }
|
||||
};
|
||||
|
||||
class LeakyReluCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto cDim = op->getOutput()->getDims();
|
||||
auto coef = op->getAlpha();
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
aDim.size(), aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
cDim.size(), cDim.data()));
|
||||
cnnlActivationDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetActivationDescriptor_v5(
|
||||
opDesc, CNNL_ACTIVATION_LEAKYRELU, CNNL_ACTIVATION_HIGH_PRECISION,
|
||||
CNNL_NOT_PROPAGATE_NAN, coef, 0.0, 0.0, 0.0, true));
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
cnnlStatus_t stat =
|
||||
cnnlActivationForward(context->cnnlHandle(), opDesc, &alpha, aDesc,
|
||||
aData, &beta, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
checkCnnlError(cnnlDestroyActivationDescriptor(opDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LeakyRelu, LeakyReluCnnl,
|
||||
"LeakyRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||
|
|
|
@ -16,10 +16,14 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
auto outDims = op->getOutput()->getDims();
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto padDims = [](Shape shape) {
|
||||
for (size_t i = shape.size(); i < 4; ++i) {
|
||||
shape.push_back(1);
|
||||
}
|
||||
return shape;
|
||||
};
|
||||
auto dims = padDims(op->getInputs(0)->getDims());
|
||||
auto outDims = padDims(op->getOutput()->getDims());
|
||||
|
||||
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
||||
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
#include "operators/resize.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace infini {
|
||||
class ResizeCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ResizeObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto nDims = op->getInputs(0)->getRank();
|
||||
if (nDims != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto cDim = op->getOutput()->getDims();
|
||||
std::vector<int> aTransDim = {aDim[0], aDim[2], aDim[3], aDim[1]};
|
||||
std::vector<int> cTransDim = {cDim[0], cDim[2], cDim[3], cDim[1]};
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc, aTransDesc, cTransDesc;
|
||||
// input
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
aDim.size(), aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aTransDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aTransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
|
||||
aTransDim.size(), aTransDim.data()));
|
||||
// output
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
cDim.size(), cDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cTransDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cTransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
|
||||
cTransDim.size(), cTransDim.data()));
|
||||
|
||||
// transpose
|
||||
BangPtr aTransData = context->getWorkspace(
|
||||
cnnlGetTensorElementNum(aTransDesc) * op->getDType().getSize());
|
||||
BangPtr cTransData = context->getWorkspace(
|
||||
cnnlGetTensorElementNum(cTransDesc) * op->getDType().getSize());
|
||||
|
||||
int permuteIn[4] = {0, 2, 3, 1};
|
||||
cnnlTransposeDescriptor_t inDesc;
|
||||
checkCnnlError(cnnlCreateTransposeDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTransposeDescriptor(inDesc, 4, permuteIn));
|
||||
size_t wsSizeIn;
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), aDesc, inDesc,
|
||||
&wsSizeIn);
|
||||
BangPtr wsDataIn = context->getWorkspace(wsSizeIn);
|
||||
|
||||
checkCnnlError(cnnlTranspose_v2(context->cnnlHandle(), inDesc, aDesc,
|
||||
aData, aTransDesc, aTransData, wsDataIn,
|
||||
wsSizeIn));
|
||||
|
||||
cnnlTensorDescriptor_t boxesDesc, boxesIndexDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&boxesDesc));
|
||||
auto nBatch = aDim[0];
|
||||
std::vector<int> boxesDim = {nBatch, 4};
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
boxesDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
boxesDim.size(), boxesDim.data()));
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&boxesIndexDesc));
|
||||
std::vector<int> boxesIndexDim = {nBatch};
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
boxesIndexDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32,
|
||||
boxesIndexDim.size(), boxesIndexDim.data()));
|
||||
std::vector<int32_t> boxesIndex(nBatch);
|
||||
std::iota(boxesIndex.begin(), boxesIndex.end(), 0);
|
||||
BangPtr boxesIndexData =
|
||||
context->getWorkspace(nBatch * sizeof(int32_t));
|
||||
context->copyBlobFromCPU(boxesIndexData, boxesIndex.data(),
|
||||
nBatch * sizeof(int32_t));
|
||||
|
||||
cnnlCropAndResizeMode_t mode;
|
||||
auto coefMode = op->getMode();
|
||||
if (coefMode == ResizeObj::ECoeffMode::nearest) {
|
||||
// CNNL uses round by default and
|
||||
// does not support other nearest modes
|
||||
mode = CNNL_CROP_AND_RESIZE_NEAREST;
|
||||
} else if (coefMode == ResizeObj::ECoeffMode::linear) {
|
||||
mode = CNNL_CROP_AND_RESIZE_BILINEAR;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
std::vector<float> box;
|
||||
auto transMode = op->getCoordinateTransMode();
|
||||
if (transMode ==
|
||||
enum_to_underlying(
|
||||
ResizeObj::ECoordinateTransMode::tfCropAndResize)) {
|
||||
box = {op->getRoi(2), op->getRoi(3), op->getRoi(6), op->getRoi(7)};
|
||||
} else {
|
||||
box = {0, 0, 1.0, 1.0};
|
||||
}
|
||||
|
||||
BangPtr boxesData =
|
||||
context->getWorkspace(nBatch * box.size() * sizeof(float));
|
||||
for (auto i = 0; i < nBatch; i++) {
|
||||
context->copyBlobFromCPU(boxesData + i * box.size() * sizeof(float),
|
||||
box.data(), box.size() * sizeof(float));
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCropAndResize(
|
||||
context->cnnlHandle(), aTransDesc, aTransData, boxesDesc, boxesData,
|
||||
boxesIndexDesc, boxesIndexData, mode, 0.0, cTransDesc, cTransData));
|
||||
|
||||
// transpose
|
||||
int permuteOut[4] = {0, 3, 1, 2};
|
||||
cnnlTransposeDescriptor_t outDesc;
|
||||
checkCnnlError(cnnlCreateTransposeDescriptor(&outDesc));
|
||||
checkCnnlError(cnnlSetTransposeDescriptor(outDesc, 4, permuteOut));
|
||||
size_t wsSizeOut;
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cTransDesc,
|
||||
outDesc, &wsSizeOut);
|
||||
BangPtr wsDataOut = context->getWorkspace(wsSizeOut);
|
||||
|
||||
checkCnnlError(cnnlTranspose_v2(context->cnnlHandle(), outDesc,
|
||||
cTransDesc, cTransData, cDesc, cData,
|
||||
wsDataOut, wsSizeOut));
|
||||
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aTransDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cTransDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(boxesDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(boxesIndexDesc));
|
||||
checkCnnlError(cnnlDestroyTransposeDescriptor(inDesc));
|
||||
checkCnnlError(cnnlDestroyTransposeDescriptor(outDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Resize, ResizeCnnl, "Resize_cnnl_BANG");
|
||||
}; // namespace infini
|
|
@ -18,9 +18,16 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
// Only 4D and 5D tensors are supported by
|
||||
// cudnnBatchNormalizationForwardInference
|
||||
if (auto dims = op->getInputs(0)->getDims(); dims.size() < 4) {
|
||||
auto dims_t = dims;
|
||||
for (size_t i = dims_t.size(); i < 4; ++i) {
|
||||
dims_t.push_back(1);
|
||||
}
|
||||
op->getInputs(0)->setShape(dims_t);
|
||||
}
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
IT_ASSERT(dims.size() == 4);
|
||||
|
||||
int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];
|
||||
|
|
|
@ -173,6 +173,25 @@ class TanhCudnn : public ActivationCudnn {
|
|||
}
|
||||
};
|
||||
|
||||
class LeakyReluCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
auto alpha = op->getAlpha();
|
||||
size_t num = op->getOutput()->size();
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
leaky_relu_kernel<float>((float *)inputData, (float *)outputData,
|
||||
num, alpha);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda,
|
||||
|
@ -185,11 +204,14 @@ REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::LeakyRelu, LeakyReluCuda,
|
||||
"LeakyRelu_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Cast, CastCuda, "Cast_CUDA");
|
||||
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda,
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda,
|
||||
// "Softmax_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Relu,
|
||||
// UnaryCuda,
|
||||
// "Relu_CUDA");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda,
|
||||
// "Sigmoid_CUDA");
|
||||
|
|
|
@ -110,7 +110,8 @@ __global__ void _silu_kernel(T *input, T *output, size_t n) {
|
|||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
float x = input[i];
|
||||
output[i] = x / (1.0 + expf(-x));;
|
||||
output[i] = x / (1.0 + expf(-x));
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,33 +144,40 @@ __global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _leaky_relu_kernel(T *input, T *output, size_t n, float alpha) {
|
||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
size_t stride = blockDim.x * gridDim.x;
|
||||
for (size_t i = index; i < n; i += stride) {
|
||||
output[i] = input[i] > 0 ? input[i] : alpha * input[i];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_softmax_kernel1<T>
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>(input, output, num);
|
||||
_softmax_kernel2<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_relu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_relu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T>
|
||||
void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -177,75 +185,78 @@ void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_swish_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_tanh_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_tanh_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_abs_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_abs_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sqrt_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
((T *)input, (T *)output, num);
|
||||
_sqrt_kernel<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
(T *)input, (T *)output, num);
|
||||
}
|
||||
|
||||
template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_gelu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_silu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_silu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_erf_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_erf_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_neg_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_neg_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void leaky_relu_kernel(T *input, T *output, size_t num, float alpha) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_leaky_relu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num, alpha);
|
||||
}
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
|
@ -315,7 +326,7 @@ void unary_kernel(const Operator &_op) {
|
|||
} else if (op->getOpType() == OpType::Silu) {
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Float16){
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
silu_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
|
@ -346,8 +357,8 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_cast_kernel<INPUT, OUTPUT>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num);
|
||||
}
|
||||
|
||||
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
||||
|
@ -359,4 +370,6 @@ template void cast_kernel<float, int8_t>(float *input, int8_t *output,
|
|||
template void cast_kernel<int8_t, float>(int8_t *input, float *output,
|
||||
size_t num);
|
||||
|
||||
template void leaky_relu_kernel<float>(float *input, float *output, size_t num,
|
||||
float alpha);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -19,13 +19,17 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n, c, h, w;
|
||||
if (dims.size() != 4) {
|
||||
h = 1;
|
||||
w = 1;
|
||||
}
|
||||
|
||||
w = dims[3];
|
||||
h = dims[2];
|
||||
c = dims[1];
|
||||
n = dims[0];
|
||||
|
||||
int w = dims[3];
|
||||
int h = dims[2];
|
||||
int c = dims[1];
|
||||
int n = dims[0];
|
||||
auto ret = xdnn::batch_norm_infer<float>(
|
||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#include "operators/layer_norm.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class LayerNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LayerNormObj>(_op);
|
||||
auto context = static_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
float eps = op->getEps();
|
||||
// int axis = op->getAxis();
|
||||
|
||||
const auto &opInputShape = op->getInputs(0)->getDims();
|
||||
const auto &opOutputShape = op->getOutput()->getDims();
|
||||
IT_ASSERT(opInputShape.size() == 2);
|
||||
|
||||
int ret;
|
||||
if (op->numInputs() == 3) {
|
||||
// with bias
|
||||
void *const biasData = op->getInputs(2)->getRawDataPtr<void *>();
|
||||
ret = xdnn::layer_norm<float, float>(
|
||||
context->KUNLUNHandle(), (float const *)inputData,
|
||||
(float *)outputData, opInputShape[0], opInputShape[1], eps,
|
||||
(float *)scaleData, (float *)biasData, nullptr, nullptr);
|
||||
} else {
|
||||
// without bias
|
||||
ret = xdnn::layer_norm<float, float>(
|
||||
context->KUNLUNHandle(), (float const *)inputData,
|
||||
(float *)outputData, opInputShape[0], opInputShape[1], eps,
|
||||
(float *)scaleData, nullptr, nullptr, nullptr);
|
||||
}
|
||||
assert(ret == 0);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::LayerNormalization, LayerNormXdnn,
|
||||
"LayerNorm_xdnn_KUNLUN");
|
||||
|
||||
}; // namespace infini
|
|
@ -21,6 +21,26 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class LeakyReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto alpha = op->getAlpha();
|
||||
|
||||
auto ret = xdnn::leaky_relu<float>(context->KUNLUNHandle(),
|
||||
(float *const)aData, (float *)cData,
|
||||
len, alpha);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
|
@ -552,6 +572,8 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
|||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::LeakyRelu, LeakyReluXdnn,
|
||||
"LeakyRelu_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
||||
"Sigmoid_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
||||
|
|
|
@ -283,6 +283,38 @@ vector<int> PReluObj::getWorkloadVector() const {
|
|||
|
||||
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
float alpha)
|
||||
: OperatorObj(OpType::LeakyRelu, {input}, {output}), alphaValue(alpha) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> LeakyReluObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
||||
std::string LeakyReluObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> LeakyReluObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> LeakyReluObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
|
||||
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
|
||||
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "cmath"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/resize.h"
|
||||
#include "test.h"
|
||||
namespace infini {
|
||||
TEST(Resize, Bang_downsample_sizes_nearest) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||
|
||||
auto inputMlu = gMlu->cloneTensor(input);
|
||||
auto scalesMlu = gMlu->cloneTensor(scales);
|
||||
auto op = gMlu->addOp<ResizeObj>(inputMlu, nullptr, std::nullopt, nullptr,
|
||||
scalesMlu, nullptr);
|
||||
gMlu->dataMalloc();
|
||||
inputMlu->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scalesMlu->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
bangRuntime->run(gMlu);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{5, 8}));
|
||||
}
|
||||
|
||||
TEST(Resize, Bang_upsample_sizes_nearest) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
scales->copyin(vector<float>{1, 1, 2, 3});
|
||||
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||
|
||||
auto inputMlu = gMlu->cloneTensor(input);
|
||||
auto scalesMlu = gMlu->cloneTensor(scales);
|
||||
auto op = gMlu->addOp<ResizeObj>(inputMlu, nullptr, std::nullopt, nullptr,
|
||||
scalesMlu, nullptr);
|
||||
gMlu->dataMalloc();
|
||||
inputMlu->copyin(vector<float>{1, 2, 3, 4});
|
||||
scalesMlu->copyin(vector<float>{1, 1, 2, 3});
|
||||
|
||||
bangRuntime->run(gMlu);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
|
||||
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}));
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue