InfiniTensor/include/nnet/Visitor/GetTensorsVisitor.h

22 lines
445 B
C
Raw Normal View History

2022-08-08 16:02:07 +08:00
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Get all tensors in the stage
class GetTensorsVisitor : public ExprTreeVisitor {
private:
unordered_map<string, Tensor> tensors;
void visit_(const Tensor &c) override;
public:
GetTensorsVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
auto get(const Expr &c) {
dispatch(c);
return tensors;
}
};
} // namespace nnet