From 00e6cc2587874116645cca70b044c742802994ac Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:48:35 +0800 Subject: [PATCH] XCCL support (#171) * add reduce_mean and gather * fix format * add kunlun allreduce and cmakefile * add kunlun allreduce and cmakefile * deltete cmake opt * fix format * fix makefile * add DIST option in Makefile * add xpu allgather * delete xpu_wait() * add xpu allgather * delete specific compiler * fix format * fix gather * add broadcast * fix format * fix * fix xpu, add where operation, fix element-wise operation * fix softmax * fix softmax * log internal input and output * fix kunlun gather bugs * update CMakeList.txt and Makefile * fix some kunlun kernels * fix Makefile * fix Makefile * set cmake version 3.12 * format * fix where, gather and support gpt2 * "fix format" * fix format * copy onnx.py from master * use KUNLUN_HOME instead of absolute path * fix torchvision models * support torchvison model-zoo * fix format * format fix, CMakeList fix * fix review * fix vecToString return value * fix format * delete empty file --------- Co-authored-by: wanghailu Co-authored-by: wanghailu Co-authored-by: Haojie Wang --- CMakeLists.txt | 16 +- Makefile | 1 + cmake/FindXCCL.cmake | 27 +++ examples/NNmodel | 2 +- .../distributed/{launch.py => cuda_launch.py} | 0 examples/distributed/launch_kunlun.py | 213 ++++++++++++++++++ include/core/common.h | 32 ++- include/core/runtime.h | 2 + include/core/tensor.h | 29 ++- include/core/workspace.h | 42 ++++ include/kunlun/kunlun_act_type.h | 23 ++ include/kunlun/kunlun_common.h | 2 + include/kunlun/kunlun_runtime.h | 48 ++-- include/kunlun/xccl_communicator.h | 60 +++++ include/operators/unary.h | 1 + include/utils/broadcast_shape.h | 14 -- include/utils/operator_utils.h | 9 + include/utils/small_array.h | 8 + pyinfinitensor/src/pyinfinitensor/onnx.py | 6 +- src/core/graph_handler.cc | 2 + src/core/tensor.cc | 30 +++ src/ffi/ffi_infinitensor.cc | 4 +- src/kernels/cpu/unary.cc | 3 + src/kernels/cuda/where.cc | 2 +- src/kernels/kunlun/all_gather.cc | 43 ++++ src/kernels/kunlun/all_reduce.cc | 49 ++++ src/kernels/kunlun/batch_norm.cc | 2 +- src/kernels/kunlun/broadcast.cc | 32 +++ src/kernels/kunlun/cast.cc | 58 ++--- src/kernels/kunlun/concat.cc | 4 +- src/kernels/kunlun/conv.cc | 12 +- src/kernels/kunlun/conv_trans.cc | 9 +- src/kernels/kunlun/element_wise.cc | 194 ++++++++-------- src/kernels/kunlun/gather.cc | 18 +- src/kernels/kunlun/matmul.cc | 113 +++++++++- src/kernels/kunlun/pooling.cc | 33 ++- .../kunlun/{reduce_mean.cc => reduce.cc} | 28 ++- src/kernels/kunlun/select.cc | 32 --- src/kernels/kunlun/slice.cc | 39 ++++ src/kernels/kunlun/softmax.cc | 6 +- src/kernels/kunlun/transpose.cc | 16 +- src/kernels/kunlun/unary.cc | 175 ++++++++------ src/kernels/kunlun/where.cc | 67 ++++++ src/kunlun/kunlun_runtime.cc | 15 +- src/operators/matmul.cc | 2 +- src/operators/unary.cc | 2 + src/utils/operator_utils.cc | 24 ++ test/kernels/kunlun/test_kunlun_allgather.cc | 50 ++++ test/kernels/kunlun/test_kunlun_allreduce.cc | 72 ++++++ test/kernels/kunlun/test_kunlun_broadcast.cc | 56 +++++ test/kernels/kunlun/test_kunlun_gather.cc | 144 ++++++++++++ test/kernels/kunlun/test_kunlun_matmul.cc | 142 ++++++++---- test/kernels/kunlun/test_kunlun_slice.cc | 39 ++++ test/kernels/kunlun/test_kunlun_where.cc | 77 +++++++ test/kunlun/test_kunlun_workspace.cc | 20 ++ 55 files changed, 1780 insertions(+), 369 deletions(-) create mode 100644 cmake/FindXCCL.cmake rename examples/distributed/{launch.py => cuda_launch.py} (100%) create mode 100644 examples/distributed/launch_kunlun.py create mode 100644 include/core/workspace.h create mode 100644 include/kunlun/kunlun_act_type.h create mode 100644 include/kunlun/xccl_communicator.h delete mode 100644 include/utils/broadcast_shape.h create mode 100644 src/kernels/kunlun/all_gather.cc create mode 100644 src/kernels/kunlun/all_reduce.cc create mode 100644 src/kernels/kunlun/broadcast.cc rename src/kernels/kunlun/{reduce_mean.cc => reduce.cc} (53%) delete mode 100644 src/kernels/kunlun/select.cc create mode 100644 src/kernels/kunlun/slice.cc create mode 100644 src/kernels/kunlun/where.cc create mode 100644 test/kernels/kunlun/test_kunlun_allgather.cc create mode 100644 test/kernels/kunlun/test_kunlun_allreduce.cc create mode 100644 test/kernels/kunlun/test_kunlun_broadcast.cc create mode 100644 test/kernels/kunlun/test_kunlun_gather.cc create mode 100644 test/kernels/kunlun/test_kunlun_slice.cc create mode 100644 test/kernels/kunlun/test_kunlun_where.cc create mode 100644 test/kunlun/test_kunlun_workspace.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 19d11183..ccacac1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,11 +53,13 @@ endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off +add_compile_options(-Wno-error=unused-variable) find_package( Python COMPONENTS Interpreter Development REQUIRED) + # OpenMP find_package(OpenMP) if(OpenMP_C_FOUND) @@ -282,9 +284,9 @@ if(USE_KUNLUN) endif() message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}") - include_directories("${KUNLUN_HOME}/XTDK/include/") - find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64") - find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/XTDK/shlib") + include_directories("${KUNLUN_HOME}/include/") + find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/") + find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) @@ -297,6 +299,13 @@ if(USE_KUNLUN) endif() message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}") + if (BUILD_DIST) + message(STATUS "Add BUILD_DIST, use XCCL with KUNLUN XPU") + list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + find_package(XCCL REQUIRED) + add_compile_definitions(INFINI_USE_XCCL=1) + target_link_libraries(InfiniTensor ${XCCL_LIBRARIES}) + endif() target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++) endif() @@ -335,6 +344,7 @@ if(BUILD_TEST) endif() if (USE_KUNLUN) build_test(test/kernels/kunlun/*.cc) + build_test(test/kunlun/*.cc) endif() if (USE_INTELCPU) build_test(test/kernels/intelcpu/*.cc) diff --git a/Makefile b/Makefile index ff2ad0a9..c0e0c8b7 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ BACKTRACE ?= ON TEST ?= ON DIST ?= OFF NNET ?= OFF +DIST ?= OFF FORMAT_ORIGIN ?= # Docker build options DOCKER_NAME ?= infinitensor diff --git a/cmake/FindXCCL.cmake b/cmake/FindXCCL.cmake new file mode 100644 index 00000000..fa4b3b00 --- /dev/null +++ b/cmake/FindXCCL.cmake @@ -0,0 +1,27 @@ +# Find the xccl libraries +set(XCCL_INCLUDE_DIR $ENV{KUNLUN_HOME}/include CACHE PATH "Folder contains KUNLUN XCCL headers") +set(XCCL_LIB_DIR $ENV{KUNLUN_HOME} CACHE PATH "Folder contains KUNLUN XCCL libraries") + +list(APPEND CMAKE_PREFIX_PATH $ENV{KUNLUN_HOME}) + +find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR} + NAMES xpu/bkcl.h + HINTS XCCL_INCLUDE_DIR) + +find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR} + NAMES so/libbkcl.so + HINTS XCCL_LIB_DIR) + +message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}") +message(STATUS "XCCL_LIBRARIES: ${XCCL_LIBRARIES}") + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(XCCL DEFAULT_MSG XCCL_INCLUDE_DIRS XCCL_LIBRARIES) + +if (XCCL_FOUND) + set (XCCL_HEADER_FILE "${XCCL_INCLUDE_DIRS}/xpu/bkcl.h") + message (STATUS "Determing XCCL version from ${XCCL_HEADER_FILE}...") + list (APPEND CMAKE_REQUIRED_INCLUDES ${XCCL_INCLUDE_DIRS}) + message(STATUS "Found XCCL (include: ${XCCL_INCLUDE_DIRS}, library: ${XCCL_LIBRARIES})") + mark_as_advanced(XCCL_INCLUDE_DIRS XCCL_LIBRARIES) +endif() diff --git a/examples/NNmodel b/examples/NNmodel index b896cec2..51d31052 160000 --- a/examples/NNmodel +++ b/examples/NNmodel @@ -1 +1 @@ -Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98 +Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77 diff --git a/examples/distributed/launch.py b/examples/distributed/cuda_launch.py similarity index 100% rename from examples/distributed/launch.py rename to examples/distributed/cuda_launch.py diff --git a/examples/distributed/launch_kunlun.py b/examples/distributed/launch_kunlun.py new file mode 100644 index 00000000..e8c1a0ab --- /dev/null +++ b/examples/distributed/launch_kunlun.py @@ -0,0 +1,213 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model + +st_input_dir = "standard/inputs/" +st_output_dir = "standard/outputs/" + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + parser.add_argument( + "--run_single", + default=False, + action="store_true", + help="whether run model with single process with standard inputs" + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + args.run_single + ) + + +def run_model(model, runtime, world_size=1, rank=0, n=10): + stub = OnnxStub(model, runtime) + load_inputs(stub, world_size, rank) + # stub.tune() + stub.run() + # get outputs + time.sleep(0.01) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() + + # bench + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + + +def run_and_compare(name, model, runtime, world_size=1, rank = 0): + results = np.load(os.path.join(st_output_dir,f"output.npy")) + outputs = run_model(model, runtime, world_size, rank) + print(outputs[:100]) + if np.isnan(outputs).any(): + print("Nan in output") + print("answer argmax:", np.argmax(results)) + print("output argmax:", np.argmax(outputs)) + #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) + getDiff(results, outputs) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + runtime = backend.KUNLUNRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_and_compare(name, model, runtime, world_size, rank) + + +def start_single(name, model): + runtime = backend.KUNLUNRuntime(0) + run_and_compare(name, model, runtime) + + +def generate_input_output(model): + runtime = backend.KUNLUNRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + # input = np.array([position_id]) + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(os.path.join(st_input_dir, f"input_{i}"), input) + stub.run() + # print(stub.outputs) + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + print(output[:100]) + if np.isnan(output).any(): + print("Nan in output") + np.save(os.path.join(st_output_dir, f"output"), output) + + +def load_inputs(stub, world_size=1, rank=0): + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = np.load(os.path.join(st_input_dir, f"input_{i}.npy")) + if all(x == y for x,y in zip(input.shape,tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + + +def getDiff(base, test): + absolute_diff = np.abs(np.subtract(base, test)) + max_absolute_diff = np.max(absolute_diff) + + baseCopy = base.astype(np.float64).ravel() + testCopy = test.astype(np.float64).ravel() + upValue = np.sum(np.abs(baseCopy - testCopy)) + downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) + max_relative_diff = upValue / downValue + print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}") + + return max_absolute_diff, max_relative_diff + + +def main(): + nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args() + + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + p = mp.Process(target=generate_input_output, args=[model]) + p.start() + p.join() + return + + # # run single process. + # # use standalone process to isolate cuda. + if run_single: + print("run model by single GPU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + return + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} GPU in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/include/core/common.h b/include/core/common.h index 81e704f8..3a847783 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -61,16 +61,30 @@ template auto enum_to_underlying(T e) { } template std::string vecToString(const std::vector &vec) { - std::string ret; - ret.append("["); - for (auto d : vec) { - ret.append(std::to_string(d)); - ret.append(","); + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < vec.size(); ++i) { + ss << vec.at(i); + if (i < vec.size() - 1) { + ss << ","; + } } - if (!vec.empty()) - ret.pop_back(); - ret.append("]"); - return ret; + ss << "]"; + return ss.str(); +} + +template std::string vecToString(const T *st, size_t length) { + std::stringstream ss; + ss << "["; + size_t i = 0; + for (i = 0; i < length; i++) { + ss << *(st + i); + if (i < length - 1) { + ss << ","; + } + } + ss << "]"; + return ss.str(); } double timeit( diff --git a/include/core/runtime.h b/include/core/runtime.h index 5bc2123e..c5544276 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -15,6 +15,7 @@ class GraphObj; class GraphHandlerObj; class RuntimeObj; class BlobObj; +template class WorkspaceObj; using TensorBase = Ref; using Tensor = Ref; @@ -23,6 +24,7 @@ using Graph = Ref; using GraphHandler = Ref; using Runtime = Ref; using Blob = Ref; +template using Workspace = Ref>; using TensorVec = vector; using OpVec = vector; diff --git a/include/core/tensor.h b/include/core/tensor.h index 63efd0f7..33fa78fc 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -4,6 +4,7 @@ #include "utils/data_convert.h" #include #include +#include #if USE_CUDA #include "cuda/cuda_runtime.h" @@ -143,6 +144,7 @@ class TensorObj : public TensorBaseObj { } void printData() const; + void dumpData(std::ofstream &ofs) const; bool equalData(const Tensor &rhs, double relativeError = 1e-6) const; template bool equalData(const vector &dataVector) { @@ -198,13 +200,20 @@ class TensorObj : public TensorBaseObj { if (a[i] != b[i]) return false; } else if constexpr (std::is_floating_point_v) { - if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) > - relativeError) { + if (std::min(fabs(a[i]), fabs(b[i])) == 0. && + fabs(a[i] - b[i]) > relativeError) { + printf("Error on %lu: %f %f\n", i, a[i], b[i]); + return false; + } else if (std::min(fabs(a[i]), fabs(b[i])) != 0. && + fabs(a[i] - b[i]) / + std::max(fabs(a[i]), fabs(b[i])) > + relativeError) { printf("Error on %lu: %f %f\n", i, a[i], b[i]); return false; } - } else + } else { static_assert(!sizeof(T), "Unsupported data type"); + } } return true; } @@ -239,8 +248,8 @@ class TensorObj : public TensorBaseObj { // // std::cerr << "Init beginned " << std::endl; // #pragma omp parallel for // for (size_t i = 0; i < iEnd; ++i) - // data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) % - // 10000; + // data[i] = fastrand(random_seed[omp_get_thread_num() * + // 16]) % 10000; // // std::cerr << "Init finished" << std::endl; // computed = ComputedFull; // return true; @@ -285,8 +294,8 @@ class TensorObj : public TensorBaseObj { // auto nDim = dims.size(); // auto nBroadcastDim = ds.size() - nDim; // for (size_t i = 0; i < nDim; ++i) - // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >= - // dims[i]) + // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + + // i] >= dims[i]) // return (size_t)-1; // size_t idx = 0; // for (size_t i = 0; i < nDim; ++i) @@ -345,12 +354,14 @@ class TensorObj : public TensorBaseObj { // return (g_seed >> 16) & 0x7FFF; // } - // std::vector> const *getSplittingPoints() const { + // std::vector> const *getSplittingPoints() + // const { // assert(!splittingPoints.empty()); // return &splittingPoints; // } - // bool setSplittingPoints(std::vector> value) { + // bool setSplittingPoints(std::vector> value) + // { // assert(!value.empty()); // splittingPoints = value; // return true; diff --git a/include/core/workspace.h b/include/core/workspace.h new file mode 100644 index 00000000..80be67a4 --- /dev/null +++ b/include/core/workspace.h @@ -0,0 +1,42 @@ +#pragma once +#include "core/runtime.h" + +namespace infini { + +template class WorkspaceObj { + private: + T workspace; // workspace pointer + size_t workspaceSize; // Size of workspace + size_t workspaceAlloc; // currently use workspace size + + public: + WorkspaceObj(T workspace_, size_t workspaceSize_) + : workspace(workspace_), workspaceSize(workspaceSize_) { + workspaceAlloc = 0; + } + virtual ~WorkspaceObj() { + // Dealloc workspace in RuntimeObj + // Set workspace = nullptr here + workspace = nullptr; + } + size_t getWorkspaceSize() const { return workspaceSize; } + + T getWorkspace(size_t size) { + // Get unused workspace + IT_ASSERT(size + workspaceAlloc <= workspaceSize); + auto ret = (T)(static_cast(workspace) + workspaceAlloc); + workspaceAlloc += size; + return ret; + } + T getWorkspace() { + // Override getWorkspace in order to dealloc in runtime + return workspace; + } + void resetWorkspace() { + // Reset workspaceAlloc every time end kernel + workspaceAlloc = 0; + } + size_t getWorkspaceAlloc() const { return workspaceAlloc; } +}; + +} // namespace infini diff --git a/include/kunlun/kunlun_act_type.h b/include/kunlun/kunlun_act_type.h new file mode 100644 index 00000000..cd49808e --- /dev/null +++ b/include/kunlun/kunlun_act_type.h @@ -0,0 +1,23 @@ +#include "core/op_type.h" +#include "kunlun/kunlun_common.h" + +namespace infini { +using KunlunActType = xdnn::Activation_t; +KunlunActType parseActType(ActType act) { + switch (act) { + case ActType::None: + return KunlunActType::LINEAR; + case ActType::Tanh: + return KunlunActType::TANH; + case ActType::Sigmoid: + return KunlunActType::SIGMOID; + case ActType::Relu: + return KunlunActType::RELU6; + default: + fprintf(stderr, "Activation Type not support yet!\n"); + break; + } + return KunlunActType::LINEAR; +} + +}; // namespace infini diff --git a/include/kunlun/kunlun_common.h b/include/kunlun/kunlun_common.h index 2350cc93..fc390ef7 100644 --- a/include/kunlun/kunlun_common.h +++ b/include/kunlun/kunlun_common.h @@ -3,6 +3,8 @@ #include "xpu/runtime_ex.h" #include "xpu/xdnn.h" +namespace xdnn = baidu::xpu::api; + #define checkKUNLUNError(call) \ { \ auto err = call; \ diff --git a/include/kunlun/kunlun_runtime.h b/include/kunlun/kunlun_runtime.h index 6a5be4c9..0c175158 100644 --- a/include/kunlun/kunlun_runtime.h +++ b/include/kunlun/kunlun_runtime.h @@ -1,28 +1,35 @@ #pragma once #include "core/runtime.h" +#include "core/workspace.h" #include "kunlun/kunlun_common.h" - +#ifdef INFINI_USE_XCCL +#include "kunlun/xccl_communicator.h" +#endif namespace infini { class KUNLUNRuntimeObj : public RuntimeObj { private: - baidu::xpu::api::Context *xdnn; - KUNLUNPtr workspace; - size_t workspaceSize; + xdnn::Context *ctx; + std::unique_ptr comm; + // KUNLUNPtr workspace; + // size_t workspaceSize; + Workspace workspace; public: - KUNLUNRuntimeObj() : RuntimeObj(Device::KUNLUN) { - xdnn = baidu::xpu::api::create_context(); + KUNLUNRuntimeObj(int deviceId = 0) : RuntimeObj(Device::KUNLUN) { + xpu_set_device(deviceId); + ctx = xdnn::create_context(); // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); - workspaceSize = 3ll << 30; // 3 GB - // std::cout<(workspaceSize)<< std::endl; - workspace = alloc(workspaceSize); + size_t workspaceSize = 3llu << 30; // 3 GB + KUNLUNPtr wkspacePtr = alloc(workspaceSize); + workspace = + make_ref>(wkspacePtr, workspaceSize); } virtual ~KUNLUNRuntimeObj() { - dealloc(workspace); - baidu::xpu::api::destroy_context(xdnn); + KUNLUNPtr wkspacePtr = workspace->getWorkspace(); + dealloc(wkspacePtr); + xdnn::destroy_context(ctx); } string toString() const override; @@ -31,6 +38,7 @@ class KUNLUNRuntimeObj : public RuntimeObj { // double runEvaluation(const Graph &graph, int nWarmups, // int nEvaluations) const; void sync() const; + KUNLUNPtr alloc(size_t size) override { void *ptr; checkKUNLUNError( @@ -38,33 +46,33 @@ class KUNLUNRuntimeObj : public RuntimeObj { return ptr; } void dealloc(void *ptr) override { xpu_free(ptr); } - baidu::xpu::api::Context *KUNLUNHandle() const { return xdnn; } + + xdnn::Context *KUNLUNHandle() const { return ctx; } + // Get $size workspace by bytes KUNLUNPtr getWorkspace(size_t size) const { - IT_ASSERT(size <= workspaceSize); - return workspace; + auto ret = workspace->getWorkspace(size); + return ret; } + Workspace getWorkspaceObj() const { return workspace; } void copyBlobFromCPU(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_HOST_TO_DEVICE); } - void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_HOST); } - void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE); } + void initComm(const string &name, int worldSize, int rank) final; - void initComm(const string &, int, int) override { IT_TODO_HALT(); } - - CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } + CommunicatorObj &getCommunicator() const final { return *comm; } private: void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; diff --git a/include/kunlun/xccl_communicator.h b/include/kunlun/xccl_communicator.h new file mode 100644 index 00000000..6e9c31d0 --- /dev/null +++ b/include/kunlun/xccl_communicator.h @@ -0,0 +1,60 @@ +#pragma once +#include "core/communicator.h" +#include "xpu/bkcl.h" +#include +#include +#include +#include + +#define checkXcclError(call) \ + { \ + auto err = call; \ + if (BKCL_SUCCESS != err) { \ + fprintf(stderr, "XCCL error in %s:%i.\n", __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace infini { + +class XcclCommunicatorObj final : public CommunicatorObj { + private: + BKCLContext_t comm; + + public: + XcclCommunicatorObj(const string &name, int worldSize, int rank) + : CommunicatorObj(worldSize, rank) { + const std::string filePath("./" + name + "_xccl_id.bin"); + BKCLUniqueId commId; + if (rank == 0) { + checkXcclError(bkcl_get_unique_id(&commId)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *)&commId, sizeof(BKCLUniqueId)); + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + _IT_ASSERT_2(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *)&commId, sizeof(BKCLUniqueId)); + } + checkXcclError(bkcl_init_rank(&comm, rank, worldSize, &commId)); + if (rank == 0) { + std::filesystem::remove(filePath); + } + } + + BKCLContext_t getXcclComm() { return comm; } + + ~XcclCommunicatorObj() final { checkXcclError(bkcl_destroy_context(comm)); } + virtual string toString() const final { + std::ostringstream oss; + oss << "XCCL communicator"; + return oss.str(); + } +}; + +} // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h index 8da375de..bb370259 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -159,6 +159,7 @@ enum class CastType { Uint322Int64, Float162Float, BFloat162Float, + Float2Float, }; class CastObj : public OperatorObj { diff --git a/include/utils/broadcast_shape.h b/include/utils/broadcast_shape.h deleted file mode 100644 index e794ff90..00000000 --- a/include/utils/broadcast_shape.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -namespace infini { -void broadcastShape(const Shape &originShape, SmallArray &modifyShape, - int nDims, int size) { - for (int i = nDims - size - 1; i >= 0; --i) { - modifyShape.data[i] = 1; - } - for (int i = nDims - 1; i >= nDims - size; --i) { - modifyShape.data[i] = originShape[i - nDims + size]; - } -} - -} // namespace infini diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h index b0871c0b..c29b8ca5 100644 --- a/include/utils/operator_utils.h +++ b/include/utils/operator_utils.h @@ -5,6 +5,9 @@ #include "core/operator.h" #include "core/tensor.h" +#include "utils/small_array.h" +#include + namespace infini { // Launch a broadcast shape based on the shape of input A and B @@ -20,6 +23,12 @@ size_t delocate_index(const Shape &shapeIndex, const Shape &shape, const Shape &stride); // Convert KernelAttrs to a string representation std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs); +// VectorProd +int shapeProd(std::vector::iterator start, std::vector::iterator end); +void broadcastShape(const Shape &originShape, SmallArray &modifyShape, + int nDims, int size); +void broadcastShape(const Shape &tempShape, Shape &modifyShape); + } // namespace infini #endif diff --git a/include/utils/small_array.h b/include/utils/small_array.h index 3ea93279..5a757b2d 100644 --- a/include/utils/small_array.h +++ b/include/utils/small_array.h @@ -4,6 +4,14 @@ namespace infini { #define SMALL_ARRAY_SIZE 8 struct SmallArray { int data[SMALL_ARRAY_SIZE]; + + int prod(int start, int end) { + int result = 1; + for (int i = start; i < end; ++i) { + result *= data[i]; + } + return result; + } }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 1a2e28a7..a21e0b0a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,4 @@ -import backend +import backend from onnx import ( ModelProto, TensorProto, @@ -208,8 +208,8 @@ class OnnxStub: ) elif node.op_type == "MatMul": tensors[node.output[0]] = self.handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], + tensors[node.input[0]], # input + tensors[node.input[1]], # weight tensors.get(node.output[0]), False, False, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 0821121d..c8458454 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -695,6 +695,8 @@ static CastType inferCastType(Tensor input, int to) { return CastType::Float162Float; } else if (iType == DataType::BFloat16 && oType == DataType::Float32) { return CastType::BFloat162Float; + } else if (iType == DataType::Float32 && oType == DataType::Float32) { + return CastType::Float2Float; } else { IT_TODO_HALT_MSG("Unsupported CastType : input_type is " + iType.toString() + " output_type is " + diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 5be8a18d..1287f7ba 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -66,6 +66,36 @@ void TensorObj::setShape(Shape shape_) { _size = size; } +void TensorObj::dumpData(std::ofstream &ofs) const { + IT_ASSERT(data != nullptr); + if (!runtime->isCpu()) + IT_TODO_HALT(); + +#define TRY_DUMP(N) \ + if (dtype == DataType(N)) \ + ofs << dataToString::t>() << std::endl; + + TRY_DUMP(0) // fmt: new line + else TRY_DUMP(1) // + else TRY_DUMP(2) // + else TRY_DUMP(3) // + else TRY_DUMP(4) // + else TRY_DUMP(5) // + else TRY_DUMP(6) // + else TRY_DUMP(7) // + else TRY_DUMP(8) // + else TRY_DUMP(9) // + else TRY_DUMP(10) // + else TRY_DUMP(11) // + else TRY_DUMP(12) // + else TRY_DUMP(13) // + else TRY_DUMP(16) // + else IT_TODO_HALT(); + ofs.flush(); + +#undef TRY_DUMP +} + void TensorObj::printData() const { IT_ASSERT(data != nullptr); if (!runtime->isCpu()) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 9dc43510..361a3d50 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -429,7 +429,9 @@ void init_graph_builder(py::module &m) { #endif #ifdef USE_KUNLUN py::class_, RuntimeObj>( - m, "KUNLUNRuntime"); + m, "KUNLUNRuntime") + .def(py::init(), py::arg("device") = 0) + .def("init_comm", &KUNLUNRuntimeObj::initComm); #endif py::class_>(m, "Tensor", py::buffer_protocol()) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 9e7cead0..77cbe6b3 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -145,6 +145,9 @@ class NativeUnary : public CpuKernelWithoutConfig { case OpType::Atanh: _doCompute = aTanhCompute; break; + case OpType::Acosh: + _doCompute = aCoshCompute; + break; default: IT_TODO_HALT(); } diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc index da6ac784..b18dec8a 100644 --- a/src/kernels/cuda/where.cc +++ b/src/kernels/cuda/where.cc @@ -2,7 +2,7 @@ #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_where.h" -#include "utils/broadcast_shape.h" +#include "utils/operator_utils.h" namespace infini { diff --git a/src/kernels/kunlun/all_gather.cc b/src/kernels/kunlun/all_gather.cc new file mode 100644 index 00000000..1b863feb --- /dev/null +++ b/src/kernels/kunlun/all_gather.cc @@ -0,0 +1,43 @@ +#ifdef INFINI_USE_XCCL +#include "operators/all_gather.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "kunlun/xccl_communicator.h" + +namespace infini { +class AllGatherXCCL : public KUNLUNKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + int world_size = op->getWorldSize(); + IT_ASSERT(world_size == context->getCommunicator().getWorldSize()); + void *input = op->getInputs(0)->getRawDataPtr(); + KUNLUNPtr output_temp = + context->getWorkspace(op->getInputs(0)->getBytes() * world_size); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / op->getDType().getSize(); + + BKCLContext_t comm = + dynamic_cast(context->getCommunicator()) + .getXcclComm(); + // TODO: Using the default stream 0 + checkXcclError( + bkcl_all_gather(comm, input, count, output_temp, BKCL_FLOAT, 0)); + + for (int i = 0; i < world_size; ++i) { + Tensor output = op->getOutput(i); + context->copyBlobInsideRuntime( + output->getRawDataPtr(), + static_cast(output_temp) + i * count, bytes); + } + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::AllGather, AllGatherXCCL, + "AllGatcher_XCCL_KUNLUN"); +} // namespace infini +#endif diff --git a/src/kernels/kunlun/all_reduce.cc b/src/kernels/kunlun/all_reduce.cc new file mode 100644 index 00000000..ab01d60d --- /dev/null +++ b/src/kernels/kunlun/all_reduce.cc @@ -0,0 +1,49 @@ +#ifdef INFINI_USE_XCCL +#include "operators/all_reduce.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "kunlun/xccl_communicator.h" + +namespace infini { +class AllReduceXCCL : public KUNLUNKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput(0)->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->size(); + + BKCLContext_t comm = + dynamic_cast(context->getCommunicator()) + .getXcclComm(); + checkXcclError(bkcl_all_reduce(comm, input, output, count, + BKCLDataType::BKCL_FLOAT, getRedOp(), + 0)); + } + virtual BKCLOp getRedOp() const = 0; +}; + +class AllReduceSumXCCL : public AllReduceXCCL { + BKCLOp getRedOp() const override { return BKCLOp::BKCL_ADD; } +}; + +class AllReduceMinXCCL : public AllReduceXCCL { + BKCLOp getRedOp() const override { return BKCLOp::BKCL_MIN; } +}; + +class AllReduceMaxXCCL : public AllReduceXCCL { + BKCLOp getRedOp() const override { return BKCLOp::BKCL_MAX; } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceSum, AllReduceSumXCCL, + "AllReduce_Sum_XCCL_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceMax, AllReduceMaxXCCL, + "AllReduce_Max_XCCL_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceMin, AllReduceMinXCCL, + "AllReduce_Min_XCCL_KUNLUN"); +} // namespace infini +#endif diff --git a/src/kernels/kunlun/batch_norm.cc b/src/kernels/kunlun/batch_norm.cc index d0e1c9b2..30c8eee4 100644 --- a/src/kernels/kunlun/batch_norm.cc +++ b/src/kernels/kunlun/batch_norm.cc @@ -26,7 +26,7 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig { int h = dims[2]; int c = dims[1]; int n = dims[0]; - auto ret = baidu::xpu::api::batch_norm_infer( + auto ret = xdnn::batch_norm_infer( context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h, w, op->getEps(), (float *)scale, (float *)bias, (float *)mean, (float *)var, true); diff --git a/src/kernels/kunlun/broadcast.cc b/src/kernels/kunlun/broadcast.cc new file mode 100644 index 00000000..9750aec6 --- /dev/null +++ b/src/kernels/kunlun/broadcast.cc @@ -0,0 +1,32 @@ +#ifdef INFINI_USE_XCCL +#include "operators/broadcast.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "kunlun/xccl_communicator.h" + +namespace infini { +class BroadcastXCCL : public KUNLUNKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + + BKCLContext_t comm = + dynamic_cast(context->getCommunicator()) + .getXcclComm(); + // TODO: Using default stream 0 for now. + checkXcclError(bkcl_broadcast(comm, input, output, count, BKCL_FLOAT, + op->getRoot(), 0)); + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Broadcast, BroadcastXCCL, + "Broadcast_XCCL_KUNLUN"); +} // namespace infini +#endif diff --git a/src/kernels/kunlun/cast.cc b/src/kernels/kunlun/cast.cc index 0bd7e4e8..d9cc890f 100644 --- a/src/kernels/kunlun/cast.cc +++ b/src/kernels/kunlun/cast.cc @@ -17,74 +17,78 @@ class CastXdnn : public KUNLUNKernelWithoutConfig { int ret = 0; switch (type) { case CastType::Float2Float16: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (float *)aData, (float16 *)cData, len); break; case CastType::Float2Int64: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (float *)aData, (int64_t *)cData, len); break; case CastType::Float2Int32: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (float *)aData, (int *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (float *)aData, (int *)cData, len); break; case CastType::Float2Int16: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (float *)aData, (int16_t *)cData, len); break; case CastType::Float2Int8: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (float *)aData, (int8_t *)cData, len); break; case CastType::Int322Float: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int *)aData, (float *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), (int *)aData, + (float *)cData, len); break; case CastType::Int322Int8: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int *)aData, (int8_t *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), (int *)aData, + (int8_t *)cData, len); break; case CastType::Int322Int16: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int *)aData, (int16_t *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int *)aData, (int16_t *)cData, len); break; case CastType::Int162Float: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (int16_t *)aData, (float *)cData, len); break; case CastType::Int162Int32: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int16_t *)aData, (int *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int16_t *)aData, (int *)cData, len); break; case CastType::Int82Float: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (int8_t *)aData, (float *)cData, len); break; case CastType::Int82Int16: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int8_t *)aData, (int16_t *)cData, - len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int8_t *)aData, (int16_t *)cData, + len); break; case CastType::Int82Int32: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int8_t *)aData, (int *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int8_t *)aData, (int *)cData, len); break; case CastType::Int322Int64: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int *)aData, (int64_t *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int *)aData, (int64_t *)cData, len); break; case CastType::Int642Int32: - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int64_t *)aData, (int *)cData, len); + ret = xdnn::cast(context->KUNLUNHandle(), + (int64_t *)aData, (int *)cData, len); break; case CastType::Int642Float: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (int64_t *)aData, (float *)cData, len); break; case CastType::Float162Float: - ret = baidu::xpu::api::cast( + ret = xdnn::cast( context->KUNLUNHandle(), (float16 *)aData, (float *)cData, len); break; + case CastType::Float2Float: + ret = xdnn::copy(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); + break; default: IT_TODO_HALT(); } diff --git a/src/kernels/kunlun/concat.cc b/src/kernels/kunlun/concat.cc index f7ba2a2d..b2fd9fa8 100644 --- a/src/kernels/kunlun/concat.cc +++ b/src/kernels/kunlun/concat.cc @@ -26,8 +26,8 @@ class ConcatXdnn : public KUNLUNKernelWithoutConfig { } dims.push_back(dim); } - auto ret = baidu::xpu::api::concat( - context->KUNLUNHandle(), inputsData, (float *)cData, dims, axis); + auto ret = xdnn::concat(context->KUNLUNHandle(), inputsData, + (float *)cData, dims, axis); assert(ret == 0); return; } diff --git a/src/kernels/kunlun/conv.cc b/src/kernels/kunlun/conv.cc index 45f054b1..271ab133 100644 --- a/src/kernels/kunlun/conv.cc +++ b/src/kernels/kunlun/conv.cc @@ -24,11 +24,17 @@ class ConvXdnn : public KUNLUNKernelWithoutConfig { std::vector stride = {sh, sw}; std::vector dilation = {dh, dw}; - auto ret = baidu::xpu::api::conv2d( + // TODO: Convolution operators still have some accuracy problems + checkKUNLUNError((xdnn::conv2d( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g, - nullptr, nullptr, nullptr, true); - assert(ret == 0); + nullptr, nullptr, nullptr, true))); + + // checkKUNLUNError((xdnn::conv2d_fusion( + // context->KUNLUNHandle(), (float *const)aData, (float + // *const)bData, (float *)cData, n, c, h, w, f, ksize, stride, pads, + // dilation, g, nullptr, nullptr, nullptr, true, nullptr, nullptr, + // xdnn::Activation_t::LINEAR))); return; } }; diff --git a/src/kernels/kunlun/conv_trans.cc b/src/kernels/kunlun/conv_trans.cc index 8219d829..76677e27 100644 --- a/src/kernels/kunlun/conv_trans.cc +++ b/src/kernels/kunlun/conv_trans.cc @@ -37,11 +37,10 @@ class ConvTransXdnn : public KUNLUNKernelWithoutConfig { if (dimOutput.size() != 4) IT_TODO_HALT(); - auto ret = - baidu::xpu::api::conv2d_transpose( - context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g, - nullptr, nullptr, nullptr, isNCHW); + auto ret = xdnn::conv2d_transpose( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g, + nullptr, nullptr, nullptr, isNCHW); assert(ret == 0); return; } diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc index 5a9754f5..665ea56a 100644 --- a/src/kernels/kunlun/element_wise.cc +++ b/src/kernels/kunlun/element_wise.cc @@ -1,6 +1,7 @@ #include "operators/element_wise.h" #include "kunlun/kunlun_kernel_without_config.h" #include "kunlun/kunlun_runtime.h" +#include "utils/operator_utils.h" namespace infini { class AddXdnn : public KUNLUNKernelWithoutConfig { @@ -22,10 +23,9 @@ class AddXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_add( + checkKUNLUNError(xdnn::broadcast_add( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -49,10 +49,9 @@ class SubXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_sub( + checkKUNLUNError(xdnn::broadcast_sub( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -76,10 +75,9 @@ class MulXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_mul( + checkKUNLUNError(xdnn::broadcast_mul( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -95,18 +93,40 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); + auto aSize = op->getInputs(0)->size(); auto aDim = op->getInputs(0)->getDims(); + auto bSize = op->getInputs(1)->size(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() == 0) { - aDim.push_back(1); - } + auto dtype = op->getDType(); + if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_div( - context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + + if (aSize == bSize) { + // Do ElementWise Sub with no broadcast + checkKUNLUNError(xdnn::div(context->KUNLUNHandle(), + (float *)aData, (float *)bData, + (float *)cData, aSize)); + } else { + // Do broadcast div + Shape aligned = infer_broadcast(aDim, bDim); + if (aligned == aDim) { + // BData need to be broadcasted + checkKUNLUNError(xdnn::broadcast_div( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)cData, aDim, bDim)); + } else { + // Use workspace to broadcast aData + KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize()); + checkKUNLUNError(xdnn::broadcast( + context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim, + bDim)); + checkKUNLUNError(xdnn::div(context->KUNLUNHandle(), + (float *)wks, (float *)bData, + (float *)cData, bSize)); + } + } return; } }; @@ -131,10 +151,9 @@ class PowXdnn : public KUNLUNKernelWithoutConfig { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_pow( + checkKUNLUNError(xdnn::broadcast_pow( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -158,10 +177,9 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_max( + checkKUNLUNError(xdnn::broadcast_max( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -185,10 +203,9 @@ class MinXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_min( + checkKUNLUNError(xdnn::broadcast_min( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -204,7 +221,9 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -214,12 +233,11 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_equal( + checkKUNLUNError(xdnn::broadcast_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (bool *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + (bool *)wsData, aDim, bDim)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -235,7 +253,8 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -245,12 +264,11 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_greater_equal( + checkKUNLUNError(xdnn::broadcast_greater_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (bool *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + (bool *)wsData, aDim, bDim)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -266,7 +284,8 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + KUNLUNPtr wsData = + context->getWorkspace(len * (op->getDType()).getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -276,12 +295,11 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_greater_than( + checkKUNLUNError(xdnn::broadcast_greater_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (bool *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + (bool *)wsData, aDim, bDim)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -297,7 +315,8 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -307,12 +326,11 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_less_equal( + checkKUNLUNError(xdnn::broadcast_less_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (bool *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + (bool *)wsData, aDim, bDim)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -328,7 +346,8 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -338,12 +357,11 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_less_than( + checkKUNLUNError(xdnn::broadcast_less_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (bool *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + (bool *)wsData, aDim, bDim)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -367,10 +385,9 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::broadcast_floordiv( + checkKUNLUNError(xdnn::broadcast_floordiv( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, aDim, bDim); - assert(ret == 0); + (float *)cData, aDim, bDim)); return; } }; @@ -388,10 +405,9 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig { size_t len = op->getOutput()->size(); auto dim = op->getInputs(0)->getDims(); - auto ret = baidu::xpu::api::mse_loss( - context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, len); - assert(ret == 0); + checkKUNLUNError(xdnn::mse_loss(context->KUNLUNHandle(), + (float *)aData, (float *)bData, + (float *)cData, len)); return; } }; @@ -407,7 +423,8 @@ class AndXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -417,12 +434,11 @@ class AndXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::logical_and( - context->KUNLUNHandle(), (bool *)aData, (bool *)bData, - (bool *)wsData, len); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + checkKUNLUNError(xdnn::logical_and(context->KUNLUNHandle(), + (bool *)aData, (bool *)bData, + (bool *)wsData, len)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -438,7 +454,8 @@ class OrXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -448,12 +465,11 @@ class OrXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::logical_or( - context->KUNLUNHandle(), (bool *)aData, (bool *)bData, - (bool *)wsData, len); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + checkKUNLUNError(xdnn::logical_or(context->KUNLUNHandle(), + (bool *)aData, (bool *)bData, + (bool *)wsData, len)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -469,7 +485,8 @@ class XorXdnn : public KUNLUNKernelWithoutConfig { void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); @@ -479,12 +496,11 @@ class XorXdnn : public KUNLUNKernelWithoutConfig { if (bDim.size() == 0) { bDim.push_back(1); } - auto ret = baidu::xpu::api::logical_xor( - context->KUNLUNHandle(), (bool *)aData, (bool *)bData, - (bool *)wsData, len); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + checkKUNLUNError(xdnn::logical_xor(context->KUNLUNHandle(), + (bool *)aData, (bool *)bData, + (bool *)wsData, len)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; @@ -499,14 +515,14 @@ class NotXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); + auto dtype = op->getDType(); + KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize()); auto aDim = op->getInputs(0)->getDims(); - auto ret = baidu::xpu::api::logical_not( - context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len); - assert(ret == 0); + checkKUNLUNError(xdnn::logical_not( + context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len)); + checkKUNLUNError((xdnn::cast( + context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len))); return; } }; diff --git a/src/kernels/kunlun/gather.cc b/src/kernels/kunlun/gather.cc index 75fd2365..ed22c88a 100644 --- a/src/kernels/kunlun/gather.cc +++ b/src/kernels/kunlun/gather.cc @@ -1,4 +1,5 @@ #include "operators/gather.h" +#include "core/common.h" #include "kunlun/kunlun_kernel_without_config.h" #include "kunlun/kunlun_runtime.h" @@ -10,17 +11,18 @@ class GatherXdnn : public KUNLUNKernelWithoutConfig { IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); - void *const aData = (op->getInputs(0)->getRawDataPtr()); - void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const aData = (op->getInputs(0)->getRawDataPtr()); // data + void *const bData = + (op->getInputs(1)->getRawDataPtr()); // indice void *const cData = (op->getOutput()->getRawDataPtr()); - auto shape = op->getInputs(0)->getDims(); - auto index = op->getInputs(1)->getDims(); - auto axis = op->getAxis(); - auto ret = baidu::xpu::api::gather( + Shape aShape = op->getInputs(0)->getDims(); + Tensor bTensor = op->getInputs(1); + int axis = op->getAxis(); + checkKUNLUNError((baidu::xpu::api::gather( context->KUNLUNHandle(), (float *)aData, (int *)bData, - (float *)cData, shape, index.size(), axis); - assert(ret == 0); + (float *)cData, aShape, bTensor->size(), axis))); + return; } }; diff --git a/src/kernels/kunlun/matmul.cc b/src/kernels/kunlun/matmul.cc index f70394a9..bddee46d 100644 --- a/src/kernels/kunlun/matmul.cc +++ b/src/kernels/kunlun/matmul.cc @@ -1,30 +1,123 @@ #include "operators/matmul.h" +#include "kunlun/kunlun_act_type.h" +#include "kunlun/kunlun_common.h" #include "kunlun/kunlun_kernel_without_config.h" #include "kunlun/kunlun_runtime.h" +#include "utils/operator_utils.h" namespace infini { class MatmulXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, const RuntimeObj *_context) const override { + // This kernel do C = act(alpha * x * w + beta * bias) auto op = as(_op); IT_ASSERT(op->getDType() == DataType::Float32); auto context = dynamic_cast(_context); + void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); - void *const cData = (op->getOutput()->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + + Shape aDims = op->getInputs(0)->getDims(); + Shape bDims = op->getInputs(1)->getDims(); + Shape cDims = op->getOutput()->getDims(); + + const auto [b, m, n, k] = op->getBMNK(); bool transA = op->getTransA(); bool transB = op->getTransB(); + int rankA = op->getInputs(0)->getRank(); + int rankB = op->getInputs(1)->getRank(); + int rankAligned = std::max(rankA, rankB); + IT_ASSERT(rankAligned <= SMALL_ARRAY_SIZE); - auto b = op->getB(); - auto m = op->getM(); - auto n = op->getN(); - auto k = op->getK(); + float alpha = 1.f, beta = 0.f; + Tensor biasTensor = op->getBias(); + DataType dtype = op->getDType(); - auto ret = baidu::xpu::api::fc_batched( - context->KUNLUNHandle(), b, transA, transB, m, n, k, 1.0, - (float *)aData, m * k, (float *)bData, n * k, 0.0, (float *)cData, - m * n, nullptr, nullptr); - assert(ret == 0); + if (b > 1) { + SmallArray alignedAShape; + SmallArray alignedBShape; + // Padding 1 in aShape and bShape in order to align rank + broadcastShape(aDims, alignedAShape, rankAligned, rankA); + broadcastShape(bDims, alignedBShape, rankAligned, rankB); + // Calculate batch dim + int batchA = alignedAShape.prod(0, rankAligned - 2); + int batchB = alignedBShape.prod(0, rankAligned - 2); + // View aShape bShape to 3 dim + Shape aDimsMatmul = {batchA, aDims[rankA - 2], aDims[rankA - 1]}; + Shape bDimsMatmul = {batchB, bDims[rankB - 2], bDims[rankB - 1]}; + auto numOutput = op->getOutput()->size(); + KUNLUNPtr wkspace = nullptr; + void *AData = nullptr; + void *BData = nullptr; + void *CData = nullptr; + if (batchA != batchB) { + // If bs not equal, then broadcast + IT_ASSERT(batchA == 1 || batchB == 1); + if (batchA == 1) { + // Broadcast aShapeMatmul in batch dimension + Shape aDimsTarget = {b, aDimsMatmul[1], aDimsMatmul[2]}; + auto numInput = + shapeProd(aDimsTarget.begin(), aDimsTarget.end()); + wkspace = context->getWorkspace(numInput * dtype.getSize()); + checkKUNLUNError(xdnn::broadcast( + context->KUNLUNHandle(), (float *)aData, + (float *)wkspace, aDimsMatmul, aDimsTarget)); + AData = wkspace; + BData = bData; + CData = + biasTensor + ? context->getWorkspace(numOutput * dtype.getSize()) + : outData; + } else { + // Broadcast bShapeMatmul in batch dimension + Shape bDimsTarget = {b, bDimsMatmul[1], bDimsMatmul[2]}; + auto numInput = + shapeProd(bDimsTarget.begin(), bDimsTarget.end()); + wkspace = context->getWorkspace(numInput * dtype.getSize()); + checkKUNLUNError(xdnn::broadcast( + context->KUNLUNHandle(), (float *)bData, + (float *)wkspace, bDimsMatmul, bDimsTarget)); + AData = aData; + BData = wkspace; + CData = + biasTensor + ? context->getWorkspace(numOutput * dtype.getSize()) + : outData; + } // endif batchA == 1 + } else { // batchA == batchB, no need to broadcast + AData = aData; + BData = bData; + CData = biasTensor + ? context->getWorkspace(numOutput * dtype.getSize()) + : outData; + } + checkKUNLUNError((xdnn::fc_batched( + context->KUNLUNHandle(), b, transA, transB, m, n, k, alpha, + (float *)AData, m * k, (float *)BData, n * k, beta, + (float *)CData, m * n, nullptr, nullptr))); + // Broadcast_add xw and bias if bias exists + if (biasTensor) { + auto biasShape = biasTensor->getDims(); + broadcastShape(cDims, biasShape); + checkKUNLUNError(baidu::xpu::api::broadcast_add( + context->KUNLUNHandle(), (float *)CData, + biasTensor->getRawDataPtr(), (float *)outData, + cDims, biasShape)); + } + } else { + // Matmul with no batch, call fc_fusion + const int lda = transA ? m : k, ldb = transB ? k : n, ldc = n; + auto kunlunAct = parseActType(std::move(op->getAct())); + checkKUNLUNError( + (baidu::xpu::api::fc_fusion( + context->KUNLUNHandle(), (float *)aData, (float *)bData, + (float *)outData, m, n, k, transA, transB, nullptr, nullptr, + nullptr, lda, ldb, ldc, alpha, 0.f, + biasTensor ? biasTensor->getRawDataPtr() : nullptr, + kunlunAct, nullptr))); + } return; } }; diff --git a/src/kernels/kunlun/pooling.cc b/src/kernels/kunlun/pooling.cc index bc49e31c..0f9580ea 100644 --- a/src/kernels/kunlun/pooling.cc +++ b/src/kernels/kunlun/pooling.cc @@ -14,11 +14,23 @@ class AvgPooling : public KUNLUNKernelWithoutConfig { auto [n, c, h, w, kh, kw] = op->getNCHWRS(); auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + auto outShape = op->getOutput()->getDims(); std::vector ksize = {kh, kw}; std::vector stride = {sh, sw}; std::vector pad = {ph, pw}; + int yh = outShape[op->getOutput()->getRank() - 2]; + int yw = outShape[op->getOutput()->getRank() - 1]; + + // If Maxpool with ceilMode true + // We need to change padding in order to call xdnn api + if (op->getCeilMode() && yh > (h + 2 * ph - kh) / sh + 1) { + auto padh = yh - ((h + 2 * ph - kh) / sh + 1); + auto padw = yw - ((w + 2 * pw - kw) / sw + 1); + pad = {0, padh, 0, padw}; + } + auto ret = baidu::xpu::api::avg_pool2d( context->KUNLUNHandle(), (float *)aData, (float *)cData, n, c, h, w, ksize, stride, pad, true, true, nullptr, nullptr); @@ -38,21 +50,30 @@ class MaxPooling : public KUNLUNKernelWithoutConfig { auto [n, c, h, w, kh, kw] = op->getNCHWRS(); auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + auto outShape = op->getOutput()->getDims(); std::vector ksize = {kh, kw}; std::vector stride = {sh, sw}; std::vector pad = {ph, pw}; - int yh = (h + ph * 2 - kh) / sh + 1; - int yw = (w + pw * 2 - kw) / sw + 1; + int yh = outShape[op->getOutput()->getRank() - 2]; + int yw = outShape[op->getOutput()->getRank() - 1]; - KUNLUNPtr indices = context->getWorkspace(yh * yw * 4); + // If Maxpool with ceilMode true + // We need to change padding in order to call xdnn api + if (op->getCeilMode() && yh > (h + 2 * ph - kh) / sh + 1) { + auto padh = yh - ((h + 2 * ph - kh) / sh + 1); + auto padw = yw - ((w + 2 * pw - kw) / sw + 1); + pad = {0, padh, 0, padw}; + } - auto ret = baidu::xpu::api::max_pool2d( + KUNLUNPtr indices = context->getWorkspace(yh * yw * sizeof(int)); + + checkKUNLUNError(baidu::xpu::api::max_pool2d( context->KUNLUNHandle(), (float *)aData, (float *)cData, (int *)indices, n, c, h, w, ksize, stride, pad, true, nullptr, - nullptr, false); - assert(ret == 0); + nullptr, false)); + return; } }; diff --git a/src/kernels/kunlun/reduce_mean.cc b/src/kernels/kunlun/reduce.cc similarity index 53% rename from src/kernels/kunlun/reduce_mean.cc rename to src/kernels/kunlun/reduce.cc index 928d42c8..71f360e1 100644 --- a/src/kernels/kunlun/reduce_mean.cc +++ b/src/kernels/kunlun/reduce.cc @@ -1,8 +1,9 @@ +#include "operators/reduce.h" #include "kunlun/kunlun_kernel_without_config.h" #include "kunlun/kunlun_runtime.h" -#include "operators/reduce.h" namespace infini { + class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { @@ -26,6 +27,31 @@ class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig { } }; +class ReduceSumXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto axes_set = op->getAxes(); + std::vector axes; + axes.assign(axes_set.begin(), axes_set.end()); + auto shape = op->getInputs(0)->getDims(); + + auto ret = baidu::xpu::api::reduce_sum( + context->KUNLUNHandle(), (float *)aData, (float *)cData, shape, + axes); + assert(ret == 0); + return; + } +}; + REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, ReduceMeanXdnn, "ReduceMean_xdnn_KUNLUN"); +REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceSum, ReduceSumXdnn, + "ReduceSum_xdnn_KUNLUN"); }; // namespace infini diff --git a/src/kernels/kunlun/select.cc b/src/kernels/kunlun/select.cc deleted file mode 100644 index 7cdfd8bf..00000000 --- a/src/kernels/kunlun/select.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "kunlun/kunlun_kernel_without_config.h" -#include "kunlun/kunlun_runtime.h" -#include "operators/where.h" - -namespace infini { -class WhereXdnn : public KUNLUNKernelWithoutConfig { - void compute(const Operator &_op, - const RuntimeObj *_context) const override { - auto op = as(_op); - IT_ASSERT(op->getDType() == DataType::Float32); - auto context = dynamic_cast(_context); - - void *const aData = (op->getInputs(0)->getRawDataPtr()); - void *const bData = (op->getInputs(1)->getRawDataPtr()); - void *const cData = (op->getInputs(2)->getRawDataPtr()); - void *const dData = (op->getOutput()->getRawDataPtr()); - - auto aDim = op->getInputs(0)->getDims(); - auto bDim = op->getInputs(1)->getDims(); - auto cDim = op->getInputs(2)->getDims(); - auto dDim = op->getOutput()->getDims(); - - auto ret = baidu::xpu::api::select( - context->KUNLUNHandle(), (bool *)cData, (float *)aData, - (float *)bData, (float *)dData, cDim, aDim); - assert(ret == 0); - return; - } -}; - -REGISTER_KERNEL(Device::KUNLUN, OpType::Where, WhereXdnn, "Where_xdnn_KUNLUN"); -}; // namespace infini diff --git a/src/kernels/kunlun/slice.cc b/src/kernels/kunlun/slice.cc new file mode 100644 index 00000000..3de02476 --- /dev/null +++ b/src/kernels/kunlun/slice.cc @@ -0,0 +1,39 @@ +#include "operators/slice.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class SliceXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *inData = op->getInputs(0)->getRawDataPtr(); + void *outData = op->getOutput()->getRawDataPtr(); + + // Get attributes of Slice OP + Shape starts = op->getStarts(), ends = op->getEnds(), + steps = op->getSteps(); + Shape inShape = op->getInputs(0)->getDims(); + // If all steps are 1, set continuous True + bool continuous = + (size_t)std::count(steps.begin(), steps.end(), 1) == steps.size(); + if (continuous) { + // if continuous, call xdnn::slice + checkKUNLUNError( + xdnn::slice(context->KUNLUNHandle(), (float *)inData, + (float *)outData, inShape, starts, ends)); + + } else { + // else call xdnn::strided_slice + checkKUNLUNError(xdnn::strided_slice( + context->KUNLUNHandle(), (float *)inData, (float *)outData, + inShape, starts, ends, steps)); + } + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Slice, SliceXdnn, "Slice_xdnn_KUNLUN") +}; // namespace infini diff --git a/src/kernels/kunlun/softmax.cc b/src/kernels/kunlun/softmax.cc index 552b6c21..d5203d05 100644 --- a/src/kernels/kunlun/softmax.cc +++ b/src/kernels/kunlun/softmax.cc @@ -15,9 +15,9 @@ class SoftmaxXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); - auto ret = baidu::xpu::api::softmax( - context->KUNLUNHandle(), (float *)aData, (float *)cData, dim, axis); - assert(ret == 0); + checkKUNLUNError(xdnn::softmax(context->KUNLUNHandle(), + (float *)aData, (float *)cData, + dim, axis)); return; } }; diff --git a/src/kernels/kunlun/transpose.cc b/src/kernels/kunlun/transpose.cc index 7a89480e..f887e179 100644 --- a/src/kernels/kunlun/transpose.cc +++ b/src/kernels/kunlun/transpose.cc @@ -16,13 +16,9 @@ class TransposeXdnn : public KUNLUNKernelWithoutConfig { auto dimin = op->getInputs(0)->getDims(); auto permute = op->getPermute(); - if (dimin.size() != 4) { - IT_TODO_HALT(); - } - - auto ret = baidu::xpu::api::transpose( - context->KUNLUNHandle(), (float *)aData, (float *)cData, dimin, - permute); + auto ret = + xdnn::transpose(context->KUNLUNHandle(), (float *)aData, + (float *)cData, dimin, permute); assert(ret == 0); return; } @@ -46,9 +42,9 @@ class DepthToSpaceXdnn : public KUNLUNKernelWithoutConfig { } else { permute = {0, 1, 4, 2, 5, 3}; } - auto ret = baidu::xpu::api::transpose( - context->KUNLUNHandle(), (float *)aData, (float *)cData, reshape, - permute); + auto ret = + xdnn::transpose(context->KUNLUNHandle(), (float *)aData, + (float *)cData, reshape, permute); assert(ret == 0); return; } diff --git a/src/kernels/kunlun/unary.cc b/src/kernels/kunlun/unary.cc index 3b444d3b..5fa2e99b 100644 --- a/src/kernels/kunlun/unary.cc +++ b/src/kernels/kunlun/unary.cc @@ -14,8 +14,8 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::relu( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::relu(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -32,8 +32,8 @@ class SigmoidXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::sigmoid( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::sigmoid(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -50,8 +50,45 @@ class TanhXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::tanh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::tanh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class HardSwishXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + auto ret = xdnn::hard_swish(context->KUNLUNHandle(), + (float *)aData, (float *)cData, len); + assert(ret == 0); + return; + } +}; + +class HardSigmoidXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto len = op->getInputs(0)->size(); + + // Slop set to 0.2 as default + auto ret = xdnn::hard_sigmoid( + context->KUNLUNHandle(), (float *)aData, (float *)cData, len, 0.2); assert(ret == 0); return; } @@ -68,8 +105,8 @@ class SquareXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::square( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::square(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -86,8 +123,8 @@ class SqrtXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::sqrt( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::sqrt(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -104,8 +141,8 @@ class RsqrtXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::rsqrt( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::rsqrt(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -122,8 +159,8 @@ class ExpXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::exp( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::exp(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -140,8 +177,8 @@ class CeilXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::ceil( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::ceil(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -160,9 +197,8 @@ class ClipXdnn : public KUNLUNKernelWithoutConfig { float min = op->getMin().value(); float max = op->getMax().value(); - auto ret = baidu::xpu::api::clip(context->KUNLUNHandle(), - (float *)aData, (float *)cData, - len, min, max); + auto ret = xdnn::clip(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len, min, max); assert(ret == 0); return; } @@ -179,8 +215,8 @@ class FloorXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::floor( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::floor(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -197,8 +233,8 @@ class NegXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::neg( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::neg(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -214,8 +250,8 @@ class CopyXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::copy( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::copy(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -232,8 +268,8 @@ class ReciprocalXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::reciprocal( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::reciprocal(context->KUNLUNHandle(), + (float *)aData, (float *)cData, len); assert(ret == 0); return; } @@ -250,8 +286,8 @@ class AbsXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::abs( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::abs(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -268,8 +304,8 @@ class ATanXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::arctan( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::arctan(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; } @@ -288,36 +324,36 @@ class LogXdnn : public KUNLUNKernelWithoutConfig { 1, }; auto len = op->getInputs(0)->size(); + auto dtype = op->getDType(); // get ptr of tempspace - KUNLUNPtr temp = context->getWorkspace(len * sizeof(float)); + KUNLUNPtr temp = context->getWorkspace(len * dtype.getSize()); LogObj::LogType type = op->getType(); // get output of xpu::api::loge(x) - auto ret = baidu::xpu::api::log( - context->KUNLUNHandle(), (float *)aData, (float *)temp, len); + auto ret = xdnn::log(context->KUNLUNHandle(), (float *)aData, + (float *)temp, len); // get ptr of divider - KUNLUNPtr dd = - (float *)(context->getWorkspace((1 + len) * sizeof(float))) + len; + KUNLUNPtr dd = context->getWorkspace(1 * dtype.getSize()); // choose from logE, log2, log10 switch (type) { float constant; case LogObj::LogE: // if use loge, copy from temp to cData - ret = baidu::xpu::api::copy( - context->KUNLUNHandle(), (float *)temp, (float *)cData, len); + ret = xdnn::copy(context->KUNLUNHandle(), (float *)temp, + (float *)cData, len); break; case LogObj::Log2: constant = std::log(2); context->copyBlobFromCPU(dd, &constant, sizeof(float)); - ret = baidu::xpu::api::broadcast_div( - context->KUNLUNHandle(), (float *)temp, (float *)dd, - (float *)cData, aDim, divDim); + ret = xdnn::broadcast_div(context->KUNLUNHandle(), + (float *)temp, (float *)dd, + (float *)cData, aDim, divDim); break; case LogObj::Log10: constant = std::log(10); context->copyBlobFromCPU(dd, &constant, sizeof(float)); - ret = baidu::xpu::api::broadcast_div( - context->KUNLUNHandle(), (float *)temp, (float *)dd, - (float *)cData, aDim, divDim); + ret = xdnn::broadcast_div(context->KUNLUNHandle(), + (float *)temp, (float *)dd, + (float *)cData, aDim, divDim); break; default: printf("LogType not support!"); @@ -337,8 +373,8 @@ class CosXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::cos( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::cos(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -354,8 +390,8 @@ class SinXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::sin( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::sin(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -371,8 +407,8 @@ class TanXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::tan( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::tan(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -388,8 +424,8 @@ class SinhXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::sinh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::sinh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -405,8 +441,8 @@ class CoshXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::cosh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::cosh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -422,8 +458,8 @@ class ErfXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::erf( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::erf(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -439,8 +475,8 @@ class ACosXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::arccos( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::arccos(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -456,8 +492,8 @@ class ACoshXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::acosh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::acosh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -473,8 +509,8 @@ class ASinXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::arcsin( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::arcsin(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -490,8 +526,8 @@ class ASinhXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::asinh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::asinh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -507,8 +543,8 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); auto len = op->getInputs(0)->size(); - auto ret = baidu::xpu::api::atanh( - context->KUNLUNHandle(), (float *)aData, (float *)cData, len); + auto ret = xdnn::atanh(context->KUNLUNHandle(), (float *)aData, + (float *)cData, len); assert(ret == 0); return; @@ -546,7 +582,10 @@ REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, ErfXdnn, "Erf_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, ACosXdnn, "ACos_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, ACoshXdnn, "ACosh_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, ASinXdnn, "ASin_xdnn"); -REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, ASinhXdnn, - "ASinh_xdnn_Float3 2"); +REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, ASinhXdnn, "ASinh_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, ATanhXdnn, "ATanh_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::HardSwish, HardSwishXdnn, + "HardSwish_xdnn"); +REGISTER_KERNEL(Device::KUNLUN, OpType::HardSigmoid, HardSigmoidXdnn, + "HardSigmoid_xdnn"); }; // namespace infini diff --git a/src/kernels/kunlun/where.cc b/src/kernels/kunlun/where.cc new file mode 100644 index 00000000..bd950d1e --- /dev/null +++ b/src/kernels/kunlun/where.cc @@ -0,0 +1,67 @@ +#pragma GCC diagnostic ignored "-Wunused-variable" +#include "operators/where.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "utils/operator_utils.h" + +namespace infini { + +class WhereXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + IT_ASSERT(op->getDType() == DataType::Float32); + auto context = dynamic_cast(_context); + + void *const aData = + (op->getInputs(0)->getRawDataPtr()); // inputX + void *const bData = + (op->getInputs(1)->getRawDataPtr()); // inputY + void *const cData = + (op->getInputs(2)->getRawDataPtr()); // condition + void *const dData = + (op->getOutput()->getRawDataPtr()); // output + + auto aDim = op->getInputs(0)->getDims(); // dimX + auto bDim = op->getInputs(1)->getDims(); // dimY + auto cDim = op->getInputs(2)->getDims(); // dimCondition + auto dDim = op->getOutput()->getDims(); // dimOutput + + auto dtype = op->getDType(); + + if (aDim != bDim) { + // Infer broadcast for X and Y + Shape XYDim = infer_broadcast(aDim, bDim); + int XYSize = std::accumulate(XYDim.begin(), XYDim.end(), 1, + std::multiplies()); + // Align rank for XYDim and aDim or bDim + broadcastShape(XYDim, aDim); + broadcastShape(XYDim, bDim); + // Get workspace + void *wkspace = context->getWorkspace(XYSize * dtype.getSize()); + // Broadcast X Y + checkKUNLUNError(xdnn::broadcast( + context->KUNLUNHandle(), + (float *)(XYDim == aDim ? bData : aData), (float *)wkspace, + (XYDim == aDim ? bDim : aDim), XYDim)); + // Align Rank + broadcastShape(dDim, XYDim); + broadcastShape(dDim, XYDim); + // Where + void *XData = XYDim == aDim ? aData : wkspace; + void *YData = XYDim == bDim ? bData : wkspace; + checkKUNLUNError(xdnn::select( + context->KUNLUNHandle(), (bool *)cData, (float *)XData, + (float *)YData, (float *)dData, cDim, XYDim)); + } else { + checkKUNLUNError(xdnn::select( + context->KUNLUNHandle(), (bool *)cData, (float *)aData, + (float *)bData, (float *)dData, cDim, aDim)); + } + + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Where, WhereXdnn, "Where_xdnn_KUNLUN"); +}; // namespace infini diff --git a/src/kunlun/kunlun_runtime.cc b/src/kunlun/kunlun_runtime.cc index b614ac9c..b54c1cd0 100644 --- a/src/kunlun/kunlun_runtime.cc +++ b/src/kunlun/kunlun_runtime.cc @@ -19,6 +19,7 @@ void KUNLUNRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, auto perfData = perfEngine.getPerfData(perfKey); if (!perfData && !tune) { kernel->compute(op, this); + workspace->resetWorkspace(); continue; } @@ -52,8 +53,20 @@ void KUNLUNRuntimeObj::run(const Graph &graph, bool tune, sync(); } -void KUNLUNRuntimeObj::sync() const { ; } +void KUNLUNRuntimeObj::sync() const { xpu_wait(); } string KUNLUNRuntimeObj::toString() const { return "KUNLUN Runtime"; } +void KUNLUNRuntimeObj::initComm(const string &name, int worldSize, int rank) { + IT_ASSERT(worldSize > 0); + IT_ASSERT(rank >= 0); + IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; +#ifdef INFINI_USE_XCCL + comm = std::make_unique(name, worldSize, rank); +#else + IT_TODO_HALT_MSG("Not compiled with XCCL"); +#endif +} + } // namespace infini diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 60cbb826..db4533a7 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -25,7 +25,7 @@ optional> MatmulObj::inferShape(const TensorVec &inputs) { auto A = inputs[0], B = inputs[1]; auto shapeA = A->getDims(); auto shapeB = B->getDims(); - int rankA = A->getRank(); + int rankA = A->getRank(); // Rank is the Shape of TensorDims int rankB = B->getRank(); Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2)); Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2)); diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 79d2ab83..2a6a3f4c 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -231,6 +231,8 @@ DataType CastObj::getOutputDataType() const { return DataType::Float32; case CastType::Float2BFloat16: return DataType::BFloat16; + case CastType::Float2Float: + return DataType::Float32; default: IT_TODO_HALT(); } diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index b191fb33..502336a8 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -114,4 +114,28 @@ std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); return deviceStr + ", " + opStr; } + +int shapeProd(std::vector::iterator start, + std::vector::iterator end) { + return std::accumulate(start, end, 1, std::multiplies()); +} + +void broadcastShape(const Shape &originShape, SmallArray &modifyShape, + int nDims, int size) { + for (int i = nDims - size - 1; i >= 0; --i) { + modifyShape.data[i] = 1; + } + for (int i = nDims - 1; i >= nDims - size; --i) { + modifyShape.data[i] = originShape[i - nDims + size]; + } +} + +void broadcastShape(const Shape &tempShape, Shape &modifyShape) { + // Align Rank, Add 1 in the start of smallShape + IT_ASSERT(tempShape.size() >= modifyShape.size()); + modifyShape.insert(modifyShape.begin(), + tempShape.size() - modifyShape.size(), 1); + return; +} + } // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_allgather.cc b/test/kernels/kunlun/test_kunlun_allgather.cc new file mode 100644 index 00000000..2e7670ad --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_allgather.cc @@ -0,0 +1,50 @@ +#ifdef INFINI_USE_XCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/all_gather.h" +#include "test.h" +#include "xpu/bkcl.h" +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +void allGather(const string taskName, int deviceID, vector data, + vector> ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime kunlunRuntime = make_ref(deviceID); + kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(kunlunRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, WORLD_SIZE); + // Copy data from CPU to GPU + g->dataMalloc(); + input->copyin(data); + // Run operation + kunlunRuntime->run(g); + // Copy output from GPU to CPU + for (int i = 0; i < WORLD_SIZE; ++i) { + auto result = op->getOutputs()[i]->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(ans[i])); + } +} + +TEST(KUNLUN_AllGather, run) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector> ans = {{2., 3.}, {5., 6.}}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace infini +#endif diff --git a/test/kernels/kunlun/test_kunlun_allreduce.cc b/test/kernels/kunlun/test_kunlun_allreduce.cc new file mode 100644 index 00000000..c08f1a32 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_allreduce.cc @@ -0,0 +1,72 @@ +#ifdef INFINI_USE_XCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/all_reduce.h" +#include "test.h" +#include "xpu/bkcl.h" +#include + +static int WORLD_SIZE = 2; + +using namespace infini; + +template +void allReduce(const string taskName, int deviceID, vector data, + vector ans) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime kunlunRuntime = make_ref(deviceID); + kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID); + Graph g = make_ref(kunlunRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr); + g->dataMalloc(); + input->copyin(data); + kunlunRuntime->run(g); + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(KUNLUN_AllReduce, sum) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {7., 9.}; + std::vector threads; + for (int rank = 0; rank < WORLD_SIZE; ++rank) { + threads.emplace_back(allReduce, "test_allreduce_sum", + rank, data[rank], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(KUNLUN_AllReduce, max) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {5., 6.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_max", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(KUNLUN_AllReduce, min) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {2., 3.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_min", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} +#endif diff --git a/test/kernels/kunlun/test_kunlun_broadcast.cc b/test/kernels/kunlun/test_kunlun_broadcast.cc new file mode 100644 index 00000000..b99f139e --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_broadcast.cc @@ -0,0 +1,56 @@ +#ifdef INFINI_USE_XCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/broadcast.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; +static int root = 0; + +namespace infini { + +void broadcast(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime kunlunRuntime = make_ref(deviceID); + kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(kunlunRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr, root); + // Copy data from CPU to GPU + g->dataMalloc(); + // Only rank 0 has the data + if (deviceID == root) { + input->copyin(data); + } + // Run broadcast operation + kunlunRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(KUNLUN_Broadcast, run) { + // Only 1 device gets data. Every rank should have the same data after + // broadcast. + vector data = {2., 3., 5., 6.}; + vector ans = {2., 3., 5., 6.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(broadcast, "test_broadcast", gpu, data, ans); + } + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace infini + +#endif diff --git a/test/kernels/kunlun/test_kunlun_gather.cc b/test/kernels/kunlun/test_kunlun_gather.cc new file mode 100644 index 00000000..707d85aa --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_gather.cc @@ -0,0 +1,144 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/gather.h" + +#include "test.h" + +namespace infini { +/* +test1: +input = [ + [1, 2], + [3, 4], + [5, 6], + ] + indices = [ + [0, 1], + [1, 2], + ] + output = [ + [ + [1, 2], + [3, 4], + ], + [ + [3, 4], + [5, 6], + ], + ] + axis=0 + */ + +/* +test2 +input = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + indices = [ + [0, 2], + ] + axis = 1, + output = [ + [[0, 2]], + [[3, 5]], + [[6, 8]], + ] +*/ +/* +test3 +input=[[[ 0, 1], + [ 2, 3], + [ 4, 5], + [ 6, 7]], + + [[ 8, 9], + [10, 11], + [12, 13], + [14, 15]]] //(2,4,2) +indices=[[0],[3],[1]] //(3,1) +axis=1 +output= + +*/ + +TEST(Gather, KUNLUN) { + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 2}, DataType::Float32); + auto index = gCpu->addTensor({2, 2}, DataType::Int32); + gCpu->dataMalloc(); + input->copyin(vector{1, 2, 3, 4, 5, 6}); + index->copyin(vector{0, 1, 1, 2}); + auto kunlunRuntime = make_ref(); + Graph gCuda = make_ref(kunlunRuntime); + + auto inputCuda = gCuda->cloneTensor(input); + auto indexCuda = gCuda->cloneTensor(index); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 0); + gCuda->dataMalloc(); + inputCuda->copyin(vector{1, 2, 3, 4, 5, 6}); + indexCuda->copyin(vector{0, 1, 1, 2}); + kunlunRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{1, 2, 3, 4, 3, 4, 5, 6})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 3}, DataType::Float32); + auto index = gCpu->addTensor({1, 2}, DataType::Int32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 2}); + auto kunlunRuntime = make_ref(); + Graph gCuda = make_ref(kunlunRuntime); + + auto inputCuda = gCuda->cloneTensor(input); + auto indexCuda = gCuda->cloneTensor(index); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 1); + gCuda->dataMalloc(); + inputCuda->setData(IncrementalGenerator()); + indexCuda->copyin(vector{0, 2}); + kunlunRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{0, 2, 3, 5, 6, 8})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({3, 2}, DataType::Float32); + auto index = gCpu->addTensor({2, 2}, DataType::Int32); + gCpu->dataMalloc(); + input->copyin(std::vector{1.0, 1.2, 2.3, 3.4, 4.5, 5.7}); + index->copyin(std::vector{0, 1, 1, 2}); + auto kunlunRuntime = make_ref(); + Graph gCuda = make_ref(kunlunRuntime); + + auto inputCuda = gCuda->cloneTensor(input); + auto indexCuda = gCuda->cloneTensor(index); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 0); + gCuda->dataMalloc(); + inputCuda->copyin(std::vector{1.0, 1.2, 2.3, 3.4, 4.5, 5.7}); + indexCuda->copyin(std::vector{0, 1, 1, 2}); + kunlunRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + std::vector{1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7})); + } +} + +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_matmul.cc b/test/kernels/kunlun/test_kunlun_matmul.cc index dcd2084f..448d7c06 100644 --- a/test/kernels/kunlun/test_kunlun_matmul.cc +++ b/test/kernels/kunlun/test_kunlun_matmul.cc @@ -7,52 +7,112 @@ #include "test.h" namespace infini { +using ExpectOutput = vector; -template -void testMatmul(const std::function &generatorA, - const std::function &generatorB, - bool transA, bool transB, const Shape &shapeA, - const Shape &shapeB) { - // Runtime - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); - auto xpuRuntime = make_ref(); +void testMatmulKUNLUNWithBias( + const std::function &generatorA, + const std::function &generatorB, + const std::function &generatorBias, + bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, + const Shape &shapeBias, const ExpectOutput &ansVec) { + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); + auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); + auto BiasCpu = gCpu->addTensor(shapeBias, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(generatorA); + BCpu->setData(generatorB); + BiasCpu->setData(generatorBias); - // Build input data on CPU - Tensor inputCpu1 = - make_ref(shapeA, DataType::Float32, cpuRuntime); - Tensor inputCpu2 = - make_ref(shapeB, DataType::Float32, cpuRuntime); + auto kunlunRuntime = make_ref(); + auto gKunlun = make_ref(kunlunRuntime); + auto AKunlun = gKunlun->cloneTensor(ACpu); + auto BKunlun = gKunlun->cloneTensor(BCpu); + auto BiasKunlun = gKunlun->cloneTensor(BiasCpu); + auto matmul = gKunlun->addOp(AKunlun, BKunlun, nullptr, transA, + transB, BiasKunlun); - // MLU - Graph xpuGraph = make_ref(xpuRuntime); - auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1); - auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2); - auto mluOp = xpuGraph->addOp(inputMlu1, inputMlu2, nullptr); - xpuGraph->dataMalloc(); - inputMlu1->setData(generatorA); - inputMlu2->setData(generatorB); - xpuRuntime->run(xpuGraph); - auto outputMlu = mluOp->getOutput(); - auto outputMlu2Cpu = outputMlu->clone(cpuRuntime); - // CPU - Graph cpuGraph = make_ref(cpuRuntime); - auto cpuOp = cpuGraph->addOp(inputCpu1, inputCpu2, nullptr); - cpuGraph->addTensor(inputCpu1); - cpuGraph->addTensor(inputCpu2); - cpuGraph->dataMalloc(); - inputCpu1->setData(generatorA); - inputCpu2->setData(generatorB); - cpuRuntime->run(cpuGraph); - auto outputCpu = cpuOp->getOutput(); - outputCpu->print(); - outputMlu2Cpu->print(); - // Check - EXPECT_TRUE(outputCpu->equalData(outputMlu2Cpu)); + // allocate Kunlun memory + gKunlun->dataMalloc(); + AKunlun->setData(generatorA); + BKunlun->setData(generatorB); + BiasKunlun->setData(generatorBias); + kunlunRuntime->run(gKunlun); + + auto CCpu = gCpu->cloneTensor(matmul->getOutput()); + // CCpu->printData(); + // check results on CPU + EXPECT_TRUE(CCpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + // gKunlun->print(); } -TEST(xpu_Matmul, run) { - testMatmul(IncrementalGenerator(), IncrementalGenerator(), false, - false, Shape{2, 3}, Shape{3, 4}); +void testMatmulKUNLUN( + const std::function &generatorA, + const std::function &generatorB, + bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, + const ExpectOutput &ansVec) { + auto cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); + auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(generatorA); + BCpu->setData(generatorB); + + auto kunlunRuntime = make_ref(); + auto gKunlun = make_ref(kunlunRuntime); + auto AKunlun = gKunlun->cloneTensor(ACpu); + auto BKunlun = gKunlun->cloneTensor(BCpu); + auto matmul = gKunlun->addOp(AKunlun, BKunlun, nullptr, transA, + transB, nullptr); + + // allocate Kunlun memory + gKunlun->dataMalloc(); + AKunlun->setData(generatorA); + BKunlun->setData(generatorB); + kunlunRuntime->run(gKunlun); + + auto CCpu = gCpu->cloneTensor(matmul->getOutput()); + // CCpu->printData(); + // check results on CPU + EXPECT_TRUE(CCpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + // gKunlun->print(); +} + +TEST(XDNN_Matmul, run) { + testMatmulKUNLUN(IncrementalGenerator(), OneGenerator(), false, false, + Shape{1, 3, 5}, Shape{1, 5, 2}, + ExpectOutput{10, 10, 35, 35, 60, 60}); + testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), true, + false, Shape{2, 3, 4}, Shape{2, 3, 2}, + ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424, + 475, 448, 502, 472, 529}); + testMatmulKUNLUN( + IncrementalGenerator(), IncrementalGenerator(), false, false, + Shape{2, 3, 5}, Shape{5, 2}, + ExpectOutput{60, 70, 160, 195, 260, 320, 360, 445, 460, 570, 560, 695}); + testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), true, + false, Shape{2, 5, 3}, Shape{5, 2}, + ExpectOutput{180, 210, 200, 235, 220, 260, 480, 585, 500, + 610, 520, 635}); + testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), false, + false, Shape{3, 5}, Shape{5, 2}, + ExpectOutput{60, 70, 160, 195, 260, 320}); +} + +TEST(XDNN_Matmul_With_Bias, run) { + testMatmulKUNLUNWithBias(IncrementalGenerator(), OneGenerator(), + OneGenerator(), false, false, Shape{1, 3, 5}, + Shape{1, 5, 2}, Shape{2}, + ExpectOutput{11, 11, 36, 36, 61, 61}); + testMatmulKUNLUNWithBias(IncrementalGenerator(), IncrementalGenerator(), + OneGenerator(), true, false, Shape{2, 3, 4}, + Shape{2, 3, 2}, Shape{4, 2}, + ExpectOutput{41, 53, 47, 62, 53, 71, 59, 80, 401, + 449, 425, 476, 449, 503, 473, 530}); } } // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_slice.cc b/test/kernels/kunlun/test_kunlun_slice.cc new file mode 100644 index 00000000..04c8c7d5 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_slice.cc @@ -0,0 +1,39 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(KUNLUN_Slice, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph; + Graph g = make_ref(kunlunRuntime); + auto i = g->cloneTensor(icpu); + auto op = + g->addOp(i, nullptr, vector{1, 1}, vector{2, 5}, + vector{0, 3}, std::nullopt); + + // allocate CUDA memory + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute on CUDA + kunlunRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + // cudaPrintTensor(o); + // check results on CPU + EXPECT_TRUE(cpuo->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); +} +} // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_where.cc b/test/kernels/kunlun/test_kunlun_where.cc new file mode 100644 index 00000000..c744ef68 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_where.cc @@ -0,0 +1,77 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/where.h" + +#include "test.h" + +namespace infini { + +void test_where(const Shape &inputXShape, const vector &inputXData, + const Shape &inputYShape, const vector &inputYData, + const Shape &conditionShape, + const vector &conditionData, + const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto condition = gCpu->addTensor(conditionShape, DataType::Bool); + auto inputX = gCpu->addTensor(inputXShape, DataType::Float32); + auto inputY = gCpu->addTensor(inputYShape, DataType::Float32); + + gCpu->dataMalloc(); + condition->copyin(conditionData); // + inputX->copyin(inputXData); + inputY->copyin(inputYData); // + + auto kunlunRuntime = make_ref(); + Graph gCuda = make_ref(kunlunRuntime); + + auto conditionGpu = gCuda->cloneTensor(condition); + auto inputXGpu = gCuda->cloneTensor(inputX); + auto inputYGpu = gCuda->cloneTensor(inputY); + + auto op = gCuda->addOp(inputXGpu, inputYGpu, conditionGpu, + nullptr); // WhereObj + gCuda->dataMalloc(); + conditionGpu->copyin(conditionData); + inputXGpu->copyin(inputXData); + inputYGpu->copyin(inputYData); + kunlunRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); +} + +TEST(KUNLUN_Where, run) { + test_where( + Shape{2, 2, 3, 1}, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + Shape{2, 2, 3, 1}, vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, + vector{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); + + test_where(Shape{2, 2, 1, 3}, // inputx + vector{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}, + Shape{2, 2, 1, 3}, // inputy + vector{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7}, + Shape{2, 2, 1, 3}, // condition + vector{0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0}, + vector{1, 1, 2, 2, 5, 1, 0, 2, 2, 3, 6, 7}); + + test_where(Shape{2, 2, 1, 3}, + vector{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}, // inputX + Shape{2, 2, 1, 3}, + vector{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7}, // inputY + Shape{2, 1, 1, 3}, vector{1, 1, 0, 1, 1, 1}, // condition + vector{0, 1, 3, 3, 4, 1, 0, 1, 2, 3, 4, 5}); // result + + test_where(Shape{2, 2, 1, 3}, + vector{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}, // inputX + Shape{2, 2, 1, 3}, + vector{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7}, // inputY + Shape{2, 1, 1, 3}, + vector{1, 1, 0, 1, 1, + 1}, // condition } // python output + vector{0, 1, 3, 3, 4, 1, 0, 1, 2, 3, 4, 5}); // result +} +} // namespace infini diff --git a/test/kunlun/test_kunlun_workspace.cc b/test/kunlun/test_kunlun_workspace.cc new file mode 100644 index 00000000..6feb6823 --- /dev/null +++ b/test/kunlun/test_kunlun_workspace.cc @@ -0,0 +1,20 @@ +#include "core/runtime.h" +#include "core/workspace.h" +#include "kunlun/kunlun_runtime.h" + +#include "test.h" + +namespace infini { +TEST(KunlunWorkspace, test) { + Ref kunlunRuntime = make_ref(); + auto wkspace = kunlunRuntime->getWorkspaceObj(); + KUNLUNPtr space1 = kunlunRuntime->getWorkspace(1024 * 1024 * sizeof(float)); + IT_ASSERT(wkspace->getWorkspaceAlloc() == 1024 * 1024 * sizeof(float)); + KUNLUNPtr space2 = kunlunRuntime->getWorkspace(1024 * 1024 * sizeof(float)); + IT_ASSERT(wkspace->getWorkspaceAlloc() == 1024 * 1024 * sizeof(float) * 2); + IT_ASSERT((void *)(static_cast(space1) + + 1024 * 1024 * sizeof(float)) == (void *)space2); + wkspace->resetWorkspace(); + IT_ASSERT(wkspace->getWorkspaceAlloc() == 0); +} +} // namespace infini