forked from jiuyuan/InfiniTensor
impl distributed launch with NCCL (#106)
* add cmake bits about NCCL * move example to examples/NNmodel * impl NCCL communicator * add comm related function to Runtime * export runtime interface * add launch.py * use unique name to distingush the the NCCL ID file * add timeout to communicator init * expose communicator obj from runtime obj, add unit test for nccl communicator * reformat files * Add allReduce operator and cuda nccl allReduce kernel * impl model parallel for resnet * add allGather nccl kernel and operator * Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output * fix format of onnx.py * use concat following AllGather * get tensor parallel for resnet * fix format of graph_handler.cc * change BUILD_DIST default to OFF * polish code of communicator * update .gitignore * Add broadcast operator and cuda kernel * Add comments for operators * remove const of class member * move communicator to CudaRuntimeObj * Add an empty line at EOF. --------- Co-authored-by: panzezhong <panzezhong@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
b4eda85e67
commit
f60767a770
|
@ -42,3 +42,5 @@ build_debug/
|
||||||
|
|
||||||
# onnx model
|
# onnx model
|
||||||
*.onnx
|
*.onnx
|
||||||
|
*.pb
|
||||||
|
*.npy
|
||||||
|
|
|
@ -11,5 +11,5 @@
|
||||||
path = 3rd-party/backward-cpp
|
path = 3rd-party/backward-cpp
|
||||||
url = git@github.com:bombela/backward-cpp.git
|
url = git@github.com:bombela/backward-cpp.git
|
||||||
[submodule "example"]
|
[submodule "example"]
|
||||||
path = example
|
path = examples/NNmodel
|
||||||
url = git@github.com:wanghailu0717/NNmodel.git
|
url = git@github.com:wanghailu0717/NNmodel.git
|
||||||
|
|
|
@ -8,6 +8,7 @@ option(USE_BANG "Support BANG MLU" OFF)
|
||||||
option(USE_INTELCPU "Support INTELCPU" OFF)
|
option(USE_INTELCPU "Support INTELCPU" OFF)
|
||||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||||
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
||||||
|
option(BUILD_DIST "Build project for distributed running" OFF)
|
||||||
option(BUILD_TEST "Build tests" OFF)
|
option(BUILD_TEST "Build tests" OFF)
|
||||||
|
|
||||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||||
|
@ -194,6 +195,13 @@ if(USE_CUDA)
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
||||||
target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver)
|
target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver)
|
||||||
|
if (BUILD_DIST)
|
||||||
|
message(STATUS "Add BUILD_DIST, use NCCL with CUDA")
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
add_compile_definitions(INFINI_USE_NCCL=1)
|
||||||
|
target_link_libraries(InfiniTensor nccl)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_BANG)
|
if(USE_BANG)
|
||||||
|
@ -261,6 +269,7 @@ if(BUILD_TEST)
|
||||||
build_test(test/operators/*.cc)
|
build_test(test/operators/*.cc)
|
||||||
if (USE_CUDA)
|
if (USE_CUDA)
|
||||||
build_test(test/kernels/cuda/*.cc)
|
build_test(test/kernels/cuda/*.cc)
|
||||||
|
build_test(test/cuda/*.cc)
|
||||||
endif()
|
endif()
|
||||||
if (USE_BANG)
|
if (USE_BANG)
|
||||||
build_test(test/kernels/bang/*.cc)
|
build_test(test/kernels/bang/*.cc)
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# From PyTorch:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||||
|
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||||
|
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||||
|
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||||
|
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||||
|
# Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||||
|
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||||
|
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||||
|
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||||
|
#
|
||||||
|
# From Caffe2:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# All contributions by Facebook:
|
||||||
|
# Copyright (c) 2016 Facebook Inc.
|
||||||
|
#
|
||||||
|
# All contributions by Google:
|
||||||
|
# Copyright (c) 2015 Google Inc.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# All contributions by Yangqing Jia:
|
||||||
|
# Copyright (c) 2015 Yangqing Jia
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# All contributions by Kakao Brain:
|
||||||
|
# Copyright 2019-2020 Kakao Brain
|
||||||
|
#
|
||||||
|
# All contributions from Caffe:
|
||||||
|
# Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# All other contributions:
|
||||||
|
# Copyright(c) 2015, 2016 the respective contributors
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||||
|
# copyright over their contributions to Caffe2. The project versioning records
|
||||||
|
# all such contribution and copyright details. If a contributor wants to further
|
||||||
|
# mark their specific copyright on a particular contribution, they should
|
||||||
|
# indicate their copyright solely in the commit message of the change when it is
|
||||||
|
# committed.
|
||||||
|
#
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||||
|
# and IDIAP Research Institute nor the names of its contributors may be
|
||||||
|
# used to endorse or promote products derived from this software without
|
||||||
|
# specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||||
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
# POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
#
|
||||||
|
# Find the nccl libraries
|
||||||
|
#
|
||||||
|
# The following variables are optionally searched for defaults
|
||||||
|
# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou…
|
||||||
|
# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo
|
||||||
|
# NCCL_LIB_DIR: Directory where NCCL library is found
|
||||||
|
#
|
||||||
|
# The following are set after configuration is done:
|
||||||
|
# NCCL_FOUND
|
||||||
|
# NCCL_INCLUDE_DIRS
|
||||||
|
# NCCL_LIBRARIES
|
||||||
|
#
|
||||||
|
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks
|
||||||
|
# install NCCL in the same location as the CUDA toolkit.
|
||||||
|
# See https://github.com/caffe2/caffe2/issues/1601
|
||||||
|
|
||||||
|
set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers")
|
||||||
|
set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries")
|
||||||
|
set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with")
|
||||||
|
|
||||||
|
if ($ENV{NCCL_ROOT_DIR})
|
||||||
|
message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.")
|
||||||
|
endif()
|
||||||
|
list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
|
||||||
|
# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT})
|
||||||
|
|
||||||
|
find_path(NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR})
|
||||||
|
|
||||||
|
if (USE_STATIC_NCCL)
|
||||||
|
MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.")
|
||||||
|
SET(NCCL_LIBNAME "nccl_static")
|
||||||
|
if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
|
||||||
|
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
SET(NCCL_LIBNAME "nccl")
|
||||||
|
if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
|
||||||
|
set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR})
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
|
||||||
|
set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...")
|
||||||
|
set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
||||||
|
list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS})
|
||||||
|
include(CheckCXXSymbolExists)
|
||||||
|
check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
|
||||||
|
|
||||||
|
if (NCCL_VERSION_DEFINED)
|
||||||
|
set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
|
||||||
|
file(WRITE ${file} "
|
||||||
|
#include <iostream>
|
||||||
|
#include <nccl.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
|
||||||
|
int x;
|
||||||
|
ncclGetVersion(&x);
|
||||||
|
return x == NCCL_VERSION_CODE;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
|
||||||
|
RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER
|
||||||
|
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}"
|
||||||
|
LINK_LIBRARIES ${NCCL_LIBRARIES})
|
||||||
|
if (NOT NCCL_VERSION_MATCHED)
|
||||||
|
message(FATAL_ERROR "Found NCCL header version and library version do not match! \
|
||||||
|
(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
|
||||||
|
endif()
|
||||||
|
message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
|
||||||
|
else()
|
||||||
|
# message(STATUS "NCCL version < 2.3.5-5")
|
||||||
|
endif ()
|
||||||
|
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
|
||||||
|
|
||||||
|
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
|
@ -0,0 +1,100 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
import numpy as np
|
||||||
|
from parallel import parallel_model
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, required=True, help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return args.num_nodes, args.nproc_per_node, args.model
|
||||||
|
|
||||||
|
|
||||||
|
def run_stub(stub: OnnxStub, inputs: np.array, n=100):
|
||||||
|
# warm up
|
||||||
|
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
||||||
|
stub.tune()
|
||||||
|
for _ in range(20):
|
||||||
|
stub.run()
|
||||||
|
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||||
|
|
||||||
|
# bench
|
||||||
|
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n):
|
||||||
|
stub.run()
|
||||||
|
end = time.time()
|
||||||
|
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||||
|
print("outputs sum:", outputs.sum())
|
||||||
|
# np.save("results", outputs)
|
||||||
|
results = np.load("results.npy")
|
||||||
|
print("max diff:", abs(outputs - results).max())
|
||||||
|
assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6)
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
return avg_time
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
|
):
|
||||||
|
print("start worker")
|
||||||
|
runtime = backend.CudaRuntime(local_rank)
|
||||||
|
print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
onnx.save(model, f"dist_model_rank{rank}.onnx")
|
||||||
|
print("load model")
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
data = np.load("inputs.npy")
|
||||||
|
print("run model")
|
||||||
|
avg_time = run_stub(stub, data)
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
nnodes, nproc_per_node, model_path = parse_args()
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
# generate standard results
|
||||||
|
# runtime = backend.CudaRuntime(0)
|
||||||
|
# stub = OnnxStub(model, runtime)
|
||||||
|
# data = np.random.randn(1, 3, 224, 224)
|
||||||
|
# np.save("inputs", data)
|
||||||
|
# run_stub(stub, data)
|
||||||
|
# del stub
|
||||||
|
|
||||||
|
dist_name = f"dist_{os.getpid()}"
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(dist_name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,103 @@
|
||||||
|
import onnx
|
||||||
|
from onnx import (
|
||||||
|
ModelProto,
|
||||||
|
TensorProto,
|
||||||
|
NodeProto,
|
||||||
|
AttributeProto,
|
||||||
|
)
|
||||||
|
from onnx import helper, numpy_helper
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
def parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name in attrs:
|
||||||
|
if attr.type == AttributeProto.INT:
|
||||||
|
attrs[attr.name] = attr.i
|
||||||
|
elif attr.type == AttributeProto.INTS:
|
||||||
|
attrs[attr.name] = attr.ints
|
||||||
|
elif attr.type == AttributeProto.FLOAT:
|
||||||
|
attrs[attr.name] = attr.f
|
||||||
|
elif attr.type == AttributeProto.STRING:
|
||||||
|
attrs[attr.name] = attr.s
|
||||||
|
elif attr.type == AttributeProto.TENSOR:
|
||||||
|
attrs[attr.name] = attr.t
|
||||||
|
else:
|
||||||
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
|
data = {init.name: init for init in model.graph.initializer}
|
||||||
|
nodes = list(model.graph.node)
|
||||||
|
|
||||||
|
def shard_tensor(tensor: TensorProto, dim: int):
|
||||||
|
array = numpy_helper.to_array(tensor)
|
||||||
|
if dim >= array.ndim:
|
||||||
|
dim = array.ndim - 1
|
||||||
|
assert array.shape[dim] % tp_world_size == 0
|
||||||
|
seg = array.shape[dim] // tp_world_size
|
||||||
|
array = array[tp_rank * seg : (tp_rank + 1) * seg]
|
||||||
|
return numpy_helper.from_array(array, name=tensor.name + f":sharded({dim})")
|
||||||
|
|
||||||
|
def shard_gemm(node: NodeProto):
|
||||||
|
attrs = parse_attribute(
|
||||||
|
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||||
|
)
|
||||||
|
trans = [attrs["transA"], attrs["transB"]]
|
||||||
|
dim = 0
|
||||||
|
for i, (input, t) in enumerate(zip(node.input, trans)):
|
||||||
|
if input in data:
|
||||||
|
dim = i
|
||||||
|
sharded = shard_tensor(data[input], dim ^ t)
|
||||||
|
node.input[i] = sharded.name
|
||||||
|
data[input] = sharded
|
||||||
|
if len(node.input) > 2:
|
||||||
|
input = node.input[2]
|
||||||
|
sharded = shard_tensor(data[input], dim)
|
||||||
|
node.input[2] = sharded.name
|
||||||
|
data[input] = sharded
|
||||||
|
|
||||||
|
node.output[0] += f":sharded({dim})"
|
||||||
|
return dim
|
||||||
|
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
if node.op_type == "Gemm":
|
||||||
|
output = node.output[0]
|
||||||
|
dim = shard_gemm(node)
|
||||||
|
gathered = [node.output[0] + f".{i}" for i in range(tp_world_size)]
|
||||||
|
# all_gather
|
||||||
|
nodes.insert(
|
||||||
|
i + 1,
|
||||||
|
helper.make_node(
|
||||||
|
op_type="AllGather",
|
||||||
|
inputs=[node.output[0]],
|
||||||
|
outputs=gathered,
|
||||||
|
name=node.name + "/allgather",
|
||||||
|
# domain="infini", # shape inference fails for custom domain
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# concat
|
||||||
|
nodes.insert(
|
||||||
|
i + 2,
|
||||||
|
helper.make_node(
|
||||||
|
op_type="Concat",
|
||||||
|
inputs=gathered,
|
||||||
|
outputs=[output],
|
||||||
|
name=node.name + "/concat",
|
||||||
|
axis=dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
graph = helper.make_graph(
|
||||||
|
nodes,
|
||||||
|
model.graph.name + f"_{tp_rank}",
|
||||||
|
model.graph.input,
|
||||||
|
model.graph.output,
|
||||||
|
data.values(),
|
||||||
|
doc_string=model.graph.doc_string,
|
||||||
|
value_info=model.graph.value_info,
|
||||||
|
)
|
||||||
|
model = helper.make_model(graph)
|
||||||
|
|
||||||
|
onnx.shape_inference.infer_shapes(model)
|
||||||
|
return model
|
|
@ -0,0 +1,22 @@
|
||||||
|
#pragma once
|
||||||
|
#include "object.h"
|
||||||
|
#include "ref.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
// base class
|
||||||
|
class CommunicatorObj : public Object {
|
||||||
|
protected:
|
||||||
|
int worldSize;
|
||||||
|
int rank;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CommunicatorObj(int worldSize, int rank)
|
||||||
|
: worldSize(worldSize), rank(rank) {}
|
||||||
|
|
||||||
|
virtual ~CommunicatorObj() = default;
|
||||||
|
virtual int getWorldSize() const { return worldSize; }
|
||||||
|
virtual int getRank() const { return rank; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -74,6 +74,14 @@ class GraphHandlerObj {
|
||||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||||
|
|
||||||
|
Tensor allReduceSum(Tensor input, Tensor output);
|
||||||
|
Tensor allReduceProd(Tensor input, Tensor output);
|
||||||
|
Tensor allReduceMin(Tensor input, Tensor output);
|
||||||
|
Tensor allReduceMax(Tensor input, Tensor output);
|
||||||
|
Tensor allReduceAvg(Tensor input, Tensor output);
|
||||||
|
TensorVec allGather(Tensor input, std::optional<TensorVec> outputs, int n);
|
||||||
|
Tensor broadcast(Tensor input, Tensor output, int root);
|
||||||
|
|
||||||
//------ modifiers
|
//------ modifiers
|
||||||
|
|
||||||
inline bool topo_sort() { return g->topo_sort(); }
|
inline bool topo_sort() { return g->topo_sort(); }
|
||||||
|
|
|
@ -221,6 +221,15 @@ struct OpType {
|
||||||
FloorMod,
|
FloorMod,
|
||||||
Square,
|
Square,
|
||||||
SquaredDifference,
|
SquaredDifference,
|
||||||
|
|
||||||
|
// Communication Ops
|
||||||
|
AllReduceSum,
|
||||||
|
AllReduceProd,
|
||||||
|
AllReduceMin,
|
||||||
|
AllReduceMax,
|
||||||
|
AllReduceAvg,
|
||||||
|
AllGather,
|
||||||
|
Broadcast,
|
||||||
} type;
|
} type;
|
||||||
|
|
||||||
constexpr OpType(decltype(type) t) : type(t) {}
|
constexpr OpType(decltype(type) t) : type(t) {}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/common.h"
|
#include "core/common.h"
|
||||||
|
#include "core/communicator.h"
|
||||||
#include "core/op_type.h"
|
#include "core/op_type.h"
|
||||||
#include "core/ref.h"
|
#include "core/ref.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -35,9 +36,11 @@ enum class Device { CPU = 1, CUDA, BANG, INTELCPU };
|
||||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||||
protected:
|
protected:
|
||||||
Device device;
|
Device device;
|
||||||
|
int deviceId;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RuntimeObj(Device device) : device(device) {}
|
explicit RuntimeObj(Device device, int deviceId = 0)
|
||||||
|
: device(device), deviceId(deviceId) {}
|
||||||
RuntimeObj(RuntimeObj &other) = delete;
|
RuntimeObj(RuntimeObj &other) = delete;
|
||||||
RuntimeObj &operator=(RuntimeObj const &) = delete;
|
RuntimeObj &operator=(RuntimeObj const &) = delete;
|
||||||
virtual ~RuntimeObj() {}
|
virtual ~RuntimeObj() {}
|
||||||
|
@ -77,6 +80,12 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||||
size_t bytes) const = 0;
|
size_t bytes) const = 0;
|
||||||
virtual string toString() const = 0;
|
virtual string toString() const = 0;
|
||||||
|
|
||||||
|
int getDeviceId() const { return deviceId; }
|
||||||
|
|
||||||
|
virtual void initComm(const string &name, int worldSize, int rank) = 0;
|
||||||
|
|
||||||
|
virtual CommunicatorObj &getCommunicator() const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void printProfilingData(double totTime,
|
void printProfilingData(double totTime,
|
||||||
const std::map<OpType, double> &opTime,
|
const std::map<OpType, double> &opTime,
|
||||||
|
@ -97,6 +106,9 @@ class CpuRuntimeObj : public RuntimeObj {
|
||||||
void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override;
|
void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override;
|
||||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const override;
|
size_t bytes) const override;
|
||||||
|
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
||||||
|
|
||||||
|
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class NativeCpuRuntimeObj : public CpuRuntimeObj {
|
class NativeCpuRuntimeObj : public CpuRuntimeObj {
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -8,12 +11,15 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
private:
|
private:
|
||||||
cudnnHandle_t cudnn;
|
cudnnHandle_t cudnn;
|
||||||
cublasHandle_t cublas;
|
cublasHandle_t cublas;
|
||||||
|
std::unique_ptr<CommunicatorObj> comm;
|
||||||
CudaPtr workspace;
|
CudaPtr workspace;
|
||||||
size_t workspaceSize;
|
size_t workspaceSize;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
explicit CudaRuntimeObj(int deviceId = 0)
|
||||||
|
: RuntimeObj(Device::CUDA, deviceId) {
|
||||||
|
|
||||||
|
checkCudaError(cudaSetDevice(deviceId));
|
||||||
checkCudnnError(cudnnCreate(&cudnn));
|
checkCudnnError(cudnnCreate(&cudnn));
|
||||||
checkCublasError(cublasCreate(&cublas));
|
checkCublasError(cublasCreate(&cublas));
|
||||||
// 10GB for Longformer
|
// 10GB for Longformer
|
||||||
|
@ -69,6 +75,11 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
|
|
||||||
void runWithoutSync(const Graph &graph) const;
|
void runWithoutSync(const Graph &graph) const;
|
||||||
|
|
||||||
|
// init communicator
|
||||||
|
void initComm(const string &name, int worldSize, int rank) final;
|
||||||
|
|
||||||
|
CommunicatorObj &getCommunicator() const final { return *comm; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void tune(const Graph &graph, bool profiling) const;
|
void tune(const Graph &graph, bool profiling) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/communicator.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#define checkNcclError(call) \
|
||||||
|
{ \
|
||||||
|
auto err = call; \
|
||||||
|
if (ncclSuccess != err) { \
|
||||||
|
fprintf(stderr, "NCCL error in %s:%i : %s.\n", __FILE__, __LINE__, \
|
||||||
|
ncclGetErrorString(err)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class NcclCommunicatorObj final : public CommunicatorObj {
|
||||||
|
private:
|
||||||
|
ncclComm_t comm;
|
||||||
|
|
||||||
|
public:
|
||||||
|
NcclCommunicatorObj(const string &name, int worldSize, int rank)
|
||||||
|
: CommunicatorObj(worldSize, rank) {
|
||||||
|
const std::string filePath("./" + name + "_nccl_id.bin");
|
||||||
|
ncclUniqueId commId;
|
||||||
|
if (rank == 0) {
|
||||||
|
checkNcclError(ncclGetUniqueId(&commId));
|
||||||
|
std::ofstream ofs(filePath, std::ios::binary);
|
||||||
|
ofs.write((char *)&commId, sizeof(ncclUniqueId));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto begin = std::chrono::steady_clock::now();
|
||||||
|
while (!std::filesystem::exists(filePath)) {
|
||||||
|
auto now = std::chrono::steady_clock::now();
|
||||||
|
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||||
|
"time limit (10s) exceeded.");
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
std::ifstream ifs(filePath, std::ios::binary);
|
||||||
|
ifs.read((char *)&commId, sizeof(ncclUniqueId));
|
||||||
|
}
|
||||||
|
checkNcclError(ncclCommInitRank(&comm, worldSize, commId, rank));
|
||||||
|
if (rank == 0) {
|
||||||
|
std::filesystem::remove(filePath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the actual ncclComm_t
|
||||||
|
ncclComm_t getNcclComm() { return comm; }
|
||||||
|
|
||||||
|
void finalize() { checkNcclError(ncclCommFinalize(comm)); }
|
||||||
|
|
||||||
|
~NcclCommunicatorObj() final {
|
||||||
|
finalize();
|
||||||
|
checkNcclError(ncclCommDestroy(comm));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual string toString() const final {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "NCCL communicator";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,44 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief The AllGather operation gathers N values from k ranks into
|
||||||
|
* an output of size k*N, and distributes that result to all ranks.
|
||||||
|
* The output is ordered by rank index.
|
||||||
|
*
|
||||||
|
* For more details:
|
||||||
|
* https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather
|
||||||
|
*/
|
||||||
|
class AllGatherObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new AllGather object
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor from this rank.
|
||||||
|
* @param outputs A list of output tensors collected from all ranks.
|
||||||
|
* @param world_size Total number of ranks.
|
||||||
|
*/
|
||||||
|
AllGatherObj(GraphObj *graph, Tensor input, std::optional<TensorVec>,
|
||||||
|
int world_size);
|
||||||
|
OP_CLONE(AllGatherObj);
|
||||||
|
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return world_size; }
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
|
||||||
|
int getWorldSize() const { return world_size; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int world_size;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,75 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief The AllReduce operation is performing reductions on data (sum, min,
|
||||||
|
* max, avg, or div) across devices and writing the result in the
|
||||||
|
* receive buffers of every rank. For example, in an allreduce operation between
|
||||||
|
* k ranks and performing a sum, each rank will provide an array Vk of N values,
|
||||||
|
* and receive an identical arrays S of N values, where S[i] =
|
||||||
|
* V0[i]+V1[i]+…+Vk-1[i].
|
||||||
|
*
|
||||||
|
* For more details:
|
||||||
|
* https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce
|
||||||
|
*/
|
||||||
|
class AllReduceBaseObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new AllReduce base object. Should be called by every
|
||||||
|
* child class constructor, but not directly.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param opType The operation type. This param is taken care of by child
|
||||||
|
* classes.
|
||||||
|
* @param input The input tensor from this rank.
|
||||||
|
* @param output The output tensor, same size as input.
|
||||||
|
*/
|
||||||
|
AllReduceBaseObj(GraphObj *graph, OpType opType, Tensor input,
|
||||||
|
Tensor output);
|
||||||
|
OP_CLONE(AllReduceBaseObj);
|
||||||
|
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||||
|
return {{inputs[0]->getDims()}};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||||
|
return {inputs[0]->getDType()};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceSumObj : public AllReduceBaseObj {
|
||||||
|
public:
|
||||||
|
AllReduceSumObj(GraphObj *graph, Tensor input, Tensor output);
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceProdObj : public AllReduceBaseObj {
|
||||||
|
public:
|
||||||
|
AllReduceProdObj(GraphObj *graph, Tensor input, Tensor output);
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceMinObj : public AllReduceBaseObj {
|
||||||
|
public:
|
||||||
|
AllReduceMinObj(GraphObj *graph, Tensor input, Tensor output);
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceMaxObj : public AllReduceBaseObj {
|
||||||
|
public:
|
||||||
|
AllReduceMaxObj(GraphObj *graph, Tensor input, Tensor output);
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceAvgObj : public AllReduceBaseObj {
|
||||||
|
public:
|
||||||
|
AllReduceAvgObj(GraphObj *graph, Tensor input, Tensor output);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,49 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief The Broadcast operation copies an N-element buffer on the root rank to
|
||||||
|
* all ranks.
|
||||||
|
*
|
||||||
|
* For more details:
|
||||||
|
* https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast
|
||||||
|
*/
|
||||||
|
class BroadcastObj : public OperatorObj {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Broadcast object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor. Only root needs to initialize it with
|
||||||
|
* data.
|
||||||
|
* @param output The output tensor, same size as input.
|
||||||
|
* @param root The root rank who performs the broadcast.
|
||||||
|
*/
|
||||||
|
BroadcastObj(GraphObj *graph, Tensor input, Tensor output, int root);
|
||||||
|
OP_CLONE(BroadcastObj);
|
||||||
|
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||||
|
return {{inputs[0]->getDims()}};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
|
||||||
|
int getRoot() const { return root; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||||
|
return {inputs[0]->getDType()};
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// The rank who broadcasts data among this communication group
|
||||||
|
int root;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -591,6 +591,54 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next((attr.i for attr in node.attribute if attr.name == "to")),
|
next((attr.i for attr in node.attribute if attr.name == "to")),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "AllReduceSum":
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "AllReduceProd":
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceProd(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "AllReduceMin":
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceMin(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "AllReduceMax":
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceMax(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "AllReduceAvg":
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceAvg(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "AllGather":
|
||||||
|
for name, tensor in zip(
|
||||||
|
node.output,
|
||||||
|
self.handler.allGather(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
None,
|
||||||
|
len(node.output),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
tensors[name] = tensor
|
||||||
|
elif node.op_type == "Broadcast":
|
||||||
|
tensors[node.output[0]] = self.handler.broadcast(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
attr.i
|
||||||
|
for attr in node.attribute
|
||||||
|
if attr.name == "root"
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
elif node.op_type == "Expand":
|
elif node.op_type == "Expand":
|
||||||
shape = _parse_data(data[node.input[1]])
|
shape = _parse_data(data[node.input[1]])
|
||||||
tensors[node.output[0]] = self.handler.expand(
|
tensors[node.output[0]] = self.handler.expand(
|
||||||
|
|
|
@ -329,6 +329,83 @@ class TestStringMethods(unittest.TestCase):
|
||||||
[pads_data],
|
[pads_data],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_allReduceSum(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
allReduceSum = make_node(
|
||||||
|
"AllReduceSum", ["input"], ["output"], name="allReduceSum"
|
||||||
|
)
|
||||||
|
graph = make_graph([allReduceSum], "allReduceSum", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_allReduceProd(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
allReduceProd = make_node(
|
||||||
|
"AllReduceProd", ["input"], ["output"], name="allReduceProd"
|
||||||
|
)
|
||||||
|
graph = make_graph([allReduceProd], "allReduceProd", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_allReduceMin(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
allReduceMin = make_node(
|
||||||
|
"AllReduceMin", ["input"], ["output"], name="allReduceMin"
|
||||||
|
)
|
||||||
|
graph = make_graph([allReduceMin], "allReduceMin", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_allReduceMax(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
allReduceMax = make_node(
|
||||||
|
"AllReduceMax", ["input"], ["output"], name="allReduceMax"
|
||||||
|
)
|
||||||
|
graph = make_graph([allReduceMax], "allReduceMax", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_allReduceAvg(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
allReduceAvg = make_node(
|
||||||
|
"AllReduceAvg", ["input"], ["output"], name="allReduceAvg"
|
||||||
|
)
|
||||||
|
graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
split = make_node(
|
||||||
|
"Split", ["input"], ["output"], name="split", axis=0
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([split], "split", [input], []))
|
||||||
|
|
||||||
|
def test_allBroadcast(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
broadcast = make_node(
|
||||||
|
"Broadcast", ["input"], ["output"], name="broadcast", root=1
|
||||||
|
)
|
||||||
|
graph = make_graph([broadcast], "broadcast", [input], [output])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_allGather(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
world_size = make_tensor_value_info("world_size", TensorProto.INT32, [1])
|
||||||
|
allGather = make_node(
|
||||||
|
"AllGather", ["input", "world_size"], ["output"], name="allGather"
|
||||||
|
)
|
||||||
|
graph = make_graph([allGather], "allGather", [input, world_size], [])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
|
#include "operators/broadcast.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
|
@ -300,6 +303,73 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::allReduceSum(Tensor input, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AllReduceSumObj>(std::move(input), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllReduceSumObj>(std::move(input), output)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::allReduceProd(Tensor input, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AllReduceProdObj>(std::move(input), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllReduceProdObj>(std::move(input), output)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::allReduceMin(Tensor input, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AllReduceMinObj>(std::move(input), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllReduceMinObj>(std::move(input), output)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::allReduceMax(Tensor input, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AllReduceMaxObj>(std::move(input), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllReduceMaxObj>(std::move(input), output)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::allReduceAvg(Tensor input, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AllReduceAvgObj>(std::move(input), output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllReduceAvgObj>(std::move(input), output)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorVec GraphHandlerObj::allGather(Tensor input,
|
||||||
|
std::optional<TensorVec> outputs, int n) {
|
||||||
|
if (outputs) {
|
||||||
|
g->addOpWithOutputs<AllGatherObj>(std::move(input), outputs, n);
|
||||||
|
return *outputs;
|
||||||
|
} else {
|
||||||
|
return g->addOp<AllGatherObj>(std::move(input), outputs, n)
|
||||||
|
->getOutputs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<BroadcastObj>(std::move(input), output, root);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<BroadcastObj>(std::move(input), output, root)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<CastObj>(std::move(input), output,
|
g->addOpWithOutputs<CastObj>(std::move(input), output,
|
||||||
|
|
|
@ -214,6 +214,15 @@ const char *OpType::toString() const {
|
||||||
CASE(FloorMod);
|
CASE(FloorMod);
|
||||||
CASE(Square);
|
CASE(Square);
|
||||||
CASE(SquaredDifference);
|
CASE(SquaredDifference);
|
||||||
|
|
||||||
|
// Communcation
|
||||||
|
CASE(AllReduceSum);
|
||||||
|
CASE(AllReduceProd);
|
||||||
|
CASE(AllReduceMin);
|
||||||
|
CASE(AllReduceMax);
|
||||||
|
CASE(AllReduceAvg);
|
||||||
|
CASE(AllGather);
|
||||||
|
CASE(Broadcast);
|
||||||
default:
|
default:
|
||||||
return "Unknown";
|
return "Unknown";
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,9 @@
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
#include "core/perf_engine.h"
|
#include "core/perf_engine.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
#endif
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
|
||||||
|
@ -96,4 +99,15 @@ void CudaRuntimeObj::sync() const { checkCudaError(cudaDeviceSynchronize()); }
|
||||||
|
|
||||||
string CudaRuntimeObj::toString() const { return "CUDA Runtime"; }
|
string CudaRuntimeObj::toString() const { return "CUDA Runtime"; }
|
||||||
|
|
||||||
|
void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||||
|
IT_ASSERT(worldSize > 0);
|
||||||
|
IT_ASSERT(rank >= 0);
|
||||||
|
IT_ASSERT(rank < worldSize);
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
comm = std::make_unique<NcclCommunicatorObj>(name, worldSize, rank);
|
||||||
|
#else
|
||||||
|
IT_TODO_HALT_MSG("Not compiled with NCCL.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -143,7 +143,10 @@ static int tensor_dtype(Tensor t) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
// NOTE(lizhouyang): deprecate this, use CudaRuntime directly.
|
||||||
|
[[deprecated]] static Ref<CudaRuntimeObj> cuda_runtime() {
|
||||||
|
return make_ref<CudaRuntimeObj>(0);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_BANG
|
#ifdef USE_BANG
|
||||||
|
@ -311,7 +314,9 @@ void init_graph_builder(py::module &m) {
|
||||||
RuntimeObj>(m, "CpuRuntime");
|
RuntimeObj>(m, "CpuRuntime");
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
||||||
m, "CudaRuntime");
|
m, "CudaRuntime")
|
||||||
|
.def(py::init<int>(), py::arg("device") = 0)
|
||||||
|
.def("init_comm", &CudaRuntimeObj::initComm);
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_BANG
|
#ifdef USE_BANG
|
||||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||||
|
@ -435,6 +440,13 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||||
.def("slice", &Handler::slice, policy::move)
|
.def("slice", &Handler::slice, policy::move)
|
||||||
.def("pad", &Handler::pad, policy::move)
|
.def("pad", &Handler::pad, policy::move)
|
||||||
|
.def("allReduceSum", &Handler::allReduceSum, policy::move)
|
||||||
|
.def("allReduceProd", &Handler::allReduceProd, policy::move)
|
||||||
|
.def("allReduceMin", &Handler::allReduceMin, policy::move)
|
||||||
|
.def("allReduceMax", &Handler::allReduceMax, policy::move)
|
||||||
|
.def("allReduceAvg", &Handler::allReduceAvg, policy::move)
|
||||||
|
.def("allGather", &Handler::allGather, policy::move)
|
||||||
|
.def("broadcast", &Handler::broadcast, policy::move)
|
||||||
.def("cast", &Handler::cast, policy::move)
|
.def("cast", &Handler::cast, policy::move)
|
||||||
.def("expand", &Handler::expand, policy::move)
|
.def("expand", &Handler::expand, policy::move)
|
||||||
.def("erf", &Handler::erf, policy::move)
|
.def("erf", &Handler::erf, policy::move)
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class AllGatherNCCL : public CudaKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllGatherObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
int world_size = op->getWorldSize();
|
||||||
|
// Check if world size info in operator matches runtime
|
||||||
|
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
|
||||||
|
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
CudaPtr output_temp =
|
||||||
|
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
|
||||||
|
// void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t bytes = op->getInputs(0)->getBytes();
|
||||||
|
size_t count = bytes / op->getDType().getSize();
|
||||||
|
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
checkNcclError(
|
||||||
|
ncclAllGather(input, output_temp, count, ncclFloat, comm, 0));
|
||||||
|
|
||||||
|
for (int i = 0; i < world_size; ++i) {
|
||||||
|
Tensor output = op->getOutput(i);
|
||||||
|
context->copyBlobInsideRuntime(
|
||||||
|
output->getRawDataPtr<float *>(),
|
||||||
|
static_cast<float *>(output_temp) + i * count, bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32,
|
||||||
|
AllGatherNCCL, "AllGather_NCCL_CUDA_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,58 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class AllReduceNCCL : public CudaKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllReduceBaseObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||||
|
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
|
||||||
|
getRedOp(), comm, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ncclRedOp_t getRedOp() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceSumNCCL : public AllReduceNCCL {
|
||||||
|
ncclRedOp_t getRedOp() const override { return ncclSum; }
|
||||||
|
};
|
||||||
|
class AllReduceProdNCCL : public AllReduceNCCL {
|
||||||
|
ncclRedOp_t getRedOp() const override { return ncclProd; }
|
||||||
|
};
|
||||||
|
class AllReduceMinNCCL : public AllReduceNCCL {
|
||||||
|
ncclRedOp_t getRedOp() const override { return ncclMin; }
|
||||||
|
};
|
||||||
|
class AllReduceMaxNCCL : public AllReduceNCCL {
|
||||||
|
ncclRedOp_t getRedOp() const override { return ncclMax; }
|
||||||
|
};
|
||||||
|
class AllReduceAvgNCCL : public AllReduceNCCL {
|
||||||
|
ncclRedOp_t getRedOp() const override { return ncclAvg; }
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32,
|
||||||
|
AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32,
|
||||||
|
AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32,
|
||||||
|
AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32,
|
||||||
|
AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32,
|
||||||
|
AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32");
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,32 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class BroadcastNCCL : public CudaKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<BroadcastObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||||
|
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
checkNcclError(ncclBroadcast(input, output, count, ncclFloat,
|
||||||
|
op->getRoot(), comm, 0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32,
|
||||||
|
BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,49 @@
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input,
|
||||||
|
std::optional<TensorVec> outputs, int world_size)
|
||||||
|
: OperatorObj(
|
||||||
|
OpType::AllGather, {input},
|
||||||
|
((!outputs) ? TensorVec(world_size, nullptr) : std::move(*outputs))),
|
||||||
|
world_size(world_size) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>>
|
||||||
|
AllGatherObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
Shape input_shape = inputs[0]->getDims();
|
||||||
|
vector<Shape> output_shapes(getWorldSize(), input_shape);
|
||||||
|
return output_shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<DataType> AllGatherObj::inferDataType(const TensorVec &inputs) const {
|
||||||
|
return vector<DataType>(world_size, inputs[0]->getDType());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AllGatherObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "AllGather"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=";
|
||||||
|
for (auto i = 0; i < world_size; i++)
|
||||||
|
os << outputs[i]->getGuid() << ",";
|
||||||
|
os << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AllGatherObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{type.underlying()};
|
||||||
|
const Shape shape = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
ret.emplace_back(world_size);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AllGatherObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying(), world_size};
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
AllReduceBaseObj::AllReduceBaseObj(GraphObj *graph, OpType opType, Tensor input,
|
||||||
|
Tensor output)
|
||||||
|
: OperatorObj(opType, {input}, {output}) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AllReduceBaseObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << type.toString() << "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ",";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AllReduceBaseObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{type.underlying()};
|
||||||
|
const Shape shape = outputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AllReduceBaseObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying()};
|
||||||
|
}
|
||||||
|
|
||||||
|
AllReduceSumObj::AllReduceSumObj(GraphObj *graph, Tensor input, Tensor output)
|
||||||
|
: AllReduceBaseObj(graph, OpType::AllReduceSum, input, output) {}
|
||||||
|
|
||||||
|
AllReduceProdObj::AllReduceProdObj(GraphObj *graph, Tensor input, Tensor output)
|
||||||
|
: AllReduceBaseObj(graph, OpType::AllReduceProd, input, output) {}
|
||||||
|
|
||||||
|
AllReduceMinObj::AllReduceMinObj(GraphObj *graph, Tensor input, Tensor output)
|
||||||
|
: AllReduceBaseObj(graph, OpType::AllReduceMin, input, output) {}
|
||||||
|
|
||||||
|
AllReduceMaxObj::AllReduceMaxObj(GraphObj *graph, Tensor input, Tensor output)
|
||||||
|
: AllReduceBaseObj(graph, OpType::AllReduceMax, input, output) {}
|
||||||
|
|
||||||
|
AllReduceAvgObj::AllReduceAvgObj(GraphObj *graph, Tensor input, Tensor output)
|
||||||
|
: AllReduceBaseObj(graph, OpType::AllReduceAvg, input, output) {}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,33 @@
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
BroadcastObj::BroadcastObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
int root)
|
||||||
|
: OperatorObj(OpType::Broadcast, {input}, {output}), root(root) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> BroadcastObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{type.underlying()};
|
||||||
|
const Shape shape = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> BroadcastObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying()};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string BroadcastObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Broadcast"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ",";
|
||||||
|
os << "root=" << root;
|
||||||
|
os << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,55 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void allReduceSum(float *data, int deviceId) {
|
||||||
|
// Create Runtime and setup communication
|
||||||
|
CudaRuntimeObj *cuda_runtime = new CudaRuntimeObj(deviceId);
|
||||||
|
int rank = deviceId;
|
||||||
|
cuda_runtime->initComm("test_nccl_comm", WORLD_SIZE, rank);
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(cuda_runtime->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
|
||||||
|
// Copy data
|
||||||
|
float *data_gpu;
|
||||||
|
checkCudaError(cudaMalloc(&data_gpu, sizeof(float)));
|
||||||
|
checkCudaError(
|
||||||
|
cudaMemcpy(data_gpu, data, sizeof(float), cudaMemcpyHostToDevice));
|
||||||
|
|
||||||
|
// Do AllReduce
|
||||||
|
checkNcclError(
|
||||||
|
ncclAllReduce(data_gpu, data_gpu, 1, ncclFloat, ncclSum, comm, 0));
|
||||||
|
|
||||||
|
// Copy data back and sync device
|
||||||
|
checkCudaError(
|
||||||
|
cudaMemcpy(data, data_gpu, sizeof(float), cudaMemcpyDeviceToHost));
|
||||||
|
checkCudaError(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup communication between 2 threads, each controlling 1 GPU.
|
||||||
|
// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}.
|
||||||
|
TEST(NCCL, multi_gpu_communication) {
|
||||||
|
int num_threads = WORLD_SIZE;
|
||||||
|
float data[] = {1.0, 4.0};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < num_threads; ++gpu) {
|
||||||
|
threads.emplace_back(allReduceSum, &data[gpu], gpu);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_threads; ++i) {
|
||||||
|
ASSERT_EQ(data[i], 5.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,51 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void allGather(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<vector<float>> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime cudaRuntime = make_ref<CudaRuntimeObj>(deviceID);
|
||||||
|
cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllGatherObj>(input, std::nullopt, WORLD_SIZE);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
// Run operation
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
auto result = op->getOutputs()[i]->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData(ans[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllGather, run) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<vector<float>> ans = {{2., 3.}, {5., 6.}};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,109 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <typename OperatorObj>
|
||||||
|
void allReduce(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime cudaRuntime = make_ref<CudaRuntimeObj>(deviceID);
|
||||||
|
cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<OperatorObj>(input, nullptr);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
// Run operation
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllReduce, sum) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {7., 9.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceSumObj>, "test_allreduce_sum",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllReduce, prod) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {10., 18.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceProdObj>, "test_allreduce_prod",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllReduce, min) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {2., 3.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceMinObj>, "test_allreduce_min",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllReduce, max) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {5., 6.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceMaxObj>, "test_allreduce_max",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_AllReduce, avg) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {3.5, 4.5};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceAvgObj>, "test_allreduce_avg",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,56 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
static int root = 0;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void broadcast(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime cudaRuntime = make_ref<CudaRuntimeObj>(deviceID);
|
||||||
|
cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<BroadcastObj>(input, nullptr, root);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
// Only rank 0 has the data
|
||||||
|
if (deviceID == root) {
|
||||||
|
input->copyin(data);
|
||||||
|
}
|
||||||
|
// Run broadcast operation
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_Broadcast, run) {
|
||||||
|
// Only 1 device gets data. Every rank should have the same data after
|
||||||
|
// broadcast.
|
||||||
|
vector<float> data = {2., 3., 5., 6.};
|
||||||
|
vector<float> ans = {2., 3., 5., 6.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(broadcast, "test_broadcast", gpu, data, ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,23 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(AllGather, ShapeTypeInfer) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
int world_size = 8;
|
||||||
|
{
|
||||||
|
Shape shape = {1, 3, 2, 4};
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor(shape, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllGatherObj>(input, std::nullopt, world_size);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllGather);
|
||||||
|
EXPECT_EQ(op->numOutputs(), world_size);
|
||||||
|
for (int i = 0; i < world_size; ++i) {
|
||||||
|
EXPECT_EQ(op->getOutput(i)->getDims(), shape);
|
||||||
|
EXPECT_EQ(op->getOutput(i)->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,50 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(AllReuce, ShapeTypeInfer) {
|
||||||
|
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllReduceSumObj>(input, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllReduceSum);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllReduceProdObj>(input, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllReduceProd);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllReduceMinObj>(input, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllReduceMin);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllReduceMaxObj>(input, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllReduceMax);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllReduceAvgObj>(input, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::AllReduceAvg);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,19 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Broadcast, ShapeTypeInfer) {
|
||||||
|
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
int root = 0;
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<BroadcastObj>(input, nullptr, root);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::Broadcast);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue