diff --git a/CMakeLists.txt b/CMakeLists.txt index 1101a8c2..70508c79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ if(USE_CUDA) message("CMake 3.18 or higher is required for setting CUDAToolkit") cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit else() - cmake_minimum_required(VERSION 3.12) + cmake_minimum_required(VERSION 3.17) endif() include(CMakeDependentOption) @@ -245,6 +245,7 @@ if(USE_BANG) find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") + find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror") if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) @@ -261,7 +262,13 @@ if(USE_BANG) # BangC Kernels ################################################################################ - target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) + target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++) + if (BUILD_DIST) + message(STATUS "Add BUILD_DIST, use CNCL with BANG") + + add_compile_definitions(INFINI_USE_CNCL=1) + + endif() endif() if(USE_KUNLUN) @@ -324,6 +331,7 @@ if(BUILD_TEST) endif() if (USE_BANG) build_test(test/kernels/bang/*.cc) + build_test(test/bang/*.cc) endif() if (USE_KUNLUN) build_test(test/kernels/kunlun/*.cc) diff --git a/Makefile b/Makefile index 302f47b8..d21a406b 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,7 @@ CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) +CMAKE_OPT += -DBUILD_DIST=ON CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) diff --git a/cmake/FindCNCL.cmake b/cmake/FindCNCL.cmake new file mode 100644 index 00000000..31351dda --- /dev/null +++ b/cmake/FindCNCL.cmake @@ -0,0 +1,76 @@ +SET(CNCL_LIB_SEARCH_PATHS $ENV{NEUWARE_HOME}/lib64) +SET(CNCL_INCLUDE_SEARCH_PATHS $ENV{NEUWARE_HOME}/include) + +set(CNCL_INCLUDE_DIR $ENV{NEUWARE_HOME}/include) +set(CNCL_LIB_DIR $ENV{NEUWARE_HOME}/lib64) +set(CNCL_VERSION $ENV{CNCL_VERSION} CACHE STRING "Version of CNCL to build with") + +if ($ENV{CNCL_ROOT_DIR}) + message(WARNING "CNCL_ROOT_DIR is deprecated. Please set CNCL_ROOT instead.") +endif() +list(APPEND CNCL_ROOT $ENV{CNCL_ROOT_DIR} ${MLU_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. CNCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CNCL_ROOT}) + +find_path(CNCL_INCLUDE_DIRS + NAMES cncl.h + HINTS ${CNCL_INCLUDE_DIR}) + +if (USE_STATIC_CNCL) + MESSAGE(STATUS "USE_STATIC_CNCL is set. Linking with static CNCL library.") + SET(CNCL_LIBNAME "CNCL_static") + if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(CNCL_LIBNAME "cncl") + if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(CNCL_LIBRARIES + NAMES ${CNCL_LIBNAME} + HINTS ${CNCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(CNCL DEFAULT_MSG CNCL_INCLUDE_DIRS CNCL_LIBRARIES) + +if(CNCL_FOUND) # obtaining CNCL version and some sanity checks + set (CNCL_HEADER_FILE "${CNCL_INCLUDE_DIRS}/cncl.h") + message (STATUS "Determining CNCL version from ${CNCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${CNCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(CNCL_VERSION_CODE CNCL.h CNCL_VERSION_DEFINED) + + if (CNCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_cncl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << CNCL_MAJOR << '.' << CNCL_MINOR << '.' << CNCL_PATCH << std::endl; + int x; + CNCLGetVersion(&x); + return x == CNCL_VERSION_CODE; + } +") + try_run(CNCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE CNCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CNCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${CNCL_LIBRARIES}) + if (NOT CNCL_VERSION_MATCHED) + message(FATAL_ERROR "Found CNCL header version and library version do not match! \ +(include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES}) Please set CNCL_INCLUDE_DIR and CNCL_LIB_DIR manually.") + endif() + message(STATUS "CNCL version: ${CNCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "CNCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found CNCL (include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES})") + mark_as_advanced(CNCL_ROOT_DIR CNCL_INCLUDE_DIRS CNCL_LIBRARIES) +endif() diff --git a/examples/NNmodel b/examples/NNmodel index 51d31052..b896cec2 160000 --- a/examples/NNmodel +++ b/examples/NNmodel @@ -1 +1 @@ -Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77 +Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98 diff --git a/examples/distributed/bang_launch.py b/examples/distributed/bang_launch.py new file mode 100644 index 00000000..518935b5 --- /dev/null +++ b/examples/distributed/bang_launch.py @@ -0,0 +1,196 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=2, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx", + help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + default=False, + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(model, runtime, world_size=1, rank=0, n=10): + stub = OnnxStub(model, runtime) + load_inputs(stub, world_size, rank) + # stub.tune() + stub.run() + # get outputs + time.sleep(0.01) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() + + # bench + begin = time.time() + for _ in range(n): + stub.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, model, runtime, world_size=1, rank = 0): + results = np.load(f"./data/output.npy") + outputs = run_model(model, runtime, world_size, rank) + print("answer argmax:", np.argmax(results)) + print("output argmax:", np.argmax(outputs)) + #np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3) + getDiff(results, outputs) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + runtime = backend.BangRuntime(local_rank) + # print("init comm") + runtime.init_comm( + dist_name, + world_size, + rank, + ) + run_and_compare(name, model, runtime, world_size, rank) + + +def start_single(name, model): + runtime = backend.BangRuntime(0) + run_and_compare(name, model, runtime) + + +def generate_input_output(model): + os.makedirs(os.path.dirname("./data/"), exist_ok=True) + runtime = backend.BangRuntime(0) + stub = OnnxStub(model, runtime) + position_id = 0 + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = tensor.copyout_numpy() + if np.issubdtype(input.dtype, np.integer): + if input.size == 1: + # input = np.array([position_id]) + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + else: + input = np.random.randint(0,2,size=input.shape, dtype=input.dtype) + elif input.dtype == np.bool_: + input = np.random.randint(0,2,size=input.shape) > 0 + else: + if i == 0: + input = np.ones(input.shape).astype(input.dtype) + position_id = input.shape[-1] - 1 + else: + input = np.random.rand(*input.shape).astype(input.dtype) + tensor.copyin_numpy(input) + np.save(f"./data/input_{i}", input) + stub.run() + time.sleep(0.01) + output = next(stub.outputs.values().__iter__()).copyout_numpy() + if np.isnan(output).any(): + print("Nan in output") + np.save(f"./data/output", output) + + +def load_inputs(stub, world_size=1, rank=0): + for i, (name, tensor) in enumerate(stub.inputs.items()): + input = np.load(f"./data/input_{i}.npy") + if all(x == y for x,y in zip(input.shape,tensor.shape())): + tensor.copyin_numpy(input) + else: + tensor.copyin_numpy(np.hsplit(input, world_size)[rank]) + +def getDiff(base, test): + absolute_diff = np.abs(np.subtract(base, test)) + max_absolute_diff = np.max(absolute_diff) + + baseCopy = base.astype(np.float64).ravel() + testCopy = test.astype(np.float64).ravel() + upValue = np.sum(np.abs(baseCopy - testCopy)) + downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9) + max_relative_diff = upValue / downValue + print(f"Max absolute difference: {max_absolute_diff}\n" + f"Max relative difference: {max_relative_diff}") + return max_absolute_diff, max_relative_diff + + +def main(): + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() + + model = onnx.load(model_path) + + # generate standart output + if gen_std: + print("Generate inputs and outputs.") + p = mp.Process(target=generate_input_output, args=[model]) + p.start() + p.join() + return + + # run single process. + # use standalone process to isolate cuda. + print("run model by single MLU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} MLUs in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 3ddf2ead..1214b6b3 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -115,7 +115,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): assert out_dims[s_dim] % tp_world_size == 0, out_dims out_dims[s_dim] //= tp_world_size # if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others. - # node.input[1] = node.output[0] + "_shape" + node.input[1] = node.output[0] + "_shape" data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1]) place[node.output[0]] = Shard(s_dim) diff --git a/include/bang/bang_runtime.h b/include/bang/bang_runtime.h index 2dde7756..e1ca6b38 100644 --- a/include/bang/bang_runtime.h +++ b/include/bang/bang_runtime.h @@ -7,17 +7,19 @@ namespace infini { class BangRuntimeObj : public RuntimeObj { private: cnnlHandle_t cnnl; + cnrtQueue_t queue; + std::unique_ptr comm; BangPtr workspace; size_t workspaceSize; mutable size_t cursor; public: - BangRuntimeObj() : RuntimeObj(Device::BANG) { + explicit BangRuntimeObj(int deviceId = 0) + : RuntimeObj(Device::BANG, deviceId) { cnInit(0); CNdev dev; - cnDeviceGet(&dev, 0); + cnDeviceGet(&dev, deviceId); checkBangError(cnrtSetDevice(dev)); - cnrtQueue_t queue; checkBangError(cnrtQueueCreate(&queue)); checkCnnlError(cnnlCreate(&cnnl)); @@ -30,6 +32,7 @@ class BangRuntimeObj : public RuntimeObj { } virtual ~BangRuntimeObj() { dealloc(workspace); + checkBangError(cnrtQueueDestroy(queue)); checkCnnlError(cnnlDestroy(cnnl)); } string toString() const override; @@ -73,10 +76,9 @@ class BangRuntimeObj : public RuntimeObj { checkBangError(cnrtMemcpy(dst, const_cast(src), bytes, CNRT_MEM_TRANS_DIR_PEER2PEER)); } - - void initComm(const string &, int, int) override { IT_TODO_HALT(); } - - CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } + void initComm(const string &name, int worldSize, int rank) final; + CommunicatorObj &getCommunicator() const override { return *comm; } + cnrtQueue_t getBangQueue() const { return queue; } private: void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; diff --git a/include/bang/cncl_communicator.h b/include/bang/cncl_communicator.h new file mode 100644 index 00000000..0999686c --- /dev/null +++ b/include/bang/cncl_communicator.h @@ -0,0 +1,79 @@ +#pragma once +#include "bang_common.h" +#include "core/communicator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini { + +class CnclCommunicatorObj final : public CommunicatorObj { + private: + cnclComm_t *comms; + + public: + CnclCommunicatorObj(const string &name, int worldSize, int rank) + : CommunicatorObj(worldSize, rank) { + const std::string filePath("./" + name + "_cncl_id.bin"); + cnclCliqueId clique_id; + if (rank == 0) { + CNCL_CHECK(cnclGetCliqueId(&clique_id)); + std::ofstream ofs(filePath, std::ios::binary); + ofs.write((char *)&clique_id, sizeof(cnclCliqueId)); + + } else { + auto begin = std::chrono::steady_clock::now(); + while (!std::filesystem::exists(filePath)) { + auto now = std::chrono::steady_clock::now(); + _IT_ASSERT_2(now < begin + std::chrono::seconds(10), + "time limit (10s) exceeded."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::ifstream ifs(filePath, std::ios::binary); + ifs.read((char *)&clique_id, sizeof(cnclCliqueId)); + } + + int num_comms = 1; + int *dev_list = new int[num_comms]; + int *rank_list = new int[num_comms]; + comms = new cnclComm_t[num_comms]; + uint32_t num_dev = 0; + checkBangError(cnrtGetDeviceCount(&num_dev)); + + for (int i = 0; i < num_comms; i++) { + rank_list[i] = rank; + dev_list[i] = rank_list[i] % num_dev; + } + + CNCL_CHECK(cnclInitComms(comms, num_comms, dev_list, rank_list, + worldSize, &clique_id)); + + if (rank == 0) { + std::filesystem::remove(filePath); + } + + delete[] dev_list; + delete[] rank_list; + } + + ~CnclCommunicatorObj() { + CNCL_CHECK(cnclDestroyComms(comms, 1)); + delete[] comms; + } + + // Get the actual cnclComm_t + cnclComm_t getCnclComm() { return comms[0]; } + + virtual string toString() const final { + std::ostringstream oss; + oss << "CNCL communicator"; + return oss.str(); + } +}; + +} // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index f095db81..313a1f79 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -65,12 +65,18 @@ class GraphHandlerObj { std::optional max); Tensor transpose(Tensor data, Tensor transposed, Shape perm); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); + Tensor resize(Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + Tensor scales, Tensor roi, vector sizes_, + vector scales_, vector roi_, string mode, + string ratioPolicy, string nearestMode, + string coordTransMode); Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, - int num_outputs); + std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis); Tensor reduceMean(Tensor data, Tensor reduced, @@ -99,6 +105,8 @@ class GraphHandlerObj { int outputType, Tensor input); Tensor depthToSpace(Tensor input, Tensor output, int blocksize, std::string mode); + Tensor lrn(Tensor input, Tensor output, float alpha, float beta, float bias, + int size); //------ modifiers diff --git a/include/core/kernel.h b/include/core/kernel.h index 3ef0d1b9..a19f3f1a 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -2,6 +2,7 @@ #include "core/common.h" #include "core/operator.h" #include "core/tensor.h" +#include "utils/operator_utils.h" #include #include using json = nlohmann::json; @@ -102,11 +103,9 @@ class KernelRegistry { } Kernel *getKernel(const KernelAttrs &kernelAttrs) const { auto it = kernels.find(kernelAttrs); - IT_ASSERT(it != kernels.end(), - "Kernel not found for key {" + - to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + - ", " + std::to_string(std::get<1>(kernelAttrs)) + ", " + - std::get<2>(kernelAttrs).toString() + "}"); + IT_ASSERT(it != kernels.end(), "Kernel not found for key {" + + get_kernel_attrs_str(kernelAttrs) + + "}"); return std::get<0>(it->second); } const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { diff --git a/include/core/tensor.h b/include/core/tensor.h index cb09261a..95229c14 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -8,7 +8,9 @@ #if USE_CUDA #include "cuda/cuda_runtime.h" #endif - +#if USE_BANG +#include "bang/bang_runtime.h" +#endif namespace infini { // TODO: how to deal with this diff --git a/include/operators/lrn.h b/include/operators/lrn.h new file mode 100644 index 00000000..e86dbdc4 --- /dev/null +++ b/include/operators/lrn.h @@ -0,0 +1,29 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class LRNObj : public OperatorObj { + + public: + LRNObj(GraphObj *graph, Tensor inputX, Tensor inputY, float alpha, + float beta, float bias, int size); + OP_CLONE(LRNObj); + + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + auto getAlphaBetaBias() const { + return tuple(alpha_value, beta_value, bias_value); + } + auto getSize() const { return size_value; } + + private: + float alpha_value, beta_value, bias_value; + int size_value; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/include/operators/resize.h b/include/operators/resize.h index 96283c12..220ef719 100644 --- a/include/operators/resize.h +++ b/include/operators/resize.h @@ -27,6 +27,60 @@ class ResizeObj : public OperatorObj { enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller, none }; enum class ECoeffMode { nearest, linear, cubic }; + static ECoordinateTransMode fromECoordinateTransModeStr(string mode) { + if (mode == "half_pixel") { + return ECoordinateTransMode::halfPixel; + } else if (mode == "asymmetric") { + return ECoordinateTransMode::asymmetric; + } else if (mode == "align_corners") { + return ECoordinateTransMode::alignCorners; + } else if (mode == "pytorch_half_pixel") { + return ECoordinateTransMode::pytorchHalfPixel; + } else if (mode == "tf_crop_and_resize") { + return ECoordinateTransMode::tfCropAndResize; + } else { + IT_TODO_HALT(); + } + } + + static ENearestMode fromENearestModeStr(string mode) { + if (mode == "round_prefer_floor") { + return ENearestMode::roundPreferFloor; + } else if (mode == "round_prefer_ceil") { + return ENearestMode::roundPreferCeil; + } else if (mode == "floor") { + return ENearestMode::floor; + } else if (mode == "ceil") { + return ENearestMode::ceil; + } else { + return ENearestMode::none; + } + } + + static EKeepAspectRatioPolicy fromRatioPolicyStr(string ratioPolicyStr) { + if (ratioPolicyStr == "stretch") { + return EKeepAspectRatioPolicy::stretch; + } else if (ratioPolicyStr == "not_larger") { + return EKeepAspectRatioPolicy::notLarger; + } else if (ratioPolicyStr == "not_smaller") { + return EKeepAspectRatioPolicy::notSmaller; + } else { + return EKeepAspectRatioPolicy::none; + } + } + + static ECoeffMode fromECoeffModeStr(string mode) { + if (mode == "nearest") { + return ECoeffMode::nearest; + } else if (mode == "linear") { + return ECoeffMode::linear; + } else if (mode == "cubic") { + return ECoeffMode::cubic; + } else { + IT_TODO_HALT(); + } + } + private: vector axes; vector scales; diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h index 4f6a6985..b0871c0b 100644 --- a/include/utils/operator_utils.h +++ b/include/utils/operator_utils.h @@ -2,6 +2,7 @@ #ifndef OPERATOR_UTIL_H #define OPERATOR_UTIL_H +#include "core/operator.h" #include "core/tensor.h" namespace infini { @@ -10,8 +11,15 @@ namespace infini { Shape infer_broadcast(const Shape &A, const Shape &B); // Launch the real axis based on rank and current axis int get_real_axis(const int &axis, const int &rank); -// check if tensor B is unidirectional broadcastable to tensor A +// Check if tensor B is unidirectional broadcastable to tensor A bool is_unidirectional_broadcasting(const Shape &A, const Shape &B); +// Locate the index with size from Shape +Shape locate_index(size_t inputN, const Shape &shape); +// Delocate the ShapeIndex from Shape with broadcast +size_t delocate_index(const Shape &shapeIndex, const Shape &shape, + const Shape &stride); +// Convert KernelAttrs to a string representation +std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs); } // namespace infini #endif diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 90a3d3ab..c63746af 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -535,6 +535,65 @@ class OnnxStub: tensors.get(node.output[0]), shape, ) + elif node.op_type == "Resize": + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, + { + "antialias": 0, + "axes": None, + "coordinate_transformation_mode": "half_pixel", + "cubic_coeff_a": -0.75, + "exclude_outside": 0, + "extrapolation_value": 0.0, + "keep_aspect_ratio_policy": "none", + "mode": "nearest", + "nearest_mode": "none", + }, + ) + ( + axes, + keep_aspect_ratio_policy, + coordinate_transformation_mode, + mode, + nearest_mode, + ) = ( + attributes[name] + for name in [ + "axes", + "keep_aspect_ratio_policy", + "coordinate_transformation_mode", + "mode", + "nearest_mode", + ] + ) + if len(node.input) > 1: + roiVal = _parse_data(data[node.input[1]]) + else: + roiVal = [] + if len(node.input) > 2: + scalesVal = _parse_data(data[node.input[2]]) + else: + scalesVal = [] + if len(node.input) > 3: + sizesVal = _parse_data(data[node.input[3]]) + else: + sizesVal = [] + tensors[node.output[0]] = self.handler.resize( + tensors[node.input[0]], + output, + axes, + tensors[node.input[3]] if len(node.input) > 3 else None, + tensors[node.input[2]] if len(node.input) > 2 else None, + tensors[node.input[1]] if len(node.input) > 1 else None, + sizesVal, + scalesVal, + roiVal, + mode, + keep_aspect_ratio_policy, + nearest_mode, + coordinate_transformation_mode, + ) elif node.op_type == "Squeeze": input_shape = _search_shape(model, node.input[0]) axes = set( @@ -585,6 +644,20 @@ class OnnxStub: tensors.get(node.output[0]), ) elif node.op_type == "Split": + split = ( + _parse_data(data[node.input[1]]) + if (len(node.input) > 1) + else None + ) + if split is None: + split = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "split" + ), + None, + ) for name, tensor in zip( node.output, self.handler.split( @@ -598,7 +671,7 @@ class OnnxStub: ), 0, ), - len(node.output), + split if split is not None else len(node.output), ), ): tensors[name] = tensor @@ -857,6 +930,22 @@ class OnnxStub: tensors[output_name] = self.handler.tensor(dims, tensor.data_type) data[output_name] = tensor tensors[output_name].set_weight() + elif node.op_type == "LRN": + attributes = _parse_attribute( + node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1} + ) + (alpha, beta, bias, size) = ( + attributes[name] + for name in ["alpha", "beta", "bias", "size"] + ) + tensors[node.output[0]] = self.handler.lrn( + tensors[node.input[0]], + tensors.get(node.output[0]), + alpha, + beta, + bias, + size, + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) new_node_name.append(node.name) @@ -1195,6 +1284,20 @@ class OnnxStub: elif ty == backend.OpTypeId.Expand: shape = backend.expand_shape_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape)) + elif ty == backend.OpTypeId.LRN: + alpha, beta, bias, size = backend.lrn_attrs_of(op) + ctx.push_node( + make_node( + ty.name, + inputs, + outputs, + name, + alpha, + beta, + bias, + size, + ) + ) else: raise Exception("Unsupported OpType", ty) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index ca290d76..f5d5a426 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -295,6 +295,14 @@ class TestStringMethods(unittest.TestCase): make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data]) ) + def test_resize(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 128, 40, 40]) + roi = make_tensor("roi", TensorProto.FLOAT, [0], []) + scales = make_tensor("scales", TensorProto.FLOAT, [4], [1, 1, 2, 2]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 128, 80, 80]) + reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize") + make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales])) + def test_concat(self): input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5]) @@ -435,6 +443,12 @@ class TestStringMethods(unittest.TestCase): split = make_node("Split", ["input"], ["output"], name="split", axis=0) make_and_import_model(make_graph([split], "split", [input], [])) + def test_split1(self): + input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) + splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1]) + split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1) + make_and_import_model(make_graph([split], "split", [input, splitAttr], [])) + def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index c9f9a933..2f16b500 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -1,6 +1,9 @@ #include "bang/bang_runtime.h" #include "core/kernel.h" #include "core/perf_engine.h" +#ifdef INFINI_USE_CNCL +#include "bang/cncl_communicator.h" +#endif namespace infini { @@ -59,4 +62,15 @@ void BangRuntimeObj::sync() const { cnrtSyncDevice(); } string BangRuntimeObj::toString() const { return "BANG Runtime"; } +void BangRuntimeObj::initComm(const string &name, int worldSize, int rank) { + IT_ASSERT(worldSize > 0); + IT_ASSERT(rank >= 0); + IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; +#ifdef INFINI_USE_CNCL + comm = std::make_unique(name, worldSize, rank); +#else + IT_TODO_HALT_MSG("Not compiled with CNCL."); +#endif +} } // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 1eb73499..7fc6f977 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -10,12 +10,14 @@ #include "operators/expand.h" #include "operators/gather.h" #include "operators/layer_norm.h" +#include "operators/lrn.h" #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" #include "operators/recv.h" #include "operators/reduce.h" #include "operators/reshape.h" +#include "operators/resize.h" #include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" @@ -24,6 +26,7 @@ #include "operators/unary.h" #include "operators/where.h" #include +#include namespace infini { @@ -252,6 +255,64 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { } } +Tensor GraphHandlerObj::resize(Tensor input, Tensor output, + const std::optional> &axes, + Tensor sizes, Tensor scales, Tensor roi, + vector sizes_, vector scales_, + vector roi_, string mode, + string ratioPolicy, string nearestMode, + string coordTransMode) { + if (sizes_.size() > 0) { + sizes->dataMalloc(); + sizes->copyin(sizes_); + } + if (scales_.size() > 0) { + scales->dataMalloc(); + scales->copyin(scales_); + } + if (roi_.size() > 0) { + roi->dataMalloc(); + roi->copyin(roi_); + } + ResizeObj::EKeepAspectRatioPolicy ratioPolicy_ = + ResizeObj::fromRatioPolicyStr(ratioPolicy); + ResizeObj::ENearestMode nearestMode_ = + ResizeObj::fromENearestModeStr(nearestMode); + ResizeObj::ECoordinateTransMode coordTransMode_ = + ResizeObj::fromECoordinateTransModeStr(coordTransMode); + ResizeObj::ECoeffMode mode_ = ResizeObj::fromECoeffModeStr(mode); + if (output) { + if (mode == "nearest") { + g->addOpWithOutputs( + std::move(input), output, std::move(axes), std::move(sizes), + std::move(scales), std::move(roi), ratioPolicy_, nearestMode_, + coordTransMode_); + } else { + g->addOpWithOutputs( + std::move(input), output, std::move(axes), std::move(sizes), + std::move(scales), std::move(roi), mode_, ratioPolicy_, + coordTransMode_); + } + return output; + } else { + if (mode == "nearest") { + return g + ->addOp(std::move(input), output, std::move(axes), + std::move(sizes), std::move(scales), + std::move(roi), ratioPolicy_, nearestMode_, + coordTransMode_) + ->getOutput(); + } else { + return g + ->addOp(std::move(input), output, std::move(axes), + std::move(sizes), std::move(scales), + std::move(roi), mode_, ratioPolicy_, + coordTransMode_) + ->getOutput(); + } + } +} + Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { if (output) { g->addOpWithOutputs(std::move(inputs), output, dim); @@ -283,14 +344,29 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, } TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, - int axis, int num_outputs) { + int axis, + std::variant> numOrRatio) { if (outputs) { - g->addOpWithOutputs(std::move(input), outputs, axis, - num_outputs); + if (std::holds_alternative(numOrRatio)) { + g->addOpWithOutputs(std::move(input), outputs, axis, + std::get(numOrRatio)); + } else { + g->addOpWithOutputs(std::move(input), outputs, axis, + std::get>(numOrRatio)); + } return *outputs; } else { - return g->addOp(std::move(input), outputs, axis, num_outputs) - ->getOutputs(); + if (std::holds_alternative(numOrRatio)) { + return g + ->addOp(std::move(input), outputs, axis, + std::get(numOrRatio)) + ->getOutputs(); + } else { + return g + ->addOp(std::move(input), outputs, axis, + std::get>(numOrRatio)) + ->getOutputs(); + } } } @@ -519,6 +595,19 @@ Tensor GraphHandlerObj::depthToSpace(Tensor input, Tensor output, int blocksize, } } +Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha, + float beta, float bias, int size) { + if (output) { + g->addOpWithOutputs(std::move(input), output, alpha, beta, bias, + size); + return output; + } else { + return g + ->addOp(std::move(input), output, alpha, beta, bias, size) + ->getOutput(); + } +} + static CastType inferCastType(Tensor input, int to) { auto iType = input->getDType(); auto oType = DataType(to); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index ca99a4c3..eadd4a4e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -5,6 +5,7 @@ #include "operators/conv.h" #include "operators/expand.h" #include "operators/gather.h" +#include "operators/lrn.h" #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" @@ -113,6 +114,7 @@ void export_values(py::module &m) { .VALUE(OpType, Erf) .VALUE(OpType, Where) .VALUE(OpType, DepthToSpace) + .VALUE(OpType, LRN) .export_values(); #undef VALUE @@ -296,6 +298,14 @@ static std::tuple depth_to_space_attrs_of(Operator op) { depth_to_space->getModeString()); } +static std::tuple lrn_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::LRN); + auto lrn = dynamic_cast(op.get()); + auto [alpha, beta, bias] = lrn->getAlphaBetaBias(); + auto size = lrn->getSize(); + return std::make_tuple(alpha, beta, bias, size); +} + void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) @@ -332,7 +342,8 @@ void export_functions(py::module &m) { .FUNCTION(gather_axis_of) .FUNCTION(flatten_axis_of) .FUNCTION(cast_to_of) - .FUNCTION(depth_to_space_attrs_of); + .FUNCTION(depth_to_space_attrs_of) + .FUNCTION(lrn_attrs_of); #undef FUNCTION } @@ -388,7 +399,9 @@ void init_graph_builder(py::module &m) { #endif #ifdef USE_BANG py::class_, RuntimeObj>( - m, "BangRuntime"); + m, "BangRuntime") + .def(py::init(), py::arg("device") = 0) + .def("init_comm", &BangRuntimeObj::initComm); #endif #ifdef USE_KUNLUN py::class_, RuntimeObj>( @@ -495,6 +508,7 @@ void init_graph_builder(py::module &m) { .def("transpose", &Handler::transpose, policy::move) .def("depthToSpace", &Handler::depthToSpace, policy::move) .def("reshape", &Handler::reshape, policy::move) + .def("resize", &Handler::resize, policy::move) .def("concat", &Handler::concat, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("split", &Handler::split, policy::move) @@ -517,6 +531,7 @@ void init_graph_builder(py::module &m) { .def("expand", &Handler::expand, policy::move) .def("erf", &Handler::erf, policy::move) .def("where", &Handler::where, policy::move) + .def("lrn", &Handler::lrn, policy::move) .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic) .def("operators", &Handler::operators, policy::move) diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index 87b8396f..1d7b0c20 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -30,8 +30,9 @@ class UnaryCnnl : public BangKernelWithoutConfig { cDim.data())); cnnlActivationDescriptor_t opDesc; checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); - checkCnnlError(cnnlSetActivationDescriptor( - opDesc, getOpType(), CNNL_NOT_PROPAGATE_NAN, getCoef())); + checkCnnlError(cnnlSetActivationDescriptor_v2( + opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, getCoef())); auto [alpha, beta] = getAlphBeta(); cnnlStatus_t stat = @@ -131,31 +132,51 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { std::vector inDim = {1, 1, 1}; std::vector outDim = inDim; - if (axis == 0) { - mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; - inDim[0] = aDim[0]; - inDim[1] = aDim[1]; - for (size_t i = 2; i < aDim.size(); ++i) { - inDim[2] *= aDim[i]; + if (aDim.size() >= 3) { + if (axis == 0) { + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim[0] = aDim[0]; + inDim[1] = aDim[1]; + for (size_t i = 2; i < aDim.size(); ++i) { + inDim[2] *= aDim[i]; + } + outDim = inDim; + } else if (axis == aDim.size() - 1) { + mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; + inDim[0] = aDim[0]; + for (size_t i = 1; i < axis; ++i) { + inDim[1] *= aDim[i]; + } + inDim[2] = aDim[axis]; + outDim = inDim; + } else { + mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + for (size_t i = 0; i < axis; ++i) { + inDim[0] *= aDim[i]; + } + inDim[1] = aDim[axis]; + for (size_t i = axis + 1; i < aDim.size(); ++i) { + inDim[2] *= aDim[i]; + } + outDim = inDim; } - outDim = inDim; - } else if (axis == aDim.size() - 1) { - mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; - inDim[0] = aDim[0]; - for (size_t i = 1; i < axis; ++i) { - inDim[1] *= aDim[i]; + } else if (aDim.size() == 2) { + if (axis == 0) { + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim = aDim; + inDim.push_back(1); + outDim = inDim; + } else { + mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; + inDim = aDim; + inDim.insert(inDim.begin(), 1); + outDim = inDim; } - inDim[2] = aDim[axis]; - outDim = inDim; } else { - mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; - for (size_t i = 0; i < axis; ++i) { - inDim[0] *= aDim[i]; - } - inDim[1] = aDim[axis]; - for (size_t i = axis + 1; i < aDim.size(); ++i) { - inDim[2] *= aDim[i]; - } + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim = aDim; + inDim.push_back(1); + inDim.push_back(1); outDim = inDim; } @@ -171,8 +192,8 @@ class SoftmaxCnnl : public BangKernelWithoutConfig { float beta = 0.0; cnnlStatus_t stat = cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE, - mode, CNNL_COMPUTATION_HIGH_PRECISION, &alpha, - aDesc, aData, &beta, cDesc, cData); + mode, CNNL_COMPUTATION_ULTRAHIGH_PRECISION, + &alpha, aDesc, aData, &beta, cDesc, cData); if (stat != CNNL_STATUS_SUCCESS) return; checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); diff --git a/src/kernels/bang/all_gather.cc b/src/kernels/bang/all_gather.cc new file mode 100644 index 00000000..d44569fe --- /dev/null +++ b/src/kernels/bang/all_gather.cc @@ -0,0 +1,49 @@ +#ifdef INFINI_USE_CNCL +#include "operators/all_gather.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class AllGatherCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + int world_size = op->getWorldSize(); + // Check if world size info in operator matches runtime + IT_ASSERT(world_size == context->getCommunicator().getWorldSize()); + + void *input = op->getInputs(0)->getRawDataPtr(); + BangPtr output_temp = + context->getWorkspace(op->getInputs(0)->getBytes() * world_size); + // void *output = op->getOutput()->getRawDataPtr(); + // IT_ASSERT(op->getDType() == DataType::Float32); + checkBangError(cnrtMalloc(&output_temp, + op->getInputs(0)->getBytes() * world_size)); + size_t bytes = op->getInputs(0)->getBytes(); + size_t count = bytes / op->getDType().getSize(); + + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + CNCL_CHECK( + cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue)); + checkBangError(cnrtQueueSync(queue)); + for (int i = 0; i < world_size; ++i) { + Tensor output = op->getOutput(i); + context->copyBlobInsideRuntime( + output->getRawDataPtr(), + static_cast(output_temp) + i * count, bytes); + } + checkBangError(cnrtFree(output_temp)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32, + AllGatherCNCL, "AllGather_CNCL_BANG_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/bang/all_reduce.cc b/src/kernels/bang/all_reduce.cc new file mode 100644 index 00000000..4e9266fb --- /dev/null +++ b/src/kernels/bang/all_reduce.cc @@ -0,0 +1,53 @@ +#ifdef INFINI_USE_CNCL +#include "operators/all_reduce.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class AllReduceCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->size(); + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + // checkBangError(cnrtQueueSync(queue)); + CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(), + comm, queue)); + checkBangError(cnrtQueueSync(queue)); + } + + virtual cnclReduceOp_t getRedOp() const = 0; +}; + +class AllReduceSumCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclSum; } +}; +class AllReduceProdCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclProd; } +}; +class AllReduceMinCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclMin; } +}; +class AllReduceMaxCNCL : public AllReduceCNCL { + cnclReduceOp_t getRedOp() const override { return cnclMax; } +}; + +REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32, + AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32, + AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32, + AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32, + AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32"); +} // namespace infini +#endif diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index d6b9ce53..a1bc81c0 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -17,51 +17,87 @@ class BatchNormCnnl : public BangKernelWithoutConfig { void *const output = (op->getOutput()->getRawDataPtr()); auto dims = op->getInputs(0)->getDims(); - + auto outDims = op->getOutput()->getDims(); if (dims.size() != 4) IT_TODO_HALT(); - int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1]; + int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]}; + int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]}; + int permute[4] = {0, 2, 3, 1}; + int permuteOut[4] = {0, 3, 1, 2}; - for (size_t i = 0; i < dims.size(); ++i) { - dimArray[i] = dims[i]; - strideArray[i] = op->getInputs(0)->getStride()[i]; - } - int w = dimArray[3]; - dimArray[3] = dimArray[1]; - int h = dimArray[2]; - dimArray[1] = h; - dimArray[2] = w; - - dimPArray[0] = op->getInputs(1)->getDims()[0]; - stridePArray[0] = op->getInputs(1)->getDims()[0]; // get inputs - cnnlTensorDescriptor_t inDesc; + cnnlTensorDescriptor_t inDesc, intransDesc, outDesc, outtransDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC, - CNNL_DTYPE_FLOAT, dims.size(), - dimArray, strideArray)); + checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc)); + checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); + checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc)); + checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, dims.size(), + dims.data())); + checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC, + CNNL_DTYPE_FLOAT, dims.size(), + dimsTrans)); + checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, outDims.size(), + outDims.data())); + checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC, + CNNL_DTYPE_FLOAT, outDims.size(), + dimsOutTrans)); + cnnlTransposeDescriptor_t opDesc; + checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); + checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute)); + size_t wsSize; + cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), inDesc, opDesc, + &wsSize); + BangPtr wsData = context->getWorkspace(wsSize); + BangPtr inputTrans = context->getWorkspace( + cnnlGetTensorElementNum(inDesc) * sizeof(float)); + BangPtr outputTrans = context->getWorkspace( + cnnlGetTensorElementNum(inDesc) * sizeof(float)); + cnnlStatus_t stat = + cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input, + intransDesc, inputTrans, wsData, wsSize); + if (stat != CNNL_STATUS_SUCCESS) + return; // get bnScaleBiasMeanVarDesc + auto dimsScaleBiasMeanVar = op->getInputs(1)->getDims(); cnnlTensorDescriptor_t paraDesc; checkCnnlError(cnnlCreateTensorDescriptor(¶Desc)); - checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, 1, dimPArray, - stridePArray)); + checkCnnlError(cnnlSetTensorDescriptor( + paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data())); float alpha = 1.f, beta = 0.f; // This mode is intended for use after convolutional layers - cnnlStatus_t stat = cnnlBatchNormForwardInference( - context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc, - scale, bias, mean, var, op->getEps(), inDesc, output); + stat = cnnlBatchNormForwardInference( + context->cnnlHandle(), &alpha, &beta, intransDesc, inputTrans, + paraDesc, scale, bias, mean, var, op->getEps(), outtransDesc, + outputTrans); + if (stat != CNNL_STATUS_SUCCESS) + return; + cnnlTransposeDescriptor_t op2Desc; + checkCnnlError(cnnlCreateTransposeDescriptor(&op2Desc)); + checkCnnlError(cnnlSetTransposeDescriptor(op2Desc, 4, permuteOut)); + cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), intransDesc, + op2Desc, &wsSize); + BangPtr ws2Data = context->getWorkspace(wsSize); + stat = cnnlTranspose_v2(context->cnnlHandle(), op2Desc, outtransDesc, + outputTrans, outDesc, output, ws2Data, wsSize); if (stat != CNNL_STATUS_SUCCESS) return; // Destories in BANG does not require sync. But cnnl does not state // whether sync is required before destories. checkCnnlError(cnnlDestroyTensorDescriptor(inDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(outDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(intransDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(outtransDesc)); checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc)); + checkCnnlError(cnnlDestroyTransposeDescriptor(opDesc)); + checkCnnlError(cnnlDestroyTransposeDescriptor(op2Desc)); } }; diff --git a/src/kernels/bang/broadcast.cc b/src/kernels/bang/broadcast.cc new file mode 100644 index 00000000..411506c5 --- /dev/null +++ b/src/kernels/bang/broadcast.cc @@ -0,0 +1,34 @@ +#ifdef INFINI_USE_CNCL +#include "operators/broadcast.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include +namespace infini { +class BroadcastCNCL : public BangKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + void *output = op->getOutput()->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + + cnclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = context->getBangQueue(); + // TODO: Using default stream 0 for now. + CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32, + op->getRoot(), comm, queue)); + checkBangError(cnrtQueueSync(queue)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32, + BroadcastCNCL, "Broadcast_CNCL_BANG_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/bang/gather.cc b/src/kernels/bang/gather.cc index b5a326fc..dc3ee636 100644 --- a/src/kernels/bang/gather.cc +++ b/src/kernels/bang/gather.cc @@ -23,6 +23,8 @@ class GatherCnnl : public BangKernelWithoutConfig { CNNL_DTYPE_FLOAT, aDim.size(), aDim.data())); checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); + checkCnnlError( + cnnlSetTensorDescriptorPointerMode(bDesc, CNNL_POINTER_MODE_HOST)); checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, bDim.size(), bDim.data())); diff --git a/src/kernels/bang/layer_norm.cc b/src/kernels/bang/layer_norm.cc new file mode 100644 index 00000000..231177c5 --- /dev/null +++ b/src/kernels/bang/layer_norm.cc @@ -0,0 +1,64 @@ +#include "operators/layer_norm.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" + +namespace infini { + +class LayerNormCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const scaleData = (op->getInputs(1)->getRawDataPtr()); + void *biasData = NULL; + if (op->numInputs() == 3) { + biasData = (op->getInputs(2)->getRawDataPtr()); + } + void *const outputData = (op->getOutput()->getRawDataPtr()); + + auto inDims = op->getInputs(0)->getDims(); + auto outDims = op->getOutput()->getDims(); + auto fiterDims = op->getOutput(1)->getDims(); + + float eps = op->getEps(); + const int axis = op->getAxis(); + + cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc; + + checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); + checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, inDims.size(), + inDims.data())); + checkCnnlError(cnnlCreateTensorDescriptor(&fiterDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + fiterDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, fiterDims.size(), + fiterDims.data())); + checkCnnlError(cnnlCreateTensorDescriptor(&outDesc)); + checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, outDims.size(), + outDims.data())); + size_t wsSize; + cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, + &wsSize); + BangPtr wsData = context->getWorkspace(wsSize); + + cnnlStatus_t stat = cnnlLayerNormForward( + context->cnnlHandle(), inDesc, inputData, axis, fiterDesc, + scaleData, biasData, eps, wsData, wsSize, outDesc, outputData, + inDesc, NULL, NULL); + + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(inDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(fiterDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(outDesc)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32, + LayerNormCnnl, "LayerNorm_BANG_Float32"); + +}; // namespace infini diff --git a/src/kernels/bang/lrn.cc b/src/kernels/bang/lrn.cc new file mode 100644 index 00000000..4183f0fd --- /dev/null +++ b/src/kernels/bang/lrn.cc @@ -0,0 +1,62 @@ +#include "operators/lrn.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" + +namespace infini { +class LRNCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + cnnlTensorDescriptor_t aDesc, cDesc; + auto aDim = op->getInputs(0)->getDims(); + auto cDim = op->getOutput()->getDims(); + auto [alpha, beta, bias] = op->getAlphaBetaBias(); + auto size = op->getSize(); + + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, aDim.size(), + aDim.data())); + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, cDim.size(), + cDim.data())); + + size_t extra_size; + cnnlGetLrnExtraInputSize_v2(context->cnnlHandle(), cDesc, + CNNL_LRN_LOCAL_SIZE, size, &extra_size); + void *extra_cpu = NULL; + extra_cpu = malloc(extra_size); + BangPtr extra_mlu = context->getWorkspace(extra_size); + cnnlInitLrnExtraInput(context->cnnlHandle(), CNNL_LRN_LOCAL_SIZE, size, + (double)alpha, (double)beta, (double)bias, aDesc, + cDesc, extra_cpu); + cnrtMemcpy(extra_mlu, extra_cpu, extra_size, + CNRT_MEM_TRANS_DIR_HOST2DEV); + + size_t wsSize; + cnnlGetLrnWorkspaceSize_v2(context->cnnlHandle(), aDesc, cDesc, + CNNL_LRN_LOCAL_SIZE, size, &wsSize); + BangPtr wsData = context->getWorkspace(wsSize); + + cnnlStatus_t stat = cnnlLrn_v2( + context->cnnlHandle(), CNNL_LRN_LOCAL_SIZE, size, (double)alpha, + (double)beta, (double)bias, wsData, wsSize, aDesc, aData, extra_mlu, + extra_size, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl, + "LRN_cnnl_BANG_Float32"); + +}; // namespace infini diff --git a/src/kernels/bang/matmul.cc b/src/kernels/bang/matmul.cc index 39888e71..368d6b1c 100644 --- a/src/kernels/bang/matmul.cc +++ b/src/kernels/bang/matmul.cc @@ -10,15 +10,29 @@ class MatmulCnnl : public BangKernelWithoutConfig { auto op = as(_op); auto context = dynamic_cast(_context); + auto input_num = op->numInputs(); + void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *biasData = NULL; + if (input_num > 2) { + biasData = (op->getInputs(2)->getRawDataPtr()); + } void *const cData = (op->getOutput()->getRawDataPtr()); - cnnlTensorDescriptor_t aDesc, bDesc, cDesc; + cnnlTensorDescriptor_t aDesc, bDesc, cDesc, biasDesc; auto dimInputs0 = op->getInputs(0)->getDims(); auto dimInputs1 = op->getInputs(1)->getDims(); + std::vector dimBias; + if (input_num > 2) { + dimBias = op->getInputs(2)->getDims(); + } + auto dimOutput = op->getOutput()->getDims(); + float alpha = 1.0; + float beta = 0.0; + int32_t transA = op->getTransA(); int32_t transB = op->getTransB(); @@ -37,6 +51,13 @@ class MatmulCnnl : public BangKernelWithoutConfig { cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimOutput.size(), dimOutput.data())); + if (input_num > 2) { + checkCnnlError(cnnlCreateTensorDescriptor(&biasDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + biasDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dimBias.size(), + dimBias.data())); + } + cnnlMatMulDescriptor_t bmm_desc; cnnlMatMulDescCreate(&bmm_desc); cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSA, &transA, @@ -47,8 +68,6 @@ class MatmulCnnl : public BangKernelWithoutConfig { cnnlMatMulAlgo_t bmm_algo; cnnlMatMulAlgoCreate(&bmm_algo); - float alpha = 1.0; - float beta = 0.0; int count = 0; cnnlMatMulHeuristicResult_t desc; @@ -66,9 +85,22 @@ class MatmulCnnl : public BangKernelWithoutConfig { if (stat != CNNL_STATUS_SUCCESS) return; + wsData = NULL; + if (input_num > 2) { + cnnlGetBiasAddWorkspaceSize(context->cnnlHandle(), biasDesc, cDesc, + &wsSize); + stat = cnnlBiasAdd(context->cnnlHandle(), &alpha, biasDesc, + biasData, wsData, wsSize, &alpha, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + } + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); checkCnnlError(cnnlDestroyTensorDescriptor(bDesc)); checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + if (input_num > 2) { + checkCnnlError(cnnlDestroyTensorDescriptor(biasDesc)); + } checkCnnlError(cnnlMatMulDescDestroy(bmm_desc)); checkCnnlError(cnnlMatMulAlgoDestroy(bmm_algo)); checkCnnlError(cnnlDestroyMatMulHeuristicResult(desc)); diff --git a/src/kernels/bang/pad.cc b/src/kernels/bang/pad.cc index e211ee93..c2503ca0 100644 --- a/src/kernels/bang/pad.cc +++ b/src/kernels/bang/pad.cc @@ -13,14 +13,14 @@ class PadCnnl : public BangKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); cnnlTensorDescriptor_t aDesc, cDesc; - auto dim = op->getOutput()->getDims(); - int dim_size = dim.size(); - int dim_array[dim_size]; - for (int i = 0; i < dim_size; ++i) { - dim_array[i] = dim[i]; - } + auto dimIn = op->getInputs(0)->getDims(); + auto dimOut = op->getOutput()->getDims(); + + int dim_size = dimIn.size(); int paddings[dim_size * 2]; + std::vector pads = op->getPads(); + if (pads.size() == 2 && dim_size != 1) { for (int i = 0; i < dim_size * 2; i += 2) { paddings[i] = pads[0]; @@ -32,20 +32,18 @@ class PadCnnl : public BangKernelWithoutConfig { paddings[i + 1] = pads[i / 2 + dim_size]; } } - int dimout_array[dim_size]; - for (int i = 0; i < dim_size; ++i) { - dimout_array[i] = dim[i] + paddings[2 * i] + paddings[2 * i + 1]; - } + float paddingValue = 0.0; // input checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); - checkCnnlError(cnnlSetTensorDescriptor( - aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dim_size, dim_array)); + checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, dimIn.size(), + dimIn.data())); // output checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, dim_size, - dimout_array)); + CNNL_DTYPE_FLOAT, dimOut.size(), + dimOut.data())); cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData, paddings, &paddingValue, cDesc, cData); diff --git a/src/kernels/bang/pooling.cc b/src/kernels/bang/pooling.cc index 8a91b466..f3cf04bc 100644 --- a/src/kernels/bang/pooling.cc +++ b/src/kernels/bang/pooling.cc @@ -21,13 +21,14 @@ class PoolingCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inArray)); + bool mode = op->getCeilMode(); // get maxpool descriptor cnnlPoolingDescriptor_t poolingDesc; checkCnnlError(cnnlCreatePoolingDescriptor(&poolingDesc)); checkCnnlError(cnnlSetPooling2dDescriptor_v2( poolingDesc, getPoolingMode(), CNNL_NOT_PROPAGATE_NAN, kh, kw, ph, - ph, pw, pw, sh, sw, dh, dw, false)); + ph, pw, pw, sh, sw, dh, dw, mode)); // get outputs // TODO: verify ceiling mode diff --git a/src/kernels/bang/reduce_mean.cc b/src/kernels/bang/reduce.cc similarity index 82% rename from src/kernels/bang/reduce_mean.cc rename to src/kernels/bang/reduce.cc index 1b55c2ca..88d1e645 100644 --- a/src/kernels/bang/reduce_mean.cc +++ b/src/kernels/bang/reduce.cc @@ -1,12 +1,14 @@ +#include "operators/reduce.h" #include "bang/bang_kernel_without_config.h" #include "bang/bang_runtime.h" -#include "operators/reduce.h" namespace infini { -class ReduceMeanCnnl : public BangKernelWithoutConfig { +class ReduceCnnlBase : public BangKernelWithoutConfig { + virtual cnnlReduceOp_t getReduceOp() const = 0; + void compute(const Operator &_op, const RuntimeObj *_context) const override { - auto op = as(_op); + auto op = as(_op); auto context = dynamic_cast(_context); void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); @@ -34,7 +36,7 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig { cnnlReduceDescriptor_t reduceDesc; checkCnnlError(cnnlCreateReduceDescriptor(&reduceDesc)); checkCnnlError(cnnlSetReduceDescriptor_v2( - reduceDesc, axes.data(), axes.size(), CNNL_REDUCE_AVG, + reduceDesc, axes.data(), axes.size(), getReduceOp(), CNNL_DTYPE_FLOAT, CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); @@ -63,7 +65,17 @@ class ReduceMeanCnnl : public BangKernelWithoutConfig { } }; +class ReduceMeanCnnl : public ReduceCnnlBase { + cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_AVG; } +}; + +class ReduceSumCnnl : public ReduceCnnlBase { + cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; } +}; + REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32, ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32, + ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/bang/reshape.cc b/src/kernels/bang/reshape.cc index 564ed1d7..f5628a7b 100644 --- a/src/kernels/bang/reshape.cc +++ b/src/kernels/bang/reshape.cc @@ -27,6 +27,8 @@ class CopyBang : public BangKernelWithoutConfig { // reshape/flatten/identity all act as copying from input to output. REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang, "Reshape_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang, + "Reshape_BANG_Int64"); REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang, "Flatten_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang, diff --git a/src/kernels/bang/slice.cc b/src/kernels/bang/slice.cc new file mode 100644 index 00000000..5cc772aa --- /dev/null +++ b/src/kernels/bang/slice.cc @@ -0,0 +1,64 @@ +#include "operators/slice.h" +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" + +namespace infini { +class SliceCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + auto starts = op->getStarts(); + auto ends = op->getEnds(); + auto steps = op->getSteps(); + + int32_t starts_array[starts.size()]; + int32_t ends_array[ends.size()]; + int32_t steps_array[steps.size()]; + + for (size_t i = 0; i < starts.size(); i++) { + starts_array[i] = starts[i]; + ends_array[i] = ends[i]; + steps_array[i] = steps[i]; + } + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + int aDim_size = aDim.size(); + int aDim_array[aDim_size]; + for (int i = 0; i < aDim_size; ++i) { + aDim_array[i] = aDim[i]; + } + auto cDim = op->getOutput()->getDims(); + int cDim_size = cDim.size(); + int cDim_array[cDim_size]; + for (int i = 0; i < cDim_size; ++i) { + cDim_array[i] = cDim[i]; + } + cnnlTensorDescriptor_t aDesc, cDesc; + // input + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, aDim_size, aDim_array)); + // output + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, cDim_size, cDim_array)); + + cnnlStatus_t stat = + cnnlStridedSlice(context->cnnlHandle(), aDesc, aData, starts_array, + ends_array, steps_array, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::Slice, DataType::Float32, SliceCnnl, + "Slice_cnnl_BANG_Float32"); +}; // namespace infini diff --git a/src/kernels/cpu/element_wise.cc b/src/kernels/cpu/element_wise.cc index 8d225779..ff03350c 100644 --- a/src/kernels/cpu/element_wise.cc +++ b/src/kernels/cpu/element_wise.cc @@ -1,5 +1,6 @@ #include "operators/element_wise.h" #include "core/kernel.h" +#include "utils/operator_utils.h" namespace infini { template class NativeElementWise : public CpuKernelWithoutConfig { @@ -11,37 +12,34 @@ template class NativeElementWise : public CpuKernelWithoutConfig { T *inptr1 = op->getInputs(1)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); - int a[4] = {1, 1, 1, 1}; - int b[4] = {1, 1, 1, 1}; - int c[4] = {1, 1, 1, 1}; - auto a_input = op->getInputs(0)->getDims(); - auto b_input = op->getInputs(1)->getDims(); - auto c_output = op->getOutput()->getDims(); - std::copy(a_input.begin(), a_input.end(), a + (4 - a_input.size())); - std::copy(b_input.begin(), b_input.end(), b + (4 - b_input.size())); - std::copy(c_output.begin(), c_output.end(), c + (4 - c_output.size())); + auto shapeA = op->getInputs(0)->getDims(); + auto shapeB = op->getInputs(1)->getDims(); + auto shapeC = op->getOutput()->getDims(); + auto rank = op->getOutput()->getRank(); + Shape a(rank, 1); + Shape b(rank, 1); + std::copy(shapeA.begin(), shapeA.end(), + a.begin() + (rank - shapeA.size())); + std::copy(shapeB.begin(), shapeB.end(), + b.begin() + (rank - shapeB.size())); + auto getStride = [&](const Shape &shape) { + int p = 1; + Shape stride(rank); + for (auto i = rank; i > 0; --i) { + stride[i - 1] = p; + p = p * shape[i - 1]; + } + return stride; + }; + Shape strideA = getStride(a); + Shape strideB = getStride(b); auto n = op->getOutput()->size(); for (size_t i = 0; i < n; ++i) { - int c0_index = i / (c[1] * c[2] * c[3]); - int c1_index = (i % (c[1] * c[2] * c[3])) / (c[2] * c[3]); - int c2_index = ((i % (c[1] * c[2] * c[3])) % (c[2] * c[3])) / c[3]; - int c3_index = ((i % (c[1] * c[2] * c[3])) % (c[2] * c[3])) % c[3]; - - int a0_index = c0_index % a[0]; - int a1_index = c1_index % a[1]; - int a2_index = c2_index % a[2]; - int a3_index = c3_index % a[3]; - - int b0_index = c0_index % b[0]; - int b1_index = c1_index % b[1]; - int b2_index = c2_index % b[2]; - int b3_index = c3_index % b[3]; - outptr[i] = doCompute( - inptr0[a0_index * a[1] * a[2] * a[3] + a1_index * a[2] * a[3] + - a2_index * a[3] + a3_index], - inptr1[b0_index * b[1] * b[2] * b[3] + b1_index * b[2] * b[3] + - b2_index * b[3] + b3_index]); + auto shapeIndexC = locate_index(i, shapeC); + auto indexA = delocate_index(shapeIndexC, a, strideA); + auto indexB = delocate_index(shapeIndexC, b, strideB); + outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]); } } }; diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu index f119bd9c..cd6bc37b 100644 --- a/src/kernels/cuda/pad_slice.cu +++ b/src/kernels/cuda/pad_slice.cu @@ -35,7 +35,7 @@ __global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData, whole[tid] = 0; else whole[tid] = part[offset]; - else + else if (offset >= 0) part[offset] = whole[tid]; tid += stride; } diff --git a/src/operators/lrn.cc b/src/operators/lrn.cc new file mode 100644 index 00000000..5cdc29a6 --- /dev/null +++ b/src/operators/lrn.cc @@ -0,0 +1,36 @@ +#include "operators/lrn.h" +#include "utils/operator_utils.h" + +namespace infini { + +LRNObj::LRNObj(GraphObj *graph, Tensor input, Tensor output, float alpha, + float beta, float bias, int size) + : OperatorObj(OpType::LRN, TensorVec{input}, {output}), alpha_value(alpha), + beta_value(beta), bias_value(bias), size_value(size) { + IT_ASSERT(checkValid(graph)); +} + +optional> LRNObj::inferShape(const TensorVec &inputs) { + const auto A = inputs[0]; + return {{A->getDims()}}; +} + +std::string LRNObj::toString() const { + std::ostringstream os; + os << "LRN[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector LRNObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector LRNObj::getOpAttrVector() const { return {type.underlying()}; } + +} // namespace infini diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index a9b81a5e..6687a8fd 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -1,4 +1,5 @@ #include "utils/operator_utils.h" +#include "core/runtime.h" namespace infini { @@ -64,4 +65,54 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) { } return true; } + +Shape locate_index(size_t inputN, const Shape &shape) { + Shape ans(shape.size()); + auto i = ans.rbegin(); + auto j = shape.rbegin(), ej = shape.rend(); + while (j != ej) { + auto div = std::div(inputN, *j++); + *i++ = div.rem; + inputN = div.quot; + } + return ans; +} + +size_t delocate_index(const Shape &shapeIndex, const Shape &shape, + const Shape &stride) { + size_t ans = 0; + Shape index(shapeIndex.size()); + IT_ASSERT(shapeIndex.size() == shape.size()); + IT_ASSERT(shape.size() == stride.size()); + for (size_t i = 0; i < shape.size(); ++i) { + index[i] = shapeIndex[i] % shape[i]; + ans += index[i] * stride[i]; + } + return ans; +} + +std::string device_to_str(Device device) { + std::string deviceStr; + switch (device) { + case Device::CPU: + return "CPU"; + case Device::CUDA: + return "CUDA"; + case Device::BANG: + return "BANG"; + case Device::INTELCPU: + return "INTELCPU"; + case Device::KUNLUN: + return "KUNLUN"; + default: + IT_TODO_HALT(); + } +} + +std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) { + std::string deviceStr = device_to_str(std::get<0>(kernelAttrs)); + std::string opStr = OpType(std::get<1>(kernelAttrs)).toString(); + std::string datatypeStr = std::get<2>(kernelAttrs).toString(); + return deviceStr + ", " + opStr + ", " + datatypeStr; +} } // namespace infini diff --git a/test/bang/test_cncl_comm.cc b/test/bang/test_cncl_comm.cc new file mode 100644 index 00000000..50b47434 --- /dev/null +++ b/test/bang/test_cncl_comm.cc @@ -0,0 +1,58 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "test.h" + +static int WORLD_SIZE = 2; + +namespace infini { + +void allReduceSum(float *data, int deviceId) { + // Create Runtime and setup communication + BangRuntimeObj *bang_runtime = new BangRuntimeObj(deviceId); + int rank = deviceId; + bang_runtime->initComm("test_cncl_comm", WORLD_SIZE, rank); + cnclComm_t comm = + dynamic_cast(bang_runtime->getCommunicator()) + .getCnclComm(); + cnrtQueue_t queue = bang_runtime->getBangQueue(); + // Copy data + float *data_mlu; + checkBangError(cnrtMalloc((void **)&data_mlu, sizeof(float))); + checkBangError( + cnrtMemcpy(data_mlu, data, sizeof(float), cnrtMemcpyHostToDev)); + // Do AllReduce + CNCL_CHECK( + cnclAllReduce(data_mlu, data_mlu, 1, cnclFloat, cnclSum, comm, queue)); + + checkBangError(cnrtQueueSync(queue)); + // Copy data back and sync device + checkBangError( + cnrtMemcpy(data, data_mlu, sizeof(float), cnrtMemcpyDevToHost)); + ASSERT_EQ(*data, 5.0f); +} + +// Setup communication between 2 threads, each controlling 1 MLU. +// Do AllReduce Sum on {1.0, 4.0}. Results should be {5.0, 5.0}. +TEST(CNCL, multi_mlu_communication) { + float data[] = {1.0, 4.0}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduceSum(&data[i], i); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_all_gather.cc b/test/kernels/bang/test_bang_all_gather.cc new file mode 100644 index 00000000..038cc7ab --- /dev/null +++ b/test/kernels/bang/test_bang_all_gather.cc @@ -0,0 +1,60 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_gather.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +void allGather(const string taskName, int deviceID, vector data, + vector> ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, std::nullopt, WORLD_SIZE); + // Copy data from CPU to MLU + g->dataMalloc(); + input->copyin(data); + // Run operation + bangRuntime->run(g); + // Copy output from MLU to CPU + for (int i = 0; i < WORLD_SIZE; ++i) { + auto result = op->getOutputs()[i]->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(ans[i])); + } +} + +TEST(BANG_AllGather, run) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector> ans = {{2., 3.}, {5., 6.}}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allGather("test_all_gather", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_all_reduce.cc b/test/kernels/bang/test_bang_all_reduce.cc new file mode 100644 index 00000000..a10a9288 --- /dev/null +++ b/test/kernels/bang/test_bang_all_reduce.cc @@ -0,0 +1,124 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/all_reduce.h" +#include "test.h" +#include +#include +#include + +static int WORLD_SIZE = 2; + +namespace infini { + +template +void allReduce(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr); + // Copy data from CPU to MLU + g->dataMalloc(); + input->copyin(data); + // Run operation + bangRuntime->run(g); + // Copy output from MLU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(BANG_AllReduce, sum) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {7., 9.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_sum", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, prod) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {10., 18.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_prod", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, min) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {2., 3.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_min", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +TEST(BANG_AllReduce, max) { + vector data[2] = {{2., 3.}, {5., 6.}}; + vector ans = {5., 6.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + allReduce("test_allreduce_max", i, data[i], ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_batch_norm.cc b/test/kernels/bang/test_bang_batch_norm.cc new file mode 100644 index 00000000..cf79ff8d --- /dev/null +++ b/test/kernels/bang/test_bang_batch_norm.cc @@ -0,0 +1,57 @@ +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/batch_norm.h" +#include "test.h" + +namespace infini { + +TEST(BANG_BatchNorm, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build cpu graph + Graph gCpu = make_ref(cpuRuntime); + auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); + auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32); + + // Build input data on CPU + gCpu->dataMalloc(); + iCpu->setData(IncrementalGenerator()); + meanCpu->copyin(vector{1, 6, 9}); + varCpu->copyin(vector{4, 1, 9}); + scaleCpu->setData(OneGenerator()); + biasCpu->setData(ZeroGenerator()); + + Graph g = make_ref(bangRuntime); + + auto i = g->cloneTensor(iCpu); + auto mean = g->cloneTensor(meanCpu); + auto var = g->cloneTensor(varCpu); + auto scale = g->cloneTensor(scaleCpu); + auto bias = g->cloneTensor(biasCpu); + auto op = + g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); + + g->dataMalloc(); + i->setData(IncrementalGenerator()); + mean->copyin(vector{1, 6, 9}); + var->copyin(vector{4, 1, 9}); + scale->setData(OneGenerator()); + bias->setData(ZeroGenerator()); + + bangRuntime->run(g); + + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2})); + EXPECT_TRUE(ocpu->equalData(vector{ + -0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.3333333, 0.6666667})); +} +} // namespace infini diff --git a/test/kernels/bang/test_bang_broadcast.cc b/test/kernels/bang/test_bang_broadcast.cc new file mode 100644 index 00000000..e05666f7 --- /dev/null +++ b/test/kernels/bang/test_bang_broadcast.cc @@ -0,0 +1,65 @@ +#ifdef INFINI_USE_CNCL +#include "bang/bang_runtime.h" +#include "bang/cncl_communicator.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/broadcast.h" +#include "test.h" +#include +#include + +static int WORLD_SIZE = 2; +static int root = 0; + +namespace infini { + +void broadcast(const string taskName, int deviceID, vector data, + vector ans) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime bangRuntime = make_ref(deviceID); + bangRuntime->initComm(taskName, WORLD_SIZE, deviceID); + // Create Graph and insert allReduce operation + Graph g = make_ref(bangRuntime); + auto input = + g->addTensor(Shape{static_cast(data.size())}, DataType::Float32); + auto op = g->addOp(input, nullptr, root); + // Copy data from CPU to GPU + g->dataMalloc(); + // Only rank 0 has the data + if (deviceID == root) { + input->copyin(data); + } + // Run broadcast operation + bangRuntime->run(g); + // Copy output from GPU to CPU + auto result = op->getOutput()->clone(cpuRuntime); + + EXPECT_TRUE(result->equalData(ans)); +} + +TEST(BANG_Broadcast, run) { + // Only 1 device gets data. Every rank should have the same data after + // broadcast. + vector data = {2., 3., 5., 6.}; + vector ans = {2., 3., 5., 6.}; + + for (int i = 0; i < WORLD_SIZE; ++i) { + pid_t pid = fork(); + if (pid == 0) { + // Child process + broadcast("test_broadcast", i, data, ans); + exit(0); // Ensure child process exits to avoid unnecessary + // repetition in parent + } else if (pid < 0) { + std::cerr << "Error creating process" << std::endl; + } + } + // Wait for all child processes to finish + for (int i = 0; i < WORLD_SIZE; ++i) { + wait(NULL); + } +} + +} // namespace infini +#endif diff --git a/test/kernels/bang/test_bang_concat.cc b/test/kernels/bang/test_bang_concat.cc index 4cf130e3..3e0c2775 100644 --- a/test/kernels/bang/test_bang_concat.cc +++ b/test/kernels/bang/test_bang_concat.cc @@ -32,6 +32,8 @@ void testConcat(const std::function &generator, auto gpuOp = bangGraph->addOp(TensorVec{inputGpu1, inputGpu2}, nullptr, 2); bangGraph->dataMalloc(); + inputGpu1->setData(generator); + inputGpu2->setData(generator); bangRuntime->run(bangGraph); auto outputGpu = gpuOp->getOutput(); auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); diff --git a/test/kernels/bang/test_bang_pooling.cc b/test/kernels/bang/test_bang_pooling.cc index 4bbc8091..8032f213 100644 --- a/test/kernels/bang/test_bang_pooling.cc +++ b/test/kernels/bang/test_bang_pooling.cc @@ -18,8 +18,14 @@ void testPooling(const std::function &generator, // Build input data on CPU Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); - inputCpu->dataMalloc(); + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = + cpuGraph->addOp(inputCpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0); + cpuGraph->addTensor(inputCpu); + cpuGraph->dataMalloc(); inputCpu->setData(generator); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); // GPU Graph bangGraph = make_ref(bangRuntime); @@ -27,17 +33,16 @@ void testPooling(const std::function &generator, auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 3, 3, 1, 1, 1, 1, 2, 2, 0); bangGraph->dataMalloc(); + inputGpu->setData(generator); bangRuntime->run(bangGraph); auto outputGpu = gpuOp->getOutput(); auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); - inputCpu->printData(); - outputGpu2Cpu->printData(); EXPECT_TRUE(1); } TEST(cnnl_Pooling, run) { - testPooling(IncrementalGenerator(), Shape{1, 1, 5, 5}); - testPooling(IncrementalGenerator(), Shape{1, 1, 5, 5}); + testPooling(IncrementalGenerator(), Shape{1, 3, 5, 5}); + testPooling(IncrementalGenerator(), Shape{1, 3, 5, 5}); } } // namespace infini diff --git a/test/kernels/bang/test_bang_reduce.cc b/test/kernels/bang/test_bang_reduce.cc new file mode 100644 index 00000000..485e16f0 --- /dev/null +++ b/test/kernels/bang/test_bang_reduce.cc @@ -0,0 +1,82 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/reduce.h" + +#include "test.h" + +namespace infini { + +template +void test_reduce(const Shape &shape, const vector &data, + const optional> &axis, bool keepDims, + const vector &ExpectData) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // Build BANG graph + Graph g = make_ref(bangRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, axis, keepDims); + + // allocate BANG memory + g->dataMalloc(); + i->copyin(data); + + // Execute on BANG + bangRuntime->run(g); + + // clone BANG output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(ExpectData)); +} + +TEST(BANG_ReduceMean, run) { + test_reduce( + Shape{3, 2, 2}, vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reduce( + Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt, + false, vector{18.25}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); +} + +TEST(BANG_ReduceSum, run) { + test_reduce(Shape{3, 2, 2}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, true, vector{12}); + test_reduce(Shape{1, 3, 2, 2, 1}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, false, vector{12}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{30, 36, 102, 108}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{30, 36, 102, 108}); +} + +} // namespace infini diff --git a/test/kernels/bang/test_bang_slice.cc b/test/kernels/bang/test_bang_slice.cc new file mode 100644 index 00000000..0f932409 --- /dev/null +++ b/test/kernels/bang/test_bang_slice.cc @@ -0,0 +1,39 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/slice.h" +#include "test.h" + +namespace infini { +TEST(BANG_Slice, run) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = + make_ref(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->setData(IncrementalGenerator()); + + // Build CUDA graph; + Graph g = make_ref(bangRuntime); + auto i = g->cloneTensor(icpu); + auto op = + g->addOp(i, nullptr, vector{1, 1}, vector{2, 5}, + vector{0, 3}, std::nullopt); + + // allocate CUDA memory + g->dataMalloc(); + i->setData(IncrementalGenerator()); + + // Execute on CUDA + bangRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto cpuo = o->clone(cpuRuntime); + // bangPrintTensor(o); + // check results on CPU + EXPECT_TRUE(cpuo->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); +} +} // namespace infini diff --git a/test/kernels/bang/test_bang_softmax.cc b/test/kernels/bang/test_bang_softmax.cc new file mode 100644 index 00000000..0ce65776 --- /dev/null +++ b/test/kernels/bang/test_bang_softmax.cc @@ -0,0 +1,131 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/softmax.h" +#include "test.h" +#include +namespace infini { + +TEST(cuDNN_Softmax, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 1); + bangGraph->dataMalloc(); + inputGpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData( + vector{0.032058604, 0.08714432, 0.23688284, 0.6439143, + 0.032058604, 0.08714432, 0.23688284, 0.6439143})); +} + +TEST(cuDNN_Softmax, run_axis0) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 0); + bangGraph->dataMalloc(); + inputGpu->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + EXPECT_TRUE( + outputGpu2Cpu->equalData(vector{0., 0., 0., 0., 1, 1, 1, 1})); +} + +TEST(cuDNN_Softmax2, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 1); + bangGraph->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138, + 0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862, + 0.9820138, 0.9820138, 0.9820138, 0.9820138})); +} + +TEST(cuDNN_Softmax2, run_axis2) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 2); + bangGraph->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029, + 0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971, + 0.1192029, 0.1192029, 0.8807971, 0.8807971})); +} + +TEST(cuDNN_Softmax2, run_axis3) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, 3); + bangGraph->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + EXPECT_TRUE(outputGpu2Cpu->equalData(vector{ + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586})); +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 2cab944e..43700b77 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -73,6 +73,38 @@ TEST(Split, CudaHigh) { 44., 45., 46., 47.})); } +TEST(Split, SplitWithRatio) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + vector split = {2, 4}; + auto op = gCuda->addOp(inputGpu, std::nullopt, 1, split); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)2); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + EXPECT_TRUE( + o0Cpu->equalData(vector{0., 1., 2., 3., 4., 5., 6., 7., 24., 25., + 26., 27., 28., 29., 30., 31.})); + EXPECT_TRUE(o1Cpu->equalData( + vector{8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 32., 33., 34., 35., 36., 37., + 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.})); +} + TEST(Split, Cuda_dim0) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); diff --git a/test/kernels/nativecpu/test_nativecpu_elementwise.cc b/test/kernels/nativecpu/test_nativecpu_elementwise.cc new file mode 100644 index 00000000..c6ef1911 --- /dev/null +++ b/test/kernels/nativecpu/test_nativecpu_elementwise.cc @@ -0,0 +1,44 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testElementWiseNativeCpu( + const std::function &generator1, + const std::function &generator2, + const Shape &shape1, const Shape &shape2, const ExpectOutput &ansVec) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto t1 = g->addTensor(shape1, DataType::Float32); + auto t2 = g->addTensor(shape2, DataType::Float32); + + auto op = g->addOp(t1, t2, nullptr); + g->dataMalloc(); + t1->setData(generator1); + t2->setData(generator2); + + runtime->run(g); + EXPECT_TRUE(op->getOutput()->equalData(ansVec)); +} + +TEST(ElementWise, NativeCpu) { + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 4, 5, 6, 6, 7, 8, 10, 11, 12}); + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 0, 0, 3, 4, 5, 0, 0, 0, 9, 10, 11}); + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 2, 3, 4, 6, 7, 8, 8, 9, 10}); + testElementWiseNativeCpu( + IncrementalGenerator(), OneGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); +} + +} // namespace infini