NNET supports TVM backend and kernels (#78)

* Add: mutator InfoGAN minimum test

* Add: cache and padding (bugs!!)

* Add: expression reader as a cmake target

* Fix: [Intermediate] NMutator::expressionToGraph

To be fix: matmul with implicit broadcast

* Add: matmul broadcast

* Fix: GraphObj ctor should use cloneTensor

* Fix: cuBLAS failure when codegen is enabled

* Add: Exception for checkCuError

* Fix: graph OpList ctor

* Add: expr simplication for TVM

* Add: TVM headers and CMake include paths

* Add: CMake config

* Add: PackedFunc (broken)

* Fix: remove cuCtxCreate which makes TVM fails

* Fix: membound_tvm

* Fix: test_memboundOp

* Add: PRelu Expr and AsTVMVisitor

* Add: Random generator

* Add: support TVM packed function

* Fix: specify runtime

* Add: CMake support of TVM

* Add: detailed output of Matmul

* Add: comments for Matmul

* Chore: format and comments

* Chore: GraphObj::selfCheck without assert control

* Fix: CMAKE_CXX_FLAGS in CMakeLists

* fix merge bug

* update api for mkl batchnorm test

* fix lotus env

* fig header bug

---------

Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
Co-authored-by: huangshuhong <huangsh19@mails.tsinghua.edu.cn>
Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
zhengly123 2023-04-18 00:26:36 +08:00 committed by GitHub
parent 43d4798323
commit a1974aabcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1158 additions and 334 deletions

View File

@ -16,6 +16,16 @@ cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
message(STATUS "Using config.cmake in CMAKE_CURRENT_BINARY_DIR directory")
include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
else()
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/config.cmake)
message(STATUS "Using config.cmake in CMAKE_CURRENT_SOURCE_DIR directory")
include(${CMAKE_CURRENT_SOURCE_DIR}/config.cmake)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
@ -32,19 +42,6 @@ if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
if(BUILD_TEST)
set(BUILD_GMOCK
OFF
CACHE BOOL "Do not build gmock" FORCE)
set(INSTALL_GTEST
OFF
CACHE BOOL "Do not install gtest" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall ")
add_subdirectory(3rd-party/googletest)
include_directories(SYSTEM 3rd-party/googletest/googletest/include)
endif()
#Protobuf
if(USE_PROTOBUF)
add_definitions(-D TENSOR_PROTOBUF)
@ -71,10 +68,35 @@ include_directories(3rd-party/pybind11/include)
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
# TVM backend
if(BUILD_TEST_EINNET)
if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR)
message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON")
endif()
# TVM and DMLC for invoking TVM packed functions
include_directories(${TVM_INCLUDE_DIR})
include_directories(${DMLC_INCLUDE_DIR})
include_directories(${DLPACK_INCLUDE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels
endif()
if(BUILD_TEST)
set(BUILD_GMOCK
OFF
CACHE BOOL "Do not build gmock" FORCE)
set(INSTALL_GTEST
OFF
CACHE BOOL "Do not install gtest" FORCE)
add_subdirectory(3rd-party/googletest)
include_directories(3rd-party/googletest/googletest/include)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
# Source files
file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc)
@ -101,6 +123,11 @@ endif()
target_link_libraries(InfiniTensor pybind11::embed)
# TVM backend
if(BUILD_TEST_EINNET)
target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so)
endif()
# Python bindings
file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc)
pybind11_add_module(backend MODULE ${FFIS})
@ -229,5 +256,9 @@ if(BUILD_TEST)
endif()
if(BUILD_TEST_EINNET)
build_test(test/nnet/test_*.cc)
# Build expression reader
add_executable(nnet_reader test/nnet/readlog.cc)
target_link_libraries(nnet_reader InfiniTensor)
endif()
endif()

View File

@ -16,7 +16,7 @@ endif
build:
mkdir -p build/$(TYPE)
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j22
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8
clean:
rm -rf build

View File

@ -0,0 +1,13 @@
set(TVM_HOME "/home/zly/Apps/tvm-v0.10.0")
set(TVM_INCLUDE_DIR "${TVM_HOME}/include")
set(TVM_LIB_DIR "${TVM_HOME}/build")
set(DMLC_INCLUDE_DIR "${TVM_HOME}/3rdparty/dmlc-core/include")
set(DLPACK_INCLUDE_DIR "${TVM_HOME}/3rdparty/dlpack/include")
set(USE_CUDA ON)
set(USE_BANG OFF)
set(BUILD_TEST ON)
set(BUILD_TEST_CORE ON)
set(BUILD_TEST_PET OFF)
set(BUILD_TEST_EINNET ON)

View File

@ -19,6 +19,9 @@ class GraphObj : public Object {
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
/**
* @brief Clone a tensor and add it to the graph.
*/
Tensor cloneTensor(const Tensor &tensor) {
return addTensor(tensor->clone(runtime));
}

View File

@ -69,6 +69,9 @@ class TensorObj : public TensorBaseObj {
void copyData(const TensorObj *src);
void copyData(const Tensor &src) { copyData(src.get()); }
// FIXME: std::fucntion copies the generator instead of passing it by ref.
// Thus the internal state of generator cannot be updated.
void setData(
const std::function<void(void *, size_t, DataType)> &generator) const;
Tensor clone() const {
@ -92,29 +95,31 @@ class TensorObj : public TensorBaseObj {
}
void printData() const;
bool equalData(const Tensor &rhs) const;
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
template <typename T> bool equalData(const vector<T> &dataVector) {
IT_ASSERT(DataType::get<T>() == dtype);
IT_ASSERT(size() == dataVector.size());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size(),
1e-6);
}
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
private:
void printDataFloat() const;
void printDataUint32_t() const;
void printDataFloat(float *ptr) const;
void printDataUint32_t(uint32_t *ptr) const;
template <typename T>
bool equalDataImpl(const T *a, const T *b, size_t size) const {
bool equalDataImpl(const T *a, const T *b, size_t size,
double relativeError) const {
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_integral_v<T>) {
if (a[i] != b[i])
return false;
} else if constexpr (std::is_floating_point_v<T>) {
if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) >
1e-6) {
relativeError) {
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
return false;
}

View File

@ -23,9 +23,8 @@
const char *errName; \
if (CUDA_SUCCESS != err) { \
cuGetErrorString(err, &errName); \
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
errName); \
exit(EXIT_FAILURE); \
IT_ASSERT(err == CUDA_SUCCESS, \
(string("CU error: ") + string(errName))); \
} \
}

View File

@ -11,18 +11,8 @@ class CudaRuntimeObj : public RuntimeObj {
CudaPtr workspace;
size_t workspaceSize;
public:
CUdevice cuDevice;
CUcontext newContext;
public:
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
// Prepare for nvrtc. cuCtxCreate should be called befero others.
// Otherwise it will result in strange failure, such as cuBLAS failed on
// certian inputs.
checkCUresult(cuInit(0));
checkCUresult(cuDeviceGet(&cuDevice, 0));
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
checkCudnnError(cudnnCreate(&cudnn));
checkCublasError(cublasCreate(&cublas));
@ -32,10 +22,13 @@ class CudaRuntimeObj : public RuntimeObj {
workspace = alloc(workspaceSize);
}
virtual ~CudaRuntimeObj() {
try {
dealloc(workspace);
checkCudnnError(cudnnDestroy(cudnn));
checkCublasError(cublasDestroy(cublas));
checkCUresult(cuCtxDestroy(newContext));
} catch (const std::exception &e) {
std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl;
}
}
string toString() const override;

View File

@ -22,6 +22,7 @@ class HashVisitor : public Functor<HashType(void)> {
HashType visit_(const Subscript &c) override;
HashType visit_(const Tensor &c) override;
HashType visit_(const Var &c) override;
HashType visit_(const Func &c) override;
public:
HashVisitor(int _verobse = 0) : Functor(_verobse) {}

View File

@ -20,7 +20,13 @@ class MergeMemboundMutator : public Mutator {
*/
MergeMemboundMutator(const VecExpr &kernels)
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
Expr merge(bool allowEmptyMembound = false);
/// @brief Merged multiple expressions into one with one or several stages.
/// @param allowEmptyMembound
/// @param allowFailure If true, return nullptr when merging fails. If
/// false, assert will fail.
/// @return
Expr merge(bool allowEmptyMembound = false, bool allowFailure = false);
};
} // namespace nnet

View File

@ -66,7 +66,7 @@ static inline HashType genhash(string s) {
}
#define nnet_unimplemented_halt() \
{ assert(!"Unimplemented"); }
{ IT_TODO_HALT(); }
#define nnet_unimplemented_continue() \
{ dbg("Unimplemented"); }

View File

@ -104,10 +104,11 @@ enum class NodeType {
FuncNodeType
};
enum class FuncType { Relu, Tanh };
enum class FuncType { Relu, Tanh, PRelu };
#define DEFINE_GETTYPE(CLASS) \
NodeType getType() const override { return NodeType::CLASS##Type; }
#define DEFINE_GETTYPE(CLASS, isScalar_v) \
NodeType getType() const override { return NodeType::CLASS##Type; } \
bool isScalar() const override { return isScalar_v; }
class ExprNode {
public:
@ -119,6 +120,7 @@ class ExprNode {
friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr);
virtual NodeType getType() const = 0;
virtual bool isScalar() const = 0;
};
class VarNode : public ExprNode {
@ -127,7 +129,7 @@ class VarNode : public ExprNode {
public:
VarNode(std::string _name) : name(_name){};
virtual ~VarNode() {}
DEFINE_GETTYPE(VarNode);
DEFINE_GETTYPE(VarNode, true);
const std::string &getName() const { return name; }
HashType hash() const override { return genhash(name); };
@ -152,7 +154,7 @@ class TensorNode : public ExprNode {
TensorNode(string _name, vector<int> _shape, vector<int> _paddings = {},
Routine _source = nullptr);
virtual ~TensorNode() {}
DEFINE_GETTYPE(TensorNode);
DEFINE_GETTYPE(TensorNode, false);
bool operator==(const string &rhs) { return name == rhs; }
friend bool operator==(const string &lhs, const TensorNode &rhs) {
@ -174,6 +176,7 @@ class TensorNode : public ExprNode {
const Routine &getSource() const { return source; }
int getData(const Ref<vector<int>> &data, const vector<int> &idx);
size_t getOffset(const vector<int> &idx);
bool hasPadding();
};
enum class OpType { Range, Add, Mul, Div, Mod, Sub };
@ -220,7 +223,7 @@ class RangeOpNode : public OperatorNode {
const vector<int> &paddings)
: OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters},
paddings(paddings){};
DEFINE_GETTYPE(RangeOpNode);
DEFINE_GETTYPE(RangeOpNode, false);
virtual HashType hash() const override {
nnet_unimplemented_halt();
@ -289,7 +292,7 @@ class BinaryOpNode : public OperatorNode {
BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs)
: OperatorNode(_opType, {_lhs, _rhs}){};
virtual ~BinaryOpNode() {}
DEFINE_GETTYPE(BinaryOpNode);
DEFINE_GETTYPE(BinaryOpNode, true);
virtual HashType hash() const override {
return genhash((HashType)opType,
@ -314,7 +317,7 @@ class ConstantNode : public ExprNode {
ConstantNode(int _val) : val(_val){};
ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){};
virtual ~ConstantNode() {}
DEFINE_GETTYPE(ConstantNode);
DEFINE_GETTYPE(ConstantNode, true);
int getValue() const { return val; }
virtual HashType hash() const override { return genhash(val, 6214587); };
@ -334,7 +337,7 @@ class SubscriptNode : public ExprNode {
SubscriptNode(Expr _indexed, vector<Expr> _subExprs) : subExprs(_subExprs) {
setObject(_indexed);
};
DEFINE_GETTYPE(SubscriptNode);
DEFINE_GETTYPE(SubscriptNode, true);
virtual HashType hash() const override {
nnet_unimplemented_continue();
@ -358,14 +361,15 @@ class SubscriptNode : public ExprNode {
class FuncNode : public ExprNode {
protected:
Subscript object;
Expr object;
FuncType funcType;
public:
FuncNode(Expr object, FuncType funcType) : funcType(funcType) {
setObject(object);
FuncNode(Expr object, FuncType funcType)
: object(object), funcType(funcType) {
nnet_assert(object->isScalar(), "FuncNode operates on a scalar");
}
DEFINE_GETTYPE(FuncNode);
DEFINE_GETTYPE(FuncNode, true);
virtual HashType hash() const override {
nnet_unimplemented_continue();
@ -373,7 +377,7 @@ class FuncNode : public ExprNode {
};
virtual string toReadable() const override;
const Subscript &getObject() const { return object; }
const Expr &getObject() const { return object; }
void setObject(Expr e);
FuncType getFuncType() const { return funcType; }

View File

@ -20,7 +20,7 @@ class NMutator : public Mutator {
public:
NMutator(Mode mode = Mode::Normal);
NMutator(const std::vector<int> &derivationRules);
NMutator(Mode mode, const std::vector<int> &derivationRules);
~NMutator();
vector<Graph> run(const Graph &in_graph) override;

View File

@ -19,9 +19,15 @@ class MatmulObj : public OperatorObj {
public:
/**
* @brief Construct a new Matmul object. This comments show how operators is
* defined in InfiniTensor. The constructor can create output tensors for
* the operator or not, which depends on `graph`.
* @brief Matmul operator with batch broadcast and tensor transpose
* supports. Only one tensor with singe batch can be broadcasted due to the
* BLAS interface restriction. Tranpose indicates whether the last two
* dimensions should be transposed before Matmul and does not affect other
* leading dimensions.
*
* Matmul show how operators are defined in InfiniTensor. The constructor of
* an operator can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph The computation graph that this operator belongs to.
* @param A The input tensor.

View File

@ -7,9 +7,10 @@ namespace infini {
class MemBoundObj : public OperatorObj {
private:
std::vector<nnet::Tensor> nnetInputs;
nnet::Expr expr;
nnet::Expr expr, simplifiedExpr;
double exec_time;
std::string hint;
HashType hash, simplifiedHash;
int n, f, h, w;
public:
@ -26,11 +27,15 @@ class MemBoundObj : public OperatorObj {
int numOutputs() const override { return outputs.size(); }
const vector<nnet::Tensor> &getNnetInputs() const { return nnetInputs; }
const nnet::Expr getNnetExpr() const { return expr; }
pair<const nnet::Expr, HashType> getSimplifiedNnetExpr() const {
return {expr, hash};
}
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
HashType getHash() const;
static HashType calcHash(nnet::Expr expr);
static bool checkOOB(nnet::Expr expr);
};
} // namespace infini

View File

@ -1,5 +1,7 @@
#pragma once
#include "core/common.h"
#include "core/tensor_base.h"
#include <random>
namespace infini {
@ -38,6 +40,31 @@ class IncrementalGenerator : public DataGenerator {
void fill(float *data, size_t size) override { fill<float>(data, size); }
};
class RandomGenerator : public DataGenerator {
private:
double l, r;
std::mt19937 e;
std::uniform_int_distribution<int> di;
std::uniform_real_distribution<float> dr;
public:
RandomGenerator(double l = 0, double r = 1, unsigned int seed = 0)
: l(l), r(r), e(seed), di(l, r), dr(l, r) {}
virtual ~RandomGenerator() {}
private:
void fill(uint32_t *data, size_t size) override {
for (size_t i = 0; i < size; i++) {
data[i] = di(e);
}
}
void fill(float *data, size_t size) override {
for (size_t i = 0; i < size; i++) {
data[i] = dr(e);
}
}
};
template <int val> class ValGenerator : public DataGenerator {
public:
virtual ~ValGenerator() {}

View File

@ -1 +1,2 @@
from .gen_ansor_op import gen_ansor_op
from .gen_ansor_so import gen_ansor_so

View File

@ -3,19 +3,50 @@ import re
import numpy as np
import tvm
from tvm import te, tir, auto_scheduler, topi
import os
import json
import logging
USE_CACHE = True
logger = logging.getLogger('InfiniTensor')
logger.setLevel(logging.DEBUG)
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name):
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f,
func_name, input_names, output_name, nnet_expression: str,
nnet_simplified_expression: str, hash_code=None):
assert len(input_tensors) == len(input_dtypes)
assert len(input_tensors) == len(input_names)
logging.debug(f'Work on hash {hash_code}')
dir_name = os.path.join(".cache", "generated_kernels", str(hash_code))
func_code_fn = os.path.join(dir_name, "kernel.cu")
invoke_code_fn = os.path.join(dir_name, "invoke.cpp")
config_fn = os.path.join(dir_name, "config.json")
if USE_CACHE and hash_code is not None:
if os.path.exists(dir_name):
print(f"Use cache in {dir_name}")
with open(func_code_fn, "r") as func_code_fin:
func_code = func_code_fin.read()
with open(invoke_code_fn, "r") as invoke_code_fin:
invoke_code = invoke_code_fin.read()
with open(config_fn, "r") as config_fin:
config = json.loads(config_fin.read().strip())
conv_time = config["conv_time"]
invoke_params = config["invoke_params"]
logger.debug(f'Find tuning log for {hash_code}')
return func_code, invoke_code, conv_time, invoke_params
print("Generating Ansor op: ")
print(f)
@auto_scheduler.register_workload(func_name)
def compute():
_locals = locals()
exec(f, {'tvm': tvm, 'te': te, 'tir': tir}, _locals)
exec(f, {'tvm': tvm, 'te': te, 'tir': tir, 'topi': topi}, _locals)
return _locals['ret']
target = tvm.target.Target("cuda")
@ -43,6 +74,28 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
# Kill the measurement process
del measure_ctx
def test_mutator():
# test part
tgt_temp = tvm.target.Target(target="llvm", host="llvm")
all_tensors = compute()
sch = te.create_schedule(all_tensors[0].op)
args = all_tensors
C0, K0, A0 = args
func_temp = tvm.build(sch, args, tgt_temp, name="temp")
# print result
n, c, h, w, f, r, s = 1, 1, 2, 2, 1, 4, 4
dev_temp = tvm.device(tgt_temp.kind.name, 0)
A_temp = tvm.nd.array(
np.arange(n*h*w*f).reshape(n, h, w, f).astype(A0.dtype), dev_temp)
K_temp = tvm.nd.array(
np.arange(f*r*s*c).reshape(f, r, s, c).astype(K0.dtype), dev_temp)
C_temp = tvm.nd.array(
np.zeros((1, 4, 4, 1)).astype(C0.dtype), dev_temp)
func_temp(C_temp, K_temp, A_temp)
print("================= Test Result =====================")
print(C_temp)
ir = str(tvm.lower(sch, args, simple_mode=True))
thread_dim = [1, 1, 1]
block_dim = [1, 1, 1]
@ -83,11 +136,27 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
print("Func Code")
# Attach TVM code behind func_code
func_code += "\n/* " + f + "*/"
func_code += "\n/* NNET tensor expression \n" + nnet_expression + "\n*/\n"
func_code += "\n/* NNET simplified tensor expression \n" + \
nnet_simplified_expression + "\n*/\n"
func_code += "\n/* TVM compute\n" + f + "\n*/\n"
print(func_code)
print("Invoke Code")
print(invoke_code)
print("Time")
print(conv_time)
if hash_code is not None:
if not os.path.exists(dir_name):
os.makedirs(dir_name)
with open(func_code_fn, "w") as func_code_fout:
func_code_fout.write(func_code)
with open(invoke_code_fn, "w") as invoke_code_fout:
invoke_code_fout.write(invoke_code)
with open(config_fn, "w") as config_fout:
config_fout.write(json.dumps({
"conv_time": conv_time,
"invoke_params": invoke_params
}, ensure_ascii=False, indent=2))
return func_code, invoke_code, conv_time, invoke_params # ms

View File

@ -0,0 +1,106 @@
import re
import numpy as np
import tvm
from tvm import te, tir, auto_scheduler, topi
import os
import json
import logging
USE_CACHE = True
logger = logging.getLogger('InfiniTensor')
logger.setLevel(logging.DEBUG)
def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
tvm_code, func_name, nnet_expression: str,
nnet_simplified_expression: str, hash_code=None):
assert len(input_tensors) == len(input_dtypes)
logging.debug(f'Work on hash {hash_code}')
dir_name = os.path.join(".cache", "generated_kernels", str(hash_code))
if not os.path.exists(dir_name):
os.makedirs(dir_name)
so_fn = os.path.join(dir_name, f"{func_name}.so")
config_fn = os.path.join(dir_name, "config_so.json")
print("Generating Ansor op: ")
print(tvm_code)
print("Input shape: ")
print(input_tensors)
print("Output shape: ")
print(output_tensor)
if USE_CACHE and hash_code is not None:
if os.path.exists(dir_name) and \
os.path.exists(so_fn) and \
os.path.exists(config_fn):
print(f"Use cache in {dir_name}")
with open(config_fn, "r") as config_fin:
config = json.loads(config_fin.read().strip())
conv_time = config["conv_time"]
logger.debug(f'Find tuning log for {hash_code}')
return so_fn, conv_time
@auto_scheduler.register_workload(func_name)
def compute():
_locals = locals()
exec(tvm_code, {'tvm': tvm, 'te': te, 'tir': tir, 'topi': topi}, _locals)
return _locals['ret']
target = tvm.target.Target("cuda")
task = auto_scheduler.SearchTask(func=func_name, args=(), target=target)
# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
log_file = f"ansor_{func_name}_log.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
# Kill the measurement process
del measure_ctx
func = tvm.build(sch, args, target, name=func_name)
func.export_library(so_fn)
ctx = tvm.cuda(0)
input_a = []
for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)):
a_np = np.random.uniform(size=shape).astype(dtype)
input_a.append(tvm.nd.array(a_np, ctx))
a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx)
func(a_out, *input_a)
evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
conv_time = evaluator(a_out, *input_a).mean * 1e3
print("====NNET tensor expression====")
print(nnet_expression+"\n")
print("====NNET simplified tensor expression====")
print(nnet_simplified_expression+"\n")
print("====Time====")
print(conv_time)
if USE_CACHE and hash_code is not None:
with open(config_fn, "w") as config_fout:
config_fout.write(json.dumps({
"conv_time": conv_time,
}, ensure_ascii=False, indent=2))
return so_fn, conv_time

View File

@ -11,13 +11,11 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
for (const auto &op : ops_in) {
for (const auto &t : op->getInputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
tensorPool[t->getFuid()] = cloneTensor(t);
for (const auto &t : op->getOutputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
tensorPool[t->getFuid()] = cloneTensor(t);
}
for (const auto &[_, t] : tensorPool)
addTensor(t);
// Clone operators and add connections
for (const auto &op : ops_in) {
TensorVec inputs, outputs;
@ -127,8 +125,12 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
}
Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
return tensors.emplace_back(tensor);
IT_ASSERT(tensor->getRuntime() == runtime,
std::string("Tensor runtime mismatch: cannot add a tenosr in ") +
tensor->getRuntime()->toString() + " to " +
runtime->toString());
tensors.emplace_back(tensor);
return tensor;
}
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
@ -207,6 +209,13 @@ bool GraphObj::checkValid() const {
IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end());
}
}
std::set<UidBaseType> s;
// check whether two tensors with the same FUID exist
for (auto tensor : tensors) {
int cnt = s.count(tensor->getFuid());
IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid()));
s.insert(tensor->getFuid());
}
return true;
}

View File

@ -24,8 +24,8 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
bool OperatorObj::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::G2BMM ||
type == OpType::GBMM;
type == OpType::ConvTrans || type == OpType::ConvTransNHWC ||
type == OpType::G2BMM || type == OpType::GBMM;
}
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
@ -92,7 +92,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) {
IT_ASSERT(!outputs[i]);
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
}
} else { // if outputs have been created, check their shapes

View File

@ -16,9 +16,16 @@ TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
[](auto acc, auto x) { return acc * x; })) {}
string TensorObj::toString() const {
// Convert data pointer to string
std::stringstream ss;
if (data != nullptr)
ss << data->getPtr<void *>();
else
ss << "nullptr data";
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
std::to_string(fuid) + ", shape " + vecToString(shape) +
", dtype " + dtype.toString();
", dtype " + dtype.toString() + ", " + runtime->toString() +
", " + ss.str() + "\n";
vector<UidBaseType> targetGuids;
for (const auto &op : targets)
targetGuids.emplace_back(op.lock()->getGuid());
@ -57,25 +64,36 @@ vector<size_t> TensorObj::getStride() const {
void TensorObj::printData() const {
IT_ASSERT(data != nullptr);
if (!runtime->isCpu())
IT_TODO_HALT();
void *ptr = nullptr;
Blob buffer;
if (!runtime->isCpu()) {
buffer = NativeCpuRuntimeObj::getInstance()->allocBlob(getBytes());
runtime->copyBlobToCPU(buffer->getPtr<void *>(),
getRawDataPtr<void *>(), getBytes());
ptr = buffer->getPtr<void *>();
} else
ptr = data->getPtr<float *>();
if (dtype == DataType::Float32)
printDataFloat();
printDataFloat(static_cast<float *>(ptr));
else if (dtype == DataType::UInt32)
printDataUint32_t();
printDataUint32_t(static_cast<uint32_t *>(ptr));
else
IT_TODO_HALT();
}
void TensorObj::printDataFloat() const {
void TensorObj::printDataFloat(float *ptr) const {
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
auto ptr = data->getPtr<float *>();
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
if (iEnd > 1000 && i > 20 && i < iEnd - 20) {
printf("... , ");
i = iEnd - 20;
continue;
}
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";
@ -94,12 +112,11 @@ void TensorObj::printDataFloat() const {
}
}
void TensorObj::printDataUint32_t() const {
void TensorObj::printDataUint32_t(uint32_t *ptr) const {
IT_ASSERT(data != nullptr);
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
auto ptr = data->getPtr<VType *>();
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
@ -122,7 +139,7 @@ void TensorObj::printDataUint32_t() const {
}
}
bool TensorObj::equalData(const Tensor &rhs) const {
bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
IT_ASSERT(getDType() == rhs->getDType());
@ -132,10 +149,11 @@ bool TensorObj::equalData(const Tensor &rhs) const {
return false;
if (getDType() == DataType::UInt32)
return equalDataImpl(getRawDataPtr<uint32_t *>(),
rhs->getRawDataPtr<uint32_t *>(), size());
rhs->getRawDataPtr<uint32_t *>(), size(), 0);
else if (getDType() == DataType::Float32)
return equalDataImpl(getRawDataPtr<float *>(),
rhs->getRawDataPtr<float *>(), size());
rhs->getRawDataPtr<float *>(), size(),
relativeError);
else
IT_TODO_HALT();
}

View File

@ -50,10 +50,23 @@ class matmulCublas : public Kernel {
// TODO:use compute type
cublasStatus_t stat;
if (b > 1) {
// Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getDims().size();
int dimB = op->getInputs(1)->getDims().size();
long long strideA =
(dimA == 2 ||
(dimA == 3 && op->getInputs(0)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: m * k;
long long strideB =
(dimB == 2 ||
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: n * k;
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta,
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(
@ -61,6 +74,8 @@ class matmulCublas : public Kernel {
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
}
// if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat);
return (stat == CUBLAS_STATUS_SUCCESS);
}
@ -79,6 +94,8 @@ class matmulCublas : public Kernel {
const RuntimeObj *_context) const override {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<MatmulObj>(_op);
IT_ASSERT(context);
IT_ASSERT(op);
auto ret = make_ref<MatmulCublasPerfRecordObj>();
ret->time = std::numeric_limits<double>::max();
for (int i = 0; i < N_ALGO; i++) {
@ -91,9 +108,8 @@ class matmulCublas : public Kernel {
if (rcd->time < ret->time)
ret = rcd;
}
IT_ASSERT(ret->time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
IT_ASSERT(ret->time < std::numeric_limits<double>::max(),
"No valid algorithm found for " + op->toString());
return ret;
}
};

View File

@ -1,7 +1,11 @@
#ifdef INFINI_USE_TVM
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "ffi/ffi_embed.h"
#include "nnet/Visitor/AsTVMVisitor.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nvrtc.h"
#include "operators/membound.h"
#include "operators/pooling.h"
@ -17,11 +21,12 @@ class TVMRecordObj : public PerfRecordObj {
std::string log, ptx;
std::vector<int> invokeParams;
std::string kernelName;
HashType simplifiedExprHash;
};
using TVMRecord = Ref<TVMRecordObj>;
class MemboundTVM : public Kernel {
class MemboundTVMExtractSource : public Kernel {
public:
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
@ -65,6 +70,11 @@ class MemboundTVM : public Kernel {
return "var_" + std::to_string(t->getGuid());
}
bool checkOOB(nnet::Expr expr) const {
return nnet::CheckOOBVisitor().checkRangeOp(
nnet::as<nnet::RangeOpNode>(expr));
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
@ -73,10 +83,18 @@ class MemboundTVM : public Kernel {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
// invoke Ansor to tune a membound kernel
std::string func = "mem_bound_" + std::to_string(op->getGuid());
std::string kernelName = func + "_kernel0";
nnet::AsTVMVisitor visitor;
visitor.dispatch(op->getNnetExpr());
IT_ASSERT(!checkOOB(op->getNnetExpr()));
// fuse stages in nnet expr to reduce kernels generated by TVM
auto expr = op->getNnetExpr();
if (auto mergedExpr =
nnet::MergeMemboundMutator({expr}).merge(false, true))
expr = mergedExpr;
nnet::HashVisitor hashVisitor;
HashType hashCode = hashVisitor.getHash(expr);
visitor.dispatch(expr);
auto &&stmts = visitor.getStmts();
auto &&inShapes = visitor.getInputShapes();
auto &&outShape = visitor.getOutputShape();
@ -85,10 +103,14 @@ class MemboundTVM : public Kernel {
for (auto &&in : op->getInputs()) {
inputs.emplace_back(getVarName(in));
}
std::string output = getVarName(op->getOutput());
const std::string output = getVarName(op->getOutput());
const std::string func = "membound_" + std::to_string(hashCode);
const std::string kernelName = func + "_kernel0";
auto res = getAnsorCode(
inShapes, std::vector<std::string>(inShapes.size(), "float32"),
outShape, "float32", stmts, func, inputs, output);
outShape, "float32", stmts, func, inputs, output, op->toString(),
expr->toReadable(), hashCode);
// compile the kernel
auto funcCode = res.first;
@ -119,6 +141,7 @@ class MemboundTVM : public Kernel {
nvrtcGetPTX(prog, ret->ptx.data());
ret->invokeParams = invokeParams;
ret->kernelName = kernelName;
ret->simplifiedExprHash = hashCode;
// prepare for evaluation
CUmodule module;
@ -151,20 +174,43 @@ class MemboundTVM : public Kernel {
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
}
/// @brief
/// @param inDims
/// @param inDTypes
/// @param outDims
/// @param outDType
/// @param lambda
/// @param funcName Generated function name
/// @param inputNames Input array names in the generated invocation code.
/// @param outputName Output array names in the generated invocation code.
/// @param nnetExpressionString Save expr in string for logging.
/// @param nnetSimplifiedExprString Save simplified expr in string for
/// logging.
/// @param hashCode (optional) Hash code of the input expression for kernel
/// cache.
/// @return
std::pair<std::string, std::vector<int>>
getAnsorCode(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims, const std::string &outDType,
const std::string &lambda, const std::string &funcName,
const std::vector<std::string> &inputNames,
const std::string &outputName) const {
const std::string &outputName,
const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString,
const HashType hashCode) const {
std::string funcCode;
std::vector<int> invokeParams;
try {
start_interpreter();
auto func = py::module::import("cpp_plugin").attr("gen_ansor_op");
py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda,
funcName, inputNames, outputName);
// Use static to avoid re-importing the module. Re-importing results
// in cuBLAS failure, whose root cause is not identified yet.
static auto func =
py::module::import("cpp_plugin").attr("gen_ansor_op");
py::tuple code =
func(inDims, inDTypes, outDims, outDType, lambda, funcName,
inputNames, outputName, nnetExprString,
nnetSimplifiedExprString, std::to_string(hashCode));
funcCode = py::str(code[0]);
auto temp = py::list(code[3]);
for (int i = 0; i < 6; ++i) {
@ -183,6 +229,9 @@ class MemboundTVM : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, MemboundTVM,
"Memobund_TVM_Ansor");
// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
// MemboundTVMExtractSource,
// "Memobund_TVM_Ansor_extract_source");
}; // namespace infini
#endif

View File

@ -0,0 +1,224 @@
#ifdef INFINI_USE_TVM
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "dlpack/dlpack.h"
#include "ffi/ffi_embed.h"
#include "nnet/Visitor/AsTVMVisitor.h"
#include "operators/membound.h"
#include "operators/pooling.h"
#include "tvm/runtime/module.h"
#include "tvm/runtime/packed_func.h"
namespace py = pybind11;
namespace infini {
using DLTensorHolder = pair<DLTensor, Ref<vector<int64_t>>>;
class TVMRecordObj : public PerfRecordObj {
public:
std::string kernelName;
HashType simplifiedExprHash;
std::string dllPath;
std::string funcName;
std::vector<int> inputIdx;
};
using TVMRecord = Ref<TVMRecordObj>;
class MemboundTVMPackedFunction : public Kernel {
public:
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
auto op = as<MemBoundObj>(_op);
// auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);
tvm::runtime::PackedFunc packedFunc =
getPackedFunction(tvmRecord->dllPath, tvmRecord->funcName);
IT_ASSERT(packedFunc != nullptr);
// prepare inputs and outputs
vector<DLTensorHolder> inputsHolder;
for (auto idx : tvmRecord->inputIdx) {
inputsHolder.emplace_back(
convertTensorToDLTensor(op->getInputs()[idx]));
}
DLTensorHolder outputHolder = convertTensorToDLTensor(op->getOutput());
// make tvm arg and rv
pair<vector<TVMValue>, vector<int>> preArgs =
convertInOutToTVMArgs(inputsHolder, outputHolder);
tvm::runtime::TVMRetValue rv;
tvm::runtime::TVMArgs args(preArgs.first.data(), preArgs.second.data(),
preArgs.first.size());
packedFunc.CallPacked(args, &rv);
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
IT_ASSERT(false, "A TVM record is required for membound kernel.");
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
TVMRecord ret = std::make_shared<TVMRecordObj>();
auto op = as<MemBoundObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
// invoke Ansor to tune a membound kernel
auto [expr, hash] = op->getSimplifiedNnetExpr();
nnet::AsTVMVisitor visitor;
visitor.dispatch(expr);
auto &&stmts = visitor.getStmts();
auto &&inShapes = visitor.getInputShapes();
auto &&outShape = visitor.getOutputShape();
const std::string func = "membound_" + std::to_string(hash);
const std::string kernelName = func + "_kernel0";
// Set the dllPath directly when debugging
auto dllPath = getAnsorDLL(
inShapes, std::vector<std::string>(inShapes.size(), "float32"),
outShape, "float32", stmts, func, op->toString(),
expr->toReadable(), hash);
// remap input
vector<int> inputIdx;
int numInputs = op->getInputs().size();
for (int i = 0; i < numInputs; ++i) {
string inputName = visitor.getInputs()[i];
int j = 0;
for (; j < numInputs; ++j) {
if (inputName == op->getNnetInputs()[j]->getName())
break;
}
inputIdx.emplace_back(j);
}
tvm::runtime::PackedFunc packedFunc = getPackedFunction(dllPath, func);
IT_ASSERT(packedFunc != nullptr);
// prepare inputs and outputs
vector<DLTensorHolder> inputsHolder;
for (auto idx : inputIdx) {
inputsHolder.emplace_back(
convertTensorToDLTensor(op->getInputs()[idx]));
}
DLTensorHolder outputHolder = convertTensorToDLTensor(op->getOutput());
// make tvm arg and rv
pair<vector<TVMValue>, vector<int>> preArgs =
convertInOutToTVMArgs(inputsHolder, outputHolder);
tvm::runtime::TVMRetValue rv;
tvm::runtime::TVMArgs args(preArgs.first.data(), preArgs.second.data(),
preArgs.first.size());
ret->time = timeit([&]() { packedFunc.CallPacked(args, &rv); },
[&]() { context->sync(); });
ret->kernelName = kernelName;
ret->dllPath = dllPath;
ret->funcName = func;
ret->inputIdx = inputIdx;
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
}
/// @brief
/// @param inDims
/// @param inDTypes
/// @param outDims
/// @param outDType
/// @param lambda
/// @param funcName Generated function name
/// @param nnetExpressionString Save expr in string for logging.
/// @param nnetSimplifiedExprString Save simplified expr in string for
/// logging.
/// @param hashCode (optional) Hash code of the input expression for kernel
/// cache.
/// @return
std::string getAnsorDLL(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims,
const std::string &outDType,
const std::string &lambda,
const std::string &funcName,
const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString,
const HashType hashCode) const {
std::string dllPath;
try {
start_interpreter();
// Use static to avoid re-importing the module. Re-importing results
// in cuBLAS failure, whose root cause is not identified yet.
static auto func =
py::module::import("cpp_plugin").attr("gen_ansor_so");
py::tuple code =
func(inDims, inDTypes, outDims, outDType, lambda, funcName,
nnetExprString, nnetSimplifiedExprString,
std::to_string(hashCode));
dllPath = py::str(code[0]);
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ImportError)) {
std::cerr << "Import Error. Don't forget to set environment "
"variable PYTHONPATH to contain "
"<repo-root>/python"
<< std::endl;
}
throw;
}
return dllPath;
}
tvm::runtime::PackedFunc getPackedFunction(string path,
string functionName) const {
tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path);
return mod.GetFunction(functionName);
}
DLTensorHolder convertTensorToDLTensor(const Tensor &tensor) const {
IT_ASSERT(tensor->getRuntime()->isCuda());
// The lifecycle of shapeInt64 is managed by the caller.
auto shapeInt64 = make_ref<vector<int64_t>>();
for (auto v : tensor->getDims())
shapeInt64->push_back(v);
DLTensor ret{
.data = tensor->getRawDataPtr<void *>(),
.device = DLDevice{.device_type = kDLCUDA, .device_id = 0},
.ndim = (int32_t)shapeInt64->size(),
.dtype =
DLDataType{.code = (uint8_t)kDLFloat, .bits = 32, .lanes = 1},
.shape = static_cast<int64_t *>(shapeInt64->data()),
.strides = nullptr,
.byte_offset = 0,
};
return {ret, shapeInt64};
}
pair<vector<TVMValue>, vector<int>>
convertInOutToTVMArgs(const vector<DLTensorHolder> &inputs,
const DLTensorHolder &output) const {
vector<TVMValue> values;
vector<int> type_codes;
// The order of inputs and outputs is consistant with definition of TVM
// computation in Python, which is determined by AsTVMVisitor.
values.emplace_back(TVMValue{.v_handle = (void *)&output.first});
type_codes.emplace_back(kTVMDLTensorHandle);
for (auto &in : inputs) {
values.emplace_back(TVMValue{.v_handle = (void *)&in.first});
type_codes.emplace_back(kTVMDLTensorHandle);
}
return {values, type_codes};
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
MemboundTVMPackedFunction,
"Memobund_TVM_Ansor_packed_funciton");
}; // namespace infini
#endif

View File

@ -27,13 +27,16 @@ std::string AsTVMVisitor::visit_(const BinaryOp &c) {
}
}
std::string AsTVMVisitor::visit_(const Func &c) {
string nested = dispatch(c->getObject());
switch (c->getFuncType()) {
case FuncType::Relu:
// TODO: Deduce the dtype
return "te.max(" + dispatch(c->getObject()) +
", tvm.tir.const(0, 'float32'))";
return "te.max(" + nested + ", tvm.tir.const(0, 'float32'))";
case FuncType::Tanh:
return "te.tanh(" + dispatch(c->getObject()) + ")";
return "te.tanh(" + nested + ")";
case FuncType::PRelu:
return "tir.if_then_else(0.0 < " + nested + ", " + nested +
", (0.25 * " + nested + "))";
default:
assert(false);
}
@ -114,6 +117,11 @@ std::string AsTVMVisitor::visit_(const Subscript &c) {
str += " - " +
std::to_string(rangeOp->getLoopVarRanges()[i].second.first -
rangeOp->getPaddings(i));
} else if (c->getObject()->getType() == NodeType::TensorNodeType) {
auto tensor = as<TensorNode>(c->getObject());
if (auto pad_i = tensor->getPadding(i); pad_i > 0) {
str += " + " + std::to_string(pad_i);
}
}
}
str += "]";
@ -138,6 +146,24 @@ std::string AsTVMVisitor::visit_(const Tensor &c) {
}
stmt += "), name='" + c->getName() + "')";
stmts += stmt + "\n";
if (c->hasPadding()) {
std::string name_after_pad = "pad_" + c->getName();
pythonVars.emplace_back(name_after_pad);
// inputs.emplace_back(name_after_pad);
std::string pad_tuple = "(";
for (auto pad : c->getPaddings()) {
pad_tuple += std::to_string(pad) + ", ";
}
pad_tuple += ")";
std::string pad_stmt = name_after_pad + " = " + "topi.nn.pad(" +
c->getName() + ", " + pad_tuple + ", " +
pad_tuple + ", 0.0, \"" + name_after_pad + "\")";
stmts += pad_stmt + "\n";
return name_after_pad;
}
return c->getName();
}
std::string AsTVMVisitor::getStmts() const {

View File

@ -5,7 +5,7 @@
namespace nnet {
Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) {
// FIXME: fix empty expression in membound
assert(kernels.size() >= 1);
if (checkEmpty()) {
@ -27,20 +27,31 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false);
auto summand = curRangeOp->getSummand();
if (auto subscriptOp = as<SubscriptNode>(summand)) {
// Try merging the current and next stages
if (auto mergedExpr = rule4StageMerging(*curExpr, true)) {
// dbg(*curExpr, mergedExpr);
*curExpr = mergedExpr;
merged = true;
break;
}
// If merging fails, try the next stage
curExpr = subscriptOp->getObjectPtr();
nnet_assert(*curExpr != nullptr, __LINE__);
} else if (auto funcOp = as<FuncNode>(summand)) {
// Relu({...}[i,j])
curExpr = funcOp->getObject()->getObjectPtr();
} else
// If the object of FuncNode is a subscript, like
// Relu({...}[i,j]), we can further merge it. Otherwise, like
// Relu(A[i]+B[j]), we cannot.
if (auto sub = as<SubscriptNode>(funcOp->getObject()))
curExpr = sub->getObjectPtr();
else
break;
} else {
if (allowFailure)
return nullptr;
else
nnet_unimplemented_halt();
}
}
} while (merged);
return expr;
}

View File

@ -153,4 +153,11 @@ HashType HashVisitor::visit_(const Var &c) {
return varHash[c];
}
HashType HashVisitor::visit_(const Func &c) {
HashType objHash = dispatch(c->getObject());
return hash(binPrefix,
hash((((HashType)c->getFuncType()) + 10086), objHash));
return 0;
}
} // namespace nnet

View File

@ -90,6 +90,14 @@ size_t TensorNode::getOffset(const vector<int> &idx) {
return offset;
}
bool TensorNode::hasPadding() {
for (auto pad : paddings) {
if (pad > 0)
return true;
}
return false;
}
string RangeOpNode::toReadable() const {
string ret;
for (int i = 0; i < IterationType::NumIterationType; ++i) {
@ -264,10 +272,15 @@ string FuncNode::toReadable() const {
ret += "Relu";
else if (funcType == FuncType::Tanh)
ret += "Tanh";
else if (funcType == FuncType::PRelu)
ret += "PRelu";
else
nnet_unimplemented_halt();
ret += "( ... " + serializeVec(object->getIndex()) + ")\n {" +
object->getObject()->toReadable() + "}";
if (auto sub = as<SubscriptNode>(object))
ret += "( ... " + serializeVec(sub->getIndex()) + ")\n {" +
sub->getObject()->toReadable() + "}";
else
ret += "(" + object->toReadable() + ")";
return ret;
}
@ -380,6 +393,7 @@ int64_t TensorNode::getSize() const {
size *= len;
return size;
}
int RangeOpNode::getPaddings(int dim) const {
return dim < (int)paddings.size() ? paddings[dim] : 0;
}
@ -445,8 +459,8 @@ vector<Range> RangeOpNode::getOutputRanges() const {
}
void FuncNode::setObject(Expr e) {
object = as<SubscriptNode>(e);
nnet_assert(object, "Illegal subscripted object");
nnet_assert(e->isScalar(), "FuncNode operates on scalars");
object = e;
}
} // namespace nnet

View File

@ -574,7 +574,8 @@ Expr ConvTransPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W,
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK =
makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
// makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
makeSubscript(K, {f, (R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, c});
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
auto range1 = makeRangeOperator(

View File

@ -7,15 +7,18 @@
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/reshape.h"
namespace infini {
NMutator::NMutator(Mode mode) : Mutator(10), mode{mode} {
IT_ASSERT(mode != Mode::RuleBased, "Use RuleBased in the other ctor.");
IT_ASSERT(mode != Mode::RuleBased, "Specify rules for the RuleBased mode.");
}
NMutator::NMutator(const std::vector<int> &derivationRules)
: Mutator(10), mode{Mode::RuleBased}, derivationRules{derivationRules} {}
NMutator::NMutator(Mode mode, const std::vector<int> &derivationRules)
: Mutator(10), mode{Mode::RuleBased}, derivationRules{derivationRules} {
IT_ASSERT(mode == Mode::RuleBased);
}
NMutator::~NMutator() {}
@ -69,9 +72,10 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
}
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
IT_TODO_HALT();
// OpVec computeOps = in_graph->getComputeOps();
// if (infini::Graph g = transformTConv1x1(computeOps[0])) {
OpVec computeOps = in_graph->getComputeOps();
IT_ASSERT(computeOps.size() == 1);
/* if (infini::Graph g = transformTConv1x1(computeOps[0])) {
// out_graphs.emplace_back(g);
// return;
// }
@ -95,39 +99,40 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
// // out_graphs.emplace_back(graph);
// // return;
// // }
*/
// auto expr = opToExpression(computeOps[0]);
// if (!expr)
// return;
auto expr = opToExpression(computeOps[0]);
if (!expr)
return;
// nnet::Derivator derivator(maxDepth);
// nnet::Formula conv_9x9(expr, 0);
// // const std::vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; //
// Tconv
// // const std::vector<int> rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM
// if (mode == Mode::Normal) {
// derivator.search(conv_9x9, 0);
// } else if (mode == Mode::RuleBased) {
// dbg(derivationRules);
// derivator.ruleBasedDFS(conv_9x9, 0, derivationRules);
// } else
// nnet_assert(0, "Unknown mode");
// const auto &candidates = derivator.getCandidates();
// dbg(candidates.size());
// // derivator.print();
// for (const auto &candidate : candidates) {
// // dbg(nnet::FullPrinterVisitor().print(candidate.root));
// if (auto g = expressionToGraph(candidate.root, in_graph)) {
// out_graphs.emplace_back(g);
nnet::Derivator derivator(maxDepth);
nnet::Formula conv_9x9(expr, 0);
// const std::vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90};
// ConvTraspose
// const std::vector<int> rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM
if (mode == Mode::Normal) {
derivator.search(conv_9x9, 0);
} else if (mode == Mode::RuleBased) {
dbg(derivationRules);
derivator.ruleBasedDFS(conv_9x9, 0, derivationRules);
} else
IT_TODO_HALT_MSG("Unknown NMutator search mode.");
const auto &candidates = derivator.getCandidates();
dbg(candidates.size());
// derivator.print();
for (const auto &candidate : candidates) {
// dbg(nnet::FullPrinterVisitor().print(candidate.root));
if (auto g = expressionToGraph(candidate.root, in_graph)) {
out_graphs.emplace_back(g);
}
// break; // HACK:Debug only for the first subgraph
}
// dbg(out_graphs);
// for (auto graph : out_graphs) {
// graph->print();
// }
// // break; // HACK:Debug only for the first subgraph
// }
// // dbg(out_graphs);
// // for (auto graph : out_graphs) {
// // graph->print();
// // }
// cntStates += derivator.getNumIntermediateStates();
// cntCandidates += derivator.getNumCandidates();
cntStates += derivator.getNumIntermediateStates();
cntCandidates += derivator.getNumCandidates();
}
void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
@ -245,7 +250,7 @@ nnet::Expr NMutator::opToExpression(Operator op) {
std::vector<int>{0, 0, ph, pw});
const auto K = nnet::makeTensor("K", KT->getDims());
return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
} else if (auto convOp = as<ConvTransposed2dObj>(op)) {
} else if (auto convOp = as<ConvTransposed2dNHWCObj>(op)) {
const auto &AT = convOp->getInputs()[0];
const auto &KT = convOp->getInputs()[1];
inputsNameNToTensorT["A"] = AT;
@ -304,77 +309,97 @@ nnet::Expr NMutator::opToExpression(Operator op) {
}
infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
IT_TODO_HALT();
// auto g = make_ref<GraphObj>();
// nnet::FullPrinterVisitor fullVisitor;
// const auto &tensorQueueN = fullVisitor.traverse(expr);
// // Build tensors: Skip the first one, which is output
// auto nameNToTensorT = inputsNameNToTensorT;
// for (size_t i = 1; i < tensorQueueN.size(); ++i) {
// const auto &[nameN, routineN, tensorN] = tensorQueueN[i];
// // dbg(nameN, routineN, tensorN);
// if (!routineN) {
// // This is an inputs
// assert(nameNToTensorT.count(nameN));
// } else {
// assert(!nameNToTensorT.count(nameN));
// nameNToTensorT[nameN] = g->addTensor(tensorN->getShape());
// }
// }
// const auto &outputsPET = in_graph->getOutputs();
// if (outputsPET.size() != 1) {
// nnet_unimplemented_continue();
// return nullptr;
// }
// nameNToTensorT[std::get<0>(tensorQueueN.at(0))] = outputsPET[0];
// // Build computation graph in PET:
// for (int i = tensorQueueN.size() - 1; i >= 0; --i) {
// const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i];
// if (!routineN)
// continue;
// // dbg(outputNameN, routineN, tensorN, routineN->getType());
// if (auto op = nnet::as<nnet::ConvNode>(routineN)) {
// // g->conv(i8, w9, 2, 2);
// std::vector<nnet::Tensor> inputsN = op->getInputs();
// auto A = nameNToTensorT.at(inputsN[0]->getName());
// auto K = nameNToTensorT.at(inputsN[1]->getName());
// auto output = nameNToTensorT.at(outputNameN);
// const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
// g->conv(A, K, output, ph, pw, sh, sw, dh, dw);
// } else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
// assert(op->getInputs().size() == 1);
// nnet::MatchReshapeVisitor matchReshapeVisitor;
// if (matchReshapeVisitor(op->getExpr())) {
// auto input =
// nameNToTensorT.at(op->getInputs().at(0)->getName());
// auto output = nameNToTensorT.at(outputNameN);
// g->reshape(input, output);
// } else {
// TensorVec inputsPET;
// TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
// for (const auto &inputN : op->getInputs())
// inputsPET.emplace_back(
// nameNToTensorT.at(inputN->getName()));
// // Re-estimate time here.
// ssize_t cnt = 0;
// for (const auto tensor : inputsPET)
// cnt += tensor->size();
// for (const auto tensor : outputsPET)
// cnt += tensor->size();
// g->membound(inputsPET, outputsPET, op->getInputs(),
// op->getExpr(), memboundTime(cnt));
// }
// } else if (auto op = nnet::as<nnet::MatmulNode>(routineN)) {
// assert(op->getInputs().size() == 2);
// nnet::Tensor AN = op->getInputs()[0];
// nnet::Tensor BN = op->getInputs()[1];
// TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
// nameNToTensorT.at(BN->getName())};
// TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
// const auto &[b, m, n, k, transa, transb] = op->getArgs();
// g->matmul(inputsPET[0], inputsPET[1], outputsPET[0], transa,
// transb);
// } else if (auto op = nnet::as<nnet::G2bmmNode>(routineN)) {
auto g = make_ref<GraphObj>(runtime);
nnet::FullPrinterVisitor fullVisitor;
// Get tensors in the reversed topological order
const auto &tensorQueueN = fullVisitor.traverse(expr);
dbg(fullVisitor.print(expr));
// Build a map: name in nnet -> tensors in infini
// Add input tensors to the map
std::map<std::string, Tensor> nameNToTensorT;
for (const auto &[k, v] : inputsNameNToTensorT)
nameNToTensorT[k] = g->cloneTensor(v);
// Add output tensors to the map
const auto &outputsT = in_graph->getOutputs();
if (outputsT.size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
nameNToTensorT[std::get<0>(tensorQueueN.at(0))] =
g->cloneTensor(outputsT[0]);
// Skip the first tensor, which is output and should be created by clone
for (size_t i = 1; i < tensorQueueN.size(); ++i) {
const auto &[nameN, routineN, tensorN] = tensorQueueN[i];
// dbg(nameN, routineN, tensorN);
if (!routineN) {
// this tensor is an input as it is not contrusted by a routine
IT_ASSERT(nameNToTensorT.count(nameN),
"Missing an input tensor in graph or a rountine for this "
"tensor.");
} else { // this tensor is an intermediate result
IT_ASSERT(!nameNToTensorT.count(nameN),
"An NNET tensor appears twice or it is an input tensor "
"with routine specified.");
nameNToTensorT[nameN] = g->addTensor(tensorN->getShape());
}
}
// Build computation graph in InfiniTensor
for (int i = tensorQueueN.size() - 1; i >= 0; --i) {
const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i];
if (!routineN)
continue;
// dbg(outputNameN, routineN, tensorN, routineN->getType());
if (auto op = nnet::as<nnet::ConvNode>(routineN)) {
std::vector<nnet::Tensor> inputsN = op->getInputs();
auto A = nameNToTensorT.at(inputsN[0]->getName());
auto K = nameNToTensorT.at(inputsN[1]->getName());
auto output = nameNToTensorT.at(outputNameN);
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
g->addOpWithOutputs<ConvObj>(A, K, output, ph, pw, sh, sw, dh, dw);
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
assert(op->getInputs().size() == 1);
nnet::MatchReshapeVisitor matchReshapeVisitor;
// If this routine only change the shape, translate it to a Reshape
if (matchReshapeVisitor(op->getExpr())) {
auto input =
nameNToTensorT.at(op->getInputs().at(0)->getName());
auto output = nameNToTensorT.at(outputNameN);
g->addOpWithOutputs<ReshapeObj>(input, output,
output->getDims());
} else {
TensorVec inputsPET;
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
for (const auto &inputN : op->getInputs())
inputsPET.emplace_back(
nameNToTensorT.at(inputN->getName()));
// Re-estimate time here.
ssize_t cnt = 0;
for (const auto &tensor : inputsPET)
cnt += tensor->size();
for (const auto &tensor : outputsPET)
cnt += tensor->size();
dbg(inputsPET, outputsPET, op->getInputs(), op->getExpr(),
memboundTime(cnt));
g->addOpWithOutputs<MemBoundObj>(inputsPET, outputsPET,
op->getInputs(), op->getExpr(),
memboundTime(cnt));
}
} else if (auto op = nnet::as<nnet::MatmulNode>(routineN)) {
assert(op->getInputs().size() == 2);
nnet::Tensor AN = op->getInputs()[0];
nnet::Tensor BN = op->getInputs()[1];
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
nameNToTensorT.at(BN->getName())};
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
const auto &[b, m, n, k, transa, transb] = op->getArgs();
g->addOpWithOutputs<MatmulObj>(inputsPET[0], inputsPET[1],
outputsPET[0], transa, transb);
}
// TODO
// else if (auto op = nnet::as<nnet::G2bmmNode>(routineN)) {
// assert(op->getInputs().size() == 2);
// nnet::Tensor AN = op->getInputs()[0];
// nnet::Tensor BN = op->getInputs()[1];
@ -393,10 +418,10 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
// const auto &[b, m, w, n, dilation] = op->getArgs();
// g->gbmml(inputsPET[0], inputsPET[1], outputsPET[0], dilation);
// }
// }
// g->updateConnection();
// Graph graph = new Graph(g->getOperators());
// return graph;
else
IT_TODO_HALT();
}
return g;
}
double NMutator::memboundTime(ssize_t cnt) {

View File

@ -5,22 +5,24 @@ namespace infini {
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
act(act), b(1) {
act(act) {
auto shape_a = A->getDims();
auto shape_b = B->getDims();
IT_ASSERT(shape_a.size() == shape_b.size());
switch (shape_a.size()) {
case 0:
case 1:
IT_ASSERT(false);
case 2:
break;
default:
int dimA = shape_a.size(), dimB = shape_b.size();
IT_ASSERT(dimA >= 2 && dimB >= 2);
b = 1;
if (dimA <= 3 && dimB <= 3) {
int b1 = dimA == 2 ? 1 : A->getDims()[0];
int b2 = dimB == 2 ? 1 : B->getDims()[0];
b = std::max(b1, b2);
} else {
IT_ASSERT_TODO(dimA == dimB);
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
IT_ASSERT(shape_a[i] == shape_b[i]);
IT_ASSERT_TODO(shape_a[i] == shape_b[i]);
b *= shape_a[i];
}
break;
}
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
@ -38,6 +40,11 @@ string MatmulObj::toString() const {
}
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1];
int dimA = A->getDims().size(), dimB = B->getDims().size();
if (dimA > 3 || dimB > 3) {
// no broadcast
auto shape_a = inputs[0]->getDims();
auto it = shape_a.rbegin();
*it++ = n;
@ -45,6 +52,34 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
return {{std::move(shape_a)}};
}
int b1 = dimA == 2 ? 1 : A->getDims()[0];
int b2 = dimB == 2 ? 1 : B->getDims()[0];
int b = std::max(b1, b2);
int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2];
int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1];
int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1];
int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2];
if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) {
printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB);
return {};
}
if (b1 != 1 && b2 != 1 && b1 != b2) {
printf("Bad batch size b1 = %d, b2 = %d\n", b1, b2);
return {};
}
if (kA != kB) {
printf("Bad K: kA = %d, kB = %d\n", kA, kB);
return {};
}
if (dimA == 2 && dimB == 2) {
return {{{m, n}}};
} else {
return {{{b, m, n}}};
}
}
vector<int> MatmulObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, n, k, transA, transB,
enum_to_underlying(act)};

View File

@ -1,5 +1,7 @@
#include "operators/membound.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
namespace infini {
@ -10,6 +12,19 @@ MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input,
: OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs),
expr(expr), exec_time(exec_time), hint(hint) {
IT_ASSERT(checkValid(graph));
IT_ASSERT(!checkOOB(expr));
hash = calcHash(expr);
// fuse stages in nnet expr to reduce kernels generated by TVM
if (auto mergedExpr =
nnet::MergeMemboundMutator({expr}).merge(false, true)) {
simplifiedExpr = mergedExpr;
IT_ASSERT(!checkOOB(simplifiedExpr));
simplifiedHash = calcHash(simplifiedExpr);
} else {
simplifiedExpr = expr;
simplifiedHash = hash;
}
}
string MemBoundObj::toString() const {
@ -31,8 +46,15 @@ string MemBoundObj::toString() const {
os << "NNet Inputs=[";
for (const auto &tensor : nnetInputs)
os << tensor->toReadable() << ",";
os << "])";
os << "\n" << (expr ? expr->toReadable() : "Empty expression") << "\n";
os << "]";
os << ", ExprHash=" << hash;
os << ", SimplifiedExprHash=" << simplifiedHash;
os << ")\n";
os << ">>> Original expr\n"
<< (expr ? expr->toReadable() : "Empty expression") << "\n";
os << ">>> Simplified expr\n"
<< (simplifiedExpr ? simplifiedExpr->toReadable() : "Empty expression")
<< "\n";
return os.str();
}
@ -47,13 +69,18 @@ optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
}
vector<int> MemBoundObj::getWorkloadVector() const {
return {enum_to_underlying(type), (int)getHash()};
return {enum_to_underlying(type), (int)simplifiedHash};
}
vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }
HashType MemBoundObj::getHash() const {
HashType MemBoundObj::calcHash(nnet::Expr expr) {
return nnet::HashVisitor().dispatch(expr);
}
bool MemBoundObj::checkOOB(nnet::Expr expr) {
return nnet::CheckOOBVisitor().checkRangeOp(
nnet::as<nnet::RangeOpNode>(expr));
}
} // namespace infini

View File

@ -8,7 +8,7 @@ namespace infini {
TEST(Hash, OperatorHash) {
OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown);
{ // build with addOpWithOutputs
Graph g = make_ref<GraphObj>(nullptr);
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
@ -18,7 +18,7 @@ TEST(Hash, OperatorHash) {
EXPECT_GT(key1.attrs.size(), (size_t)5);
}
{ // build with addOp
Graph g = make_ref<GraphObj>(nullptr);
Graph g = make_ref<GraphObj>(NativeCpuRuntimeObj::getInstance());
Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);

View File

@ -1,4 +1,3 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
@ -51,26 +50,38 @@ TEST(cuBLAS_Matmul, run) {
Shape{2, 3, 4}, Shape{2, 3, 2},
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
475, 448, 502, 472, 529});
testMatmulCuda(
IncrementalGenerator(), IncrementalGenerator(), false, false,
Shape{2, 3, 5}, Shape{5, 2},
ExpectOutput{60, 70, 160, 195, 260, 320, 360, 445, 460, 570, 560, 695});
testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false,
Shape{2, 5, 3}, Shape{5, 2},
ExpectOutput{180, 210, 200, 235, 220, 260, 480, 585, 500,
610, 520, 635});
testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), false, false,
Shape{3, 5}, Shape{5, 2},
ExpectOutput{60, 70, 160, 195, 260, 320});
}
TEST(cuBLAS_Matmul, tune) {
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
gCpu->dataMalloc();
ACpu->setData(IncrementalGenerator());
BCpu->setData(IncrementalGenerator());
// Matmul([A^T,B,act=0],A=597,B=595,C=598,bmnk=[1,4,4096,448])
const int B = 1, M = 4, N = 4096, K = 448;
const bool transA = true, transB = false;
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto gCuda = make_ref<GraphObj>(cudaRuntime);
auto ACuda = gCuda->cloneTensor(ACpu);
auto BCuda = gCuda->cloneTensor(BCpu);
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
Graph g = make_ref<GraphObj>(cudaRuntime);
auto a = g->addTensor(transA ? Shape{B, K, M} : Shape{B, M, K});
auto b = g->addTensor(transB ? Shape{B, N, K} : Shape{B, K, N});
// allocate CUDA memory
gCuda->dataMalloc();
cudaRuntime->run(gCuda, true);
g->dataMalloc();
a->setData(IncrementalGenerator());
b->setData(IncrementalGenerator());
auto matmul = g->addOp<MatmulObj>(a, b, nullptr, transA, transB);
matmul->print();
double time = cudaRuntime->getPerfTime(g);
EXPECT_GT(time, 1e-3);
EXPECT_LT(time, 1);
cudaRuntime->run(g, true);
}
}; // namespace infini

View File

@ -13,10 +13,10 @@ TEST(MklBatchNorm, run) {
// Build graph
Graph g = make_ref<GraphObj>(runtime);
auto i = g->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
auto mean = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
auto var = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
auto scale = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
auto bias = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
auto mean = g->addTensor(Shape{3}, DataType::Float32);
auto var = g->addTensor(Shape{3}, DataType::Float32);
auto scale = g->addTensor(Shape{3}, DataType::Float32);
auto bias = g->addTensor(Shape{3}, DataType::Float32);
auto op =
g->addOp<BatchNormObj>(i, nullptr, mean, var, scale, bias, 0.9, 0);
g->dataMalloc();

View File

@ -7,7 +7,8 @@
#include "nnet/routine.h"
#include "nnet/test.h"
#include "operators/matmul.h"
#include <chrono>
#include "operators/membound.h"
#include "test.h"
using namespace infini;
using namespace std;
@ -18,8 +19,8 @@ TEST(nnet, MemboundOpInterpretation) {
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc();
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
auto mutations = nmutator.run(g);
@ -36,7 +37,7 @@ TEST(nnet, MemboundOpInterpretation) {
EXPECT_EQ(membound->getOpType(), OpType::MemBound);
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32, runtime);
ans->dataMalloc();
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
ans->copyin(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
EXPECT_TRUE(membound->getOutput()->equalData(ans));
}
@ -49,8 +50,8 @@ TEST(nnet, MemboundOp_Ansor_Codegen) {
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::Float32);
g->dataMalloc();
i0->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
i0->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
auto mutations = nmutator.run(g);
@ -67,7 +68,7 @@ TEST(nnet, MemboundOp_Ansor_Codegen) {
EXPECT_EQ(membound->getOpType(), OpType::MemBound);
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::Float32, cpu);
ans->dataMalloc();
ans->copyData(vector<float>{38, 44, 50, 56, 83, 98, 113, 128});
ans->copyin(vector<float>{38, 44, 50, 56, 83, 98, 113, 128});
auto oCpu = gCpu->cloneTensor(membound->getOutput());
oCpu->printData();
@ -77,3 +78,41 @@ TEST(nnet, MemboundOp_Ansor_Codegen) {
// double time = timeit([&]() { runtime->run(gNew, false); }); // tune
// kernels std::cout << "Time (ms):" << time << std::endl;
}
pair<std::vector<nnet::Tensor>, nnet::Expr> getPReluExpr(int size) {
using namespace nnet;
using nnet::make_ref;
DEFINE_VAR(i);
auto A = make_ref<TensorNode>("A", vector{size});
auto B = make_ref<TensorNode>("B", vector{size});
Expr e = make_ref<FuncNode>(makeSubscript(A, {i}) - makeSubscript(B, {i}),
FuncType::PRelu);
Expr ret = makeRangeOperator({{i, {0, size}}}, {}, e);
return {{A, B}, ret};
}
TEST(nnet, PRelu_Ansor_Codegen) {
auto cuda = make_ref<CudaRuntimeObj>();
Runtime cpu = NativeCpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(cuda);
Tensor i0 = g->addTensor(vector{12});
Tensor w0 = g->addTensor(vector{12});
Tensor o0 = g->addTensor(vector{12});
auto [nnetInputs, expr] = getPReluExpr(12);
g->addOpWithOutputs<MemBoundObj>(vector{i0, w0}, vector{o0}, nnetInputs,
expr, -1);
g->dataMalloc();
i0->setData(IncrementalGenerator());
w0->setData(ValGenerator<5>());
cuda->run(g, true); // tune kernels
// check answer
auto ans = make_ref<TensorObj>(Shape{12}, DataType::Float32, cpu);
ans->dataMalloc();
ans->copyin(
vector<float>{-1.25, -1., -0.75, -0.5, -0.25, 0, 1, 2, 3, 4, 5, 6});
Graph gCpu = make_ref<GraphObj>(cpu);
auto oCpu = gCpu->cloneTensor(o0);
EXPECT_TRUE(oCpu->equalData(ans));
}

View File

@ -4,19 +4,16 @@
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/expr.h"
#include "nnet/test.h"
#include "gtest/gtest.h"
using namespace nnet;
using namespace std;
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
TEST(FuseMembound, Relu) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
// dilation_heads = 2;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
DEFINE_VAR(b, m, w, k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
@ -35,10 +32,7 @@ TEST(FuseMembound, MemMemFusion) {
const int n_heads = 8, seq_len = 100, feat_len = 100;
// dilation_heads = 2;
const int Batch = n_heads, M = seq_len, K = feat_len;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
DEFINE_VAR(b, m, w, k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, K, M}),
@ -55,3 +49,25 @@ TEST(FuseMembound, MemMemFusion) {
{{k, {0, K}}}, makeSubscript(A, {b, m, k}));
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}
TEST(FuseMembound, mergeNestedStagesInRangeOp) {
// Case in ConvTranspose to Matmul
// L<f:0:448><i39:0:4096>Sum ... [i39,f]
// {L<i39:0:4096><f:0:448>Sum ... [f,(i39 / 1024),((i39 / 256) % 4),(i39
// % 256)] {K}}
DEFINE_VAR(f, i);
const int I = 4096, F = 448;
auto K = make_ref<TensorNode>("K", vector<int>({448, 4, 4, 256}));
auto subA = makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256});
auto range = makeRangeOperator({{i, {0, I}}, {f, {0, F}}}, {}, subA);
auto outerRange = makeRangeOperator({{f, {0, F}}, {i, {0, I}}}, {},
makeSubscript(range, {i, f}));
auto merged = MergeMemboundMutator({outerRange}).merge();
// Compare the result with answer
RangeOp ans = makeRangeOperator(
{{f, {0, F}}, {i, {0, I}}}, {},
makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}));
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}

View File

@ -56,45 +56,72 @@ TEST(Mutator, NaiveConvWithInterpreter) {
// FIXME: failed since implicit transpose for DLT
TEST(Mutator, InfoGAN_TConv_3_correctness) {
// verifyNaiveMembound True: subgraph after transformation
// verifyNaiveMembound False: subgraph of one single membound (eOP)
// const bool verifyNaiveMembound = false;
const bool useMutatorDirectly = true;
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
// {n, h, w, f} * {f, r, s, c}
auto i0 = g->addTensor({1, 2, 2, 448});
auto w0 = g->addTensor({448, 4, 4, 256});
const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4;
// // Minimum config for test
// const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4;
// const int n = 1, c = 2, h = 2, w = 2, f = 2, r = 4, s = 4;
auto i0 = g->addTensor({n, h, w, f});
auto w0 = g->addTensor({f, r, s, c});
g->addOp<ConvTransposed2dNHWCObj>(i0, w0, nullptr, 1, 1, 2, 2, 1, 1);
auto mutator = make_ref<NMutator>();
mutator->setToNaiveMembound();
SearchEngine searchEngine(runtime, mutator);
auto bestGraph = searchEngine.run(g);
bestGraph->print();
printf("--- SearchEngine Finished ---\n");
auto mutator =
make_ref<NMutator>(NMutator::Mode::RuleBased,
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
// // Translate OP to membound without derivation
// mutator->setToNaiveMembound();
vector<Graph> bestGraphs;
if (useMutatorDirectly) { // Use mutator results
bestGraphs = mutator->run(g);
} else { // Use search engine results
SearchEngine searchEngine(runtime, mutator);
bestGraphs.emplace_back(searchEngine.run(g));
}
g->dataMalloc();
map<UidBaseType, Tensor> fuidToInputTensor;
for (auto t : g->getInputs()) {
EXPECT_EQ(fuidToInputTensor.count(t->getFuid()), 0);
fuidToInputTensor[t->getFuid()] = t;
}
for (size_t i = 0; i < bestGraphs.size(); i++) {
auto bestGraphCpu = bestGraphs[i];
auto bestGraph =
make_ref<GraphObj>(runtime, bestGraphCpu->getOperators());
auto gen = RandomGenerator(0, 1, i);
bestGraph->dataMalloc();
for (auto t : g->getTensors()) {
if (t->getFuid() <= 2)
t->setData(IncrementalGenerator());
// Initialize inputs with random data
for (auto t : g->getInputs()) {
t->setData(gen);
}
for (auto t : bestGraph->getTensors()) {
if (t->getFuid() <= 2)
t->setData(IncrementalGenerator());
for (auto t : bestGraph->getInputs()) {
t->copyData(fuidToInputTensor[t->getFuid()]);
}
// Initialize outputs with zeros
for (auto t : g->getOutputs()) {
t->setData(ZeroGenerator());
}
for (auto t : bestGraph->getOutputs()) {
t->setData(ZeroGenerator());
}
runtime->run(bestGraph, true); // Tune kernels
runtime->run(g);
runtime->run(bestGraph);
runtime->run(bestGraph, false); // Execute transfomraed graph
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
EXPECT_TRUE(go0->equalData(bgo0));
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
EXPECT_TRUE(go0->equalData(bgo0, 1e-4));
}
}
// TEST(Mutator, Conv9x9) {

View File

@ -9,7 +9,7 @@ then
elif [ "$1" == "intelcpu" ]
then
echo "Load INTELCPU environment."
spack load intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0
spack load gcc@12.1.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0
# The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly.
export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH
@ -20,7 +20,7 @@ then
# Preloading the missing libs will work, refered to https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/mkl-fails-to-load/m-p/1155538
export MKLLIB_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-mkl-2022.1.0-mf6te62fo6wxlo33jwwwgg5kljoagc6g/mkl/2022.1.0/
export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-11.3.0/intel-oneapi-compilers-2022.1.0-qrq4a63scjip455bpxvl5ipgqbllwecj/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so
export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-compilers-2022.1.0-6k6zm3h4qcsni27nihc4b6wuqgtxqxqa/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so
else
echo "Bad option. Please enter 'cuda' or 'intelcpu'. CUDA will be loaded by default if nothing specified."
fi