forked from jiuyuan/InfiniTensor
opt: 优化 ReduceMeanObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
d11fb0ad5f
commit
fb9d84dbb7
|
@ -7,7 +7,7 @@ namespace infini {
|
|||
*
|
||||
*/
|
||||
class ReduceMeanObj : public OperatorObj {
|
||||
set<int> axis; // axis to reduce
|
||||
set<int> axes; // axis to reduce
|
||||
bool keepDims;
|
||||
|
||||
public:
|
||||
|
@ -17,11 +17,11 @@ class ReduceMeanObj : public OperatorObj {
|
|||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axis Axes to reduce.
|
||||
* @param axes Axes to reduce.
|
||||
* @param keepDims Keep the reduced dimensions or not.
|
||||
*/
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<const vector<int>> &axis,
|
||||
const optional<const vector<int>> &axes,
|
||||
bool keepDims = true);
|
||||
OP_CLONE(ReduceMeanObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -2,27 +2,27 @@
|
|||
|
||||
namespace infini {
|
||||
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<const vector<int>> &_axis,
|
||||
const optional<const vector<int>> &_axes,
|
||||
bool keepDims)
|
||||
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
||||
|
||||
if (_axis != std::nullopt) {
|
||||
IT_ASSERT((*_axis).size() <= input->getDims().size());
|
||||
for (size_t j = 0; j < (*_axis).size(); ++j) {
|
||||
int idx = (*_axis)[j];
|
||||
const auto size = input->getDims().size();
|
||||
if (_axes) {
|
||||
// TODO 不需要这个,但需要处理负数,一对相反数应该不能同时出现。
|
||||
// IT_ASSERT((*_axes).size() <= input->getDims().size());
|
||||
for (auto idx : *_axes) {
|
||||
if (idx < 0)
|
||||
IT_TODO_HALT();
|
||||
IT_ASSERT((size_t)idx < input->getDims().size());
|
||||
axis.emplace(idx);
|
||||
IT_ASSERT((size_t)idx < size);
|
||||
axes.emplace(idx);
|
||||
}
|
||||
} else
|
||||
for (size_t i = 0; i < input->getDims().size(); ++i)
|
||||
axis.emplace(i);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
axes.emplace(i);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
bool ReduceMeanObj::isReduced(int idx) const {
|
||||
return axis.find(idx) != axis.end();
|
||||
return axes.find(idx) != axes.end();
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
|
@ -31,7 +31,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const {
|
|||
|
||||
if (keepDims) {
|
||||
Shape ret = dims;
|
||||
for (auto it : axis)
|
||||
for (auto it : axes)
|
||||
ret[it] = 1;
|
||||
return {{ret}};
|
||||
} else {
|
||||
|
@ -55,14 +55,14 @@ std::string ReduceMeanObj::toString() const {
|
|||
|
||||
std::string axisstr;
|
||||
axisstr.append("[");
|
||||
for (auto d : axis) {
|
||||
for (auto d : axes) {
|
||||
axisstr.append(std::to_string(d));
|
||||
axisstr.append(",");
|
||||
}
|
||||
if (!axis.empty())
|
||||
if (!axes.empty())
|
||||
axisstr.pop_back();
|
||||
axisstr.append("]");
|
||||
os << "axis=" << axisstr << ",";
|
||||
os << "axes=" << axisstr << ",";
|
||||
os << "keepDims=" << keepDims << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
|
@ -73,13 +73,13 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
|
|||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
ret.emplace_back((int)keepDims);
|
||||
ret.insert(ret.end(), axis.begin(), axis.end());
|
||||
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> ReduceMeanObj::getOpAttrVector() const {
|
||||
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
|
||||
ret.insert(ret.end(), axis.begin(), axis.end());
|
||||
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||
return ret;
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue