diff --git a/CMakeLists.txt b/CMakeLists.txt index 13ce9cd1..f3c06283 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,16 @@ cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST set(DEFAULT_BUILD_TYPE "RelWithDebInfo") +if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) + message(STATUS "Using config.cmake in CMAKE_CURRENT_BINARY_DIR directory") + include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) +else() + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/config.cmake) + message(STATUS "Using config.cmake in CMAKE_CURRENT_SOURCE_DIR directory") + include(${CMAKE_CURRENT_SOURCE_DIR}/config.cmake) + endif() +endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off @@ -32,19 +42,6 @@ if(OpenMP_CXX_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") endif() - -if(BUILD_TEST) - set(BUILD_GMOCK - OFF - CACHE BOOL "Do not build gmock" FORCE) - set(INSTALL_GTEST - OFF - CACHE BOOL "Do not install gtest" FORCE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall ") - add_subdirectory(3rd-party/googletest) - include_directories(SYSTEM 3rd-party/googletest/googletest/include) -endif() - #Protobuf if(USE_PROTOBUF) add_definitions(-D TENSOR_PROTOBUF) @@ -71,10 +68,35 @@ include_directories(3rd-party/pybind11/include) add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) +# TVM backend +if(BUILD_TEST_EINNET) + if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR) + message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON") + endif() + # TVM and DMLC for invoking TVM packed functions + include_directories(${TVM_INCLUDE_DIR}) + include_directories(${DMLC_INCLUDE_DIR}) + include_directories(${DLPACK_INCLUDE_DIR}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels +endif() + +if(BUILD_TEST) + set(BUILD_GMOCK + OFF + CACHE BOOL "Do not build gmock" FORCE) + set(INSTALL_GTEST + OFF + CACHE BOOL "Do not install gtest" FORCE) + add_subdirectory(3rd-party/googletest) + include_directories(3rd-party/googletest/googletest/include) +endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion + # Source files file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc) @@ -101,6 +123,11 @@ endif() target_link_libraries(InfiniTensor pybind11::embed) +# TVM backend +if(BUILD_TEST_EINNET) + target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so) +endif() + # Python bindings file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc) pybind11_add_module(backend MODULE ${FFIS}) @@ -229,5 +256,9 @@ if(BUILD_TEST) endif() if(BUILD_TEST_EINNET) build_test(test/nnet/test_*.cc) + + # Build expression reader + add_executable(nnet_reader test/nnet/readlog.cc) + target_link_libraries(nnet_reader InfiniTensor) endif() endif() diff --git a/Makefile b/Makefile index 3df4c34d..6b5fa090 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ endif build: mkdir -p build/$(TYPE) - cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j22 + cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8 clean: rm -rf build diff --git a/cmake/config_lotus_TVM.cmake b/cmake/config_lotus_TVM.cmake new file mode 100644 index 00000000..6a45220f --- /dev/null +++ b/cmake/config_lotus_TVM.cmake @@ -0,0 +1,13 @@ +set(TVM_HOME "/home/zly/Apps/tvm-v0.10.0") +set(TVM_INCLUDE_DIR "${TVM_HOME}/include") +set(TVM_LIB_DIR "${TVM_HOME}/build") +set(DMLC_INCLUDE_DIR "${TVM_HOME}/3rdparty/dmlc-core/include") +set(DLPACK_INCLUDE_DIR "${TVM_HOME}/3rdparty/dlpack/include") + +set(USE_CUDA ON) +set(USE_BANG OFF) + +set(BUILD_TEST ON) +set(BUILD_TEST_CORE ON) +set(BUILD_TEST_PET OFF) +set(BUILD_TEST_EINNET ON) diff --git a/include/core/graph.h b/include/core/graph.h index cbeceac1..3d62be30 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -19,6 +19,9 @@ class GraphObj : public Object { Tensor addTensor(Shape dim, DataType dtype = DataType::Float32); Tensor addTensor(const Tensor &tensor); TensorVec addTensor(const TensorVec &tensors); + /** + * @brief Clone a tensor and add it to the graph. + */ Tensor cloneTensor(const Tensor &tensor) { return addTensor(tensor->clone(runtime)); } diff --git a/include/core/tensor.h b/include/core/tensor.h index 72a3b007..8f8fa356 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -69,6 +69,9 @@ class TensorObj : public TensorBaseObj { void copyData(const TensorObj *src); void copyData(const Tensor &src) { copyData(src.get()); } + + // FIXME: std::fucntion copies the generator instead of passing it by ref. + // Thus the internal state of generator cannot be updated. void setData( const std::function &generator) const; Tensor clone() const { @@ -92,29 +95,31 @@ class TensorObj : public TensorBaseObj { } void printData() const; - bool equalData(const Tensor &rhs) const; + bool equalData(const Tensor &rhs, double relativeError = 1e-6) const; template bool equalData(const vector &dataVector) { IT_ASSERT(DataType::get() == dtype); IT_ASSERT(size() == dataVector.size()); - return equalDataImpl(getRawDataPtr(), dataVector.data(), size()); + return equalDataImpl(getRawDataPtr(), dataVector.data(), size(), + 1e-6); } size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const; private: - void printDataFloat() const; - void printDataUint32_t() const; + void printDataFloat(float *ptr) const; + void printDataUint32_t(uint32_t *ptr) const; template - bool equalDataImpl(const T *a, const T *b, size_t size) const { + bool equalDataImpl(const T *a, const T *b, size_t size, + double relativeError) const { for (size_t i = 0; i < size; ++i) { if constexpr (std::is_integral_v) { if (a[i] != b[i]) return false; } else if constexpr (std::is_floating_point_v) { if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) > - 1e-6) { + relativeError) { printf("Error on %lu: %f %f\n", i, a[i], b[i]); return false; } diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index 7d3bb65d..dec9a40b 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -23,9 +23,8 @@ const char *errName; \ if (CUDA_SUCCESS != err) { \ cuGetErrorString(err, &errName); \ - fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \ - errName); \ - exit(EXIT_FAILURE); \ + IT_ASSERT(err == CUDA_SUCCESS, \ + (string("CU error: ") + string(errName))); \ } \ } diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 5f191b35..b5830454 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -11,18 +11,8 @@ class CudaRuntimeObj : public RuntimeObj { CudaPtr workspace; size_t workspaceSize; - public: - CUdevice cuDevice; - CUcontext newContext; - public: CudaRuntimeObj() : RuntimeObj(Device::CUDA) { - // Prepare for nvrtc. cuCtxCreate should be called befero others. - // Otherwise it will result in strange failure, such as cuBLAS failed on - // certian inputs. - checkCUresult(cuInit(0)); - checkCUresult(cuDeviceGet(&cuDevice, 0)); - checkCUresult(cuCtxCreate(&newContext, 0, cuDevice)); checkCudnnError(cudnnCreate(&cudnn)); checkCublasError(cublasCreate(&cublas)); @@ -32,10 +22,13 @@ class CudaRuntimeObj : public RuntimeObj { workspace = alloc(workspaceSize); } virtual ~CudaRuntimeObj() { - dealloc(workspace); - checkCudnnError(cudnnDestroy(cudnn)); - checkCublasError(cublasDestroy(cublas)); - checkCUresult(cuCtxDestroy(newContext)); + try { + dealloc(workspace); + checkCudnnError(cudnnDestroy(cudnn)); + checkCublasError(cublasDestroy(cublas)); + } catch (const std::exception &e) { + std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl; + } } string toString() const override; diff --git a/include/nnet/Visitor/HashVisitor.h b/include/nnet/Visitor/HashVisitor.h index 0d20f49d..bc006eca 100644 --- a/include/nnet/Visitor/HashVisitor.h +++ b/include/nnet/Visitor/HashVisitor.h @@ -22,10 +22,11 @@ class HashVisitor : public Functor { HashType visit_(const Subscript &c) override; HashType visit_(const Tensor &c) override; HashType visit_(const Var &c) override; + HashType visit_(const Func &c) override; public: HashVisitor(int _verobse = 0) : Functor(_verobse) {} HashType getHash(const Expr &c); }; -} // namespace nnet \ No newline at end of file +} // namespace nnet diff --git a/include/nnet/Visitor/MergeMemboundMutator.h b/include/nnet/Visitor/MergeMemboundMutator.h index 0b2673fa..6baebcd1 100644 --- a/include/nnet/Visitor/MergeMemboundMutator.h +++ b/include/nnet/Visitor/MergeMemboundMutator.h @@ -20,7 +20,13 @@ class MergeMemboundMutator : public Mutator { */ MergeMemboundMutator(const VecExpr &kernels) : Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {} - Expr merge(bool allowEmptyMembound = false); + + /// @brief Merged multiple expressions into one with one or several stages. + /// @param allowEmptyMembound + /// @param allowFailure If true, return nullptr when merging fails. If + /// false, assert will fail. + /// @return + Expr merge(bool allowEmptyMembound = false, bool allowFailure = false); }; -} // namespace nnet \ No newline at end of file +} // namespace nnet diff --git a/include/nnet/common.h b/include/nnet/common.h index 567c6a27..616e1dd7 100644 --- a/include/nnet/common.h +++ b/include/nnet/common.h @@ -66,7 +66,7 @@ static inline HashType genhash(string s) { } #define nnet_unimplemented_halt() \ - { assert(!"Unimplemented"); } + { IT_TODO_HALT(); } #define nnet_unimplemented_continue() \ { dbg("Unimplemented"); } diff --git a/include/nnet/expr.h b/include/nnet/expr.h index 751859c9..f424237d 100644 --- a/include/nnet/expr.h +++ b/include/nnet/expr.h @@ -104,10 +104,11 @@ enum class NodeType { FuncNodeType }; -enum class FuncType { Relu, Tanh }; +enum class FuncType { Relu, Tanh, PRelu }; -#define DEFINE_GETTYPE(CLASS) \ - NodeType getType() const override { return NodeType::CLASS##Type; } +#define DEFINE_GETTYPE(CLASS, isScalar_v) \ + NodeType getType() const override { return NodeType::CLASS##Type; } \ + bool isScalar() const override { return isScalar_v; } class ExprNode { public: @@ -119,6 +120,7 @@ class ExprNode { friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr); virtual NodeType getType() const = 0; + virtual bool isScalar() const = 0; }; class VarNode : public ExprNode { @@ -127,7 +129,7 @@ class VarNode : public ExprNode { public: VarNode(std::string _name) : name(_name){}; virtual ~VarNode() {} - DEFINE_GETTYPE(VarNode); + DEFINE_GETTYPE(VarNode, true); const std::string &getName() const { return name; } HashType hash() const override { return genhash(name); }; @@ -152,7 +154,7 @@ class TensorNode : public ExprNode { TensorNode(string _name, vector _shape, vector _paddings = {}, Routine _source = nullptr); virtual ~TensorNode() {} - DEFINE_GETTYPE(TensorNode); + DEFINE_GETTYPE(TensorNode, false); bool operator==(const string &rhs) { return name == rhs; } friend bool operator==(const string &lhs, const TensorNode &rhs) { @@ -174,6 +176,7 @@ class TensorNode : public ExprNode { const Routine &getSource() const { return source; } int getData(const Ref> &data, const vector &idx); size_t getOffset(const vector &idx); + bool hasPadding(); }; enum class OpType { Range, Add, Mul, Div, Mod, Sub }; @@ -220,7 +223,7 @@ class RangeOpNode : public OperatorNode { const vector &paddings) : OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters}, paddings(paddings){}; - DEFINE_GETTYPE(RangeOpNode); + DEFINE_GETTYPE(RangeOpNode, false); virtual HashType hash() const override { nnet_unimplemented_halt(); @@ -289,7 +292,7 @@ class BinaryOpNode : public OperatorNode { BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs) : OperatorNode(_opType, {_lhs, _rhs}){}; virtual ~BinaryOpNode() {} - DEFINE_GETTYPE(BinaryOpNode); + DEFINE_GETTYPE(BinaryOpNode, true); virtual HashType hash() const override { return genhash((HashType)opType, @@ -314,7 +317,7 @@ class ConstantNode : public ExprNode { ConstantNode(int _val) : val(_val){}; ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){}; virtual ~ConstantNode() {} - DEFINE_GETTYPE(ConstantNode); + DEFINE_GETTYPE(ConstantNode, true); int getValue() const { return val; } virtual HashType hash() const override { return genhash(val, 6214587); }; @@ -334,7 +337,7 @@ class SubscriptNode : public ExprNode { SubscriptNode(Expr _indexed, vector _subExprs) : subExprs(_subExprs) { setObject(_indexed); }; - DEFINE_GETTYPE(SubscriptNode); + DEFINE_GETTYPE(SubscriptNode, true); virtual HashType hash() const override { nnet_unimplemented_continue(); @@ -358,14 +361,15 @@ class SubscriptNode : public ExprNode { class FuncNode : public ExprNode { protected: - Subscript object; + Expr object; FuncType funcType; public: - FuncNode(Expr object, FuncType funcType) : funcType(funcType) { - setObject(object); + FuncNode(Expr object, FuncType funcType) + : object(object), funcType(funcType) { + nnet_assert(object->isScalar(), "FuncNode operates on a scalar"); } - DEFINE_GETTYPE(FuncNode); + DEFINE_GETTYPE(FuncNode, true); virtual HashType hash() const override { nnet_unimplemented_continue(); @@ -373,7 +377,7 @@ class FuncNode : public ExprNode { }; virtual string toReadable() const override; - const Subscript &getObject() const { return object; } + const Expr &getObject() const { return object; } void setObject(Expr e); FuncType getFuncType() const { return funcType; } diff --git a/include/nnet/nmutator.h b/include/nnet/nmutator.h index 6fab0857..c3009b10 100644 --- a/include/nnet/nmutator.h +++ b/include/nnet/nmutator.h @@ -20,7 +20,7 @@ class NMutator : public Mutator { public: NMutator(Mode mode = Mode::Normal); - NMutator(const std::vector &derivationRules); + NMutator(Mode mode, const std::vector &derivationRules); ~NMutator(); vector run(const Graph &in_graph) override; diff --git a/include/operators/matmul.h b/include/operators/matmul.h index 9c54d332..a1c57cfe 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -19,9 +19,15 @@ class MatmulObj : public OperatorObj { public: /** - * @brief Construct a new Matmul object. This comments show how operators is - * defined in InfiniTensor. The constructor can create output tensors for - * the operator or not, which depends on `graph`. + * @brief Matmul operator with batch broadcast and tensor transpose + * supports. Only one tensor with singe batch can be broadcasted due to the + * BLAS interface restriction. Tranpose indicates whether the last two + * dimensions should be transposed before Matmul and does not affect other + * leading dimensions. + * + * Matmul show how operators are defined in InfiniTensor. The constructor of + * an operator can create output tensors for the operator or not, which + * depends on `graph`. * * @param graph The computation graph that this operator belongs to. * @param A The input tensor. diff --git a/include/operators/membound.h b/include/operators/membound.h index 1828723e..4a444553 100644 --- a/include/operators/membound.h +++ b/include/operators/membound.h @@ -7,9 +7,10 @@ namespace infini { class MemBoundObj : public OperatorObj { private: std::vector nnetInputs; - nnet::Expr expr; + nnet::Expr expr, simplifiedExpr; double exec_time; std::string hint; + HashType hash, simplifiedHash; int n, f, h, w; public: @@ -26,11 +27,15 @@ class MemBoundObj : public OperatorObj { int numOutputs() const override { return outputs.size(); } const vector &getNnetInputs() const { return nnetInputs; } const nnet::Expr getNnetExpr() const { return expr; } + pair getSimplifiedNnetExpr() const { + return {expr, hash}; + } private: vector getWorkloadVector() const override; vector getOpAttrVector() const override; - HashType getHash() const; + static HashType calcHash(nnet::Expr expr); + static bool checkOOB(nnet::Expr expr); }; } // namespace infini diff --git a/include/utils/data_generator.h b/include/utils/data_generator.h index 6a106d2e..89d8b84c 100644 --- a/include/utils/data_generator.h +++ b/include/utils/data_generator.h @@ -1,5 +1,7 @@ +#pragma once #include "core/common.h" #include "core/tensor_base.h" +#include namespace infini { @@ -38,6 +40,31 @@ class IncrementalGenerator : public DataGenerator { void fill(float *data, size_t size) override { fill(data, size); } }; +class RandomGenerator : public DataGenerator { + private: + double l, r; + std::mt19937 e; + std::uniform_int_distribution di; + std::uniform_real_distribution dr; + + public: + RandomGenerator(double l = 0, double r = 1, unsigned int seed = 0) + : l(l), r(r), e(seed), di(l, r), dr(l, r) {} + virtual ~RandomGenerator() {} + + private: + void fill(uint32_t *data, size_t size) override { + for (size_t i = 0; i < size; i++) { + data[i] = di(e); + } + } + void fill(float *data, size_t size) override { + for (size_t i = 0; i < size; i++) { + data[i] = dr(e); + } + } +}; + template class ValGenerator : public DataGenerator { public: virtual ~ValGenerator() {} diff --git a/python/cpp_plugin/__init__.py b/python/cpp_plugin/__init__.py index 811587de..af4c9b40 100644 --- a/python/cpp_plugin/__init__.py +++ b/python/cpp_plugin/__init__.py @@ -1 +1,2 @@ from .gen_ansor_op import gen_ansor_op +from .gen_ansor_so import gen_ansor_so diff --git a/python/cpp_plugin/gen_ansor_op.py b/python/cpp_plugin/gen_ansor_op.py index 6d9a0964..204816ad 100644 --- a/python/cpp_plugin/gen_ansor_op.py +++ b/python/cpp_plugin/gen_ansor_op.py @@ -3,19 +3,50 @@ import re import numpy as np import tvm from tvm import te, tir, auto_scheduler, topi +import os +import json +import logging + +USE_CACHE = True +logger = logging.getLogger('InfiniTensor') +logger.setLevel(logging.DEBUG) -def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name): +def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, + func_name, input_names, output_name, nnet_expression: str, + nnet_simplified_expression: str, hash_code=None): assert len(input_tensors) == len(input_dtypes) assert len(input_tensors) == len(input_names) + logging.debug(f'Work on hash {hash_code}') + + dir_name = os.path.join(".cache", "generated_kernels", str(hash_code)) + func_code_fn = os.path.join(dir_name, "kernel.cu") + invoke_code_fn = os.path.join(dir_name, "invoke.cpp") + config_fn = os.path.join(dir_name, "config.json") + + if USE_CACHE and hash_code is not None: + if os.path.exists(dir_name): + print(f"Use cache in {dir_name}") + with open(func_code_fn, "r") as func_code_fin: + func_code = func_code_fin.read() + with open(invoke_code_fn, "r") as invoke_code_fin: + invoke_code = invoke_code_fin.read() + with open(config_fn, "r") as config_fin: + config = json.loads(config_fin.read().strip()) + conv_time = config["conv_time"] + invoke_params = config["invoke_params"] + + logger.debug(f'Find tuning log for {hash_code}') + return func_code, invoke_code, conv_time, invoke_params + print("Generating Ansor op: ") print(f) @auto_scheduler.register_workload(func_name) def compute(): _locals = locals() - exec(f, {'tvm': tvm, 'te': te, 'tir': tir}, _locals) + exec(f, {'tvm': tvm, 'te': te, 'tir': tir, 'topi': topi}, _locals) return _locals['ret'] target = tvm.target.Target("cuda") @@ -43,6 +74,28 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu # Kill the measurement process del measure_ctx + def test_mutator(): + # test part + tgt_temp = tvm.target.Target(target="llvm", host="llvm") + all_tensors = compute() + sch = te.create_schedule(all_tensors[0].op) + args = all_tensors + C0, K0, A0 = args + func_temp = tvm.build(sch, args, tgt_temp, name="temp") + + # print result + n, c, h, w, f, r, s = 1, 1, 2, 2, 1, 4, 4 + dev_temp = tvm.device(tgt_temp.kind.name, 0) + A_temp = tvm.nd.array( + np.arange(n*h*w*f).reshape(n, h, w, f).astype(A0.dtype), dev_temp) + K_temp = tvm.nd.array( + np.arange(f*r*s*c).reshape(f, r, s, c).astype(K0.dtype), dev_temp) + C_temp = tvm.nd.array( + np.zeros((1, 4, 4, 1)).astype(C0.dtype), dev_temp) + func_temp(C_temp, K_temp, A_temp) + print("================= Test Result =====================") + print(C_temp) + ir = str(tvm.lower(sch, args, simple_mode=True)) thread_dim = [1, 1, 1] block_dim = [1, 1, 1] @@ -83,11 +136,27 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu print("Func Code") # Attach TVM code behind func_code - func_code += "\n/* " + f + "*/" + func_code += "\n/* NNET tensor expression \n" + nnet_expression + "\n*/\n" + func_code += "\n/* NNET simplified tensor expression \n" + \ + nnet_simplified_expression + "\n*/\n" + func_code += "\n/* TVM compute\n" + f + "\n*/\n" print(func_code) print("Invoke Code") print(invoke_code) print("Time") print(conv_time) + if hash_code is not None: + if not os.path.exists(dir_name): + os.makedirs(dir_name) + with open(func_code_fn, "w") as func_code_fout: + func_code_fout.write(func_code) + with open(invoke_code_fn, "w") as invoke_code_fout: + invoke_code_fout.write(invoke_code) + with open(config_fn, "w") as config_fout: + config_fout.write(json.dumps({ + "conv_time": conv_time, + "invoke_params": invoke_params + }, ensure_ascii=False, indent=2)) + return func_code, invoke_code, conv_time, invoke_params # ms diff --git a/python/cpp_plugin/gen_ansor_so.py b/python/cpp_plugin/gen_ansor_so.py new file mode 100644 index 00000000..d06bc3be --- /dev/null +++ b/python/cpp_plugin/gen_ansor_so.py @@ -0,0 +1,106 @@ +import re + +import numpy as np +import tvm +from tvm import te, tir, auto_scheduler, topi +import os +import json +import logging + +USE_CACHE = True +logger = logging.getLogger('InfiniTensor') +logger.setLevel(logging.DEBUG) + + +def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype, + tvm_code, func_name, nnet_expression: str, + nnet_simplified_expression: str, hash_code=None): + assert len(input_tensors) == len(input_dtypes) + + logging.debug(f'Work on hash {hash_code}') + dir_name = os.path.join(".cache", "generated_kernels", str(hash_code)) + + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + so_fn = os.path.join(dir_name, f"{func_name}.so") + config_fn = os.path.join(dir_name, "config_so.json") + + print("Generating Ansor op: ") + print(tvm_code) + + print("Input shape: ") + print(input_tensors) + print("Output shape: ") + print(output_tensor) + + if USE_CACHE and hash_code is not None: + if os.path.exists(dir_name) and \ + os.path.exists(so_fn) and \ + os.path.exists(config_fn): + print(f"Use cache in {dir_name}") + with open(config_fn, "r") as config_fin: + config = json.loads(config_fin.read().strip()) + conv_time = config["conv_time"] + + logger.debug(f'Find tuning log for {hash_code}') + return so_fn, conv_time + + @auto_scheduler.register_workload(func_name) + def compute(): + _locals = locals() + exec(tvm_code, {'tvm': tvm, 'te': te, 'tir': tir, 'topi': topi}, _locals) + return _locals['ret'] + + target = tvm.target.Target("cuda") + + task = auto_scheduler.SearchTask(func=func_name, args=(), target=target) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + + log_file = f"ansor_{func_name}_log.json" + measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + runner=measure_ctx.runner, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + ) + + # Run auto-tuning (search) + task.tune(tune_option) + # Apply the best schedule + sch, args = task.apply_best(log_file) + + # Kill the measurement process + del measure_ctx + + func = tvm.build(sch, args, target, name=func_name) + func.export_library(so_fn) + + ctx = tvm.cuda(0) + input_a = [] + for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)): + a_np = np.random.uniform(size=shape).astype(dtype) + input_a.append(tvm.nd.array(a_np, ctx)) + a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx) + func(a_out, *input_a) + evaluator = func.time_evaluator(func.entry_name, ctx, number=100) + conv_time = evaluator(a_out, *input_a).mean * 1e3 + + print("====NNET tensor expression====") + print(nnet_expression+"\n") + print("====NNET simplified tensor expression====") + print(nnet_simplified_expression+"\n") + print("====Time====") + print(conv_time) + + if USE_CACHE and hash_code is not None: + with open(config_fn, "w") as config_fout: + config_fout.write(json.dumps({ + "conv_time": conv_time, + }, ensure_ascii=False, indent=2)) + + return so_fn, conv_time diff --git a/src/core/graph.cc b/src/core/graph.cc index 3bc525d4..90147320 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -11,13 +11,11 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) for (const auto &op : ops_in) { for (const auto &t : op->getInputs()) if (tensorPool.find(t->getFuid()) == tensorPool.end()) - tensorPool[t->getFuid()] = t->clone(); + tensorPool[t->getFuid()] = cloneTensor(t); for (const auto &t : op->getOutputs()) if (tensorPool.find(t->getFuid()) == tensorPool.end()) - tensorPool[t->getFuid()] = t->clone(); + tensorPool[t->getFuid()] = cloneTensor(t); } - for (const auto &[_, t] : tensorPool) - addTensor(t); // Clone operators and add connections for (const auto &op : ops_in) { TensorVec inputs, outputs; @@ -127,8 +125,12 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) { } Tensor GraphObj::addTensor(const Tensor &tensor) { - IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch"); - return tensors.emplace_back(tensor); + IT_ASSERT(tensor->getRuntime() == runtime, + std::string("Tensor runtime mismatch: cannot add a tenosr in ") + + tensor->getRuntime()->toString() + " to " + + runtime->toString()); + tensors.emplace_back(tensor); + return tensor; } TensorVec GraphObj::addTensor(const TensorVec &tensors) { @@ -207,6 +209,13 @@ bool GraphObj::checkValid() const { IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end()); } } + std::set s; + // check whether two tensors with the same FUID exist + for (auto tensor : tensors) { + int cnt = s.count(tensor->getFuid()); + IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid())); + s.insert(tensor->getFuid()); + } return true; } diff --git a/src/core/operator.cc b/src/core/operator.cc index e707d94e..cea23321 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -24,8 +24,8 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; } bool OperatorObj::isComputeOp() const { return type == OpType::Conv || type == OpType::Matmul || - type == OpType::ConvTrans || type == OpType::G2BMM || - type == OpType::GBMM; + type == OpType::ConvTrans || type == OpType::ConvTransNHWC || + type == OpType::G2BMM || type == OpType::GBMM; } bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; } @@ -92,7 +92,7 @@ bool OperatorObj::checkValid(GraphObj *graph) { if (graph) { // if graph != nullptr, outputs should be created auto dataTypes = inferDataType(); for (size_t i = 0; i < outputs.size(); i++) { - IT_ASSERT(!outputs[i]); + IT_ASSERT(!outputs[i], "Find empty output while operator creation"); outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); } } else { // if outputs have been created, check their shapes diff --git a/src/core/tensor.cc b/src/core/tensor.cc index cdcd9e28..fd5ddde4 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -16,9 +16,16 @@ TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime) [](auto acc, auto x) { return acc * x; })) {} string TensorObj::toString() const { + // Convert data pointer to string + std::stringstream ss; + if (data != nullptr) + ss << data->getPtr(); + else + ss << "nullptr data"; string ret = "Tensor " + std::to_string(guid) + ", Fuid " + std::to_string(fuid) + ", shape " + vecToString(shape) + - ", dtype " + dtype.toString(); + ", dtype " + dtype.toString() + ", " + runtime->toString() + + ", " + ss.str() + "\n"; vector targetGuids; for (const auto &op : targets) targetGuids.emplace_back(op.lock()->getGuid()); @@ -57,25 +64,36 @@ vector TensorObj::getStride() const { void TensorObj::printData() const { IT_ASSERT(data != nullptr); - if (!runtime->isCpu()) - IT_TODO_HALT(); + void *ptr = nullptr; + Blob buffer; + if (!runtime->isCpu()) { + buffer = NativeCpuRuntimeObj::getInstance()->allocBlob(getBytes()); + runtime->copyBlobToCPU(buffer->getPtr(), + getRawDataPtr(), getBytes()); + ptr = buffer->getPtr(); + } else + ptr = data->getPtr(); if (dtype == DataType::Float32) - printDataFloat(); + printDataFloat(static_cast(ptr)); else if (dtype == DataType::UInt32) - printDataUint32_t(); + printDataUint32_t(static_cast(ptr)); else IT_TODO_HALT(); } -void TensorObj::printDataFloat() const { +void TensorObj::printDataFloat(float *ptr) const { std::cout << "Tensor: " << guid << std::endl; auto numDims = shape.size(); auto dimSzVec = std::vector(numDims, 1); - auto ptr = data->getPtr(); dimSzVec[numDims - 1] = shape[numDims - 1]; for (int i = numDims - 1; i != 0; --i) dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1]; for (size_t i = 0, iEnd = size(); i < iEnd; ++i) { + if (iEnd > 1000 && i > 20 && i < iEnd - 20) { + printf("... , "); + i = iEnd - 20; + continue; + } for (size_t j = 0; j < numDims; ++j) { if (i % dimSzVec[j] == 0) { std::cout << "["; @@ -94,12 +112,11 @@ void TensorObj::printDataFloat() const { } } -void TensorObj::printDataUint32_t() const { +void TensorObj::printDataUint32_t(uint32_t *ptr) const { IT_ASSERT(data != nullptr); std::cout << "Tensor: " << guid << std::endl; auto numDims = shape.size(); auto dimSzVec = std::vector(numDims, 1); - auto ptr = data->getPtr(); dimSzVec[numDims - 1] = shape[numDims - 1]; for (int i = numDims - 1; i != 0; --i) dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1]; @@ -122,7 +139,7 @@ void TensorObj::printDataUint32_t() const { } } -bool TensorObj::equalData(const Tensor &rhs) const { +bool TensorObj::equalData(const Tensor &rhs, double relativeError) const { IT_ASSERT(data != nullptr); IT_ASSERT(rhs->data != nullptr); IT_ASSERT(getDType() == rhs->getDType()); @@ -132,10 +149,11 @@ bool TensorObj::equalData(const Tensor &rhs) const { return false; if (getDType() == DataType::UInt32) return equalDataImpl(getRawDataPtr(), - rhs->getRawDataPtr(), size()); + rhs->getRawDataPtr(), size(), 0); else if (getDType() == DataType::Float32) return equalDataImpl(getRawDataPtr(), - rhs->getRawDataPtr(), size()); + rhs->getRawDataPtr(), size(), + relativeError); else IT_TODO_HALT(); } diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index cee15379..0b15e4b6 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -50,10 +50,23 @@ class matmulCublas : public Kernel { // TODO:use compute type cublasStatus_t stat; if (b > 1) { + // Support batch broadcast with zero stride + int dimA = op->getInputs(0)->getDims().size(); + int dimB = op->getInputs(1)->getDims().size(); + long long strideA = + (dimA == 2 || + (dimA == 3 && op->getInputs(0)->getDims()[0] == 1)) + ? 0 // Broadcast the batch dimension if batch size is 1 + : m * k; + long long strideB = + (dimB == 2 || + (dimB == 3 && op->getInputs(1)->getDims()[0] == 1)) + ? 0 // Broadcast the batch dimension if batch size is 1 + : n * k; stat = cublasGemmStridedBatchedEx( context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, - CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta, - outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, + CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA, + &beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, (cublasGemmAlgo_t)record->algo); } else { stat = cublasGemmEx( @@ -61,6 +74,8 @@ class matmulCublas : public Kernel { CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData, CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo); } + // if (stat != CUBLAS_STATUS_SUCCESS) + // cout << cublasGetErrorString(stat); return (stat == CUBLAS_STATUS_SUCCESS); } @@ -79,6 +94,8 @@ class matmulCublas : public Kernel { const RuntimeObj *_context) const override { auto context = dynamic_cast(_context); auto op = as(_op); + IT_ASSERT(context); + IT_ASSERT(op); auto ret = make_ref(); ret->time = std::numeric_limits::max(); for (int i = 0; i < N_ALGO; i++) { @@ -91,9 +108,8 @@ class matmulCublas : public Kernel { if (rcd->time < ret->time) ret = rcd; } - IT_ASSERT(ret->time < std::numeric_limits::max(), "No valid " - "algorithm " - "found"); + IT_ASSERT(ret->time < std::numeric_limits::max(), + "No valid algorithm found for " + op->toString()); return ret; } }; diff --git a/src/kernels/cuda/membound_TVM.cc b/src/kernels/cuda/membound_tvm_extract_source.cc similarity index 71% rename from src/kernels/cuda/membound_TVM.cc rename to src/kernels/cuda/membound_tvm_extract_source.cc index 16c44a72..e4b76e60 100644 --- a/src/kernels/cuda/membound_TVM.cc +++ b/src/kernels/cuda/membound_tvm_extract_source.cc @@ -1,7 +1,11 @@ +#ifdef INFINI_USE_TVM #include "core/kernel.h" #include "cuda/cuda_runtime.h" #include "ffi/ffi_embed.h" #include "nnet/Visitor/AsTVMVisitor.h" +#include "nnet/Visitor/CheckOOBVisitor.h" +#include "nnet/Visitor/HashVisitor.h" +#include "nnet/Visitor/MergeMemboundMutator.h" #include "nvrtc.h" #include "operators/membound.h" #include "operators/pooling.h" @@ -17,11 +21,12 @@ class TVMRecordObj : public PerfRecordObj { std::string log, ptx; std::vector invokeParams; std::string kernelName; + HashType simplifiedExprHash; }; using TVMRecord = Ref; -class MemboundTVM : public Kernel { +class MemboundTVMExtractSource : public Kernel { public: void compute(const Operator &_op, const PerfRecord &record, const RuntimeObj *_context) const override { @@ -65,6 +70,11 @@ class MemboundTVM : public Kernel { return "var_" + std::to_string(t->getGuid()); } + bool checkOOB(nnet::Expr expr) const { + return nnet::CheckOOBVisitor().checkRangeOp( + nnet::as(expr)); + } + // Premise: op is idempotent since it is called multiple times. PerfRecord tune(const Operator &_op, const RuntimeObj *_context) const override { @@ -73,10 +83,18 @@ class MemboundTVM : public Kernel { auto context = dynamic_cast(_context); // invoke Ansor to tune a membound kernel - std::string func = "mem_bound_" + std::to_string(op->getGuid()); - std::string kernelName = func + "_kernel0"; nnet::AsTVMVisitor visitor; - visitor.dispatch(op->getNnetExpr()); + IT_ASSERT(!checkOOB(op->getNnetExpr())); + // fuse stages in nnet expr to reduce kernels generated by TVM + auto expr = op->getNnetExpr(); + if (auto mergedExpr = + nnet::MergeMemboundMutator({expr}).merge(false, true)) + expr = mergedExpr; + + nnet::HashVisitor hashVisitor; + HashType hashCode = hashVisitor.getHash(expr); + + visitor.dispatch(expr); auto &&stmts = visitor.getStmts(); auto &&inShapes = visitor.getInputShapes(); auto &&outShape = visitor.getOutputShape(); @@ -85,10 +103,14 @@ class MemboundTVM : public Kernel { for (auto &&in : op->getInputs()) { inputs.emplace_back(getVarName(in)); } - std::string output = getVarName(op->getOutput()); + const std::string output = getVarName(op->getOutput()); + + const std::string func = "membound_" + std::to_string(hashCode); + const std::string kernelName = func + "_kernel0"; auto res = getAnsorCode( inShapes, std::vector(inShapes.size(), "float32"), - outShape, "float32", stmts, func, inputs, output); + outShape, "float32", stmts, func, inputs, output, op->toString(), + expr->toReadable(), hashCode); // compile the kernel auto funcCode = res.first; @@ -119,6 +141,7 @@ class MemboundTVM : public Kernel { nvrtcGetPTX(prog, ret->ptx.data()); ret->invokeParams = invokeParams; ret->kernelName = kernelName; + ret->simplifiedExprHash = hashCode; // prepare for evaluation CUmodule module; @@ -151,20 +174,43 @@ class MemboundTVM : public Kernel { return std::dynamic_pointer_cast(ret); } + /// @brief + /// @param inDims + /// @param inDTypes + /// @param outDims + /// @param outDType + /// @param lambda + /// @param funcName Generated function name + /// @param inputNames Input array names in the generated invocation code. + /// @param outputName Output array names in the generated invocation code. + /// @param nnetExpressionString Save expr in string for logging. + /// @param nnetSimplifiedExprString Save simplified expr in string for + /// logging. + /// @param hashCode (optional) Hash code of the input expression for kernel + /// cache. + /// @return std::pair> getAnsorCode(const std::vector> &inDims, const std::vector &inDTypes, const std::vector &outDims, const std::string &outDType, const std::string &lambda, const std::string &funcName, const std::vector &inputNames, - const std::string &outputName) const { + const std::string &outputName, + const std::string &nnetExprString, + const std::string &nnetSimplifiedExprString, + const HashType hashCode) const { std::string funcCode; std::vector invokeParams; try { start_interpreter(); - auto func = py::module::import("cpp_plugin").attr("gen_ansor_op"); - py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda, - funcName, inputNames, outputName); + // Use static to avoid re-importing the module. Re-importing results + // in cuBLAS failure, whose root cause is not identified yet. + static auto func = + py::module::import("cpp_plugin").attr("gen_ansor_op"); + py::tuple code = + func(inDims, inDTypes, outDims, outDType, lambda, funcName, + inputNames, outputName, nnetExprString, + nnetSimplifiedExprString, std::to_string(hashCode)); funcCode = py::str(code[0]); auto temp = py::list(code[3]); for (int i = 0; i < 6; ++i) { @@ -183,6 +229,9 @@ class MemboundTVM : public Kernel { } }; -REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, MemboundTVM, - "Memobund_TVM_Ansor"); +// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, +// MemboundTVMExtractSource, +// "Memobund_TVM_Ansor_extract_source"); }; // namespace infini + +#endif diff --git a/src/kernels/cuda/membound_tvm_packed_function.cc b/src/kernels/cuda/membound_tvm_packed_function.cc new file mode 100644 index 00000000..8086518d --- /dev/null +++ b/src/kernels/cuda/membound_tvm_packed_function.cc @@ -0,0 +1,224 @@ +#ifdef INFINI_USE_TVM +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" +#include "dlpack/dlpack.h" +#include "ffi/ffi_embed.h" +#include "nnet/Visitor/AsTVMVisitor.h" +#include "operators/membound.h" +#include "operators/pooling.h" +#include "tvm/runtime/module.h" +#include "tvm/runtime/packed_func.h" + +namespace py = pybind11; + +namespace infini { + +using DLTensorHolder = pair>>; + +class TVMRecordObj : public PerfRecordObj { + public: + std::string kernelName; + HashType simplifiedExprHash; + std::string dllPath; + std::string funcName; + std::vector inputIdx; +}; + +using TVMRecord = Ref; + +class MemboundTVMPackedFunction : public Kernel { + public: + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + // auto context = dynamic_cast(_context); + auto tvmRecord = std::dynamic_pointer_cast(record); + tvm::runtime::PackedFunc packedFunc = + getPackedFunction(tvmRecord->dllPath, tvmRecord->funcName); + IT_ASSERT(packedFunc != nullptr); + + // prepare inputs and outputs + vector inputsHolder; + for (auto idx : tvmRecord->inputIdx) { + inputsHolder.emplace_back( + convertTensorToDLTensor(op->getInputs()[idx])); + } + DLTensorHolder outputHolder = convertTensorToDLTensor(op->getOutput()); + + // make tvm arg and rv + pair, vector> preArgs = + convertInOutToTVMArgs(inputsHolder, outputHolder); + tvm::runtime::TVMRetValue rv; + tvm::runtime::TVMArgs args(preArgs.first.data(), preArgs.second.data(), + preArgs.first.size()); + + packedFunc.CallPacked(args, &rv); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + IT_ASSERT(false, "A TVM record is required for membound kernel."); + } + + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + TVMRecord ret = std::make_shared(); + auto op = as(_op); + auto context = dynamic_cast(_context); + + // invoke Ansor to tune a membound kernel + auto [expr, hash] = op->getSimplifiedNnetExpr(); + nnet::AsTVMVisitor visitor; + visitor.dispatch(expr); + auto &&stmts = visitor.getStmts(); + auto &&inShapes = visitor.getInputShapes(); + auto &&outShape = visitor.getOutputShape(); + + const std::string func = "membound_" + std::to_string(hash); + const std::string kernelName = func + "_kernel0"; + // Set the dllPath directly when debugging + auto dllPath = getAnsorDLL( + inShapes, std::vector(inShapes.size(), "float32"), + outShape, "float32", stmts, func, op->toString(), + expr->toReadable(), hash); + + // remap input + vector inputIdx; + int numInputs = op->getInputs().size(); + for (int i = 0; i < numInputs; ++i) { + string inputName = visitor.getInputs()[i]; + int j = 0; + for (; j < numInputs; ++j) { + if (inputName == op->getNnetInputs()[j]->getName()) + break; + } + inputIdx.emplace_back(j); + } + + tvm::runtime::PackedFunc packedFunc = getPackedFunction(dllPath, func); + IT_ASSERT(packedFunc != nullptr); + + // prepare inputs and outputs + vector inputsHolder; + for (auto idx : inputIdx) { + inputsHolder.emplace_back( + convertTensorToDLTensor(op->getInputs()[idx])); + } + DLTensorHolder outputHolder = convertTensorToDLTensor(op->getOutput()); + + // make tvm arg and rv + pair, vector> preArgs = + convertInOutToTVMArgs(inputsHolder, outputHolder); + tvm::runtime::TVMRetValue rv; + tvm::runtime::TVMArgs args(preArgs.first.data(), preArgs.second.data(), + preArgs.first.size()); + + ret->time = timeit([&]() { packedFunc.CallPacked(args, &rv); }, + [&]() { context->sync(); }); + ret->kernelName = kernelName; + ret->dllPath = dllPath; + ret->funcName = func; + ret->inputIdx = inputIdx; + + return std::dynamic_pointer_cast(ret); + } + + /// @brief + /// @param inDims + /// @param inDTypes + /// @param outDims + /// @param outDType + /// @param lambda + /// @param funcName Generated function name + /// @param nnetExpressionString Save expr in string for logging. + /// @param nnetSimplifiedExprString Save simplified expr in string for + /// logging. + /// @param hashCode (optional) Hash code of the input expression for kernel + /// cache. + /// @return + std::string getAnsorDLL(const std::vector> &inDims, + const std::vector &inDTypes, + const std::vector &outDims, + const std::string &outDType, + const std::string &lambda, + const std::string &funcName, + const std::string &nnetExprString, + const std::string &nnetSimplifiedExprString, + const HashType hashCode) const { + std::string dllPath; + try { + start_interpreter(); + // Use static to avoid re-importing the module. Re-importing results + // in cuBLAS failure, whose root cause is not identified yet. + static auto func = + py::module::import("cpp_plugin").attr("gen_ansor_so"); + py::tuple code = + func(inDims, inDTypes, outDims, outDType, lambda, funcName, + nnetExprString, nnetSimplifiedExprString, + std::to_string(hashCode)); + dllPath = py::str(code[0]); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ImportError)) { + std::cerr << "Import Error. Don't forget to set environment " + "variable PYTHONPATH to contain " + "/python" + << std::endl; + } + throw; + } + + return dllPath; + } + + tvm::runtime::PackedFunc getPackedFunction(string path, + string functionName) const { + tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path); + return mod.GetFunction(functionName); + } + + DLTensorHolder convertTensorToDLTensor(const Tensor &tensor) const { + IT_ASSERT(tensor->getRuntime()->isCuda()); + // The lifecycle of shapeInt64 is managed by the caller. + auto shapeInt64 = make_ref>(); + for (auto v : tensor->getDims()) + shapeInt64->push_back(v); + DLTensor ret{ + .data = tensor->getRawDataPtr(), + .device = DLDevice{.device_type = kDLCUDA, .device_id = 0}, + .ndim = (int32_t)shapeInt64->size(), + .dtype = + DLDataType{.code = (uint8_t)kDLFloat, .bits = 32, .lanes = 1}, + .shape = static_cast(shapeInt64->data()), + .strides = nullptr, + .byte_offset = 0, + }; + return {ret, shapeInt64}; + } + + pair, vector> + convertInOutToTVMArgs(const vector &inputs, + const DLTensorHolder &output) const { + vector values; + vector type_codes; + + // The order of inputs and outputs is consistant with definition of TVM + // computation in Python, which is determined by AsTVMVisitor. + values.emplace_back(TVMValue{.v_handle = (void *)&output.first}); + type_codes.emplace_back(kTVMDLTensorHandle); + + for (auto &in : inputs) { + values.emplace_back(TVMValue{.v_handle = (void *)&in.first}); + type_codes.emplace_back(kTVMDLTensorHandle); + } + + return {values, type_codes}; + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, + MemboundTVMPackedFunction, + "Memobund_TVM_Ansor_packed_funciton"); +}; // namespace infini + +#endif diff --git a/src/nnet/Visitor/AsTVMVisitor.cc b/src/nnet/Visitor/AsTVMVisitor.cc index 5044eb9e..45456944 100644 --- a/src/nnet/Visitor/AsTVMVisitor.cc +++ b/src/nnet/Visitor/AsTVMVisitor.cc @@ -27,13 +27,16 @@ std::string AsTVMVisitor::visit_(const BinaryOp &c) { } } std::string AsTVMVisitor::visit_(const Func &c) { + string nested = dispatch(c->getObject()); switch (c->getFuncType()) { case FuncType::Relu: // TODO: Deduce the dtype - return "te.max(" + dispatch(c->getObject()) + - ", tvm.tir.const(0, 'float32'))"; + return "te.max(" + nested + ", tvm.tir.const(0, 'float32'))"; case FuncType::Tanh: - return "te.tanh(" + dispatch(c->getObject()) + ")"; + return "te.tanh(" + nested + ")"; + case FuncType::PRelu: + return "tir.if_then_else(0.0 < " + nested + ", " + nested + + ", (0.25 * " + nested + "))"; default: assert(false); } @@ -114,6 +117,11 @@ std::string AsTVMVisitor::visit_(const Subscript &c) { str += " - " + std::to_string(rangeOp->getLoopVarRanges()[i].second.first - rangeOp->getPaddings(i)); + } else if (c->getObject()->getType() == NodeType::TensorNodeType) { + auto tensor = as(c->getObject()); + if (auto pad_i = tensor->getPadding(i); pad_i > 0) { + str += " + " + std::to_string(pad_i); + } } } str += "]"; @@ -138,6 +146,24 @@ std::string AsTVMVisitor::visit_(const Tensor &c) { } stmt += "), name='" + c->getName() + "')"; stmts += stmt + "\n"; + + if (c->hasPadding()) { + std::string name_after_pad = "pad_" + c->getName(); + pythonVars.emplace_back(name_after_pad); + // inputs.emplace_back(name_after_pad); + std::string pad_tuple = "("; + for (auto pad : c->getPaddings()) { + pad_tuple += std::to_string(pad) + ", "; + } + pad_tuple += ")"; + + std::string pad_stmt = name_after_pad + " = " + "topi.nn.pad(" + + c->getName() + ", " + pad_tuple + ", " + + pad_tuple + ", 0.0, \"" + name_after_pad + "\")"; + stmts += pad_stmt + "\n"; + return name_after_pad; + } + return c->getName(); } std::string AsTVMVisitor::getStmts() const { diff --git a/src/nnet/Visitor/MergeMemboundMutator.cc b/src/nnet/Visitor/MergeMemboundMutator.cc index 9fce6928..1b521bde 100644 --- a/src/nnet/Visitor/MergeMemboundMutator.cc +++ b/src/nnet/Visitor/MergeMemboundMutator.cc @@ -5,7 +5,7 @@ namespace nnet { -Expr MergeMemboundMutator::merge(bool allowEmptyMembound) { +Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) { // FIXME: fix empty expression in membound assert(kernels.size() >= 1); if (checkEmpty()) { @@ -27,19 +27,30 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound) { assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false); auto summand = curRangeOp->getSummand(); if (auto subscriptOp = as(summand)) { + // Try merging the current and next stages if (auto mergedExpr = rule4StageMerging(*curExpr, true)) { // dbg(*curExpr, mergedExpr); *curExpr = mergedExpr; merged = true; break; } + // If merging fails, try the next stage curExpr = subscriptOp->getObjectPtr(); nnet_assert(*curExpr != nullptr, __LINE__); } else if (auto funcOp = as(summand)) { - // Relu({...}[i,j]) - curExpr = funcOp->getObject()->getObjectPtr(); - } else - nnet_unimplemented_halt(); + // If the object of FuncNode is a subscript, like + // Relu({...}[i,j]), we can further merge it. Otherwise, like + // Relu(A[i]+B[j]), we cannot. + if (auto sub = as(funcOp->getObject())) + curExpr = sub->getObjectPtr(); + else + break; + } else { + if (allowFailure) + return nullptr; + else + nnet_unimplemented_halt(); + } } } while (merged); return expr; diff --git a/src/nnet/Visitor/hashVisitor.cc b/src/nnet/Visitor/hashVisitor.cc index 359e2335..7a7314c7 100644 --- a/src/nnet/Visitor/hashVisitor.cc +++ b/src/nnet/Visitor/hashVisitor.cc @@ -153,4 +153,11 @@ HashType HashVisitor::visit_(const Var &c) { return varHash[c]; } +HashType HashVisitor::visit_(const Func &c) { + HashType objHash = dispatch(c->getObject()); + return hash(binPrefix, + hash((((HashType)c->getFuncType()) + 10086), objHash)); + return 0; +} + } // namespace nnet \ No newline at end of file diff --git a/src/nnet/expr.cc b/src/nnet/expr.cc index ea25bd5b..09fbb60f 100644 --- a/src/nnet/expr.cc +++ b/src/nnet/expr.cc @@ -90,6 +90,14 @@ size_t TensorNode::getOffset(const vector &idx) { return offset; } +bool TensorNode::hasPadding() { + for (auto pad : paddings) { + if (pad > 0) + return true; + } + return false; +} + string RangeOpNode::toReadable() const { string ret; for (int i = 0; i < IterationType::NumIterationType; ++i) { @@ -264,10 +272,15 @@ string FuncNode::toReadable() const { ret += "Relu"; else if (funcType == FuncType::Tanh) ret += "Tanh"; + else if (funcType == FuncType::PRelu) + ret += "PRelu"; else nnet_unimplemented_halt(); - ret += "( ... " + serializeVec(object->getIndex()) + ")\n {" + - object->getObject()->toReadable() + "}"; + if (auto sub = as(object)) + ret += "( ... " + serializeVec(sub->getIndex()) + ")\n {" + + sub->getObject()->toReadable() + "}"; + else + ret += "(" + object->toReadable() + ")"; return ret; } @@ -380,6 +393,7 @@ int64_t TensorNode::getSize() const { size *= len; return size; } + int RangeOpNode::getPaddings(int dim) const { return dim < (int)paddings.size() ? paddings[dim] : 0; } @@ -445,8 +459,8 @@ vector RangeOpNode::getOutputRanges() const { } void FuncNode::setObject(Expr e) { - object = as(e); - nnet_assert(object, "Illegal subscripted object"); + nnet_assert(e->isScalar(), "FuncNode operates on scalars"); + object = e; } } // namespace nnet diff --git a/src/nnet/iterator_table.cc b/src/nnet/iterator_table.cc index b89769cd..4934b939 100644 --- a/src/nnet/iterator_table.cc +++ b/src/nnet/iterator_table.cc @@ -574,7 +574,8 @@ Expr ConvTransPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W, auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f}); auto subK = - makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c}); + // makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c}); + makeSubscript(K, {f, (R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, c}); // x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2 auto range1 = makeRangeOperator( diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index b26806c3..50575082 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -7,15 +7,18 @@ #include "operators/conv.h" #include "operators/matmul.h" #include "operators/membound.h" +#include "operators/reshape.h" namespace infini { NMutator::NMutator(Mode mode) : Mutator(10), mode{mode} { - IT_ASSERT(mode != Mode::RuleBased, "Use RuleBased in the other ctor."); + IT_ASSERT(mode != Mode::RuleBased, "Specify rules for the RuleBased mode."); } -NMutator::NMutator(const std::vector &derivationRules) - : Mutator(10), mode{Mode::RuleBased}, derivationRules{derivationRules} {} +NMutator::NMutator(Mode mode, const std::vector &derivationRules) + : Mutator(10), mode{Mode::RuleBased}, derivationRules{derivationRules} { + IT_ASSERT(mode == Mode::RuleBased); +} NMutator::~NMutator() {} @@ -69,9 +72,10 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph, } void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { - IT_TODO_HALT(); - // OpVec computeOps = in_graph->getComputeOps(); - // if (infini::Graph g = transformTConv1x1(computeOps[0])) { + OpVec computeOps = in_graph->getComputeOps(); + IT_ASSERT(computeOps.size() == 1); + + /* if (infini::Graph g = transformTConv1x1(computeOps[0])) { // out_graphs.emplace_back(g); // return; // } @@ -95,39 +99,40 @@ void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { // // out_graphs.emplace_back(graph); // // return; // // } + */ - // auto expr = opToExpression(computeOps[0]); - // if (!expr) - // return; + auto expr = opToExpression(computeOps[0]); + if (!expr) + return; - // nnet::Derivator derivator(maxDepth); - // nnet::Formula conv_9x9(expr, 0); - // // const std::vector rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; // - // Tconv - // // const std::vector rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM - // if (mode == Mode::Normal) { - // derivator.search(conv_9x9, 0); - // } else if (mode == Mode::RuleBased) { - // dbg(derivationRules); - // derivator.ruleBasedDFS(conv_9x9, 0, derivationRules); - // } else - // nnet_assert(0, "Unknown mode"); - // const auto &candidates = derivator.getCandidates(); - // dbg(candidates.size()); - // // derivator.print(); - // for (const auto &candidate : candidates) { - // // dbg(nnet::FullPrinterVisitor().print(candidate.root)); - // if (auto g = expressionToGraph(candidate.root, in_graph)) { - // out_graphs.emplace_back(g); - // } - // // break; // HACK:Debug only for the first subgraph + nnet::Derivator derivator(maxDepth); + nnet::Formula conv_9x9(expr, 0); + // const std::vector rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; + // ConvTraspose + // const std::vector rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM + if (mode == Mode::Normal) { + derivator.search(conv_9x9, 0); + } else if (mode == Mode::RuleBased) { + dbg(derivationRules); + derivator.ruleBasedDFS(conv_9x9, 0, derivationRules); + } else + IT_TODO_HALT_MSG("Unknown NMutator search mode."); + const auto &candidates = derivator.getCandidates(); + dbg(candidates.size()); + // derivator.print(); + for (const auto &candidate : candidates) { + // dbg(nnet::FullPrinterVisitor().print(candidate.root)); + if (auto g = expressionToGraph(candidate.root, in_graph)) { + out_graphs.emplace_back(g); + } + // break; // HACK:Debug only for the first subgraph + } + // dbg(out_graphs); + // for (auto graph : out_graphs) { + // graph->print(); // } - // // dbg(out_graphs); - // // for (auto graph : out_graphs) { - // // graph->print(); - // // } - // cntStates += derivator.getNumIntermediateStates(); - // cntCandidates += derivator.getNumCandidates(); + cntStates += derivator.getNumIntermediateStates(); + cntCandidates += derivator.getNumCandidates(); } void NMutator::runMultipleOps(Graph in_graph, std::vector &out_graphs) { @@ -245,7 +250,7 @@ nnet::Expr NMutator::opToExpression(Operator op) { std::vector{0, 0, ph, pw}); const auto K = nnet::makeTensor("K", KT->getDims()); return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s); - } else if (auto convOp = as(op)) { + } else if (auto convOp = as(op)) { const auto &AT = convOp->getInputs()[0]; const auto &KT = convOp->getInputs()[1]; inputsNameNToTensorT["A"] = AT; @@ -304,99 +309,119 @@ nnet::Expr NMutator::opToExpression(Operator op) { } infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) { - IT_TODO_HALT(); - // auto g = make_ref(); - // nnet::FullPrinterVisitor fullVisitor; - // const auto &tensorQueueN = fullVisitor.traverse(expr); - // // Build tensors: Skip the first one, which is output - // auto nameNToTensorT = inputsNameNToTensorT; - // for (size_t i = 1; i < tensorQueueN.size(); ++i) { - // const auto &[nameN, routineN, tensorN] = tensorQueueN[i]; - // // dbg(nameN, routineN, tensorN); - // if (!routineN) { - // // This is an inputs - // assert(nameNToTensorT.count(nameN)); - // } else { - // assert(!nameNToTensorT.count(nameN)); - // nameNToTensorT[nameN] = g->addTensor(tensorN->getShape()); - // } - // } - // const auto &outputsPET = in_graph->getOutputs(); - // if (outputsPET.size() != 1) { - // nnet_unimplemented_continue(); - // return nullptr; - // } - // nameNToTensorT[std::get<0>(tensorQueueN.at(0))] = outputsPET[0]; - // // Build computation graph in PET: - // for (int i = tensorQueueN.size() - 1; i >= 0; --i) { - // const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i]; - // if (!routineN) - // continue; - // // dbg(outputNameN, routineN, tensorN, routineN->getType()); - // if (auto op = nnet::as(routineN)) { - // // g->conv(i8, w9, 2, 2); - // std::vector inputsN = op->getInputs(); - // auto A = nameNToTensorT.at(inputsN[0]->getName()); - // auto K = nameNToTensorT.at(inputsN[1]->getName()); - // auto output = nameNToTensorT.at(outputNameN); - // const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs(); - // g->conv(A, K, output, ph, pw, sh, sw, dh, dw); - // } else if (auto op = nnet::as(routineN)) { - // assert(op->getInputs().size() == 1); - // nnet::MatchReshapeVisitor matchReshapeVisitor; - // if (matchReshapeVisitor(op->getExpr())) { - // auto input = - // nameNToTensorT.at(op->getInputs().at(0)->getName()); - // auto output = nameNToTensorT.at(outputNameN); - // g->reshape(input, output); - // } else { - // TensorVec inputsPET; - // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; - // for (const auto &inputN : op->getInputs()) - // inputsPET.emplace_back( - // nameNToTensorT.at(inputN->getName())); - // // Re-estimate time here. - // ssize_t cnt = 0; - // for (const auto tensor : inputsPET) - // cnt += tensor->size(); - // for (const auto tensor : outputsPET) - // cnt += tensor->size(); - // g->membound(inputsPET, outputsPET, op->getInputs(), - // op->getExpr(), memboundTime(cnt)); - // } - // } else if (auto op = nnet::as(routineN)) { - // assert(op->getInputs().size() == 2); - // nnet::Tensor AN = op->getInputs()[0]; - // nnet::Tensor BN = op->getInputs()[1]; - // TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), - // nameNToTensorT.at(BN->getName())}; - // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; - // const auto &[b, m, n, k, transa, transb] = op->getArgs(); - // g->matmul(inputsPET[0], inputsPET[1], outputsPET[0], transa, - // transb); - // } else if (auto op = nnet::as(routineN)) { - // assert(op->getInputs().size() == 2); - // nnet::Tensor AN = op->getInputs()[0]; - // nnet::Tensor BN = op->getInputs()[1]; - // TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), - // nameNToTensorT.at(BN->getName())}; - // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; - // const auto &[b, m, w, k, dilation] = op->getArgs(); - // g->g2bmm(inputsPET[0], inputsPET[1], outputsPET[0], w, dilation); - // } else if (auto op = nnet::as(routineN)) { - // assert(op->getInputs().size() == 2); - // nnet::Tensor AN = op->getInputs()[0]; - // nnet::Tensor BN = op->getInputs()[1]; - // TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), - // nameNToTensorT.at(BN->getName())}; - // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; - // const auto &[b, m, w, n, dilation] = op->getArgs(); - // g->gbmml(inputsPET[0], inputsPET[1], outputsPET[0], dilation); - // } - // } - // g->updateConnection(); - // Graph graph = new Graph(g->getOperators()); - // return graph; + auto g = make_ref(runtime); + nnet::FullPrinterVisitor fullVisitor; + // Get tensors in the reversed topological order + const auto &tensorQueueN = fullVisitor.traverse(expr); + dbg(fullVisitor.print(expr)); + + // Build a map: name in nnet -> tensors in infini + // Add input tensors to the map + std::map nameNToTensorT; + for (const auto &[k, v] : inputsNameNToTensorT) + nameNToTensorT[k] = g->cloneTensor(v); + + // Add output tensors to the map + const auto &outputsT = in_graph->getOutputs(); + if (outputsT.size() != 1) { + nnet_unimplemented_continue(); + return nullptr; + } + nameNToTensorT[std::get<0>(tensorQueueN.at(0))] = + g->cloneTensor(outputsT[0]); + // Skip the first tensor, which is output and should be created by clone + for (size_t i = 1; i < tensorQueueN.size(); ++i) { + const auto &[nameN, routineN, tensorN] = tensorQueueN[i]; + // dbg(nameN, routineN, tensorN); + if (!routineN) { + // this tensor is an input as it is not contrusted by a routine + IT_ASSERT(nameNToTensorT.count(nameN), + "Missing an input tensor in graph or a rountine for this " + "tensor."); + } else { // this tensor is an intermediate result + IT_ASSERT(!nameNToTensorT.count(nameN), + "An NNET tensor appears twice or it is an input tensor " + "with routine specified."); + nameNToTensorT[nameN] = g->addTensor(tensorN->getShape()); + } + } + + // Build computation graph in InfiniTensor + for (int i = tensorQueueN.size() - 1; i >= 0; --i) { + const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i]; + if (!routineN) + continue; + // dbg(outputNameN, routineN, tensorN, routineN->getType()); + if (auto op = nnet::as(routineN)) { + std::vector inputsN = op->getInputs(); + auto A = nameNToTensorT.at(inputsN[0]->getName()); + auto K = nameNToTensorT.at(inputsN[1]->getName()); + auto output = nameNToTensorT.at(outputNameN); + const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs(); + g->addOpWithOutputs(A, K, output, ph, pw, sh, sw, dh, dw); + } else if (auto op = nnet::as(routineN)) { + assert(op->getInputs().size() == 1); + nnet::MatchReshapeVisitor matchReshapeVisitor; + // If this routine only change the shape, translate it to a Reshape + if (matchReshapeVisitor(op->getExpr())) { + auto input = + nameNToTensorT.at(op->getInputs().at(0)->getName()); + auto output = nameNToTensorT.at(outputNameN); + g->addOpWithOutputs(input, output, + output->getDims()); + } else { + TensorVec inputsPET; + TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; + for (const auto &inputN : op->getInputs()) + inputsPET.emplace_back( + nameNToTensorT.at(inputN->getName())); + // Re-estimate time here. + ssize_t cnt = 0; + for (const auto &tensor : inputsPET) + cnt += tensor->size(); + for (const auto &tensor : outputsPET) + cnt += tensor->size(); + dbg(inputsPET, outputsPET, op->getInputs(), op->getExpr(), + memboundTime(cnt)); + g->addOpWithOutputs(inputsPET, outputsPET, + op->getInputs(), op->getExpr(), + memboundTime(cnt)); + } + } else if (auto op = nnet::as(routineN)) { + assert(op->getInputs().size() == 2); + nnet::Tensor AN = op->getInputs()[0]; + nnet::Tensor BN = op->getInputs()[1]; + TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), + nameNToTensorT.at(BN->getName())}; + TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; + const auto &[b, m, n, k, transa, transb] = op->getArgs(); + g->addOpWithOutputs(inputsPET[0], inputsPET[1], + outputsPET[0], transa, transb); + } + // TODO + // else if (auto op = nnet::as(routineN)) { + // assert(op->getInputs().size() == 2); + // nnet::Tensor AN = op->getInputs()[0]; + // nnet::Tensor BN = op->getInputs()[1]; + // TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), + // nameNToTensorT.at(BN->getName())}; + // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; + // const auto &[b, m, w, k, dilation] = op->getArgs(); + // g->g2bmm(inputsPET[0], inputsPET[1], outputsPET[0], w, dilation); + // } else if (auto op = nnet::as(routineN)) { + // assert(op->getInputs().size() == 2); + // nnet::Tensor AN = op->getInputs()[0]; + // nnet::Tensor BN = op->getInputs()[1]; + // TensorVec inputsPET = {nameNToTensorT.at(AN->getName()), + // nameNToTensorT.at(BN->getName())}; + // TensorVec outputsPET = {nameNToTensorT.at(outputNameN)}; + // const auto &[b, m, w, n, dilation] = op->getArgs(); + // g->gbmml(inputsPET[0], inputsPET[1], outputsPET[0], dilation); + // } + else + IT_TODO_HALT(); + } + return g; } double NMutator::memboundTime(ssize_t cnt) { diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 07708d07..4587efa8 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -5,22 +5,24 @@ namespace infini { MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, bool transB, [[maybe_unused]] Tensor bias, ActType act) : OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB), - act(act), b(1) { + act(act) { auto shape_a = A->getDims(); auto shape_b = B->getDims(); - IT_ASSERT(shape_a.size() == shape_b.size()); - switch (shape_a.size()) { - case 0: - case 1: - IT_ASSERT(false); - case 2: - break; - default: + int dimA = shape_a.size(), dimB = shape_b.size(); + IT_ASSERT(dimA >= 2 && dimB >= 2); + + b = 1; + if (dimA <= 3 && dimB <= 3) { + int b1 = dimA == 2 ? 1 : A->getDims()[0]; + int b2 = dimB == 2 ? 1 : B->getDims()[0]; + + b = std::max(b1, b2); + } else { + IT_ASSERT_TODO(dimA == dimB); for (size_t i = 0; i < shape_a.size() - 2; ++i) { - IT_ASSERT(shape_a[i] == shape_b[i]); + IT_ASSERT_TODO(shape_a[i] == shape_b[i]); b *= shape_a[i]; } - break; } m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1); n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin()); @@ -38,11 +40,44 @@ string MatmulObj::toString() const { } optional> MatmulObj::inferShape(const TensorVec &inputs) const { - auto shape_a = inputs[0]->getDims(); - auto it = shape_a.rbegin(); - *it++ = n; - *it++ = m; - return {{std::move(shape_a)}}; + auto A = inputs[0], B = inputs[1]; + int dimA = A->getDims().size(), dimB = B->getDims().size(); + + if (dimA > 3 || dimB > 3) { + // no broadcast + auto shape_a = inputs[0]->getDims(); + auto it = shape_a.rbegin(); + *it++ = n; + *it++ = m; + return {{std::move(shape_a)}}; + } + + int b1 = dimA == 2 ? 1 : A->getDims()[0]; + int b2 = dimB == 2 ? 1 : B->getDims()[0]; + + int b = std::max(b1, b2); + int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2]; + int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1]; + int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1]; + int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2]; + + if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) { + printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB); + return {}; + } + if (b1 != 1 && b2 != 1 && b1 != b2) { + printf("Bad batch size b1 = %d, b2 = %d\n", b1, b2); + return {}; + } + if (kA != kB) { + printf("Bad K: kA = %d, kB = %d\n", kA, kB); + return {}; + } + if (dimA == 2 && dimB == 2) { + return {{{m, n}}}; + } else { + return {{{b, m, n}}}; + } } vector MatmulObj::getWorkloadVector() const { diff --git a/src/operators/membound.cc b/src/operators/membound.cc index dc269742..ba69a5f5 100644 --- a/src/operators/membound.cc +++ b/src/operators/membound.cc @@ -1,5 +1,7 @@ #include "operators/membound.h" +#include "nnet/Visitor/CheckOOBVisitor.h" #include "nnet/Visitor/HashVisitor.h" +#include "nnet/Visitor/MergeMemboundMutator.h" namespace infini { @@ -10,6 +12,19 @@ MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input, : OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs), expr(expr), exec_time(exec_time), hint(hint) { IT_ASSERT(checkValid(graph)); + IT_ASSERT(!checkOOB(expr)); + hash = calcHash(expr); + + // fuse stages in nnet expr to reduce kernels generated by TVM + if (auto mergedExpr = + nnet::MergeMemboundMutator({expr}).merge(false, true)) { + simplifiedExpr = mergedExpr; + IT_ASSERT(!checkOOB(simplifiedExpr)); + simplifiedHash = calcHash(simplifiedExpr); + } else { + simplifiedExpr = expr; + simplifiedHash = hash; + } } string MemBoundObj::toString() const { @@ -31,8 +46,15 @@ string MemBoundObj::toString() const { os << "NNet Inputs=["; for (const auto &tensor : nnetInputs) os << tensor->toReadable() << ","; - os << "])"; - os << "\n" << (expr ? expr->toReadable() : "Empty expression") << "\n"; + os << "]"; + os << ", ExprHash=" << hash; + os << ", SimplifiedExprHash=" << simplifiedHash; + os << ")\n"; + os << ">>> Original expr\n" + << (expr ? expr->toReadable() : "Empty expression") << "\n"; + os << ">>> Simplified expr\n" + << (simplifiedExpr ? simplifiedExpr->toReadable() : "Empty expression") + << "\n"; return os.str(); } @@ -47,13 +69,18 @@ optional> MemBoundObj::inferShape(const TensorVec &inputs) const { } vector MemBoundObj::getWorkloadVector() const { - return {enum_to_underlying(type), (int)getHash()}; + return {enum_to_underlying(type), (int)simplifiedHash}; } vector MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); } -HashType MemBoundObj::getHash() const { +HashType MemBoundObj::calcHash(nnet::Expr expr) { return nnet::HashVisitor().dispatch(expr); } +bool MemBoundObj::checkOOB(nnet::Expr expr) { + return nnet::CheckOOBVisitor().checkRangeOp( + nnet::as(expr)); +} + } // namespace infini diff --git a/test/core/test_hash.cc b/test/core/test_hash.cc index 8c1e659a..c6098aab 100644 --- a/test/core/test_hash.cc +++ b/test/core/test_hash.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Hash, OperatorHash) { OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown); { // build with addOpWithOutputs - Graph g = make_ref(nullptr); + Graph g = make_ref(NativeCpuRuntimeObj::getInstance()); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); @@ -18,7 +18,7 @@ TEST(Hash, OperatorHash) { EXPECT_GT(key1.attrs.size(), (size_t)5); } { // build with addOp - Graph g = make_ref(nullptr); + Graph g = make_ref(NativeCpuRuntimeObj::getInstance()); Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32); auto matmul = g->addOp(i0, w0, nullptr); diff --git a/test/kernels/cuda/test_cuda_matmul.cc b/test/kernels/cuda/test_cuda_matmul.cc index f52fc2f1..805096c4 100644 --- a/test/kernels/cuda/test_cuda_matmul.cc +++ b/test/kernels/cuda/test_cuda_matmul.cc @@ -1,4 +1,3 @@ - #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" @@ -51,26 +50,38 @@ TEST(cuBLAS_Matmul, run) { Shape{2, 3, 4}, Shape{2, 3, 2}, ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424, 475, 448, 502, 472, 529}); + testMatmulCuda( + IncrementalGenerator(), IncrementalGenerator(), false, false, + Shape{2, 3, 5}, Shape{5, 2}, + ExpectOutput{60, 70, 160, 195, 260, 320, 360, 445, 460, 570, 560, 695}); + testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false, + Shape{2, 5, 3}, Shape{5, 2}, + ExpectOutput{180, 210, 200, 235, 220, 260, 480, 585, 500, + 610, 520, 635}); + testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), false, false, + Shape{3, 5}, Shape{5, 2}, + ExpectOutput{60, 70, 160, 195, 260, 320}); } TEST(cuBLAS_Matmul, tune) { - auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); - Graph gCpu = make_ref(cpuRuntime); - auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); - auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); - gCpu->dataMalloc(); - ACpu->setData(IncrementalGenerator()); - BCpu->setData(IncrementalGenerator()); - + // Matmul([A^T,B,act=0],A=597,B=595,C=598,bmnk=[1,4,4096,448]) + const int B = 1, M = 4, N = 4096, K = 448; + const bool transA = true, transB = false; auto cudaRuntime = make_ref(); - auto gCuda = make_ref(cudaRuntime); - auto ACuda = gCuda->cloneTensor(ACpu); - auto BCuda = gCuda->cloneTensor(BCpu); - auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); - + Graph g = make_ref(cudaRuntime); + auto a = g->addTensor(transA ? Shape{B, K, M} : Shape{B, M, K}); + auto b = g->addTensor(transB ? Shape{B, N, K} : Shape{B, K, N}); // allocate CUDA memory - gCuda->dataMalloc(); - cudaRuntime->run(gCuda, true); + g->dataMalloc(); + a->setData(IncrementalGenerator()); + b->setData(IncrementalGenerator()); + + auto matmul = g->addOp(a, b, nullptr, transA, transB); + matmul->print(); + double time = cudaRuntime->getPerfTime(g); + EXPECT_GT(time, 1e-3); + EXPECT_LT(time, 1); + cudaRuntime->run(g, true); } }; // namespace infini diff --git a/test/kernels/intelcpu/test_mkl_batch_norm.cc b/test/kernels/intelcpu/test_mkl_batch_norm.cc index 24c87474..8c620490 100644 --- a/test/kernels/intelcpu/test_mkl_batch_norm.cc +++ b/test/kernels/intelcpu/test_mkl_batch_norm.cc @@ -13,10 +13,10 @@ TEST(MklBatchNorm, run) { // Build graph Graph g = make_ref(runtime); auto i = g->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); - auto mean = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto var = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto scale = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); - auto bias = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto mean = g->addTensor(Shape{3}, DataType::Float32); + auto var = g->addTensor(Shape{3}, DataType::Float32); + auto scale = g->addTensor(Shape{3}, DataType::Float32); + auto bias = g->addTensor(Shape{3}, DataType::Float32); auto op = g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); g->dataMalloc(); diff --git a/test/nnet/test_memboundOp.cc b/test/nnet/test_memboundOp.cc index 9f1847d6..910344f2 100644 --- a/test/nnet/test_memboundOp.cc +++ b/test/nnet/test_memboundOp.cc @@ -7,7 +7,8 @@ #include "nnet/routine.h" #include "nnet/test.h" #include "operators/matmul.h" -#include +#include "operators/membound.h" +#include "test.h" using namespace infini; using namespace std; @@ -18,8 +19,8 @@ TEST(nnet, MemboundOpInterpretation) { Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); g->dataMalloc(); - i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); g->addOpWithOutputs(i0, w0, o0); NMutator nmutator(NMutator::Mode::ToNaiveMembound); auto mutations = nmutator.run(g); @@ -36,7 +37,7 @@ TEST(nnet, MemboundOpInterpretation) { EXPECT_EQ(membound->getOpType(), OpType::MemBound); auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32, runtime); ans->dataMalloc(); - ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); + ans->copyin(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(membound->getOutput()->equalData(ans)); } @@ -49,8 +50,8 @@ TEST(nnet, MemboundOp_Ansor_Codegen) { Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::Float32); g->dataMalloc(); - i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); g->addOpWithOutputs(i0, w0, o0); NMutator nmutator(NMutator::Mode::ToNaiveMembound); auto mutations = nmutator.run(g); @@ -67,7 +68,7 @@ TEST(nnet, MemboundOp_Ansor_Codegen) { EXPECT_EQ(membound->getOpType(), OpType::MemBound); auto ans = make_ref(Shape{1, 2, 4}, DataType::Float32, cpu); ans->dataMalloc(); - ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); + ans->copyin(vector{38, 44, 50, 56, 83, 98, 113, 128}); auto oCpu = gCpu->cloneTensor(membound->getOutput()); oCpu->printData(); @@ -77,3 +78,41 @@ TEST(nnet, MemboundOp_Ansor_Codegen) { // double time = timeit([&]() { runtime->run(gNew, false); }); // tune // kernels std::cout << "Time (ms):" << time << std::endl; } + +pair, nnet::Expr> getPReluExpr(int size) { + using namespace nnet; + using nnet::make_ref; + DEFINE_VAR(i); + auto A = make_ref("A", vector{size}); + auto B = make_ref("B", vector{size}); + Expr e = make_ref(makeSubscript(A, {i}) - makeSubscript(B, {i}), + FuncType::PRelu); + Expr ret = makeRangeOperator({{i, {0, size}}}, {}, e); + return {{A, B}, ret}; +} + +TEST(nnet, PRelu_Ansor_Codegen) { + auto cuda = make_ref(); + Runtime cpu = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(cuda); + Tensor i0 = g->addTensor(vector{12}); + Tensor w0 = g->addTensor(vector{12}); + Tensor o0 = g->addTensor(vector{12}); + auto [nnetInputs, expr] = getPReluExpr(12); + g->addOpWithOutputs(vector{i0, w0}, vector{o0}, nnetInputs, + expr, -1); + g->dataMalloc(); + i0->setData(IncrementalGenerator()); + w0->setData(ValGenerator<5>()); + cuda->run(g, true); // tune kernels + + // check answer + auto ans = make_ref(Shape{12}, DataType::Float32, cpu); + ans->dataMalloc(); + ans->copyin( + vector{-1.25, -1., -0.75, -0.5, -0.25, 0, 1, 2, 3, 4, 5, 6}); + + Graph gCpu = make_ref(cpu); + auto oCpu = gCpu->cloneTensor(o0); + EXPECT_TRUE(oCpu->equalData(ans)); +} diff --git a/test/nnet/test_mergeStage.cc b/test/nnet/test_mergeStage.cc index c14c68e3..cf3cac41 100644 --- a/test/nnet/test_mergeStage.cc +++ b/test/nnet/test_mergeStage.cc @@ -4,19 +4,16 @@ #include "nnet/Visitor/HashVisitor.h" #include "nnet/Visitor/MergeMemboundMutator.h" #include "nnet/expr.h" +#include "nnet/test.h" #include "gtest/gtest.h" using namespace nnet; using namespace std; -#define DEFINE_VAR(name) auto name = make_ref(#name); TEST(FuseMembound, Relu) { const int n_heads = 8, seq_len = 10000, feat_len = 512; // dilation_heads = 2; const int Batch = n_heads, M = seq_len, K = feat_len, W = 32; - DEFINE_VAR(b); - DEFINE_VAR(m); - DEFINE_VAR(w); - DEFINE_VAR(k); + DEFINE_VAR(b, m, w, k); auto A = make_ref("A", vector({Batch, M, K}), vector{0, 0, 0}); @@ -35,10 +32,7 @@ TEST(FuseMembound, MemMemFusion) { const int n_heads = 8, seq_len = 100, feat_len = 100; // dilation_heads = 2; const int Batch = n_heads, M = seq_len, K = feat_len; - DEFINE_VAR(b); - DEFINE_VAR(m); - DEFINE_VAR(w); - DEFINE_VAR(k); + DEFINE_VAR(b, m, w, k); auto A = make_ref("A", vector({Batch, M, K}), vector{0, 0, 0}); auto B = make_ref("B", vector({Batch, K, M}), @@ -54,4 +48,26 @@ TEST(FuseMembound, MemMemFusion) { RangeOp ans = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}}, {{k, {0, K}}}, makeSubscript(A, {b, m, k})); EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans)); -} \ No newline at end of file +} + +TEST(FuseMembound, mergeNestedStagesInRangeOp) { + // Case in ConvTranspose to Matmul + // LSum ... [i39,f] + // {LSum ... [f,(i39 / 1024),((i39 / 256) % 4),(i39 + // % 256)] {K}} + DEFINE_VAR(f, i); + const int I = 4096, F = 448; + auto K = make_ref("K", vector({448, 4, 4, 256})); + + auto subA = makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}); + auto range = makeRangeOperator({{i, {0, I}}, {f, {0, F}}}, {}, subA); + auto outerRange = makeRangeOperator({{f, {0, F}}, {i, {0, I}}}, {}, + makeSubscript(range, {i, f})); + auto merged = MergeMemboundMutator({outerRange}).merge(); + + // Compare the result with answer + RangeOp ans = makeRangeOperator( + {{f, {0, F}}, {i, {0, I}}}, {}, + makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256})); + EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans)); +} diff --git a/test/nnet/test_mutator.cc b/test/nnet/test_mutator.cc index 0b9411bd..cf4d8ab2 100644 --- a/test/nnet/test_mutator.cc +++ b/test/nnet/test_mutator.cc @@ -56,45 +56,72 @@ TEST(Mutator, NaiveConvWithInterpreter) { // FIXME: failed since implicit transpose for DLT TEST(Mutator, InfoGAN_TConv_3_correctness) { - // verifyNaiveMembound True: subgraph after transformation - // verifyNaiveMembound False: subgraph of one single membound (eOP) - // const bool verifyNaiveMembound = false; + const bool useMutatorDirectly = true; Runtime runtime = make_ref(); Graph g = make_ref(runtime); Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); - // {n, h, w, f} * {f, r, s, c} - auto i0 = g->addTensor({1, 2, 2, 448}); - auto w0 = g->addTensor({448, 4, 4, 256}); + const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4; + // // Minimum config for test + // const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4; + // const int n = 1, c = 2, h = 2, w = 2, f = 2, r = 4, s = 4; + + auto i0 = g->addTensor({n, h, w, f}); + auto w0 = g->addTensor({f, r, s, c}); g->addOp(i0, w0, nullptr, 1, 1, 2, 2, 1, 1); - auto mutator = make_ref(); - mutator->setToNaiveMembound(); - SearchEngine searchEngine(runtime, mutator); - auto bestGraph = searchEngine.run(g); - bestGraph->print(); - printf("--- SearchEngine Finished ---\n"); + auto mutator = + make_ref(NMutator::Mode::RuleBased, + vector{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}); + // // Translate OP to membound without derivation + // mutator->setToNaiveMembound(); + vector bestGraphs; + if (useMutatorDirectly) { // Use mutator results + bestGraphs = mutator->run(g); + } else { // Use search engine results + SearchEngine searchEngine(runtime, mutator); + bestGraphs.emplace_back(searchEngine.run(g)); + } g->dataMalloc(); - bestGraph->dataMalloc(); - for (auto t : g->getTensors()) { - if (t->getFuid() <= 2) - t->setData(IncrementalGenerator()); + map fuidToInputTensor; + for (auto t : g->getInputs()) { + EXPECT_EQ(fuidToInputTensor.count(t->getFuid()), 0); + fuidToInputTensor[t->getFuid()] = t; } - for (auto t : bestGraph->getTensors()) { - if (t->getFuid() <= 2) - t->setData(IncrementalGenerator()); + + for (size_t i = 0; i < bestGraphs.size(); i++) { + auto bestGraphCpu = bestGraphs[i]; + auto bestGraph = + make_ref(runtime, bestGraphCpu->getOperators()); + + auto gen = RandomGenerator(0, 1, i); + bestGraph->dataMalloc(); + // Initialize inputs with random data + for (auto t : g->getInputs()) { + t->setData(gen); + } + for (auto t : bestGraph->getInputs()) { + t->copyData(fuidToInputTensor[t->getFuid()]); + } + + // Initialize outputs with zeros + for (auto t : g->getOutputs()) { + t->setData(ZeroGenerator()); + } + for (auto t : bestGraph->getOutputs()) { + t->setData(ZeroGenerator()); + } + + runtime->run(bestGraph, true); // Tune kernels + runtime->run(g); + runtime->run(bestGraph, false); // Execute transfomraed graph + + auto go0 = gCpu->cloneTensor(g->getOutputs()[0]); + auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]); + EXPECT_TRUE(go0->equalData(bgo0, 1e-4)); } - runtime->run(g); - runtime->run(bestGraph); - - auto go0 = gCpu->cloneTensor(g->getOutputs()[0]); - auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]); - - EXPECT_TRUE(go0->equalData(bgo0)); - EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr() != - bestGraph->getOutputs()[0]->getRawDataPtr()); } // TEST(Mutator, Conv9x9) { diff --git a/test/script/env_lotus.sh b/test/script/env_lotus.sh index 428024f1..9fe82b98 100644 --- a/test/script/env_lotus.sh +++ b/test/script/env_lotus.sh @@ -9,7 +9,7 @@ then elif [ "$1" == "intelcpu" ] then echo "Load INTELCPU environment." - spack load intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0 + spack load gcc@12.1.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0 # The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly. export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH @@ -20,7 +20,7 @@ then # Preloading the missing libs will work, refered to https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/mkl-fails-to-load/m-p/1155538 export MKLLIB_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-mkl-2022.1.0-mf6te62fo6wxlo33jwwwgg5kljoagc6g/mkl/2022.1.0/ - export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-11.3.0/intel-oneapi-compilers-2022.1.0-qrq4a63scjip455bpxvl5ipgqbllwecj/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so + export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-compilers-2022.1.0-6k6zm3h4qcsni27nihc4b6wuqgtxqxqa/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so else echo "Bad option. Please enter 'cuda' or 'intelcpu'. CUDA will be loaded by default if nothing specified." fi