From b7e2096a266e7fc6829e72d7af37a4d6c23fcfec Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Mon, 8 Aug 2022 16:02:07 +0800 Subject: [PATCH] Add: nnet code --- CMakeLists.txt | 47 +- include/nnet/Pass/MatchComputationKernel.h | 15 + include/nnet/Pass/MatchMemBoundKernel.h | 15 + include/nnet/Pass/Pass.h | 41 + include/nnet/Pass/Rule1VariableSplit.h | 18 + include/nnet/Pass/Rule2VariableMerging.h | 29 + include/nnet/Pass/Rule3StageSplit.h | 19 + include/nnet/Pass/Rule4StageMerging.h | 22 + include/nnet/Pass/Rule5RangeRelaxation.h | 16 + include/nnet/Pass/Rule6KenerlMatching.h | 17 + include/nnet/Pass/Rule7DLT.h | 16 + include/nnet/Pass/Rule8GuidedDLT.h | 48 + include/nnet/Pass/Rule90TwoStageElementWise.h | 16 + include/nnet/Pass/Rule91MergeStagesWithSum.h | 15 + include/nnet/Pass/Rule9RangeMagnify.h | 15 + include/nnet/ReplaceKit.h | 46 + include/nnet/Visitor/AsTVMVisitor.h | 38 + include/nnet/Visitor/CheckOOBVisitor.h | 22 + include/nnet/Visitor/CloneMutator.h | 16 + .../Visitor/CompareMultiFormulasVisitor.h | 15 + include/nnet/Visitor/CountRoutineVisitor.h | 18 + include/nnet/Visitor/FullPrinterVisitor.h | 25 + include/nnet/Visitor/GetTensorsVisitor.h | 22 + include/nnet/Visitor/HashVisitor.h | 31 + include/nnet/Visitor/InputVisitor.h | 23 + include/nnet/Visitor/Interpreter.h | 55 + include/nnet/Visitor/MatchReshapeVisitor.h | 14 + include/nnet/Visitor/MatchTableVisitor.h | 60 + include/nnet/Visitor/MatmulTransposeMutator.h | 18 + include/nnet/Visitor/MergeMemboundMutator.h | 20 + include/nnet/Visitor/PatternMatcher.h | 43 + include/nnet/Visitor/RangeMagnifyVisitor.h | 25 + include/nnet/Visitor/RangeRelaxFunctor.h | 18 + include/nnet/Visitor/ReplaceNodeMutator.h | 20 + include/nnet/Visitor/ReplaceVariable.h | 33 + include/nnet/Visitor/Serializer.h | 52 + include/nnet/Visitor/SimplifyExprVisitor.h | 38 + include/nnet/Visitor/SimplifyFormulaMutator.h | 18 + include/nnet/Visitor/StrideVisitor.h | 38 + include/nnet/common.h | 77 + include/nnet/dbg.h | 893 ++++++++ include/nnet/derivator.h | 156 ++ include/nnet/dlt.h | 53 + include/nnet/expr.h | 416 ++++ include/nnet/iterator_table.h | 234 +++ include/nnet/nmutator.h | 57 + include/nnet/permutation.h | 38 + include/nnet/ref.h | 200 ++ include/nnet/routine.h | 158 ++ include/nnet/test.h | 28 + include/nnet/visitor.h | 128 ++ src/nnet/Pass/MatchComputationKernel.cc | 25 + src/nnet/Pass/MatchMemBoundKernel.cc | 23 + src/nnet/Pass/Pass.cc | 58 + src/nnet/Pass/Rule1VariableSplit.cc | 134 ++ src/nnet/Pass/Rule2VariableMerging.cc | 186 ++ src/nnet/Pass/Rule3StageSplit.cc | 82 + src/nnet/Pass/Rule4StageMerging.cc | 91 + src/nnet/Pass/Rule5RangeRelaxation.cc | 72 + src/nnet/Pass/Rule6KenerlMatching.cc | 57 + src/nnet/Pass/Rule7DLT.cc | 78 + src/nnet/Pass/Rule8GuidedDLT.cc | 317 +++ src/nnet/Pass/Rule90TwoStageElementWise.cc | 54 + src/nnet/Pass/Rule91MergeStagesWithSum.cc | 11 + src/nnet/Pass/Rule9RangeMagnify.cc | 45 + src/nnet/Visitor/AsTVMVisitor.cc | 165 ++ src/nnet/Visitor/CheckOOBVisitor.cc | 35 + src/nnet/Visitor/CloneMutator.cc | 9 + .../Visitor/CompareMultiFormulasVisitor.cc | 34 + src/nnet/Visitor/CountRoutineVisitor.cc | 38 + src/nnet/Visitor/FullPrinterVisitor.cc | 58 + src/nnet/Visitor/GetTensorsVisitor.cc | 9 + src/nnet/Visitor/InputVisitor.cc | 7 + src/nnet/Visitor/Interpreter.cc | 174 ++ src/nnet/Visitor/MatchReshapeVisitor.cc | 46 + src/nnet/Visitor/MatchTableVisitor.cc | 56 + src/nnet/Visitor/MatmulTransposeMutator.cc | 100 + src/nnet/Visitor/MergeMemboundMutator.cc | 135 ++ src/nnet/Visitor/PatternMatcher.cc | 124 ++ src/nnet/Visitor/RangeMagnifyVisitor.cc | 56 + src/nnet/Visitor/RangeRelaxFunctor.cc | 133 ++ src/nnet/Visitor/ReplaceNodeMutator.cc | 23 + src/nnet/Visitor/ReplaceVariable.cc | 50 + src/nnet/Visitor/Serializer.cc | 245 +++ src/nnet/Visitor/SimplifyExprVisitor.cc | 160 ++ src/nnet/Visitor/SimplifyFormulaMutator.cc | 30 + src/nnet/Visitor/StrideVisitor.cc | 59 + src/nnet/Visitor/hashVisitor.cc | 156 ++ src/nnet/derivator.cc | 529 +++++ src/nnet/dlt.cc | 131 ++ src/nnet/expr.cc | 452 ++++ src/nnet/iterator_table.cc | 653 ++++++ src/nnet/nmutator.cc | 722 +++++++ src/nnet/permutation.cc | 35 + src/nnet/replace_kit.cc | 92 + src/nnet/routine.cc | 136 ++ src/nnet/test.cc | 70 + src/nnet/utils.cc | 12 + src/nnet/visitor.cc | 97 + test/nnet/failed/test_activation.cc | 49 + test/nnet/failed/test_csrnet.cc | 81 + test/nnet/failed/test_longformer.cc | 351 ++++ test/nnet/failed/test_multiple_op.cc | 30 + test/nnet/failed/test_mutator.cc | 407 ++++ test/nnet/failed/test_operatorCompute.cc | 34 + test/nnet/failed/test_transposeOpToExpr.cc | 26 + .../TConv4x4_NHWF_RSFC_0.expr | 395 ++++ .../TConv4x4_NHWF_RSFC_1.expr | 492 +++++ .../TConv4x4_NHWF_RSFC_10.expr | 472 +++++ .../TConv4x4_NHWF_RSFC_11.expr | 683 ++++++ .../TConv4x4_NHWF_RSFC_12.expr | 874 ++++++++ .../TConv4x4_NHWF_RSFC_13.expr | 819 ++++++++ .../TConv4x4_NHWF_RSFC_14.expr | 854 ++++++++ .../TConv4x4_NHWF_RSFC_15.expr | 819 ++++++++ .../TConv4x4_NHWF_RSFC_16.expr | 819 ++++++++ .../TConv4x4_NHWF_RSFC_17.expr | 1134 ++++++++++ .../TConv4x4_NHWF_RSFC_18.expr | 1829 +++++++++++++++++ .../TConv4x4_NHWF_RSFC_2.expr | 587 ++++++ .../TConv4x4_NHWF_RSFC_3.expr | 492 +++++ .../TConv4x4_NHWF_RSFC_4.expr | 587 ++++++ .../TConv4x4_NHWF_RSFC_5.expr | 492 +++++ .../TConv4x4_NHWF_RSFC_6.expr | 577 ++++++ .../TConv4x4_NHWF_RSFC_7.expr | 482 +++++ .../TConv4x4_NHWF_RSFC_8.expr | 557 +++++ .../TConv4x4_NHWF_RSFC_9.expr | 472 +++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_0.expr | 180 ++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_1.expr | 180 ++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_10.expr | 572 ++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_11.expr | 627 ++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_12.expr | 572 ++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_13.expr | 811 ++++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_14.expr | 736 +++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr | 1353 ++++++++++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_2.expr | 224 ++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_3.expr | 268 +++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_4.expr | 345 ++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_5.expr | 420 ++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_6.expr | 345 ++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_7.expr | 420 ++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_8.expr | 345 ++++ .../conv2conv/Conv2conv_5x5_NCHW_FCRS_9.expr | 345 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_0.expr | 179 ++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_1.expr | 223 ++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_10.expr | 636 ++++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_11.expr | 581 ++++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_12.expr | 830 ++++++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_13.expr | 755 +++++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr | 1392 +++++++++++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_2.expr | 267 +++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_3.expr | 344 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_4.expr | 419 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_5.expr | 344 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_6.expr | 419 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_7.expr | 344 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_8.expr | 344 ++++ .../conv2conv/Conv2conv_6x6_NCHW_FCRS_9.expr | 581 ++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_0.expr | 180 ++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_1.expr | 224 ++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_10.expr | 637 ++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_11.expr | 582 ++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_12.expr | 906 ++++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_13.expr | 851 ++++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr | 1488 ++++++++++++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_2.expr | 268 +++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_3.expr | 345 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_4.expr | 420 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_5.expr | 345 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_6.expr | 420 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_7.expr | 345 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_8.expr | 345 ++++ .../conv2conv/Conv2conv_9x9_NCHW_FCRS_9.expr | 582 ++++++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_0.expr | 179 ++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_1.expr | 256 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_10.expr | 898 ++++++++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr | 1573 ++++++++++++++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_2.expr | 331 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_3.expr | 256 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_4.expr | 331 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_5.expr | 256 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_6.expr | 256 +++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_7.expr | 467 +++++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_8.expr | 648 ++++++ .../log/conv2gemm/Conv2gemm_NCHW_RSFC_9.expr | 593 ++++++ .../Conv2gemm_1x7_NCHW_FCRS_0.expr | 170 ++ .../Conv2gemm_1x7_NCHW_FCRS_1.expr | 247 +++ .../Conv2gemm_1x7_NCHW_FCRS_10.expr | 889 ++++++++ .../Conv2gemm_1x7_NCHW_FCRS_11.expr | 1564 ++++++++++++++ .../Conv2gemm_1x7_NCHW_FCRS_2.expr | 322 +++ .../Conv2gemm_1x7_NCHW_FCRS_3.expr | 247 +++ .../Conv2gemm_1x7_NCHW_FCRS_4.expr | 322 +++ .../Conv2gemm_1x7_NCHW_FCRS_5.expr | 247 +++ .../Conv2gemm_1x7_NCHW_FCRS_6.expr | 247 +++ .../Conv2gemm_1x7_NCHW_FCRS_7.expr | 458 +++++ .../Conv2gemm_1x7_NCHW_FCRS_8.expr | 639 ++++++ .../Conv2gemm_1x7_NCHW_FCRS_9.expr | 584 ++++++ test/nnet/log/g2bmm/G2BMM_0.expr | 143 ++ test/nnet/log/g2bmm/G2BMM_1.expr | 258 +++ test/nnet/log/g2bmm/G2BMM_2.expr | 292 +++ test/nnet/log/g2bmm/G2BMM_3.expr | 306 +++ test/nnet/log/g2bmm/G2BMM_4.expr | 331 +++ test/nnet/log/g2bmm/G2BMM_5.expr | 276 +++ test/nnet/log/g2bmm/G2BMM_6.expr | 321 +++ test/nnet/log/g2bmm/G2BMM_7.expr | 276 +++ test/nnet/log/g2bmm/G2BMM_8.expr | 404 ++++ test/nnet/log/g2bmm/G2BMM_9.expr | 746 +++++++ test/nnet/log/gbmm/GBMM_0.expr | 143 ++ test/nnet/log/gbmm/GBMM_1.expr | 258 +++ test/nnet/log/gbmm/GBMM_2.expr | 292 +++ test/nnet/log/gbmm/GBMM_3.expr | 306 +++ test/nnet/log/gbmm/GBMM_4.expr | 331 +++ test/nnet/log/gbmm/GBMM_5.expr | 276 +++ test/nnet/log/gbmm/GBMM_6.expr | 321 +++ test/nnet/log/gbmm/GBMM_7.expr | 276 +++ test/nnet/log/gbmm/GBMM_8.expr | 404 ++++ test/nnet/log/gbmm/GBMM_9.expr | 746 +++++++ test/nnet/readlog.cc | 17 + test/nnet/test_Interpreter.cc | 171 ++ test/nnet/test_OpSearch.cc | 240 +++ test/nnet/test_Rule2VariableMerging.cc | 131 ++ test/nnet/test_TConv2gemm.cc | 390 ++++ test/nnet/test_as_tvm.cc | 40 + test/nnet/test_compareFormulas.cc | 45 + test/nnet/test_conv2conv.cc | 126 ++ test/nnet/test_conv2gemm.cc | 282 +++ test/nnet/test_dlt.cc | 78 + test/nnet/test_g2bmm.cc | 83 + test/nnet/test_guidedDLT.cc | 410 ++++ test/nnet/test_hash.cc | 65 + test/nnet/test_matchConv.cc | 144 ++ test/nnet/test_matchElementWise.cc | 102 + test/nnet/test_matchMatmul.cc | 338 +++ test/nnet/test_matchReshape.cc | 150 ++ test/nnet/test_mergeStage.cc | 61 + test/nnet/test_oobChecker.cc | 54 + test/nnet/test_rangeMagnify.cc | 45 + test/nnet/test_relaxation.cc | 91 + test/nnet/test_serializer.cc | 100 + test/nnet/test_simplify.cc | 221 ++ test/nnet/test_subset.cc | 19 + 239 files changed, 69388 insertions(+), 13 deletions(-) create mode 100644 include/nnet/Pass/MatchComputationKernel.h create mode 100644 include/nnet/Pass/MatchMemBoundKernel.h create mode 100644 include/nnet/Pass/Pass.h create mode 100644 include/nnet/Pass/Rule1VariableSplit.h create mode 100644 include/nnet/Pass/Rule2VariableMerging.h create mode 100644 include/nnet/Pass/Rule3StageSplit.h create mode 100644 include/nnet/Pass/Rule4StageMerging.h create mode 100644 include/nnet/Pass/Rule5RangeRelaxation.h create mode 100644 include/nnet/Pass/Rule6KenerlMatching.h create mode 100644 include/nnet/Pass/Rule7DLT.h create mode 100644 include/nnet/Pass/Rule8GuidedDLT.h create mode 100644 include/nnet/Pass/Rule90TwoStageElementWise.h create mode 100644 include/nnet/Pass/Rule91MergeStagesWithSum.h create mode 100644 include/nnet/Pass/Rule9RangeMagnify.h create mode 100644 include/nnet/ReplaceKit.h create mode 100644 include/nnet/Visitor/AsTVMVisitor.h create mode 100644 include/nnet/Visitor/CheckOOBVisitor.h create mode 100644 include/nnet/Visitor/CloneMutator.h create mode 100644 include/nnet/Visitor/CompareMultiFormulasVisitor.h create mode 100644 include/nnet/Visitor/CountRoutineVisitor.h create mode 100644 include/nnet/Visitor/FullPrinterVisitor.h create mode 100644 include/nnet/Visitor/GetTensorsVisitor.h create mode 100644 include/nnet/Visitor/HashVisitor.h create mode 100644 include/nnet/Visitor/InputVisitor.h create mode 100644 include/nnet/Visitor/Interpreter.h create mode 100644 include/nnet/Visitor/MatchReshapeVisitor.h create mode 100644 include/nnet/Visitor/MatchTableVisitor.h create mode 100644 include/nnet/Visitor/MatmulTransposeMutator.h create mode 100644 include/nnet/Visitor/MergeMemboundMutator.h create mode 100644 include/nnet/Visitor/PatternMatcher.h create mode 100644 include/nnet/Visitor/RangeMagnifyVisitor.h create mode 100644 include/nnet/Visitor/RangeRelaxFunctor.h create mode 100644 include/nnet/Visitor/ReplaceNodeMutator.h create mode 100644 include/nnet/Visitor/ReplaceVariable.h create mode 100644 include/nnet/Visitor/Serializer.h create mode 100644 include/nnet/Visitor/SimplifyExprVisitor.h create mode 100644 include/nnet/Visitor/SimplifyFormulaMutator.h create mode 100644 include/nnet/Visitor/StrideVisitor.h create mode 100644 include/nnet/common.h create mode 100644 include/nnet/dbg.h create mode 100644 include/nnet/derivator.h create mode 100644 include/nnet/dlt.h create mode 100644 include/nnet/expr.h create mode 100644 include/nnet/iterator_table.h create mode 100644 include/nnet/nmutator.h create mode 100644 include/nnet/permutation.h create mode 100644 include/nnet/ref.h create mode 100644 include/nnet/routine.h create mode 100644 include/nnet/test.h create mode 100644 include/nnet/visitor.h create mode 100644 src/nnet/Pass/MatchComputationKernel.cc create mode 100644 src/nnet/Pass/MatchMemBoundKernel.cc create mode 100644 src/nnet/Pass/Pass.cc create mode 100644 src/nnet/Pass/Rule1VariableSplit.cc create mode 100644 src/nnet/Pass/Rule2VariableMerging.cc create mode 100644 src/nnet/Pass/Rule3StageSplit.cc create mode 100644 src/nnet/Pass/Rule4StageMerging.cc create mode 100644 src/nnet/Pass/Rule5RangeRelaxation.cc create mode 100644 src/nnet/Pass/Rule6KenerlMatching.cc create mode 100644 src/nnet/Pass/Rule7DLT.cc create mode 100644 src/nnet/Pass/Rule8GuidedDLT.cc create mode 100644 src/nnet/Pass/Rule90TwoStageElementWise.cc create mode 100644 src/nnet/Pass/Rule91MergeStagesWithSum.cc create mode 100644 src/nnet/Pass/Rule9RangeMagnify.cc create mode 100644 src/nnet/Visitor/AsTVMVisitor.cc create mode 100644 src/nnet/Visitor/CheckOOBVisitor.cc create mode 100644 src/nnet/Visitor/CloneMutator.cc create mode 100644 src/nnet/Visitor/CompareMultiFormulasVisitor.cc create mode 100644 src/nnet/Visitor/CountRoutineVisitor.cc create mode 100644 src/nnet/Visitor/FullPrinterVisitor.cc create mode 100644 src/nnet/Visitor/GetTensorsVisitor.cc create mode 100644 src/nnet/Visitor/InputVisitor.cc create mode 100644 src/nnet/Visitor/Interpreter.cc create mode 100644 src/nnet/Visitor/MatchReshapeVisitor.cc create mode 100644 src/nnet/Visitor/MatchTableVisitor.cc create mode 100644 src/nnet/Visitor/MatmulTransposeMutator.cc create mode 100644 src/nnet/Visitor/MergeMemboundMutator.cc create mode 100644 src/nnet/Visitor/PatternMatcher.cc create mode 100644 src/nnet/Visitor/RangeMagnifyVisitor.cc create mode 100644 src/nnet/Visitor/RangeRelaxFunctor.cc create mode 100644 src/nnet/Visitor/ReplaceNodeMutator.cc create mode 100644 src/nnet/Visitor/ReplaceVariable.cc create mode 100644 src/nnet/Visitor/Serializer.cc create mode 100644 src/nnet/Visitor/SimplifyExprVisitor.cc create mode 100644 src/nnet/Visitor/SimplifyFormulaMutator.cc create mode 100644 src/nnet/Visitor/StrideVisitor.cc create mode 100644 src/nnet/Visitor/hashVisitor.cc create mode 100644 src/nnet/derivator.cc create mode 100644 src/nnet/dlt.cc create mode 100644 src/nnet/expr.cc create mode 100644 src/nnet/iterator_table.cc create mode 100644 src/nnet/nmutator.cc create mode 100644 src/nnet/permutation.cc create mode 100644 src/nnet/replace_kit.cc create mode 100644 src/nnet/routine.cc create mode 100644 src/nnet/test.cc create mode 100644 src/nnet/utils.cc create mode 100644 src/nnet/visitor.cc create mode 100644 test/nnet/failed/test_activation.cc create mode 100644 test/nnet/failed/test_csrnet.cc create mode 100644 test/nnet/failed/test_longformer.cc create mode 100644 test/nnet/failed/test_multiple_op.cc create mode 100644 test/nnet/failed/test_mutator.cc create mode 100644 test/nnet/failed/test_operatorCompute.cc create mode 100644 test/nnet/failed/test_transposeOpToExpr.cc create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_0.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_1.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_10.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_11.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_12.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_13.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_14.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_15.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_16.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_17.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_2.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_3.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_4.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_5.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_6.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_7.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_8.expr create mode 100644 test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_9.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_0.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_1.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_10.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_11.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_12.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_13.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_14.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_2.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_3.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_4.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_5.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_6.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_7.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_8.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_9.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_0.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_1.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_10.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_11.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_12.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_13.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_2.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_3.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_4.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_5.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_6.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_7.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_8.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_9.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_0.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_1.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_10.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_11.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_12.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_13.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_2.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_3.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_4.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_5.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_6.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_7.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_8.expr create mode 100644 test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_9.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_0.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_1.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_10.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_2.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_3.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_4.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_5.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_6.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_7.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_8.expr create mode 100644 test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_9.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_0.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_1.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_10.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_2.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_3.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_4.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_5.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_6.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_7.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_8.expr create mode 100644 test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_9.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_0.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_1.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_2.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_3.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_4.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_5.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_6.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_7.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_8.expr create mode 100644 test/nnet/log/g2bmm/G2BMM_9.expr create mode 100644 test/nnet/log/gbmm/GBMM_0.expr create mode 100644 test/nnet/log/gbmm/GBMM_1.expr create mode 100644 test/nnet/log/gbmm/GBMM_2.expr create mode 100644 test/nnet/log/gbmm/GBMM_3.expr create mode 100644 test/nnet/log/gbmm/GBMM_4.expr create mode 100644 test/nnet/log/gbmm/GBMM_5.expr create mode 100644 test/nnet/log/gbmm/GBMM_6.expr create mode 100644 test/nnet/log/gbmm/GBMM_7.expr create mode 100644 test/nnet/log/gbmm/GBMM_8.expr create mode 100644 test/nnet/log/gbmm/GBMM_9.expr create mode 100644 test/nnet/readlog.cc create mode 100644 test/nnet/test_Interpreter.cc create mode 100644 test/nnet/test_OpSearch.cc create mode 100644 test/nnet/test_Rule2VariableMerging.cc create mode 100644 test/nnet/test_TConv2gemm.cc create mode 100644 test/nnet/test_as_tvm.cc create mode 100644 test/nnet/test_compareFormulas.cc create mode 100644 test/nnet/test_conv2conv.cc create mode 100644 test/nnet/test_conv2gemm.cc create mode 100644 test/nnet/test_dlt.cc create mode 100644 test/nnet/test_g2bmm.cc create mode 100644 test/nnet/test_guidedDLT.cc create mode 100644 test/nnet/test_hash.cc create mode 100644 test/nnet/test_matchConv.cc create mode 100644 test/nnet/test_matchElementWise.cc create mode 100644 test/nnet/test_matchMatmul.cc create mode 100644 test/nnet/test_matchReshape.cc create mode 100644 test/nnet/test_mergeStage.cc create mode 100644 test/nnet/test_oobChecker.cc create mode 100644 test/nnet/test_rangeMagnify.cc create mode 100644 test/nnet/test_relaxation.cc create mode 100644 test/nnet/test_serializer.cc create mode 100644 test/nnet/test_simplify.cc create mode 100644 test/nnet/test_subset.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fa588d0..57dec16a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,10 +4,10 @@ include(CMakeDependentOption) project(InfiniTensor C CXX) # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. -option(BUILD_TEST "Build tests" OFF) +option(BUILD_TEST "Build tests" ON) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) -cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF) -cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF) set(DEFAULT_BUILD_TYPE "RelWithDebInfo") @@ -31,16 +31,24 @@ if(OpenMP_CXX_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") endif() -# add_subdirectory(3rd-party/pybind11) -# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) - include_directories(include) + +# # Pybind11 +# add_subdirectory(3rd-party/pybind11) # include_directories(3rd-party/pybind11/include) -# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) + +# nlohmann_json +add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) +include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) if(BUILD_TEST) - set(BUILD_GMOCK off) - set(INSTALL_GTEST off) + # TODO: chekc set + set(BUILD_GMOCK + OFF + CACHE BOOL "Do not build gmock" FORCE) + set(INSTALL_GTEST + OFF + CACHE BOOL "Do not install gtest" FORCE) add_subdirectory(3rd-party/googletest) include_directories(3rd-party/googletest/googletest/include) endif() @@ -59,13 +67,26 @@ add_library(InfiniTensor SHARED ${SRC}) # pybind11_add_module(infini MODULE ${FFI}) # target_link_libraries(infini PRIVATE infini_cpp) -if(BUILD_TEST) - enable_testing() - file(GLOB_RECURSE TEST_SOURCES test/*.cc) +function(build_test files) + # Non-recursive glob for skip failed tests + file(GLOB TEST_SOURCES ${files}) foreach(testsourcefile ${TEST_SOURCES}) get_filename_component(testname ${testsourcefile} NAME_WE) add_executable(${testname} ${testsourcefile}) - target_link_libraries(${testname} InfiniTensor gtest_main gtest) + target_link_libraries(${testname} InfiniTensor GTest::gtest_main) add_test(NAME ${testname} COMMAND ${testname}) endforeach(testsourcefile ${TEST_SOURCES}) +endfunction() + +if(BUILD_TEST) + enable_testing() + if(BUILD_TEST_CORE) + build_test(test/core/*.cc) + endif() + if(BUILD_TEST_PET) + build_test(test/pet/*.cc) + endif() + if(BUILD_TEST_EINNET) + build_test(test/nnet/*.cc) + endif() endif() diff --git a/include/nnet/Pass/MatchComputationKernel.h b/include/nnet/Pass/MatchComputationKernel.h new file mode 100644 index 00000000..4b182a89 --- /dev/null +++ b/include/nnet/Pass/MatchComputationKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class MatchComputationKernel : public Pass { + public: + MatchComputationKernel(Derivator &derivator) + : Pass(derivator, "MatchComputationKernel") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/MatchMemBoundKernel.h b/include/nnet/Pass/MatchMemBoundKernel.h new file mode 100644 index 00000000..6b0a4bec --- /dev/null +++ b/include/nnet/Pass/MatchMemBoundKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class MatchMemBoundKernel : public Pass { + public: + MatchMemBoundKernel(Derivator &derivator) + : Pass(derivator, "MatchMemBoundKernel") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Pass.h b/include/nnet/Pass/Pass.h new file mode 100644 index 00000000..a8d517fa --- /dev/null +++ b/include/nnet/Pass/Pass.h @@ -0,0 +1,41 @@ +#pragma once +#include "nnet/derivator.h" + +namespace nnet { + +class Pass { + private: + VecExpr transformations; + + protected: + Derivator &derivator; + string passName; + /** + * @brief // False if does not add log in Derivator. It should be false for + * single Pass test to avoid mismatch of passInfos and passMsgs due to + * different number of "run" and "nextStep". + */ + bool enableLogging, enableDebug; + + virtual void transform(Formula &origin, int depth, Expr &rCur) = 0; + void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur, + const string &ruleInfo = ""); + + Var getNewVar(); + string newTensorName(); + + private: + void initialize(Formula &origin, const Expr &rCur); + void finalize(); + + public: + Pass(Derivator &derivator, const string &passName); + virtual ~Pass(); + + void run(Formula &origin, int dfsDepth, Expr &rCur); + void setEnableLogging(bool value); + void setEnableDebug(bool value); + const VecExpr &getTransformations(); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule1VariableSplit.h b/include/nnet/Pass/Rule1VariableSplit.h new file mode 100644 index 00000000..06f2bb4e --- /dev/null +++ b/include/nnet/Pass/Rule1VariableSplit.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule1VariableSplit : public Pass { + public: + Rule1VariableSplit(Derivator &derivator) + : Pass(derivator, "Rule1VariableSplit") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + vector getSplitableVar(const RangeOp &rangeOp); + Expr replaceIters(Expr cur, const Replace &replace); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule2VariableMerging.h b/include/nnet/Pass/Rule2VariableMerging.h new file mode 100644 index 00000000..1f277f46 --- /dev/null +++ b/include/nnet/Pass/Rule2VariableMerging.h @@ -0,0 +1,29 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule2VariableMerging : public Pass { + private: + map> substituteRules; + + public: + Rule2VariableMerging(Derivator &derivator) + : Pass(derivator, "Rule2VariableMerging") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + + vector getMergableReplaces(RangeOp rangeOp, int depth); + optional getReplaceMergingTwoLoopIters(const RangeOp &rangeOp, + pair pairA, + pair pairB, + const IteratorTable &exprIT, + int tensorID); + optional getReplaceMappingTwoLoopIters(const RangeOp &rangeOp, + pair pa, + pair pb); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule3StageSplit.h b/include/nnet/Pass/Rule3StageSplit.h new file mode 100644 index 00000000..99e172cf --- /dev/null +++ b/include/nnet/Pass/Rule3StageSplit.h @@ -0,0 +1,19 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule3StageSplit : public Pass { + private: + map> substituteRules; + + public: + Rule3StageSplit(Derivator &derivator) + : Pass(derivator, "Rule3StageSplit") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + vector> getSplitSummationIters(RangeOp rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule4StageMerging.h b/include/nnet/Pass/Rule4StageMerging.h new file mode 100644 index 00000000..13f11074 --- /dev/null +++ b/include/nnet/Pass/Rule4StageMerging.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule4StageMerging : public Pass { + bool success, mergeStageWithCalc; + + public: + Rule4StageMerging(Derivator &derivator) + : Pass(derivator, "Rule4StageMerging"), success(false), + mergeStageWithCalc(false) {} + bool rule4StageMerging(Formula &origin, int depth, Expr &rCur, + bool mergeStageWithCalc = false); + bool isSuccessful(); + void setMergeStageWithCalc(bool value); + + private: + virtual void transform(Formula &origin, int depth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule5RangeRelaxation.h b/include/nnet/Pass/Rule5RangeRelaxation.h new file mode 100644 index 00000000..c1719dfc --- /dev/null +++ b/include/nnet/Pass/Rule5RangeRelaxation.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule5RangeRelaxation : public Pass { + public: + Rule5RangeRelaxation(Derivator &derivator) + : Pass(derivator, "Rule5RangeRelaxation") {} + Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur); + + private: + virtual void transform(Formula &origin, int depth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule6KenerlMatching.h b/include/nnet/Pass/Rule6KenerlMatching.h new file mode 100644 index 00000000..19648eaf --- /dev/null +++ b/include/nnet/Pass/Rule6KenerlMatching.h @@ -0,0 +1,17 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule6KenerlMatching : public Pass { + public: + Rule6KenerlMatching(Derivator &derivator) + : Pass(derivator, "Rule6KenerlMatching") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + // RE: seperating this func is a choice. + VecExpr matchElementWise(const RangeOp &rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule7DLT.h b/include/nnet/Pass/Rule7DLT.h new file mode 100644 index 00000000..d2ce0ecc --- /dev/null +++ b/include/nnet/Pass/Rule7DLT.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule7DLT : public Pass { + public: + Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand); + vector getFactors(); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule8GuidedDLT.h b/include/nnet/Pass/Rule8GuidedDLT.h new file mode 100644 index 00000000..e6536df4 --- /dev/null +++ b/include/nnet/Pass/Rule8GuidedDLT.h @@ -0,0 +1,48 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule8GuidedDLT : public Pass { + public: + Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {} + VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur, + bool debug = false); + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + /** + * @brief If only one row miss match (more iterators mismatch), directly do + * data layout construction according to the IT. + * + * @return Expr Return nullptr if failed. + */ + Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch, + const IteratorTable &exprIT, const Pattern &pattern); + /** + * @brief Check whether two iterators overlap each other. If overlapping, we + * cannot simply reconstruct the tensor into a new one by seperate all + * iterators into different dimensions. + */ + bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT, + int tensorID); + /** + * @brief Only product of two tensors can be guided DLTed. + * + * @param cur + * @return true + * @return false + */ + bool statisfyGuidedDLT(RangeOp cur) const; + /** + * @brief Deal with output DLT mismatch only. + */ + Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch, + const IteratorTable &exprIT, + const Pattern &pattern); + Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace, + vector tensorDimAxes, vector newShape); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule90TwoStageElementWise.h b/include/nnet/Pass/Rule90TwoStageElementWise.h new file mode 100644 index 00000000..ab37cf99 --- /dev/null +++ b/include/nnet/Pass/Rule90TwoStageElementWise.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule90TwoStageElementWise : public Pass { + public: + Rule90TwoStageElementWise(Derivator &derivator) + : Pass(derivator, "Rule90TwoStageElementWise") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + VecExpr matchTwoStageElementWise(const RangeOp &rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule91MergeStagesWithSum.h b/include/nnet/Pass/Rule91MergeStagesWithSum.h new file mode 100644 index 00000000..221e1772 --- /dev/null +++ b/include/nnet/Pass/Rule91MergeStagesWithSum.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule91MergeStagesWithSum : public Pass { + public: + Rule91MergeStagesWithSum(Derivator &derivator) + : Pass(derivator, "Rule91MergeStagesWithSum") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule9RangeMagnify.h b/include/nnet/Pass/Rule9RangeMagnify.h new file mode 100644 index 00000000..613de158 --- /dev/null +++ b/include/nnet/Pass/Rule9RangeMagnify.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule9RangeMagnify : public Pass { + public: + Rule9RangeMagnify(Derivator &derivator) + : Pass(derivator, "Rule9RangeMagnify") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/ReplaceKit.h b/include/nnet/ReplaceKit.h new file mode 100644 index 00000000..19f2ec0e --- /dev/null +++ b/include/nnet/ReplaceKit.h @@ -0,0 +1,46 @@ +#pragma once +#include "nnet/expr.h" + +namespace nnet { + +struct Replace { + int iteratorType; + vector oldIters; // i_1, ... + vector newIters; // j_1, ... + VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter + VecExpr psis; // i_1=\psi_1(j_1, ...) + vector newVarRanges; + + bool isReplaced(Var var) const { + for (const auto &iter : oldIters) + if (iter->equal(var)) + return true; + return false; + } + + string toReadable() const { + string ret = "Old iters: " + serializeVec(oldIters) + + ", new iters: " + serializeVec(newIters); + ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis); + return ret; + } +}; + +class ReplaceKit { + public: + static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp, + const Replace &replace, + const Expr &replacedSummand); + static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner, + const Replace &replace); + static RangeOp buildDLTOuterRangeOp(const RangeOp &original, + const Subscript &subscriptedNewRangeOp); + static Expr replaceMultipleExprs(const Expr &cur, + const vector &patterns, + const VecExpr &replacements, + bool simplify = true); + static Expr replaceExpr(const Expr &cur, const Expr &pattern, + const Expr &replacement); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/AsTVMVisitor.h b/include/nnet/Visitor/AsTVMVisitor.h new file mode 100644 index 00000000..f193231a --- /dev/null +++ b/include/nnet/Visitor/AsTVMVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +class AsTVMVisitor : public Functor { + private: + int nStage = 0, curStage = -1; + std::unordered_map offset; + std::vector inputs; + std::string output; + std::vector pythonVars; + std::vector> inputShapes; + std::vector outputShape; + std::string stmts; + + public: + std::string getStmts() const; + + const std::vector &getInputs() const { return inputs; } + const std::string &getOutput() const { return output; } + + const std::vector> &getInputShapes() const { + return inputShapes; + } + const std::vector &getOutputShape() const { return outputShape; } + + std::string visit_(const Constant &c) override; + std::string visit_(const BinaryOp &c) override; + std::string visit_(const Func &c) override; + std::string visit_(const RangeOp &c) override; + std::string visit_(const Subscript &c) override; + std::string visit_(const Var &c) override; + std::string visit_(const Tensor &c) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CheckOOBVisitor.h b/include/nnet/Visitor/CheckOOBVisitor.h new file mode 100644 index 00000000..017a7ae9 --- /dev/null +++ b/include/nnet/Visitor/CheckOOBVisitor.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CheckOOBVisitor : public ExprTreeVisitor { + RangeOp rangeOp; + bool detect = false; + + public: + CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {} + void visit_(const Subscript &c) override; + + /** + * @brief + * @return true If there is OOB + * @return false If there is no OOB + */ + bool checkRangeOp(const RangeOp &_rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CloneMutator.h b/include/nnet/Visitor/CloneMutator.h new file mode 100644 index 00000000..6812cb80 --- /dev/null +++ b/include/nnet/Visitor/CloneMutator.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes. +class CloneMutator : public Mutator { + public: + CloneMutator() : Mutator(false) {} + Expr visit_(const Constant &c) override; + Expr visit_(const Var &c) override; + Expr visit_(const Tensor &c) override; + Expr clone(const Expr &c) { return dispatch(c); } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CompareMultiFormulasVisitor.h b/include/nnet/Visitor/CompareMultiFormulasVisitor.h new file mode 100644 index 00000000..a120d48e --- /dev/null +++ b/include/nnet/Visitor/CompareMultiFormulasVisitor.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CompareMultiFormulasVisitor : public ExprTreeVisitor { + vector newSumVarRanges; + RangeOp newRangeOp; + + public: + CompareMultiFormulasVisitor() : ExprTreeVisitor() {} + bool compare(const VecExpr &roots); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CountRoutineVisitor.h b/include/nnet/Visitor/CountRoutineVisitor.h new file mode 100644 index 00000000..d9536495 --- /dev/null +++ b/include/nnet/Visitor/CountRoutineVisitor.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CountRoutineVisitor : public ExprTreeVisitor { + private: + vector cnts; + + public: + CountRoutineVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 1, _verobse) {} + void visit_(const Tensor &c) override; + vector count(const Expr &root); + bool match(const Expr &root, int nMatmul = 0, int nConv = 0, + int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0); +}; +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/FullPrinterVisitor.h b/include/nnet/Visitor/FullPrinterVisitor.h new file mode 100644 index 00000000..a9caeccf --- /dev/null +++ b/include/nnet/Visitor/FullPrinterVisitor.h @@ -0,0 +1,25 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class FullPrinterVisitor : public ExprTreeVisitor { + private: + vector> q; + + public: + FullPrinterVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + void visit_(const Tensor &c) override; + + string print(const Expr &root); + /** + * @brief Get all tensors & OPs in a reversed order + * + * @param root + * @return vector<> + */ + const vector> &traverse(const Expr &root); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/GetTensorsVisitor.h b/include/nnet/Visitor/GetTensorsVisitor.h new file mode 100644 index 00000000..66a8d5dc --- /dev/null +++ b/include/nnet/Visitor/GetTensorsVisitor.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Get all tensors in the stage +class GetTensorsVisitor : public ExprTreeVisitor { + private: + unordered_map tensors; + + void visit_(const Tensor &c) override; + + public: + GetTensorsVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + auto get(const Expr &c) { + dispatch(c); + return tensors; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/HashVisitor.h b/include/nnet/Visitor/HashVisitor.h new file mode 100644 index 00000000..0d20f49d --- /dev/null +++ b/include/nnet/Visitor/HashVisitor.h @@ -0,0 +1,31 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Calculate hash for a normal form, starting at a RangeOp +class HashVisitor : public Functor { + inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599}; + + PtrUmap varHash; + int nLoopVars = 0; + PtrUmap name2id; + vector rootId; + vector haveAlias; + int nVars = 0; + vector power; + + private: + HashType visit_(const Constant &c) override; + HashType visit_(const BinaryOp &c) override; + HashType visit_(const RangeOp &c) override; + HashType visit_(const Subscript &c) override; + HashType visit_(const Tensor &c) override; + HashType visit_(const Var &c) override; + + public: + HashVisitor(int _verobse = 0) : Functor(_verobse) {} + HashType getHash(const Expr &c); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/InputVisitor.h b/include/nnet/Visitor/InputVisitor.h new file mode 100644 index 00000000..1bc3c89a --- /dev/null +++ b/include/nnet/Visitor/InputVisitor.h @@ -0,0 +1,23 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class InputVisitor : public ExprTreeVisitor { + vector inputs; + + public: + int nInputs = 0; + InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + void visit_(const Tensor &c) override; + + /** + * @brief Get the all inputs in the netsed stages + */ + vector getInputs(const RangeOp &_rangeOp) { + dispatch(_rangeOp); + return inputs; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/Interpreter.h b/include/nnet/Visitor/Interpreter.h new file mode 100644 index 00000000..b4479b79 --- /dev/null +++ b/include/nnet/Visitor/Interpreter.h @@ -0,0 +1,55 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class Interpreter : public Functor { + public: + using ttype = int; // Test data type + using rtype = int; // Return data type + using Position = vector; + using Inputs = unordered_map>>; + using Iteration = PtrUmap; + + private: + // cache the input value + Inputs inputs; + vector iterations; + vector positions; + + rtype visit_(const Constant &c) override; + rtype visit_(const BinaryOp &c) override; + rtype visit_(const RangeOp &c) override; + rtype visit_(const Subscript &c) override; + rtype visit_(const Var &c) override; + rtype visit_(const Tensor &c) override; + // int visit_(const Func &c); // Future work + + static Inputs genInputStartingFromZero(const RangeOp &range); + + public: + Interpreter(Inputs _inputs, int _verbose = 0) + : Functor(_verbose), inputs(_inputs) {} + Interpreter(RangeOp range, int _verbose = 0); + + /** + * @brief Calculate the output at specified poistions + * + * @param expr The expression to be calculated. + * @param poses Positions of output. + * @return vector Value of output. + */ + vector interpret(const Expr &expr, const vector &poses); + /** + * @brief Calculate the output at equally spaced positions + * + * @param expr The expression to be calculated. + * @param nPoses The number of calculated output positions. + * @return vector Value of output. + */ + vector interpretUniformSample(const RangeOp &range, + int nPoses = 100); + vector interpretAllOutput(const RangeOp &range); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatchReshapeVisitor.h b/include/nnet/Visitor/MatchReshapeVisitor.h new file mode 100644 index 00000000..caf0130e --- /dev/null +++ b/include/nnet/Visitor/MatchReshapeVisitor.h @@ -0,0 +1,14 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MatchReshapeVisitor : public Functor { + private: + PtrMap _coefficient; + + public: + bool visit_(const RangeOp &c) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatchTableVisitor.h b/include/nnet/Visitor/MatchTableVisitor.h new file mode 100644 index 00000000..53836883 --- /dev/null +++ b/include/nnet/Visitor/MatchTableVisitor.h @@ -0,0 +1,60 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +class MatchTableVisitor + : public Functor stride)> { + private: + // Var -> {(tensor, dim)} + Appearance appearance; + vector tensors; + vector subscripts; + StrideTable strideTable; + PtrMap>> + strideInDim; // [Iterator][tensorID][dim]=stride + + // Intermediate variable + // product of a sub-exprtree: Stride has to be done in two DFS + SubexprSride subexprStride; + bool hasUnsupportedOp = false; + + public: + MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {} + void visit_(const BinaryOp &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Subscript &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Var &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Constant &c, const Tensor &tensor, int dim, + optional stride) override; + // void visit_(const Tensor &c, const Tensor &tensor) override; + + [[nodiscard]] bool operator()(const RangeOp &e) { + hasUnsupportedOp = false; + // get the location and stride of each iterator + auto mulOp = as(e->getSummand()); + // TODO [feature]: support complex index exprs + if (!mulOp || mulOp->getOpType() != OpType::Mul) { + nnet_unimplemented_continue(); + return false; + } + StrideVisitor strideVisitor(0); + subexprStride = strideVisitor.getFormulaStride(e); + dispatch(mulOp->getLhs(), nullptr, 0, 0); + dispatch(mulOp->getRhs(), nullptr, 0, 0); + subscripts.emplace_back(as(mulOp->getLhs())); + subscripts.emplace_back(as(mulOp->getRhs())); + assert(tensors.size() == subscripts.size()); + assert(tensors.size() < 5); + return !hasUnsupportedOp; + } + + auto getResult() const { + return tuple(appearance, tensors, strideTable, subscripts); + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatmulTransposeMutator.h b/include/nnet/Visitor/MatmulTransposeMutator.h new file mode 100644 index 00000000..000467d2 --- /dev/null +++ b/include/nnet/Visitor/MatmulTransposeMutator.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MatmulTransposeMutator : public Mutator { + Derivator &derivator; + + public: + MatmulTransposeMutator(Derivator &derivator) + : Mutator(1), derivator(derivator) {} + VecExpr transpose(const Tensor &tensor); + + private: + Tensor transposeInput(const Tensor &tensor); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MergeMemboundMutator.h b/include/nnet/Visitor/MergeMemboundMutator.h new file mode 100644 index 00000000..786103e1 --- /dev/null +++ b/include/nnet/Visitor/MergeMemboundMutator.h @@ -0,0 +1,20 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MergeMemboundMutator : public Mutator { + VecExpr kernels; + int curDepth; // from the last one to the first one + Expr visit_(const Tensor &c) override; + // FIXME: duplicate code + Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc); + bool checkEmpty(); + + public: + MergeMemboundMutator(const VecExpr &kernels) + : Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {} + Expr merge(bool allowEmptyMembound = false); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/PatternMatcher.h b/include/nnet/Visitor/PatternMatcher.h new file mode 100644 index 00000000..6765a4e0 --- /dev/null +++ b/include/nnet/Visitor/PatternMatcher.h @@ -0,0 +1,43 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +/** + * @brief Since the output positions of operators always start from 0, we have + * to offset them if the the boundary expression of is not 0. + */ +class PatternMatcher : public Functor { + private: + Derivator &derivator; + bool hasNonZeroRange; + const RangeOp originalCur; + + public: + PatternMatcher(Derivator &derivator, const RangeOp &cur); + /** + * @brief Get the Cur whose loop vars are all offset to [0, x). Since + * operator outputs start from 0, RangeOp has to be aligned. + */ + RangeOp getOffsetCur(); + /** + * @brief Add outer RangeOp to map the original positions to the new + * positions staring from 0. + * + * @param exprs Tensors from matched exprs + */ + VecExpr applyWrapper(const VecExpr &exprs); + + VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern); + + private: + VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp, + IteratorTable &exprIT); + // get reverse tensor and iterator map ([pattern tensor/iter ID] -> + // real) + Expr matchKernelWithTensorMap(const Pattern &pattern, + const RangeOp &rangeOp, + IteratorTable &exprIT); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/RangeMagnifyVisitor.h b/include/nnet/Visitor/RangeMagnifyVisitor.h new file mode 100644 index 00000000..f6fa03c1 --- /dev/null +++ b/include/nnet/Visitor/RangeMagnifyVisitor.h @@ -0,0 +1,25 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class RangeMagnifyVisitor : public Mutator { + vector newSumVarRanges; + RangeOp newRangeOp; + + public: + RangeMagnifyVisitor() : Mutator(0) {} + Expr visit_(const RangeOp &c) override; + Expr visit_(const Subscript &c) override; + /** + * @brief + * + * @param root + * @param _newSumVarRanges + * @return RangeOp nullptr if failed to magnify + */ + RangeOp magnify(const RangeOp &root, + const vector &_newSumVarRanges); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/RangeRelaxFunctor.h b/include/nnet/Visitor/RangeRelaxFunctor.h new file mode 100644 index 00000000..2dd80c2f --- /dev/null +++ b/include/nnet/Visitor/RangeRelaxFunctor.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +using RangeMap = PtrMap; +class RangeRelaxFunctor : public Functor { + RangeOp rangeOp; + + public: + RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {} + RangeMap visit_(const BinaryOp &c) override; + RangeMap visit_(const RangeOp &c) override; + RangeMap visit_(const Subscript &c) override; + RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/ReplaceNodeMutator.h b/include/nnet/Visitor/ReplaceNodeMutator.h new file mode 100644 index 00000000..bdb7fb63 --- /dev/null +++ b/include/nnet/Visitor/ReplaceNodeMutator.h @@ -0,0 +1,20 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Replace node according to its address the summand +// Only subscript and tensor are supported now. +class ReplaceNodeMutator : public Mutator { + int nSubscripts = 0; + ExprNode *target; + Expr replacement; + + public: + ReplaceNodeMutator() : Mutator(0) {} + Expr visit_(const Subscript &c) override; + Expr visit_(const Tensor &c) override; + Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/ReplaceVariable.h b/include/nnet/Visitor/ReplaceVariable.h new file mode 100644 index 00000000..20efbc4f --- /dev/null +++ b/include/nnet/Visitor/ReplaceVariable.h @@ -0,0 +1,33 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class ReplaceVariable : public Mutator { + VecExpr patterns, replacements; + map patternHash; + + public: + ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) { + set({_pattern}, {_replacement}); + } + ReplaceVariable(const map> &mapping) + : Mutator(false) { + VecExpr _patterns, _replacements; + for (const auto &[_, v] : mapping) { + _patterns.emplace_back(v.first); + _replacements.emplace_back(v.second); + } + set(_patterns, _replacements); + } + Expr visit_(const BinaryOp &c) override; + // NOT recur to the next stage + Expr visit_(const RangeOp &c) override; + Expr visit_(const Var &c) override; + + private: + void set(VecExpr _pattern, VecExpr _replacement); + Expr match(const Expr &c); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/Serializer.h b/include/nnet/Visitor/Serializer.h new file mode 100644 index 00000000..3568b152 --- /dev/null +++ b/include/nnet/Visitor/Serializer.h @@ -0,0 +1,52 @@ +#pragma once +#include "nlohmann/json_fwd.hpp" +#include "nnet/visitor.h" +#include + +namespace nnet { + +class Serializer : public Functor { + using json = nlohmann::ordered_json; + + private: + static constexpr int VERSION{1}; + std::unique_ptr jPtr; + json &j; + static int id; + + string visit_(const Constant &c) override; + string visit_(const BinaryOp &c) override; + string visit_(const RangeOp &c) override; + string visit_(const Subscript &c) override; + string visit_(const Var &c) override; + string visit_(const Tensor &c) override; + string dispatchRoutine(const Routine &c); + + Expr buildExprTree(string key); + Routine buildRoutine(string key); + + public: + Serializer(int _verobse = 0); + virtual ~Serializer(); + + /** + * @brief Serialize the given expression to json file + * + * @param expr The expression to be serialized + * @param filePath The path of json file to be output + * @param msg Message of derivation + * @return bool Whether the serialization succeed + */ + bool serialize(const Expr &expr, const string &filePath, + const string &msg = ""); + + /** + * @brief Deserialize the given json file to expression + * + * @param filePath The path to file to be deserialized + * @return Expression deserialized from the given json file + */ + Expr deserialize(const string &filePath); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/SimplifyExprVisitor.h b/include/nnet/Visitor/SimplifyExprVisitor.h new file mode 100644 index 00000000..e8ed459c --- /dev/null +++ b/include/nnet/Visitor/SimplifyExprVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +// Simplify a index expression tree +class SimplifyExprVisitor : public Functor stride)> { + private: + SubexprSride subexprStride; + int constant; + PtrMap strides; // [var]=strides + + map, int, RefValueLess>> divStrides, + modStrides; // 3*(i%8): []=3 + + // For divde and modulo with expr as dividend: 3*((i+1)%8): []=3 + map, int, RefAddrLess>> divExprStrides, + modExprStrides; + + public: + SimplifyExprVisitor() : Functor(0) {} + void visit_(const BinaryOp &c, optional stride) override; + void visit_(const Var &c, optional stride) override; + void visit_(const Constant &c, optional stride) override; + PtrMap getStrides(const Expr &expr); + // TODO [refactor]: move this to SimplifyFormulaMutator as a member func + // this class should be get coefficients in a expr + Expr simplify(const Expr &expr); + int getConstant(const Expr &expr); + pair, int> getStridesConstant(const Expr &expr); + optional getExprRange(const Expr &expr, const RangeOp &rangeOp); + PtrMap getStrides() { return strides; } + const auto &getDivStrides() { return divStrides; } + const auto &getModStrides() { return modStrides; } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/SimplifyFormulaMutator.h b/include/nnet/Visitor/SimplifyFormulaMutator.h new file mode 100644 index 00000000..48f92870 --- /dev/null +++ b/include/nnet/Visitor/SimplifyFormulaMutator.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +// Simplify all indexes in subscripts in an expression tree +class SimplifyFormulaMutator : public Mutator { + int nSubscripts = 0; + + public: + SimplifyFormulaMutator() : Mutator(0) {} + Expr visit_(const Subscript &c) override; + // Expr visit_(const BinaryOp &c) override; + Expr simplify(const Expr &expr); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/StrideVisitor.h b/include/nnet/Visitor/StrideVisitor.h new file mode 100644 index 00000000..dca17569 --- /dev/null +++ b/include/nnet/Visitor/StrideVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +using SubexprSride = map>; +class StrideVisitor : public Functor(void)> { + private: + SubexprSride subexprStride; + + public: + StrideVisitor(int _verobse = 0) : Functor(_verobse) {} + optional visit_(const BinaryOp &c) override; + optional visit_(const Subscript &c) override; + optional visit_(const Var &c) override; + optional visit_(const Constant &c) override; + // void visit_(const Tensor &c, const Tensor &tensor) override; + + auto getFormulaStride(const RangeOp &e) { + subexprStride.clear(); + // get the location and stride of each iterator + auto mulOp = as(e->getSummand()); + // TODO [feature]: support complex index exprs + if (!mulOp || mulOp->getOpType() != OpType::Mul) + nnet_unimplemented_continue(); + dispatch(mulOp->getLhs()); + dispatch(mulOp->getRhs()); + return subexprStride; + } + + [[nodiscard]] auto getExprStride(const Expr &e) { + subexprStride.clear(); + dispatch(e); + return subexprStride; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/common.h b/include/nnet/common.h new file mode 100644 index 00000000..567c6a27 --- /dev/null +++ b/include/nnet/common.h @@ -0,0 +1,77 @@ +#pragma once +#include "dbg.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nnet { +using std::dynamic_pointer_cast; +using std::endl; +using std::list; +using std::make_pair; +using std::make_shared; +using std::make_tuple; +using std::map; +using std::max; +using std::min; +using std::nullopt; +using std::optional; +using std::pair; +using std::set; +using std::shared_ptr; +using std::string; +using std::tie; +using std::to_string; +using std::tuple; +using std::unique_ptr; +using std::unordered_map; +template using uset = std::unordered_set; +using std::vector; +using std::weak_ptr; + +// Aliases +using dtype = float; +using HashType = int; + +template struct ptr_less { + bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; } +}; + +template struct ptr_hash { + size_t operator()(const T &lhs) const { + return std::hash()(*lhs); + } +}; + +template struct ptr_equal { + bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; } +}; + +static inline HashType genhash(HashType a, HashType b) { + return (a * 10007 + b + 12345) % 1000000007; +} + +static inline HashType genhash(string s) { + HashType ret = 0; + for (auto c : s) + ret = genhash(ret, c); + return ret; +} + +#define nnet_unimplemented_halt() \ + { assert(!"Unimplemented"); } + +#define nnet_unimplemented_continue() \ + { dbg("Unimplemented"); } + +#define nnet_assert(expr, msg) assert(((void)(msg), (expr))) + +std::string pointer_to_hex(void *i); +} // namespace nnet diff --git a/include/nnet/dbg.h b/include/nnet/dbg.h new file mode 100644 index 00000000..f5deb180 --- /dev/null +++ b/include/nnet/dbg.h @@ -0,0 +1,893 @@ +/***************************************************************************** + + dbg(...) macro + +License (MIT): + + Copyright (c) 2019 David Peter + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +*****************************************************************************/ + +#ifndef DBG_MACRO_DBG_H +#define DBG_MACRO_DBG_H + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) +#define DBG_MACRO_UNIX +#elif defined(_MSC_VER) +#define DBG_MACRO_WINDOWS +#endif + +// #ifndef DBG_MACRO_NO_WARNING +// #pragma message("WARNING: the 'dbg.h' header is included in your code base") +// #endif // DBG_MACRO_NO_WARNING + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef DBG_MACRO_UNIX +#include +#endif + +#if __cplusplus >= 201703L +#define DBG_MACRO_CXX_STANDARD 17 +#elif __cplusplus >= 201402L +#define DBG_MACRO_CXX_STANDARD 14 +#else +#define DBG_MACRO_CXX_STANDARD 11 +#endif + +#if DBG_MACRO_CXX_STANDARD >= 17 +#include +#include +#endif + +namespace dbg { + +#ifdef DBG_MACRO_UNIX +inline bool isColorizedOutputEnabled() { + return isatty(fileno(stderr)); +} +#else +inline bool isColorizedOutputEnabled() { + return true; +} +#endif + +struct time {}; + +namespace pretty_function { + +// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to +// extract the template argument in `type_name_impl` + +#if defined(__clang__) +#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char *dbg::type_name_impl() [T = ") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1; +#elif defined(__GNUC__) && !defined(__clang__) +#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char* dbg::type_name_impl() [with T = ") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1; +#elif defined(_MSC_VER) +#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char *__cdecl dbg::type_name_impl<") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1; +#else +#error "This compiler is currently not supported by dbg_macro." +#endif + +} // namespace pretty_function + +// Formatting helpers + +template +struct print_formatted { + static_assert(std::is_integral::value, + "Only integral types are supported."); + + print_formatted(T value, int numeric_base) + : inner(value), base(numeric_base) {} + + operator T() const { return inner; } + + const char* prefix() const { + switch (base) { + case 8: + return "0o"; + case 16: + return "0x"; + case 2: + return "0b"; + default: + return ""; + } + } + + T inner; + int base; +}; + +template +print_formatted hex(T value) { + return print_formatted{value, 16}; +} + +template +print_formatted oct(T value) { + return print_formatted{value, 8}; +} + +template +print_formatted bin(T value) { + return print_formatted{value, 2}; +} + +// Implementation of 'type_name()' + +template +const char* type_name_impl() { + return DBG_MACRO_PRETTY_FUNCTION; +} + +template +struct type_tag {}; + +template +std::string get_type_name(type_tag) { + namespace pf = pretty_function; + + std::string type = type_name_impl(); + return type.substr(pf::PREFIX_LENGTH, + type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH); +} + +template +std::string type_name() { + if (std::is_volatile::value) { + if (std::is_pointer::value) { + return type_name::type>() + " volatile"; + } else { + return "volatile " + type_name::type>(); + } + } + if (std::is_const::value) { + if (std::is_pointer::value) { + return type_name::type>() + " const"; + } else { + return "const " + type_name::type>(); + } + } + if (std::is_pointer::value) { + return type_name::type>() + "*"; + } + if (std::is_lvalue_reference::value) { + return type_name::type>() + "&"; + } + if (std::is_rvalue_reference::value) { + return type_name::type>() + "&&"; + } + return get_type_name(type_tag{}); +} + +inline std::string get_type_name(type_tag) { + return "short"; +} + +inline std::string get_type_name(type_tag) { + return "unsigned short"; +} + +inline std::string get_type_name(type_tag) { + return "long"; +} + +inline std::string get_type_name(type_tag) { + return "unsigned long"; +} + +inline std::string get_type_name(type_tag) { + return "std::string"; +} + +template +std::string get_type_name(type_tag>>) { + return "std::vector<" + type_name() + ">"; +} + +template +std::string get_type_name(type_tag>) { + return "std::pair<" + type_name() + ", " + type_name() + ">"; +} + +template +std::string type_list_to_string() { + std::string result; + auto unused = {(result += type_name() + ", ", 0)..., 0}; + static_cast(unused); + +#if DBG_MACRO_CXX_STANDARD >= 17 + if constexpr (sizeof...(T) > 0) { +#else + if (sizeof...(T) > 0) { +#endif + result.pop_back(); + result.pop_back(); + } + return result; +} + +template +std::string get_type_name(type_tag>) { + return "std::tuple<" + type_list_to_string() + ">"; +} + +template +inline std::string get_type_name(type_tag>) { + return type_name(); +} + +// Implementation of 'is_detected' to specialize for container-like types + +namespace detail_detector { + +struct nonesuch { + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template +using void_t = void; + +template + class Op, + class... Args> +struct detector { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +} // namespace detail_detector + +template