forked from jiuyuan/InfiniTensor
fix: 修改 clip 接口
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
3141676b18
commit
1ec2fd09e5
|
@ -36,13 +36,13 @@ class ClipObj : public OperatorObj {
|
|||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
std::optional<float> getMin() const { return minValue; };
|
||||
std::optional<float> getMax() const { return maxValue; };
|
||||
float getMin() const { return minValue; };
|
||||
float getMax() const { return maxValue; };
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
std::optional<float> minValue, maxValue;
|
||||
float minValue, maxValue;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
|
@ -149,7 +148,6 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
|
|||
g->addOp<ReduceMeanObj>(inputs[0], nullptr, axes_vector,
|
||||
eOp->getKeepDims());
|
||||
} else {
|
||||
dbg(op);
|
||||
for (auto &t : inputs) {
|
||||
if (t->getDims().size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
|
|
@ -68,9 +68,9 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
|
|||
auto n = op->getOutput()->size();
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
auto val = *inptr++;
|
||||
*outptr++ = (minValue && val < *minValue) ? *minValue
|
||||
: (maxValue && val > *maxValue) ? *maxValue
|
||||
: val;
|
||||
*outptr++ = val < minValue ? minValue
|
||||
: val > maxValue ? maxValue
|
||||
: val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -34,8 +34,8 @@ vector<int> UnaryObj::getOpAttrVector() const {
|
|||
|
||||
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
std::optional<float> min, std::optional<float> max)
|
||||
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min),
|
||||
maxValue(max) {
|
||||
: OperatorObj(OpType::Clip, {input}, {output}),
|
||||
minValue(min ? *min : -INFINITY), maxValue(max ? *max : INFINITY) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue