forked from jiuyuan/InfiniTensor
Add TVM codegen for MemboundOp (#35)
* Add: interface for membound TVM kernel and test * add getAnsorCode * add evaluation, but link failed * add evaluation of kernel, but link failed * Fix: link libcuda and nvrtc * add print * Add: const for source of copy * compile and evaluate the kernel * add compute * fix gen_ansor_op.py * fix membound_TVM * format and fix CMakeLists.txt * fix memory leak Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com> Co-authored-by: huangshuhong <huangsh19@mails.tsinghua.edu.cn>
This commit is contained in:
parent
ba0b11a499
commit
11d5aa1ccc
|
@ -1,4 +1,4 @@
|
||||||
cmake_minimum_required(VERSION 3.10) # Required by CMAKE_CUDA_HOST_COMPILER
|
cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit
|
||||||
include(CMakeDependentOption)
|
include(CMakeDependentOption)
|
||||||
project(InfiniTensor C CXX)
|
project(InfiniTensor C CXX)
|
||||||
|
|
||||||
|
@ -52,12 +52,11 @@ if(USE_PROTOBUF)
|
||||||
target_link_libraries(tensor_proto PUBLIC ${PROTOBUF_LIBRARIES})
|
target_link_libraries(tensor_proto PUBLIC ${PROTOBUF_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
include_directories(include)
|
include_directories(include)
|
||||||
|
|
||||||
# # Pybind11
|
# Pybind11
|
||||||
# add_subdirectory(3rd-party/pybind11)
|
add_subdirectory(3rd-party/pybind11)
|
||||||
# include_directories(3rd-party/pybind11/include)
|
include_directories(3rd-party/pybind11/include)
|
||||||
|
|
||||||
# nlohmann_json
|
# nlohmann_json
|
||||||
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||||
|
@ -75,7 +74,7 @@ if(BUILD_TEST)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Source files
|
# Source files
|
||||||
file(GLOB_RECURSE SRC src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc)
|
file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc)
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
file(GLOB_RECURSE SRC_CUDA src/cuda/*.cc src/cuda/*.cu src/kernels/cuda/*.cc src/kernels/cuda/*.cu)
|
file(GLOB_RECURSE SRC_CUDA src/cuda/*.cc src/cuda/*.cu src/kernels/cuda/*.cc src/kernels/cuda/*.cu)
|
||||||
|
@ -93,6 +92,8 @@ if(USE_PROTOBUF)
|
||||||
target_link_libraries(InfiniTensor tensor_proto)
|
target_link_libraries(InfiniTensor tensor_proto)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
target_link_libraries(InfiniTensor pybind11::embed)
|
||||||
|
|
||||||
if(USE_BACKTRACE)
|
if(USE_BACKTRACE)
|
||||||
add_definitions(-D BACKWARD_TRACE)
|
add_definitions(-D BACKWARD_TRACE)
|
||||||
add_subdirectory(3rd-party/backward-cpp)
|
add_subdirectory(3rd-party/backward-cpp)
|
||||||
|
@ -109,10 +110,15 @@ if(USE_CUDA)
|
||||||
# CMP0104 requires CUDA_ARCHITECTURES
|
# CMP0104 requires CUDA_ARCHITECTURES
|
||||||
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
# TODO: find_package seems unnecessary for CMake >= 3.8
|
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
||||||
find_package(CUDA REQUIRED)
|
target_link_libraries(
|
||||||
# message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}")
|
InfiniTensor
|
||||||
target_link_libraries(InfiniTensor cudnn curand cublas ${CUDA_LIBRARIES})
|
cudnn
|
||||||
|
CUDA::curand
|
||||||
|
CUDA::cublas
|
||||||
|
CUDA::nvrtc
|
||||||
|
CUDA::cudart
|
||||||
|
CUDA::cuda_driver)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_BANG)
|
if(USE_BANG)
|
||||||
|
|
|
@ -66,14 +66,17 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||||
bool isCuda() const { return device == Device::CUDA; }
|
bool isCuda() const { return device == Device::CUDA; }
|
||||||
bool isBang() const { return device == Device::BANG; }
|
bool isBang() const { return device == Device::BANG; }
|
||||||
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
||||||
|
// TODO: unify these copy APIs
|
||||||
|
virtual void copyBlobFromCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void printProfilingData(double totTime,
|
void printProfilingData(double totTime,
|
||||||
const std::map<OpType, double> &opTime,
|
const std::map<OpType, double> &opTime,
|
||||||
const std::map<OpType, int> &opCnt) const;
|
const std::map<OpType, int> &opCnt) const;
|
||||||
virtual void copyBlobFromCPU(void *dst, void *src, size_t bytes) const = 0;
|
virtual void copyBlobToCPU(void *dst, const void *src,
|
||||||
virtual void copyBlobToCPU(void *dst, void *src, size_t bytes) const = 0;
|
size_t bytes) const = 0;
|
||||||
virtual void copyBlobInsideRuntime(void *dst, void *src,
|
virtual void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const = 0;
|
size_t bytes) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -94,9 +97,10 @@ class CpuRuntimeObj : public RuntimeObj {
|
||||||
sizeof(uint64_t));
|
sizeof(uint64_t));
|
||||||
};
|
};
|
||||||
|
|
||||||
void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override;
|
void copyBlobFromCPU(void *dst, const void *src,
|
||||||
void copyBlobToCPU(void *dst, void *src, size_t bytes) const override;
|
size_t bytes) const override;
|
||||||
void copyBlobInsideRuntime(void *dst, void *src,
|
void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override;
|
||||||
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const override;
|
size_t bytes) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -32,14 +32,7 @@ class TensorObj : public TensorBaseObj {
|
||||||
template <typename T> void copyData(const T *dptr) {
|
template <typename T> void copyData(const T *dptr) {
|
||||||
IT_ASSERT(DataType::get<T>() == dtype);
|
IT_ASSERT(DataType::get<T>() == dtype);
|
||||||
IT_ASSERT(data != nullptr);
|
IT_ASSERT(data != nullptr);
|
||||||
if (!runtime->isCpu())
|
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), dptr, getBytes());
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ptr = data->getPtr<T *>();
|
|
||||||
size_t sz = size();
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (size_t i = 0; i < sz; ++i) {
|
|
||||||
ptr[i] = dptr[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> void copyData(vector<T> dataVector) {
|
template <typename T> void copyData(vector<T> dataVector) {
|
||||||
|
|
|
@ -17,6 +17,18 @@
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define checkCUresult(call) \
|
||||||
|
{ \
|
||||||
|
auto err = call; \
|
||||||
|
const char *errName; \
|
||||||
|
if (CUDA_SUCCESS != err) { \
|
||||||
|
cuGetErrorString(err, &errName); \
|
||||||
|
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
|
||||||
|
errName); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define checkCublasError(call) \
|
#define checkCublasError(call) \
|
||||||
{ \
|
{ \
|
||||||
auto err = call; \
|
auto err = call; \
|
||||||
|
|
|
@ -11,6 +11,10 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
CudaPtr workspace;
|
CudaPtr workspace;
|
||||||
size_t workspaceSize;
|
size_t workspaceSize;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUdevice cuDevice;
|
||||||
|
CUcontext newContext;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
||||||
checkCudnnError(cudnnCreate(&cudnn));
|
checkCudnnError(cudnnCreate(&cudnn));
|
||||||
|
@ -19,11 +23,16 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
// size_t longformerNum = 3lu * (1 << 30);
|
// size_t longformerNum = 3lu * (1 << 30);
|
||||||
workspaceSize = 7ll << 30; // 7 GB
|
workspaceSize = 7ll << 30; // 7 GB
|
||||||
workspace = alloc(workspaceSize);
|
workspace = alloc(workspaceSize);
|
||||||
|
|
||||||
|
checkCUresult(cuInit(0));
|
||||||
|
checkCUresult(cuDeviceGet(&cuDevice, 0));
|
||||||
|
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
|
||||||
}
|
}
|
||||||
virtual ~CudaRuntimeObj() {
|
virtual ~CudaRuntimeObj() {
|
||||||
dealloc(workspace);
|
dealloc(workspace);
|
||||||
checkCudnnError(cudnnDestroy(cudnn));
|
checkCudnnError(cudnnDestroy(cudnn));
|
||||||
checkCublasError(cublasDestroy(cublas));
|
checkCublasError(cublasDestroy(cublas));
|
||||||
|
checkCUresult(cuCtxDestroy(newContext));
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(const Graph &graph, bool tune = false,
|
void run(const Graph &graph, bool tune = false,
|
||||||
|
@ -44,15 +53,17 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
return workspace;
|
return workspace;
|
||||||
}
|
}
|
||||||
|
|
||||||
void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override {
|
void copyBlobFromCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const override {
|
||||||
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyHostToDevice));
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyHostToDevice));
|
||||||
}
|
}
|
||||||
|
|
||||||
void copyBlobToCPU(void *dst, void *src, size_t bytes) const override {
|
void copyBlobToCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const override {
|
||||||
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost));
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost));
|
||||||
}
|
}
|
||||||
|
|
||||||
void copyBlobInsideRuntime(void *dst, void *src,
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const override {
|
size_t bytes) const override {
|
||||||
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice));
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
#pragma once
|
||||||
|
#include <pybind11/embed.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void start_interpreter();
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1 @@
|
||||||
|
from .gen_ansor_op import gen_ansor_op
|
|
@ -0,0 +1,94 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tvm
|
||||||
|
from tvm import te, tir, auto_scheduler, topi
|
||||||
|
|
||||||
|
|
||||||
|
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name):
|
||||||
|
assert len(input_tensors) == len(input_dtypes)
|
||||||
|
assert len(input_tensors) == len(input_names)
|
||||||
|
|
||||||
|
print("Generating Ansor op: ")
|
||||||
|
print(f)
|
||||||
|
|
||||||
|
@auto_scheduler.register_workload(func_name)
|
||||||
|
def compute():
|
||||||
|
_locals = locals()
|
||||||
|
exec(f, {'tvm': tvm, 'te': te, 'tir': tir}, _locals)
|
||||||
|
return _locals['ret']
|
||||||
|
|
||||||
|
target = tvm.target.Target("cuda")
|
||||||
|
|
||||||
|
task = auto_scheduler.SearchTask(func=func_name, args=(), target=target)
|
||||||
|
|
||||||
|
# Inspect the computational graph
|
||||||
|
print("Computational DAG:")
|
||||||
|
print(task.compute_dag)
|
||||||
|
|
||||||
|
log_file = f"ansor_{func_name}_log.json"
|
||||||
|
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
|
||||||
|
tune_option = auto_scheduler.TuningOptions(
|
||||||
|
num_measure_trials=10,
|
||||||
|
runner=measure_ctx.runner,
|
||||||
|
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
|
||||||
|
verbose=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run auto-tuning (search)
|
||||||
|
task.tune(tune_option)
|
||||||
|
# Apply the best schedule
|
||||||
|
sch, args = task.apply_best(log_file)
|
||||||
|
|
||||||
|
# Kill the measurement process
|
||||||
|
del measure_ctx
|
||||||
|
|
||||||
|
ir = str(tvm.lower(sch, args, simple_mode=True))
|
||||||
|
thread_dim = [1, 1, 1]
|
||||||
|
block_dim = [1, 1, 1]
|
||||||
|
p = re.compile('"thread_extent" = (\d+)')
|
||||||
|
for line in ir.splitlines():
|
||||||
|
if "thread_extent" in line:
|
||||||
|
ext = int(p.search(line).group(1))
|
||||||
|
if "threadIdx.x" in line:
|
||||||
|
thread_dim[0] = ext
|
||||||
|
elif "threadIdx.y" in line:
|
||||||
|
thread_dim[1] = ext
|
||||||
|
elif "threadIdx.z" in line:
|
||||||
|
thread_dim[2] = ext
|
||||||
|
elif "blockIdx.x" in line:
|
||||||
|
block_dim[0] = ext
|
||||||
|
elif "blockIdx.y" in line:
|
||||||
|
block_dim[1] = ext
|
||||||
|
elif "blockIdx.z" in line:
|
||||||
|
block_dim[2] = ext
|
||||||
|
|
||||||
|
func = tvm.build(sch, args, target, name=func_name)
|
||||||
|
|
||||||
|
func_code = func.imported_modules[0].get_source()
|
||||||
|
invoke_code = "%s_kernel0<<<dim3(%s), dim3(%s)>>>(%s, %s);" % (
|
||||||
|
func_name, ", ".join(map(str, block_dim)), ", ".join(
|
||||||
|
map(str, thread_dim)),
|
||||||
|
output_name, ", ".join(input_names))
|
||||||
|
invoke_params = block_dim + thread_dim
|
||||||
|
|
||||||
|
ctx = tvm.cuda(0)
|
||||||
|
input_a = []
|
||||||
|
for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)):
|
||||||
|
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||||
|
input_a.append(tvm.nd.array(a_np, ctx))
|
||||||
|
a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx)
|
||||||
|
func(a_out, *input_a)
|
||||||
|
evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
|
||||||
|
conv_time = evaluator(a_out, *input_a).mean * 1e3
|
||||||
|
|
||||||
|
print("Func Code")
|
||||||
|
# Attach TVM code behind func_code
|
||||||
|
func_code += "\n/* " + f + "*/"
|
||||||
|
print(func_code)
|
||||||
|
print("Invoke Code")
|
||||||
|
print(invoke_code)
|
||||||
|
print("Time")
|
||||||
|
print(conv_time)
|
||||||
|
|
||||||
|
return func_code, invoke_code, conv_time, invoke_params # ms
|
|
@ -124,15 +124,17 @@ void RuntimeObj::copyBlob(const TensorObj *dst, const TensorObj *src) const {
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CpuRuntimeObj::copyBlobFromCPU(void *dst, void *src, size_t bytes) const {
|
void CpuRuntimeObj::copyBlobFromCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const {
|
||||||
copyBlobInsideRuntime(dst, src, bytes);
|
copyBlobInsideRuntime(dst, src, bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CpuRuntimeObj::copyBlobToCPU(void *dst, void *src, size_t bytes) const {
|
void CpuRuntimeObj::copyBlobToCPU(void *dst, const void *src,
|
||||||
|
size_t bytes) const {
|
||||||
copyBlobInsideRuntime(dst, src, bytes);
|
copyBlobInsideRuntime(dst, src, bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, void *src,
|
void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const {
|
size_t bytes) const {
|
||||||
memcpy(dst, src, bytes);
|
memcpy(dst, src, bytes);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
#include "ffi/ffi_embed.h"
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void start_interpreter() {
|
||||||
|
static py::scoped_interpreter *interpretor = nullptr;
|
||||||
|
if (interpretor == nullptr) {
|
||||||
|
interpretor = new py::scoped_interpreter(); // guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,191 @@
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "ffi/ffi_embed.h"
|
||||||
|
#include "nnet/Visitor/AsTVMVisitor.h"
|
||||||
|
#include "nvrtc.h"
|
||||||
|
#include "operators/membound.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class TVMRecordObj : public PerfRecordObj {
|
||||||
|
// TODO: Add more attrs
|
||||||
|
public:
|
||||||
|
size_t logSize, ptxSize;
|
||||||
|
std::string log, ptx;
|
||||||
|
std::vector<int> invokeParams;
|
||||||
|
std::string kernelName;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TVMRecord = Ref<TVMRecordObj>;
|
||||||
|
|
||||||
|
class MemboundTVM : public Kernel {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op, const PerfRecord &record,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<MemBoundObj>(_op);
|
||||||
|
// auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);
|
||||||
|
|
||||||
|
// prepare for evaluation
|
||||||
|
CUmodule module;
|
||||||
|
CUfunction kernel;
|
||||||
|
checkCUresult(cuModuleLoadDataEx(&module, tvmRecord->ptx.data(), 0,
|
||||||
|
nullptr, nullptr));
|
||||||
|
checkCUresult(cuModuleGetFunction(&kernel, module,
|
||||||
|
tvmRecord->kernelName.c_str()));
|
||||||
|
std::vector<void *> args;
|
||||||
|
for (auto &&in : op->getInputs()) {
|
||||||
|
args.push_back(in->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
|
args.push_back(op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
std::vector<void *> argsPtr;
|
||||||
|
for (auto &arg : args) {
|
||||||
|
argsPtr.push_back(&arg);
|
||||||
|
}
|
||||||
|
auto invokeParams = tvmRecord->invokeParams;
|
||||||
|
|
||||||
|
// begin evaluation
|
||||||
|
cuLaunchKernel(kernel, invokeParams[0], invokeParams[1],
|
||||||
|
invokeParams[2], invokeParams[3], invokeParams[4],
|
||||||
|
invokeParams[5], 0, NULL, argsPtr.data(), 0);
|
||||||
|
|
||||||
|
// free module
|
||||||
|
checkCUresult(cuModuleUnload(module));
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
IT_ASSERT(false, "A TVM record is required for membound kernel.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getVarName(const Tensor &t) const {
|
||||||
|
return "var_" + std::to_string(t->getGuid());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Premise: op is idempotent since it is called multiple times.
|
||||||
|
PerfRecord tune(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
TVMRecord ret = std::make_shared<TVMRecordObj>();
|
||||||
|
auto op = as<MemBoundObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
// invoke Ansor to tune a membound kernel
|
||||||
|
std::string func = "mem_bound_" + std::to_string(op->getGuid());
|
||||||
|
std::string kernelName = func + "_kernel0";
|
||||||
|
nnet::AsTVMVisitor visitor;
|
||||||
|
visitor.dispatch(op->getNnetExpr());
|
||||||
|
auto &&stmts = visitor.getStmts();
|
||||||
|
auto &&inShapes = visitor.getInputShapes();
|
||||||
|
auto &&outShape = visitor.getOutputShape();
|
||||||
|
|
||||||
|
std::vector<std::string> inputs;
|
||||||
|
for (auto &&in : op->getInputs()) {
|
||||||
|
inputs.emplace_back(getVarName(in));
|
||||||
|
}
|
||||||
|
std::string output = getVarName(op->getOutput());
|
||||||
|
auto res = getAnsorCode(
|
||||||
|
inShapes, std::vector<std::string>(inShapes.size(), "float32"),
|
||||||
|
outShape, "float32", stmts, func, inputs, output);
|
||||||
|
|
||||||
|
// compile the kernel
|
||||||
|
auto funcCode = res.first;
|
||||||
|
auto invokeParams = res.second;
|
||||||
|
std::string fileName = func + ".cu";
|
||||||
|
nvrtcProgram prog;
|
||||||
|
nvrtcCreateProgram(&prog, // prog
|
||||||
|
funcCode.c_str(), // buffer
|
||||||
|
fileName.c_str(), // name
|
||||||
|
0, // numHeaders
|
||||||
|
NULL, // headers
|
||||||
|
NULL); // includeNames
|
||||||
|
const char *opts[] = {"--gpu-architecture=compute_80", "--fmad=false"};
|
||||||
|
nvrtcCompileProgram(prog, // prog
|
||||||
|
2, // numOptions
|
||||||
|
opts); // options
|
||||||
|
|
||||||
|
// copy ptx and log to ret
|
||||||
|
size_t logSize;
|
||||||
|
nvrtcGetProgramLogSize(prog, &logSize);
|
||||||
|
size_t ptxSize;
|
||||||
|
nvrtcGetPTXSize(prog, &ptxSize);
|
||||||
|
ret->logSize = logSize;
|
||||||
|
ret->ptxSize = ptxSize;
|
||||||
|
ret->log = std::string(logSize, ' ');
|
||||||
|
ret->ptx = std::string(ptxSize, ' ');
|
||||||
|
nvrtcGetProgramLog(prog, ret->log.data());
|
||||||
|
nvrtcGetPTX(prog, ret->ptx.data());
|
||||||
|
ret->invokeParams = invokeParams;
|
||||||
|
ret->kernelName = kernelName;
|
||||||
|
|
||||||
|
// prepare for evaluation
|
||||||
|
CUmodule module;
|
||||||
|
CUfunction kernel;
|
||||||
|
checkCUresult(
|
||||||
|
cuModuleLoadDataEx(&module, ret->ptx.data(), 0, nullptr, nullptr));
|
||||||
|
checkCUresult(cuModuleGetFunction(&kernel, module, kernelName.c_str()));
|
||||||
|
std::vector<void *> args;
|
||||||
|
for (auto &&in : op->getInputs()) {
|
||||||
|
args.push_back(in->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
|
args.push_back(op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
std::vector<void *> argsPtr;
|
||||||
|
for (auto &arg : args) {
|
||||||
|
argsPtr.push_back(&arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evaluate the kernel
|
||||||
|
ret->time = timeit(
|
||||||
|
[&]() {
|
||||||
|
// TODO: run the kernel
|
||||||
|
cuLaunchKernel(kernel, invokeParams[0], invokeParams[1],
|
||||||
|
invokeParams[2], invokeParams[3],
|
||||||
|
invokeParams[4], invokeParams[5], 0, NULL,
|
||||||
|
argsPtr.data(), 0);
|
||||||
|
},
|
||||||
|
[&]() { context->sync(); });
|
||||||
|
|
||||||
|
// free module
|
||||||
|
checkCUresult(cuModuleUnload(module));
|
||||||
|
nvrtcDestroyProgram(&prog);
|
||||||
|
|
||||||
|
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, std::vector<int>>
|
||||||
|
getAnsorCode(const std::vector<std::vector<int>> &inDims,
|
||||||
|
const std::vector<std::string> &inDTypes,
|
||||||
|
const std::vector<int> &outDims, const std::string &outDType,
|
||||||
|
const std::string &lambda, const std::string &funcName,
|
||||||
|
const std::vector<std::string> &inputNames,
|
||||||
|
const std::string &outputName) const {
|
||||||
|
std::string funcCode;
|
||||||
|
std::vector<int> invokeParams;
|
||||||
|
try {
|
||||||
|
start_interpreter();
|
||||||
|
auto func = py::module::import("cpp_plugin").attr("gen_ansor_op");
|
||||||
|
py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda,
|
||||||
|
funcName, inputNames, outputName);
|
||||||
|
funcCode = py::str(code[0]);
|
||||||
|
auto temp = py::list(code[3]);
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
invokeParams.push_back(temp[i].cast<int>());
|
||||||
|
}
|
||||||
|
} catch (py::error_already_set &e) {
|
||||||
|
if (e.matches(PyExc_ImportError)) {
|
||||||
|
std::cerr << "Import Error. Don't forget to set environment "
|
||||||
|
"variable PYTHONPATH to contain "
|
||||||
|
"<repo-root>/python"
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
return std::make_pair(funcCode, invokeParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, MemboundTVM,
|
||||||
|
"Memobund_TVM_Ansor");
|
||||||
|
}; // namespace infini
|
|
@ -1,5 +1,6 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
||||||
#include "nnet/expr.h"
|
#include "nnet/expr.h"
|
||||||
#include "nnet/nmutator.h"
|
#include "nnet/nmutator.h"
|
||||||
|
@ -37,4 +38,42 @@ TEST(nnet, MemboundOpInterpretation) {
|
||||||
ans->dataMalloc();
|
ans->dataMalloc();
|
||||||
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||||
EXPECT_TRUE(membound->getOutput()->equalData(ans));
|
EXPECT_TRUE(membound->getOutput()->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(nnet, MemboundOp_Ansor_Codegen) {
|
||||||
|
auto runtime = make_ref<CudaRuntimeObj>();
|
||||||
|
Runtime cpu = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||||
|
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32);
|
||||||
|
Tensor o0 = g->addTensor({1, 2, 4}, DataType::Float32);
|
||||||
|
g->dataMalloc();
|
||||||
|
i0->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
w0->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
|
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
|
||||||
|
auto mutations = nmutator.run(g);
|
||||||
|
ASSERT_EQ(mutations.size(), 1u);
|
||||||
|
Graph gNew = mutations[0];
|
||||||
|
gNew->print();
|
||||||
|
gNew->dataMalloc();
|
||||||
|
runtime->run(gNew, true); // tune kernels
|
||||||
|
|
||||||
|
// check answer
|
||||||
|
auto ops = gNew->getOperators();
|
||||||
|
EXPECT_EQ(ops.size(), 1u);
|
||||||
|
auto membound = ops[0];
|
||||||
|
EXPECT_EQ(membound->getOpType(), OpType::MemBound);
|
||||||
|
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::Float32, cpu);
|
||||||
|
ans->dataMalloc();
|
||||||
|
ans->copyData(vector<float>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||||
|
|
||||||
|
auto oCpu = gCpu->cloneTensor(membound->getOutput());
|
||||||
|
oCpu->printData();
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ans));
|
||||||
|
|
||||||
|
// Timing
|
||||||
|
// double time = timeit([&]() { runtime->run(gNew, false); }); // tune
|
||||||
|
// kernels std::cout << "Time (ms):" << time << std::endl;
|
||||||
}
|
}
|
Loading…
Reference in New Issue