forked from jiuyuan/InfiniTensor
22 lines
445 B
C++
22 lines
445 B
C++
#pragma once
|
|
#include "nnet/visitor.h"
|
|
|
|
namespace nnet {
|
|
|
|
// Get all tensors in the stage
|
|
class GetTensorsVisitor : public ExprTreeVisitor {
|
|
private:
|
|
unordered_map<string, Tensor> tensors;
|
|
|
|
void visit_(const Tensor &c) override;
|
|
|
|
public:
|
|
GetTensorsVisitor(int _verobse = 0)
|
|
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
|
auto get(const Expr &c) {
|
|
dispatch(c);
|
|
return tensors;
|
|
}
|
|
};
|
|
|
|
} // namespace nnet
|