opt: 优化 ReduceMeanObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 15:14:28 +08:00
parent d11fb0ad5f
commit fb9d84dbb7
2 changed files with 20 additions and 20 deletions

View File

@ -7,7 +7,7 @@ namespace infini {
*
*/
class ReduceMeanObj : public OperatorObj {
set<int> axis; // axis to reduce
set<int> axes; // axis to reduce
bool keepDims;
public:
@ -17,11 +17,11 @@ class ReduceMeanObj : public OperatorObj {
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
* @param axis Axes to reduce.
* @param axes Axes to reduce.
* @param keepDims Keep the reduced dimensions or not.
*/
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &axes,
bool keepDims = true);
OP_CLONE(ReduceMeanObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -2,27 +2,27 @@
namespace infini {
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &_axis,
const optional<const vector<int>> &_axes,
bool keepDims)
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
if (_axis != std::nullopt) {
IT_ASSERT((*_axis).size() <= input->getDims().size());
for (size_t j = 0; j < (*_axis).size(); ++j) {
int idx = (*_axis)[j];
const auto size = input->getDims().size();
if (_axes) {
// TODO 不需要这个,但需要处理负数,一对相反数应该不能同时出现。
// IT_ASSERT((*_axes).size() <= input->getDims().size());
for (auto idx : *_axes) {
if (idx < 0)
IT_TODO_HALT();
IT_ASSERT((size_t)idx < input->getDims().size());
axis.emplace(idx);
IT_ASSERT((size_t)idx < size);
axes.emplace(idx);
}
} else
for (size_t i = 0; i < input->getDims().size(); ++i)
axis.emplace(i);
for (size_t i = 0; i < size; ++i)
axes.emplace(i);
IT_ASSERT(checkValid(graph));
}
bool ReduceMeanObj::isReduced(int idx) const {
return axis.find(idx) != axis.end();
return axes.find(idx) != axes.end();
}
optional<vector<Shape>>
@ -31,7 +31,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const {
if (keepDims) {
Shape ret = dims;
for (auto it : axis)
for (auto it : axes)
ret[it] = 1;
return {{ret}};
} else {
@ -55,14 +55,14 @@ std::string ReduceMeanObj::toString() const {
std::string axisstr;
axisstr.append("[");
for (auto d : axis) {
for (auto d : axes) {
axisstr.append(std::to_string(d));
axisstr.append(",");
}
if (!axis.empty())
if (!axes.empty())
axisstr.pop_back();
axisstr.append("]");
os << "axis=" << axisstr << ",";
os << "axes=" << axisstr << ",";
os << "keepDims=" << keepDims << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
@ -73,13 +73,13 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace_back((int)keepDims);
ret.insert(ret.end(), axis.begin(), axis.end());
ret.insert(ret.end(), axes.begin(), axes.end());
return ret;
}
vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
ret.insert(ret.end(), axis.begin(), axis.end());
ret.insert(ret.end(), axes.begin(), axes.end());
return ret;
}
} // namespace infini