fix: 修改 clip 接口

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-06-30 14:16:09 +08:00
parent 3141676b18
commit 1ec2fd09e5
4 changed files with 8 additions and 10 deletions

View File

@ -36,13 +36,13 @@ class ClipObj : public OperatorObj {
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
std::optional<float> getMin() const { return minValue; };
std::optional<float> getMax() const { return maxValue; };
float getMin() const { return minValue; };
float getMax() const { return maxValue; };
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
std::optional<float> minValue, maxValue;
float minValue, maxValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};

View File

@ -1,6 +1,5 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "nnet/dbg.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
@ -149,7 +148,6 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector,
eOp->getKeepDims());
} else {
dbg(op);
for (auto &t : inputs) {
if (t->getDims().size() != 4)
IT_TODO_HALT();

View File

@ -68,9 +68,9 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
auto n = op->getOutput()->size();
for (size_t offset = 0; offset < n; offset++) {
auto val = *inptr++;
*outptr++ = (minValue && val < *minValue) ? *minValue
: (maxValue && val > *maxValue) ? *maxValue
: val;
*outptr++ = val < minValue ? minValue
: val > maxValue ? maxValue
: val;
}
}
};

View File

@ -34,8 +34,8 @@ vector<int> UnaryObj::getOpAttrVector() const {
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
std::optional<float> min, std::optional<float> max)
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min),
maxValue(max) {
: OperatorObj(OpType::Clip, {input}, {output}),
minValue(min ? *min : -INFINITY), maxValue(max ? *max : INFINITY) {
IT_ASSERT(checkValid(graph));
}