diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h index cfa4eccc..23ea1432 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce_mean.h @@ -7,7 +7,7 @@ namespace infini { * */ class ReduceMeanObj : public OperatorObj { - set axis; // axis to reduce + set axes; // axis to reduce bool keepDims; public: @@ -17,11 +17,11 @@ class ReduceMeanObj : public OperatorObj { * @param graph The computation graph that this operator belongs to. * @param input The input tensor. * @param output The output tensor. - * @param axis Axes to reduce. + * @param axes Axes to reduce. * @param keepDims Keep the reduced dimensions or not. */ ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &axis, + const optional> &axes, bool keepDims = true); OP_CLONE(ReduceMeanObj); optional> inferShape(const TensorVec &inputs) const override; diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce_mean.cc index 3e627102..b59cc828 100644 --- a/src/operators/reduce_mean.cc +++ b/src/operators/reduce_mean.cc @@ -2,27 +2,27 @@ namespace infini { ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &_axis, + const optional> &_axes, bool keepDims) : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { - - if (_axis != std::nullopt) { - IT_ASSERT((*_axis).size() <= input->getDims().size()); - for (size_t j = 0; j < (*_axis).size(); ++j) { - int idx = (*_axis)[j]; + const auto size = input->getDims().size(); + if (_axes) { + // TODO 不需要这个,但需要处理负数,一对相反数应该不能同时出现。 + // IT_ASSERT((*_axes).size() <= input->getDims().size()); + for (auto idx : *_axes) { if (idx < 0) IT_TODO_HALT(); - IT_ASSERT((size_t)idx < input->getDims().size()); - axis.emplace(idx); + IT_ASSERT((size_t)idx < size); + axes.emplace(idx); } } else - for (size_t i = 0; i < input->getDims().size(); ++i) - axis.emplace(i); + for (size_t i = 0; i < size; ++i) + axes.emplace(i); IT_ASSERT(checkValid(graph)); } bool ReduceMeanObj::isReduced(int idx) const { - return axis.find(idx) != axis.end(); + return axes.find(idx) != axes.end(); } optional> @@ -31,7 +31,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const { if (keepDims) { Shape ret = dims; - for (auto it : axis) + for (auto it : axes) ret[it] = 1; return {{ret}}; } else { @@ -55,14 +55,14 @@ std::string ReduceMeanObj::toString() const { std::string axisstr; axisstr.append("["); - for (auto d : axis) { + for (auto d : axes) { axisstr.append(std::to_string(d)); axisstr.append(","); } - if (!axis.empty()) + if (!axes.empty()) axisstr.pop_back(); axisstr.append("]"); - os << "axis=" << axisstr << ","; + os << "axes=" << axisstr << ","; os << "keepDims=" << keepDims << ","; os << "input=" << inputs[0]->getGuid() << ","; os << "output=" << outputs[0]->getGuid() << ")"; @@ -73,13 +73,13 @@ vector ReduceMeanObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace_back((int)keepDims); - ret.insert(ret.end(), axis.begin(), axis.end()); + ret.insert(ret.end(), axes.begin(), axes.end()); return ret; } vector ReduceMeanObj::getOpAttrVector() const { vector ret = {enum_to_underlying(type), (int)keepDims}; - ret.insert(ret.end(), axis.begin(), axis.end()); + ret.insert(ret.end(), axes.begin(), axes.end()); return ret; } } // namespace infini