feat: support some ops

This commit is contained in:
zhangyunze 2023-09-19 14:12:32 +08:00
parent 4fdcf97e32
commit fca88cc6ba
1 changed files with 38 additions and 11 deletions

View File

@ -64,11 +64,6 @@ void addOperatorFromGraphTopo(
std::vector<refactor::computation::Edge> edges) {
std::string name(nodeInfo->opType.name());
auto attr = nodeInfo->attributes;
#define ELSE_IF(op) \
else if (name == "onnx::op") { \
g.addOpWithOutputs<op##Obj>(edgeToTensor[input[0]], \
edgeToTensor[output[0]]); \
}
if (name == "onnx::Conv") {
// auto p = attr["pads"].ints();
// auto s = attr["strides"].ints();
@ -247,12 +242,44 @@ void addOperatorFromGraphTopo(
}
g.addOpWithOutputs<TransposeObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]], perm);
}
ELSE_IF(Relu)
ELSE_IF(Sqrt)
ELSE_IF(Identity)
#undef ELSE_IF
} else if (name == "onnx::Split") {
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : 0;
std::vector<Tensor> outputs;
for (auto i : output) {
outputs.emplace_back(edgeToTensor[i]);
}
int num = output.size();
if (input.size() == 2) {
auto ratioValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
std::vector<int> ratio;
auto rank = edgeToTensor[input[1]]->getDims()[0];
for (size_t i = 0; i < (size_t)rank; ++i) {
ratio.emplace_back(static_cast<int>(*(ratioValue + i)));
}
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, ratio);
} else {
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, num);
}
} else if (name == "onnx::Where") {
IT_ASSERT(input.size() == 3);
g.addOpWithOutputs<WhereObj>(edgeToTensor[input[1]], edgeToTensor[input[2]],
edgeToTensor[input[0]], edgeToTensor[output[0]]);
} else if (name == "onnx::Softmax") {
//auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : -1;
} else if (name == "onnx::Sqrt") {
g.addOpWithOutputs<SqrtObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Relu") {
g.addOpWithOutputs<ReluObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Identity") {
g.addOpWithOutputs<IdentityObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Tanh") {
g.addOpWithOutputs<TanhObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
}
}
void addEdgeToTensor(GraphObj &g, size_t index,