From e8f820f47bf22a0010c6f7e866690f7c7b88c217 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 21 Sep 2023 03:40:49 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=E7=BC=96?= =?UTF-8?q?=E8=AF=91=E5=99=A8=E7=9A=84=20setInput?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- RefactorGraph | 2 +- src/ffi/ffi_infinitensor.cc | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/RefactorGraph b/RefactorGraph index d78c70db..8c78f9ee 160000 --- a/RefactorGraph +++ b/RefactorGraph @@ -1 +1 @@ -Subproject commit d78c70dbe9872b6b17ca32525cb1210edfeae0fb +Subproject commit 8c78f9eeeb03f6794189cc7a945a8edbe4252040 diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index d6d850fe..e9e879e1 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -43,12 +43,22 @@ class Compiler { public: explicit Compiler(Graph g) : _g(std::move(g)) {} std::unordered_set fillEdgeInfo() { return _g.fillEdgeInfo(); } - void setInput(size_t index, int dataType, std::vector shape) { + void setInput(size_t index, int dataType, + std::vector> shape) { ASSERT(index < _g.internal().topology.globalInputsCount(), fmt::format("set input {} failed with wrong index", index)); + auto dataType_ = static_cast(dataType); + Shape shape_(shape.size(), DimExpr(1)); + std::transform(shape.begin(), shape.end(), shape_.begin(), + [](auto const &d) -> DimExpr { + if (std::holds_alternative(d)) { + return DimExpr(std::get(d)); + } else { + return DimExpr(std::get(d)); + } + }); _g.internal().edges[index].tensor = - std::move(Tensor::share(static_cast(dataType), - Shape(shape.begin(), shape.end()))); + Tensor::share(dataType_, std::move(shape_)); } void substitute(const char *name, int64_t value) { ASSERT(_g.substitute(name, value),