forked from jiuyuan/InfiniTensor
Bang cncl (#163)
* MLU CNCL base * add FindCNCL.cmake, not find -lcncl * bangPrintFloat not find * docker:make sucessful, test error * delete net file and onnxtest.py * init * fix cncl * format * fix * format * fix cncl * run dist gpt2 on mlu * format * fix import error on mlu docker * run llama single card * run distributed llama2 * add test for slice/reduce on mlu * fix cncl related test * fix format * format * delete comments * change GPU to MLU * MLU CNCL base * add FindCNCL.cmake, not find -lcncl * bangPrintFloat not find * docker:make sucessful, test error * delete net file and onnxtest.py * init * fix cncl * format * fix * format * fix cncl * run dist gpt2 on mlu * format * fix import error on mlu docker * run llama single card * run distributed llama2 * add test for slice/reduce on mlu * fix cncl related test * fix format * format * delete comments * change GPU to MLU * modify launch script * fix name * fix format * fix gather * format python script --------- Co-authored-by: xgqdut2016 <kenan_gewei@163.com> Co-authored-by: Bolun <chamberlain0w0@gmail.com> Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com>
This commit is contained in:
parent
83f1de93d0
commit
42032356fb
|
@ -13,7 +13,7 @@ if(USE_CUDA)
|
||||||
message("CMake 3.18 or higher is required for setting CUDAToolkit")
|
message("CMake 3.18 or higher is required for setting CUDAToolkit")
|
||||||
cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit
|
cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit
|
||||||
else()
|
else()
|
||||||
cmake_minimum_required(VERSION 3.12)
|
cmake_minimum_required(VERSION 3.17)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(CMakeDependentOption)
|
include(CMakeDependentOption)
|
||||||
|
@ -245,6 +245,7 @@ if(USE_BANG)
|
||||||
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
|
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
|
||||||
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
|
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
|
||||||
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
|
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
|
||||||
|
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||||
|
|
||||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||||
|
@ -261,7 +262,13 @@ if(USE_BANG)
|
||||||
# BangC Kernels
|
# BangC Kernels
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||||
|
if (BUILD_DIST)
|
||||||
|
message(STATUS "Add BUILD_DIST, use CNCL with BANG")
|
||||||
|
|
||||||
|
add_compile_definitions(INFINI_USE_CNCL=1)
|
||||||
|
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_KUNLUN)
|
if(USE_KUNLUN)
|
||||||
|
@ -324,6 +331,7 @@ if(BUILD_TEST)
|
||||||
endif()
|
endif()
|
||||||
if (USE_BANG)
|
if (USE_BANG)
|
||||||
build_test(test/kernels/bang/*.cc)
|
build_test(test/kernels/bang/*.cc)
|
||||||
|
build_test(test/bang/*.cc)
|
||||||
endif()
|
endif()
|
||||||
if (USE_KUNLUN)
|
if (USE_KUNLUN)
|
||||||
build_test(test/kernels/kunlun/*.cc)
|
build_test(test/kernels/kunlun/*.cc)
|
||||||
|
|
1
Makefile
1
Makefile
|
@ -29,6 +29,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||||
|
CMAKE_OPT += -DBUILD_DIST=ON
|
||||||
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||||
|
|
||||||
ifeq ($(INTELCPU), ON)
|
ifeq ($(INTELCPU), ON)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
SET(CNCL_LIB_SEARCH_PATHS $ENV{NEUWARE_HOME}/lib64)
|
||||||
|
SET(CNCL_INCLUDE_SEARCH_PATHS $ENV{NEUWARE_HOME}/include)
|
||||||
|
|
||||||
|
set(CNCL_INCLUDE_DIR $ENV{NEUWARE_HOME}/include)
|
||||||
|
set(CNCL_LIB_DIR $ENV{NEUWARE_HOME}/lib64)
|
||||||
|
set(CNCL_VERSION $ENV{CNCL_VERSION} CACHE STRING "Version of CNCL to build with")
|
||||||
|
|
||||||
|
if ($ENV{CNCL_ROOT_DIR})
|
||||||
|
message(WARNING "CNCL_ROOT_DIR is deprecated. Please set CNCL_ROOT instead.")
|
||||||
|
endif()
|
||||||
|
list(APPEND CNCL_ROOT $ENV{CNCL_ROOT_DIR} ${MLU_TOOLKIT_ROOT_DIR})
|
||||||
|
# Compatible layer for CMake <3.12. CNCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH ${CNCL_ROOT})
|
||||||
|
|
||||||
|
find_path(CNCL_INCLUDE_DIRS
|
||||||
|
NAMES cncl.h
|
||||||
|
HINTS ${CNCL_INCLUDE_DIR})
|
||||||
|
|
||||||
|
if (USE_STATIC_CNCL)
|
||||||
|
MESSAGE(STATUS "USE_STATIC_CNCL is set. Linking with static CNCL library.")
|
||||||
|
SET(CNCL_LIBNAME "CNCL_static")
|
||||||
|
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
|
||||||
|
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
SET(CNCL_LIBNAME "cncl")
|
||||||
|
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
|
||||||
|
set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(CNCL_LIBRARIES
|
||||||
|
NAMES ${CNCL_LIBNAME}
|
||||||
|
HINTS ${CNCL_LIB_DIR})
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(CNCL DEFAULT_MSG CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(CNCL_FOUND) # obtaining CNCL version and some sanity checks
|
||||||
|
set (CNCL_HEADER_FILE "${CNCL_INCLUDE_DIRS}/cncl.h")
|
||||||
|
message (STATUS "Determining CNCL version from ${CNCL_HEADER_FILE}...")
|
||||||
|
set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
||||||
|
list (APPEND CMAKE_REQUIRED_INCLUDES ${CNCL_INCLUDE_DIRS})
|
||||||
|
include(CheckCXXSymbolExists)
|
||||||
|
check_cxx_symbol_exists(CNCL_VERSION_CODE CNCL.h CNCL_VERSION_DEFINED)
|
||||||
|
|
||||||
|
if (CNCL_VERSION_DEFINED)
|
||||||
|
set(file "${PROJECT_BINARY_DIR}/detect_cncl_version.cc")
|
||||||
|
file(WRITE ${file} "
|
||||||
|
#include <iostream>
|
||||||
|
#include <cncl.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
std::cout << CNCL_MAJOR << '.' << CNCL_MINOR << '.' << CNCL_PATCH << std::endl;
|
||||||
|
int x;
|
||||||
|
CNCLGetVersion(&x);
|
||||||
|
return x == CNCL_VERSION_CODE;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
try_run(CNCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
|
||||||
|
RUN_OUTPUT_VARIABLE CNCL_VERSION_FROM_HEADER
|
||||||
|
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CNCL_INCLUDE_DIRS}"
|
||||||
|
LINK_LIBRARIES ${CNCL_LIBRARIES})
|
||||||
|
if (NOT CNCL_VERSION_MATCHED)
|
||||||
|
message(FATAL_ERROR "Found CNCL header version and library version do not match! \
|
||||||
|
(include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES}) Please set CNCL_INCLUDE_DIR and CNCL_LIB_DIR manually.")
|
||||||
|
endif()
|
||||||
|
message(STATUS "CNCL version: ${CNCL_VERSION_FROM_HEADER}")
|
||||||
|
else()
|
||||||
|
# message(STATUS "CNCL version < 2.3.5-5")
|
||||||
|
endif ()
|
||||||
|
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
|
||||||
|
|
||||||
|
message(STATUS "Found CNCL (include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(CNCL_ROOT_DIR CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
|
||||||
|
endif()
|
|
@ -0,0 +1,196 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
import numpy as np
|
||||||
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
||||||
|
help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.gen_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
load_inputs(stub, world_size, rank)
|
||||||
|
# stub.tune()
|
||||||
|
stub.run()
|
||||||
|
# get outputs
|
||||||
|
time.sleep(0.01)
|
||||||
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
|
# bench
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n):
|
||||||
|
stub.run()
|
||||||
|
end = time.time()
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||||
|
results = np.load(f"./data/output.npy")
|
||||||
|
outputs = run_model(model, runtime, world_size, rank)
|
||||||
|
print("answer argmax:", np.argmax(results))
|
||||||
|
print("output argmax:", np.argmax(outputs))
|
||||||
|
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||||
|
getDiff(results, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
|
):
|
||||||
|
dist_name = name + "_dist"
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
onnx.save_model(
|
||||||
|
model,
|
||||||
|
f"./{dist_name}_rank{rank}.onnx",
|
||||||
|
save_as_external_data=True,
|
||||||
|
location=extern_path,
|
||||||
|
)
|
||||||
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
|
runtime = backend.BangRuntime(local_rank)
|
||||||
|
# print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
run_and_compare(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def start_single(name, model):
|
||||||
|
runtime = backend.BangRuntime(0)
|
||||||
|
run_and_compare(name, model, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_output(model):
|
||||||
|
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||||
|
runtime = backend.BangRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
position_id = 0
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = tensor.copyout_numpy()
|
||||||
|
if np.issubdtype(input.dtype, np.integer):
|
||||||
|
if input.size == 1:
|
||||||
|
# input = np.array([position_id])
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
else:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
elif input.dtype == np.bool_:
|
||||||
|
input = np.random.randint(0,2,size=input.shape) > 0
|
||||||
|
else:
|
||||||
|
if i == 0:
|
||||||
|
input = np.ones(input.shape).astype(input.dtype)
|
||||||
|
position_id = input.shape[-1] - 1
|
||||||
|
else:
|
||||||
|
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
np.save(f"./data/input_{i}", input)
|
||||||
|
stub.run()
|
||||||
|
time.sleep(0.01)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
np.save(f"./data/output", output)
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(f"./data/input_{i}.npy")
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
def getDiff(base, test):
|
||||||
|
absolute_diff = np.abs(np.subtract(base, test))
|
||||||
|
max_absolute_diff = np.max(absolute_diff)
|
||||||
|
|
||||||
|
baseCopy = base.astype(np.float64).ravel()
|
||||||
|
testCopy = test.astype(np.float64).ravel()
|
||||||
|
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||||
|
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||||
|
max_relative_diff = upValue / downValue
|
||||||
|
print(f"Max absolute difference: {max_absolute_diff}\n"
|
||||||
|
f"Max relative difference: {max_relative_diff}")
|
||||||
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print("Generate inputs and outputs.")
|
||||||
|
p = mp.Process(target=generate_input_output, args=[model])
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# run single process.
|
||||||
|
# use standalone process to isolate cuda.
|
||||||
|
print("run model by single MLU.")
|
||||||
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"run model by {world_size} MLUs in parallel.")
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -115,7 +115,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||||
out_dims[s_dim] //= tp_world_size
|
out_dims[s_dim] //= tp_world_size
|
||||||
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
|
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
|
||||||
# node.input[1] = node.output[0] + "_shape"
|
node.input[1] = node.output[0] + "_shape"
|
||||||
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
|
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
|
||||||
place[node.output[0]] = Shard(s_dim)
|
place[node.output[0]] = Shard(s_dim)
|
||||||
|
|
||||||
|
|
|
@ -7,17 +7,19 @@ namespace infini {
|
||||||
class BangRuntimeObj : public RuntimeObj {
|
class BangRuntimeObj : public RuntimeObj {
|
||||||
private:
|
private:
|
||||||
cnnlHandle_t cnnl;
|
cnnlHandle_t cnnl;
|
||||||
|
cnrtQueue_t queue;
|
||||||
|
std::unique_ptr<CommunicatorObj> comm;
|
||||||
BangPtr workspace;
|
BangPtr workspace;
|
||||||
size_t workspaceSize;
|
size_t workspaceSize;
|
||||||
mutable size_t cursor;
|
mutable size_t cursor;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
BangRuntimeObj() : RuntimeObj(Device::BANG) {
|
explicit BangRuntimeObj(int deviceId = 0)
|
||||||
|
: RuntimeObj(Device::BANG, deviceId) {
|
||||||
cnInit(0);
|
cnInit(0);
|
||||||
CNdev dev;
|
CNdev dev;
|
||||||
cnDeviceGet(&dev, 0);
|
cnDeviceGet(&dev, deviceId);
|
||||||
checkBangError(cnrtSetDevice(dev));
|
checkBangError(cnrtSetDevice(dev));
|
||||||
cnrtQueue_t queue;
|
|
||||||
checkBangError(cnrtQueueCreate(&queue));
|
checkBangError(cnrtQueueCreate(&queue));
|
||||||
|
|
||||||
checkCnnlError(cnnlCreate(&cnnl));
|
checkCnnlError(cnnlCreate(&cnnl));
|
||||||
|
@ -30,6 +32,7 @@ class BangRuntimeObj : public RuntimeObj {
|
||||||
}
|
}
|
||||||
virtual ~BangRuntimeObj() {
|
virtual ~BangRuntimeObj() {
|
||||||
dealloc(workspace);
|
dealloc(workspace);
|
||||||
|
checkBangError(cnrtQueueDestroy(queue));
|
||||||
checkCnnlError(cnnlDestroy(cnnl));
|
checkCnnlError(cnnlDestroy(cnnl));
|
||||||
}
|
}
|
||||||
string toString() const override;
|
string toString() const override;
|
||||||
|
@ -73,10 +76,9 @@ class BangRuntimeObj : public RuntimeObj {
|
||||||
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
||||||
CNRT_MEM_TRANS_DIR_PEER2PEER));
|
CNRT_MEM_TRANS_DIR_PEER2PEER));
|
||||||
}
|
}
|
||||||
|
void initComm(const string &name, int worldSize, int rank) final;
|
||||||
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
CommunicatorObj &getCommunicator() const override { return *comm; }
|
||||||
|
cnrtQueue_t getBangQueue() const { return queue; }
|
||||||
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
#pragma once
|
||||||
|
#include "bang_common.h"
|
||||||
|
#include "core/communicator.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <cncl.h>
|
||||||
|
#include <cnrt.h>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class CnclCommunicatorObj final : public CommunicatorObj {
|
||||||
|
private:
|
||||||
|
cnclComm_t *comms;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CnclCommunicatorObj(const string &name, int worldSize, int rank)
|
||||||
|
: CommunicatorObj(worldSize, rank) {
|
||||||
|
const std::string filePath("./" + name + "_cncl_id.bin");
|
||||||
|
cnclCliqueId clique_id;
|
||||||
|
if (rank == 0) {
|
||||||
|
CNCL_CHECK(cnclGetCliqueId(&clique_id));
|
||||||
|
std::ofstream ofs(filePath, std::ios::binary);
|
||||||
|
ofs.write((char *)&clique_id, sizeof(cnclCliqueId));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto begin = std::chrono::steady_clock::now();
|
||||||
|
while (!std::filesystem::exists(filePath)) {
|
||||||
|
auto now = std::chrono::steady_clock::now();
|
||||||
|
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||||
|
"time limit (10s) exceeded.");
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
std::ifstream ifs(filePath, std::ios::binary);
|
||||||
|
ifs.read((char *)&clique_id, sizeof(cnclCliqueId));
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_comms = 1;
|
||||||
|
int *dev_list = new int[num_comms];
|
||||||
|
int *rank_list = new int[num_comms];
|
||||||
|
comms = new cnclComm_t[num_comms];
|
||||||
|
uint32_t num_dev = 0;
|
||||||
|
checkBangError(cnrtGetDeviceCount(&num_dev));
|
||||||
|
|
||||||
|
for (int i = 0; i < num_comms; i++) {
|
||||||
|
rank_list[i] = rank;
|
||||||
|
dev_list[i] = rank_list[i] % num_dev;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNCL_CHECK(cnclInitComms(comms, num_comms, dev_list, rank_list,
|
||||||
|
worldSize, &clique_id));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
std::filesystem::remove(filePath);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete[] dev_list;
|
||||||
|
delete[] rank_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CnclCommunicatorObj() {
|
||||||
|
CNCL_CHECK(cnclDestroyComms(comms, 1));
|
||||||
|
delete[] comms;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the actual cnclComm_t
|
||||||
|
cnclComm_t getCnclComm() { return comms[0]; }
|
||||||
|
|
||||||
|
virtual string toString() const final {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "CNCL communicator";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -8,7 +8,9 @@
|
||||||
#if USE_CUDA
|
#if USE_CUDA
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#endif
|
#endif
|
||||||
|
#if USE_BANG
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#endif
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
// TODO: how to deal with this
|
// TODO: how to deal with this
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
#include "bang/bang_runtime.h"
|
#include "bang/bang_runtime.h"
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
#include "core/perf_engine.h"
|
#include "core/perf_engine.h"
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -59,4 +62,15 @@ void BangRuntimeObj::sync() const { cnrtSyncDevice(); }
|
||||||
|
|
||||||
string BangRuntimeObj::toString() const { return "BANG Runtime"; }
|
string BangRuntimeObj::toString() const { return "BANG Runtime"; }
|
||||||
|
|
||||||
|
void BangRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||||
|
IT_ASSERT(worldSize > 0);
|
||||||
|
IT_ASSERT(rank >= 0);
|
||||||
|
IT_ASSERT(rank < worldSize);
|
||||||
|
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
comm = std::make_unique<CnclCommunicatorObj>(name, worldSize, rank);
|
||||||
|
#else
|
||||||
|
IT_TODO_HALT_MSG("Not compiled with CNCL.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -399,7 +399,9 @@ void init_graph_builder(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_BANG
|
#ifdef USE_BANG
|
||||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||||
m, "BangRuntime");
|
m, "BangRuntime")
|
||||||
|
.def(py::init<int>(), py::arg("device") = 0)
|
||||||
|
.def("init_comm", &BangRuntimeObj::initComm);
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_KUNLUN
|
#ifdef USE_KUNLUN
|
||||||
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include <thread>
|
||||||
|
namespace infini {
|
||||||
|
class AllGatherCNCL : public BangKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllGatherObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
int world_size = op->getWorldSize();
|
||||||
|
// Check if world size info in operator matches runtime
|
||||||
|
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
|
||||||
|
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
BangPtr output_temp =
|
||||||
|
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
|
||||||
|
// void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
// IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
checkBangError(cnrtMalloc(&output_temp,
|
||||||
|
op->getInputs(0)->getBytes() * world_size));
|
||||||
|
size_t bytes = op->getInputs(0)->getBytes();
|
||||||
|
size_t count = bytes / op->getDType().getSize();
|
||||||
|
|
||||||
|
cnclComm_t comm =
|
||||||
|
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getCnclComm();
|
||||||
|
cnrtQueue_t queue = context->getBangQueue();
|
||||||
|
CNCL_CHECK(
|
||||||
|
cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue));
|
||||||
|
checkBangError(cnrtQueueSync(queue));
|
||||||
|
for (int i = 0; i < world_size; ++i) {
|
||||||
|
Tensor output = op->getOutput(i);
|
||||||
|
context->copyBlobInsideRuntime(
|
||||||
|
output->getRawDataPtr<float *>(),
|
||||||
|
static_cast<float *>(output_temp) + i * count, bytes);
|
||||||
|
}
|
||||||
|
checkBangError(cnrtFree(output_temp));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32,
|
||||||
|
AllGatherCNCL, "AllGather_CNCL_BANG_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,53 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include <thread>
|
||||||
|
namespace infini {
|
||||||
|
class AllReduceCNCL : public BangKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllReduceBaseObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->size();
|
||||||
|
cnclComm_t comm =
|
||||||
|
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getCnclComm();
|
||||||
|
cnrtQueue_t queue = context->getBangQueue();
|
||||||
|
// checkBangError(cnrtQueueSync(queue));
|
||||||
|
CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(),
|
||||||
|
comm, queue));
|
||||||
|
checkBangError(cnrtQueueSync(queue));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual cnclReduceOp_t getRedOp() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceSumCNCL : public AllReduceCNCL {
|
||||||
|
cnclReduceOp_t getRedOp() const override { return cnclSum; }
|
||||||
|
};
|
||||||
|
class AllReduceProdCNCL : public AllReduceCNCL {
|
||||||
|
cnclReduceOp_t getRedOp() const override { return cnclProd; }
|
||||||
|
};
|
||||||
|
class AllReduceMinCNCL : public AllReduceCNCL {
|
||||||
|
cnclReduceOp_t getRedOp() const override { return cnclMin; }
|
||||||
|
};
|
||||||
|
class AllReduceMaxCNCL : public AllReduceCNCL {
|
||||||
|
cnclReduceOp_t getRedOp() const override { return cnclMax; }
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32,
|
||||||
|
AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32,
|
||||||
|
AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32,
|
||||||
|
AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32,
|
||||||
|
AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,34 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include <thread>
|
||||||
|
namespace infini {
|
||||||
|
class BroadcastCNCL : public BangKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<BroadcastObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||||
|
|
||||||
|
cnclComm_t comm =
|
||||||
|
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getCnclComm();
|
||||||
|
cnrtQueue_t queue = context->getBangQueue();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32,
|
||||||
|
op->getRoot(), comm, queue));
|
||||||
|
checkBangError(cnrtQueueSync(queue));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32,
|
||||||
|
BroadcastCNCL, "Broadcast_CNCL_BANG_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -23,6 +23,8 @@ class GatherCnnl : public BangKernelWithoutConfig {
|
||||||
CNNL_DTYPE_FLOAT, aDim.size(),
|
CNNL_DTYPE_FLOAT, aDim.size(),
|
||||||
aDim.data()));
|
aDim.data()));
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||||
|
checkCnnlError(
|
||||||
|
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
|
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
|
||||||
CNNL_DTYPE_INT32, bDim.size(),
|
CNNL_DTYPE_INT32, bDim.size(),
|
||||||
bDim.data()));
|
bDim.data()));
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
|
#include "operators/reduce.h"
|
||||||
#include "bang/bang_kernel_without_config.h"
|
#include "bang/bang_kernel_without_config.h"
|
||||||
#include "bang/bang_runtime.h"
|
#include "bang/bang_runtime.h"
|
||||||
#include "operators/reduce.h"
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class ReduceMeanCnnl : public BangKernelWithoutConfig {
|
class ReduceCnnlBase : public BangKernelWithoutConfig {
|
||||||
|
virtual cnnlReduceOp_t getReduceOp() const = 0;
|
||||||
|
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
auto op = as<ReduceMeanObj>(_op);
|
auto op = as<ReduceBaseObj>(_op);
|
||||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
@ -34,7 +36,7 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlReduceDescriptor_t reduceDesc;
|
cnnlReduceDescriptor_t reduceDesc;
|
||||||
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
|
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
|
||||||
checkCnnlError(cnnlSetReduceDescriptor_v2(
|
checkCnnlError(cnnlSetReduceDescriptor_v2(
|
||||||
reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG,
|
reduceDesc, axes.data(), axes.size(), getReduceOp(),
|
||||||
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
|
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
|
||||||
CNNL_32BIT_INDICES, 0.0));
|
CNNL_32BIT_INDICES, 0.0));
|
||||||
|
|
||||||
|
@ -63,7 +65,17 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ReduceMeanCnnl : public ReduceCnnlBase {
|
||||||
|
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_AVG; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReduceSumCnnl : public ReduceCnnlBase {
|
||||||
|
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; }
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
|
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
|
||||||
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
|
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32,
|
||||||
|
ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32");
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
|
@ -27,6 +27,8 @@ class CopyBang : public BangKernelWithoutConfig {
|
||||||
// reshape/flatten/identity all act as copying from input to output.
|
// reshape/flatten/identity all act as copying from input to output.
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang,
|
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang,
|
||||||
"Reshape_BANG_Float32");
|
"Reshape_BANG_Float32");
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang,
|
||||||
|
"Reshape_BANG_Int64");
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
|
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
|
||||||
"Flatten_BANG_Float32");
|
"Flatten_BANG_Float32");
|
||||||
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,
|
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class SliceCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<SliceObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
auto starts = op->getStarts();
|
||||||
|
auto ends = op->getEnds();
|
||||||
|
auto steps = op->getSteps();
|
||||||
|
|
||||||
|
int32_t starts_array[starts.size()];
|
||||||
|
int32_t ends_array[ends.size()];
|
||||||
|
int32_t steps_array[steps.size()];
|
||||||
|
|
||||||
|
for (size_t i = 0; i < starts.size(); i++) {
|
||||||
|
starts_array[i] = starts[i];
|
||||||
|
ends_array[i] = ends[i];
|
||||||
|
steps_array[i] = steps[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
|
int aDim_size = aDim.size();
|
||||||
|
int aDim_array[aDim_size];
|
||||||
|
for (int i = 0; i < aDim_size; ++i) {
|
||||||
|
aDim_array[i] = aDim[i];
|
||||||
|
}
|
||||||
|
auto cDim = op->getOutput()->getDims();
|
||||||
|
int cDim_size = cDim.size();
|
||||||
|
int cDim_array[cDim_size];
|
||||||
|
for (int i = 0; i < cDim_size; ++i) {
|
||||||
|
cDim_array[i] = cDim[i];
|
||||||
|
}
|
||||||
|
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||||
|
// input
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array));
|
||||||
|
// output
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array));
|
||||||
|
|
||||||
|
cnnlStatus_t stat =
|
||||||
|
cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array,
|
||||||
|
ends_array, steps_array, cDesc, cData);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl,
|
||||||
|
"Slice_cnnl_BANG_Float32");
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,58 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void allReduceSum(float *data, int deviceId) {
|
||||||
|
// Create Runtime and setup communication
|
||||||
|
BangRuntimeObj *bang_runtime = new BangRuntimeObj(deviceId);
|
||||||
|
int rank = deviceId;
|
||||||
|
bang_runtime->initComm("test_cncl_comm", WORLD_SIZE, rank);
|
||||||
|
cnclComm_t comm =
|
||||||
|
dynamic_cast<CnclCommunicatorObj &>(bang_runtime->getCommunicator())
|
||||||
|
.getCnclComm();
|
||||||
|
cnrtQueue_t queue = bang_runtime->getBangQueue();
|
||||||
|
// Copy data
|
||||||
|
float *data_mlu;
|
||||||
|
checkBangError(cnrtMalloc((void **)&data_mlu, sizeof(float)));
|
||||||
|
checkBangError(
|
||||||
|
cnrtMemcpy(data_mlu, data, sizeof(float), cnrtMemcpyHostToDev));
|
||||||
|
// Do AllReduce
|
||||||
|
CNCL_CHECK(
|
||||||
|
cnclAllReduce(data_mlu, data_mlu, 1, cnclFloat, cnclSum, comm, queue));
|
||||||
|
|
||||||
|
checkBangError(cnrtQueueSync(queue));
|
||||||
|
// Copy data back and sync device
|
||||||
|
checkBangError(
|
||||||
|
cnrtMemcpy(data, data_mlu, sizeof(float), cnrtMemcpyDevToHost));
|
||||||
|
ASSERT_EQ(*data, 5.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup communication between 2 threads, each controlling 1 MLU.
|
||||||
|
// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}.
|
||||||
|
TEST(CNCL, multi_mlu_communication) {
|
||||||
|
float data[] = {1.0, 4.0};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allReduceSum(&data[i], i);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,60 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <cncl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void allGather(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<vector<float>> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
|
||||||
|
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllGatherObj>(input, std::nullopt, WORLD_SIZE);
|
||||||
|
// Copy data from CPU to MLU
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
// Run operation
|
||||||
|
bangRuntime->run(g);
|
||||||
|
// Copy output from MLU to CPU
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
auto result = op->getOutputs()[i]->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData(ans[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_AllGather, run) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<vector<float>> ans = {{2., 3.}, {5., 6.}};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allGather("test_all_gather", i, data[i], ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,124 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <cncl.h>
|
||||||
|
#include <future>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <typename OperatorObj>
|
||||||
|
void allReduce(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
|
||||||
|
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<OperatorObj>(input, nullptr);
|
||||||
|
// Copy data from CPU to MLU
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
// Run operation
|
||||||
|
bangRuntime->run(g);
|
||||||
|
// Copy output from MLU to CPU
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_AllReduce, sum) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {7., 9.};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allReduce<AllReduceSumObj>("test_allreduce_sum", i, data[i], ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_AllReduce, prod) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {10., 18.};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allReduce<AllReduceProdObj>("test_allreduce_prod", i, data[i], ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_AllReduce, min) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {2., 3.};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allReduce<AllReduceMinObj>("test_allreduce_min", i, data[i], ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_AllReduce, max) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {5., 6.};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
allReduce<AllReduceMaxObj>("test_allreduce_max", i, data[i], ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,65 @@
|
||||||
|
#ifdef INFINI_USE_CNCL
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "bang/cncl_communicator.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <cncl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
static int root = 0;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void broadcast(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
|
||||||
|
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<BroadcastObj>(input, nullptr, root);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
// Only rank 0 has the data
|
||||||
|
if (deviceID == root) {
|
||||||
|
input->copyin(data);
|
||||||
|
}
|
||||||
|
// Run broadcast operation
|
||||||
|
bangRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_Broadcast, run) {
|
||||||
|
// Only 1 device gets data. Every rank should have the same data after
|
||||||
|
// broadcast.
|
||||||
|
vector<float> data = {2., 3., 5., 6.};
|
||||||
|
vector<float> ans = {2., 3., 5., 6.};
|
||||||
|
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
pid_t pid = fork();
|
||||||
|
if (pid == 0) {
|
||||||
|
// Child process
|
||||||
|
broadcast("test_broadcast", i, data, ans);
|
||||||
|
exit(0); // Ensure child process exits to avoid unnecessary
|
||||||
|
// repetition in parent
|
||||||
|
} else if (pid < 0) {
|
||||||
|
std::cerr << "Error creating process" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all child processes to finish
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
wait(NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,82 @@
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/reduce.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <typename ReduceObjT>
|
||||||
|
void test_reduce(const Shape &shape, const vector<float> &data,
|
||||||
|
const optional<const vector<int>> &axis, bool keepDims,
|
||||||
|
const vector<float> &ExpectData) {
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||||
|
|
||||||
|
// Build BANG graph
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op = g->addOp<ReduceObjT>(i, nullptr, axis, keepDims);
|
||||||
|
|
||||||
|
// allocate BANG memory
|
||||||
|
g->dataMalloc();
|
||||||
|
i->copyin(data);
|
||||||
|
|
||||||
|
// Execute on BANG
|
||||||
|
bangRuntime->run(g);
|
||||||
|
|
||||||
|
// clone BANG output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto ocpu = o->clone(cpuRuntime);
|
||||||
|
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(ocpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_ReduceMean, run) {
|
||||||
|
test_reduce<ReduceMeanObj>(
|
||||||
|
Shape{3, 2, 2}, vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||||
|
std::nullopt, true, vector<float>{18.25});
|
||||||
|
test_reduce<ReduceMeanObj>(
|
||||||
|
Shape{1, 3, 2, 2, 1},
|
||||||
|
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt,
|
||||||
|
false, vector<float>{18.25});
|
||||||
|
|
||||||
|
test_reduce<ReduceMeanObj>(
|
||||||
|
Shape{2, 3, 2, 2},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||||
|
test_reduce<ReduceMeanObj>(
|
||||||
|
Shape{2, 3, 2, 2, 1},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BANG_ReduceSum, run) {
|
||||||
|
test_reduce<ReduceSumObj>(Shape{3, 2, 2},
|
||||||
|
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||||
|
std::nullopt, true, vector<float>{12});
|
||||||
|
test_reduce<ReduceSumObj>(Shape{1, 3, 2, 2, 1},
|
||||||
|
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||||
|
std::nullopt, false, vector<float>{12});
|
||||||
|
|
||||||
|
test_reduce<ReduceSumObj>(
|
||||||
|
Shape{2, 3, 2, 2},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, false, vector<float>{30, 36, 102, 108});
|
||||||
|
test_reduce<ReduceSumObj>(
|
||||||
|
Shape{2, 3, 2, 2, 1},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, true, vector<float>{30, 36, 102, 108});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,39 @@
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(BANG_Slice, run) {
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu =
|
||||||
|
make_ref<TensorObj>(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Build CUDA graph;
|
||||||
|
Graph g = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op =
|
||||||
|
g->addOp<SliceObj>(i, nullptr, vector<int>{1, 1}, vector<int>{2, 5},
|
||||||
|
vector<int>{0, 3}, std::nullopt);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
i->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
bangRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto cpuo = o->clone(cpuRuntime);
|
||||||
|
// bangPrintTensor(o);
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(cpuo->equalData(vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue