forked from jiuyuan/InfiniTensor
fix: 修改 clip 接口
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
3141676b18
commit
1ec2fd09e5
|
@ -36,13 +36,13 @@ class ClipObj : public OperatorObj {
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
std::optional<float> getMin() const { return minValue; };
|
float getMin() const { return minValue; };
|
||||||
std::optional<float> getMax() const { return maxValue; };
|
float getMax() const { return maxValue; };
|
||||||
int numInputs() const override { return 1; }
|
int numInputs() const override { return 1; }
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::optional<float> minValue, maxValue;
|
float minValue, maxValue;
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "nnet/dbg.h"
|
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
|
@ -149,7 +148,6 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
|
||||||
g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector,
|
g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector,
|
||||||
eOp->getKeepDims());
|
eOp->getKeepDims());
|
||||||
} else {
|
} else {
|
||||||
dbg(op);
|
|
||||||
for (auto &t : inputs) {
|
for (auto &t : inputs) {
|
||||||
if (t->getDims().size() != 4)
|
if (t->getDims().size() != 4)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
|
|
@ -68,9 +68,9 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
|
||||||
auto n = op->getOutput()->size();
|
auto n = op->getOutput()->size();
|
||||||
for (size_t offset = 0; offset < n; offset++) {
|
for (size_t offset = 0; offset < n; offset++) {
|
||||||
auto val = *inptr++;
|
auto val = *inptr++;
|
||||||
*outptr++ = (minValue && val < *minValue) ? *minValue
|
*outptr++ = val < minValue ? minValue
|
||||||
: (maxValue && val > *maxValue) ? *maxValue
|
: val > maxValue ? maxValue
|
||||||
: val;
|
: val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -34,8 +34,8 @@ vector<int> UnaryObj::getOpAttrVector() const {
|
||||||
|
|
||||||
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
std::optional<float> min, std::optional<float> max)
|
std::optional<float> min, std::optional<float> max)
|
||||||
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min),
|
: OperatorObj(OpType::Clip, {input}, {output}),
|
||||||
maxValue(max) {
|
minValue(min ? *min : -INFINITY), maxValue(max ? *max : INFINITY) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue