opt: 优化 ReduceMeanObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 15:14:28 +08:00
parent d11fb0ad5f
commit fb9d84dbb7
2 changed files with 20 additions and 20 deletions

View File

@ -7,7 +7,7 @@ namespace infini {
* *
*/ */
class ReduceMeanObj : public OperatorObj { class ReduceMeanObj : public OperatorObj {
set<int> axis; // axis to reduce set<int> axes; // axis to reduce
bool keepDims; bool keepDims;
public: public:
@ -17,11 +17,11 @@ class ReduceMeanObj : public OperatorObj {
* @param graph The computation graph that this operator belongs to. * @param graph The computation graph that this operator belongs to.
* @param input The input tensor. * @param input The input tensor.
* @param output The output tensor. * @param output The output tensor.
* @param axis Axes to reduce. * @param axes Axes to reduce.
* @param keepDims Keep the reduced dimensions or not. * @param keepDims Keep the reduced dimensions or not.
*/ */
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &axis, const optional<const vector<int>> &axes,
bool keepDims = true); bool keepDims = true);
OP_CLONE(ReduceMeanObj); OP_CLONE(ReduceMeanObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -2,27 +2,27 @@
namespace infini { namespace infini {
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &_axis, const optional<const vector<int>> &_axes,
bool keepDims) bool keepDims)
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
const auto size = input->getDims().size();
if (_axis != std::nullopt) { if (_axes) {
IT_ASSERT((*_axis).size() <= input->getDims().size()); // TODO 不需要这个,但需要处理负数,一对相反数应该不能同时出现。
for (size_t j = 0; j < (*_axis).size(); ++j) { // IT_ASSERT((*_axes).size() <= input->getDims().size());
int idx = (*_axis)[j]; for (auto idx : *_axes) {
if (idx < 0) if (idx < 0)
IT_TODO_HALT(); IT_TODO_HALT();
IT_ASSERT((size_t)idx < input->getDims().size()); IT_ASSERT((size_t)idx < size);
axis.emplace(idx); axes.emplace(idx);
} }
} else } else
for (size_t i = 0; i < input->getDims().size(); ++i) for (size_t i = 0; i < size; ++i)
axis.emplace(i); axes.emplace(i);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
bool ReduceMeanObj::isReduced(int idx) const { bool ReduceMeanObj::isReduced(int idx) const {
return axis.find(idx) != axis.end(); return axes.find(idx) != axes.end();
} }
optional<vector<Shape>> optional<vector<Shape>>
@ -31,7 +31,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const {
if (keepDims) { if (keepDims) {
Shape ret = dims; Shape ret = dims;
for (auto it : axis) for (auto it : axes)
ret[it] = 1; ret[it] = 1;
return {{ret}}; return {{ret}};
} else { } else {
@ -55,14 +55,14 @@ std::string ReduceMeanObj::toString() const {
std::string axisstr; std::string axisstr;
axisstr.append("["); axisstr.append("[");
for (auto d : axis) { for (auto d : axes) {
axisstr.append(std::to_string(d)); axisstr.append(std::to_string(d));
axisstr.append(","); axisstr.append(",");
} }
if (!axis.empty()) if (!axes.empty())
axisstr.pop_back(); axisstr.pop_back();
axisstr.append("]"); axisstr.append("]");
os << "axis=" << axisstr << ","; os << "axes=" << axisstr << ",";
os << "keepDims=" << keepDims << ","; os << "keepDims=" << keepDims << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
@ -73,13 +73,13 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type)); ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace_back((int)keepDims); ret.emplace_back((int)keepDims);
ret.insert(ret.end(), axis.begin(), axis.end()); ret.insert(ret.end(), axes.begin(), axes.end());
return ret; return ret;
} }
vector<int> ReduceMeanObj::getOpAttrVector() const { vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ret = {enum_to_underlying(type), (int)keepDims}; vector<int> ret = {enum_to_underlying(type), (int)keepDims};
ret.insert(ret.end(), axis.begin(), axis.end()); ret.insert(ret.end(), axes.begin(), axes.end());
return ret; return ret;
} }
} // namespace infini } // namespace infini