Bang cncl (#163)

* MLU CNCL base

* add FindCNCL.cmake, not find -lcncl

* bangPrintFloat not find

* docker:make sucessful, test error

* delete net file and onnxtest.py

* init

* fix cncl

* format

* fix

* format

* fix cncl

* run dist gpt2 on mlu

* format

* fix import error on mlu docker

* run llama single card

* run distributed llama2

* add test for slice/reduce on mlu

* fix cncl related test

* fix format

* format

* delete comments

* change GPU to MLU

* MLU CNCL base

* add FindCNCL.cmake, not find -lcncl

* bangPrintFloat not find

* docker:make sucessful, test error

* delete net file and onnxtest.py

* init

* fix cncl

* format

* fix

* format

* fix cncl

* run dist gpt2 on mlu

* format

* fix import error on mlu docker

* run llama single card

* run distributed llama2

* add test for slice/reduce on mlu

* fix cncl related test

* fix format

* format

* delete comments

* change GPU to MLU

* modify launch script

* fix name

* fix format

* fix gather

* format python script

---------

Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
Co-authored-by: Bolun <chamberlain0w0@gmail.com>
Co-authored-by: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com>
This commit is contained in:
Hardy 2024-01-03 13:28:03 +08:00 committed by GitHub
parent 83f1de93d0
commit 42032356fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1040 additions and 16 deletions

View File

@ -13,7 +13,7 @@ if(USE_CUDA)
message("CMake 3.18 or higher is required for setting CUDAToolkit")
cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit
else()
cmake_minimum_required(VERSION 3.12)
cmake_minimum_required(VERSION 3.17)
endif()
include(CMakeDependentOption)
@ -245,6 +245,7 @@ if(USE_BANG)
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
@ -261,7 +262,13 @@ if(USE_BANG)
# BangC Kernels
################################################################################
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
if (BUILD_DIST)
message(STATUS "Add BUILD_DIST, use CNCL with BANG")
add_compile_definitions(INFINI_USE_CNCL=1)
endif()
endif()
if(USE_KUNLUN)
@ -324,6 +331,7 @@ if(BUILD_TEST)
endif()
if (USE_BANG)
build_test(test/kernels/bang/*.cc)
build_test(test/bang/*.cc)
endif()
if (USE_KUNLUN)
build_test(test/kernels/kunlun/*.cc)

View File

@ -29,6 +29,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
CMAKE_OPT += -DBUILD_TEST=$(TEST)
CMAKE_OPT += -DBUILD_DIST=ON
CMAKE_OPT += -DBUILD_NNET=$(NNET)
ifeq ($(INTELCPU), ON)

76
cmake/FindCNCL.cmake Normal file
View File

@ -0,0 +1,76 @@
SET(CNCL_LIB_SEARCH_PATHS $ENV{NEUWARE_HOME}/lib64)
SET(CNCL_INCLUDE_SEARCH_PATHS $ENV{NEUWARE_HOME}/include)
set(CNCL_INCLUDE_DIR $ENV{NEUWARE_HOME}/include)
set(CNCL_LIB_DIR $ENV{NEUWARE_HOME}/lib64)
set(CNCL_VERSION $ENV{CNCL_VERSION} CACHE STRING "Version of CNCL to build with")
if ($ENV{CNCL_ROOT_DIR})
message(WARNING "CNCL_ROOT_DIR is deprecated. Please set CNCL_ROOT instead.")
endif()
list(APPEND CNCL_ROOT $ENV{CNCL_ROOT_DIR} ${MLU_TOOLKIT_ROOT_DIR})
# Compatible layer for CMake <3.12. CNCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
list(APPEND CMAKE_PREFIX_PATH ${CNCL_ROOT})
find_path(CNCL_INCLUDE_DIRS
NAMES cncl.h
HINTS ${CNCL_INCLUDE_DIR})
if (USE_STATIC_CNCL)
MESSAGE(STATUS "USE_STATIC_CNCL is set. Linking with static CNCL library.")
SET(CNCL_LIBNAME "CNCL_static")
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
endif()
else()
SET(CNCL_LIBNAME "cncl")
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
endif()
endif()
find_library(CNCL_LIBRARIES
NAMES ${CNCL_LIBNAME}
HINTS ${CNCL_LIB_DIR})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(CNCL DEFAULT_MSG CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
if(CNCL_FOUND) # obtaining CNCL version and some sanity checks
set (CNCL_HEADER_FILE "${CNCL_INCLUDE_DIRS}/cncl.h")
message (STATUS "Determining CNCL version from ${CNCL_HEADER_FILE}...")
set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
list (APPEND CMAKE_REQUIRED_INCLUDES ${CNCL_INCLUDE_DIRS})
include(CheckCXXSymbolExists)
check_cxx_symbol_exists(CNCL_VERSION_CODE CNCL.h CNCL_VERSION_DEFINED)
if (CNCL_VERSION_DEFINED)
set(file "${PROJECT_BINARY_DIR}/detect_cncl_version.cc")
file(WRITE ${file} "
#include <iostream>
#include <cncl.h>
int main()
{
std::cout << CNCL_MAJOR << '.' << CNCL_MINOR << '.' << CNCL_PATCH << std::endl;
int x;
CNCLGetVersion(&x);
return x == CNCL_VERSION_CODE;
}
")
try_run(CNCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
RUN_OUTPUT_VARIABLE CNCL_VERSION_FROM_HEADER
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CNCL_INCLUDE_DIRS}"
LINK_LIBRARIES ${CNCL_LIBRARIES})
if (NOT CNCL_VERSION_MATCHED)
message(FATAL_ERROR "Found CNCL header version and library version do not match! \
(include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES}) Please set CNCL_INCLUDE_DIR and CNCL_LIB_DIR manually.")
endif()
message(STATUS "CNCL version: ${CNCL_VERSION_FROM_HEADER}")
else()
# message(STATUS "CNCL version < 2.3.5-5")
endif ()
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
message(STATUS "Found CNCL (include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES})")
mark_as_advanced(CNCL_ROOT_DIR CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
endif()

View File

@ -0,0 +1,196 @@
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
default=False,
action="store_true",
help="whether to generate the standard results.",
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.batch_size,
args.length,
args.gen_std,
)
def run_model(model, runtime, world_size=1, rank=0, n=10):
stub = OnnxStub(model, runtime)
load_inputs(stub, world_size, rank)
# stub.tune()
stub.run()
# get outputs
time.sleep(0.01)
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench
begin = time.time()
for _ in range(n):
stub.run()
end = time.time()
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
return outputs
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
results = np.load(f"./data/output.npy")
outputs = run_model(model, runtime, world_size, rank)
print("answer argmax:", np.argmax(results))
print("output argmax:", np.argmax(outputs))
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
getDiff(results, outputs)
def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
):
dist_name = name + "_dist"
model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
onnx.save_model(
model,
f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path,
)
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.BangRuntime(local_rank)
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
run_and_compare(name, model, runtime, world_size, rank)
def start_single(name, model):
runtime = backend.BangRuntime(0)
run_and_compare(name, model, runtime)
def generate_input_output(model):
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
runtime = backend.BangRuntime(0)
stub = OnnxStub(model, runtime)
position_id = 0
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = tensor.copyout_numpy()
if np.issubdtype(input.dtype, np.integer):
if input.size == 1:
# input = np.array([position_id])
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
else:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
elif input.dtype == np.bool_:
input = np.random.randint(0,2,size=input.shape) > 0
else:
if i == 0:
input = np.ones(input.shape).astype(input.dtype)
position_id = input.shape[-1] - 1
else:
input = np.random.rand(*input.shape).astype(input.dtype)
tensor.copyin_numpy(input)
np.save(f"./data/input_{i}", input)
stub.run()
time.sleep(0.01)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
if np.isnan(output).any():
print("Nan in output")
np.save(f"./data/output", output)
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(f"./data/input_{i}.npy")
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def getDiff(base, test):
absolute_diff = np.abs(np.subtract(base, test))
max_absolute_diff = np.max(absolute_diff)
baseCopy = base.astype(np.float64).ravel()
testCopy = test.astype(np.float64).ravel()
upValue = np.sum(np.abs(baseCopy - testCopy))
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
max_relative_diff = upValue / downValue
print(f"Max absolute difference: {max_absolute_diff}\n"
f"Max relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff
def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
model = onnx.load(model_path)
# generate standart output
if gen_std:
print("Generate inputs and outputs.")
p = mp.Process(target=generate_input_output, args=[model])
p.start()
p.join()
return
# run single process.
# use standalone process to isolate cuda.
print("run model by single MLU.")
p = mp.Process(target=start_single, args=(name, model))
p.start()
p.join()
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"run model by {world_size} MLUs in parallel.")
workers = [
mp.Process(
target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -115,7 +115,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
assert out_dims[s_dim] % tp_world_size == 0, out_dims
out_dims[s_dim] //= tp_world_size
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
# node.input[1] = node.output[0] + "_shape"
node.input[1] = node.output[0] + "_shape"
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
place[node.output[0]] = Shard(s_dim)

View File

@ -7,17 +7,19 @@ namespace infini {
class BangRuntimeObj : public RuntimeObj {
private:
cnnlHandle_t cnnl;
cnrtQueue_t queue;
std::unique_ptr<CommunicatorObj> comm;
BangPtr workspace;
size_t workspaceSize;
mutable size_t cursor;
public:
BangRuntimeObj() : RuntimeObj(Device::BANG) {
explicit BangRuntimeObj(int deviceId = 0)
: RuntimeObj(Device::BANG, deviceId) {
cnInit(0);
CNdev dev;
cnDeviceGet(&dev, 0);
cnDeviceGet(&dev, deviceId);
checkBangError(cnrtSetDevice(dev));
cnrtQueue_t queue;
checkBangError(cnrtQueueCreate(&queue));
checkCnnlError(cnnlCreate(&cnnl));
@ -30,6 +32,7 @@ class BangRuntimeObj : public RuntimeObj {
}
virtual ~BangRuntimeObj() {
dealloc(workspace);
checkBangError(cnrtQueueDestroy(queue));
checkCnnlError(cnnlDestroy(cnnl));
}
string toString() const override;
@ -73,10 +76,9 @@ class BangRuntimeObj : public RuntimeObj {
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
CNRT_MEM_TRANS_DIR_PEER2PEER));
}
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
void initComm(const string &name, int worldSize, int rank) final;
CommunicatorObj &getCommunicator() const override { return *comm; }
cnrtQueue_t getBangQueue() const { return queue; }
private:
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;

View File

@ -0,0 +1,79 @@
#pragma once
#include "bang_common.h"
#include "core/communicator.h"
#include <chrono>
#include <cncl.h>
#include <cnrt.h>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <mutex>
#include <thread>
namespace infini {
class CnclCommunicatorObj final : public CommunicatorObj {
private:
cnclComm_t *comms;
public:
CnclCommunicatorObj(const string &name, int worldSize, int rank)
: CommunicatorObj(worldSize, rank) {
const std::string filePath("./" + name + "_cncl_id.bin");
cnclCliqueId clique_id;
if (rank == 0) {
CNCL_CHECK(cnclGetCliqueId(&clique_id));
std::ofstream ofs(filePath, std::ios::binary);
ofs.write((char *)&clique_id, sizeof(cnclCliqueId));
} else {
auto begin = std::chrono::steady_clock::now();
while (!std::filesystem::exists(filePath)) {
auto now = std::chrono::steady_clock::now();
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
"time limit (10s) exceeded.");
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::ifstream ifs(filePath, std::ios::binary);
ifs.read((char *)&clique_id, sizeof(cnclCliqueId));
}
int num_comms = 1;
int *dev_list = new int[num_comms];
int *rank_list = new int[num_comms];
comms = new cnclComm_t[num_comms];
uint32_t num_dev = 0;
checkBangError(cnrtGetDeviceCount(&num_dev));
for (int i = 0; i < num_comms; i++) {
rank_list[i] = rank;
dev_list[i] = rank_list[i] % num_dev;
}
CNCL_CHECK(cnclInitComms(comms, num_comms, dev_list, rank_list,
worldSize, &clique_id));
if (rank == 0) {
std::filesystem::remove(filePath);
}
delete[] dev_list;
delete[] rank_list;
}
~CnclCommunicatorObj() {
CNCL_CHECK(cnclDestroyComms(comms, 1));
delete[] comms;
}
// Get the actual cnclComm_t
cnclComm_t getCnclComm() { return comms[0]; }
virtual string toString() const final {
std::ostringstream oss;
oss << "CNCL communicator";
return oss.str();
}
};
} // namespace infini

View File

@ -8,7 +8,9 @@
#if USE_CUDA
#include "cuda/cuda_runtime.h"
#endif
#if USE_BANG
#include "bang/bang_runtime.h"
#endif
namespace infini {
// TODO: how to deal with this

View File

@ -1,6 +1,9 @@
#include "bang/bang_runtime.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
#ifdef INFINI_USE_CNCL
#include "bang/cncl_communicator.h"
#endif
namespace infini {
@ -59,4 +62,15 @@ void BangRuntimeObj::sync() const { cnrtSyncDevice(); }
string BangRuntimeObj::toString() const { return "BANG Runtime"; }
void BangRuntimeObj::initComm(const string &name, int worldSize, int rank) {
IT_ASSERT(worldSize > 0);
IT_ASSERT(rank >= 0);
IT_ASSERT(rank < worldSize);
IT_ASSERT(!comm) << "communicator is already initialized.";
#ifdef INFINI_USE_CNCL
comm = std::make_unique<CnclCommunicatorObj>(name, worldSize, rank);
#else
IT_TODO_HALT_MSG("Not compiled with CNCL.");
#endif
}
} // namespace infini

View File

@ -399,7 +399,9 @@ void init_graph_builder(py::module &m) {
#endif
#ifdef USE_BANG
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
m, "BangRuntime");
m, "BangRuntime")
.def(py::init<int>(), py::arg("device") = 0)
.def("init_comm", &BangRuntimeObj::initComm);
#endif
#ifdef USE_KUNLUN
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(

View File

@ -0,0 +1,49 @@
#ifdef INFINI_USE_CNCL
#include "operators/all_gather.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include <thread>
namespace infini {
class AllGatherCNCL : public BangKernelWithoutConfig {
public:
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AllGatherObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int world_size = op->getWorldSize();
// Check if world size info in operator matches runtime
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
void *input = op->getInputs(0)->getRawDataPtr<void *>();
BangPtr output_temp =
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
// void *output = op->getOutput()->getRawDataPtr<void *>();
// IT_ASSERT(op->getDType() == DataType::Float32);
checkBangError(cnrtMalloc(&output_temp,
op->getInputs(0)->getBytes() * world_size));
size_t bytes = op->getInputs(0)->getBytes();
size_t count = bytes / op->getDType().getSize();
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
CNCL_CHECK(
cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue));
checkBangError(cnrtQueueSync(queue));
for (int i = 0; i < world_size; ++i) {
Tensor output = op->getOutput(i);
context->copyBlobInsideRuntime(
output->getRawDataPtr<float *>(),
static_cast<float *>(output_temp) + i * count, bytes);
}
checkBangError(cnrtFree(output_temp));
}
};
REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32,
AllGatherCNCL, "AllGather_CNCL_BANG_Float32");
} // namespace infini
#endif

View File

@ -0,0 +1,53 @@
#ifdef INFINI_USE_CNCL
#include "operators/all_reduce.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include <thread>
namespace infini {
class AllReduceCNCL : public BangKernelWithoutConfig {
public:
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<AllReduceBaseObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t count = op->getInputs(0)->size();
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
// checkBangError(cnrtQueueSync(queue));
CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(),
comm, queue));
checkBangError(cnrtQueueSync(queue));
}
virtual cnclReduceOp_t getRedOp() const = 0;
};
class AllReduceSumCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclSum; }
};
class AllReduceProdCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclProd; }
};
class AllReduceMinCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclMin; }
};
class AllReduceMaxCNCL : public AllReduceCNCL {
cnclReduceOp_t getRedOp() const override { return cnclMax; }
};
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32,
AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32,
AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32,
AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32,
AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32");
} // namespace infini
#endif

View File

@ -0,0 +1,34 @@
#ifdef INFINI_USE_CNCL
#include "operators/broadcast.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include <thread>
namespace infini {
class BroadcastCNCL : public BangKernelWithoutConfig {
public:
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<BroadcastObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = context->getBangQueue();
// TODO: Using default stream 0 for now.
CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32,
op->getRoot(), comm, queue));
checkBangError(cnrtQueueSync(queue));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32,
BroadcastCNCL, "Broadcast_CNCL_BANG_Float32");
} // namespace infini
#endif

View File

@ -23,6 +23,8 @@ class GatherCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, aDim.size(),
aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
checkCnnlError(
cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST));
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, bDim.size(),
bDim.data()));

View File

@ -1,12 +1,14 @@
#include "operators/reduce.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/reduce.h"
namespace infini {
class ReduceMeanCnnl : public BangKernelWithoutConfig {
class ReduceCnnlBase : public BangKernelWithoutConfig {
virtual cnnlReduceOp_t getReduceOp() const = 0;
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceMeanObj>(_op);
auto op = as<ReduceBaseObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -34,7 +36,7 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig {
cnnlReduceDescriptor_t reduceDesc;
checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc));
checkCnnlError(cnnlSetReduceDescriptor_v2(
reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG,
reduceDesc, axes.data(), axes.size(), getReduceOp(),
CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES, 0.0));
@ -63,7 +65,17 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig {
}
};
class ReduceMeanCnnl : public ReduceCnnlBase {
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_AVG; }
};
class ReduceSumCnnl : public ReduceCnnlBase {
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; }
};
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32,
ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -27,6 +27,8 @@ class CopyBang : public BangKernelWithoutConfig {
// reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang,
"Reshape_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang,
"Reshape_BANG_Int64");
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
"Flatten_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,

64
src/kernels/bang/slice.cc Normal file
View File

@ -0,0 +1,64 @@
#include "operators/slice.h"
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
namespace infini {
class SliceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SliceObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto starts = op->getStarts();
auto ends = op->getEnds();
auto steps = op->getSteps();
int32_t starts_array[starts.size()];
int32_t ends_array[ends.size()];
int32_t steps_array[steps.size()];
for (size_t i = 0; i < starts.size(); i++) {
starts_array[i] = starts[i];
ends_array[i] = ends[i];
steps_array[i] = steps[i];
}
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto aDim = op->getInputs(0)->getDims();
int aDim_size = aDim.size();
int aDim_array[aDim_size];
for (int i = 0; i < aDim_size; ++i) {
aDim_array[i] = aDim[i];
}
auto cDim = op->getOutput()->getDims();
int cDim_size = cDim.size();
int cDim_array[cDim_size];
for (int i = 0; i < cDim_size; ++i) {
cDim_array[i] = cDim[i];
}
cnnlTensorDescriptor_t aDesc, cDesc;
// input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array));
// output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array));
cnnlStatus_t stat =
cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array,
ends_array, steps_array, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl,
"Slice_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -0,0 +1,58 @@
#ifdef INFINI_USE_CNCL
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include "test.h"
static int WORLD_SIZE = 2;
namespace infini {
void allReduceSum(float *data, int deviceId) {
// Create Runtime and setup communication
BangRuntimeObj *bang_runtime = new BangRuntimeObj(deviceId);
int rank = deviceId;
bang_runtime->initComm("test_cncl_comm", WORLD_SIZE, rank);
cnclComm_t comm =
dynamic_cast<CnclCommunicatorObj &>(bang_runtime->getCommunicator())
.getCnclComm();
cnrtQueue_t queue = bang_runtime->getBangQueue();
// Copy data
float *data_mlu;
checkBangError(cnrtMalloc((void **)&data_mlu, sizeof(float)));
checkBangError(
cnrtMemcpy(data_mlu, data, sizeof(float), cnrtMemcpyHostToDev));
// Do AllReduce
CNCL_CHECK(
cnclAllReduce(data_mlu, data_mlu, 1, cnclFloat, cnclSum, comm, queue));
checkBangError(cnrtQueueSync(queue));
// Copy data back and sync device
checkBangError(
cnrtMemcpy(data, data_mlu, sizeof(float), cnrtMemcpyDevToHost));
ASSERT_EQ(*data, 5.0f);
}
// Setup communication between 2 threads, each controlling 1 MLU.
// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}.
TEST(CNCL, multi_mlu_communication) {
float data[] = {1.0, 4.0};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allReduceSum(&data[i], i);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
} // namespace infini
#endif

View File

@ -0,0 +1,60 @@
#ifdef INFINI_USE_CNCL
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/all_gather.h"
#include "test.h"
#include <cncl.h>
#include <thread>
static int WORLD_SIZE = 2;
namespace infini {
void allGather(const string taskName, int deviceID, vector<float> data,
vector<vector<float>> ans) {
// Create Runtimes and initiate communication
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
// Create Graph and insert allReduce operation
Graph g = make_ref<GraphObj>(bangRuntime);
auto input =
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
auto op = g->addOp<AllGatherObj>(input, std::nullopt, WORLD_SIZE);
// Copy data from CPU to MLU
g->dataMalloc();
input->copyin(data);
// Run operation
bangRuntime->run(g);
// Copy output from MLU to CPU
for (int i = 0; i < WORLD_SIZE; ++i) {
auto result = op->getOutputs()[i]->clone(cpuRuntime);
EXPECT_TRUE(result->equalData(ans[i]));
}
}
TEST(BANG_AllGather, run) {
vector<float> data[2] = {{2., 3.}, {5., 6.}};
vector<vector<float>> ans = {{2., 3.}, {5., 6.}};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allGather("test_all_gather", i, data[i], ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
} // namespace infini
#endif

View File

@ -0,0 +1,124 @@
#ifdef INFINI_USE_CNCL
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/all_reduce.h"
#include "test.h"
#include <cncl.h>
#include <future>
#include <thread>
static int WORLD_SIZE = 2;
namespace infini {
template <typename OperatorObj>
void allReduce(const string taskName, int deviceID, vector<float> data,
vector<float> ans) {
// Create Runtimes and initiate communication
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
// Create Graph and insert allReduce operation
Graph g = make_ref<GraphObj>(bangRuntime);
auto input =
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
auto op = g->addOp<OperatorObj>(input, nullptr);
// Copy data from CPU to MLU
g->dataMalloc();
input->copyin(data);
// Run operation
bangRuntime->run(g);
// Copy output from MLU to CPU
auto result = op->getOutput()->clone(cpuRuntime);
EXPECT_TRUE(result->equalData(ans));
}
TEST(BANG_AllReduce, sum) {
vector<float> data[2] = {{2., 3.}, {5., 6.}};
vector<float> ans = {7., 9.};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allReduce<AllReduceSumObj>("test_allreduce_sum", i, data[i], ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
TEST(BANG_AllReduce, prod) {
vector<float> data[2] = {{2., 3.}, {5., 6.}};
vector<float> ans = {10., 18.};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allReduce<AllReduceProdObj>("test_allreduce_prod", i, data[i], ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
TEST(BANG_AllReduce, min) {
vector<float> data[2] = {{2., 3.}, {5., 6.}};
vector<float> ans = {2., 3.};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allReduce<AllReduceMinObj>("test_allreduce_min", i, data[i], ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
TEST(BANG_AllReduce, max) {
vector<float> data[2] = {{2., 3.}, {5., 6.}};
vector<float> ans = {5., 6.};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
allReduce<AllReduceMaxObj>("test_allreduce_max", i, data[i], ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
} // namespace infini
#endif

View File

@ -0,0 +1,65 @@
#ifdef INFINI_USE_CNCL
#include "bang/bang_runtime.h"
#include "bang/cncl_communicator.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/broadcast.h"
#include "test.h"
#include <cncl.h>
#include <thread>
static int WORLD_SIZE = 2;
static int root = 0;
namespace infini {
void broadcast(const string taskName, int deviceID, vector<float> data,
vector<float> ans) {
// Create Runtimes and initiate communication
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
Runtime bangRuntime = make_ref<BangRuntimeObj>(deviceID);
bangRuntime->initComm(taskName, WORLD_SIZE, deviceID);
// Create Graph and insert allReduce operation
Graph g = make_ref<GraphObj>(bangRuntime);
auto input =
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
auto op = g->addOp<BroadcastObj>(input, nullptr, root);
// Copy data from CPU to GPU
g->dataMalloc();
// Only rank 0 has the data
if (deviceID == root) {
input->copyin(data);
}
// Run broadcast operation
bangRuntime->run(g);
// Copy output from GPU to CPU
auto result = op->getOutput()->clone(cpuRuntime);
EXPECT_TRUE(result->equalData(ans));
}
TEST(BANG_Broadcast, run) {
// Only 1 device gets data. Every rank should have the same data after
// broadcast.
vector<float> data = {2., 3., 5., 6.};
vector<float> ans = {2., 3., 5., 6.};
for (int i = 0; i < WORLD_SIZE; ++i) {
pid_t pid = fork();
if (pid == 0) {
// Child process
broadcast("test_broadcast", i, data, ans);
exit(0); // Ensure child process exits to avoid unnecessary
// repetition in parent
} else if (pid < 0) {
std::cerr << "Error creating process" << std::endl;
}
}
// Wait for all child processes to finish
for (int i = 0; i < WORLD_SIZE; ++i) {
wait(NULL);
}
}
} // namespace infini
#endif

View File

@ -0,0 +1,82 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/reduce.h"
#include "test.h"
namespace infini {
template <typename ReduceObjT>
void test_reduce(const Shape &shape, const vector<float> &data,
const optional<const vector<int>> &axis, bool keepDims,
const vector<float> &ExpectData) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
// Build BANG graph
Graph g = make_ref<GraphObj>(bangRuntime);
auto i = g->cloneTensor(icpu);
auto op = g->addOp<ReduceObjT>(i, nullptr, axis, keepDims);
// allocate BANG memory
g->dataMalloc();
i->copyin(data);
// Execute on BANG
bangRuntime->run(g);
// clone BANG output to CPU
auto o = op->getOutput();
auto ocpu = o->clone(cpuRuntime);
// check results on CPU
EXPECT_TRUE(ocpu->equalData(ExpectData));
}
TEST(BANG_ReduceMean, run) {
test_reduce<ReduceMeanObj>(
Shape{3, 2, 2}, vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, true, vector<float>{18.25});
test_reduce<ReduceMeanObj>(
Shape{1, 3, 2, 2, 1},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt,
false, vector<float>{18.25});
test_reduce<ReduceMeanObj>(
Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
test_reduce<ReduceMeanObj>(
Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
}
TEST(BANG_ReduceSum, run) {
test_reduce<ReduceSumObj>(Shape{3, 2, 2},
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
std::nullopt, true, vector<float>{12});
test_reduce<ReduceSumObj>(Shape{1, 3, 2, 2, 1},
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
std::nullopt, false, vector<float>{12});
test_reduce<ReduceSumObj>(
Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{30, 36, 102, 108});
test_reduce<ReduceSumObj>(
Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{30, 36, 102, 108});
}
} // namespace infini

View File

@ -0,0 +1,39 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/slice.h"
#include "test.h"
namespace infini {
TEST(BANG_Slice, run) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor icpu =
make_ref<TensorObj>(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime);
icpu->dataMalloc();
icpu->setData(IncrementalGenerator());
// Build CUDA graph;
Graph g = make_ref<GraphObj>(bangRuntime);
auto i = g->cloneTensor(icpu);
auto op =
g->addOp<SliceObj>(i, nullptr, vector<int>{1, 1}, vector<int>{2, 5},
vector<int>{0, 3}, std::nullopt);
// allocate CUDA memory
g->dataMalloc();
i->setData(IncrementalGenerator());
// Execute on CUDA
bangRuntime->run(g);
// clone CUDA output to CPU
auto o = op->getOutput();
auto cpuo = o->clone(cpuRuntime);
// bangPrintTensor(o);
// check results on CPU
EXPECT_TRUE(cpuo->equalData(vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
}
} // namespace infini