diff --git a/CMakeLists.txt b/CMakeLists.txt index f5fd43ce..524c0aca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,14 +1,15 @@ -cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit +cmake_minimum_required(VERSION 3.12) # FindCUDAToolkit include(CMakeDependentOption) project(InfiniTensor C CXX) # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. option(USE_CUDA "Support CUDA GPU" OFF) option(USE_BANG "Support BANG MLU" OFF) +option(USE_XPU "Support KUNLUN MLU" ON) option(USE_INTELCPU "Support INTELCPU" OFF) -option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) +option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" OFF) option(USE_PROTOBUF "Serialize and deserialize tensors" OFF) -option(BUILD_TEST "Build tests" OFF) +option(BUILD_TEST "Build tests" ON) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) @@ -110,6 +111,11 @@ if(USE_BANG) list (APPEND SRC ${SRC_BANG}) endif() +if(USE_XPU) + file(GLOB_RECURSE SRC_XPU src/xpu/*.cc src/kernels/xpu/*.cc ) + list (APPEND SRC ${SRC_XPU}) +endif() + if(USE_INTELCPU) file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc ) list (APPEND SRC ${SRC_INTELCPU}) @@ -218,6 +224,37 @@ if(USE_BANG) target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) endif() +if(USE_XPU) + add_compile_definitions(USE_XPU=1) + include_directories(src/kernels/xpu/include) + if ((NOT DEFINED XPU_HOME) AND (NOT DEFINED ENV{XPU_HOME})) + message(FATAL_ERROR "XPU_HOME is not defined from cmake or env") + elseif (DEFINED XPU_HOME) + set(XPU_HOME ${XPU_HOME} CACHE STRING "XPU_HOME directory for Kunlun development") + else() + set(XPU_HOME $ENV{XPU_HOME} CACHE STRING "XPU_HOME directory for Kunlun development") + endif() + message(STATUS "XPU_HOME: ${XPU_HOME}") + + include_directories("${XPU_HOME}/include/") + include_directories("${XPU_HOME}/XTDK/include/") + find_library(XPU_RT libxpurt.so "${XPU_HOME}/lib64") + find_library(XPU_DNN libxpuapi.so "${XPU_HOME}/XTDK/shlib") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") + + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) + execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) + set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") + elseif(DEFINED TARGET_CPU_ARCH) + set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + else() + set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH") + endif() + message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}") + + target_link_libraries(InfiniTensor ${XPU_RT} ${XPU_DNN} stdc++) +endif() + # # Python bindings # pybind11_add_module(infini MODULE ${FFI}) # target_link_libraries(infini PRIVATE infini_cpp) @@ -247,6 +284,9 @@ if(BUILD_TEST) if (USE_BANG) build_test(test/kernels/bang/*.cc) endif() + if (USE_XPU) + build_test(test/kernels/xpu/*.cc) + endif() if (USE_INTELCPU) build_test(test/kernels/intelcpu/*.cc) endif() diff --git a/include/core/runtime.h b/include/core/runtime.h index 53920fdb..a0158ab6 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -29,7 +29,7 @@ using OpLists = list; using VType = uint32_t; -enum class Device { CPU = 1, CUDA, BANG, INTELCPU }; +enum class Device { CPU = 1, CUDA, BANG, INTELCPU, XPU }; /***************** Forward declaration end *****************/ class RuntimeObj : public std::enable_shared_from_this { @@ -69,6 +69,7 @@ class RuntimeObj : public std::enable_shared_from_this { } bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } + bool isXPU() const { return device == Device::XPU; } void copyBlob(const TensorObj *dst, const TensorObj *src) const; // TODO: unify these copy APIs virtual void copyBlobFromCPU(void *dst, const void *src, diff --git a/include/xpu/operator_timer.h b/include/xpu/operator_timer.h new file mode 100644 index 00000000..15ebce94 --- /dev/null +++ b/include/xpu/operator_timer.h @@ -0,0 +1,10 @@ +#pragma once +namespace infini { +namespace opTimer { +double getPerfConvXdnn(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name); +double getPerfMatmulXdnn(int b, int m, int n, int k, const char *name); +} // namespace opTimer +} // namespace infini diff --git a/include/xpu/xpu_common.h b/include/xpu/xpu_common.h new file mode 100644 index 00000000..743767d7 --- /dev/null +++ b/include/xpu/xpu_common.h @@ -0,0 +1,20 @@ +#pragma once +#include "xpu/runtime_ex.h" +#include "xpu/xdnn.h" +#include "core/common.h" + +#define checkXPUError(call) \ + { \ + auto err = call; \ + if (XPU_SUCCESS != err) { \ + fprintf(stderr, "XPU error in %s:%i : %s.\n", __FILE__, __LINE__, \ + xpu_strerror(err)); \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace infini { + +using XPUPtr = void *; + +} // namespace infini diff --git a/include/xpu/xpu_kernel_without_config.h b/include/xpu/xpu_kernel_without_config.h new file mode 100644 index 00000000..5966b4b1 --- /dev/null +++ b/include/xpu/xpu_kernel_without_config.h @@ -0,0 +1,24 @@ +#pragma once +#include "xpu/xpu_runtime.h" +#include "core/kernel.h" + +namespace infini { + +class XPUKernelWithoutConfig : public Kernel { + public: + virtual void compute(const Operator &op, const PerfRecord &record, + const RuntimeObj *context) const { + compute(op, context); + } + virtual void compute(const Operator &op, + const RuntimeObj *context) const = 0; + // Premise: op is idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op, + const RuntimeObj *_context) const { + auto context = dynamic_cast(_context); + return make_ref(timeit([&]() { compute(op, _context); }, + [&]() { context->sync(); })); + } +}; + +} // namespace infini diff --git a/include/xpu/xpu_runtime.h b/include/xpu/xpu_runtime.h new file mode 100644 index 00000000..ce1c58b5 --- /dev/null +++ b/include/xpu/xpu_runtime.h @@ -0,0 +1,63 @@ +#pragma once +#include "xpu/xpu_common.h" +#include "core/runtime.h" + +namespace infini { + +class XPURuntimeObj : public RuntimeObj { + private: + baidu::xpu::api::Context* xdnn; + XPUPtr workspace; + size_t workspaceSize; + + public: + XPURuntimeObj() : RuntimeObj(Device::XPU) { + xdnn = baidu::xpu::api::create_context(); + // 10GB for Longformer + // size_t longformerNum = 3lu * (1 << 30); + workspaceSize = 7ll << 30; // 7 GB + workspace = alloc(workspaceSize); + } + virtual ~XPURuntimeObj() { + dealloc(workspace); + baidu::xpu::api::destroy_context(xdnn); + } + string toString() const override; + + void run(const Graph &graph, bool tune = false, + bool profiling = false) const; + // double runEvaluation(const Graph &graph, int nWarmups, + // int nEvaluations) const; + void sync() const; + XPUPtr alloc(size_t size) override { + void *ptr; + xpu_malloc(&ptr, size); + return ptr; + } + void dealloc(void *ptr) override { xpu_free(ptr); } + baidu::xpu::api::Context* XPUHandle() const { return xdnn; } + XPUPtr getWorkspace(size_t size) const { + IT_ASSERT(size <= workspaceSize); + return workspace; + } + + void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_HOST_TO_DEVICE); + } + + void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } + + void copyBlobInsideRuntime(void *dst, const void *src, + size_t bytes) const override { + xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE); + } + + private: + void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; +}; + +} // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 6bb77256..6cd1e745 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -21,6 +21,9 @@ #ifdef USE_BANG #include "bang/bang_runtime.h" #endif +#ifdef USE_XPU +#include "xpu/xpu_runtime.h" +#endif #ifdef USE_INTELCPU #include "intelcpu/mkl_runtime.h" #include "intelcpu/operator_timer.h" @@ -127,6 +130,10 @@ static Ref cuda_runtime() { return make_ref(); } static Ref bang_runtime() { return make_ref(); } #endif +#ifdef USE_XPU +static Ref xpu_runtime() { return make_ref(); } +#endif + #ifdef USE_INTELCPU static Ref intelcpu_runtime() { return make_ref(); } #endif @@ -239,6 +246,10 @@ void export_functions(py::module &m) { #ifdef USE_BANG .FUNCTION(bang_runtime) #endif + +#ifdef USE_XPU + .FUNCTION(xpu_runtime) +#endif .FUNCTION(conv_attrs_of) .FUNCTION(conv_trans_attrs_of) .FUNCTION(matmul_attrs_of) @@ -269,6 +280,10 @@ void init_graph_builder(py::module &m) { #ifdef USE_BANG py::class_, RuntimeObj>( m, "BangRuntime"); +#endif +#ifdef USE_XPU + py::class_, RuntimeObj>( + m, "XPURuntime"); #endif py::class_>(m, "Tensor") .def("fuid", &TensorObj::getFuid, policy::automatic) diff --git a/src/kernels/xpu/element_wise.cc b/src/kernels/xpu/element_wise.cc new file mode 100644 index 00000000..4b5eff8b --- /dev/null +++ b/src/kernels/xpu/element_wise.cc @@ -0,0 +1,29 @@ +#include "operators/element_wise.h" +#include "xpu/xpu_kernel_without_config.h" +#include "xpu/xpu_runtime.h" + +namespace infini { +class AddXdnn : public XPUKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() != 4 || bDim.size() != 4) + IT_TODO_HALT(); + auto ret = baidu::xpu::api::broadcast_add(context->XPUHandle(), (float*)aData, (float*)bData, (float*)cData, aDim, bDim); + assert(ret == 0); + return; + + } +}; + +REGISTER_KERNEL(Device::XPU, OpType::Add, DataType::Float32, AddXdnn, + "Add_xdnn_XPU_Float32"); +}; // namespace infini diff --git a/src/xpu/operator_timer.cc b/src/xpu/operator_timer.cc new file mode 100644 index 00000000..f08e71e1 --- /dev/null +++ b/src/xpu/operator_timer.cc @@ -0,0 +1,71 @@ +#include "xpu/operator_timer.h" +#include "xpu/xpu_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/conv.h" +#include "operators/matmul.h" +#include "utils/data_generator.h" + +namespace infini { +namespace opTimer { + +double getPerfConvXPU(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name) { + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime xpu = make_ref(); + Graph gXpu = make_ref(xpu); + // Set input data on CPU in a CPU Graph + IT_ASSERT(c % group == 0); + Tensor i0Cpu = gCpu->addTensor({n, h, w, c}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({f, r, s, c / group}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to Xpu + Tensor i0XPU = gXpu->cloneTensor(i0Cpu); + Tensor w0XPU = gXpu->cloneTensor(w0Cpu); + // Build Xpu graph + auto conv = gXpu->addOp(i0XPU, w0XPU, nullptr, padh, padw, + strideh, stridew, dilationh, dilationw); + // allocate Xpu memory + gXpu->dataMalloc(); + // Execute on Xpu + bool tune = true; + xpu->run(gXpu, tune); + return xpu->getPerfTime(gXpu); +} + +double getPerfMatmulXPU(int b, int m, int n, int k, const char *name) { + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime xpu = make_ref(); + Graph gXpu = make_ref(xpu); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to Xpu + Tensor i0XPU = gXpu->cloneTensor(i0Cpu); + Tensor w0XPU = gXpu->cloneTensor(w0Cpu); + // Build Xpu graph + auto conv = gXpu->addOp(i0XPU, w0XPU, nullptr); + // allocate Xpu memory + gXpu->dataMalloc(); + // Execute on Xpu + bool tune = true; + xpu->run(gXpu, tune); + return xpu->getPerfTime(gXpu); +} + +} // namespace opTimer +} // namespace infini diff --git a/src/xpu/xpu_runtime.cc b/src/xpu/xpu_runtime.cc new file mode 100644 index 00000000..3f973723 --- /dev/null +++ b/src/xpu/xpu_runtime.cc @@ -0,0 +1,59 @@ +#include "xpu/xpu_runtime.h" +#include "core/kernel.h" +#include "core/perf_engine.h" + +namespace infini { + +void XPURuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, + bool profiling = false) const { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); + double totalTime = 0; + std::map opTime; + std::map opCnt; + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType(), DataType::Float32}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; + auto perfData = perfEngine.getPerfData(perfKey); + if (!perfData && !tune) { + kernel->compute(op, this); + continue; + } + + PerfRecord record; + if (!perfData) { + record = kernel->tune(op, this); + perfEngine.setPerfData(perfKey, record); + } else + record = perfData; + + double t = record->time; + totalTime += t; + + if (profiling) { + double t = timeit([&]() { kernel->compute(op, record, this); }, + [&]() { sync(); }, 1, 1); + op->print(); + printf(" op_time on xpu %lf\n", t); + totalTime += t; + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } + } +} + +void XPURuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { + if (profiling) + IT_TODO_HALT(); + runWithoutSync(graph, tune, profiling); + sync(); +} + +void XPURuntimeObj::sync() const { ; } + +string XPURuntimeObj::toString() const { return "XPU Runtime"; } + +} // namespace infini diff --git a/test/kernels/xpu/test_xpu_add.cc b/test/kernels/xpu/test_xpu_add.cc new file mode 100644 index 00000000..a658dcdf --- /dev/null +++ b/test/kernels/xpu/test_xpu_add.cc @@ -0,0 +1,54 @@ +#include "xpu/xpu_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +template +void testAdd( + const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto xpuRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu1->setData(generator); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu2->dataMalloc(); + inputCpu2->setData(generator); + + // GPU + Graph xpuGraph = make_ref(xpuRuntime); + auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1); + auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2); + auto gpuOp = xpuGraph->addOp(inputGpu1, inputGpu2, nullptr); + xpuGraph->dataMalloc(); + xpuRuntime->run(xpuGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu1, inputCpu2, nullptr); + cpuGraph->dataMalloc(); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + //outputCpu->printData(); + //outputGpu2Cpu->printData(); + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +TEST(xpu_add, run) { + testAdd(IncrementalGenerator(), Shape{10, 256, 256, 3}); +} + +} // namespace infini