fix: 改正一些算子的导出

fix: 前端需要保存一份后端以维持硬件上空间存在
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-19 14:17:55 +08:00
parent fca88cc6ba
commit 742e876a96
4 changed files with 56 additions and 54 deletions

View File

@ -114,9 +114,8 @@ class GraphObj : public Object {
bool checkValid() const;
std::vector<Tensor>
transformFromGraphTopo(refactor::computation::Graph &graph,
Runtime runtime);
void transformFromGraphTopo(refactor::computation::Graph &graph,
Runtime runtime);
private:
/**

View File

@ -350,9 +350,8 @@ bool GraphObj::checkValid() const {
return true;
}
std::vector<Tensor>
GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
Runtime runtime) {
void GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
Runtime runtime) {
// create ops and tensors
ops.clear();
tensors.clear();
@ -392,11 +391,8 @@ GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
}
}
std::vector<Tensor> outputs;
for (auto edgeIdx : it.globalOutputs()) {
auto ptr = edgeToTensor.at(edgeIdx);
ptr->setOutput();
outputs.push_back(std::move(ptr));
edgeToTensor.at(edgeIdx)->setOutput();
}
dataMalloc();
@ -408,8 +404,6 @@ GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
}
}
}
return outputs;
}
} // namespace infini

View File

@ -28,7 +28,7 @@ using Name = std::string;
class Handler {
Graph _g;
std::vector<infini::Tensor> _outputs;
infini::Graph _lastBackend;
public:
explicit Handler(Graph g) : _g(std::move(g)) {}
@ -47,13 +47,13 @@ class Handler {
using namespace infini;
#ifdef USE_CUDA
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto graph = make_ref<GraphObj>(cudaRuntime);
_outputs = graph->transformFromGraphTopo(_g, cudaRuntime);
graph->getRuntime()->run(graph);
_lastBackend = make_ref<GraphObj>(cudaRuntime);
_lastBackend->transformFromGraphTopo(_g, cudaRuntime);
_lastBackend->getRuntime()->run(_lastBackend);
#endif
}
template <class T> std::vector<T> copyout(size_t i) {
return _outputs.at(i)->copyout<T>();
return _lastBackend->getOutputs().at(i)->copyout<T>();
}
};

View File

@ -243,43 +243,52 @@ void addOperatorFromGraphTopo(
g.addOpWithOutputs<TransposeObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]], perm);
} else if (name == "onnx::Split") {
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : 0;
std::vector<Tensor> outputs;
for (auto i : output) {
outputs.emplace_back(edgeToTensor[i]);
}
int num = output.size();
if (input.size() == 2) {
auto ratioValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
std::vector<int> ratio;
auto rank = edgeToTensor[input[1]]->getDims()[0];
for (size_t i = 0; i < (size_t)rank; ++i) {
ratio.emplace_back(static_cast<int>(*(ratioValue + i)));
}
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, ratio);
} else {
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis, num);
}
} else if (name == "onnx::Where") {
IT_ASSERT(input.size() == 3);
g.addOpWithOutputs<WhereObj>(edgeToTensor[input[1]], edgeToTensor[input[2]],
edgeToTensor[input[0]], edgeToTensor[output[0]]);
} else if (name == "onnx::Softmax") {
//auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : -1;
} else if (name == "onnx::Sqrt") {
g.addOpWithOutputs<SqrtObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Relu") {
g.addOpWithOutputs<ReluObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Identity") {
g.addOpWithOutputs<IdentityObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Tanh") {
g.addOpWithOutputs<TanhObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
}
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : 0;
std::vector<Tensor> outputs;
for (auto i : output) {
outputs.emplace_back(edgeToTensor[i]);
}
int num = output.size();
if (input.size() == 2) {
auto ratioValue =
reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
std::vector<int> ratio;
auto rank = edgeToTensor[input[1]]->getDims()[0];
for (size_t i = 0; i < (size_t)rank; ++i) {
ratio.emplace_back(static_cast<int>(*(ratioValue + i)));
}
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis,
ratio);
} else {
g.addOpWithOutputs<SplitObj>(edgeToTensor[input[0]], outputs, axis,
num);
}
// } else if (name == "onnx::Where") {
// IT_ASSERT(input.size() == 3);
// g.addOpWithOutputs<WhereObj>(
// edgeToTensor[input[1]], edgeToTensor[input[2]],
// edgeToTensor[input[0]], edgeToTensor[output[0]]);
// } else if (name == "onnx::Softmax") {
// // auto axis = attr.find("axis") != attr.end() ?
// attr["axis"].int_() :
// // -1;
} else if (name == "onnx::Sqrt") {
g.addOpWithOutputs<SqrtObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Relu") {
g.addOpWithOutputs<ReluObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Identity") {
g.addOpWithOutputs<IdentityObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else if (name == "onnx::Tanh") {
g.addOpWithOutputs<TanhObj>(edgeToTensor[input[0]],
edgeToTensor[output[0]]);
} else {
std::cerr << "Unknown operator: " << name << std::endl;
IT_ASSERT_TODO("");
}
}
void addEdgeToTensor(GraphObj &g, size_t index,