fix: 从模板导出每个类型的 python 接口

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-03-22 09:46:40 +08:00
parent 73e895b8ce
commit 5aeacedab3
4 changed files with 16 additions and 32 deletions

View File

@ -89,18 +89,6 @@ class GraphHandlerObj {
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc() { g->dataMalloc(); }
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
tensor->copyin(list);
}
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
tensor->copyin(list);
}
inline void copy_float(Tensor tensor, std::vector<float> list) {
tensor->copyin(list);
}
inline void run() { g->getRuntime()->run(g); } inline void run() { g->getRuntime()->run(g); }
}; };

View File

@ -57,20 +57,16 @@ class TensorObj : public TensorBaseObj {
return ans; return ans;
} }
// Copy the element at `pos`. // Copy the element at `pos`.
template <typename T> inline auto copyout(const vector<int> &pos) const { template <typename T> inline auto copyOne(const vector<int> &pos) const {
IT_ASSERT(DataType::get<T>() == dtype); IT_ASSERT(DataType::get<T>() == dtype);
auto offset = getOffset(pos); auto offset = getOffset(pos);
auto bytes = dtype.getSize(); auto bytes = dtype.getSize();
T ans; T ans;
runtime->copyBlobToCPU(&ans, getRawDataPtr<void *>() + offset * bytes, runtime->copyBlobToCPU(
bytes); &ans, getRawDataPtr<uint8_t *>() + offset * bytes, bytes);
return ans; return ans;
} }
inline auto copyoutFloat() const { return copyout<float>(); }
inline auto copyoutInt32() const { return copyout<int32_t>(); }
inline auto copyoutInt64() const { return copyout<int64_t>(); }
void copyData(const TensorObj *src); void copyData(const TensorObj *src);
void copyData(const Tensor &src) { copyData(src.get()); } void copyData(const Tensor &src) { copyData(src.get()); }
void setData( void setData(

View File

@ -101,7 +101,7 @@ class OnnxStub:
(alpha, beta, transA, transB) = ( (alpha, beta, transA, transB) = (
attributes[name] for name in ["alpha", "beta", "transA", "transB"] attributes[name] for name in ["alpha", "beta", "transA", "transB"]
) )
# TODO 不支持这些参数 # FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0 assert alpha == 1.0
assert beta == 1.0 assert beta == 1.0
tensors[node.output[0]] = self.handler.matmul( tensors[node.output[0]] = self.handler.matmul(
@ -265,7 +265,7 @@ class OnnxStub:
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "Flatten": elif node.op_type == "Flatten":
# TODO 后端算子不支持沿任意轴展开 # FIXME axis must be 1
axis = next( axis = next(
(attr.i for attr in node.attribute if attr.name == "axis"), None (attr.i for attr in node.attribute if attr.name == "axis"), None
) )
@ -315,7 +315,7 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "axis")), next((attr.i for attr in node.attribute if attr.name == "axis")),
) )
elif node.op_type == "ReduceMean": elif node.op_type == "ReduceMean":
tensors[node.output[0]] = self.handler.reduceMean( tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
tensors[node.input[1]] if len(node.input) > 1 else None, tensors[node.input[1]] if len(node.input) > 1 else None,
@ -351,11 +351,11 @@ class OnnxStub:
else: else:
self.initializer[obj.fuid()] = tensor self.initializer[obj.fuid()] = tensor
if tensor.data_type == TensorProto.INT32: if tensor.data_type == TensorProto.INT32:
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) obj.copyin_int32([int(i) for i in tensor.int32_data])
elif tensor.data_type == TensorProto.INT64: elif tensor.data_type == TensorProto.INT64:
self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data]) obj.copyin_int64([int(i) for i in tensor.int64_data])
elif tensor.data_type == TensorProto.FLOAT: elif tensor.data_type == TensorProto.FLOAT:
self.handler.copy_float(obj, [float(i) for i in tensor.float_data]) obj.copyin_float([int(i) for i in tensor.float_data])
else: else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)

View File

@ -177,9 +177,12 @@ void init_graph_builder(py::module &m) {
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor") py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("fuid", &TensorObj::getFuid, policy::automatic) .def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move) .def("shape", &TensorObj::getDims, policy::move)
.def("copyoutFloat", &TensorObj::copyoutFloat, policy::move) .def("copyin_float", &TensorObj::copyin<float>, policy::move)
.def("copyoutInt32", &TensorObj::copyoutInt32, policy::move) .def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
.def("copyoutInt64", &TensorObj::copyoutInt64, policy::move) .def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
.def("has_target", &TensorObj::hasTarget, policy::automatic) .def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getOutputOf, policy::move); .def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator") py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
@ -212,15 +215,12 @@ void init_graph_builder(py::module &m) {
.def("reshape", &Handler::reshape, policy::move) .def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move) .def("concat", &Handler::concat, policy::move)
.def("gather", &Handler::gather, policy::move) .def("gather", &Handler::gather, policy::move)
.def("reduceMean", &Handler::reduceMean, policy::move) .def("reduce_mean", &Handler::reduceMean, policy::move)
.def("slice", &Handler::slice, policy::move) .def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move) .def("pad", &Handler::pad, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("copy_int32", &Handler::copy_int32, policy::automatic)
.def("copy_int64", &Handler::copy_int64, policy::automatic)
.def("copy_float", &Handler::copy_float, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic);
} }