forked from jiuyuan/InfiniTensor
test: 支持编译 einnet 单元测试,但不是所有测试都能通过 (#174)
* test: 支持编译 einnet 单元测试,但不是所有测试都能通过 Signed-off-by: YdrMaster <ydrml@hotmail.com> * Fix: locating resource files and skip codegen - Change the path parameters in `matchExprResult` and `checkExprLogSame` to paths relative to the project home - Skip NNetMemboundOp tests as they require codegen --------- Signed-off-by: YdrMaster <ydrml@hotmail.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com> Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
ec3adf6fa7
commit
1a6fccccbe
|
@ -21,7 +21,6 @@ project(InfiniTensor C CXX)
|
||||||
|
|
||||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
|
||||||
|
|
||||||
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
# Build Type
|
# Build Type
|
||||||
|
@ -95,16 +94,17 @@ add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||||
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||||
|
|
||||||
# TVM backend
|
# TVM backend
|
||||||
if(BUILD_TEST_EINNET)
|
if(BUILD_NNET AND BUILD_TEST)
|
||||||
if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR)
|
|
||||||
message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON")
|
|
||||||
endif()
|
|
||||||
# TVM and DMLC for invoking TVM packed functions
|
# TVM and DMLC for invoking TVM packed functions
|
||||||
include_directories(${TVM_INCLUDE_DIR})
|
include_directories(${TVM_INCLUDE_DIR})
|
||||||
include_directories(${DMLC_INCLUDE_DIR})
|
include_directories(${DMLC_INCLUDE_DIR})
|
||||||
include_directories(${DLPACK_INCLUDE_DIR})
|
include_directories(${DLPACK_INCLUDE_DIR})
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ")
|
if (TVM_INCLUDE_DIR AND DMLC_INCLUDE_DIR AND DLPACK_INCLUDE_DIR AND DLPACK_INCLUDE_DIR)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels
|
||||||
|
else()
|
||||||
|
# message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_NNET AND BUILD_TEST is ON")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(BUILD_TEST)
|
if(BUILD_TEST)
|
||||||
|
@ -130,6 +130,8 @@ if(BUILD_NNET)
|
||||||
add_compile_definitions(BUILD_NNET=1)
|
add_compile_definitions(BUILD_NNET=1)
|
||||||
file(GLOB_RECURSE SRC_NNET src/nnet/*.cc)
|
file(GLOB_RECURSE SRC_NNET src/nnet/*.cc)
|
||||||
list (APPEND SRC ${SRC_NNET})
|
list (APPEND SRC ${SRC_NNET})
|
||||||
|
# For locating resource files
|
||||||
|
set_source_files_properties(src/nnet/test.cc PROPERTIES COMPILE_OPTIONS "-DINFINI_PROJECT_HOME=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
|
@ -161,7 +163,7 @@ endif()
|
||||||
target_link_libraries(InfiniTensor pybind11::embed)
|
target_link_libraries(InfiniTensor pybind11::embed)
|
||||||
|
|
||||||
# TVM backend
|
# TVM backend
|
||||||
if(BUILD_TEST_EINNET)
|
if(BUILD_NNET AND BUILD_TEST AND TVM_LIB_DIR)
|
||||||
target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so)
|
target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -333,7 +335,7 @@ if(BUILD_TEST)
|
||||||
if(BUILD_TEST_PET)
|
if(BUILD_TEST_PET)
|
||||||
build_test(test/pet/*.cc)
|
build_test(test/pet/*.cc)
|
||||||
endif()
|
endif()
|
||||||
if(BUILD_TEST_EINNET)
|
if(BUILD_NNET AND BUILD_TEST)
|
||||||
build_test(test/nnet/test_*.cc)
|
build_test(test/nnet/test_*.cc)
|
||||||
|
|
||||||
# Build expression reader
|
# Build expression reader
|
||||||
|
|
5
Makefile
5
Makefile
|
@ -7,6 +7,7 @@ KUNLUN ?= OFF
|
||||||
INTELCPU ?= off
|
INTELCPU ?= off
|
||||||
BACKTRACE ?= ON
|
BACKTRACE ?= ON
|
||||||
TEST ?= ON
|
TEST ?= ON
|
||||||
|
NNET ?= OFF
|
||||||
FORMAT_ORIGIN ?=
|
FORMAT_ORIGIN ?=
|
||||||
# Docker build options
|
# Docker build options
|
||||||
DOCKER_NAME ?= infinitensor
|
DOCKER_NAME ?= infinitensor
|
||||||
|
@ -22,13 +23,13 @@ ifeq ($(CUDA), ON)
|
||||||
DOCKER_RUN_OPTION += --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v `pwd`:`pwd` -w `pwd`
|
DOCKER_RUN_OPTION += --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v `pwd`:`pwd` -w `pwd`
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
|
||||||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||||
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
||||||
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||||
|
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||||
|
|
||||||
ifeq ($(INTELCPU), ON)
|
ifeq ($(INTELCPU), ON)
|
||||||
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
||||||
|
@ -71,5 +72,3 @@ docker-start:
|
||||||
|
|
||||||
docker-exec:
|
docker-exec:
|
||||||
docker exec -it $(DOCKER_IMAGE_NAME) bash
|
docker exec -it $(DOCKER_IMAGE_NAME) bash
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
namespace nnet {
|
namespace nnet {
|
||||||
int matchExprResult(Derivator &derivator, string fn);
|
int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome);
|
||||||
bool checkExprLogSame(string fnPrefix, int start, int end);
|
bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end);
|
||||||
bool checkExprsEquvivalence(VecExpr exprs);
|
bool checkExprsEquvivalence(VecExpr exprs);
|
||||||
} // namespace nnet
|
} // namespace nnet
|
||||||
|
|
|
@ -3,10 +3,30 @@
|
||||||
#include "nnet/Visitor/HashVisitor.h"
|
#include "nnet/Visitor/HashVisitor.h"
|
||||||
#include "nnet/Visitor/Interpreter.h"
|
#include "nnet/Visitor/Interpreter.h"
|
||||||
#include "nnet/Visitor/Serializer.h"
|
#include "nnet/Visitor/Serializer.h"
|
||||||
|
#include <filesystem>
|
||||||
namespace nnet {
|
namespace nnet {
|
||||||
|
|
||||||
int matchExprResult(Derivator &derivator, string fn) {
|
std::filesystem::path getProjectHome() {
|
||||||
auto ans = Serializer().deserialize(fn);
|
#ifndef INFINI_PROJECT_HOME
|
||||||
|
#error INFINI_PROJECT_HOME is not defined
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define Q(x) #x
|
||||||
|
#define QUOTE(x) Q(x)
|
||||||
|
#define PROJECT_HOME QUOTE(INFINI_PROJECT_HOME)
|
||||||
|
return std::filesystem::path(PROJECT_HOME);
|
||||||
|
#undef PROJECT_HOME
|
||||||
|
#undef QUOTE
|
||||||
|
#undef Q
|
||||||
|
}
|
||||||
|
|
||||||
|
string getResourceFilePath(string path) {
|
||||||
|
return (getProjectHome() / path).string();
|
||||||
|
}
|
||||||
|
|
||||||
|
int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome) {
|
||||||
|
auto fullPath = getResourceFilePath(pathRelativeToProjectHome);
|
||||||
|
auto ans = Serializer().deserialize(fullPath);
|
||||||
auto hashAns = HashVisitor()(ans);
|
auto hashAns = HashVisitor()(ans);
|
||||||
int match = 0;
|
int match = 0;
|
||||||
for (const auto &candidate : derivator.getCandidates()) {
|
for (const auto &candidate : derivator.getCandidates()) {
|
||||||
|
@ -16,16 +36,17 @@ int matchExprResult(Derivator &derivator, string fn) {
|
||||||
return match;
|
return match;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool checkExprLogSame(string fnPrefix, int start, int end) {
|
bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end) {
|
||||||
|
auto fullPath = getResourceFilePath(pathRelativeToProjectHome);
|
||||||
Serializer serializer;
|
Serializer serializer;
|
||||||
string fn0 = fnPrefix + to_string(start) + ".expr";
|
string fn0 = fullPath + to_string(start) + ".expr";
|
||||||
Expr expr0 = serializer.deserialize(fn0);
|
Expr expr0 = serializer.deserialize(fn0);
|
||||||
RangeOp range0 = as<RangeOpNode>(expr0);
|
RangeOp range0 = as<RangeOpNode>(expr0);
|
||||||
Interpreter interpreter(range0);
|
Interpreter interpreter(range0);
|
||||||
auto ans0 = interpreter.interpretUniformSample(range0);
|
auto ans0 = interpreter.interpretUniformSample(range0);
|
||||||
dbg(expr0, ans0);
|
dbg(expr0, ans0);
|
||||||
for (int i = start + 1; i < end; ++i) {
|
for (int i = start + 1; i < end; ++i) {
|
||||||
string fn1 = fnPrefix + to_string(i) + ".expr";
|
string fn1 = fullPath + to_string(i) + ".expr";
|
||||||
Expr expr1 = serializer.deserialize(fn1);
|
Expr expr1 = serializer.deserialize(fn1);
|
||||||
RangeOp range1 = as<RangeOpNode>(expr1);
|
RangeOp range1 = as<RangeOpNode>(expr1);
|
||||||
dbg(fn1, expr1);
|
dbg(fn1, expr1);
|
||||||
|
|
|
@ -85,7 +85,7 @@ TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
|
||||||
|
|
||||||
EXPECT_GE(derivator.getNumCandidates(), 1);
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr");
|
derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
// derivator.print();
|
// derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
|
@ -160,7 +160,7 @@ TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
|
||||||
EXPECT_GE(derivator.getNumCandidates(), 1);
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator,
|
||||||
"../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
|
"test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
}
|
}
|
||||||
|
@ -197,8 +197,7 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
||||||
derivator.search(conv_9x9, 0);
|
derivator.search(conv_9x9, 0);
|
||||||
|
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr");
|
||||||
"../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr");
|
|
||||||
// derivator.print();
|
// derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
|
@ -236,7 +235,7 @@ TEST_F(OpSearch, G2BMM_RuleBased) {
|
||||||
|
|
||||||
EXPECT_GE(derivator.getNumCandidates(), 1);
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches =
|
int nMatches =
|
||||||
matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr");
|
matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
// derivator.print();
|
// derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
|
|
|
@ -365,14 +365,14 @@ TEST(TConv2gemm, TConv4x4_NHWF_RSFC_search) {
|
||||||
// }
|
// }
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator,
|
||||||
"../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
|
"test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TConv2gemm, TConv4x4_NHWF_FRSC_CheckDerivationCorrectness_log) {
|
TEST(TConv2gemm, TConv4x4_NHWF_FRSC_CheckDerivationCorrectness_log) {
|
||||||
const string fnPrefix =
|
const string fnPrefix =
|
||||||
"../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_";
|
"test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_";
|
||||||
EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 11));
|
EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 11));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,10 +37,8 @@ TEST(Conv2conv, 9x9_NCHW_FCRS) {
|
||||||
derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true);
|
derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true);
|
||||||
} else
|
} else
|
||||||
derivator.search(conv_9x9, 0);
|
derivator.search(conv_9x9, 0);
|
||||||
|
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator, "test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr");
|
||||||
"../test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr");
|
|
||||||
derivator.print();
|
derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
|
@ -81,8 +79,7 @@ TEST(Conv2conv, 6x6_RuleBased_NCHW_FCRS) {
|
||||||
|
|
||||||
ASSERT_GE(derivator.getNumCandidates(), 1);
|
ASSERT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator, "test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr");
|
||||||
"../test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr");
|
|
||||||
derivator.print();
|
derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
|
@ -121,8 +118,7 @@ TEST(Conv2conv, 5x5_RuleBased_NCHW_FCRS) {
|
||||||
derivator.search(conv_9x9, 0);
|
derivator.search(conv_9x9, 0);
|
||||||
|
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr");
|
||||||
"../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr");
|
|
||||||
derivator.print();
|
derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
|
|
|
@ -147,7 +147,7 @@ TEST(Conv2gemm, timing_NHWC_RSFC_search) {
|
||||||
|
|
||||||
// Conv2gemm requires thorough update, this is disabled temporarily
|
// Conv2gemm requires thorough update, this is disabled temporarily
|
||||||
TEST(Conv2gemm, CheckCorrectness) {
|
TEST(Conv2gemm, CheckCorrectness) {
|
||||||
const string fnPrefix = "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_";
|
const string fnPrefix = "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_";
|
||||||
// conv2gemm_7 has T3
|
// conv2gemm_7 has T3
|
||||||
EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 7));
|
EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 7));
|
||||||
}
|
}
|
||||||
|
@ -182,7 +182,7 @@ TEST(Conv2gemm, NCHW_RSFC_search) {
|
||||||
|
|
||||||
ASSERT_GE(derivator.getNumCandidates(), 1);
|
ASSERT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr");
|
derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
// derivator.print();
|
// derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
|
@ -278,6 +278,6 @@ TEST(Conv2gemm1x7, NCHW_FCRS_search) {
|
||||||
ASSERT_GE(derivator.getNumCandidates(), 1);
|
ASSERT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches = matchExprResult(
|
int nMatches = matchExprResult(
|
||||||
derivator,
|
derivator,
|
||||||
"../test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr");
|
"test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
}
|
}
|
|
@ -36,10 +36,8 @@ TEST(GBMM, RuleBased) {
|
||||||
} else {
|
} else {
|
||||||
derivator.search(dialted_g2bmm, 0);
|
derivator.search(dialted_g2bmm, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT_GE(derivator.getNumCandidates(), 1);
|
ASSERT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches =
|
int nMatches = matchExprResult(derivator, "test/nnet/log/gbmm/GBMM_9.expr");
|
||||||
matchExprResult(derivator, "../test/nnet/log/gbmm/GBMM_9.expr");
|
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
derivator.print();
|
derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
|
@ -78,7 +76,7 @@ TEST(G2BMM, RuleBased) {
|
||||||
|
|
||||||
ASSERT_GE(derivator.getNumCandidates(), 1);
|
ASSERT_GE(derivator.getNumCandidates(), 1);
|
||||||
int nMatches =
|
int nMatches =
|
||||||
matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr");
|
matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr");
|
||||||
EXPECT_GE(nMatches, 1);
|
EXPECT_GE(nMatches, 1);
|
||||||
derivator.print();
|
derivator.print();
|
||||||
derivator.printStatistics();
|
derivator.printStatistics();
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
|
@ -12,7 +14,22 @@
|
||||||
using namespace infini;
|
using namespace infini;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
TEST(nnet, MemboundOpInterpretation) {
|
class NNetMemboundOp : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
if (!hasTVMBackend())
|
||||||
|
GTEST_SKIP() << "Skipping test since no TVM backend.\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static bool hasTVMBackend() {
|
||||||
|
// TODO: as the dispatch mechanism of backend is going to change, this
|
||||||
|
// function is to be implemented.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NNetMemboundOp, MemboundOpInterpretation) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
@ -41,7 +58,7 @@ TEST(nnet, MemboundOpInterpretation) {
|
||||||
EXPECT_TRUE(membound->getOutput()->equalData(ans));
|
EXPECT_TRUE(membound->getOutput()->equalData(ans));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(nnet, MemboundOp_Ansor_Codegen) {
|
TEST_F(NNetMemboundOp, MemboundOp_Ansor_Codegen) {
|
||||||
auto runtime = make_ref<CudaRuntimeObj>();
|
auto runtime = make_ref<CudaRuntimeObj>();
|
||||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
@ -91,7 +108,7 @@ pair<std::vector<nnet::Tensor>, nnet::Expr> getPReluExpr(int size) {
|
||||||
return {{A, B}, ret};
|
return {{A, B}, ret};
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(nnet, PRelu_Ansor_Codegen) {
|
TEST_F(NNetMemboundOp, PRelu_Ansor_Codegen) {
|
||||||
auto cuda = make_ref<CudaRuntimeObj>();
|
auto cuda = make_ref<CudaRuntimeObj>();
|
||||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(cuda);
|
Graph g = make_ref<GraphObj>(cuda);
|
||||||
|
@ -116,3 +133,5 @@ TEST(nnet, PRelu_Ansor_Codegen) {
|
||||||
auto oCpu = gCpu->cloneTensor(o0);
|
auto oCpu = gCpu->cloneTensor(o0);
|
||||||
EXPECT_TRUE(oCpu->equalData(ans));
|
EXPECT_TRUE(oCpu->equalData(ans));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
|
||||||
#include "core/blob.h"
|
#include "core/blob.h"
|
||||||
#include "core/dummy_mutator.h"
|
#include "core/dummy_mutator.h"
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
|
@ -477,3 +479,5 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
||||||
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
||||||
// }
|
// }
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue