forked from jiuyuan/InfiniTensor
Compare commits
9 Commits
master
...
dev-leakyr
Author | SHA1 | Date |
---|---|---|
zhangyue | a889527aa5 | |
Zhang Bolun | 2acb680c64 | |
Zhang Bolun | 5862671c0c | |
Zhang Bolun | 917e82e90c | |
zhangyunze | d1799b67a3 | |
weijie01 | 36baae7615 | |
Zhang Bolun | 23b1612192 | |
zhangyunze | 77fd137dcb | |
zhangyunze | c6de91ee82 |
|
@ -53,6 +53,7 @@ class GraphHandlerObj {
|
||||||
Tensor max(Tensor a, Tensor b, Tensor c);
|
Tensor max(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
||||||
Tensor relu(Tensor x, Tensor y);
|
Tensor relu(Tensor x, Tensor y);
|
||||||
|
Tensor leakyRelu(Tensor x, Tensor y, float alpha);
|
||||||
Tensor silu(Tensor x, Tensor y);
|
Tensor silu(Tensor x, Tensor y);
|
||||||
Tensor gelu(Tensor x, Tensor y);
|
Tensor gelu(Tensor x, Tensor y);
|
||||||
Tensor sigmoid(Tensor x, Tensor y);
|
Tensor sigmoid(Tensor x, Tensor y);
|
||||||
|
|
|
@ -15,6 +15,8 @@ template <typename T> void gelu_kernel(T *input, T *output, size_t num);
|
||||||
template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
||||||
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
||||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||||
|
template <typename T>
|
||||||
|
void leaky_relu_kernel(T *input, T *output, size_t num, float alpha);
|
||||||
|
|
||||||
template <typename INPUT, typename OUTPUT>
|
template <typename INPUT, typename OUTPUT>
|
||||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||||
|
|
|
@ -228,6 +228,23 @@ class PReluObj : public OperatorObj {
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LeakyReluObj : public OperatorObj {
|
||||||
|
public:
|
||||||
|
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float alpha);
|
||||||
|
OP_CLONE(LeakyReluObj);
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
float getAlpha() const { return alphaValue; }
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
float alphaValue;
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
class LogObj : public OperatorObj {
|
class LogObj : public OperatorObj {
|
||||||
public:
|
public:
|
||||||
enum LogType {
|
enum LogType {
|
||||||
|
|
|
@ -85,7 +85,8 @@ class OnnxStub:
|
||||||
while len(sorted_nodes) < len(model.graph.node):
|
while len(sorted_nodes) < len(model.graph.node):
|
||||||
updated = False
|
updated = False
|
||||||
for i, node in enumerate(model.graph.node):
|
for i, node in enumerate(model.graph.node):
|
||||||
if all(t in known_edge for t in node.input):
|
# TODO:目前只考虑了resize算子输入为空的情况
|
||||||
|
if all(t in known_edge or t == "" for t in node.input):
|
||||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||||
sorted_nodes.append(i)
|
sorted_nodes.append(i)
|
||||||
known_edge.update(node.output)
|
known_edge.update(node.output)
|
||||||
|
@ -112,7 +113,6 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
tensors[input.name].set_input()
|
tensors[input.name].set_input()
|
||||||
|
|
||||||
|
|
||||||
for node_idx in sorted_nodes:
|
for node_idx in sorted_nodes:
|
||||||
node = model.graph.node[node_idx]
|
node = model.graph.node[node_idx]
|
||||||
if node.op_type == "Conv":
|
if node.op_type == "Conv":
|
||||||
|
@ -209,8 +209,8 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
elif node.op_type == "MatMul":
|
elif node.op_type == "MatMul":
|
||||||
tensors[node.output[0]] = self.handler.matmul(
|
tensors[node.output[0]] = self.handler.matmul(
|
||||||
tensors[node.input[0]], # input
|
tensors[node.input[0]], # input
|
||||||
tensors[node.input[1]], # weight
|
tensors[node.input[1]], # weight
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
|
@ -447,6 +447,15 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "LeakyRelu":
|
||||||
|
tensors[node.output[0]] = self.handler.leakyRelu(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
next(
|
||||||
|
(attr.f for attr in node.attribute if attr.name == "alpha"),
|
||||||
|
0.01,
|
||||||
|
),
|
||||||
|
)
|
||||||
elif node.op_type == "Silu":
|
elif node.op_type == "Silu":
|
||||||
tensors[node.output[0]] = self.handler.silu(
|
tensors[node.output[0]] = self.handler.silu(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -530,12 +539,16 @@ class OnnxStub:
|
||||||
tensors[node.output[0]] = self.handler.clip(
|
tensors[node.output[0]] = self.handler.clip(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next(_parse_data(data[node.input[1]]).__iter__(), None)
|
(
|
||||||
if len(node.input) > 1
|
next(_parse_data(data[node.input[1]]).__iter__(), None)
|
||||||
else None,
|
if len(node.input) > 1
|
||||||
next(_parse_data(data[node.input[2]]).__iter__(), None)
|
else None
|
||||||
if len(node.input) > 2
|
),
|
||||||
else None,
|
(
|
||||||
|
next(_parse_data(data[node.input[2]]).__iter__(), None)
|
||||||
|
if len(node.input) > 2
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Transpose":
|
elif node.op_type == "Transpose":
|
||||||
perm = next(
|
perm = next(
|
||||||
|
@ -601,15 +614,15 @@ class OnnxStub:
|
||||||
"nearest_mode",
|
"nearest_mode",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if len(node.input) > 1:
|
if len(node.input) > 1 and node.input[1] in data:
|
||||||
roiVal = _parse_data(data[node.input[1]])
|
roiVal = _parse_data(data[node.input[1]])
|
||||||
else:
|
else:
|
||||||
roiVal = []
|
roiVal = []
|
||||||
if len(node.input) > 2:
|
if len(node.input) > 2 and node.input[2] in data:
|
||||||
scalesVal = _parse_data(data[node.input[2]])
|
scalesVal = _parse_data(data[node.input[2]])
|
||||||
else:
|
else:
|
||||||
scalesVal = []
|
scalesVal = []
|
||||||
if len(node.input) > 3:
|
if len(node.input) > 3 and node.input[3] in data:
|
||||||
sizesVal = _parse_data(data[node.input[3]])
|
sizesVal = _parse_data(data[node.input[3]])
|
||||||
else:
|
else:
|
||||||
sizesVal = []
|
sizesVal = []
|
||||||
|
@ -617,9 +630,21 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
output,
|
output,
|
||||||
axes,
|
axes,
|
||||||
tensors[node.input[3]] if len(node.input) > 3 else None,
|
(
|
||||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
tensors[node.input[3]]
|
||||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
if len(node.input) > 3 and node.input[3] != ""
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
tensors[node.input[2]]
|
||||||
|
if len(node.input) > 2 and node.input[2] != ""
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
(
|
||||||
|
tensors[node.input[1]]
|
||||||
|
if len(node.input) > 1 and node.input[1] != ""
|
||||||
|
else None
|
||||||
|
),
|
||||||
sizesVal,
|
sizesVal,
|
||||||
scalesVal,
|
scalesVal,
|
||||||
roiVal,
|
roiVal,
|
||||||
|
@ -629,18 +654,10 @@ class OnnxStub:
|
||||||
coordinate_transformation_mode,
|
coordinate_transformation_mode,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Squeeze":
|
elif node.op_type == "Squeeze":
|
||||||
axes = (
|
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||||
_parse_data(data[node.input[1]])
|
|
||||||
if len(node.input) > 1
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if axes is None:
|
if axes is None:
|
||||||
axes = next(
|
axes = next(
|
||||||
(
|
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||||
attr.ints
|
|
||||||
for attr in node.attribute
|
|
||||||
if attr.name == "axes"
|
|
||||||
),
|
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
tensors[node.output[0]] = self.handler.squeeze(
|
tensors[node.output[0]] = self.handler.squeeze(
|
||||||
|
@ -649,18 +666,10 @@ class OnnxStub:
|
||||||
axes,
|
axes,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Unsqueeze":
|
elif node.op_type == "Unsqueeze":
|
||||||
axes = (
|
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
|
||||||
_parse_data(data[node.input[1]])
|
|
||||||
if len(node.input) > 1
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if axes is None:
|
if axes is None:
|
||||||
axes = next(
|
axes = next(
|
||||||
(
|
(attr.ints for attr in node.attribute if attr.name == "axes")
|
||||||
attr.ints
|
|
||||||
for attr in node.attribute
|
|
||||||
if attr.name == "axes"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
tensors[node.output[0]] = self.handler.unsqueeze(
|
tensors[node.output[0]] = self.handler.unsqueeze(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -684,24 +693,18 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "RoPE":
|
elif node.op_type == "RoPE":
|
||||||
tensors[node.output[0]]= self.handler.RoPE(
|
tensors[node.output[0]] = self.handler.RoPE(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Split":
|
elif node.op_type == "Split":
|
||||||
split = (
|
split = (
|
||||||
_parse_data(data[node.input[1]])
|
_parse_data(data[node.input[1]]) if (len(node.input) > 1) else None
|
||||||
if (len(node.input) > 1)
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
if split is None:
|
if split is None:
|
||||||
split = next(
|
split = next(
|
||||||
(
|
(attr.ints for attr in node.attribute if attr.name == "split"),
|
||||||
attr.ints
|
|
||||||
for attr in node.attribute
|
|
||||||
if attr.name == "split"
|
|
||||||
),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
for name, tensor in zip(
|
for name, tensor in zip(
|
||||||
|
@ -710,11 +713,7 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
None,
|
None,
|
||||||
next(
|
next(
|
||||||
(
|
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||||
attr.i
|
|
||||||
for attr in node.attribute
|
|
||||||
if attr.name == "axis"
|
|
||||||
),
|
|
||||||
0,
|
0,
|
||||||
),
|
),
|
||||||
split if split is not None else len(node.output),
|
split if split is not None else len(node.output),
|
||||||
|
@ -767,12 +766,16 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
clamp(_parse_data(data[node.input[1]])),
|
clamp(_parse_data(data[node.input[1]])),
|
||||||
clamp(_parse_data(data[node.input[2]])),
|
clamp(_parse_data(data[node.input[2]])),
|
||||||
clamp(_parse_data(data[node.input[3]]))
|
(
|
||||||
if len(node.input) > 3
|
clamp(_parse_data(data[node.input[3]]))
|
||||||
else None,
|
if len(node.input) > 3
|
||||||
clamp(_parse_data(data[node.input[4]]))
|
else None
|
||||||
if len(node.input) > 4
|
),
|
||||||
else None,
|
(
|
||||||
|
clamp(_parse_data(data[node.input[4]]))
|
||||||
|
if len(node.input) > 4
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Pad":
|
elif node.op_type == "Pad":
|
||||||
tensors[node.output[0]] = self.handler.pad(
|
tensors[node.output[0]] = self.handler.pad(
|
||||||
|
@ -788,12 +791,16 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
||||||
_parse_data(data[node.input[1]])[0]
|
(
|
||||||
if len(node.input) > 1
|
_parse_data(data[node.input[1]])[0]
|
||||||
else 0.5,
|
if len(node.input) > 1
|
||||||
_parse_data(data[node.input[2]])[0]
|
else 0.5
|
||||||
if len(node.input) > 2
|
),
|
||||||
else False,
|
(
|
||||||
|
_parse_data(data[node.input[2]])[0]
|
||||||
|
if len(node.input) > 2
|
||||||
|
else False
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tensors[name] = tensor
|
tensors[name] = tensor
|
||||||
|
@ -942,18 +949,25 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Where":
|
elif node.op_type == "Where":
|
||||||
## If Y is single -inf, treat Where as Add
|
## If Y is single -inf, treat Where as Add
|
||||||
## TODO: deal with cases where Y is single inf or 0
|
## TODO: deal with cases where Y is single inf or 0
|
||||||
if node.input[0] in data and node.input[2] in data:
|
if node.input[0] in data and node.input[2] in data:
|
||||||
where_condition = to_array(data[node.input[0]])
|
where_condition = to_array(data[node.input[0]])
|
||||||
where_alt = to_array(data[node.input[2]])
|
where_alt = to_array(data[node.input[2]])
|
||||||
if where_alt.size == 1:
|
if where_alt.size == 1:
|
||||||
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
||||||
node.input[0] = node.input[0] + "_alt"
|
node.input[0] = node.input[0] + "_alt"
|
||||||
if node.input[0] not in data:
|
if node.input[0] not in data:
|
||||||
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
|
where_value = np.where(
|
||||||
data[node.input[0]] = from_array(where_value, node.input[0])
|
where_condition, 0, -np.inf
|
||||||
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
|
).astype(where_alt.dtype)
|
||||||
|
data[node.input[0]] = from_array(
|
||||||
|
where_value, node.input[0]
|
||||||
|
)
|
||||||
|
tensors[node.input[0]] = self.handler.tensor(
|
||||||
|
list(where_value.shape),
|
||||||
|
data[node.input[0]].data_type,
|
||||||
|
)
|
||||||
tensors[node.input[0]].set_weight()
|
tensors[node.input[0]].set_weight()
|
||||||
tensors[node.output[0]] = self.handler.add(
|
tensors[node.output[0]] = self.handler.add(
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
|
@ -980,8 +994,7 @@ class OnnxStub:
|
||||||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||||
)
|
)
|
||||||
(alpha, beta, bias, size) = (
|
(alpha, beta, bias, size) = (
|
||||||
attributes[name]
|
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||||
for name in ["alpha", "beta", "bias", "size"]
|
|
||||||
)
|
)
|
||||||
tensors[node.output[0]] = self.handler.lrn(
|
tensors[node.output[0]] = self.handler.lrn(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
|
|
@ -222,6 +222,15 @@ Tensor GraphHandlerObj::pRelu(Tensor x, Tensor slope, Tensor y) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::leakyRelu(Tensor x, Tensor y, float alpha) {
|
||||||
|
if (y) {
|
||||||
|
g->addOpWithOutputs<LeakyReluObj>(std::move(x), y, alpha);
|
||||||
|
return y;
|
||||||
|
} else {
|
||||||
|
return g->addOp<LeakyReluObj>(std::move(x), y, alpha)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::clip(Tensor x, Tensor y, std::optional<float> min,
|
Tensor GraphHandlerObj::clip(Tensor x, Tensor y, std::optional<float> min,
|
||||||
std::optional<float> max) {
|
std::optional<float> max) {
|
||||||
if (y) {
|
if (y) {
|
||||||
|
|
|
@ -562,6 +562,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("expand", &Handler::expand, policy::move)
|
.def("expand", &Handler::expand, policy::move)
|
||||||
.def("erf", &Handler::erf, policy::move)
|
.def("erf", &Handler::erf, policy::move)
|
||||||
.def("where", &Handler::where, policy::move)
|
.def("where", &Handler::where, policy::move)
|
||||||
|
.def("leakyRelu", &Handler::leakyRelu, policy::move)
|
||||||
.def("lrn", &Handler::lrn, policy::move)
|
.def("lrn", &Handler::lrn, policy::move)
|
||||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
.def("optimize", &Handler::optimize, policy::automatic)
|
.def("optimize", &Handler::optimize, policy::automatic)
|
||||||
|
|
|
@ -241,8 +241,50 @@ class HardSigmoidCnnl : public UnaryCnnl {
|
||||||
float getScale() const override { return 0.5f; }
|
float getScale() const override { return 0.5f; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LeakyReluCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LeakyReluObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||||
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
|
auto cDim = op->getOutput()->getDims();
|
||||||
|
auto coef = op->getAlpha();
|
||||||
|
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
aDim.size(), aDim.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
cDim.size(), cDim.data()));
|
||||||
|
cnnlActivationDescriptor_t opDesc;
|
||||||
|
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||||
|
checkCnnlError(cnnlSetActivationDescriptor_v5(
|
||||||
|
opDesc, CNNL_ACTIVATION_LEAKYRELU, CNNL_ACTIVATION_HIGH_PRECISION,
|
||||||
|
CNNL_NOT_PROPAGATE_NAN, coef, 0.0, 0.0, 0.0, true));
|
||||||
|
|
||||||
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
cnnlStatus_t stat =
|
||||||
|
cnnlActivationForward(context->cnnlHandle(), opDesc, &alpha, aDesc,
|
||||||
|
aData, &beta, cDesc, cData);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
checkCnnlError(cnnlDestroyActivationDescriptor(opDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::LeakyRelu, LeakyReluCnnl,
|
||||||
|
"LeakyRelu_cnnl_BANG");
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||||
"Sigmoid_cnnl_BANG");
|
"Sigmoid_cnnl_BANG");
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||||
|
|
|
@ -16,10 +16,14 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
||||||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto dims = op->getInputs(0)->getDims();
|
auto padDims = [](Shape shape) {
|
||||||
auto outDims = op->getOutput()->getDims();
|
for (size_t i = shape.size(); i < 4; ++i) {
|
||||||
if (dims.size() != 4)
|
shape.push_back(1);
|
||||||
IT_TODO_HALT();
|
}
|
||||||
|
return shape;
|
||||||
|
};
|
||||||
|
auto dims = padDims(op->getInputs(0)->getDims());
|
||||||
|
auto outDims = padDims(op->getOutput()->getDims());
|
||||||
|
|
||||||
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
||||||
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
#include "operators/resize.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class ResizeCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<ResizeObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
auto nDims = op->getInputs(0)->getRank();
|
||||||
|
if (nDims != 4) {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
|
auto cDim = op->getOutput()->getDims();
|
||||||
|
std::vector<int> aTransDim = {aDim[0], aDim[2], aDim[3], aDim[1]};
|
||||||
|
std::vector<int> cTransDim = {cDim[0], cDim[2], cDim[3], cDim[1]};
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t aDesc, cDesc, aTransDesc, cTransDesc;
|
||||||
|
// input
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
aDim.size(), aDim.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aTransDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
aTransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
aTransDim.size(), aTransDim.data()));
|
||||||
|
// output
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
cDim.size(), cDim.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cTransDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
cTransDesc, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
cTransDim.size(), cTransDim.data()));
|
||||||
|
|
||||||
|
// transpose
|
||||||
|
BangPtr aTransData = context->getWorkspace(
|
||||||
|
cnnlGetTensorElementNum(aTransDesc) * op->getDType().getSize());
|
||||||
|
BangPtr cTransData = context->getWorkspace(
|
||||||
|
cnnlGetTensorElementNum(cTransDesc) * op->getDType().getSize());
|
||||||
|
|
||||||
|
int permuteIn[4] = {0, 2, 3, 1};
|
||||||
|
cnnlTransposeDescriptor_t inDesc;
|
||||||
|
checkCnnlError(cnnlCreateTransposeDescriptor(&inDesc));
|
||||||
|
checkCnnlError(cnnlSetTransposeDescriptor(inDesc, 4, permuteIn));
|
||||||
|
size_t wsSizeIn;
|
||||||
|
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), aDesc, inDesc,
|
||||||
|
&wsSizeIn);
|
||||||
|
BangPtr wsDataIn = context->getWorkspace(wsSizeIn);
|
||||||
|
|
||||||
|
checkCnnlError(cnnlTranspose_v2(context->cnnlHandle(), inDesc, aDesc,
|
||||||
|
aData, aTransDesc, aTransData, wsDataIn,
|
||||||
|
wsSizeIn));
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t boxesDesc, boxesIndexDesc;
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&boxesDesc));
|
||||||
|
auto nBatch = aDim[0];
|
||||||
|
std::vector<int> boxesDim = {nBatch, 4};
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
boxesDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
boxesDim.size(), boxesDim.data()));
|
||||||
|
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&boxesIndexDesc));
|
||||||
|
std::vector<int> boxesIndexDim = {nBatch};
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
boxesIndexDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32,
|
||||||
|
boxesIndexDim.size(), boxesIndexDim.data()));
|
||||||
|
std::vector<int32_t> boxesIndex(nBatch);
|
||||||
|
std::iota(boxesIndex.begin(), boxesIndex.end(), 0);
|
||||||
|
BangPtr boxesIndexData =
|
||||||
|
context->getWorkspace(nBatch * sizeof(int32_t));
|
||||||
|
context->copyBlobFromCPU(boxesIndexData, boxesIndex.data(),
|
||||||
|
nBatch * sizeof(int32_t));
|
||||||
|
|
||||||
|
cnnlCropAndResizeMode_t mode;
|
||||||
|
auto coefMode = op->getMode();
|
||||||
|
if (coefMode == ResizeObj::ECoeffMode::nearest) {
|
||||||
|
// CNNL uses round by default and
|
||||||
|
// does not support other nearest modes
|
||||||
|
mode = CNNL_CROP_AND_RESIZE_NEAREST;
|
||||||
|
} else if (coefMode == ResizeObj::ECoeffMode::linear) {
|
||||||
|
mode = CNNL_CROP_AND_RESIZE_BILINEAR;
|
||||||
|
} else {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> box;
|
||||||
|
auto transMode = op->getCoordinateTransMode();
|
||||||
|
if (transMode ==
|
||||||
|
enum_to_underlying(
|
||||||
|
ResizeObj::ECoordinateTransMode::tfCropAndResize)) {
|
||||||
|
box = {op->getRoi(2), op->getRoi(3), op->getRoi(6), op->getRoi(7)};
|
||||||
|
} else {
|
||||||
|
box = {0, 0, 1.0, 1.0};
|
||||||
|
}
|
||||||
|
|
||||||
|
BangPtr boxesData =
|
||||||
|
context->getWorkspace(nBatch * box.size() * sizeof(float));
|
||||||
|
for (auto i = 0; i < nBatch; i++) {
|
||||||
|
context->copyBlobFromCPU(boxesData + i * box.size() * sizeof(float),
|
||||||
|
box.data(), box.size() * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
checkCnnlError(cnnlCropAndResize(
|
||||||
|
context->cnnlHandle(), aTransDesc, aTransData, boxesDesc, boxesData,
|
||||||
|
boxesIndexDesc, boxesIndexData, mode, 0.0, cTransDesc, cTransData));
|
||||||
|
|
||||||
|
// transpose
|
||||||
|
int permuteOut[4] = {0, 3, 1, 2};
|
||||||
|
cnnlTransposeDescriptor_t outDesc;
|
||||||
|
checkCnnlError(cnnlCreateTransposeDescriptor(&outDesc));
|
||||||
|
checkCnnlError(cnnlSetTransposeDescriptor(outDesc, 4, permuteOut));
|
||||||
|
size_t wsSizeOut;
|
||||||
|
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cTransDesc,
|
||||||
|
outDesc, &wsSizeOut);
|
||||||
|
BangPtr wsDataOut = context->getWorkspace(wsSizeOut);
|
||||||
|
|
||||||
|
checkCnnlError(cnnlTranspose_v2(context->cnnlHandle(), outDesc,
|
||||||
|
cTransDesc, cTransData, cDesc, cData,
|
||||||
|
wsDataOut, wsSizeOut));
|
||||||
|
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aTransDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cTransDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(boxesDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(boxesIndexDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTransposeDescriptor(inDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTransposeDescriptor(outDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::Resize, ResizeCnnl, "Resize_cnnl_BANG");
|
||||||
|
}; // namespace infini
|
|
@ -18,9 +18,16 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
||||||
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
|
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||||
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto dims = op->getInputs(0)->getDims();
|
|
||||||
// Only 4D and 5D tensors are supported by
|
// Only 4D and 5D tensors are supported by
|
||||||
// cudnnBatchNormalizationForwardInference
|
// cudnnBatchNormalizationForwardInference
|
||||||
|
if (auto dims = op->getInputs(0)->getDims(); dims.size() < 4) {
|
||||||
|
auto dims_t = dims;
|
||||||
|
for (size_t i = dims_t.size(); i < 4; ++i) {
|
||||||
|
dims_t.push_back(1);
|
||||||
|
}
|
||||||
|
op->getInputs(0)->setShape(dims_t);
|
||||||
|
}
|
||||||
|
auto dims = op->getInputs(0)->getDims();
|
||||||
IT_ASSERT(dims.size() == 4);
|
IT_ASSERT(dims.size() == 4);
|
||||||
|
|
||||||
int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];
|
int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];
|
||||||
|
|
|
@ -173,6 +173,25 @@ class TanhCudnn : public ActivationCudnn {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LeakyReluCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
|
||||||
|
auto op = as<LeakyReluObj>(_op);
|
||||||
|
auto alpha = op->getAlpha();
|
||||||
|
size_t num = op->getOutput()->size();
|
||||||
|
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
if (op->getDType() == DataType::Float32) {
|
||||||
|
leaky_relu_kernel<float>((float *)inputData, (float *)outputData,
|
||||||
|
num, alpha);
|
||||||
|
} else {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda,
|
||||||
|
@ -185,11 +204,14 @@ REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::LeakyRelu, LeakyReluCuda,
|
||||||
|
"LeakyRelu_CUDA");
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Cast, CastCuda, "Cast_CUDA");
|
REGISTER_KERNEL(Device::CUDA, OpType::Cast, CastCuda, "Cast_CUDA");
|
||||||
|
|
||||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA");
|
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda,
|
||||||
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda,
|
// "Softmax_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Relu,
|
||||||
|
// UnaryCuda,
|
||||||
// "Relu_CUDA");
|
// "Relu_CUDA");
|
||||||
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda,
|
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda,
|
||||||
// "Sigmoid_CUDA");
|
// "Sigmoid_CUDA");
|
||||||
|
|
|
@ -110,7 +110,8 @@ __global__ void _silu_kernel(T *input, T *output, size_t n) {
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
for (int i = index; i < n; i += stride) {
|
for (int i = index; i < n; i += stride) {
|
||||||
float x = input[i];
|
float x = input[i];
|
||||||
output[i] = x / (1.0 + expf(-x));;
|
output[i] = x / (1.0 + expf(-x));
|
||||||
|
;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,33 +144,40 @@ __global__ void _cast_kernel(INPUT *input, OUTPUT *output, size_t n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void _leaky_relu_kernel(T *input, T *output, size_t n, float alpha) {
|
||||||
|
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
size_t stride = blockDim.x * gridDim.x;
|
||||||
|
for (size_t i = index; i < n; i += stride) {
|
||||||
|
output[i] = input[i] > 0 ? input[i] : alpha * input[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_softmax_kernel1<T>
|
_softmax_kernel1<T>
|
||||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>
|
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>(input, output, num);
|
||||||
(input, output, num);
|
|
||||||
_softmax_kernel2<T>
|
_softmax_kernel2<T>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(input, output, num);
|
input, output, num);
|
||||||
}
|
}
|
||||||
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_relu_kernel<T>
|
_relu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sigmoid_kernel<T>
|
_sigmoid_kernel<T>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(input, output, num);
|
input, output, num);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
||||||
|
@ -177,75 +185,78 @@ void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_hard_sigmoid_kernel<T>
|
_hard_sigmoid_kernel<T>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(input, output, num);
|
input, output, num);
|
||||||
}
|
}
|
||||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_hard_swish_kernel<T>
|
_hard_swish_kernel<T>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(input, output, num);
|
input, output, num);
|
||||||
}
|
}
|
||||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_tanh_kernel<T>
|
_tanh_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_abs_kernel<T>
|
_abs_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sqrt_kernel
|
_sqrt_kernel<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
(T *)input, (T *)output, num);
|
||||||
((T *)input, (T *)output, num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_gelu_kernel<T>
|
_gelu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_silu_kernel<T>
|
_silu_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_erf_kernel<T>
|
_erf_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
|
||||||
}
|
}
|
||||||
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_neg_kernel<T>
|
_neg_kernel<T><<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
input, output, num);
|
||||||
(input, output, num);
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void leaky_relu_kernel(T *input, T *output, size_t num, float alpha) {
|
||||||
|
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
|
_leaky_relu_kernel<T>
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
|
input, output, num, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
void unary_kernel(const Operator &_op) {
|
void unary_kernel(const Operator &_op) {
|
||||||
|
@ -315,7 +326,7 @@ void unary_kernel(const Operator &_op) {
|
||||||
} else if (op->getOpType() == OpType::Silu) {
|
} else if (op->getOpType() == OpType::Silu) {
|
||||||
if (_op->getDType() == DataType::Float32) {
|
if (_op->getDType() == DataType::Float32) {
|
||||||
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||||
} else if (_op->getDType() == DataType::Float16){
|
} else if (_op->getDType() == DataType::Float16) {
|
||||||
silu_kernel<half>((half *)inputData, (half *)outputData, num);
|
silu_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||||
} else {
|
} else {
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
@ -346,8 +357,8 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_cast_kernel<INPUT, OUTPUT>
|
_cast_kernel<INPUT, OUTPUT>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(input, output, num);
|
input, output, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
||||||
|
@ -359,4 +370,6 @@ template void cast_kernel<float, int8_t>(float *input, int8_t *output,
|
||||||
template void cast_kernel<int8_t, float>(int8_t *input, float *output,
|
template void cast_kernel<int8_t, float>(int8_t *input, float *output,
|
||||||
size_t num);
|
size_t num);
|
||||||
|
|
||||||
|
template void leaky_relu_kernel<float>(float *input, float *output, size_t num,
|
||||||
|
float alpha);
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -19,13 +19,17 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto dims = op->getInputs(0)->getDims();
|
auto dims = op->getInputs(0)->getDims();
|
||||||
|
|
||||||
if (dims.size() != 4)
|
int n, c, h, w;
|
||||||
IT_TODO_HALT();
|
if (dims.size() != 4) {
|
||||||
|
h = 1;
|
||||||
|
w = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
w = dims[3];
|
||||||
|
h = dims[2];
|
||||||
|
c = dims[1];
|
||||||
|
n = dims[0];
|
||||||
|
|
||||||
int w = dims[3];
|
|
||||||
int h = dims[2];
|
|
||||||
int c = dims[1];
|
|
||||||
int n = dims[0];
|
|
||||||
auto ret = xdnn::batch_norm_infer<float>(
|
auto ret = xdnn::batch_norm_infer<float>(
|
||||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "operators/layer_norm.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class LayerNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LayerNormObj>(_op);
|
||||||
|
auto context = static_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
float eps = op->getEps();
|
||||||
|
// int axis = op->getAxis();
|
||||||
|
|
||||||
|
const auto &opInputShape = op->getInputs(0)->getDims();
|
||||||
|
const auto &opOutputShape = op->getOutput()->getDims();
|
||||||
|
IT_ASSERT(opInputShape.size() == 2);
|
||||||
|
|
||||||
|
int ret;
|
||||||
|
if (op->numInputs() == 3) {
|
||||||
|
// with bias
|
||||||
|
void *const biasData = op->getInputs(2)->getRawDataPtr<void *>();
|
||||||
|
ret = xdnn::layer_norm<float, float>(
|
||||||
|
context->KUNLUNHandle(), (float const *)inputData,
|
||||||
|
(float *)outputData, opInputShape[0], opInputShape[1], eps,
|
||||||
|
(float *)scaleData, (float *)biasData, nullptr, nullptr);
|
||||||
|
} else {
|
||||||
|
// without bias
|
||||||
|
ret = xdnn::layer_norm<float, float>(
|
||||||
|
context->KUNLUNHandle(), (float const *)inputData,
|
||||||
|
(float *)outputData, opInputShape[0], opInputShape[1], eps,
|
||||||
|
(float *)scaleData, nullptr, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
assert(ret == 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::LayerNormalization, LayerNormXdnn,
|
||||||
|
"LayerNorm_xdnn_KUNLUN");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -21,6 +21,26 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LeakyReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LeakyReluObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
auto len = op->getInputs(0)->size();
|
||||||
|
auto alpha = op->getAlpha();
|
||||||
|
|
||||||
|
auto ret = xdnn::leaky_relu<float>(context->KUNLUNHandle(),
|
||||||
|
(float *const)aData, (float *)cData,
|
||||||
|
len, alpha);
|
||||||
|
assert(ret == 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
|
@ -552,6 +572,8 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::LeakyRelu, LeakyReluXdnn,
|
||||||
|
"LeakyRelu_xdnn_KUNLUN");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
||||||
"Sigmoid_xdnn_KUNLUN");
|
"Sigmoid_xdnn_KUNLUN");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
||||||
|
|
|
@ -283,6 +283,38 @@ vector<int> PReluObj::getWorkloadVector() const {
|
||||||
|
|
||||||
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
|
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||||
|
|
||||||
|
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
float alpha)
|
||||||
|
: OperatorObj(OpType::LeakyRelu, {input}, {output}), alphaValue(alpha) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> LeakyReluObj::inferShape(const TensorVec &inputs) {
|
||||||
|
const auto A = inputs[0];
|
||||||
|
return {{A->getDims()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LeakyReluObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << type.toString() << "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> LeakyReluObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{type.underlying()};
|
||||||
|
const Shape shape = outputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> LeakyReluObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying()};
|
||||||
|
}
|
||||||
|
|
||||||
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
|
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
|
||||||
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
|
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "cmath"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/resize.h"
|
||||||
|
#include "test.h"
|
||||||
|
namespace infini {
|
||||||
|
TEST(Resize, Bang_downsample_sizes_nearest) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||||
|
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||||
|
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||||
|
|
||||||
|
auto inputMlu = gMlu->cloneTensor(input);
|
||||||
|
auto scalesMlu = gMlu->cloneTensor(scales);
|
||||||
|
auto op = gMlu->addOp<ResizeObj>(inputMlu, nullptr, std::nullopt, nullptr,
|
||||||
|
scalesMlu, nullptr);
|
||||||
|
gMlu->dataMalloc();
|
||||||
|
inputMlu->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
scalesMlu->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||||
|
|
||||||
|
bangRuntime->run(gMlu);
|
||||||
|
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{5, 8}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Resize, Bang_upsample_sizes_nearest) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||||
|
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->copyin(vector<float>{1, 2, 3, 4});
|
||||||
|
scales->copyin(vector<float>{1, 1, 2, 3});
|
||||||
|
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
Graph gMlu = make_ref<GraphObj>(bangRuntime);
|
||||||
|
|
||||||
|
auto inputMlu = gMlu->cloneTensor(input);
|
||||||
|
auto scalesMlu = gMlu->cloneTensor(scales);
|
||||||
|
auto op = gMlu->addOp<ResizeObj>(inputMlu, nullptr, std::nullopt, nullptr,
|
||||||
|
scalesMlu, nullptr);
|
||||||
|
gMlu->dataMalloc();
|
||||||
|
inputMlu->copyin(vector<float>{1, 2, 3, 4});
|
||||||
|
scalesMlu->copyin(vector<float>{1, 1, 2, 3});
|
||||||
|
|
||||||
|
bangRuntime->run(gMlu);
|
||||||
|
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
oCpu->equalData(vector<float>{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
|
||||||
|
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue