forked from jiuyuan/InfiniTensor
fix ReduceMean and element_wise (#90)
* feat: 导出 getPerfTime 到 python Signed-off-by: YdrMaster <ydrml@hotmail.com> * fix parsing of ReduceMean * ReduceMean axes defaults to None * fix ElementWiseCudnn with shape broadcasting * fix format --------- Signed-off-by: YdrMaster <ydrml@hotmail.com> Co-authored-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
19d7dc871d
commit
579cdbbb81
|
@ -101,7 +101,11 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
inline void data_malloc() { g->dataMalloc(); }
|
inline void data_malloc() { g->dataMalloc(); }
|
||||||
|
|
||||||
|
inline void tune() { g->getRuntime()->run(g, true); }
|
||||||
|
|
||||||
inline void run() { g->getRuntime()->run(g); }
|
inline void run() { g->getRuntime()->run(g); }
|
||||||
|
|
||||||
|
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -476,7 +476,8 @@ class OnnxStub:
|
||||||
tensors[node.output[0]] = self.handler.reduce_mean(
|
tensors[node.output[0]] = self.handler.reduce_mean(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
tensors[node.input[1]] if len(node.input) > 1 else None,
|
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||||
|
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
|
||||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||||
!= 0,
|
!= 0,
|
||||||
)
|
)
|
||||||
|
@ -823,9 +824,15 @@ class OnnxStub:
|
||||||
def optimize(self) -> None:
|
def optimize(self) -> None:
|
||||||
self.handler.optimize()
|
self.handler.optimize()
|
||||||
|
|
||||||
|
def tune(self) -> None:
|
||||||
|
self.handler.tune()
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.handler.run()
|
self.handler.run()
|
||||||
|
|
||||||
|
def get_perf_time(self) -> float:
|
||||||
|
self.handler.get_perf_time()
|
||||||
|
|
||||||
|
|
||||||
def from_onnx(model: ModelProto, runtime):
|
def from_onnx(model: ModelProto, runtime):
|
||||||
stub = OnnxStub(model, runtime)
|
stub = OnnxStub(model, runtime)
|
||||||
|
|
|
@ -325,7 +325,10 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("optimize", &Handler::optimize, policy::automatic)
|
.def("optimize", &Handler::optimize, policy::automatic)
|
||||||
.def("operators", &Handler::operators, policy::move)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
|
||||||
|
.def("tune", &Handler::tune, policy::automatic)
|
||||||
|
.def("run", &Handler::run, policy::automatic)
|
||||||
|
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -14,14 +14,25 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
||||||
auto op = as<ElementWiseObj>(_op);
|
auto op = as<ElementWiseObj>(_op);
|
||||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
auto aTensor = op->getInputs(0);
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
auto bTensor = op->getInputs(1);
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
auto cTensor = op->getOutput();
|
||||||
|
|
||||||
|
// cudnnOpTensor only allows B to be broadcasted.
|
||||||
|
if (aTensor->getDims() != cTensor->getDims()) {
|
||||||
|
swap(aTensor, bTensor);
|
||||||
|
}
|
||||||
|
IT_ASSERT(aTensor->getDims() == cTensor->getDims(),
|
||||||
|
"Shape does not match.");
|
||||||
|
|
||||||
|
void *const aData = (aTensor->getRawDataPtr<void *>());
|
||||||
|
void *const bData = (bTensor->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (cTensor->getRawDataPtr<void *>());
|
||||||
|
|
||||||
cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
|
cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
|
||||||
auto a_dim = op->getInputs(0)->getDims();
|
auto a_dim = aTensor->getDims();
|
||||||
auto b_dim = op->getInputs(1)->getDims();
|
auto b_dim = bTensor->getDims();
|
||||||
auto c_dim = op->getOutput()->getDims();
|
auto c_dim = cTensor->getDims();
|
||||||
|
|
||||||
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
|
Loading…
Reference in New Issue