forked from jiuyuan/InfiniTensor
feat: 支持用 numpy.ndarray 构造张量
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
8e4e392a49
commit
ca149065d4
|
@ -1,10 +1,15 @@
|
|||
import backend
|
||||
from numpy import ndarray
|
||||
import backend
|
||||
from onnx import ModelProto, NodeProto, TensorProto, AttributeProto, numpy_helper
|
||||
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
|
||||
from backend import DimExpr, refactor_tensor, refactor_operator, refactor_graph
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_tensor(array: ndarray) -> backend.Tensor:
|
||||
return _parse_tensor(numpy_helper.from_array(array))
|
||||
|
||||
|
||||
def build_graph(model: ModelProto) -> backend.Graph:
|
||||
edges: dict[str, backend.Tensor] = dict()
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class Handler {
|
|||
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
void setInput(size_t index, std::shared_ptr<Tensor> tensor) {
|
||||
ASSERT(_g.setInput(index, std::move(tensor)),
|
||||
fmt::format("set input {} failed", index));
|
||||
fmt::format("set input {} failed with wrong shape", index));
|
||||
}
|
||||
void substitute(const char *name, int64_t value) {
|
||||
ASSERT(_g.substitute(name, value),
|
||||
|
@ -53,7 +53,7 @@ class Handler {
|
|||
#endif
|
||||
}
|
||||
template <class T> std::vector<T> copyout(size_t i) {
|
||||
return _outputs[i]->copyout<T>();
|
||||
return _outputs.at(i)->copyout<T>();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue