diff --git a/CMakeLists.txt b/CMakeLists.txt index 1101a8c2..70508c79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ if(USE_CUDA) message("CMake 3.18 or higher is required for setting CUDAToolkit") cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit else() - cmake_minimum_required(VERSION 3.12) + cmake_minimum_required(VERSION 3.17) endif() include(CMakeDependentOption) @@ -245,6 +245,7 @@ if(USE_BANG) find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") + find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) @@ -261,7 +262,13 @@ if(USE_BANG) # BangC Kernels ################################################################################ - target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) + target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) + if (BUILD_DIST) + message(STATUS "Add BUILD_DIST, use CNCL with BANG") + + add_compile_definitions(INFINI_USE_CNCL=1) + + endif() endif() if(USE_KUNLUN) @@ -324,6 +331,7 @@ if(BUILD_TEST) endif() if (USE_BANG) build_test(test/kernels/bang/*.cc) + build_test(test/bang/*.cc) endif() if (USE_KUNLUN) build_test(test/kernels/kunlun/*.cc) diff --git a/Makefile b/Makefile index 302f47b8..d21a406b 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) +CMAKE_OPT += -DBUILD_DIST=ON CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) diff --git a/cmake/FindCNCL.cmake b/cmake/FindCNCL.cmake new file mode 100644 index 00000000..31351dda --- /dev/null +++ b/cmake/FindCNCL.cmake @@ -0,0 +1,76 @@ +SET(CNCL_LIB_SEARCH_PATHS $ENV{NEUWARE_HOME}/lib64) +SET(CNCL_INCLUDE_SEARCH_PATHS $ENV{NEUWARE_HOME}/include) + +set(CNCL_INCLUDE_DIR $ENV{NEUWARE_HOME}/include) +set(CNCL_LIB_DIR $ENV{NEUWARE_HOME}/lib64) +set(CNCL_VERSION $ENV{CNCL_VERSION} CACHE STRING "Version of CNCL to build with") + +if ($ENV{CNCL_ROOT_DIR}) + message(WARNING "CNCL_ROOT_DIR is deprecated. Please set CNCL_ROOT instead.") +endif() +list(APPEND CNCL_ROOT $ENV{CNCL_ROOT_DIR} ${MLU_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. CNCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CNCL_ROOT}) + +find_path(CNCL_INCLUDE_DIRS + NAMES cncl.h + HINTS ${CNCL_INCLUDE_DIR}) + +if (USE_STATIC_CNCL) + MESSAGE(STATUS "USE_STATIC_CNCL is set. Linking with static CNCL library.") + SET(CNCL_LIBNAME "CNCL_static") + if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(CNCL_LIBNAME "cncl") + if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(CNCL_LIBRARIES + NAMES ${CNCL_LIBNAME} + HINTS ${CNCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(CNCL DEFAULT_MSG CNCL_INCLUDE_DIRS CNCL_LIBRARIES) + +if(CNCL_FOUND) # obtaining CNCL version and some sanity checks + set (CNCL_HEADER_FILE "${CNCL_INCLUDE_DIRS}/cncl.h") + message (STATUS "Determining CNCL version from ${CNCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${CNCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(CNCL_VERSION_CODE CNCL.h CNCL_VERSION_DEFINED) + + if (CNCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_cncl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << CNCL_MAJOR << '.' << CNCL_MINOR << '.' << CNCL_PATCH << std::endl; + int x; + CNCLGetVersion(&x); + return x == CNCL_VERSION_CODE; + } +") + try_run(CNCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE CNCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CNCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${CNCL_LIBRARIES}) + if (NOT CNCL_VERSION_MATCHED) + message(FATAL_ERROR "Found CNCL header version and library version do not match! \ +(include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES}) Please set CNCL_INCLUDE_DIR and CNCL_LIB_DIR manually.") + endif() + message(STATUS "CNCL version: ${CNCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "CNCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found CNCL (include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES})") + mark_as_advanced(CNCL_ROOT_DIR CNCL_INCLUDE_DIRS CNCL_LIBRARIES) +endif() diff --git a/examples/distributed/bang_launch.py b/examples/distributed/bang_launch.py new file mode 100644 index 00000000..518935b5 --- /dev/null +++ b/examples/distributed/bang_launch.py @@ -0,0 +1,196 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx", + help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(model, runtime, world_size=1, rank=0, n=10): + stub = OnnxStub(model, runtime) + load_inputs(stub, world_size, rank) + # stub.tune() + stub.run() + # get outputs + time.sleep(0.01) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() + + # bench + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, model, runtime, world_size=1, rank = 0): + results = np.load(f"./data/output.npy") + outputs = run_model(model, runtime, world_size, rank) + print("answer argmax:", np.argmax(results)) + print("output argmax:", np.argmax(outputs)) + #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) + getDiff(results, outputs) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + runtime = backend.BangRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_and_compare(name, model, runtime, world_size, rank) + + +def start_single(name, model): + runtime = backend.BangRuntime(0) + run_and_compare(name, model, runtime) + + +def generate_input_output(model): + os.makedirs(os.path.dirname("./data/"), exist_ok=True) + runtime = backend.BangRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + # input = np.array([position_id]) + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(f"./data/input_{i}", input) + stub.run() + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + if np.isnan(output).any(): + print("Nan in output") + np.save(f"./data/output", output) + + +def load_inputs(stub, world_size=1, rank=0): + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = np.load(f"./data/input_{i}.npy") + if all(x == y for x,y in zip(input.shape,tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + +def getDiff(base, test): + absolute_diff = np.abs(np.subtract(base, test)) + max_absolute_diff = np.max(absolute_diff) + + baseCopy = base.astype(np.float64).ravel() + testCopy = test.astype(np.float64).ravel() + upValue = np.sum(np.abs(baseCopy - testCopy)) + downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) + max_relative_diff = upValue / downValue + print(f"Max absolute difference: {max_absolute_diff}\n" + f"Max relative difference: {max_relative_diff}") + return max_absolute_diff, max_relative_diff + + +def main(): + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() + + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + p = mp.Process(target=generate_input_output, args=[model]) + p.start() + p.join() + return + + # run single process. + # use standalone process to isolate cuda. + print("run model by single MLU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} MLUs in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 3ddf2ead..1214b6b3 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -115,7 +115,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): assert out_dims[s_dim] % tp_world_size == 0, out_dims out_dims[s_dim] //= tp_world_size # if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others. - # node.input[1] = node.output[0] + "_shape" + node.input[1] = node.output[0] + "_shape" data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1]) place[node.output[0]] = Shard(s_dim) diff --git a/include/bang/bang_runtime.h b/include/bang/bang_runtime.h index 2dde7756..e1ca6b38 100644 --- a/include/bang/bang_runtime.h +++ b/include/bang/bang_runtime.h @@ -7,17 +7,19 @@ namespace infini { class BangRuntimeObj : public RuntimeObj { private: cnnlHandle_t cnnl; + cnrtQueue_t queue; + std::unique_ptr comm; BangPtr workspace; size_t workspaceSize; mutable size_t cursor; public: - BangRuntimeObj() : RuntimeObj(Device::BANG) { + explicit BangRuntimeObj(int deviceId = 0) + : RuntimeObj(Device::BANG, deviceId) { cnInit(0); CNdev dev; - cnDeviceGet(&dev, 0); + cnDeviceGet(&dev, deviceId); checkBangError(cnrtSetDevice(dev)); - cnrtQueue_t queue; checkBangError(cnrtQueueCreate(&queue)); checkCnnlError(cnnlCreate(&cnnl)); @@ -30,6 +32,7 @@ class BangRuntimeObj : public RuntimeObj { } virtual ~BangRuntimeObj() { dealloc(workspace); + checkBangError(cnrtQueueDestroy(queue)); checkCnnlError(cnnlDestroy(cnnl)); } string toString() const override; @@ -73,10 +76,9 @@ class BangRuntimeObj : public RuntimeObj { checkBangError(cnrtMemcpy(dst, const_cast(src), bytes, CNRT_MEM_TRANS_DIR_PEER2PEER)); } - - void initComm(const string &, int, int) override { IT_TODO_HALT(); } - - CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } + void initComm(const string &name, int worldSize, int rank) final; + CommunicatorObj &getCommunicator() const override { return *comm; } + cnrtQueue_t getBangQueue() const { return queue; } private: void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; diff --git a/include/bang/cncl_communicator.h b/include/bang/cncl_communicator.h new file mode 100644 index 00000000..0999686c --- /dev/null +++ b/include/bang/cncl_communicator.h @@ -0,0 +1,79 @@ +#pragma once +#include "bang_common.h" +#include "core/communicator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini { + +class CnclCommunicatorObj final : public CommunicatorObj { + private: + cnclComm_t *comms; + + public: + CnclCommunicatorObj(const string &name, int worldSize, int rank) + : CommunicatorObj(worldSize, rank) { + const std::string filePath("./" + name + "_cncl_id.bin"); + cnclCliqueId clique_id; + if (rank == 0) { + CNCL_CHECK(cnclGetCliqueId(&clique_id)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *)&clique_id, sizeof(cnclCliqueId)); + + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + _IT_ASSERT_2(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *)&clique_id, sizeof(cnclCliqueId)); + } + + int num_comms = 1; + int *dev_list = new int[num_comms]; + int *rank_list = new int[num_comms]; + comms = new cnclComm_t[num_comms]; + uint32_t num_dev = 0; + checkBangError(cnrtGetDeviceCount(&num_dev)); + + for (int i = 0; i < num_comms; i++) { + rank_list[i] = rank; + dev_list[i] = rank_list[i] % num_dev; + } + + CNCL_CHECK(cnclInitComms(comms, num_comms, dev_list, rank_list, + worldSize, &clique_id)); + + if (rank == 0) { + std::filesystem::remove(filePath); + } + + delete[] dev_list; + delete[] rank_list; + } + + ~CnclCommunicatorObj() { + CNCL_CHECK(cnclDestroyComms(comms, 1)); + delete[] comms; + } + + // Get the actual cnclComm_t + cnclComm_t getCnclComm() { return comms[0]; } + + virtual string toString() const final { + std::ostringstream oss; + oss << "CNCL communicator"; + return oss.str(); + } +}; + +} // namespace infini diff --git a/include/core/tensor.h b/include/core/tensor.h index cb09261a..95229c14 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -8,7 +8,9 @@ #if USE_CUDA #include "cuda/cuda_runtime.h" #endif - +#if USE_BANG +#include "bang/bang_runtime.h" +#endif namespace infini { // TODO: how to deal with this diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index c9f9a933..2f16b500 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -1,6 +1,9 @@ #include "bang/bang_runtime.h" #include "core/kernel.h" #include "core/perf_engine.h" +#ifdef INFINI_USE_CNCL +#include "bang/cncl_communicator.h" +#endif namespace infini { @@ -59,4 +62,15 @@ void BangRuntimeObj::sync() const { cnrtSyncDevice(); } string BangRuntimeObj::toString() const { return "BANG Runtime"; } +void BangRuntimeObj::initComm(const string &name, int worldSize, int rank) { + IT_ASSERT(worldSize > 0); + IT_ASSERT(rank >= 0); + IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; +#ifdef INFINI_USE_CNCL + comm = std::make_unique(name, worldSize, rank); +#else + IT_TODO_HALT_MSG("Not compiled with CNCL."); +#endif +} } // namespace infini diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index c23009b5..eadd4a4e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -399,7 +399,9 @@ void init_graph_builder(py::module &m) { #endif #ifdef USE_BANG py::class_, RuntimeObj>( - m, "BangRuntime"); + m, "BangRuntime") + .def(py::init(), py::arg("device") = 0) + .def("init_comm", &BangRuntimeObj::initComm); #endif #ifdef USE_KUNLUN py::class_, RuntimeObj>( diff --git a/src/kernels/bang/all_gather.cc b/src/kernels/bang/all_gather.cc new file mode 100644 index 00000000..d44569fe --- /dev/null +++ b/src/kernels/bang/all_gather.cc @@ -0,0 +1,49 @@ +#ifdef INFINI_USE_CNCL +#include "operators/all_gather.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class AllGatherCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int world_size = op->getWorldSize(); + // Check if world size info in operator matches runtime + IT_ASSERT(world_size == context->getCommunicator().getWorldSize()); + + void *input = op->getInputs(0)->getRawDataPtr(); + BangPtr output_temp = + context->getWorkspace(op->getInputs(0)->getBytes() * world_size); + // void *output = op->getOutput()->getRawDataPtr(); + // IT_ASSERT(op->getDType() == DataType::Float32); + checkBangError(cnrtMalloc(&output_temp, + op->getInputs(0)->getBytes() * world_size)); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / op->getDType().getSize(); + + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + CNCL_CHECK( + cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue)); + checkBangError(cnrtQueueSync(queue)); + for (int i = 0; i < world_size; ++i) { + Tensor output = op->getOutput(i); + context->copyBlobInsideRuntime( + output->getRawDataPtr(), + static_cast(output_temp) + i * count, bytes); + } + checkBangError(cnrtFree(output_temp)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32, + AllGatherCNCL, "AllGather_CNCL_BANG_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/bang/all_reduce.cc b/src/kernels/bang/all_reduce.cc new file mode 100644 index 00000000..4e9266fb --- /dev/null +++ b/src/kernels/bang/all_reduce.cc @@ -0,0 +1,53 @@ +#ifdef INFINI_USE_CNCL +#include "operators/all_reduce.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class AllReduceCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->size(); + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + // checkBangError(cnrtQueueSync(queue)); + CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(), + comm, queue)); + checkBangError(cnrtQueueSync(queue)); + } + + virtual cnclReduceOp_t getRedOp() const = 0; +}; + +class AllReduceSumCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclSum; } +}; +class AllReduceProdCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclProd; } +}; +class AllReduceMinCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclMin; } +}; +class AllReduceMaxCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclMax; } +}; + +REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32, + AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32, + AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32, + AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32, + AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32"); +} // namespace infini +#endif diff --git a/src/kernels/bang/broadcast.cc b/src/kernels/bang/broadcast.cc new file mode 100644 index 00000000..411506c5 --- /dev/null +++ b/src/kernels/bang/broadcast.cc @@ -0,0 +1,34 @@ +#ifdef INFINI_USE_CNCL +#include "operators/broadcast.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class BroadcastCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + // TODO: Using default stream 0 for now. + CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32, + op->getRoot(), comm, queue)); + checkBangError(cnrtQueueSync(queue)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32, + BroadcastCNCL, "Broadcast_CNCL_BANG_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index b5a326fc..dc3ee636 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -23,6 +23,8 @@ class GatherCnnl : public BangKernelWithoutConfig { CNNL_DTYPE_FLOAT, aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); + checkCnnlError( + cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, bDim.size(), bDim.data())); diff --git a/src/kernels/bang/reduce_mean.cc b/src/kernels/bang/reduce.cc similarity index 82% rename from src/kernels/bang/reduce_mean.cc rename to src/kernels/bang/reduce.cc index 1b55c2ca..88d1e645 100644 --- a/src/kernels/bang/reduce_mean.cc +++ b/src/kernels/bang/reduce.cc @@ -1,12 +1,14 @@ +#include "operators/reduce.h" #include "bang/bang_kernel_without_config.h" #include "bang/bang_runtime.h" -#include "operators/reduce.h" namespace infini { -class ReduceMeanCnnl : public BangKernelWithoutConfig { +class ReduceCnnlBase : public BangKernelWithoutConfig { + virtual cnnlReduceOp_t getReduceOp() const = 0; + void compute(const Operator &_op, const RuntimeObj *_context) const override { - auto op = as(_op); + auto op = as(_op); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -34,7 +36,7 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig { cnnlReduceDescriptor_t reduceDesc; checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc)); checkCnnlError(cnnlSetReduceDescriptor_v2( - reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG, + reduceDesc, axes.data(), axes.size(), getReduceOp(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); @@ -63,7 +65,17 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig { } }; +class ReduceMeanCnnl : public ReduceCnnlBase { + cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_AVG; } +}; + +class ReduceSumCnnl : public ReduceCnnlBase { + cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; } +}; + REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32, ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32, + ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/reshape.cc b/src/kernels/bang/reshape.cc index 564ed1d7..f5628a7b 100644 --- a/src/kernels/bang/reshape.cc +++ b/src/kernels/bang/reshape.cc @@ -27,6 +27,8 @@ class CopyBang : public BangKernelWithoutConfig { // reshape/flatten/identity all act as copying from input to output. REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang, "Reshape_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang, + "Reshape_BANG_Int64"); REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang, "Flatten_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang, diff --git a/src/kernels/bang/slice.cc b/src/kernels/bang/slice.cc new file mode 100644 index 00000000..5cc772aa --- /dev/null +++ b/src/kernels/bang/slice.cc @@ -0,0 +1,64 @@ +#include "operators/slice.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" + +namespace infini { +class SliceCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + auto starts = op->getStarts(); + auto ends = op->getEnds(); + auto steps = op->getSteps(); + + int32_t starts_array[starts.size()]; + int32_t ends_array[ends.size()]; + int32_t steps_array[steps.size()]; + + for (size_t i = 0; i < starts.size(); i++) { + starts_array[i] = starts[i]; + ends_array[i] = ends[i]; + steps_array[i] = steps[i]; + } + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + int aDim_size = aDim.size(); + int aDim_array[aDim_size]; + for (int i = 0; i < aDim_size; ++i) { + aDim_array[i] = aDim[i]; + } + auto cDim = op->getOutput()->getDims(); + int cDim_size = cDim.size(); + int cDim_array[cDim_size]; + for (int i = 0; i < cDim_size; ++i) { + cDim_array[i] = cDim[i]; + } + cnnlTensorDescriptor_t aDesc, cDesc; + // input + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array)); + // output + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array)); + + cnnlStatus_t stat = + cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array, + ends_array, steps_array, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, + "Slice_cnnl_BANG_Float32"); +}; // namespace infini diff --git a/test/bang/test_cncl_comm.cc b/test/bang/test_cncl_comm.cc new file mode 100644 index 00000000..50b47434 --- /dev/null +++ b/test/bang/test_cncl_comm.cc @@ -0,0 +1,58 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "test.h" + +static int WORLD_SIZE = 2; + +namespace infini { + +void allReduceSum(float *data, int deviceId) { + // Create Runtime and setup communication + BangRuntimeObj *bang_runtime = new BangRuntimeObj(deviceId); + int rank = deviceId; + bang_runtime->initComm("test_cncl_comm", WORLD_SIZE, rank); + cnclComm_t comm = + dynamic_cast(bang_runtime->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = bang_runtime->getBangQueue(); + // Copy data + float *data_mlu; + checkBangError(cnrtMalloc((void **)&data_mlu, sizeof(float))); + checkBangError( + cnrtMemcpy(data_mlu, data, sizeof(float), cnrtMemcpyHostToDev)); + // Do AllReduce + CNCL_CHECK( + cnclAllReduce(data_mlu, data_mlu, 1, cnclFloat, cnclSum, comm, queue)); + + checkBangError(cnrtQueueSync(queue)); + // Copy data back and sync device + checkBangError( + cnrtMemcpy(data, data_mlu, sizeof(float), cnrtMemcpyDevToHost)); + ASSERT_EQ(*data, 5.0f); +} + +// Setup communication between 2 threads, each controlling 1 MLU. +// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}. +TEST(CNCL, multi_mlu_communication) { + float data[] = {1.0, 4.0}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduceSum(&data[i], i); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_all_gather.cc b/test/kernels/bang/test_bang_all_gather.cc new file mode 100644 index 00000000..038cc7ab --- /dev/null +++ b/test/kernels/bang/test_bang_all_gather.cc @@ -0,0 +1,60 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_gather.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +void allGather(const string taskName, int deviceID, vector data, + vector> ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, WORLD_SIZE); + // Copy data from CPU to MLU + g->dataMalloc(); + input->copyin(data); + // Run operation + bangRuntime->run(g); + // Copy output from MLU to CPU + for (int i = 0; i < WORLD_SIZE; ++i) { + auto result = op->getOutputs()[i]->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(ans[i])); + } +} + +TEST(BANG_AllGather, run) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector> ans = {{2., 3.}, {5., 6.}}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allGather("test_all_gather", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_all_reduce.cc b/test/kernels/bang/test_bang_all_reduce.cc new file mode 100644 index 00000000..a10a9288 --- /dev/null +++ b/test/kernels/bang/test_bang_all_reduce.cc @@ -0,0 +1,124 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_reduce.h" +#include "test.h" +#include +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +template +void allReduce(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr); + // Copy data from CPU to MLU + g->dataMalloc(); + input->copyin(data); + // Run operation + bangRuntime->run(g); + // Copy output from MLU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(BANG_AllReduce, sum) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {7., 9.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_sum", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, prod) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {10., 18.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_prod", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, min) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {2., 3.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_min", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, max) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {5., 6.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_max", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_broadcast.cc b/test/kernels/bang/test_bang_broadcast.cc new file mode 100644 index 00000000..e05666f7 --- /dev/null +++ b/test/kernels/bang/test_bang_broadcast.cc @@ -0,0 +1,65 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/broadcast.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; +static int root = 0; + +namespace infini { + +void broadcast(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr, root); + // Copy data from CPU to GPU + g->dataMalloc(); + // Only rank 0 has the data + if (deviceID == root) { + input->copyin(data); + } + // Run broadcast operation + bangRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(BANG_Broadcast, run) { + // Only 1 device gets data. Every rank should have the same data after + // broadcast. + vector data = {2., 3., 5., 6.}; + vector ans = {2., 3., 5., 6.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + broadcast("test_broadcast", i, data, ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_reduce.cc b/test/kernels/bang/test_bang_reduce.cc new file mode 100644 index 00000000..485e16f0 --- /dev/null +++ b/test/kernels/bang/test_bang_reduce.cc @@ -0,0 +1,82 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/reduce.h" + +#include "test.h" + +namespace infini { + +template +void test_reduce(const Shape &shape, const vector &data, + const optional> &axis, bool keepDims, + const vector &ExpectData) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // Build BANG graph + Graph g = make_ref(bangRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, axis, keepDims); + + // allocate BANG memory + g->dataMalloc(); + i->copyin(data); + + // Execute on BANG + bangRuntime->run(g); + + // clone BANG output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(ExpectData)); +} + +TEST(BANG_ReduceMean, run) { + test_reduce( + Shape{3, 2, 2}, vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reduce( + Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt, + false, vector{18.25}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); +} + +TEST(BANG_ReduceSum, run) { + test_reduce(Shape{3, 2, 2}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, true, vector{12}); + test_reduce(Shape{1, 3, 2, 2, 1}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, false, vector{12}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{30, 36, 102, 108}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{30, 36, 102, 108}); +} + +} // namespace infini diff --git a/test/kernels/bang/test_bang_slice.cc b/test/kernels/bang/test_bang_slice.cc new file mode 100644 index 00000000..0f932409 --- /dev/null +++ b/test/kernels/bang/test_bang_slice.cc @@ -0,0 +1,39 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(BANG_Slice, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph; + Graph g = make_ref(bangRuntime); + auto i = g->cloneTensor(icpu); + auto op = + g->addOp(i, nullptr, vector{1, 1}, vector{2, 5}, + vector{0, 3}, std::nullopt); + + // allocate CUDA memory + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute on CUDA + bangRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + // bangPrintTensor(o); + // check results on CPU + EXPECT_TRUE(cpuo->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); +} +} // namespace infini