diff --git a/.gitignore b/.gitignore index 863f1a48..98e980ad 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,5 @@ build_debug/ # onnx model *.onnx +*.pb +*.npy diff --git a/.gitmodules b/.gitmodules index 02a80785..3d6bad77 100644 --- a/.gitmodules +++ b/.gitmodules @@ -11,5 +11,5 @@ path = 3rd-party/backward-cpp url = git@github.com:bombela/backward-cpp.git [submodule "example"] - path = example + path = examples/NNmodel url = git@github.com:wanghailu0717/NNmodel.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c737313..49d2c5a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ option(USE_BANG "Support BANG MLU" OFF) option(USE_INTELCPU "Support INTELCPU" OFF) option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON) option(USE_PROTOBUF "Serialize and deserialize tensors" OFF) +option(BUILD_DIST "Build project for distributed running" OFF) option(BUILD_TEST "Build tests" OFF) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) @@ -194,6 +195,13 @@ if(USE_CUDA) enable_language(CUDA) find_package(CUDAToolkit) # For nvrtc and cuda driver target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver) + if (BUILD_DIST) + message(STATUS "Add BUILD_DIST, use NCCL with CUDA") + list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + find_package(NCCL REQUIRED) + add_compile_definitions(INFINI_USE_NCCL=1) + target_link_libraries(InfiniTensor nccl) + endif() endif() if(USE_BANG) @@ -261,6 +269,7 @@ if(BUILD_TEST) build_test(test/operators/*.cc) if (USE_CUDA) build_test(test/kernels/cuda/*.cc) + build_test(test/cuda/*.cc) endif() if (USE_BANG) build_test(test/kernels/bang/*.cc) diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 00000000..d2f2f835 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,165 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Find the nccl libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… +# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") +set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") + +if ($ENV{NCCL_ROOT_DIR}) + message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") +endif() +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) + +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/example b/examples/NNmodel similarity index 100% rename from example rename to examples/NNmodel diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py new file mode 100644 index 00000000..362bde1c --- /dev/null +++ b/examples/distributed/launch.py @@ -0,0 +1,100 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +import numpy as np +from parallel import parallel_model + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=1, help="number of processes per node" + ) + parser.add_argument( + "--model", type=str, required=True, help="path to the ONNX model file." + ) + args = parser.parse_args() + print("arg setting: ", args) + return args.num_nodes, args.nproc_per_node, args.model + + +def run_stub(stub: OnnxStub, inputs: np.array, n=100): + # warm up + next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) + stub.tune() + for _ in range(20): + stub.run() + outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) + + # bench + next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) + print("outputs sum:", outputs.sum()) + # np.save("results", outputs) + results = np.load("results.npy") + print("max diff:", abs(outputs - results).max()) + assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6) + avg_time = (end - begin) / n + return avg_time + + +def start_worker( + dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + print("start worker") + runtime = backend.CudaRuntime(local_rank) + print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + model = parallel_model(model, world_size, rank) + onnx.save(model, f"dist_model_rank{rank}.onnx") + print("load model") + stub = OnnxStub(model, runtime) + data = np.load("inputs.npy") + print("run model") + avg_time = run_stub(stub, data) + print(f"average time: {avg_time}") + + +def main(): + nnodes, nproc_per_node, model_path = parse_args() + world_size = nnodes * nproc_per_node + + model = onnx.load(model_path) + # generate standard results + # runtime = backend.CudaRuntime(0) + # stub = OnnxStub(model, runtime) + # data = np.random.randn(1, 3, 224, 224) + # np.save("inputs", data) + # run_stub(stub, data) + # del stub + + dist_name = f"dist_{os.getpid()}" + workers = [ + mp.Process( + target=start_worker, + args=(dist_name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel.py b/examples/distributed/parallel.py new file mode 100644 index 00000000..7d2b19a4 --- /dev/null +++ b/examples/distributed/parallel.py @@ -0,0 +1,103 @@ +import onnx +from onnx import ( + ModelProto, + TensorProto, + NodeProto, + AttributeProto, +) +from onnx import helper, numpy_helper +from typing import Dict, Any + + +def parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: + for attr in node.attribute: + if attr.name in attrs: + if attr.type == AttributeProto.INT: + attrs[attr.name] = attr.i + elif attr.type == AttributeProto.INTS: + attrs[attr.name] = attr.ints + elif attr.type == AttributeProto.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == AttributeProto.STRING: + attrs[attr.name] = attr.s + elif attr.type == AttributeProto.TENSOR: + attrs[attr.name] = attr.t + else: + assert False, "Unsupported Attribute Type: {}".format(attr.type) + return attrs + + +def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): + data = {init.name: init for init in model.graph.initializer} + nodes = list(model.graph.node) + + def shard_tensor(tensor: TensorProto, dim: int): + array = numpy_helper.to_array(tensor) + if dim >= array.ndim: + dim = array.ndim - 1 + assert array.shape[dim] % tp_world_size == 0 + seg = array.shape[dim] // tp_world_size + array = array[tp_rank * seg : (tp_rank + 1) * seg] + return numpy_helper.from_array(array, name=tensor.name + f":sharded({dim})") + + def shard_gemm(node: NodeProto): + attrs = parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + trans = [attrs["transA"], attrs["transB"]] + dim = 0 + for i, (input, t) in enumerate(zip(node.input, trans)): + if input in data: + dim = i + sharded = shard_tensor(data[input], dim ^ t) + node.input[i] = sharded.name + data[input] = sharded + if len(node.input) > 2: + input = node.input[2] + sharded = shard_tensor(data[input], dim) + node.input[2] = sharded.name + data[input] = sharded + + node.output[0] += f":sharded({dim})" + return dim + + for i, node in enumerate(nodes): + if node.op_type == "Gemm": + output = node.output[0] + dim = shard_gemm(node) + gathered = [node.output[0] + f".{i}" for i in range(tp_world_size)] + # all_gather + nodes.insert( + i + 1, + helper.make_node( + op_type="AllGather", + inputs=[node.output[0]], + outputs=gathered, + name=node.name + "/allgather", + # domain="infini", # shape inference fails for custom domain + ), + ) + # concat + nodes.insert( + i + 2, + helper.make_node( + op_type="Concat", + inputs=gathered, + outputs=[output], + name=node.name + "/concat", + axis=dim, + ), + ) + graph = helper.make_graph( + nodes, + model.graph.name + f"_{tp_rank}", + model.graph.input, + model.graph.output, + data.values(), + doc_string=model.graph.doc_string, + value_info=model.graph.value_info, + ) + model = helper.make_model(graph) + + onnx.shape_inference.infer_shapes(model) + return model diff --git a/include/core/communicator.h b/include/core/communicator.h new file mode 100644 index 00000000..9cc958d7 --- /dev/null +++ b/include/core/communicator.h @@ -0,0 +1,22 @@ +#pragma once +#include "object.h" +#include "ref.h" + +namespace infini { + +// base class +class CommunicatorObj : public Object { + protected: + int worldSize; + int rank; + + public: + CommunicatorObj(int worldSize, int rank) + : worldSize(worldSize), rank(rank) {} + + virtual ~CommunicatorObj() = default; + virtual int getWorldSize() const { return worldSize; } + virtual int getRank() const { return rank; } +}; + +} // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 49fa2347..7f514ebd 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -74,6 +74,14 @@ class GraphHandlerObj { Tensor expand(Tensor input, Tensor output, Shape dims); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); + Tensor allReduceSum(Tensor input, Tensor output); + Tensor allReduceProd(Tensor input, Tensor output); + Tensor allReduceMin(Tensor input, Tensor output); + Tensor allReduceMax(Tensor input, Tensor output); + Tensor allReduceAvg(Tensor input, Tensor output); + TensorVec allGather(Tensor input, std::optional outputs, int n); + Tensor broadcast(Tensor input, Tensor output, int root); + //------ modifiers inline bool topo_sort() { return g->topo_sort(); } diff --git a/include/core/op_type.h b/include/core/op_type.h index a5ea2524..e0146c5f 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -221,6 +221,15 @@ struct OpType { FloorMod, Square, SquaredDifference, + + // Communication Ops + AllReduceSum, + AllReduceProd, + AllReduceMin, + AllReduceMax, + AllReduceAvg, + AllGather, + Broadcast, } type; constexpr OpType(decltype(type) t) : type(t) {} diff --git a/include/core/runtime.h b/include/core/runtime.h index 2fe0467c..bd9da89a 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -1,5 +1,6 @@ #pragma once #include "core/common.h" +#include "core/communicator.h" #include "core/op_type.h" #include "core/ref.h" #include @@ -35,9 +36,11 @@ enum class Device { CPU = 1, CUDA, BANG, INTELCPU }; class RuntimeObj : public std::enable_shared_from_this { protected: Device device; + int deviceId; public: - RuntimeObj(Device device) : device(device) {} + explicit RuntimeObj(Device device, int deviceId = 0) + : device(device), deviceId(deviceId) {} RuntimeObj(RuntimeObj &other) = delete; RuntimeObj &operator=(RuntimeObj const &) = delete; virtual ~RuntimeObj() {} @@ -77,6 +80,12 @@ class RuntimeObj : public std::enable_shared_from_this { size_t bytes) const = 0; virtual string toString() const = 0; + int getDeviceId() const { return deviceId; } + + virtual void initComm(const string &name, int worldSize, int rank) = 0; + + virtual CommunicatorObj &getCommunicator() const = 0; + protected: void printProfilingData(double totTime, const std::map &opTime, @@ -97,6 +106,9 @@ class CpuRuntimeObj : public RuntimeObj { void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override; void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override; + void initComm(const string &, int, int) override { IT_TODO_HALT(); } + + CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } }; class NativeCpuRuntimeObj : public CpuRuntimeObj { diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index b5830454..19fd9fc8 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -1,6 +1,9 @@ #pragma once #include "core/runtime.h" #include "cuda/cuda_common.h" +#ifdef INFINI_USE_NCCL +#include "cuda/nccl_communicator.h" +#endif namespace infini { @@ -8,12 +11,15 @@ class CudaRuntimeObj : public RuntimeObj { private: cudnnHandle_t cudnn; cublasHandle_t cublas; + std::unique_ptr comm; CudaPtr workspace; size_t workspaceSize; public: - CudaRuntimeObj() : RuntimeObj(Device::CUDA) { + explicit CudaRuntimeObj(int deviceId = 0) + : RuntimeObj(Device::CUDA, deviceId) { + checkCudaError(cudaSetDevice(deviceId)); checkCudnnError(cudnnCreate(&cudnn)); checkCublasError(cublasCreate(&cublas)); // 10GB for Longformer @@ -69,6 +75,11 @@ class CudaRuntimeObj : public RuntimeObj { void runWithoutSync(const Graph &graph) const; + // init communicator + void initComm(const string &name, int worldSize, int rank) final; + + CommunicatorObj &getCommunicator() const final { return *comm; } + private: void tune(const Graph &graph, bool profiling) const; }; diff --git a/include/cuda/nccl_communicator.h b/include/cuda/nccl_communicator.h new file mode 100644 index 00000000..dcef2830 --- /dev/null +++ b/include/cuda/nccl_communicator.h @@ -0,0 +1,70 @@ +#pragma once +#include "core/communicator.h" +#include +#include +#include +#include +#include +#include + +#define checkNcclError(call) \ + { \ + auto err = call; \ + if (ncclSuccess != err) { \ + fprintf(stderr, "NCCL error in %s:%i : %s.\n", __FILE__, __LINE__, \ + ncclGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace infini { + +class NcclCommunicatorObj final : public CommunicatorObj { + private: + ncclComm_t comm; + + public: + NcclCommunicatorObj(const string &name, int worldSize, int rank) + : CommunicatorObj(worldSize, rank) { + const std::string filePath("./" + name + "_nccl_id.bin"); + ncclUniqueId commId; + if (rank == 0) { + checkNcclError(ncclGetUniqueId(&commId)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *)&commId, sizeof(ncclUniqueId)); + + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + _IT_ASSERT_2(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *)&commId, sizeof(ncclUniqueId)); + } + checkNcclError(ncclCommInitRank(&comm, worldSize, commId, rank)); + if (rank == 0) { + std::filesystem::remove(filePath); + } + } + + // Get the actual ncclComm_t + ncclComm_t getNcclComm() { return comm; } + + void finalize() { checkNcclError(ncclCommFinalize(comm)); } + + ~NcclCommunicatorObj() final { + finalize(); + checkNcclError(ncclCommDestroy(comm)); + } + + virtual string toString() const final { + std::ostringstream oss; + oss << "NCCL communicator"; + return oss.str(); + } +}; + +} // namespace infini diff --git a/include/operators/all_gather.h b/include/operators/all_gather.h new file mode 100644 index 00000000..423974f6 --- /dev/null +++ b/include/operators/all_gather.h @@ -0,0 +1,44 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief The AllGather operation gathers N values from k ranks into + * an output of size k*N, and distributes that result to all ranks. + * The output is ordered by rank index. + * + * For more details: + * https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather + */ +class AllGatherObj : public OperatorObj { + + public: + /** + * @brief Construct a new AllGather object + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor from this rank. + * @param outputs A list of output tensors collected from all ranks. + * @param world_size Total number of ranks. + */ + AllGatherObj(GraphObj *graph, Tensor input, std::optional, + int world_size); + OP_CLONE(AllGatherObj); + + int numInputs() const override { return 1; } + int numOutputs() const override { return world_size; } + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + + int getWorldSize() const { return world_size; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + vector inferDataType(const TensorVec &inputs) const override; + + protected: + int world_size; +}; +} // namespace infini diff --git a/include/operators/all_reduce.h b/include/operators/all_reduce.h new file mode 100644 index 00000000..f91b3ad1 --- /dev/null +++ b/include/operators/all_reduce.h @@ -0,0 +1,75 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief The AllReduce operation is performing reductions on data (sum, min, + * max, avg, or div) across devices and writing the result in the + * receive buffers of every rank. For example, in an allreduce operation between + * k ranks and performing a sum, each rank will provide an array Vk of N values, + * and receive an identical arrays S of N values, where S[i] = + * V0[i]+V1[i]+…+Vk-1[i]. + * + * For more details: + * https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce + */ +class AllReduceBaseObj : public OperatorObj { + + public: + /** + * @brief Construct a new AllReduce base object. Should be called by every + * child class constructor, but not directly. + * + * @param graph The computation graph that this operator belongs to. + * @param opType The operation type. This param is taken care of by child + * classes. + * @param input The input tensor from this rank. + * @param output The output tensor, same size as input. + */ + AllReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, + Tensor output); + OP_CLONE(AllReduceBaseObj); + + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + optional> inferShape(const TensorVec &inputs) const override { + return {{inputs[0]->getDims()}}; + }; + + std::string toString() const override; + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + vector inferDataType(const TensorVec &inputs) const override { + return {inputs[0]->getDType()}; + }; +}; + +class AllReduceSumObj : public AllReduceBaseObj { + public: + AllReduceSumObj(GraphObj *graph, Tensor input, Tensor output); +}; + +class AllReduceProdObj : public AllReduceBaseObj { + public: + AllReduceProdObj(GraphObj *graph, Tensor input, Tensor output); +}; + +class AllReduceMinObj : public AllReduceBaseObj { + public: + AllReduceMinObj(GraphObj *graph, Tensor input, Tensor output); +}; + +class AllReduceMaxObj : public AllReduceBaseObj { + public: + AllReduceMaxObj(GraphObj *graph, Tensor input, Tensor output); +}; + +class AllReduceAvgObj : public AllReduceBaseObj { + public: + AllReduceAvgObj(GraphObj *graph, Tensor input, Tensor output); +}; + +} // namespace infini diff --git a/include/operators/broadcast.h b/include/operators/broadcast.h new file mode 100644 index 00000000..1a15b770 --- /dev/null +++ b/include/operators/broadcast.h @@ -0,0 +1,49 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief The Broadcast operation copies an N-element buffer on the root rank to + * all ranks. + * + * For more details: + * https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast + */ +class BroadcastObj : public OperatorObj { + public: + /** + * @brief Construct a new Broadcast object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. Only root needs to initialize it with + * data. + * @param output The output tensor, same size as input. + * @param root The root rank who performs the broadcast. + */ + BroadcastObj(GraphObj *graph, Tensor input, Tensor output, int root); + OP_CLONE(BroadcastObj); + + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + optional> inferShape(const TensorVec &inputs) const override { + return {{inputs[0]->getDims()}}; + }; + + std::string toString() const override; + + int getRoot() const { return root; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + vector inferDataType(const TensorVec &inputs) const override { + return {inputs[0]->getDType()}; + }; + + protected: + // The rank who broadcasts data among this communication group + int root; +}; + +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 9fba35c4..b0a433ba 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -591,6 +591,54 @@ class OnnxStub: tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "to")), ) + elif node.op_type == "AllReduceSum": + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceProd": + tensors[node.output[0]] = self.handler.allReduceProd( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMin": + tensors[node.output[0]] = self.handler.allReduceMin( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMax": + tensors[node.output[0]] = self.handler.allReduceMax( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceAvg": + tensors[node.output[0]] = self.handler.allReduceAvg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllGather": + for name, tensor in zip( + node.output, + self.handler.allGather( + tensors[node.input[0]], + None, + len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Broadcast": + tensors[node.output[0]] = self.handler.broadcast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + ( + attr.i + for attr in node.attribute + if attr.name == "root" + ), + 0, + ), + ) elif node.op_type == "Expand": shape = _parse_data(data[node.input[1]]) tensors[node.output[0]] = self.handler.expand( diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 3fdb5f06..884bd874 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -329,6 +329,83 @@ class TestStringMethods(unittest.TestCase): [pads_data], ) ) + + def test_allReduceSum(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + allReduceSum = make_node( + "AllReduceSum", ["input"], ["output"], name="allReduceSum" + ) + graph = make_graph([allReduceSum], "allReduceSum", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_allReduceProd(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + allReduceProd = make_node( + "AllReduceProd", ["input"], ["output"], name="allReduceProd" + ) + graph = make_graph([allReduceProd], "allReduceProd", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_allReduceMin(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + allReduceMin = make_node( + "AllReduceMin", ["input"], ["output"], name="allReduceMin" + ) + graph = make_graph([allReduceMin], "allReduceMin", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_allReduceMax(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + allReduceMax = make_node( + "AllReduceMax", ["input"], ["output"], name="allReduceMax" + ) + graph = make_graph([allReduceMax], "allReduceMax", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_allReduceAvg(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + allReduceAvg = make_node( + "AllReduceAvg", ["input"], ["output"], name="allReduceAvg" + ) + graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_split(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + split = make_node( + "Split", ["input"], ["output"], name="split", axis=0 + ) + make_and_import_model(make_graph([split], "split", [input], [])) + + def test_allBroadcast(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) + broadcast = make_node( + "Broadcast", ["input"], ["output"], name="broadcast", root=1 + ) + graph = make_graph([broadcast], "broadcast", [input], [output]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_allGather(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + world_size = make_tensor_value_info("world_size", TensorProto.INT32, [1]) + allGather = make_node( + "AllGather", ["input", "world_size"], ["output"], name="allGather" + ) + graph = make_graph([allGather], "allGather", [input, world_size], []) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) # see def test_linear(self): diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 87ed6f46..a804a8c7 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,5 +1,8 @@ #include "core/graph_handler.h" +#include "operators/all_gather.h" +#include "operators/all_reduce.h" #include "operators/batch_norm.h" +#include "operators/broadcast.h" #include "operators/concat.h" #include "operators/conv.h" #include "operators/element_wise.h" @@ -300,6 +303,73 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output, } } +Tensor GraphHandlerObj::allReduceSum(Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), output); + return output; + } else { + return g->addOp(std::move(input), output)->getOutput(); + } +} + +Tensor GraphHandlerObj::allReduceProd(Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), output); + return output; + } else { + return g->addOp(std::move(input), output) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::allReduceMin(Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), output); + return output; + } else { + return g->addOp(std::move(input), output)->getOutput(); + } +} + +Tensor GraphHandlerObj::allReduceMax(Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), output); + return output; + } else { + return g->addOp(std::move(input), output)->getOutput(); + } +} + +Tensor GraphHandlerObj::allReduceAvg(Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), output); + return output; + } else { + return g->addOp(std::move(input), output)->getOutput(); + } +} + +TensorVec GraphHandlerObj::allGather(Tensor input, + std::optional outputs, int n) { + if (outputs) { + g->addOpWithOutputs(std::move(input), outputs, n); + return *outputs; + } else { + return g->addOp(std::move(input), outputs, n) + ->getOutputs(); + } +} + +Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) { + if (output) { + g->addOpWithOutputs(std::move(input), output, root); + return output; + } else { + return g->addOp(std::move(input), output, root) + ->getOutput(); + } +} + Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) { if (output) { g->addOpWithOutputs(std::move(input), output, diff --git a/src/core/op_type.cc b/src/core/op_type.cc index f12c4a9b..38122bf9 100644 --- a/src/core/op_type.cc +++ b/src/core/op_type.cc @@ -214,6 +214,15 @@ const char *OpType::toString() const { CASE(FloorMod); CASE(Square); CASE(SquaredDifference); + + // Communcation + CASE(AllReduceSum); + CASE(AllReduceProd); + CASE(AllReduceMin); + CASE(AllReduceMax); + CASE(AllReduceAvg); + CASE(AllGather); + CASE(Broadcast); default: return "Unknown"; } diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index a8051a91..927b1f0d 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -2,6 +2,9 @@ #include "core/kernel.h" #include "core/perf_engine.h" #include "core/runtime.h" +#ifdef INFINI_USE_NCCL +#include "cuda/nccl_communicator.h" +#endif #include "operators/conv.h" #include "operators/matmul.h" @@ -96,4 +99,15 @@ void CudaRuntimeObj::sync() const { checkCudaError(cudaDeviceSynchronize()); } string CudaRuntimeObj::toString() const { return "CUDA Runtime"; } +void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) { + IT_ASSERT(worldSize > 0); + IT_ASSERT(rank >= 0); + IT_ASSERT(rank < worldSize); +#ifdef INFINI_USE_NCCL + comm = std::make_unique(name, worldSize, rank); +#else + IT_TODO_HALT_MSG("Not compiled with NCCL."); +#endif +} + } // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index efe047da..f6af18ec 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -143,7 +143,10 @@ static int tensor_dtype(Tensor t) { } #ifdef USE_CUDA -static Ref cuda_runtime() { return make_ref(); } +// NOTE(lizhouyang): deprecate this, use CudaRuntime directly. +[[deprecated]] static Ref cuda_runtime() { + return make_ref(0); +} #endif #ifdef USE_BANG @@ -311,7 +314,9 @@ void init_graph_builder(py::module &m) { RuntimeObj>(m, "CpuRuntime"); #ifdef USE_CUDA py::class_, RuntimeObj>( - m, "CudaRuntime"); + m, "CudaRuntime") + .def(py::init(), py::arg("device") = 0) + .def("init_comm", &CudaRuntimeObj::initComm); #endif #ifdef USE_BANG py::class_, RuntimeObj>( @@ -435,6 +440,13 @@ void init_graph_builder(py::module &m) { .def("reduce_mean", &Handler::reduceMean, policy::move) .def("slice", &Handler::slice, policy::move) .def("pad", &Handler::pad, policy::move) + .def("allReduceSum", &Handler::allReduceSum, policy::move) + .def("allReduceProd", &Handler::allReduceProd, policy::move) + .def("allReduceMin", &Handler::allReduceMin, policy::move) + .def("allReduceMax", &Handler::allReduceMax, policy::move) + .def("allReduceAvg", &Handler::allReduceAvg, policy::move) + .def("allGather", &Handler::allGather, policy::move) + .def("broadcast", &Handler::broadcast, policy::move) .def("cast", &Handler::cast, policy::move) .def("expand", &Handler::expand, policy::move) .def("erf", &Handler::erf, policy::move) diff --git a/src/kernels/cuda/all_gather.cc b/src/kernels/cuda/all_gather.cc new file mode 100644 index 00000000..187aea5c --- /dev/null +++ b/src/kernels/cuda/all_gather.cc @@ -0,0 +1,46 @@ +#ifdef INFINI_USE_NCCL +#include "operators/all_gather.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" + +namespace infini { +class AllGatherNCCL : public CudaKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int world_size = op->getWorldSize(); + // Check if world size info in operator matches runtime + IT_ASSERT(world_size == context->getCommunicator().getWorldSize()); + + void *input = op->getInputs(0)->getRawDataPtr(); + CudaPtr output_temp = + context->getWorkspace(op->getInputs(0)->getBytes() * world_size); + // void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / op->getDType().getSize(); + + ncclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getNcclComm(); + // TODO: Using default stream 0 for now. + checkNcclError( + ncclAllGather(input, output_temp, count, ncclFloat, comm, 0)); + + for (int i = 0; i < world_size; ++i) { + Tensor output = op->getOutput(i); + context->copyBlobInsideRuntime( + output->getRawDataPtr(), + static_cast(output_temp) + i * count, bytes); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32, + AllGatherNCCL, "AllGather_NCCL_CUDA_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc new file mode 100644 index 00000000..2728b5e2 --- /dev/null +++ b/src/kernels/cuda/all_reduce.cc @@ -0,0 +1,58 @@ +#ifdef INFINI_USE_NCCL +#include "operators/all_reduce.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" + +namespace infini { +class AllReduceNCCL : public CudaKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + + ncclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getNcclComm(); + // TODO: Using default stream 0 for now. + checkNcclError(ncclAllReduce(input, output, count, ncclFloat, + getRedOp(), comm, 0)); + } + + virtual ncclRedOp_t getRedOp() const = 0; +}; + +class AllReduceSumNCCL : public AllReduceNCCL { + ncclRedOp_t getRedOp() const override { return ncclSum; } +}; +class AllReduceProdNCCL : public AllReduceNCCL { + ncclRedOp_t getRedOp() const override { return ncclProd; } +}; +class AllReduceMinNCCL : public AllReduceNCCL { + ncclRedOp_t getRedOp() const override { return ncclMin; } +}; +class AllReduceMaxNCCL : public AllReduceNCCL { + ncclRedOp_t getRedOp() const override { return ncclMax; } +}; +class AllReduceAvgNCCL : public AllReduceNCCL { + ncclRedOp_t getRedOp() const override { return ncclAvg; } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32, + AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32, + AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32, + AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32, + AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32, + AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32"); + +} // namespace infini +#endif diff --git a/src/kernels/cuda/broadcast.cc b/src/kernels/cuda/broadcast.cc new file mode 100644 index 00000000..79190491 --- /dev/null +++ b/src/kernels/cuda/broadcast.cc @@ -0,0 +1,32 @@ +#ifdef INFINI_USE_NCCL +#include "operators/broadcast.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" + +namespace infini { +class BroadcastNCCL : public CudaKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + + ncclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getNcclComm(); + // TODO: Using default stream 0 for now. + checkNcclError(ncclBroadcast(input, output, count, ncclFloat, + op->getRoot(), comm, 0)); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32, + BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32"); +} // namespace infini + +#endif diff --git a/src/operators/all_gather.cc b/src/operators/all_gather.cc new file mode 100644 index 00000000..127c3b8d --- /dev/null +++ b/src/operators/all_gather.cc @@ -0,0 +1,49 @@ +#include "operators/all_gather.h" + +namespace infini { +AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input, + std::optional outputs, int world_size) + : OperatorObj( + OpType::AllGather, {input}, + ((!outputs) ? TensorVec(world_size, nullptr) : std::move(*outputs))), + world_size(world_size) { + IT_ASSERT(checkValid(graph)); +} + +optional> +AllGatherObj::inferShape(const TensorVec &inputs) const { + Shape input_shape = inputs[0]->getDims(); + vector output_shapes(getWorldSize(), input_shape); + return output_shapes; +} + +vector AllGatherObj::inferDataType(const TensorVec &inputs) const { + return vector(world_size, inputs[0]->getDType()); +} + +std::string AllGatherObj::toString() const { + std::ostringstream os; + os << "AllGather" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output="; + for (auto i = 0; i < world_size; i++) + os << outputs[i]->getGuid() << ","; + os << ")"; + return os.str(); +} + +vector AllGatherObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = inputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + ret.emplace_back(world_size); + return ret; +} + +vector AllGatherObj::getOpAttrVector() const { + return {type.underlying(), world_size}; +} +} // namespace infini diff --git a/src/operators/all_reduce.cc b/src/operators/all_reduce.cc new file mode 100644 index 00000000..7b1b6134 --- /dev/null +++ b/src/operators/all_reduce.cc @@ -0,0 +1,45 @@ +#include "operators/all_reduce.h" + +namespace infini { +AllReduceBaseObj::AllReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, + Tensor output) + : OperatorObj(opType, {input}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +std::string AllReduceBaseObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ","; + return os.str(); +} + +vector AllReduceBaseObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector AllReduceBaseObj::getOpAttrVector() const { + return {type.underlying()}; +} + +AllReduceSumObj::AllReduceSumObj(GraphObj *graph, Tensor input, Tensor output) + : AllReduceBaseObj(graph, OpType::AllReduceSum, input, output) {} + +AllReduceProdObj::AllReduceProdObj(GraphObj *graph, Tensor input, Tensor output) + : AllReduceBaseObj(graph, OpType::AllReduceProd, input, output) {} + +AllReduceMinObj::AllReduceMinObj(GraphObj *graph, Tensor input, Tensor output) + : AllReduceBaseObj(graph, OpType::AllReduceMin, input, output) {} + +AllReduceMaxObj::AllReduceMaxObj(GraphObj *graph, Tensor input, Tensor output) + : AllReduceBaseObj(graph, OpType::AllReduceMax, input, output) {} + +AllReduceAvgObj::AllReduceAvgObj(GraphObj *graph, Tensor input, Tensor output) + : AllReduceBaseObj(graph, OpType::AllReduceAvg, input, output) {} +} // namespace infini diff --git a/src/operators/broadcast.cc b/src/operators/broadcast.cc new file mode 100644 index 00000000..2f9b7e69 --- /dev/null +++ b/src/operators/broadcast.cc @@ -0,0 +1,33 @@ +#include "operators/broadcast.h" + +namespace infini { +BroadcastObj::BroadcastObj(GraphObj *graph, Tensor input, Tensor output, + int root) + : OperatorObj(OpType::Broadcast, {input}, {output}), root(root) { + IT_ASSERT(checkValid(graph)); +} + +vector BroadcastObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = inputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector BroadcastObj::getOpAttrVector() const { + return {type.underlying()}; +} + +std::string BroadcastObj::toString() const { + std::ostringstream os; + os << "Broadcast" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ","; + os << "root=" << root; + os << ")"; + return os.str(); +} +} // namespace infini diff --git a/test/cuda/test_nccl_comm.cc b/test/cuda/test_nccl_comm.cc new file mode 100644 index 00000000..b8a92a14 --- /dev/null +++ b/test/cuda/test_nccl_comm.cc @@ -0,0 +1,55 @@ +#ifdef INFINI_USE_NCCL +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" +#include "test.h" + +static int WORLD_SIZE = 2; + +namespace infini { + +void allReduceSum(float *data, int deviceId) { + // Create Runtime and setup communication + CudaRuntimeObj *cuda_runtime = new CudaRuntimeObj(deviceId); + int rank = deviceId; + cuda_runtime->initComm("test_nccl_comm", WORLD_SIZE, rank); + ncclComm_t comm = + dynamic_cast(cuda_runtime->getCommunicator()) + .getNcclComm(); + + // Copy data + float *data_gpu; + checkCudaError(cudaMalloc(&data_gpu, sizeof(float))); + checkCudaError( + cudaMemcpy(data_gpu, data, sizeof(float), cudaMemcpyHostToDevice)); + + // Do AllReduce + checkNcclError( + ncclAllReduce(data_gpu, data_gpu, 1, ncclFloat, ncclSum, comm, 0)); + + // Copy data back and sync device + checkCudaError( + cudaMemcpy(data, data_gpu, sizeof(float), cudaMemcpyDeviceToHost)); + checkCudaError(cudaDeviceSynchronize()); +} + +// Setup communication between 2 threads, each controlling 1 GPU. +// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}. +TEST(NCCL, multi_gpu_communication) { + int num_threads = WORLD_SIZE; + float data[] = {1.0, 4.0}; + + std::vector threads; + for (int gpu = 0; gpu < num_threads; ++gpu) { + threads.emplace_back(allReduceSum, &data[gpu], gpu); + } + for (auto &thread : threads) { + thread.join(); + } + + for (int i = 0; i < num_threads; ++i) { + ASSERT_EQ(data[i], 5.0f); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/cuda/test_cuda_all_gather.cc b/test/kernels/cuda/test_cuda_all_gather.cc new file mode 100644 index 00000000..c5c43206 --- /dev/null +++ b/test/kernels/cuda/test_cuda_all_gather.cc @@ -0,0 +1,51 @@ +#ifdef INFINI_USE_NCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/all_gather.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +void allGather(const string taskName, int deviceID, vector data, + vector> ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime cudaRuntime = make_ref(deviceID); + cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(cudaRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, WORLD_SIZE); + // Copy data from CPU to GPU + g->dataMalloc(); + input->copyin(data); + // Run operation + cudaRuntime->run(g); + // Copy output from GPU to CPU + for (int i = 0; i < WORLD_SIZE; ++i) { + auto result = op->getOutputs()[i]->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(ans[i])); + } +} + +TEST(CUDA_AllGather, run) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector> ans = {{2., 3.}, {5., 6.}}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace infini +#endif diff --git a/test/kernels/cuda/test_cuda_all_reduce.cc b/test/kernels/cuda/test_cuda_all_reduce.cc new file mode 100644 index 00000000..7140d10b --- /dev/null +++ b/test/kernels/cuda/test_cuda_all_reduce.cc @@ -0,0 +1,109 @@ +#ifdef INFINI_USE_NCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/all_reduce.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +template +void allReduce(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime cudaRuntime = make_ref(deviceID); + cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(cudaRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr); + // Copy data from CPU to GPU + g->dataMalloc(); + input->copyin(data); + // Run operation + cudaRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(CUDA_AllReduce, sum) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {7., 9.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_sum", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(CUDA_AllReduce, prod) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {10., 18.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_prod", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(CUDA_AllReduce, min) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {2., 3.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_min", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(CUDA_AllReduce, max) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {5., 6.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_max", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(CUDA_AllReduce, avg) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {3.5, 4.5}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(allReduce, "test_allreduce_avg", + gpu, data[gpu], ans); + } + for (auto &thread : threads) { + thread.join(); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/cuda/test_cuda_broadcast.cc b/test/kernels/cuda/test_cuda_broadcast.cc new file mode 100644 index 00000000..c23a9e65 --- /dev/null +++ b/test/kernels/cuda/test_cuda_broadcast.cc @@ -0,0 +1,56 @@ +#ifdef INFINI_USE_NCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/broadcast.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; +static int root = 0; + +namespace infini { + +void broadcast(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime cudaRuntime = make_ref(deviceID); + cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(cudaRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr, root); + // Copy data from CPU to GPU + g->dataMalloc(); + // Only rank 0 has the data + if (deviceID == root) { + input->copyin(data); + } + // Run broadcast operation + cudaRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(CUDA_Broadcast, run) { + // Only 1 device gets data. Every rank should have the same data after + // broadcast. + vector data = {2., 3., 5., 6.}; + vector ans = {2., 3., 5., 6.}; + + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(broadcast, "test_broadcast", gpu, data, ans); + } + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace infini +#endif diff --git a/test/operators/test_all_gather.cc b/test/operators/test_all_gather.cc new file mode 100644 index 00000000..6a1e4e4e --- /dev/null +++ b/test/operators/test_all_gather.cc @@ -0,0 +1,23 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_gather.h" +#include "test.h" + +namespace infini { +TEST(AllGather, ShapeTypeInfer) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + int world_size = 8; + { + Shape shape = {1, 3, 2, 4}; + Graph g = make_ref(runtime); + Tensor input = g->addTensor(shape, DataType::Float32); + auto op = g->addOp(input, std::nullopt, world_size); + EXPECT_EQ(op->getOpType(), OpType::AllGather); + EXPECT_EQ(op->numOutputs(), world_size); + for (int i = 0; i < world_size; ++i) { + EXPECT_EQ(op->getOutput(i)->getDims(), shape); + EXPECT_EQ(op->getOutput(i)->getDType(), DataType::Float32); + } + } +} +} // namespace infini diff --git a/test/operators/test_all_reduce.cc b/test/operators/test_all_reduce.cc new file mode 100644 index 00000000..18a8efef --- /dev/null +++ b/test/operators/test_all_reduce.cc @@ -0,0 +1,50 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_reduce.h" +#include "test.h" + +namespace infini { +TEST(AllReuce, ShapeTypeInfer) { + auto runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + EXPECT_EQ(op->getOpType(), OpType::AllReduceSum); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + EXPECT_EQ(op->getOpType(), OpType::AllReduceProd); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + EXPECT_EQ(op->getOpType(), OpType::AllReduceMin); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + EXPECT_EQ(op->getOpType(), OpType::AllReduceMax); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr); + EXPECT_EQ(op->getOpType(), OpType::AllReduceAvg); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } +} +} // namespace infini diff --git a/test/operators/test_broadcast.cc b/test/operators/test_broadcast.cc new file mode 100644 index 00000000..ba2e1f7f --- /dev/null +++ b/test/operators/test_broadcast.cc @@ -0,0 +1,19 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/broadcast.h" +#include "test.h" + +namespace infini { +TEST(Broadcast, ShapeTypeInfer) { + auto runtime = NativeCpuRuntimeObj::getInstance(); + int root = 0; + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto op = g->addOp(input, nullptr, root); + EXPECT_EQ(op->getOpType(), OpType::Broadcast); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } +} +} // namespace infini