forked from jiuyuan/InfiniTensor
opt: 优化 ReduceMeanObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
d11fb0ad5f
commit
fb9d84dbb7
|
@ -7,7 +7,7 @@ namespace infini {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
class ReduceMeanObj : public OperatorObj {
|
class ReduceMeanObj : public OperatorObj {
|
||||||
set<int> axis; // axis to reduce
|
set<int> axes; // axis to reduce
|
||||||
bool keepDims;
|
bool keepDims;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -17,11 +17,11 @@ class ReduceMeanObj : public OperatorObj {
|
||||||
* @param graph The computation graph that this operator belongs to.
|
* @param graph The computation graph that this operator belongs to.
|
||||||
* @param input The input tensor.
|
* @param input The input tensor.
|
||||||
* @param output The output tensor.
|
* @param output The output tensor.
|
||||||
* @param axis Axes to reduce.
|
* @param axes Axes to reduce.
|
||||||
* @param keepDims Keep the reduced dimensions or not.
|
* @param keepDims Keep the reduced dimensions or not.
|
||||||
*/
|
*/
|
||||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const optional<const vector<int>> &axis,
|
const optional<const vector<int>> &axes,
|
||||||
bool keepDims = true);
|
bool keepDims = true);
|
||||||
OP_CLONE(ReduceMeanObj);
|
OP_CLONE(ReduceMeanObj);
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
|
@ -2,27 +2,27 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const optional<const vector<int>> &_axis,
|
const optional<const vector<int>> &_axes,
|
||||||
bool keepDims)
|
bool keepDims)
|
||||||
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
||||||
|
const auto size = input->getDims().size();
|
||||||
if (_axis != std::nullopt) {
|
if (_axes) {
|
||||||
IT_ASSERT((*_axis).size() <= input->getDims().size());
|
// TODO 不需要这个,但需要处理负数,一对相反数应该不能同时出现。
|
||||||
for (size_t j = 0; j < (*_axis).size(); ++j) {
|
// IT_ASSERT((*_axes).size() <= input->getDims().size());
|
||||||
int idx = (*_axis)[j];
|
for (auto idx : *_axes) {
|
||||||
if (idx < 0)
|
if (idx < 0)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
IT_ASSERT((size_t)idx < input->getDims().size());
|
IT_ASSERT((size_t)idx < size);
|
||||||
axis.emplace(idx);
|
axes.emplace(idx);
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
for (size_t i = 0; i < input->getDims().size(); ++i)
|
for (size_t i = 0; i < size; ++i)
|
||||||
axis.emplace(i);
|
axes.emplace(i);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ReduceMeanObj::isReduced(int idx) const {
|
bool ReduceMeanObj::isReduced(int idx) const {
|
||||||
return axis.find(idx) != axis.end();
|
return axes.find(idx) != axes.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<vector<Shape>>
|
optional<vector<Shape>>
|
||||||
|
@ -31,7 +31,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
|
||||||
if (keepDims) {
|
if (keepDims) {
|
||||||
Shape ret = dims;
|
Shape ret = dims;
|
||||||
for (auto it : axis)
|
for (auto it : axes)
|
||||||
ret[it] = 1;
|
ret[it] = 1;
|
||||||
return {{ret}};
|
return {{ret}};
|
||||||
} else {
|
} else {
|
||||||
|
@ -55,14 +55,14 @@ std::string ReduceMeanObj::toString() const {
|
||||||
|
|
||||||
std::string axisstr;
|
std::string axisstr;
|
||||||
axisstr.append("[");
|
axisstr.append("[");
|
||||||
for (auto d : axis) {
|
for (auto d : axes) {
|
||||||
axisstr.append(std::to_string(d));
|
axisstr.append(std::to_string(d));
|
||||||
axisstr.append(",");
|
axisstr.append(",");
|
||||||
}
|
}
|
||||||
if (!axis.empty())
|
if (!axes.empty())
|
||||||
axisstr.pop_back();
|
axisstr.pop_back();
|
||||||
axisstr.append("]");
|
axisstr.append("]");
|
||||||
os << "axis=" << axisstr << ",";
|
os << "axes=" << axisstr << ",";
|
||||||
os << "keepDims=" << keepDims << ",";
|
os << "keepDims=" << keepDims << ",";
|
||||||
os << "input=" << inputs[0]->getGuid() << ",";
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
os << "output=" << outputs[0]->getGuid() << ")";
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
@ -73,13 +73,13 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
|
||||||
vector<int> ret = inputs[0]->getDims();
|
vector<int> ret = inputs[0]->getDims();
|
||||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
ret.emplace_back((int)keepDims);
|
ret.emplace_back((int)keepDims);
|
||||||
ret.insert(ret.end(), axis.begin(), axis.end());
|
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> ReduceMeanObj::getOpAttrVector() const {
|
vector<int> ReduceMeanObj::getOpAttrVector() const {
|
||||||
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
|
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
|
||||||
ret.insert(ret.end(), axis.begin(), axis.end());
|
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue