forked from jiuyuan/InfiniTensor
fix: 从模板导出每个类型的 python 接口
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
73e895b8ce
commit
5aeacedab3
|
@ -89,18 +89,6 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
inline void data_malloc() { g->dataMalloc(); }
|
inline void data_malloc() { g->dataMalloc(); }
|
||||||
|
|
||||||
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
|
|
||||||
tensor->copyin(list);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
|
|
||||||
tensor->copyin(list);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void copy_float(Tensor tensor, std::vector<float> list) {
|
|
||||||
tensor->copyin(list);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void run() { g->getRuntime()->run(g); }
|
inline void run() { g->getRuntime()->run(g); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -57,20 +57,16 @@ class TensorObj : public TensorBaseObj {
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
// Copy the element at `pos`.
|
// Copy the element at `pos`.
|
||||||
template <typename T> inline auto copyout(const vector<int> &pos) const {
|
template <typename T> inline auto copyOne(const vector<int> &pos) const {
|
||||||
IT_ASSERT(DataType::get<T>() == dtype);
|
IT_ASSERT(DataType::get<T>() == dtype);
|
||||||
auto offset = getOffset(pos);
|
auto offset = getOffset(pos);
|
||||||
auto bytes = dtype.getSize();
|
auto bytes = dtype.getSize();
|
||||||
T ans;
|
T ans;
|
||||||
runtime->copyBlobToCPU(&ans, getRawDataPtr<void *>() + offset * bytes,
|
runtime->copyBlobToCPU(
|
||||||
bytes);
|
&ans, getRawDataPtr<uint8_t *>() + offset * bytes, bytes);
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline auto copyoutFloat() const { return copyout<float>(); }
|
|
||||||
inline auto copyoutInt32() const { return copyout<int32_t>(); }
|
|
||||||
inline auto copyoutInt64() const { return copyout<int64_t>(); }
|
|
||||||
|
|
||||||
void copyData(const TensorObj *src);
|
void copyData(const TensorObj *src);
|
||||||
void copyData(const Tensor &src) { copyData(src.get()); }
|
void copyData(const Tensor &src) { copyData(src.get()); }
|
||||||
void setData(
|
void setData(
|
||||||
|
|
|
@ -101,7 +101,7 @@ class OnnxStub:
|
||||||
(alpha, beta, transA, transB) = (
|
(alpha, beta, transA, transB) = (
|
||||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||||
)
|
)
|
||||||
# TODO 不支持这些参数
|
# FIXME unsupport attributes: `alpha` `beta`
|
||||||
assert alpha == 1.0
|
assert alpha == 1.0
|
||||||
assert beta == 1.0
|
assert beta == 1.0
|
||||||
tensors[node.output[0]] = self.handler.matmul(
|
tensors[node.output[0]] = self.handler.matmul(
|
||||||
|
@ -265,7 +265,7 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Flatten":
|
elif node.op_type == "Flatten":
|
||||||
# TODO 后端算子不支持沿任意轴展开
|
# FIXME axis must be 1
|
||||||
axis = next(
|
axis = next(
|
||||||
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||||
)
|
)
|
||||||
|
@ -315,7 +315,7 @@ class OnnxStub:
|
||||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||||
)
|
)
|
||||||
elif node.op_type == "ReduceMean":
|
elif node.op_type == "ReduceMean":
|
||||||
tensors[node.output[0]] = self.handler.reduceMean(
|
tensors[node.output[0]] = self.handler.reduce_mean(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||||
|
@ -351,11 +351,11 @@ class OnnxStub:
|
||||||
else:
|
else:
|
||||||
self.initializer[obj.fuid()] = tensor
|
self.initializer[obj.fuid()] = tensor
|
||||||
if tensor.data_type == TensorProto.INT32:
|
if tensor.data_type == TensorProto.INT32:
|
||||||
self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
obj.copyin_int32([int(i) for i in tensor.int32_data])
|
||||||
elif tensor.data_type == TensorProto.INT64:
|
elif tensor.data_type == TensorProto.INT64:
|
||||||
self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
obj.copyin_int64([int(i) for i in tensor.int64_data])
|
||||||
elif tensor.data_type == TensorProto.FLOAT:
|
elif tensor.data_type == TensorProto.FLOAT:
|
||||||
self.handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
obj.copyin_float([int(i) for i in tensor.float_data])
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||||
|
|
||||||
|
|
|
@ -177,9 +177,12 @@ void init_graph_builder(py::module &m) {
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||||
.def("shape", &TensorObj::getDims, policy::move)
|
.def("shape", &TensorObj::getDims, policy::move)
|
||||||
.def("copyoutFloat", &TensorObj::copyoutFloat, policy::move)
|
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
||||||
.def("copyoutInt32", &TensorObj::copyoutInt32, policy::move)
|
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
|
||||||
.def("copyoutInt64", &TensorObj::copyoutInt64, policy::move)
|
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
|
||||||
|
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
|
||||||
|
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
|
||||||
|
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
|
||||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||||
|
@ -212,15 +215,12 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("reshape", &Handler::reshape, policy::move)
|
.def("reshape", &Handler::reshape, policy::move)
|
||||||
.def("concat", &Handler::concat, policy::move)
|
.def("concat", &Handler::concat, policy::move)
|
||||||
.def("gather", &Handler::gather, policy::move)
|
.def("gather", &Handler::gather, policy::move)
|
||||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||||
.def("slice", &Handler::slice, policy::move)
|
.def("slice", &Handler::slice, policy::move)
|
||||||
.def("pad", &Handler::pad, policy::move)
|
.def("pad", &Handler::pad, policy::move)
|
||||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
.def("operators", &Handler::operators, policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("copy_int32", &Handler::copy_int32, policy::automatic)
|
|
||||||
.def("copy_int64", &Handler::copy_int64, policy::automatic)
|
|
||||||
.def("copy_float", &Handler::copy_float, policy::automatic)
|
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("run", &Handler::run, policy::automatic);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue