InfiniTensor/include/nnet/Visitor/InputVisitor.h

23 lines
491 B
C++

#pragma once
#include "nnet/visitor.h"
namespace nnet {
class InputVisitor : public ExprTreeVisitor {
vector<Tensor> inputs;
public:
int nInputs = 0;
InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
void visit_(const Tensor &c) override;
/**
* @brief Get the all inputs in the netsed stages
*/
vector<Tensor> getInputs(const RangeOp &_rangeOp) {
dispatch(_rangeOp);
return inputs;
}
};
} // namespace nnet