feat: 支持用 numpy.ndarray 构造张量

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-18 18:13:01 +08:00
parent 8e4e392a49
commit ca149065d4
2 changed files with 8 additions and 3 deletions

View File

@ -1,10 +1,15 @@
import backend
from numpy import ndarray
import backend
from onnx import ModelProto, NodeProto, TensorProto, AttributeProto, numpy_helper
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
from backend import DimExpr, refactor_tensor, refactor_operator, refactor_graph
from typing import Any
def build_tensor(array: ndarray) -> backend.Tensor:
return _parse_tensor(numpy_helper.from_array(array))
def build_graph(model: ModelProto) -> backend.Graph:
edges: dict[str, backend.Tensor] = dict()

View File

@ -35,7 +35,7 @@ class Handler {
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
void setInput(size_t index, std::shared_ptr<Tensor> tensor) {
ASSERT(_g.setInput(index, std::move(tensor)),
fmt::format("set input {} failed", index));
fmt::format("set input {} failed with wrong shape", index));
}
void substitute(const char *name, int64_t value) {
ASSERT(_g.substitute(name, value),
@ -53,7 +53,7 @@ class Handler {
#endif
}
template <class T> std::vector<T> copyout(size_t i) {
return _outputs[i]->copyout<T>();
return _outputs.at(i)->copyout<T>();
}
};