Constroy/fix reducemean (#91)

* feat: 导出 getPerfTime 到 python

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix parsing of ReduceMean

* ReduceMean axes defaults to None

* fix ElementWiseCudnn with shape broadcasting

* fix format

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: constroy <constroy.li@gmail.com>
This commit is contained in:
Haojie Wang 2023-06-29 08:27:00 +08:00 committed by GitHub
parent 4340522720
commit fab08fcfa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 9 deletions

View File

@ -109,7 +109,11 @@ class GraphHandlerObj {
inline void data_malloc() { g->dataMalloc(); }
inline void run() { g->getRuntime()->run(g, true); }
inline void tune() { g->getRuntime()->run(g, true); }
inline void run() { g->getRuntime()->run(g); }
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
};
} // namespace infini

View File

@ -444,7 +444,8 @@ class OnnxStub:
tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors[node.input[1]] if len(node.input) > 1 else None,
# NOTE(constroy): `axes` is an attribute until opset version 13.
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0,
)
@ -794,9 +795,15 @@ class OnnxStub:
def optimize(self) -> None:
self.handler.optimize()
def tune(self) -> None:
self.handler.tune()
def run(self) -> None:
self.handler.run()
def get_perf_time(self) -> float:
self.handler.get_perf_time()
def from_onnx(model: ModelProto, runtime):
stub = OnnxStub(model, runtime)

View File

@ -347,7 +347,10 @@ void init_graph_builder(py::module &m) {
.def("convert_nhwc", &Handler::convert_nhwc, policy::automatic)
.def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic);
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
}
void load_apps(py::module &m) {

View File

@ -14,14 +14,25 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
auto op = as<ElementWiseObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto aTensor = op->getInputs(0);
auto bTensor = op->getInputs(1);
auto cTensor = op->getOutput();
// cudnnOpTensor only allows B to be broadcasted.
if (aTensor->getDims() != cTensor->getDims()) {
swap(aTensor, bTensor);
}
IT_ASSERT(aTensor->getDims() == cTensor->getDims(),
"Shape does not match.");
void *const aData = (aTensor->getRawDataPtr<void *>());
void *const bData = (bTensor->getRawDataPtr<void *>());
void *const cData = (cTensor->getRawDataPtr<void *>());
cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims();
auto a_dim = aTensor->getDims();
auto b_dim = bTensor->getDims();
auto c_dim = cTensor->getDims();
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
IT_TODO_HALT();