modified format

This commit is contained in:
xgqdut2016 2024-04-28 15:02:14 +08:00
parent 4d078967e0
commit e6b98fd652
9 changed files with 60 additions and 94 deletions

View File

@ -153,5 +153,3 @@ class GraphHandlerObj {
};
} // namespace infini

View File

@ -268,5 +268,3 @@ enum class ActType {
} // namespace infini
#endif // OP_TYPE_H

View File

@ -228,9 +228,11 @@ class PReluObj : public OperatorObj {
vector<int> getOpAttrVector() const override;
};
class LeakyReluObj : public OperatorObj {
float alpha;
float alpha;
public:
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float alpha = 0.01);
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
float alpha = 0.01);
OP_CLONE(LeakyReluObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
@ -305,4 +307,3 @@ DEFINE_UNARY_OBJ(Reciprocal, OpType::Reciprocal)
DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt)
DEFINE_UNARY_OBJ(Round, OpType::Round)
}; // namespace infini

View File

@ -112,7 +112,6 @@ class OnnxStub:
)
tensors[input.name].set_input()
for node_idx in sorted_nodes:
node = model.graph.node[node_idx]
if node.op_type == "Conv":
@ -202,7 +201,7 @@ class OnnxStub:
p = [0, 0, 0, 0]
else:
adapt = node.input[0]
if len(node.input) > 2:
bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0])
@ -253,8 +252,8 @@ class OnnxStub:
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]], # input
tensors[node.input[1]], # weight
tensors[node.input[0]], # input
tensors[node.input[1]], # weight
tensors.get(node.output[0]),
False,
False,
@ -492,12 +491,8 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "LeakyRelu":
attributes = _parse_attribute(
node, {"alpha": 0.01}
)
(alpha) = (
attributes[name] for name in ["alpha"]
)
attributes = _parse_attribute(node, {"alpha": 0.01})
(alpha) = (attributes[name] for name in ["alpha"])
tensors[node.output[0]] = self.handler.leakyrelu(
tensors[node.input[0]],
tensors.get(node.output[0]),
@ -685,18 +680,10 @@ class OnnxStub:
coordinate_transformation_mode,
)
elif node.op_type == "Squeeze":
axes = (
_parse_data(data[node.input[1]])
if len(node.input) > 1
else None
)
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
if axes is None:
axes = next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
),
(attr.ints for attr in node.attribute if attr.name == "axes"),
[],
)
tensors[node.output[0]] = self.handler.squeeze(
@ -705,18 +692,10 @@ class OnnxStub:
axes,
)
elif node.op_type == "Unsqueeze":
axes = (
_parse_data(data[node.input[1]])
if len(node.input) > 1
else None
)
axes = _parse_data(data[node.input[1]]) if len(node.input) > 1 else None
if axes is None:
axes = next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
)
(attr.ints for attr in node.attribute if attr.name == "axes")
)
tensors[node.output[0]] = self.handler.unsqueeze(
tensors[node.input[0]],
@ -740,24 +719,18 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "RoPE":
tensors[node.output[0]]= self.handler.RoPE(
tensors[node.output[0]] = self.handler.RoPE(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Split":
split = (
_parse_data(data[node.input[1]])
if (len(node.input) > 1)
else None
_parse_data(data[node.input[1]]) if (len(node.input) > 1) else None
)
if split is None:
split = next(
(
attr.ints
for attr in node.attribute
if attr.name == "split"
),
(attr.ints for attr in node.attribute if attr.name == "split"),
None,
)
for name, tensor in zip(
@ -766,11 +739,7 @@ class OnnxStub:
tensors[node.input[0]],
None,
next(
(
attr.i
for attr in node.attribute
if attr.name == "axis"
),
(attr.i for attr in node.attribute if attr.name == "axis"),
0,
),
split if split is not None else len(node.output),
@ -998,18 +967,25 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "Where":
## If Y is single -inf, treat Where as Add
## If Y is single -inf, treat Where as Add
## TODO: deal with cases where Y is single inf or 0
if node.input[0] in data and node.input[2] in data:
where_condition = to_array(data[node.input[0]])
where_alt = to_array(data[node.input[2]])
where_alt = to_array(data[node.input[2]])
if where_alt.size == 1:
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
node.input[0] = node.input[0] + "_alt"
if node.input[0] not in data:
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
data[node.input[0]] = from_array(where_value, node.input[0])
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
where_value = np.where(
where_condition, 0, -np.inf
).astype(where_alt.dtype)
data[node.input[0]] = from_array(
where_value, node.input[0]
)
tensors[node.input[0]] = self.handler.tensor(
list(where_value.shape),
data[node.input[0]].data_type,
)
tensors[node.input[0]].set_weight()
tensors[node.output[0]] = self.handler.add(
tensors[node.input[1]],
@ -1036,8 +1012,7 @@ class OnnxStub:
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
)
(alpha, beta, bias, size) = (
attributes[name]
for name in ["alpha", "beta", "bias", "size"]
attributes[name] for name in ["alpha", "beta", "bias", "size"]
)
tensors[node.output[0]] = self.handler.lrn(
tensors[node.input[0]],
@ -1513,5 +1488,3 @@ def _parse_data_fp16(tensor: TensorProto):
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -124,16 +124,12 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
->getOutput();
}
}
Tensor GraphHandlerObj::leakyrelu(Tensor input,
Tensor output, float alpha) {
Tensor GraphHandlerObj::leakyrelu(Tensor input, Tensor output, float alpha) {
if (output) {
g->addOpWithOutputs<LeakyReluObj>(std::move(input),
output, alpha);
g->addOpWithOutputs<LeakyReluObj>(std::move(input), output, alpha);
return output;
} else {
return g
->addOp<LeakyReluObj>(std::move(input), output,
alpha)
return g->addOp<LeakyReluObj>(std::move(input), output, alpha)
->getOutput();
}
}
@ -775,5 +771,3 @@ void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
}
} // namespace infini

View File

@ -616,5 +616,3 @@ PYBIND11_MODULE(backend, m) {
infini::export_functions(m);
infini::init_graph_builder(m);
}

View File

@ -122,18 +122,19 @@ class LeakyReluAclnn : public ASCENDKernelWithoutConfig {
uint64_t workspaceSize = 0;
aclOpExecutor *executor;
float negativeSlopeValue = op->getAlpha();
aclScalar* negativeSlope = nullptr;
negativeSlope = aclCreateScalar(&negativeSlopeValue, aclDataType::ACL_FLOAT);
auto ret =
aclnnLeakyReluGetWorkspaceSize(input, negativeSlope, output, &workspaceSize, &executor);
aclScalar *negativeSlope = nullptr;
negativeSlope =
aclCreateScalar(&negativeSlopeValue, aclDataType::ACL_FLOAT);
auto ret = aclnnLeakyReluGetWorkspaceSize(input, negativeSlope, output,
&workspaceSize, &executor);
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
workspaceAddr = context->getWorkspace(workspaceSize);
}
assert(ret == ACL_SUCCESS);
ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor,
context->ASCENDHandle());
context->ASCENDHandle());
assert(ret == ACL_SUCCESS);
// aclDestroyTensor(input);
@ -156,22 +157,22 @@ class LeakyReluAclnn : public ASCENDKernelWithoutConfig {
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); \
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); \
\
auto a = op->getInputs(0) -> getDims(); \
auto a = op->getInputs(0)->getDims(); \
std::vector<int64_t> aDim(a.size(), 1); \
for (size_t i = 0; i < a.size(); ++i) { \
aDim[i] = int64_t(a[i]); \
} \
auto aS = op->getInputs(0) -> getStride(); \
auto aS = op->getInputs(0)->getStride(); \
std::vector<int64_t> aStride(aS.size(), 1); \
for (size_t i = 0; i < aS.size(); ++i) { \
aStride[i] = int64_t(aS[i]); \
} \
auto c = op->getInputs(0) -> getDims(); \
auto c = op->getInputs(0)->getDims(); \
std::vector<int64_t> cDim(c.size(), 1); \
for (size_t i = 0; i < c.size(); ++i) { \
cDim[i] = int64_t(c[i]); \
} \
auto cS = op->getInputs(0) -> getStride(); \
auto cS = op->getInputs(0)->getStride(); \
std::vector<int64_t> cStride(cS.size(), 1); \
for (size_t i = 0; i < cS.size(); ++i) { \
cStride[i] = int64_t(cS[i]); \
@ -209,7 +210,6 @@ DEFINE_UNARY_Aclnn(Sigmoid);
DEFINE_UNARY_Aclnn(Hardswish);
DEFINE_UNARY_Aclnn(Gelu);
DEFINE_UNARY_Aclnn(Tanh);
DEFINE_UNARY_Aclnn(Sin);
DEFINE_UNARY_Aclnn(Cos);
@ -252,5 +252,3 @@ REGISTER_KERNEL(Device::ASCEND, OpType::Round, RoundAclnn,
"round_ASCEND_float");
REGISTER_KERNEL(Device::ASCEND, OpType::Erf, ErfAclnn, "erf_ASCEND_float");
}; // namespace infini

View File

@ -283,9 +283,10 @@ vector<int> PReluObj::getWorkloadVector() const {
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float _alpha)
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output,
float _alpha)
: OperatorObj(OpType::LeakyRelu, {input}, {output}), alpha(_alpha) {
IT_ASSERT(checkValid(graph));
}
@ -312,7 +313,9 @@ vector<int> LeakyReluObj::getWorkloadVector() const {
return ret;
}
vector<int> LeakyReluObj::getOpAttrVector() const { return {type.underlying()}; }
vector<int> LeakyReluObj::getOpAttrVector() const {
return {type.underlying()};
}
LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
: OperatorObj(OpType::Log, {input}, {output}), logType(type) {
IT_ASSERT(checkValid(graph));

View File

@ -40,8 +40,8 @@ void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
}
void testLeakyRelu(const Shape &shape, const vector<float> &inputData,
const vector<float> &ExpectData, float alpha) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
const vector<float> &ExpectData, float alpha) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
@ -53,7 +53,7 @@ void testLeakyRelu(const Shape &shape, const vector<float> &inputData,
auto npuRuntime = make_ref<ASCENDRuntimeObj>();
Graph npuGraph = make_ref<GraphObj>(npuRuntime);
// NPU
auto inputNpu = npuGraph->cloneTensor(input);
auto npuOp = npuGraph->addOp<LeakyReluObj>(inputNpu, nullptr, alpha);
npuGraph->dataMalloc();
@ -61,16 +61,19 @@ void testLeakyRelu(const Shape &shape, const vector<float> &inputData,
npuRuntime->run(npuGraph);
auto outputNpu = npuOp->getOutput();
auto outputNpu2Cpu = outputNpu->clone(cpuRuntime);
// Check
EXPECT_TRUE(outputNpu2Cpu->equalData(ExpectData));
}
TEST(ascend_Unary, run) {
aclInit(nullptr);
testLeakyRelu(Shape{1, 2, 2, 3}, vector<float>{-6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6}, vector<float>{-0.0600, -0.0500, -0.0400, -0.0300, -0.0200, -0.0100, 1.0000, 2.0000,
3.0000, 4.0000, 5.0000, 6.0000}, 0.01);
testLeakyRelu(Shape{1, 2, 2, 3},
vector<float>{-6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6},
vector<float>{-0.0600, -0.0500, -0.0400, -0.0300, -0.0200,
-0.0100, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000,
6.0000},
0.01);
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});