diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 7af4bbe1..2c9b25b4 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -101,7 +101,11 @@ class GraphHandlerObj { inline void data_malloc() { g->dataMalloc(); } + inline void tune() { g->getRuntime()->run(g, true); } + inline void run() { g->getRuntime()->run(g); } + + inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); } }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a450b281..b57390db 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -476,7 +476,8 @@ class OnnxStub: tensors[node.output[0]] = self.handler.reduce_mean( tensors[node.input[0]], tensors.get(node.output[0]), - tensors[node.input[1]] if len(node.input) > 1 else None, + # NOTE(constroy): `axes` is an attribute until opset version 13. + next((attr.ints for attr in node.attribute if attr.name == "axes"), None), next((attr.i for attr in node.attribute if attr.name == "keepdims")) != 0, ) @@ -823,9 +824,15 @@ class OnnxStub: def optimize(self) -> None: self.handler.optimize() + def tune(self) -> None: + self.handler.tune() + def run(self) -> None: self.handler.run() + def get_perf_time(self) -> float: + self.handler.get_perf_time() + def from_onnx(model: ModelProto, runtime): stub = OnnxStub(model, runtime) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 6bb77256..1da7203a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -325,7 +325,10 @@ void init_graph_builder(py::module &m) { .def("optimize", &Handler::optimize, policy::automatic) .def("operators", &Handler::operators, policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic) - .def("run", &Handler::run, policy::automatic); + .def("get_perf_time", &Handler::get_perf_time, policy::automatic) + .def("tune", &Handler::tune, policy::automatic) + .def("run", &Handler::run, policy::automatic) + .def("get_perf_time", &Handler::get_perf_time, policy::automatic); } } // namespace infini diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index a74d1bdf..dc57e8d1 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -14,14 +14,25 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig { auto op = as(_op); auto context = dynamic_cast(_context); - void *const aData = (op->getInputs(0)->getRawDataPtr()); - void *const bData = (op->getInputs(1)->getRawDataPtr()); - void *const cData = (op->getOutput()->getRawDataPtr()); + auto aTensor = op->getInputs(0); + auto bTensor = op->getInputs(1); + auto cTensor = op->getOutput(); + + // cudnnOpTensor only allows B to be broadcasted. + if (aTensor->getDims() != cTensor->getDims()) { + swap(aTensor, bTensor); + } + IT_ASSERT(aTensor->getDims() == cTensor->getDims(), + "Shape does not match."); + + void *const aData = (aTensor->getRawDataPtr()); + void *const bData = (bTensor->getRawDataPtr()); + void *const cData = (cTensor->getRawDataPtr()); cudnnTensorDescriptor_t aDesc, bDesc, cDesc; - auto a_dim = op->getInputs(0)->getDims(); - auto b_dim = op->getInputs(1)->getDims(); - auto c_dim = op->getOutput()->getDims(); + auto a_dim = aTensor->getDims(); + auto b_dim = bTensor->getDims(); + auto c_dim = cTensor->getDims(); if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4) IT_TODO_HALT();