fix ReduceMean and element_wise (#90)

* feat: 导出 getPerfTime 到 python

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix parsing of ReduceMean

* ReduceMean axes defaults to None

* fix ElementWiseCudnn with shape broadcasting

* fix format

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
constroy 2023-06-29 07:15:07 +08:00 committed by GitHub
parent 19d7dc871d
commit 579cdbbb81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 8 deletions

View File

@ -101,7 +101,11 @@ class GraphHandlerObj {
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc() { g->dataMalloc(); }
inline void tune() { g->getRuntime()->run(g, true); }
inline void run() { g->getRuntime()->run(g); } inline void run() { g->getRuntime()->run(g); }
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
}; };
} // namespace infini } // namespace infini

View File

@ -476,7 +476,8 @@ class OnnxStub:
tensors[node.output[0]] = self.handler.reduce_mean( tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
tensors[node.input[1]] if len(node.input) > 1 else None, # NOTE(constroy): `axes` is an attribute until opset version 13.
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
next((attr.i for attr in node.attribute if attr.name == "keepdims")) next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0, != 0,
) )
@ -823,9 +824,15 @@ class OnnxStub:
def optimize(self) -> None: def optimize(self) -> None:
self.handler.optimize() self.handler.optimize()
def tune(self) -> None:
self.handler.tune()
def run(self) -> None: def run(self) -> None:
self.handler.run() self.handler.run()
def get_perf_time(self) -> float:
self.handler.get_perf_time()
def from_onnx(model: ModelProto, runtime): def from_onnx(model: ModelProto, runtime):
stub = OnnxStub(model, runtime) stub = OnnxStub(model, runtime)

View File

@ -325,7 +325,10 @@ void init_graph_builder(py::module &m) {
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
} }
} // namespace infini } // namespace infini

View File

@ -14,14 +14,25 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); auto aTensor = op->getInputs(0);
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); auto bTensor = op->getInputs(1);
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); auto cTensor = op->getOutput();
// cudnnOpTensor only allows B to be broadcasted.
if (aTensor->getDims() != cTensor->getDims()) {
swap(aTensor, bTensor);
}
IT_ASSERT(aTensor->getDims() == cTensor->getDims(),
"Shape does not match.");
void *const aData = (aTensor->getRawDataPtr<void *>());
void *const bData = (bTensor->getRawDataPtr<void *>());
void *const cData = (cTensor->getRawDataPtr<void *>());
cudnnTensorDescriptor_t aDesc, bDesc, cDesc; cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = aTensor->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = bTensor->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = cTensor->getDims();
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4) if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
IT_TODO_HALT(); IT_TODO_HALT();