fix: 修改 clip 接口

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-06-30 14:16:09 +08:00
parent 3141676b18
commit 1ec2fd09e5
4 changed files with 8 additions and 10 deletions

View File

@ -36,13 +36,13 @@ class ClipObj : public OperatorObj {
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
std::optional<float> getMin() const { return minValue; }; float getMin() const { return minValue; };
std::optional<float> getMax() const { return maxValue; }; float getMax() const { return maxValue; };
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
private: private:
std::optional<float> minValue, maxValue; float minValue, maxValue;
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };

View File

@ -1,6 +1,5 @@
#include "core/graph.h" #include "core/graph.h"
#include "core/runtime.h" #include "core/runtime.h"
#include "nnet/dbg.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
@ -149,7 +148,6 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector, g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector,
eOp->getKeepDims()); eOp->getKeepDims());
} else { } else {
dbg(op);
for (auto &t : inputs) { for (auto &t : inputs) {
if (t->getDims().size() != 4) if (t->getDims().size() != 4)
IT_TODO_HALT(); IT_TODO_HALT();

View File

@ -68,9 +68,9 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
auto n = op->getOutput()->size(); auto n = op->getOutput()->size();
for (size_t offset = 0; offset < n; offset++) { for (size_t offset = 0; offset < n; offset++) {
auto val = *inptr++; auto val = *inptr++;
*outptr++ = (minValue && val < *minValue) ? *minValue *outptr++ = val < minValue ? minValue
: (maxValue && val > *maxValue) ? *maxValue : val > maxValue ? maxValue
: val; : val;
} }
} }
}; };

View File

@ -34,8 +34,8 @@ vector<int> UnaryObj::getOpAttrVector() const {
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output, ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
std::optional<float> min, std::optional<float> max) std::optional<float> min, std::optional<float> max)
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min), : OperatorObj(OpType::Clip, {input}, {output}),
maxValue(max) { minValue(min ? *min : -INFINITY), maxValue(max ? *max : INFINITY) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }