forked from jiuyuan/InfiniTensor
feat: support some ops
This commit is contained in:
parent
4fdcf97e32
commit
fca88cc6ba
|
@ -64,11 +64,6 @@ void addOperatorFromGraphTopo(
|
|||
std::vector<refactor::computation::Edge> edges) {
|
||||
std::string name(nodeInfo->opType.name());
|
||||
auto attr = nodeInfo->attributes;
|
||||
#define ELSE_IF(op) \
|
||||
else if (name == "onnx::op") { \
|
||||
g.addOpWithOutputs<op##Obj>(edgeToTensor[input[0]], \
|
||||
edgeToTensor[output[0]]); \
|
||||
}
|
||||
if (name == "onnx::Conv") {
|
||||
// auto p = attr["pads"].ints();
|
||||
// auto s = attr["strides"].ints();
|
||||
|
@ -247,12 +242,44 @@ void addOperatorFromGraphTopo(
|
|||
}
|
||||
g.addOpWithOutputs<TransposeObj>(edgeToTensor[input[0]],
|
||||
edgeToTensor[output[0]], perm);
|
||||
}
|
||||
ELSE_IF(Relu)
|
||||
ELSE_IF(Sqrt)
|
||||
ELSE_IF(Identity)
|
||||
|
||||
#undef ELSE_IF
|
||||
} else if (name == "onnx::Split") {
|
||||
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : 0;
|
||||
std::vector<Tensor> outputs;
|
||||
for (auto i : output) {
|
||||
outputs.emplace_back(edgeToTensor[i]);
|
||||
}
|
||||
int num = output.size();
|
||||
if (input.size() == 2) {
|
||||
auto ratioValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
|
||||
std::vector<int> ratio;
|
||||
auto rank = edgeToTensor[input[1]]->getDims()[0];
|
||||
for (size_t i = 0; i < (size_t)rank; ++i) {
|
||||
ratio.emplace_back(static_cast<int>(*(ratioValue + i)));
|
||||
}
|
||||
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, ratio);
|
||||
} else {
|
||||
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, num);
|
||||
}
|
||||
} else if (name == "onnx::Where") {
|
||||
IT_ASSERT(input.size() == 3);
|
||||
g.addOpWithOutputs<WhereObj>(edgeToTensor[input[1]], edgeToTensor[input[2]],
|
||||
edgeToTensor[input[0]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Softmax") {
|
||||
//auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : -1;
|
||||
|
||||
} else if (name == "onnx::Sqrt") {
|
||||
g.addOpWithOutputs<SqrtObj>(edgeToTensor[input[0]],
|
||||
edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Relu") {
|
||||
g.addOpWithOutputs<ReluObj>(edgeToTensor[input[0]],
|
||||
edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Identity") {
|
||||
g.addOpWithOutputs<IdentityObj>(edgeToTensor[input[0]],
|
||||
edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Tanh") {
|
||||
g.addOpWithOutputs<TanhObj>(edgeToTensor[input[0]],
|
||||
edgeToTensor[output[0]]);
|
||||
}
|
||||
}
|
||||
|
||||
void addEdgeToTensor(GraphObj &g, size_t index,
|
||||
|
|
Loading…
Reference in New Issue