forked from jiuyuan/InfiniTensor
refactor: 简化编译器的 setInput
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7edf1983dd
commit
e8f820f47b
|
@ -1 +1 @@
|
|||
Subproject commit d78c70dbe9872b6b17ca32525cb1210edfeae0fb
|
||||
Subproject commit 8c78f9eeeb03f6794189cc7a945a8edbe4252040
|
|
@ -43,12 +43,22 @@ class Compiler {
|
|||
public:
|
||||
explicit Compiler(Graph g) : _g(std::move(g)) {}
|
||||
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
void setInput(size_t index, int dataType, std::vector<DimExpr> shape) {
|
||||
void setInput(size_t index, int dataType,
|
||||
std::vector<std::variant<std::string, int64_t>> shape) {
|
||||
ASSERT(index < _g.internal().topology.globalInputsCount(),
|
||||
fmt::format("set input {} failed with wrong index", index));
|
||||
auto dataType_ = static_cast<common::DataType>(dataType);
|
||||
Shape shape_(shape.size(), DimExpr(1));
|
||||
std::transform(shape.begin(), shape.end(), shape_.begin(),
|
||||
[](auto const &d) -> DimExpr {
|
||||
if (std::holds_alternative<std::string>(d)) {
|
||||
return DimExpr(std::get<std::string>(d));
|
||||
} else {
|
||||
return DimExpr(std::get<int64_t>(d));
|
||||
}
|
||||
});
|
||||
_g.internal().edges[index].tensor =
|
||||
std::move(Tensor::share(static_cast<common::DataType>(dataType),
|
||||
Shape(shape.begin(), shape.end())));
|
||||
Tensor::share(dataType_, std::move(shape_));
|
||||
}
|
||||
void substitute(const char *name, int64_t value) {
|
||||
ASSERT(_g.substitute(name, value),
|
||||
|
|
Loading…
Reference in New Issue