refactor: 简化编译器的 setInput

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-21 03:40:49 +08:00
parent 7edf1983dd
commit e8f820f47b
2 changed files with 14 additions and 4 deletions

@ -1 +1 @@
Subproject commit d78c70dbe9872b6b17ca32525cb1210edfeae0fb
Subproject commit 8c78f9eeeb03f6794189cc7a945a8edbe4252040

View File

@ -43,12 +43,22 @@ class Compiler {
public:
explicit Compiler(Graph g) : _g(std::move(g)) {}
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
void setInput(size_t index, int dataType, std::vector<DimExpr> shape) {
void setInput(size_t index, int dataType,
std::vector<std::variant<std::string, int64_t>> shape) {
ASSERT(index < _g.internal().topology.globalInputsCount(),
fmt::format("set input {} failed with wrong index", index));
auto dataType_ = static_cast<common::DataType>(dataType);
Shape shape_(shape.size(), DimExpr(1));
std::transform(shape.begin(), shape.end(), shape_.begin(),
[](auto const &d) -> DimExpr {
if (std::holds_alternative<std::string>(d)) {
return DimExpr(std::get<std::string>(d));
} else {
return DimExpr(std::get<int64_t>(d));
}
});
_g.internal().edges[index].tensor =
std::move(Tensor::share(static_cast<common::DataType>(dataType),
Shape(shape.begin(), shape.end())));
Tensor::share(dataType_, std::move(shape_));
}
void substitute(const char *name, int64_t value) {
ASSERT(_g.substitute(name, value),