diff --git a/CMakeLists.txt b/CMakeLists.txt index d942fcd4..b08d9f85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,6 @@ project(InfiniTensor C CXX) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) -cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF) set(DEFAULT_BUILD_TYPE "RelWithDebInfo") # Build Type @@ -95,16 +94,17 @@ add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) # TVM backend -if(BUILD_TEST_EINNET) - if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR) - message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON") - endif() +if(BUILD_NNET AND BUILD_TEST) # TVM and DMLC for invoking TVM packed functions include_directories(${TVM_INCLUDE_DIR}) include_directories(${DMLC_INCLUDE_DIR}) include_directories(${DLPACK_INCLUDE_DIR}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels + if (TVM_INCLUDE_DIR AND DMLC_INCLUDE_DIR AND DLPACK_INCLUDE_DIR AND DLPACK_INCLUDE_DIR) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels + else() + # message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_NNET AND BUILD_TEST is ON") + endif() endif() if(BUILD_TEST) @@ -130,6 +130,8 @@ if(BUILD_NNET) add_compile_definitions(BUILD_NNET=1) file(GLOB_RECURSE SRC_NNET src/nnet/*.cc) list (APPEND SRC ${SRC_NNET}) + # For locating resource files + set_source_files_properties(src/nnet/test.cc PROPERTIES COMPILE_OPTIONS "-DINFINI_PROJECT_HOME=${CMAKE_CURRENT_SOURCE_DIR}") endif() if(USE_CUDA) @@ -161,7 +163,7 @@ endif() target_link_libraries(InfiniTensor pybind11::embed) # TVM backend -if(BUILD_TEST_EINNET) +if(BUILD_NNET AND BUILD_TEST AND TVM_LIB_DIR) target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so) endif() @@ -333,7 +335,7 @@ if(BUILD_TEST) if(BUILD_TEST_PET) build_test(test/pet/*.cc) endif() - if(BUILD_TEST_EINNET) + if(BUILD_NNET AND BUILD_TEST) build_test(test/nnet/test_*.cc) # Build expression reader diff --git a/Makefile b/Makefile index 19f1b353..302f47b8 100644 --- a/Makefile +++ b/Makefile @@ -7,12 +7,13 @@ KUNLUN ?= OFF INTELCPU ?= off BACKTRACE ?= ON TEST ?= ON +NNET ?= OFF FORMAT_ORIGIN ?= # Docker build options DOCKER_NAME ?= infinitensor DOCKER_IMAGE_NAME ?= infinitensor DOCKER_FILE ?= infinitensor_ubuntu_22.04.dockerfile -DOCKER_RUN_OPTION ?= +DOCKER_RUN_OPTION ?= # CUDA option. ifeq ($(CUDA), ON) @@ -22,13 +23,13 @@ ifeq ($(CUDA), ON) DOCKER_RUN_OPTION += --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v `pwd`:`pwd` -w `pwd` endif - CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) CMAKE_OPT += -DUSE_CUDA=$(CUDA) CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) +CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp @@ -60,7 +61,7 @@ test-api: @echo python3 pyinfinitensor/tests/test_api.py -docker-build: +docker-build: docker build -f scripts/dockerfile/$(DOCKER_FILE) -t $(DOCKER_NAME) . docker-run: @@ -71,5 +72,3 @@ docker-start: docker-exec: docker exec -it $(DOCKER_IMAGE_NAME) bash - - diff --git a/include/nnet/test.h b/include/nnet/test.h index 6cf75873..1d24bbd4 100644 --- a/include/nnet/test.h +++ b/include/nnet/test.h @@ -24,7 +24,7 @@ // clang-format on namespace nnet { -int matchExprResult(Derivator &derivator, string fn); -bool checkExprLogSame(string fnPrefix, int start, int end); +int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome); +bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end); bool checkExprsEquvivalence(VecExpr exprs); } // namespace nnet diff --git a/src/nnet/test.cc b/src/nnet/test.cc index 6c9738f2..fbf41e20 100644 --- a/src/nnet/test.cc +++ b/src/nnet/test.cc @@ -3,10 +3,30 @@ #include "nnet/Visitor/HashVisitor.h" #include "nnet/Visitor/Interpreter.h" #include "nnet/Visitor/Serializer.h" +#include namespace nnet { -int matchExprResult(Derivator &derivator, string fn) { - auto ans = Serializer().deserialize(fn); +std::filesystem::path getProjectHome() { +#ifndef INFINI_PROJECT_HOME +#error INFINI_PROJECT_HOME is not defined +#endif + +#define Q(x) #x +#define QUOTE(x) Q(x) +#define PROJECT_HOME QUOTE(INFINI_PROJECT_HOME) + return std::filesystem::path(PROJECT_HOME); +#undef PROJECT_HOME +#undef QUOTE +#undef Q +} + +string getResourceFilePath(string path) { + return (getProjectHome() / path).string(); +} + +int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome) { + auto fullPath = getResourceFilePath(pathRelativeToProjectHome); + auto ans = Serializer().deserialize(fullPath); auto hashAns = HashVisitor()(ans); int match = 0; for (const auto &candidate : derivator.getCandidates()) { @@ -16,16 +36,17 @@ int matchExprResult(Derivator &derivator, string fn) { return match; } -bool checkExprLogSame(string fnPrefix, int start, int end) { +bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end) { + auto fullPath = getResourceFilePath(pathRelativeToProjectHome); Serializer serializer; - string fn0 = fnPrefix + to_string(start) + ".expr"; + string fn0 = fullPath + to_string(start) + ".expr"; Expr expr0 = serializer.deserialize(fn0); RangeOp range0 = as(expr0); Interpreter interpreter(range0); auto ans0 = interpreter.interpretUniformSample(range0); dbg(expr0, ans0); for (int i = start + 1; i < end; ++i) { - string fn1 = fnPrefix + to_string(i) + ".expr"; + string fn1 = fullPath + to_string(i) + ".expr"; Expr expr1 = serializer.deserialize(fn1); RangeOp range1 = as(expr1); dbg(fn1, expr1); @@ -67,4 +88,4 @@ bool checkExprsEquvivalence(VecExpr exprs) { return true; } -} // namespace nnet \ No newline at end of file +} // namespace nnet diff --git a/test/nnet/test_OpSearch.cc b/test/nnet/test_OpSearch.cc index f969ccf7..14a6c737 100644 --- a/test/nnet/test_OpSearch.cc +++ b/test/nnet/test_OpSearch.cc @@ -85,7 +85,7 @@ TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr"); + derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); @@ -160,7 +160,7 @@ TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( derivator, - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); EXPECT_GE(nMatches, 1); derivator.printStatistics(); } @@ -197,8 +197,7 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) { derivator.search(conv_9x9, 0); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); // derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -236,8 +235,8 @@ TEST_F(OpSearch, G2BMM_RuleBased) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = - matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr"); + matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); -} \ No newline at end of file +} diff --git a/test/nnet/test_TConv2gemm.cc b/test/nnet/test_TConv2gemm.cc index 73f32088..3b84b843 100644 --- a/test/nnet/test_TConv2gemm.cc +++ b/test/nnet/test_TConv2gemm.cc @@ -365,14 +365,14 @@ TEST(TConv2gemm, TConv4x4_NHWF_RSFC_search) { // } int nMatches = matchExprResult( derivator, - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); EXPECT_GE(nMatches, 1); derivator.printStatistics(); } TEST(TConv2gemm, TConv4x4_NHWF_FRSC_CheckDerivationCorrectness_log) { const string fnPrefix = - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_"; + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_"; EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 11)); } @@ -388,4 +388,4 @@ TEST(Conv2conv, InfoGAN_ConvTranspose_3_OOB_Test) { dbg(expr); Derivator derivator; derivator.checkOOB(as(expr)); -} \ No newline at end of file +} diff --git a/test/nnet/test_conv2conv.cc b/test/nnet/test_conv2conv.cc index 8e961e95..a77ab39a 100644 --- a/test/nnet/test_conv2conv.cc +++ b/test/nnet/test_conv2conv.cc @@ -37,10 +37,8 @@ TEST(Conv2conv, 9x9_NCHW_FCRS) { derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true); } else derivator.search(conv_9x9, 0); - int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -81,8 +79,7 @@ TEST(Conv2conv, 6x6_RuleBased_NCHW_FCRS) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -121,9 +118,8 @@ TEST(Conv2conv, 5x5_RuleBased_NCHW_FCRS) { derivator.search(conv_9x9, 0); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); -} \ No newline at end of file +} diff --git a/test/nnet/test_conv2gemm.cc b/test/nnet/test_conv2gemm.cc index 9827497b..3f15d328 100644 --- a/test/nnet/test_conv2gemm.cc +++ b/test/nnet/test_conv2gemm.cc @@ -147,7 +147,7 @@ TEST(Conv2gemm, timing_NHWC_RSFC_search) { // Conv2gemm requires thorough update, this is disabled temporarily TEST(Conv2gemm, CheckCorrectness) { - const string fnPrefix = "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_"; + const string fnPrefix = "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_"; // conv2gemm_7 has T3 EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 7)); } @@ -182,7 +182,7 @@ TEST(Conv2gemm, NCHW_RSFC_search) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr"); + derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); @@ -278,6 +278,6 @@ TEST(Conv2gemm1x7, NCHW_FCRS_search) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( derivator, - "../test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr"); + "test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr"); EXPECT_GE(nMatches, 1); -} \ No newline at end of file +} diff --git a/test/nnet/test_g2bmm.cc b/test/nnet/test_g2bmm.cc index e285fb98..6ae1150e 100644 --- a/test/nnet/test_g2bmm.cc +++ b/test/nnet/test_g2bmm.cc @@ -36,10 +36,8 @@ TEST(GBMM, RuleBased) { } else { derivator.search(dialted_g2bmm, 0); } - ASSERT_GE(derivator.getNumCandidates(), 1); - int nMatches = - matchExprResult(derivator, "../test/nnet/log/gbmm/GBMM_9.expr"); + int nMatches = matchExprResult(derivator, "test/nnet/log/gbmm/GBMM_9.expr"); EXPECT_GE(nMatches, 1); derivator.print(); derivator.printStatistics(); @@ -78,7 +76,7 @@ TEST(G2BMM, RuleBased) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = - matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr"); + matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr"); EXPECT_GE(nMatches, 1); derivator.print(); derivator.printStatistics(); diff --git a/test/nnet/test_memboundOp.cc b/test/nnet/test_memboundOp.cc index 910344f2..49716161 100644 --- a/test/nnet/test_memboundOp.cc +++ b/test/nnet/test_memboundOp.cc @@ -1,3 +1,5 @@ +#ifdef USE_CUDA + #include "core/graph.h" #include "core/runtime.h" #include "cuda/cuda_runtime.h" @@ -12,7 +14,22 @@ using namespace infini; using namespace std; -TEST(nnet, MemboundOpInterpretation) { +class NNetMemboundOp : public ::testing::Test { + protected: + void SetUp() override { + if (!hasTVMBackend()) + GTEST_SKIP() << "Skipping test since no TVM backend.\n"; + } + + private: + static bool hasTVMBackend() { + // TODO: as the dispatch mechanism of backend is going to change, this + // function is to be implemented. + return false; + } +}; + +TEST_F(NNetMemboundOp, MemboundOpInterpretation) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); @@ -41,7 +58,7 @@ TEST(nnet, MemboundOpInterpretation) { EXPECT_TRUE(membound->getOutput()->equalData(ans)); } -TEST(nnet, MemboundOp_Ansor_Codegen) { +TEST_F(NNetMemboundOp, MemboundOp_Ansor_Codegen) { auto runtime = make_ref(); Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpu); @@ -91,7 +108,7 @@ pair, nnet::Expr> getPReluExpr(int size) { return {{A, B}, ret}; } -TEST(nnet, PRelu_Ansor_Codegen) { +TEST_F(NNetMemboundOp, PRelu_Ansor_Codegen) { auto cuda = make_ref(); Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(cuda); @@ -116,3 +133,5 @@ TEST(nnet, PRelu_Ansor_Codegen) { auto oCpu = gCpu->cloneTensor(o0); EXPECT_TRUE(oCpu->equalData(ans)); } + +#endif diff --git a/test/nnet/test_mutator.cc b/test/nnet/test_mutator.cc index cf4d8ab2..abc3f604 100644 --- a/test/nnet/test_mutator.cc +++ b/test/nnet/test_mutator.cc @@ -1,3 +1,5 @@ +#ifdef USE_CUDA + #include "core/blob.h" #include "core/dummy_mutator.h" #include "core/graph.h" @@ -477,3 +479,5 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) { // EXPECT_TRUE(graph->verification(bestGraph.get(), true)); // } } // namespace infini + +#endif