forked from jiuyuan/InfiniTensor
Merge branch 'master' into ascend
This commit is contained in:
commit
c970c93ba1
|
@ -14,10 +14,10 @@ env:
|
|||
protobuf-version: "3.21.12"
|
||||
python-version: "3.10"
|
||||
|
||||
resnet-download: https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx
|
||||
inception-download: https://media.githubusercontent.com/media/onnx/models/main/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx
|
||||
densenet-download: https://github.com/onnx/models/raw/main/vision/classification/densenet-121/model/densenet-12.onnx
|
||||
efficientnet-download: https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx
|
||||
resnet-download: https://github.com/InfiniTensor/InfiniTensor/releases/download/test-models/resnet18-v2-7.onnx
|
||||
inception-download: https://github.com/InfiniTensor/InfiniTensor/releases/download/test-models/inception-v2-9.onnx
|
||||
densenet-download: https://github.com/InfiniTensor/InfiniTensor/releases/download/test-models/densenet-12.onnx
|
||||
efficientnet-download: https://github.com/InfiniTensor/InfiniTensor/releases/download/test-models/efficientnet-lite4-11.onnx
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
|
|
@ -14,7 +14,7 @@ if(USE_CUDA)
|
|||
message("CMake 3.18 or higher is required for setting CUDAToolkit")
|
||||
cmake_minimum_required(VERSION 3.18) # FindCUDAToolkit
|
||||
else()
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
endif()
|
||||
|
||||
include(CMakeDependentOption)
|
||||
|
@ -22,7 +22,6 @@ project(InfiniTensor C CXX)
|
|||
|
||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
||||
|
||||
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||
# Build Type
|
||||
|
@ -96,16 +95,17 @@ add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
|||
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||
|
||||
# TVM backend
|
||||
if(BUILD_TEST_EINNET)
|
||||
if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR)
|
||||
message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON")
|
||||
endif()
|
||||
if(BUILD_NNET AND BUILD_TEST)
|
||||
# TVM and DMLC for invoking TVM packed functions
|
||||
include_directories(${TVM_INCLUDE_DIR})
|
||||
include_directories(${DMLC_INCLUDE_DIR})
|
||||
include_directories(${DLPACK_INCLUDE_DIR})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels
|
||||
if (TVM_INCLUDE_DIR AND DMLC_INCLUDE_DIR AND DLPACK_INCLUDE_DIR AND DLPACK_INCLUDE_DIR)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels
|
||||
else()
|
||||
# message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_NNET AND BUILD_TEST is ON")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(BUILD_TEST)
|
||||
|
@ -119,7 +119,7 @@ if(BUILD_TEST)
|
|||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=pointer-arith")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||
|
||||
|
@ -131,6 +131,8 @@ if(BUILD_NNET)
|
|||
add_compile_definitions(BUILD_NNET=1)
|
||||
file(GLOB_RECURSE SRC_NNET src/nnet/*.cc)
|
||||
list (APPEND SRC ${SRC_NNET})
|
||||
# For locating resource files
|
||||
set_source_files_properties(src/nnet/test.cc PROPERTIES COMPILE_OPTIONS "-DINFINI_PROJECT_HOME=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
|
@ -167,7 +169,7 @@ endif()
|
|||
target_link_libraries(InfiniTensor pybind11::embed)
|
||||
|
||||
# TVM backend
|
||||
if(BUILD_TEST_EINNET)
|
||||
if(BUILD_NNET AND BUILD_TEST AND TVM_LIB_DIR)
|
||||
target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so)
|
||||
endif()
|
||||
|
||||
|
@ -249,6 +251,7 @@ if(USE_BANG)
|
|||
find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64")
|
||||
find_library(CAMBRICON_CNCL libcncl.so "${NEUWARE_HOME}/lib64")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
|
@ -265,7 +268,13 @@ if(USE_BANG)
|
|||
# BangC Kernels
|
||||
################################################################################
|
||||
|
||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
target_link_libraries(InfiniTensor ${CAMBRICON_CNCL} ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
if (BUILD_DIST)
|
||||
message(STATUS "Add BUILD_DIST, use CNCL with BANG")
|
||||
|
||||
add_compile_definitions(INFINI_USE_CNCL=1)
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_KUNLUN)
|
||||
|
@ -357,6 +366,7 @@ if(BUILD_TEST)
|
|||
endif()
|
||||
if (USE_BANG)
|
||||
build_test(test/kernels/bang/*.cc)
|
||||
build_test(test/bang/*.cc)
|
||||
endif()
|
||||
if (USE_KUNLUN)
|
||||
build_test(test/kernels/kunlun/*.cc)
|
||||
|
@ -371,7 +381,7 @@ if(BUILD_TEST)
|
|||
if(BUILD_TEST_PET)
|
||||
build_test(test/pet/*.cc)
|
||||
endif()
|
||||
if(BUILD_TEST_EINNET)
|
||||
if(BUILD_NNET AND BUILD_TEST)
|
||||
build_test(test/nnet/test_*.cc)
|
||||
|
||||
# Build expression reader
|
||||
|
|
9
Makefile
9
Makefile
|
@ -8,12 +8,13 @@ ASCEND ?= OFF
|
|||
INTELCPU ?= off
|
||||
BACKTRACE ?= ON
|
||||
TEST ?= ON
|
||||
NNET ?= OFF
|
||||
FORMAT_ORIGIN ?=
|
||||
# Docker build options
|
||||
DOCKER_NAME ?= infinitensor
|
||||
DOCKER_IMAGE_NAME ?= infinitensor
|
||||
DOCKER_FILE ?= infinitensor_ubuntu_22.04.dockerfile
|
||||
DOCKER_RUN_OPTION ?=
|
||||
DOCKER_RUN_OPTION ?=
|
||||
|
||||
# CUDA option.
|
||||
ifeq ($(CUDA), ON)
|
||||
|
@ -23,7 +24,6 @@ ifeq ($(CUDA), ON)
|
|||
DOCKER_RUN_OPTION += --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v `pwd`:`pwd` -w `pwd`
|
||||
endif
|
||||
|
||||
|
||||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
||||
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||
|
@ -31,6 +31,7 @@ CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
|||
CMAKE_OPT += -DUSE_ASCEND=$(ASCEND)
|
||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||
|
||||
ifeq ($(INTELCPU), ON)
|
||||
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
||||
|
@ -62,7 +63,7 @@ test-api:
|
|||
@echo
|
||||
python3 pyinfinitensor/tests/test_api.py
|
||||
|
||||
docker-build:
|
||||
docker-build:
|
||||
docker build -f scripts/dockerfile/$(DOCKER_FILE) -t $(DOCKER_NAME) .
|
||||
|
||||
docker-run:
|
||||
|
@ -73,5 +74,3 @@ docker-start:
|
|||
|
||||
docker-exec:
|
||||
docker exec -it $(DOCKER_IMAGE_NAME) bash
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
SET(CNCL_LIB_SEARCH_PATHS $ENV{NEUWARE_HOME}/lib64)
|
||||
SET(CNCL_INCLUDE_SEARCH_PATHS $ENV{NEUWARE_HOME}/include)
|
||||
|
||||
set(CNCL_INCLUDE_DIR $ENV{NEUWARE_HOME}/include)
|
||||
set(CNCL_LIB_DIR $ENV{NEUWARE_HOME}/lib64)
|
||||
set(CNCL_VERSION $ENV{CNCL_VERSION} CACHE STRING "Version of CNCL to build with")
|
||||
|
||||
if ($ENV{CNCL_ROOT_DIR})
|
||||
message(WARNING "CNCL_ROOT_DIR is deprecated. Please set CNCL_ROOT instead.")
|
||||
endif()
|
||||
list(APPEND CNCL_ROOT $ENV{CNCL_ROOT_DIR} ${MLU_TOOLKIT_ROOT_DIR})
|
||||
# Compatible layer for CMake <3.12. CNCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CNCL_ROOT})
|
||||
|
||||
find_path(CNCL_INCLUDE_DIRS
|
||||
NAMES cncl.h
|
||||
HINTS ${CNCL_INCLUDE_DIR})
|
||||
|
||||
if (USE_STATIC_CNCL)
|
||||
MESSAGE(STATUS "USE_STATIC_CNCL is set. Linking with static CNCL library.")
|
||||
SET(CNCL_LIBNAME "CNCL_static")
|
||||
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
|
||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
endif()
|
||||
else()
|
||||
SET(CNCL_LIBNAME "cncl")
|
||||
if (CNCL_VERSION) # Prefer the versioned library if a specific CNCL version is specified
|
||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${CNCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_library(CNCL_LIBRARIES
|
||||
NAMES ${CNCL_LIBNAME}
|
||||
HINTS ${CNCL_LIB_DIR})
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(CNCL DEFAULT_MSG CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
|
||||
|
||||
if(CNCL_FOUND) # obtaining CNCL version and some sanity checks
|
||||
set (CNCL_HEADER_FILE "${CNCL_INCLUDE_DIRS}/cncl.h")
|
||||
message (STATUS "Determining CNCL version from ${CNCL_HEADER_FILE}...")
|
||||
set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
||||
list (APPEND CMAKE_REQUIRED_INCLUDES ${CNCL_INCLUDE_DIRS})
|
||||
include(CheckCXXSymbolExists)
|
||||
check_cxx_symbol_exists(CNCL_VERSION_CODE CNCL.h CNCL_VERSION_DEFINED)
|
||||
|
||||
if (CNCL_VERSION_DEFINED)
|
||||
set(file "${PROJECT_BINARY_DIR}/detect_cncl_version.cc")
|
||||
file(WRITE ${file} "
|
||||
#include <iostream>
|
||||
#include <cncl.h>
|
||||
int main()
|
||||
{
|
||||
std::cout << CNCL_MAJOR << '.' << CNCL_MINOR << '.' << CNCL_PATCH << std::endl;
|
||||
int x;
|
||||
CNCLGetVersion(&x);
|
||||
return x == CNCL_VERSION_CODE;
|
||||
}
|
||||
")
|
||||
try_run(CNCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
|
||||
RUN_OUTPUT_VARIABLE CNCL_VERSION_FROM_HEADER
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CNCL_INCLUDE_DIRS}"
|
||||
LINK_LIBRARIES ${CNCL_LIBRARIES})
|
||||
if (NOT CNCL_VERSION_MATCHED)
|
||||
message(FATAL_ERROR "Found CNCL header version and library version do not match! \
|
||||
(include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES}) Please set CNCL_INCLUDE_DIR and CNCL_LIB_DIR manually.")
|
||||
endif()
|
||||
message(STATUS "CNCL version: ${CNCL_VERSION_FROM_HEADER}")
|
||||
else()
|
||||
# message(STATUS "CNCL version < 2.3.5-5")
|
||||
endif ()
|
||||
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
|
||||
|
||||
message(STATUS "Found CNCL (include: ${CNCL_INCLUDE_DIRS}, library: ${CNCL_LIBRARIES})")
|
||||
mark_as_advanced(CNCL_ROOT_DIR CNCL_INCLUDE_DIRS CNCL_LIBRARIES)
|
||||
endif()
|
|
@ -1 +1 @@
|
|||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
||||
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
|
@ -0,0 +1,196 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
||||
help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
load_inputs(stub, world_size, rank)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
time.sleep(0.01)
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||
results = np.load(f"./data/output.npy")
|
||||
outputs = run_model(model, runtime, world_size, rank)
|
||||
print("answer argmax:", np.argmax(results))
|
||||
print("output argmax:", np.argmax(outputs))
|
||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||
getDiff(results, outputs)
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
onnx.save_model(
|
||||
model,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.BangRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, world_size, rank)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
runtime = backend.BangRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def generate_input_output(model):
|
||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||
runtime = backend.BangRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
position_id = 0
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = tensor.copyout_numpy()
|
||||
if np.issubdtype(input.dtype, np.integer):
|
||||
if input.size == 1:
|
||||
# input = np.array([position_id])
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
else:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
elif input.dtype == np.bool_:
|
||||
input = np.random.randint(0,2,size=input.shape) > 0
|
||||
else:
|
||||
if i == 0:
|
||||
input = np.ones(input.shape).astype(input.dtype)
|
||||
position_id = input.shape[-1] - 1
|
||||
else:
|
||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||
tensor.copyin_numpy(input)
|
||||
np.save(f"./data/input_{i}", input)
|
||||
stub.run()
|
||||
time.sleep(0.01)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
np.save(f"./data/output", output)
|
||||
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(f"./data/input_{i}.npy")
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
def getDiff(base, test):
|
||||
absolute_diff = np.abs(np.subtract(base, test))
|
||||
max_absolute_diff = np.max(absolute_diff)
|
||||
|
||||
baseCopy = base.astype(np.float64).ravel()
|
||||
testCopy = test.astype(np.float64).ravel()
|
||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||
max_relative_diff = upValue / downValue
|
||||
print(f"Max absolute difference: {max_absolute_diff}\n"
|
||||
f"Max relative difference: {max_relative_diff}")
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
p = mp.Process(target=generate_input_output, args=[model])
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single MLU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} MLUs in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -5,6 +5,7 @@ import multiprocessing as mp
|
|||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
@ -44,16 +45,18 @@ def parse_args():
|
|||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, inputs: np.array, n=20):
|
||||
def run_model(model, runtime, inputs, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
stub.tune()
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
tensor.copyin_numpy(input)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
tensor.copyin_numpy(input)
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
|
@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20):
|
|||
|
||||
|
||||
def run_and_compare(name, model, runtime):
|
||||
data = np.load(f"{name}_inputs.npy")
|
||||
input_ids = np.load(f"{name}_inputs.npy")
|
||||
position_ids = np.arange(input_ids.shape[-1])
|
||||
results = np.load(f"{name}_results.npy")
|
||||
outputs = run_model(model, runtime, data)
|
||||
print("outputs sum:", outputs.sum())
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
print("max rel diff:", abs((outputs - results) / results).max())
|
||||
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids))
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
def start_worker(
|
||||
|
@ -81,14 +83,13 @@ def start_worker(
|
|||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
onnx.save_model(
|
||||
model,
|
||||
all_tensors_to_one_file=True,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
|
@ -106,10 +107,12 @@ def start_single(name, model):
|
|||
|
||||
def gen_standard(name, model, voc_size, bs, len):
|
||||
# generate standard results
|
||||
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||
np.save(f"{name}_inputs", data)
|
||||
input_ids = np.random.randint(0, voc_size, (bs, len))
|
||||
position_ids = np.arange(len)
|
||||
np.save(f"{name}_inputs", input_ids)
|
||||
runtime = backend.CudaRuntime(0)
|
||||
outputs = run_model(model, runtime, data, 1)
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids), 1)
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
np.save(f"{name}_results", outputs)
|
||||
|
||||
|
||||
|
@ -128,12 +131,14 @@ def main():
|
|||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single GPU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} GPU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
|
|
|
@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
vinfo = {info.name: info for info in model.graph.value_info}
|
||||
vinfo.update({info.name: info for info in model.graph.input})
|
||||
vinfo.update({info.name: info for info in model.graph.output})
|
||||
output = {info.name: info for info in model.graph.output}
|
||||
place: Dict[str, Placement] = {}
|
||||
nodes: List[NodeProto] = []
|
||||
|
||||
|
@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
||||
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
||||
place[node.output[0]] = out_plc
|
||||
|
||||
|
||||
def shard_concat(node: NodeProto):
|
||||
# hack for kvcache
|
||||
in_plc = place[node.input[1]]
|
||||
|
@ -114,7 +115,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||
out_dims[s_dim] //= tp_world_size
|
||||
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
|
||||
# node.input[1] = node.output[0] + "_shape"
|
||||
node.input[1] = node.output[0] + "_shape"
|
||||
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
|
||||
place[node.output[0]] = Shard(s_dim)
|
||||
|
||||
|
@ -136,7 +137,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||
|
||||
def shard_node(node: NodeProto):
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax"]:
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type in ["Where"]:
|
||||
place[node.output[0]] = place[node.input[1]]
|
||||
|
@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type == "Concat":
|
||||
shard_concat(node)
|
||||
shard_concat(node)
|
||||
|
||||
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
||||
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
||||
|
@ -175,6 +176,16 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||
input in data for input in node.input
|
||||
):
|
||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||
if (
|
||||
node.output[0] in output
|
||||
or (
|
||||
index + 1 < len(model.graph.node)
|
||||
and model.graph.node[index + 1].output[0]
|
||||
)
|
||||
in output
|
||||
):
|
||||
continue
|
||||
groups = 1
|
||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||
split_node = find_successor("Split", index, search_limit=2)
|
||||
|
@ -218,7 +229,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
new_input = []
|
||||
for info in model.graph.input:
|
||||
new_input.append(vinfo[info.name])
|
||||
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
model.graph.name + f"_{tp_rank}",
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
|
||||
import paddle
|
||||
import paddle.vision.transforms as T
|
||||
from paddle.vision.datasets import Cifar10
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
import itertools
|
||||
|
||||
def run_cifar_train_and_infer():
|
||||
|
||||
paddle.device.set_device("gpu")
|
||||
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.Resize(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
to_rgb=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# 下载数据集并初始化 DataSet
|
||||
train_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=transform)
|
||||
test_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=transform)
|
||||
|
||||
# 模型组网并初始化网络
|
||||
densenet = paddle.vision.models.DenseNet(num_classes=10)
|
||||
model = paddle.Model(densenet)
|
||||
|
||||
# 模型训练的配置准备,准备损失函数,优化器和评价指标
|
||||
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
|
||||
paddle.nn.CrossEntropyLoss(),
|
||||
paddle.metric.Accuracy())
|
||||
|
||||
# 模型训练
|
||||
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
|
||||
# 模型评估
|
||||
model.evaluate(test_dataset, batch_size=64, verbose=1)
|
||||
|
||||
# export to ONNX
|
||||
save_path = 'onnx.save/densenet' # 需要保存的路径
|
||||
x_spec = paddle.static.InputSpec([1, 3, 224, 224], 'float32', 'x') # 为模型指定输入的形状和数据类型,支持持 Tensor 或 InputSpec ,InputSpec 支持动态的 shape。
|
||||
paddle.onnx.export(densenet, save_path, input_spec=[x_spec], opset_version=11)
|
||||
|
||||
# 加载onnx模型并放到Infinitensor中
|
||||
model_path = save_path + ".onnx"
|
||||
onnx_model = onnx.load(model_path)
|
||||
gofusion_model = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||
model = gofusion_model
|
||||
model.init()
|
||||
|
||||
# 启动推理
|
||||
cifar10_test = Cifar10(
|
||||
mode="test",
|
||||
transform=transform, # apply transform to every image
|
||||
backend="cv2", # use OpenCV as image transform backend
|
||||
)
|
||||
batch_size = 1
|
||||
total_size = 0
|
||||
total_acc = 0.0
|
||||
for data in itertools.islice(iter(cifar10_test), 10000):
|
||||
images, labels = data
|
||||
next(model.inputs.items().__iter__())[1].copyin_float(images.reshape([3*224*224]).tolist())
|
||||
model.run()
|
||||
outputs = next(model.outputs.items().__iter__())[1].copyout_float()
|
||||
outputs = paddle.to_tensor(outputs)
|
||||
outputs = paddle.reshape(outputs, (1, 10))
|
||||
labels = paddle.to_tensor(labels)
|
||||
labels = paddle.reshape(labels, (1,1))
|
||||
acc = paddle.metric.accuracy(outputs, labels)
|
||||
total_acc += acc
|
||||
total_size += batch_size
|
||||
print("test acc: {}".format(total_acc.numpy() / total_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_cifar_train_and_infer()
|
|
@ -0,0 +1,80 @@
|
|||
import paddle
|
||||
import paddle.vision.transforms as T
|
||||
from paddle.vision.datasets import Cifar10
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
import itertools
|
||||
|
||||
def run_cifar_train_and_infer():
|
||||
|
||||
paddle.device.set_device("gpu")
|
||||
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.Resize(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
to_rgb=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# 下载数据集并初始化 DataSet
|
||||
train_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=transform)
|
||||
test_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=transform)
|
||||
|
||||
# 模型组网并初始化网络
|
||||
inception = paddle.vision.models.InceptionV3(num_classes=10)
|
||||
model = paddle.Model(inception)
|
||||
|
||||
# 模型训练的配置准备,准备损失函数,优化器和评价指标
|
||||
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
|
||||
paddle.nn.CrossEntropyLoss(),
|
||||
paddle.metric.Accuracy())
|
||||
|
||||
# 模型训练
|
||||
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
|
||||
# 模型评估
|
||||
model.evaluate(test_dataset, batch_size=64, verbose=1)
|
||||
|
||||
# export to ONNX
|
||||
save_path = 'onnx.save/inception' # 需要保存的路径
|
||||
x_spec = paddle.static.InputSpec([1, 3, 224, 224], 'float32', 'x') # 为模型指定输入的形状和数据类型,支持持 Tensor 或 InputSpec ,InputSpec 支持动态的 shape。
|
||||
paddle.onnx.export(inception, save_path, input_spec=[x_spec], opset_version=11)
|
||||
|
||||
# 加载onnx模型并放到Infinitensor中
|
||||
model_path = save_path + ".onnx"
|
||||
onnx_model = onnx.load(model_path)
|
||||
gofusion_model = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||
model = gofusion_model
|
||||
model.init()
|
||||
|
||||
# 启动推理
|
||||
cifar10_test = Cifar10(
|
||||
mode="test",
|
||||
transform=transform, # apply transform to every image
|
||||
backend="cv2", # use OpenCV as image transform backend
|
||||
)
|
||||
batch_size = 1
|
||||
total_size = 0
|
||||
total_acc = 0.0
|
||||
for data in itertools.islice(iter(cifar10_test), 10000):
|
||||
images, labels = data
|
||||
next(model.inputs.items().__iter__())[1].copyin_float(images.reshape([3*224*224]).tolist())
|
||||
model.run()
|
||||
outputs = next(model.outputs.items().__iter__())[1].copyout_float()
|
||||
outputs = paddle.to_tensor(outputs)
|
||||
outputs = paddle.reshape(outputs, (1, 10))
|
||||
labels = paddle.to_tensor(labels)
|
||||
labels = paddle.reshape(labels, (1,1))
|
||||
acc = paddle.metric.accuracy(outputs, labels)
|
||||
total_acc += acc
|
||||
total_size += batch_size
|
||||
print("test acc: {}".format(total_acc.numpy() / total_size))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_cifar_train_and_infer()
|
|
@ -0,0 +1,31 @@
|
|||
## Description
|
||||
|
||||
This is a doc to tell you how to run paddle*.py in your machine. If your model run on other machines except Nvidia, you may need to make some change.
|
||||
|
||||
## What do we do in paddle*.py files?
|
||||
|
||||
1. Train model and evalute model with Cifar10 dataset
|
||||
|
||||
2. Export paddle model to onnx model
|
||||
|
||||
3. Load onnx model, infer with InfiniTensor and calculate the inference accuracy
|
||||
|
||||
## Command
|
||||
|
||||
1. Go to `/examples/python` folder
|
||||
|
||||
2. Run the following command
|
||||
|
||||
1. ```
|
||||
python paddle_resnet.py
|
||||
python paddle_densenet.py
|
||||
python paddle_inception.py
|
||||
```
|
||||
|
||||
## What should I do if I use other device(MLU, XPU, NPU)?
|
||||
|
||||
You need to change this code:
|
||||
|
||||
```
|
||||
paddle.device.set_device("gpu") # Change gpu to mlu, xpu or npu
|
||||
```
|
|
@ -0,0 +1,81 @@
|
|||
|
||||
import paddle
|
||||
import paddle.vision.transforms as T
|
||||
from paddle.vision.datasets import Cifar10
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
import itertools
|
||||
from paddle.vision.models.resnet import BasicBlock
|
||||
|
||||
def run_cifar_train_and_infer():
|
||||
|
||||
paddle.device.set_device("gpu")
|
||||
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.Resize(224),
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
to_rgb=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# 下载数据集并初始化 DataSet
|
||||
train_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=transform)
|
||||
test_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=transform)
|
||||
|
||||
# 模型组网并初始化网络
|
||||
resnet = paddle.vision.models.ResNet(BasicBlock, depth=18, num_classes=10)
|
||||
model = paddle.Model(resnet)
|
||||
|
||||
# 模型训练的配置准备,准备损失函数,优化器和评价指标
|
||||
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
|
||||
paddle.nn.CrossEntropyLoss(),
|
||||
paddle.metric.Accuracy())
|
||||
|
||||
# 模型训练
|
||||
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
|
||||
# 模型评估
|
||||
model.evaluate(test_dataset, batch_size=64, verbose=1)
|
||||
|
||||
# export to ONNX
|
||||
save_path = 'onnx.save/resnet' # 需要保存的路径
|
||||
x_spec = paddle.static.InputSpec([1, 3, 224, 224], 'float32', 'x') # 为模型指定输入的形状和数据类型,支持持 Tensor 或 InputSpec ,InputSpec 支持动态的 shape。
|
||||
paddle.onnx.export(resnet, save_path, input_spec=[x_spec], opset_version=11)
|
||||
|
||||
# 加载onnx模型并放到Infinitensor中
|
||||
model_path = save_path + ".onnx"
|
||||
onnx_model = onnx.load(model_path)
|
||||
gofusion_model = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||
model = gofusion_model
|
||||
model.init()
|
||||
|
||||
# 启动推理
|
||||
cifar10_test = Cifar10(
|
||||
mode="test",
|
||||
transform=transform, # apply transform to every image
|
||||
backend="cv2", # use OpenCV as image transform backend
|
||||
)
|
||||
batch_size = 1
|
||||
total_size = 0
|
||||
total_acc = 0.0
|
||||
for data in itertools.islice(iter(cifar10_test), 10000):
|
||||
images, labels = data
|
||||
next(model.inputs.items().__iter__())[1].copyin_float(images.reshape([3*224*224]).tolist())
|
||||
model.run()
|
||||
outputs = next(model.outputs.items().__iter__())[1].copyout_float()
|
||||
outputs = paddle.to_tensor(outputs)
|
||||
outputs = paddle.reshape(outputs, (1, 10))
|
||||
labels = paddle.to_tensor(labels)
|
||||
labels = paddle.reshape(labels, (1,1))
|
||||
acc = paddle.metric.accuracy(outputs, labels)
|
||||
total_acc += acc
|
||||
total_size += batch_size
|
||||
print("test acc: {}".format(total_acc.numpy() / total_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_cifar_train_and_infer()
|
|
@ -7,16 +7,19 @@ namespace infini {
|
|||
class BangRuntimeObj : public RuntimeObj {
|
||||
private:
|
||||
cnnlHandle_t cnnl;
|
||||
cnrtQueue_t queue;
|
||||
std::unique_ptr<CommunicatorObj> comm;
|
||||
BangPtr workspace;
|
||||
size_t workspaceSize;
|
||||
mutable size_t cursor;
|
||||
|
||||
public:
|
||||
BangRuntimeObj() : RuntimeObj(Device::BANG) {
|
||||
explicit BangRuntimeObj(int deviceId = 0)
|
||||
: RuntimeObj(Device::BANG, deviceId) {
|
||||
cnInit(0);
|
||||
CNdev dev;
|
||||
cnDeviceGet(&dev, 0);
|
||||
cnDeviceGet(&dev, deviceId);
|
||||
checkBangError(cnrtSetDevice(dev));
|
||||
cnrtQueue_t queue;
|
||||
checkBangError(cnrtQueueCreate(&queue));
|
||||
|
||||
checkCnnlError(cnnlCreate(&cnnl));
|
||||
|
@ -24,10 +27,12 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
cursor = 0;
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
virtual ~BangRuntimeObj() {
|
||||
dealloc(workspace);
|
||||
checkBangError(cnrtQueueDestroy(queue));
|
||||
checkCnnlError(cnnlDestroy(cnnl));
|
||||
}
|
||||
string toString() const override;
|
||||
|
@ -45,10 +50,15 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
|
||||
cnnlHandle_t cnnlHandle() const { return cnnl; }
|
||||
BangPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
IT_ASSERT((cursor + size) <= workspaceSize);
|
||||
cursor += size;
|
||||
void *temp = workspace;
|
||||
temp += (cursor - size);
|
||||
return temp;
|
||||
}
|
||||
|
||||
void resetWorkspace() const { cursor = 0; }
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
||||
|
@ -66,10 +76,9 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
||||
CNRT_MEM_TRANS_DIR_PEER2PEER));
|
||||
}
|
||||
|
||||
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
||||
|
||||
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
||||
void initComm(const string &name, int worldSize, int rank) final;
|
||||
CommunicatorObj &getCommunicator() const override { return *comm; }
|
||||
cnrtQueue_t getBangQueue() const { return queue; }
|
||||
|
||||
private:
|
||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
#pragma once
|
||||
#include "bang_common.h"
|
||||
#include "core/communicator.h"
|
||||
#include <chrono>
|
||||
#include <cncl.h>
|
||||
#include <cnrt.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
namespace infini {
|
||||
|
||||
class CnclCommunicatorObj final : public CommunicatorObj {
|
||||
private:
|
||||
cnclComm_t *comms;
|
||||
|
||||
public:
|
||||
CnclCommunicatorObj(const string &name, int worldSize, int rank)
|
||||
: CommunicatorObj(worldSize, rank) {
|
||||
const std::string filePath("./" + name + "_cncl_id.bin");
|
||||
cnclCliqueId clique_id;
|
||||
if (rank == 0) {
|
||||
CNCL_CHECK(cnclGetCliqueId(&clique_id));
|
||||
std::ofstream ofs(filePath, std::ios::binary);
|
||||
ofs.write((char *)&clique_id, sizeof(cnclCliqueId));
|
||||
|
||||
} else {
|
||||
auto begin = std::chrono::steady_clock::now();
|
||||
while (!std::filesystem::exists(filePath)) {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||
"time limit (10s) exceeded.");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
std::ifstream ifs(filePath, std::ios::binary);
|
||||
ifs.read((char *)&clique_id, sizeof(cnclCliqueId));
|
||||
}
|
||||
|
||||
int num_comms = 1;
|
||||
int *dev_list = new int[num_comms];
|
||||
int *rank_list = new int[num_comms];
|
||||
comms = new cnclComm_t[num_comms];
|
||||
uint32_t num_dev = 0;
|
||||
checkBangError(cnrtGetDeviceCount(&num_dev));
|
||||
|
||||
for (int i = 0; i < num_comms; i++) {
|
||||
rank_list[i] = rank;
|
||||
dev_list[i] = rank_list[i] % num_dev;
|
||||
}
|
||||
|
||||
CNCL_CHECK(cnclInitComms(comms, num_comms, dev_list, rank_list,
|
||||
worldSize, &clique_id));
|
||||
|
||||
if (rank == 0) {
|
||||
std::filesystem::remove(filePath);
|
||||
}
|
||||
|
||||
delete[] dev_list;
|
||||
delete[] rank_list;
|
||||
}
|
||||
|
||||
~CnclCommunicatorObj() {
|
||||
CNCL_CHECK(cnclDestroyComms(comms, 1));
|
||||
delete[] comms;
|
||||
}
|
||||
|
||||
// Get the actual cnclComm_t
|
||||
cnclComm_t getCnclComm() { return comms[0]; }
|
||||
|
||||
virtual string toString() const final {
|
||||
std::ostringstream oss;
|
||||
oss << "CNCL communicator";
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -75,7 +75,7 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
|||
|
||||
double timeit(
|
||||
const std::function<void()> &func,
|
||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
|
||||
int timingRounds = 200);
|
||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
|
||||
int timingRounds = 10);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -53,6 +53,7 @@ class GraphObj : public Object {
|
|||
const TensorVec &getTensors() const { return tensors; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
Tensor getTensor(int) const;
|
||||
|
||||
/**
|
||||
* Sort the nodes in topological order.
|
||||
|
@ -64,7 +65,13 @@ class GraphObj : public Object {
|
|||
|
||||
void optimize();
|
||||
|
||||
void dataMalloc(bool useNaiveAllocator = false);
|
||||
void shape_infer();
|
||||
|
||||
void dataMalloc(bool useNaiveAllocator = false, size_t memPoolSize = 0);
|
||||
|
||||
Tensor cloneKV(Tensor &tensor);
|
||||
|
||||
void freeHeap();
|
||||
|
||||
/**
|
||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||
|
|
|
@ -30,6 +30,8 @@ class GraphHandlerObj {
|
|||
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias,
|
||||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw, int ceilMode);
|
||||
|
@ -63,13 +65,26 @@ class GraphHandlerObj {
|
|||
std::optional<float> max);
|
||||
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||
Tensor resize(Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
Tensor scales, Tensor roi, vector<uint32_t> sizes_,
|
||||
vector<float> scales_, vector<float> roi_, string mode,
|
||||
string ratioPolicy, string nearestMode,
|
||||
string coordTransMode);
|
||||
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul);
|
||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||
int num_outputs);
|
||||
std::variant<int, vector<int>> numOrRatio);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes, bool keepdims);
|
||||
Tensor reduceSum(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes, bool keepdims);
|
||||
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
||||
const vector<int> &ends, const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps);
|
||||
|
@ -78,6 +93,7 @@ class GraphHandlerObj {
|
|||
Tensor cast(Tensor input, Tensor output, int to);
|
||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||
std::vector<int> getDims(Tensor x) { return x->getDims(); }
|
||||
|
||||
Tensor allReduceSum(Tensor input, Tensor output);
|
||||
Tensor allReduceProd(Tensor input, Tensor output);
|
||||
|
@ -86,6 +102,13 @@ class GraphHandlerObj {
|
|||
Tensor allReduceAvg(Tensor input, Tensor output);
|
||||
TensorVec allGather(Tensor input, std::optional<TensorVec> outputs, int n);
|
||||
Tensor broadcast(Tensor input, Tensor output, int root);
|
||||
Tensor send(Tensor input, int source, int destination, Tensor output);
|
||||
Tensor recv(Tensor output, int source, int destination, Shape dims,
|
||||
int outputType, Tensor input);
|
||||
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
|
||||
std::string mode);
|
||||
Tensor lrn(Tensor input, Tensor output, float alpha, float beta, float bias,
|
||||
int size);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
|
@ -93,9 +116,19 @@ class GraphHandlerObj {
|
|||
|
||||
inline void optimize() { g->optimize(); }
|
||||
|
||||
inline void shape_infer() { g->shape_infer(); }
|
||||
|
||||
void change_shape(const vector<int> &shape, int tensorId);
|
||||
//------ runtime
|
||||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
inline void data_malloc(bool useNaiveAllocator = false,
|
||||
size_t memPoolSize = 0) {
|
||||
g->dataMalloc(useNaiveAllocator, memPoolSize);
|
||||
}
|
||||
|
||||
inline Tensor clone_KV(Tensor &tensor) { return g->cloneKV(tensor); }
|
||||
|
||||
inline void free_heap() { g->freeHeap(); }
|
||||
|
||||
inline void tune() { g->getRuntime()->run(g, true); }
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/common.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
#include "utils/operator_utils.h"
|
||||
#include <functional>
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::json;
|
||||
|
@ -29,7 +30,6 @@ class Kernel {
|
|||
public:
|
||||
Kernel() {}
|
||||
virtual ~Kernel() {}
|
||||
|
||||
/**
|
||||
* @param op The operator to be executed.
|
||||
* @param record The parameters for kernel execution. If extra parameters
|
||||
|
@ -102,11 +102,9 @@ class KernelRegistry {
|
|||
}
|
||||
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||
auto it = kernels.find(kernelAttrs);
|
||||
IT_ASSERT(it != kernels.end(),
|
||||
"Kernel not found for key {" +
|
||||
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
|
||||
", " + std::to_string(std::get<1>(kernelAttrs)) + ", " +
|
||||
std::get<2>(kernelAttrs).toString() + "}");
|
||||
IT_ASSERT(it != kernels.end(), "Kernel not found for key {" +
|
||||
get_kernel_attrs_str(kernelAttrs) +
|
||||
"}");
|
||||
return std::get<0>(it->second);
|
||||
}
|
||||
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||
|
@ -131,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
|
|||
|
||||
} // namespace infini
|
||||
|
||||
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
|
||||
#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
|
||||
namespace infini { \
|
||||
static const bool _CAT(_register_kernel_, cnt) = \
|
||||
KernelRegistry::getInstance().registerKernel( \
|
||||
KernelAttrs{device, opType, dataType}, new kernel(), name); \
|
||||
KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
|
||||
opType}, \
|
||||
new kernel(), name); \
|
||||
}
|
||||
|
||||
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
|
||||
#define REGISTER_KERNEL(device, opType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
|
||||
|
||||
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
|
||||
namespace infini { \
|
||||
|
|
|
@ -26,14 +26,23 @@ class LazyAllocator {
|
|||
|
||||
size_t weightPeak = 0;
|
||||
|
||||
size_t heapPeak = 0;
|
||||
|
||||
size_t alignment;
|
||||
|
||||
bool hasMemPool = false;
|
||||
|
||||
size_t memPoolSize = 0;
|
||||
|
||||
// pointer to the memory actually allocated
|
||||
void *ptr = nullptr;
|
||||
|
||||
// pointer to the weight memory space
|
||||
void *weightPtr = nullptr;
|
||||
|
||||
// memory pool ptr
|
||||
void *memPoolPtr = nullptr;
|
||||
|
||||
// // a cache designed for a batch size that has already occurred
|
||||
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
|
||||
// batchsizeToTensorOffset;
|
||||
|
@ -68,6 +77,10 @@ class LazyAllocator {
|
|||
|
||||
void init();
|
||||
|
||||
void setMemPool(size_t memPoolSize);
|
||||
|
||||
bool getMemPoolStatus();
|
||||
|
||||
// function: simulate memory allocation
|
||||
// arguments:
|
||||
// size: size of memory block to be allocated
|
||||
|
@ -76,6 +89,10 @@ class LazyAllocator {
|
|||
|
||||
size_t allocWeight(size_t size);
|
||||
|
||||
size_t heapAlloc(size_t size);
|
||||
|
||||
void freeHeap();
|
||||
|
||||
// function: simulate memory free
|
||||
// arguments:
|
||||
// addr: head address offset of memory block to be free
|
||||
|
@ -92,6 +109,8 @@ class LazyAllocator {
|
|||
|
||||
void *getWeightPtr();
|
||||
|
||||
void *getHeapPtr();
|
||||
|
||||
void info();
|
||||
|
||||
private:
|
||||
|
|
|
@ -25,6 +25,7 @@ struct OpType {
|
|||
Asinh, // Unary
|
||||
Atan, // Unary
|
||||
Atanh, // Unary
|
||||
AttentionKVCache, // Fusion
|
||||
AveragePool, // Pool
|
||||
BatchNormalization, //
|
||||
Bernoulli, //
|
||||
|
@ -231,6 +232,8 @@ struct OpType {
|
|||
AllReduceAvg,
|
||||
AllGather,
|
||||
Broadcast,
|
||||
Send,
|
||||
Recv,
|
||||
} type;
|
||||
|
||||
constexpr OpType(decltype(type) t) : type(t) {}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "core/tensor.h"
|
||||
|
||||
namespace infini {
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
|
||||
|
||||
struct OpPerfKey {
|
||||
HashType hash;
|
||||
|
@ -55,8 +55,7 @@ class OperatorObj : public Object {
|
|||
|
||||
public:
|
||||
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
|
||||
virtual optional<vector<Shape>>
|
||||
inferShape(const TensorVec &inputs) const = 0;
|
||||
virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
|
||||
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
|
||||
/**
|
||||
* @brief Constructs outputs (if requried) and check whether the operator is
|
||||
|
@ -91,6 +90,7 @@ class OperatorObj : public Object {
|
|||
OpType getOpType() const { return type; }
|
||||
// HACK: set correct data type
|
||||
DataType getDType() const { return getInputs(0)->getDType(); }
|
||||
DataType getOutDType() const { return getOutput()->getDType(); }
|
||||
virtual int numInputs() const = 0;
|
||||
virtual int numOutputs() const = 0;
|
||||
|
||||
|
@ -105,7 +105,7 @@ class OperatorObj : public Object {
|
|||
const TensorVec &newOutputs) const = 0;
|
||||
|
||||
protected:
|
||||
optional<vector<Shape>> inferShape() const;
|
||||
optional<vector<Shape>> inferShape();
|
||||
vector<DataType> inferDataType() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -8,7 +8,9 @@
|
|||
#if USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
#if USE_BANG
|
||||
#include "bang/bang_runtime.h"
|
||||
#endif
|
||||
namespace infini {
|
||||
|
||||
// TODO: how to deal with this
|
||||
|
@ -31,6 +33,7 @@ class TensorObj : public TensorBaseObj {
|
|||
size_t getBytes() const { return _size * dtype.getSize(); }
|
||||
|
||||
Shape getDims() const { return shape; }
|
||||
void setShape(Shape shape_);
|
||||
size_t getRank() const { return shape.size(); }
|
||||
Shape getStride() const;
|
||||
size_t getOffset(const vector<int> &ds) const;
|
||||
|
@ -41,8 +44,16 @@ class TensorObj : public TensorBaseObj {
|
|||
bool isOutput() const { return tensorType == TensorType::output; }
|
||||
bool isOthers() const { return tensorType == TensorType::others; }
|
||||
void setWeight() { tensorType = TensorType::weight; }
|
||||
void setInput() { tensorType = TensorType::input; }
|
||||
void setOutput() { tensorType = TensorType::output; }
|
||||
void setInput() {
|
||||
if (!this->isWeight()) {
|
||||
tensorType = TensorType::input;
|
||||
}
|
||||
}
|
||||
void setOutput() {
|
||||
if (!this->isWeight()) {
|
||||
tensorType = TensorType::output;
|
||||
}
|
||||
}
|
||||
string tensorTypeToString() const {
|
||||
switch (tensorType) {
|
||||
case TensorType::weight:
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include <cstdio>
|
||||
|
||||
struct AttentionKVCacheMetadata {
|
||||
int dimSize[4];
|
||||
int stride[4];
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||
float *input_q, float *input_k, float *input_v,
|
||||
int *position_id, float *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta);
|
||||
|
||||
} // namespace infini
|
|
@ -1,8 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void div_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void add_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
void expandKernel(int dType, void *input, void *output, int nDims,
|
||||
int outputsize, SmallArray inputShape,
|
||||
SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output, const float *bias, int biasSize);
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output);
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output, const half *bias, int biasSize);
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output);
|
||||
}; // namespace infini
|
|
@ -10,10 +10,11 @@ typedef struct {
|
|||
int wholeNDim[MAX_DIM]; // dim size after padding or before slicing
|
||||
int partNDim[MAX_DIM]; // dim size before padding or after slicing
|
||||
int partStride[MAX_DIM]; // stride before padding or after slicing
|
||||
int DType;
|
||||
} TransMetaData;
|
||||
|
||||
namespace infini {
|
||||
void pad_slice_kernel(float *partData, float *wholeData,
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad);
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#pragma once
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
||||
int dimsize, int stride);
|
||||
void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
||||
int dimsize, int stride);
|
||||
} // namespace infini
|
|
@ -3,13 +3,13 @@
|
|||
#include <cstdio>
|
||||
|
||||
const int BATCH_SIZE = 32; // parallel tensor number.
|
||||
const int DIM_MAX_SIZE = 4;
|
||||
const int DIM_MAX_SIZE = 8;
|
||||
|
||||
// Concat operator acts like element tensors composing to one big tensor,and
|
||||
// split operator acts like one big tensor being composed by element
|
||||
// tensors.
|
||||
struct ElementTensorMetadata {
|
||||
float *data[BATCH_SIZE];
|
||||
template <typename T> struct ElementTensorMetadata {
|
||||
T *data[BATCH_SIZE];
|
||||
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
|
||||
// the composed tensor.
|
||||
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
|
||||
|
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
|
|||
data[i], dimBgNo[i], dimSize[i], nElements[i]);
|
||||
}
|
||||
};
|
||||
|
||||
struct ComposedTensorMetadata {
|
||||
template <typename T> struct ComposedTensorMetadata {
|
||||
int dimSize[DIM_MAX_SIZE];
|
||||
int stride[DIM_MAX_SIZE];
|
||||
float *data;
|
||||
T *data;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
||||
const ComposedTensorMetadata<float> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
const ComposedTensorMetadata<half> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,48 +3,21 @@
|
|||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(float *input, float *output, size_t num);
|
||||
void relu_kernel(float *input, float *output, size_t num);
|
||||
void sigmoid_kernel(float *input, float *output, size_t num);
|
||||
void tanh_kernel(float *input, float *output, size_t num);
|
||||
void abs_kernel(float *input, float *output, size_t num);
|
||||
void sqrt_kernel(float *input, float *output, size_t num);
|
||||
void neg_kernel(float *input, float *output, size_t num);
|
||||
void gelu_kernel(float *input, float *output, size_t num);
|
||||
void erf_kernel(float *input, float *output, size_t num);
|
||||
void hard_sigmoid_kernel(float *input, float *output, size_t num);
|
||||
void hard_swish_kernel(float *input, float *output, size_t num);
|
||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void neg_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void gelu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
template <typename INPUT, typename OUTPUT>
|
||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||
|
||||
size_t num = op->getOutput()->size();
|
||||
if (op->getOpType() == OpType::Softmax)
|
||||
softmax_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Relu)
|
||||
relu_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Sigmoid)
|
||||
sigmoid_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::HardSigmoid)
|
||||
hard_sigmoid_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::HardSwish)
|
||||
hard_swish_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Tanh)
|
||||
tanh_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Abs)
|
||||
abs_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Sqrt)
|
||||
sqrt_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Gelu)
|
||||
gelu_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Neg)
|
||||
neg_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Erf)
|
||||
erf_kernel(inputData, outputData, num);
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
void unary_kernel(const Operator &_op);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,11 +1,29 @@
|
|||
#pragma once
|
||||
#include "core/tensor.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void cudaPrintFloat(float *x, int len);
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
void cudaPrintTensor(const Tensor &tensor);
|
||||
|
||||
} // namespace infini
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
|
||||
cudaDataType cublasDataTypeConvert(DataType);
|
||||
|
||||
template <int index> struct DT_CUDA {};
|
||||
template <> struct DT_CUDA<0> { using t = bool; };
|
||||
template <> struct DT_CUDA<1> { using t = float; };
|
||||
template <> struct DT_CUDA<2> { using t = unsigned char; };
|
||||
template <> struct DT_CUDA<3> { using t = char; };
|
||||
template <> struct DT_CUDA<4> { using t = unsigned short; };
|
||||
template <> struct DT_CUDA<5> { using t = short; };
|
||||
template <> struct DT_CUDA<6> { using t = int; };
|
||||
template <> struct DT_CUDA<7> { using t = long long; };
|
||||
template <> struct DT_CUDA<9> { using t = bool; };
|
||||
template <> struct DT_CUDA<10> { using t = half; };
|
||||
template <> struct DT_CUDA<11> { using t = double; };
|
||||
template <> struct DT_CUDA<12> { using t = unsigned int; };
|
||||
template <> struct DT_CUDA<13> { using t = unsigned long long; };
|
||||
template <> struct DT_CUDA<16> { using t = nv_bfloat16; };
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,9 +3,15 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape);
|
||||
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
void whereKernel(const half *inputX, const half *inputY,
|
||||
const uint8_t *condition, half *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData,
|
|||
metaData.inStride[i] = in->getStride()[i];
|
||||
}
|
||||
}
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||
template <typename T>
|
||||
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
|
||||
|
||||
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
||||
size_t num);
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(int max_threadblock_size, int batch_size, float *x,
|
||||
float *y, int dim, int stride);
|
||||
}
|
|
@ -24,7 +24,7 @@
|
|||
// clang-format on
|
||||
|
||||
namespace nnet {
|
||||
int matchExprResult(Derivator &derivator, string fn);
|
||||
bool checkExprLogSame(string fnPrefix, int start, int end);
|
||||
int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome);
|
||||
bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end);
|
||||
bool checkExprsEquvivalence(VecExpr exprs);
|
||||
} // namespace nnet
|
||||
|
|
|
@ -35,7 +35,7 @@ class G2BMMObj : public OperatorObj {
|
|||
OP_CLONE(G2BMMObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -33,7 +33,7 @@ class GBMMObj : public OperatorObj {
|
|||
OP_CLONE(GBMMObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -7,7 +7,7 @@ class ActivationBackwardObj : public OperatorObj {
|
|||
ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y,
|
||||
Tensor x, Tensor diff_x);
|
||||
OP_CLONE(ActivationBackwardObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 3; }
|
||||
|
|
|
@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj {
|
|||
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return world_size; }
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
||||
* tensors should have the same rank except for the position_id.
|
||||
*
|
||||
*/
|
||||
class AttentionKVCacheObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new AttentionKVCache object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input_k_cache The k_cache input tensor.
|
||||
* @param input_v_cache The v_cache input tensor.
|
||||
* @param input_q The query input tensor.
|
||||
* @param input_k The key input tensor.
|
||||
* @param input_v The value input tensor.
|
||||
* @param position_id The positon id of the query,
|
||||
* @param output_matmul The query output tensor.
|
||||
*/
|
||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul);
|
||||
OP_CLONE(AttentionKVCacheObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 6; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -34,7 +34,7 @@ class BatchNormObj : public OperatorObj {
|
|||
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
|
||||
float eps = 1e-5, bool trainingMode = false);
|
||||
OP_CLONE(BatchNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
// output size will be 3 when training
|
||||
|
|
|
@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class ConcatObj : public OperatorObj {
|
|||
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
|
||||
OP_CLONE(ConcatObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
|
|
|
@ -142,7 +142,7 @@ class ConvObj : public ConvBaseObj {
|
|||
ActType act = ActType::None);
|
||||
OP_CLONE(ConvObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
@ -164,7 +164,7 @@ class ConvBackwardFilterObj : public ConvBaseObj {
|
|||
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
|
@ -191,7 +191,7 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
OP_CLONE(ConvTransposed2dObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return group; }
|
||||
std::pair<int, int> getOutputPadding() const { return {oph, opw}; }
|
||||
|
||||
|
@ -218,7 +218,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj {
|
|||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
OP_CLONE(ConvTransposed2dNHWCObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -7,7 +7,7 @@ class DetObj : public OperatorObj {
|
|||
enum Mode { NormalDet = 0, LogDet };
|
||||
DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode);
|
||||
OP_CLONE(DetObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -37,7 +37,7 @@ class DropoutObj : public OperatorObj {
|
|||
DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
|
||||
float ratio, bool training_mode);
|
||||
OP_CLONE(DropoutObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -21,7 +21,7 @@ class ElementWiseObj : public OperatorObj {
|
|||
*/
|
||||
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -38,7 +38,7 @@ class MSELossObj : public OperatorObj {
|
|||
MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Reduction reduction, Tensor output);
|
||||
OP_CLONE(MSELossObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
Reduction getReduction() const { return reductionMode; }
|
||||
std::string toString() const override;
|
||||
|
|
|
@ -21,7 +21,7 @@ class ExpandObj : public OperatorObj {
|
|||
*/
|
||||
ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
OP_CLONE(ExpandObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -23,7 +23,7 @@ class ExtendObj : public OperatorObj {
|
|||
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||
int num = 1);
|
||||
OP_CLONE(ExtendObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -39,7 +39,7 @@ class GatherObj : public GatherBaseObj {
|
|||
int axis);
|
||||
OP_CLONE(GatherObj);
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
|
@ -69,7 +69,7 @@ class GatherElementsObj : public GatherBaseObj {
|
|||
Tensor output, int axis);
|
||||
OP_CLONE(GatherElementsObj);
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class LayerNormObj : public OperatorObj {
|
||||
float eps;
|
||||
int axis, stash_type;
|
||||
|
||||
public:
|
||||
LayerNormObj(GraphObj *graph, Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias = nullptr, float eps = 1e-5, int axis = -1,
|
||||
int stash_type = 1);
|
||||
OP_CLONE(LayerNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; }
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
int getAxis() const { return axis; }
|
||||
int getStashType() const { return stash_type; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class LRNObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
LRNObj(GraphObj *graph, Tensor inputX, Tensor inputY, float alpha,
|
||||
float beta, float bias, int size);
|
||||
OP_CLONE(LRNObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
auto getAlphaBetaBias() const {
|
||||
return tuple(alpha_value, beta_value, bias_value);
|
||||
}
|
||||
auto getSize() const { return size_value; }
|
||||
|
||||
private:
|
||||
float alpha_value, beta_value, bias_value;
|
||||
int size_value;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -45,7 +45,7 @@ class MatmulObj : public OperatorObj {
|
|||
OP_CLONE(MatmulObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -21,7 +21,7 @@ class MemBoundObj : public OperatorObj {
|
|||
OP_CLONE(MemBoundObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
|
|
|
@ -27,7 +27,7 @@ class PadObj : public OperatorObj {
|
|||
const vector<int> &pads, const optional<vector<int>> &axes);
|
||||
OP_CLONE(PadObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -41,7 +41,7 @@ class PoolingObj : public OperatorObj {
|
|||
int ceilMode);
|
||||
OP_CLONE(PoolingObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
*
|
||||
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
|
||||
*/
|
||||
class RecvObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new SendRecv object
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input default nullptr, because recv does not have input.
|
||||
* @param output recv output
|
||||
* @param source the send rank
|
||||
* @param destination the recv rank
|
||||
* @param dims The shape of the output tensor.
|
||||
*/
|
||||
RecvObj(GraphObj *graph, Tensor output, int source, int destination,
|
||||
Shape dims, int outputType, Tensor input = nullptr);
|
||||
OP_CLONE(RecvObj);
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
DataType getDType() const;
|
||||
int getSourceRank() const { return source; }
|
||||
int getDestinationRank() const { return destination; }
|
||||
inline Shape getShape() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
protected:
|
||||
int source;
|
||||
int destination;
|
||||
Shape dims;
|
||||
int outputType;
|
||||
};
|
||||
} // namespace infini
|
|
@ -3,27 +3,30 @@
|
|||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Compute the mean of input tensor's elements along certain axes.
|
||||
* @brief Compute the reduction of input tensor's elements along certain axes.
|
||||
*
|
||||
*/
|
||||
class ReduceMeanObj : public OperatorObj {
|
||||
class ReduceBaseObj : public OperatorObj {
|
||||
protected:
|
||||
set<int> axes; // axis to reduce
|
||||
bool keepDims;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ReduceMean object.
|
||||
* @brief Construct a new Reduce object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param opType The operation type. Should be a Reduce operation.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes Axes to reduce.
|
||||
* @param keepDims Keep the reduced dimensions or not.
|
||||
*/
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
OP_CLONE(ReduceMeanObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims);
|
||||
virtual ~ReduceBaseObj() {}
|
||||
OP_CLONE(ReduceBaseObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -38,4 +41,15 @@ class ReduceMeanObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class ReduceMeanObj : public ReduceBaseObj {
|
||||
public:
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
};
|
||||
|
||||
class ReduceSumObj : public ReduceBaseObj {
|
||||
public:
|
||||
ReduceSumObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
};
|
||||
} // namespace infini
|
|
@ -9,6 +9,7 @@ namespace infini {
|
|||
*/
|
||||
class ReshapeObj : public OperatorObj {
|
||||
Shape dims;
|
||||
Shape outputShape;
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -17,18 +18,20 @@ class ReshapeObj : public OperatorObj {
|
|||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param dims The shape of the output tensor.
|
||||
* @param dims The shape to infer the output shape.
|
||||
* @param outputShape The real shape of output tensor.
|
||||
*/
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
OP_CLONE(ReshapeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getShape() const { return dims; }
|
||||
inline Shape getShape() const { return outputShape; }
|
||||
inline Shape getDims() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
@ -55,7 +58,7 @@ class FlattenObj : public OperatorObj {
|
|||
FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis);
|
||||
OP_CLONE(FlattenObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -85,7 +88,7 @@ class IdentityObj : public OperatorObj {
|
|||
IdentityObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(IdentityObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -27,6 +27,60 @@ class ResizeObj : public OperatorObj {
|
|||
enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller, none };
|
||||
enum class ECoeffMode { nearest, linear, cubic };
|
||||
|
||||
static ECoordinateTransMode fromECoordinateTransModeStr(string mode) {
|
||||
if (mode == "half_pixel") {
|
||||
return ECoordinateTransMode::halfPixel;
|
||||
} else if (mode == "asymmetric") {
|
||||
return ECoordinateTransMode::asymmetric;
|
||||
} else if (mode == "align_corners") {
|
||||
return ECoordinateTransMode::alignCorners;
|
||||
} else if (mode == "pytorch_half_pixel") {
|
||||
return ECoordinateTransMode::pytorchHalfPixel;
|
||||
} else if (mode == "tf_crop_and_resize") {
|
||||
return ECoordinateTransMode::tfCropAndResize;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
static ENearestMode fromENearestModeStr(string mode) {
|
||||
if (mode == "round_prefer_floor") {
|
||||
return ENearestMode::roundPreferFloor;
|
||||
} else if (mode == "round_prefer_ceil") {
|
||||
return ENearestMode::roundPreferCeil;
|
||||
} else if (mode == "floor") {
|
||||
return ENearestMode::floor;
|
||||
} else if (mode == "ceil") {
|
||||
return ENearestMode::ceil;
|
||||
} else {
|
||||
return ENearestMode::none;
|
||||
}
|
||||
}
|
||||
|
||||
static EKeepAspectRatioPolicy fromRatioPolicyStr(string ratioPolicyStr) {
|
||||
if (ratioPolicyStr == "stretch") {
|
||||
return EKeepAspectRatioPolicy::stretch;
|
||||
} else if (ratioPolicyStr == "not_larger") {
|
||||
return EKeepAspectRatioPolicy::notLarger;
|
||||
} else if (ratioPolicyStr == "not_smaller") {
|
||||
return EKeepAspectRatioPolicy::notSmaller;
|
||||
} else {
|
||||
return EKeepAspectRatioPolicy::none;
|
||||
}
|
||||
}
|
||||
|
||||
static ECoeffMode fromECoeffModeStr(string mode) {
|
||||
if (mode == "nearest") {
|
||||
return ECoeffMode::nearest;
|
||||
} else if (mode == "linear") {
|
||||
return ECoeffMode::linear;
|
||||
} else if (mode == "cubic") {
|
||||
return ECoeffMode::cubic;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<int> axes;
|
||||
vector<float> scales;
|
||||
|
@ -60,7 +114,7 @@ class ResizeObj : public OperatorObj {
|
|||
|
||||
// Operator clone(TensorVec inputs, TensorVec outputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
*
|
||||
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
|
||||
*/
|
||||
class SendObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new SendRecv object
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input send input
|
||||
* @param output recv output
|
||||
* @param source the send rank
|
||||
* @param destination the recv rank
|
||||
*/
|
||||
SendObj(GraphObj *graph, Tensor input, int source, int destination,
|
||||
Tensor output = nullptr);
|
||||
OP_CLONE(SendObj);
|
||||
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int getSourceRank() const { return source; }
|
||||
int getDestinationRank() const { return destination; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
protected:
|
||||
int source;
|
||||
int destination;
|
||||
};
|
||||
} // namespace infini
|
|
@ -32,7 +32,7 @@ class SliceObj : public OperatorObj {
|
|||
const optional<vector<int>> &steps);
|
||||
OP_CLONE(SliceObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
inline int numInputs() const override { return 1; }
|
||||
inline int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -10,7 +10,7 @@ class SoftmaxObj : public OperatorObj {
|
|||
|
||||
OP_CLONE(SoftmaxObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class SplitObj : public OperatorObj {
|
|||
int dim, const vector<int> &ratio);
|
||||
OP_CLONE(SplitObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief Remove single-dimensional entries from the shape of a tensor.
|
||||
*
|
||||
*/
|
||||
class SqueezeObj : public OperatorObj {
|
||||
Shape axes;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Squeeze object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes List of integers indicating the dimensions to squeeze.
|
||||
*/
|
||||
SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||
OP_CLONE(SqueezeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getAxes() const { return axes; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -7,7 +7,7 @@ class TransposeObj : public OperatorObj {
|
|||
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
vector<int> permute);
|
||||
OP_CLONE(TransposeObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -19,4 +19,33 @@ class TransposeObj : public OperatorObj {
|
|||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class DepthToSpaceObj : public OperatorObj {
|
||||
public:
|
||||
DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize,
|
||||
std::string mode);
|
||||
OP_CLONE(DepthToSpaceObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getBlockSize() const { return blockSize; }
|
||||
int getMode() const { return D2SMode; }
|
||||
auto getModeString() const { return D2SModeString; }
|
||||
auto getReshapeDim() const { return reshapeDim; }
|
||||
auto getTransposeDim() const { return transposeDim; }
|
||||
auto getOutDim() const { return outDim; }
|
||||
|
||||
private:
|
||||
int blockSize;
|
||||
int D2SMode;
|
||||
std::string D2SModeString;
|
||||
mutable std::vector<int> reshapeDim = {1, 1, 1, 1, 1, 1};
|
||||
mutable std::vector<int> transposeDim = {1, 1, 1, 1, 1, 1};
|
||||
mutable std::vector<int> outDim = {1, 1, 1, 1};
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -17,7 +17,7 @@ class UnaryObj : public OperatorObj {
|
|||
* @param output The output tensor.
|
||||
*/
|
||||
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -33,7 +33,7 @@ class ClipObj : public OperatorObj {
|
|||
ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
std::optional<float> min, std::optional<float> max);
|
||||
OP_CLONE(ClipObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
std::optional<float> getMin() const { return minValue; };
|
||||
|
@ -52,7 +52,7 @@ class HardtanhObj : public OperatorObj {
|
|||
HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min,
|
||||
float max);
|
||||
OP_CLONE(HardtanhObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getMin() const { return minValue; };
|
||||
|
@ -70,7 +70,7 @@ class FlipObj : public OperatorObj {
|
|||
public:
|
||||
FlipObj(GraphObj *graph, Tensor input, Tensor output, vector<int> axis);
|
||||
OP_CLONE(FlipObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
vector<int> getAxis() const { return axisValue; };
|
||||
|
@ -87,7 +87,7 @@ class FillObj : public OperatorObj {
|
|||
public:
|
||||
FillObj(GraphObj *graph, Tensor input, Tensor output, float value);
|
||||
OP_CLONE(FillObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getValue() const { return setValue; };
|
||||
|
@ -104,7 +104,7 @@ class L2LossObj : public OperatorObj {
|
|||
public:
|
||||
L2LossObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(L2LossObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -120,7 +120,7 @@ class TransformObj : public OperatorObj {
|
|||
TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha,
|
||||
float beta);
|
||||
OP_CLONE(TransformObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getAlpha() const { return alphaValue; }
|
||||
|
@ -165,7 +165,7 @@ class CastObj : public OperatorObj {
|
|||
public:
|
||||
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
|
||||
OP_CLONE(CastObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -185,7 +185,7 @@ class CumsumObj : public OperatorObj {
|
|||
CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis,
|
||||
bool exclusive, bool reverse);
|
||||
OP_CLONE(CumsumObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int getAxis() const { return axisValue; }
|
||||
|
@ -205,7 +205,7 @@ class ShapeObj : public OperatorObj {
|
|||
public:
|
||||
ShapeObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(ShapeObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -216,7 +216,7 @@ class PReluObj : public OperatorObj {
|
|||
public:
|
||||
PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output);
|
||||
OP_CLONE(PReluObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -236,7 +236,7 @@ class LogObj : public OperatorObj {
|
|||
};
|
||||
LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type);
|
||||
OP_CLONE(LogObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
LogType getType() const { return logType; }
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief nsert single-dimensional entries to the shape of an input tensor.
|
||||
*
|
||||
*/
|
||||
class UnsqueezeObj : public OperatorObj {
|
||||
Shape axes;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Unsqueeze object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes List of integers indicating the dimensions to be inserted.
|
||||
*/
|
||||
UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||
OP_CLONE(UnsqueezeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getAxes() const { return axes; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -22,7 +22,7 @@ class WhereObj : public OperatorObj {
|
|||
Tensor output);
|
||||
OP_CLONE(WhereObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
namespace infini {
|
||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||
int nDims, int size) {
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
for (int i = nDims - size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i] = 1;
|
||||
}
|
||||
for (int i = size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i + nDims - size] = originShape[i];
|
||||
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||
modifyShape.data[i] = originShape[i - nDims + size];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,12 @@ template <int val> class ValGenerator : public DataGenerator {
|
|||
fill<uint32_t>(data, size);
|
||||
}
|
||||
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
||||
void fill_fp16(uint16_t *data, size_t size) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float x = 1.0f * val;
|
||||
data[i] = float_to_fp16(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
typedef ValGenerator<1> OneGenerator;
|
||||
typedef ValGenerator<0> ZeroGenerator;
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#ifndef OPERATOR_UTIL_H
|
||||
#define OPERATOR_UTIL_H
|
||||
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -10,6 +11,15 @@ namespace infini {
|
|||
Shape infer_broadcast(const Shape &A, const Shape &B);
|
||||
// Launch the real axis based on rank and current axis
|
||||
int get_real_axis(const int &axis, const int &rank);
|
||||
// Check if tensor B is unidirectional broadcastable to tensor A
|
||||
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
|
||||
// Locate the index with size from Shape
|
||||
Shape locate_index(size_t inputN, const Shape &shape);
|
||||
// Delocate the ShapeIndex from Shape with broadcast
|
||||
size_t delocate_index(const Shape &shapeIndex, const Shape &shape,
|
||||
const Shape &stride);
|
||||
// Convert KernelAttrs to a string representation
|
||||
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -209,6 +209,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||
|
||||
"""Gelu operator is not supported by onnx 14.1 currently."""
|
||||
|
||||
def test_gelu(self):
|
||||
pass
|
||||
# x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -294,6 +295,36 @@ class TestStringMethods(unittest.TestCase):
|
|||
make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
|
||||
)
|
||||
|
||||
def test_resize(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 128, 40, 40])
|
||||
roi = make_tensor("roi", TensorProto.FLOAT, [0], [])
|
||||
scales = make_tensor("scales", TensorProto.FLOAT, [4], [1, 1, 2, 2])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 128, 80, 80])
|
||||
reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize")
|
||||
make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales]))
|
||||
|
||||
def test_squeeze(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 5])
|
||||
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [3, 5])
|
||||
squeeze = make_node("Squeeze", ["input", "axes"], ["output"], name="squeeze")
|
||||
make_and_import_model(
|
||||
make_graph([squeeze], "squeeze", [input, axes], [output], [axes_data])
|
||||
)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 3, 4, 5])
|
||||
unsqueeze = make_node(
|
||||
"Unsqueeze", ["input", "axes"], ["output"], name="unsqueeze"
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph([unsqueeze], "unsqueeze", [input, axes], [output], [axes_data])
|
||||
)
|
||||
|
||||
def test_concat(self):
|
||||
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])
|
||||
|
@ -337,6 +368,14 @@ class TestStringMethods(unittest.TestCase):
|
|||
)
|
||||
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
||||
|
||||
def test_reduce_sum(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
|
||||
reduceSum = make_node(
|
||||
"ReduceSum", ["data"], ["reduced"], keepdims=1, name="reduceSum"
|
||||
)
|
||||
make_and_import_model(make_graph([reduceSum], "reduceSum", [data], [reduced]))
|
||||
|
||||
def test_slice(self):
|
||||
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])
|
||||
output = make_tensor_value_info("output", TensorProto.UINT32, [1, 1, 99, 95])
|
||||
|
@ -426,6 +465,12 @@ class TestStringMethods(unittest.TestCase):
|
|||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
|
||||
def test_split1(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
splitAttr = make_tensor_value_info("split", TensorProto.INT64, [2, 1])
|
||||
split = make_node("Split", ["input", "split"], ["output"], name="split", axis=1)
|
||||
make_and_import_model(make_graph([split], "split", [input, splitAttr], []))
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -499,6 +544,47 @@ class TestStringMethods(unittest.TestCase):
|
|||
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||
|
||||
def test_send(self):
|
||||
sendInput = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
send = make_node("Send", ["input"], [], name="send", source=0, destination=1)
|
||||
graph = make_graph([send], "send", [sendInput], [])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
def test_recv(self):
|
||||
recvOutput = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
recv = make_node(
|
||||
"Recv",
|
||||
[],
|
||||
["output"],
|
||||
name="recv",
|
||||
source=0,
|
||||
destination=1,
|
||||
shape=[1, 3, 5, 7],
|
||||
dataType=1,
|
||||
)
|
||||
graph = make_graph([recv], "recv", [], [recvOutput])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
|
||||
class TestDynamicTensor(unittest.TestCase):
|
||||
def test_dynamic_tensor(self):
|
||||
filename = r"resnet18-v2-7.onnx"
|
||||
current_path = os.getcwd()
|
||||
model_file = ""
|
||||
for root, dirs, files in os.walk(current_path):
|
||||
if filename in files:
|
||||
model_file = os.path.join(root, filename)
|
||||
|
||||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
|
||||
output_key = list(model.outputs.keys())[0]
|
||||
old_output_shape = model.getShape(output_key)
|
||||
self.assertEqual(old_output_shape, ([1, 1000]))
|
||||
model.set_input([[5, 3, 224, 224]])
|
||||
new_output_shape = model.getShape(output_key)
|
||||
self.assertEqual(new_output_shape, ([5, 1000]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#ifdef INFINI_USE_CNCL
|
||||
#include "bang/cncl_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -13,19 +16,20 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
this->resetWorkspace();
|
||||
continue;
|
||||
}
|
||||
|
||||
PerfRecord record;
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op, this);
|
||||
this->resetWorkspace();
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
|
@ -36,6 +40,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
[&]() { sync(); }, 1, 1);
|
||||
this->resetWorkspace();
|
||||
op->print();
|
||||
printf(" op_time on bang %lf\n", t);
|
||||
totalTime += t;
|
||||
|
@ -56,4 +61,15 @@ void BangRuntimeObj::sync() const { cnrtSyncDevice(); }
|
|||
|
||||
string BangRuntimeObj::toString() const { return "BANG Runtime"; }
|
||||
|
||||
void BangRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||
IT_ASSERT(worldSize > 0);
|
||||
IT_ASSERT(rank >= 0);
|
||||
IT_ASSERT(rank < worldSize);
|
||||
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||
#ifdef INFINI_USE_CNCL
|
||||
comm = std::make_unique<CnclCommunicatorObj>(name, worldSize, rank);
|
||||
#else
|
||||
IT_TODO_HALT_MSG("Not compiled with CNCL.");
|
||||
#endif
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
|
||||
namespace infini {
|
||||
|
@ -9,20 +11,33 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
|||
map<UidBaseType, Tensor> tensorPool;
|
||||
// Clone tensors
|
||||
for (const auto &op : ops_in) {
|
||||
for (const auto &t : op->getInputs())
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||
for (const auto &t : op->getOutputs())
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||
for (const auto &t : op->getInputs()) {
|
||||
if (t) {
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||
}
|
||||
}
|
||||
for (const auto &t : op->getOutputs()) {
|
||||
if (t) {
|
||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clone operators and add connections
|
||||
for (const auto &op : ops_in) {
|
||||
TensorVec inputs, outputs;
|
||||
for (const auto &t : op->getInputs())
|
||||
inputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
for (const auto &t : op->getOutputs())
|
||||
outputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
for (const auto &t : op->getInputs()) {
|
||||
if (t) {
|
||||
inputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &t : op->getOutputs()) {
|
||||
if (t) {
|
||||
outputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||
}
|
||||
}
|
||||
addOperatorAndConnect(op->clone(inputs, outputs));
|
||||
}
|
||||
}
|
||||
|
@ -31,17 +46,21 @@ void GraphObj::addOperatorAndConnect(const Operator &op) {
|
|||
sorted = false;
|
||||
ops.push_back(op);
|
||||
for (auto &input : op->getInputs()) {
|
||||
input->addTarget(op);
|
||||
if (auto pred = input->getSource()) {
|
||||
pred->addSuccessors(op);
|
||||
op->addPredecessors(pred);
|
||||
if (input) {
|
||||
input->addTarget(op);
|
||||
if (auto pred = input->getSource()) {
|
||||
pred->addSuccessors(op);
|
||||
op->addPredecessors(pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &output : op->getOutputs()) {
|
||||
output->setSource(op);
|
||||
for (auto &succ : output->getTargets()) {
|
||||
succ->addPredecessors(op);
|
||||
op->addSuccessors(succ);
|
||||
if (output) {
|
||||
output->setSource(op);
|
||||
for (auto &succ : output->getTargets()) {
|
||||
succ->addPredecessors(op);
|
||||
op->addSuccessors(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,48 +87,33 @@ string GraphObj::toString() const {
|
|||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
if (this->sorted) {
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
}
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
std::unordered_set<OperatorObj *> flags;
|
||||
sorted.reserve(ops.size());
|
||||
flags.reserve(ops.size());
|
||||
while (sorted.size() < ops.size()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getSource();
|
||||
return src // If the source node is in the waiting list,
|
||||
// means that this node is not the head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
for (auto const &op : ops) {
|
||||
if (auto const &inputs = op->getInputs();
|
||||
flags.find(op.get()) == flags.end() &&
|
||||
std::all_of(inputs.begin(), inputs.end(),
|
||||
[&flags](auto const &input) {
|
||||
auto ptr = input->getSource().get();
|
||||
return !ptr || flags.find(ptr) != flags.end();
|
||||
})) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
sorted.emplace_back(op);
|
||||
flags.insert(op.get());
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
@ -123,19 +127,56 @@ void GraphObj::optimize() {
|
|||
}
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
||||
Tensor GraphObj::getTensor(int fuid) const {
|
||||
for (auto tensor : tensors) {
|
||||
if (tensor->getFuid() == fuid) {
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void GraphObj::shape_infer() {
|
||||
for (auto &op : ops) {
|
||||
auto ans = op->inferShape();
|
||||
IT_ASSERT(ans.has_value());
|
||||
auto oldOutputs = op->getOutputs();
|
||||
IT_ASSERT(ans.value().size() == oldOutputs.size());
|
||||
// replace the old outputshape and size with new one
|
||||
for (int i = 0; i < (int)ans.value().size(); ++i) {
|
||||
auto newShape = ans.value()[i];
|
||||
auto oldShape = oldOutputs[i]->getDims();
|
||||
auto fuid = oldOutputs[i]->getFuid();
|
||||
if (newShape != oldShape) {
|
||||
auto tensor = this->getTensor(fuid);
|
||||
tensor->setShape(newShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
||||
// topological sorting first
|
||||
|
||||
IT_ASSERT(topo_sort() == true);
|
||||
if (useNaiveAllocator) {
|
||||
// can not set memory pool when use naive allocator
|
||||
IT_ASSERT(memPoolSize == 0);
|
||||
// used for debugging memory out-of-bounds access, tensors will not be
|
||||
// released correctly
|
||||
// note: behavior may not match running in non-naive mode, and it may
|
||||
// not reproduce the bug
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->dataMalloc();
|
||||
if (!tensor->isWeight() ||
|
||||
(tensor->isWeight() && !weightAllocated)) {
|
||||
tensor->dataMalloc();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (memPoolSize > 0) {
|
||||
allocator.setMemPool(memPoolSize);
|
||||
}
|
||||
// count the number of times all tensors are used
|
||||
std::unordered_map<TensorObj *, size_t> tensorToRefCount;
|
||||
// record the memory address offsets of all tensors to be allocated
|
||||
|
@ -187,24 +228,28 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
|||
// memory should be allocated for the op's output first
|
||||
auto outputs = op->getOutputs();
|
||||
for (auto &tensor : outputs) {
|
||||
if (tensor->isOthers()) {
|
||||
tensorToOffset[tensor.get()] =
|
||||
allocator.alloc(tensor->getBytes());
|
||||
if (tensor) {
|
||||
if (tensor->isOthers()) {
|
||||
tensorToOffset[tensor.get()] =
|
||||
allocator.alloc(tensor->getBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto inputs = op->getInputs();
|
||||
for (auto &tensor : inputs) {
|
||||
if (tensor->isOthers()) {
|
||||
auto tensorIter = tensorToRefCount.find(tensor.get());
|
||||
IT_ASSERT(tensorIter != tensorToRefCount.end());
|
||||
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
|
||||
tensorToRefCount[tensor.get()] -= 1;
|
||||
if (tensorToRefCount[tensor.get()] == 0) {
|
||||
// indicate that this tensor will no longer be used and
|
||||
// perform memory free
|
||||
tensorToRefCount.erase(tensor.get());
|
||||
allocator.free(tensorToOffset[tensor.get()],
|
||||
tensor->getBytes());
|
||||
if (tensor) {
|
||||
if (tensor->isOthers()) {
|
||||
auto tensorIter = tensorToRefCount.find(tensor.get());
|
||||
IT_ASSERT(tensorIter != tensorToRefCount.end());
|
||||
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
|
||||
tensorToRefCount[tensor.get()] -= 1;
|
||||
if (tensorToRefCount[tensor.get()] == 0) {
|
||||
// indicate that this tensor will no longer be used and
|
||||
// perform memory free
|
||||
tensorToRefCount.erase(tensor.get());
|
||||
allocator.free(tensorToOffset[tensor.get()],
|
||||
tensor->getBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -222,6 +267,27 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphObj::cloneKV(Tensor &tensor) {
|
||||
auto obj = tensor->clone();
|
||||
if (allocator.getMemPoolStatus()) {
|
||||
if (tensor->hasData()) {
|
||||
obj->setDataBlob(make_ref<BlobObj>(
|
||||
tensor->runtime,
|
||||
static_cast<uint8_t *>(allocator.getHeapPtr()) +
|
||||
allocator.heapAlloc(tensor->getBytes())));
|
||||
obj->copyData(tensor);
|
||||
}
|
||||
} else {
|
||||
if (tensor->hasData()) {
|
||||
obj->dataMalloc();
|
||||
obj->copyData(tensor);
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
void GraphObj::freeHeap() { this->allocator.freeHeap(); }
|
||||
|
||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/all_gather.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/broadcast.h"
|
||||
#include "operators/concat.h"
|
||||
|
@ -8,17 +9,26 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/layer_norm.h"
|
||||
#include "operators/lrn.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/recv.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
#include "operators/where.h"
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -94,6 +104,23 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
||||
Tensor output, Tensor bias,
|
||||
float eps, int axis,
|
||||
int stash_type) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<LayerNormObj>(std::move(input), std::move(scale),
|
||||
output, std::move(bias), eps, axis,
|
||||
stash_type);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<LayerNormObj>(std::move(input), std::move(scale), output,
|
||||
std::move(bias), eps, axis, stash_type)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||
int ceilMode) {
|
||||
|
@ -230,6 +257,64 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::resize(Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes,
|
||||
Tensor sizes, Tensor scales, Tensor roi,
|
||||
vector<uint32_t> sizes_, vector<float> scales_,
|
||||
vector<float> roi_, string mode,
|
||||
string ratioPolicy, string nearestMode,
|
||||
string coordTransMode) {
|
||||
if (sizes_.size() > 0) {
|
||||
sizes->dataMalloc();
|
||||
sizes->copyin<uint32_t>(sizes_);
|
||||
}
|
||||
if (scales_.size() > 0) {
|
||||
scales->dataMalloc();
|
||||
scales->copyin<float>(scales_);
|
||||
}
|
||||
if (roi_.size() > 0) {
|
||||
roi->dataMalloc();
|
||||
roi->copyin<float>(roi_);
|
||||
}
|
||||
ResizeObj::EKeepAspectRatioPolicy ratioPolicy_ =
|
||||
ResizeObj::fromRatioPolicyStr(ratioPolicy);
|
||||
ResizeObj::ENearestMode nearestMode_ =
|
||||
ResizeObj::fromENearestModeStr(nearestMode);
|
||||
ResizeObj::ECoordinateTransMode coordTransMode_ =
|
||||
ResizeObj::fromECoordinateTransModeStr(coordTransMode);
|
||||
ResizeObj::ECoeffMode mode_ = ResizeObj::fromECoeffModeStr(mode);
|
||||
if (output) {
|
||||
if (mode == "nearest") {
|
||||
g->addOpWithOutputs<ResizeObj>(
|
||||
std::move(input), output, std::move(axes), std::move(sizes),
|
||||
std::move(scales), std::move(roi), ratioPolicy_, nearestMode_,
|
||||
coordTransMode_);
|
||||
} else {
|
||||
g->addOpWithOutputs<ResizeObj>(
|
||||
std::move(input), output, std::move(axes), std::move(sizes),
|
||||
std::move(scales), std::move(roi), mode_, ratioPolicy_,
|
||||
coordTransMode_);
|
||||
}
|
||||
return output;
|
||||
} else {
|
||||
if (mode == "nearest") {
|
||||
return g
|
||||
->addOp<ResizeObj>(std::move(input), output, std::move(axes),
|
||||
std::move(sizes), std::move(scales),
|
||||
std::move(roi), ratioPolicy_, nearestMode_,
|
||||
coordTransMode_)
|
||||
->getOutput();
|
||||
} else {
|
||||
return g
|
||||
->addOp<ResizeObj>(std::move(input), output, std::move(axes),
|
||||
std::move(sizes), std::move(scales),
|
||||
std::move(roi), mode_, ratioPolicy_,
|
||||
coordTransMode_)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ConcatObj>(std::move(inputs), output, dim);
|
||||
|
@ -239,15 +324,51 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v,
|
||||
Tensor position_id,
|
||||
Tensor output_matmul) {
|
||||
if (output_matmul) {
|
||||
g->addOpWithOutputs<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul);
|
||||
return {output_matmul};
|
||||
} else {
|
||||
return g
|
||||
->addOp<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int axis, int num_outputs) {
|
||||
int axis,
|
||||
std::variant<int, vector<int>> numOrRatio) {
|
||||
if (outputs) {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
num_outputs);
|
||||
if (std::holds_alternative<int>(numOrRatio)) {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<int>(numOrRatio));
|
||||
} else {
|
||||
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<vector<int>>(numOrRatio));
|
||||
}
|
||||
return *outputs;
|
||||
} else {
|
||||
return g->addOp<SplitObj>(std::move(input), outputs, axis, num_outputs)
|
||||
->getOutputs();
|
||||
if (std::holds_alternative<int>(numOrRatio)) {
|
||||
return g
|
||||
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<int>(numOrRatio))
|
||||
->getOutputs();
|
||||
} else {
|
||||
return g
|
||||
->addOp<SplitObj>(std::move(input), outputs, axis,
|
||||
std::get<vector<int>>(numOrRatio))
|
||||
->getOutputs();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -279,18 +400,23 @@ Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes,
|
||||
bool keepdims) {
|
||||
if (reduced) {
|
||||
g->addOpWithOutputs<ReduceMeanObj>(std::move(data), reduced, axes,
|
||||
keepdims);
|
||||
return reduced;
|
||||
} else {
|
||||
return g->addOp<ReduceMeanObj>(std::move(data), reduced, axes, keepdims)
|
||||
->getOutput();
|
||||
#define DEFINE_REDUCE_METHOD(name, obj) \
|
||||
Tensor GraphHandlerObj::name(Tensor data, Tensor reduced, \
|
||||
const optional<vector<int>> &axes, \
|
||||
bool keepdims) { \
|
||||
if (reduced) { \
|
||||
g->addOpWithOutputs<_CAT(obj, Obj)>(std::move(data), reduced, \
|
||||
axes, keepdims); \
|
||||
return reduced; \
|
||||
} else { \
|
||||
return g \
|
||||
->addOp<_CAT(obj, Obj)>(std::move(data), reduced, axes, \
|
||||
keepdims) \
|
||||
->getOutput(); \
|
||||
} \
|
||||
}
|
||||
}
|
||||
DEFINE_REDUCE_METHOD(reduceMean, ReduceMean)
|
||||
DEFINE_REDUCE_METHOD(reduceSum, ReduceSum)
|
||||
|
||||
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
|
||||
const vector<int> &starts,
|
||||
|
@ -388,6 +514,39 @@ Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::send(Tensor input, int source, int destination,
|
||||
Tensor output) {
|
||||
if (output) {
|
||||
|
||||
g->addOpWithOutputs<SendObj>(std::move(input), source, destination,
|
||||
output);
|
||||
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<SendObj>(std::move(input), source, destination, output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::recv(Tensor output, int source, int destination,
|
||||
Shape dims, int outputType, Tensor input) {
|
||||
|
||||
if (output) {
|
||||
|
||||
g->addOpWithOutputs<RecvObj>(output, source, destination,
|
||||
std::move(dims), outputType,
|
||||
std::move(input));
|
||||
|
||||
return output;
|
||||
} else {
|
||||
|
||||
return g
|
||||
->addOp<RecvObj>(output, source, destination, std::move(dims),
|
||||
outputType, std::move(input))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<CastObj>(std::move(input), output,
|
||||
|
@ -425,6 +584,54 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::depthToSpace(Tensor input, Tensor output, int blocksize,
|
||||
std::string mode) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<DepthToSpaceObj>(std::move(input), output,
|
||||
blocksize, mode);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<DepthToSpaceObj>(std::move(input), output, blocksize, mode)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
||||
float beta, float bias, int size) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<LRNObj>(std::move(input), output, alpha, beta, bias,
|
||||
size);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<LRNObj>(std::move(input), output, alpha, beta, bias, size)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<SqueezeObj>(std::move(input), output,
|
||||
std::move(axes));
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<SqueezeObj>(std::move(input), output, std::move(axes))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<UnsqueezeObj>(std::move(input), output,
|
||||
std::move(axes));
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<UnsqueezeObj>(std::move(input), output, std::move(axes))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
auto oType = DataType(to);
|
||||
|
@ -520,4 +727,11 @@ static DataType dtype_repr_convert(int dtype) {
|
|||
}
|
||||
}
|
||||
|
||||
void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
|
||||
auto tensor = g->getTensor(tensorId);
|
||||
IT_ASSERT(tensor != nullptr);
|
||||
IT_ASSERT(shape.size() != 0);
|
||||
tensor->setShape(shape);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -30,6 +30,9 @@ LazyAllocator::~LazyAllocator() {
|
|||
if (this->weightPtr != nullptr) {
|
||||
runtime->dealloc(this->weightPtr);
|
||||
}
|
||||
if (this->memPoolPtr != nullptr) {
|
||||
runtime->dealloc(this->memPoolPtr);
|
||||
}
|
||||
}
|
||||
|
||||
void LazyAllocator::init() {
|
||||
|
@ -44,6 +47,17 @@ void LazyAllocator::init() {
|
|||
this->ptr = nullptr;
|
||||
}
|
||||
|
||||
void LazyAllocator::setMemPool(size_t memPoolSize) {
|
||||
IT_ASSERT(memPoolSize > 0);
|
||||
if (!this->hasMemPool) {
|
||||
this->hasMemPool = true;
|
||||
this->memPoolSize = memPoolSize;
|
||||
this->memPoolPtr = runtime->alloc(memPoolSize);
|
||||
}
|
||||
}
|
||||
|
||||
bool LazyAllocator::getMemPoolStatus() { return this->hasMemPool; }
|
||||
|
||||
size_t LazyAllocator::alloc(size_t size) {
|
||||
// pad the size to the multiple of alignment
|
||||
size = this->getAlignedSize(size);
|
||||
|
@ -102,6 +116,17 @@ size_t LazyAllocator::allocWeight(size_t size) {
|
|||
return retAddr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::heapAlloc(size_t size) {
|
||||
size = this->getAlignedSize(size);
|
||||
this->heapPeak += size;
|
||||
IT_ASSERT(this->memPoolSize >=
|
||||
this->weightPeak + this->peak + this->heapPeak);
|
||||
size_t retAddr = this->memPoolSize - this->heapPeak;
|
||||
return retAddr;
|
||||
}
|
||||
|
||||
void LazyAllocator::freeHeap() { this->heapPeak = 0; }
|
||||
|
||||
void LazyAllocator::free(size_t addr, size_t size) {
|
||||
IT_ASSERT(this->ptr == nullptr);
|
||||
size = getAlignedSize(size);
|
||||
|
@ -143,25 +168,40 @@ void LazyAllocator::free(size_t addr, size_t size) {
|
|||
}
|
||||
|
||||
void *LazyAllocator::getPtr() {
|
||||
if (this->ptr == nullptr) {
|
||||
this->ptr = runtime->alloc(this->peak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc non-weight: %p %lu
|
||||
// bytes\n", this->ptr, peak);
|
||||
// #endif
|
||||
if (!hasMemPool) {
|
||||
if (this->ptr == nullptr) {
|
||||
this->ptr = runtime->alloc(this->peak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc non-weight: %p %lu
|
||||
// bytes\n", this->ptr, peak);
|
||||
// #endif
|
||||
}
|
||||
return this->ptr;
|
||||
} else {
|
||||
IT_ASSERT(this->memPoolSize >= this->weightPeak + this->peak);
|
||||
return static_cast<uint8_t *>(this->memPoolPtr) + weightPeak;
|
||||
}
|
||||
return this->ptr;
|
||||
}
|
||||
|
||||
void *LazyAllocator::getWeightPtr() {
|
||||
if (this->weightPtr == nullptr) {
|
||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc weight: %p %lu bytes\n",
|
||||
// this->weightPtr, weightPeak);
|
||||
// #endif
|
||||
if (!hasMemPool) {
|
||||
if (this->weightPtr == nullptr) {
|
||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc weight: %p %lu
|
||||
// bytes\n",
|
||||
// this->weightPtr, weightPeak);
|
||||
// #endif
|
||||
}
|
||||
return this->weightPtr;
|
||||
} else {
|
||||
return this->memPoolPtr;
|
||||
}
|
||||
return this->weightPtr;
|
||||
}
|
||||
|
||||
void *LazyAllocator::getHeapPtr() {
|
||||
IT_ASSERT(hasMemPool);
|
||||
return this->memPoolPtr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::getAlignedSize(size_t size) {
|
||||
|
|
|
@ -6,8 +6,10 @@ namespace infini {
|
|||
|
||||
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||
: type(opType), inputs(inputs), outputs(outputs) {
|
||||
for (const auto &t : inputs)
|
||||
IT_ASSERT(t);
|
||||
if (opType != OpType::Recv) {
|
||||
for (const auto &t : inputs)
|
||||
IT_ASSERT(t);
|
||||
}
|
||||
}
|
||||
|
||||
void OperatorObj::removePredecessors(const Operator &op) {
|
||||
|
@ -77,9 +79,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
optional<vector<Shape>> OperatorObj::inferShape() const {
|
||||
return inferShape(inputs);
|
||||
}
|
||||
optional<vector<Shape>> OperatorObj::inferShape() { return inferShape(inputs); }
|
||||
|
||||
vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const {
|
||||
auto dataType = inputs[0]->getDType();
|
||||
|
|
|
@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -59,6 +59,13 @@ Shape TensorObj::getStride() const {
|
|||
return stride;
|
||||
}
|
||||
|
||||
void TensorObj::setShape(Shape shape_) {
|
||||
shape = shape_;
|
||||
size_t size = std::accumulate(shape.begin(), shape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; });
|
||||
_size = size;
|
||||
}
|
||||
|
||||
void TensorObj::printData() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (!runtime->isCpu())
|
||||
|
|
|
@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
auto &perfEngine = PerfEngine::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
|
||||
DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include <cstdio>
|
||||
|
||||
__global__ void cudaPrintFloatImpl(float *x, int len) {
|
||||
|
@ -18,4 +20,55 @@ void cudaPrintFloat(float *x, int len) {
|
|||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
|
||||
if (dataType == DataType::Float32) {
|
||||
return CUDNN_DATA_FLOAT;
|
||||
}
|
||||
if (dataType == DataType::Double) {
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
}
|
||||
if (dataType == DataType::Float16) {
|
||||
return CUDNN_DATA_HALF;
|
||||
}
|
||||
if (dataType == DataType::Int8) {
|
||||
return CUDNN_DATA_INT8;
|
||||
}
|
||||
if (dataType == DataType::Int32) {
|
||||
return CUDNN_DATA_INT32;
|
||||
}
|
||||
if (dataType == DataType::UInt8) {
|
||||
return CUDNN_DATA_UINT8;
|
||||
}
|
||||
if (dataType == DataType::BFloat16) {
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
}
|
||||
if (dataType == DataType::Int64) {
|
||||
return CUDNN_DATA_INT64;
|
||||
}
|
||||
if (dataType == DataType::Bool) {
|
||||
return CUDNN_DATA_BOOLEAN;
|
||||
}
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
cudaDataType cublasDataTypeConvert(DataType dataType) {
|
||||
switch (dataType.getIndex()) {
|
||||
case 1:
|
||||
return CUDA_R_32F;
|
||||
// case 3:
|
||||
// return CUDA_R_8I;
|
||||
case 10:
|
||||
return CUDA_R_16F;
|
||||
case 11:
|
||||
return CUDA_R_64F;
|
||||
// case 16:
|
||||
// return CUDA_R_16BF;
|
||||
default:
|
||||
IT_ASSERT(false, "MatMul Unsupported data type");
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -5,14 +5,17 @@
|
|||
#include "operators/conv.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/lrn.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
#include <algorithm>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
@ -93,7 +96,10 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Gather)
|
||||
.VALUE(OpType, GatherElements)
|
||||
.VALUE(OpType, ReduceMean)
|
||||
.VALUE(OpType, ReduceSum)
|
||||
.VALUE(OpType, Reshape)
|
||||
.VALUE(OpType, Squeeze)
|
||||
.VALUE(OpType, Unsqueeze)
|
||||
.VALUE(OpType, Flatten)
|
||||
.VALUE(OpType, Identity)
|
||||
.VALUE(OpType, BatchNormalization)
|
||||
|
@ -114,6 +120,8 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Expand)
|
||||
.VALUE(OpType, Erf)
|
||||
.VALUE(OpType, Where)
|
||||
.VALUE(OpType, DepthToSpace)
|
||||
.VALUE(OpType, LRN)
|
||||
.export_values();
|
||||
|
||||
#undef VALUE
|
||||
|
@ -227,12 +235,13 @@ clip_attrs_of(Operator op) {
|
|||
return std::make_tuple(clip->getMin(), clip->getMax());
|
||||
}
|
||||
|
||||
static std::tuple<vector<int>, bool> reduce_mean_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
|
||||
auto reduce_mean = dynamic_cast<const ReduceMeanObj *>(op.get());
|
||||
auto &set = reduce_mean->getAxes();
|
||||
static std::tuple<vector<int>, bool> reduce_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean ||
|
||||
op->getOpType() == OpType::ReduceSum);
|
||||
auto reduce = dynamic_cast<const ReduceBaseObj *>(op.get());
|
||||
auto &set = reduce->getAxes();
|
||||
return std::make_tuple(vector(set.begin(), set.end()),
|
||||
reduce_mean->getKeepDims());
|
||||
reduce->getKeepDims());
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
|
@ -260,6 +269,24 @@ static vector<int64_t> reshape_shape_of(Operator op) {
|
|||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> squeeze_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Squeeze);
|
||||
auto axes = dynamic_cast<const SqueezeObj *>(op.get())->getAxes();
|
||||
vector<int64_t> ans(axes.size());
|
||||
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||
[](auto x) { return static_cast<int64_t>(x); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> unsqueeze_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Unsqueeze);
|
||||
auto axes = dynamic_cast<const UnsqueezeObj *>(op.get())->getAxes();
|
||||
vector<int64_t> ans(axes.size());
|
||||
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||
[](auto x) { return static_cast<int64_t>(x); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> expand_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Expand);
|
||||
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
|
||||
|
@ -295,6 +322,21 @@ static int cast_to_of(Operator op) {
|
|||
return castOutputDtype.getIndex();
|
||||
}
|
||||
|
||||
static std::tuple<int, std::string> depth_to_space_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::DepthToSpace);
|
||||
auto depth_to_space = dynamic_cast<const DepthToSpaceObj *>(op.get());
|
||||
return std::make_tuple(depth_to_space->getBlockSize(),
|
||||
depth_to_space->getModeString());
|
||||
}
|
||||
|
||||
static std::tuple<float, float, float, int> lrn_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::LRN);
|
||||
auto lrn = dynamic_cast<const LRNObj *>(op.get());
|
||||
auto [alpha, beta, bias] = lrn->getAlphaBetaBias();
|
||||
auto size = lrn->getSize();
|
||||
return std::make_tuple(alpha, beta, bias, size);
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
|
@ -324,7 +366,7 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(pool_attrs_of)
|
||||
.FUNCTION(clip_attrs_of)
|
||||
.FUNCTION(reduce_mean_attrs_of)
|
||||
.FUNCTION(reduce_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(expand_shape_of)
|
||||
|
@ -334,7 +376,11 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(split_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(flatten_axis_of)
|
||||
.FUNCTION(cast_to_of);
|
||||
.FUNCTION(cast_to_of)
|
||||
.FUNCTION(depth_to_space_attrs_of)
|
||||
.FUNCTION(squeeze_axes_of)
|
||||
.FUNCTION(unsqueeze_axes_of)
|
||||
.FUNCTION(lrn_attrs_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
@ -390,7 +436,9 @@ void init_graph_builder(py::module &m) {
|
|||
#endif
|
||||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
m, "BangRuntime");
|
||||
m, "BangRuntime")
|
||||
.def(py::init<int>(), py::arg("device") = 0)
|
||||
.def("init_comm", &BangRuntimeObj::initComm);
|
||||
#endif
|
||||
#ifdef USE_KUNLUN
|
||||
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
||||
|
@ -455,7 +503,10 @@ void init_graph_builder(py::module &m) {
|
|||
})
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
.def("printData", &TensorObj::printData, policy::automatic)
|
||||
.def("copy_data",
|
||||
py::overload_cast<const Tensor &>(&TensorObj::copyData),
|
||||
policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
|
@ -470,11 +521,13 @@ void init_graph_builder(py::module &m) {
|
|||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
.def("sub", &Handler::sub, policy::move)
|
||||
.def("mul", &Handler::mul, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("div", &Handler::div, policy::move)
|
||||
.def("pow", &Handler::pow, policy::move)
|
||||
.def("min", &Handler::min, policy::move)
|
||||
|
@ -495,12 +548,18 @@ void init_graph_builder(py::module &m) {
|
|||
.def("pRelu", &Handler::pRelu, policy::move)
|
||||
.def("clip", &Handler::clip, policy::move)
|
||||
.def("transpose", &Handler::transpose, policy::move)
|
||||
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
||||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("resize", &Handler::resize, policy::move)
|
||||
.def("squeeze", &Handler::squeeze, policy::move)
|
||||
.def("unsqueeze", &Handler::unsqueeze, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||
.def("split", &Handler::split, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||
.def("reduceSum", &Handler::reduceSum, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("allReduceSum", &Handler::allReduceSum, policy::move)
|
||||
|
@ -510,17 +569,27 @@ void init_graph_builder(py::module &m) {
|
|||
.def("allReduceAvg", &Handler::allReduceAvg, policy::move)
|
||||
.def("allGather", &Handler::allGather, policy::move)
|
||||
.def("broadcast", &Handler::broadcast, policy::move)
|
||||
.def("send", &Handler::send, policy::move)
|
||||
.def("recv", &Handler::recv, policy::move)
|
||||
.def("cast", &Handler::cast, policy::move)
|
||||
.def("expand", &Handler::expand, policy::move)
|
||||
.def("erf", &Handler::erf, policy::move)
|
||||
.def("where", &Handler::where, policy::move)
|
||||
.def("lrn", &Handler::lrn, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("data_malloc", &Handler::data_malloc,
|
||||
py::arg("useNaiveAllocator") = false, py::arg("memPoolSize") = 0,
|
||||
policy::automatic)
|
||||
.def("clone_KV", &Handler::clone_KV, policy::move)
|
||||
.def("free_heap", &Handler::free_heap, policy::move)
|
||||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
|
||||
.def("tune", &Handler::tune, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic)
|
||||
.def("shape_infer", &Handler::shape_infer, policy::automatic)
|
||||
.def("change_shape", &Handler::change_shape, policy::automatic)
|
||||
.def("getDims", &Handler::getDims, policy::automatic)
|
||||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -10,6 +11,7 @@ class UnaryCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -29,8 +31,9 @@ class UnaryCnnl : public BangKernelWithoutConfig {
|
|||
cDim.data()));
|
||||
cnnlActivationDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetActivationDescriptor(
|
||||
opDesc, getOpType(), CNNL_NOT_PROPAGATE_NAN, getCoef()));
|
||||
checkCnnlError(cnnlSetActivationDescriptor_v2(
|
||||
opDesc, getOpType(), CNNL_ACTIVATION_HIGH_PRECISION,
|
||||
CNNL_NOT_PROPAGATE_NAN, getCoef()));
|
||||
|
||||
auto [alpha, beta] = getAlphBeta();
|
||||
cnnlStatus_t stat =
|
||||
|
@ -48,6 +51,7 @@ class RoundCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -78,6 +82,7 @@ class PReluCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PReluObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -113,6 +118,93 @@ class PReluCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class SoftmaxCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
|
||||
cnnlSoftmaxMode_t mode;
|
||||
size_t axis = op->getAxis();
|
||||
std::vector<int> inDim = {1, 1, 1};
|
||||
std::vector<int> outDim = inDim;
|
||||
|
||||
if (aDim.size() >= 3) {
|
||||
if (axis == 0) {
|
||||
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||
inDim[0] = aDim[0];
|
||||
inDim[1] = aDim[1];
|
||||
for (size_t i = 2; i < aDim.size(); ++i) {
|
||||
inDim[2] *= aDim[i];
|
||||
}
|
||||
outDim = inDim;
|
||||
} else if (axis == aDim.size() - 1) {
|
||||
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
||||
inDim[0] = aDim[0];
|
||||
for (size_t i = 1; i < axis; ++i) {
|
||||
inDim[1] *= aDim[i];
|
||||
}
|
||||
inDim[2] = aDim[axis];
|
||||
outDim = inDim;
|
||||
} else {
|
||||
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
|
||||
for (size_t i = 0; i < axis; ++i) {
|
||||
inDim[0] *= aDim[i];
|
||||
}
|
||||
inDim[1] = aDim[axis];
|
||||
for (size_t i = axis + 1; i < aDim.size(); ++i) {
|
||||
inDim[2] *= aDim[i];
|
||||
}
|
||||
outDim = inDim;
|
||||
}
|
||||
} else if (aDim.size() == 2) {
|
||||
if (axis == 0) {
|
||||
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||
inDim = aDim;
|
||||
inDim.push_back(1);
|
||||
outDim = inDim;
|
||||
} else {
|
||||
mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
|
||||
inDim = aDim;
|
||||
inDim.insert(inDim.begin(), 1);
|
||||
outDim = inDim;
|
||||
}
|
||||
} else {
|
||||
mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION;
|
||||
inDim = aDim;
|
||||
inDim.push_back(1);
|
||||
inDim.push_back(1);
|
||||
outDim = inDim;
|
||||
}
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, inDim.size(),
|
||||
inDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, outDim.size(),
|
||||
outDim.data()));
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
cnnlStatus_t stat =
|
||||
cnnlSoftmaxForward_v2(context->cnnlHandle(), CNNL_SOFTMAX_ACCURATE,
|
||||
mode, CNNL_COMPUTATION_ULTRAHIGH_PRECISION,
|
||||
&alpha, aDesc, aData, &beta, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
};
|
||||
|
||||
class ReluCnnl : public UnaryCnnl {
|
||||
cnnlActivationMode_t getOpType() const override {
|
||||
return CNNL_ACTIVATION_RELU;
|
||||
|
@ -127,13 +219,12 @@ class SigmoidCnnl : public UnaryCnnl {
|
|||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl,
|
||||
"Relu_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl,
|
||||
"PRelu_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
|
||||
"Round_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
|
||||
"Softmax_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -10,6 +10,7 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ActivationBackwardObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -81,11 +82,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl {
|
|||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32,
|
||||
ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32,
|
||||
SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32,
|
||||
TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, ReluBackwardCnnl,
|
||||
"ReluBackward_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, SigmoidBackwardCnnl,
|
||||
"SigmoidBackward_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, TanhBackwardCnnl,
|
||||
"TanhBackward_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
#ifdef INFINI_USE_CNCL
|
||||
#include "operators/all_gather.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "bang/cncl_communicator.h"
|
||||
#include <thread>
|
||||
namespace infini {
|
||||
class AllGatherCNCL : public BangKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AllGatherObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
int world_size = op->getWorldSize();
|
||||
// Check if world size info in operator matches runtime
|
||||
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
|
||||
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
BangPtr output_temp =
|
||||
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
|
||||
// void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
// IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
checkBangError(cnrtMalloc(&output_temp,
|
||||
op->getInputs(0)->getBytes() * world_size));
|
||||
size_t bytes = op->getInputs(0)->getBytes();
|
||||
size_t count = bytes / op->getDType().getSize();
|
||||
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
CNCL_CHECK(
|
||||
cnclAllGather(input, output_temp, count, cnclFloat32, comm, queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
Tensor output = op->getOutput(i);
|
||||
context->copyBlobInsideRuntime(
|
||||
output->getRawDataPtr<float *>(),
|
||||
static_cast<float *>(output_temp) + i * count, bytes);
|
||||
}
|
||||
checkBangError(cnrtFree(output_temp));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AllGather, DataType::Float32,
|
||||
AllGatherCNCL, "AllGather_CNCL_BANG_Float32");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -0,0 +1,53 @@
|
|||
#ifdef INFINI_USE_CNCL
|
||||
#include "operators/all_reduce.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "bang/cncl_communicator.h"
|
||||
#include <thread>
|
||||
namespace infini {
|
||||
class AllReduceCNCL : public BangKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AllReduceBaseObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
size_t count = op->getInputs(0)->size();
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
// checkBangError(cnrtQueueSync(queue));
|
||||
CNCL_CHECK(cnclAllReduce(input, output, count, cnclFloat, getRedOp(),
|
||||
comm, queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
}
|
||||
|
||||
virtual cnclReduceOp_t getRedOp() const = 0;
|
||||
};
|
||||
|
||||
class AllReduceSumCNCL : public AllReduceCNCL {
|
||||
cnclReduceOp_t getRedOp() const override { return cnclSum; }
|
||||
};
|
||||
class AllReduceProdCNCL : public AllReduceCNCL {
|
||||
cnclReduceOp_t getRedOp() const override { return cnclProd; }
|
||||
};
|
||||
class AllReduceMinCNCL : public AllReduceCNCL {
|
||||
cnclReduceOp_t getRedOp() const override { return cnclMin; }
|
||||
};
|
||||
class AllReduceMaxCNCL : public AllReduceCNCL {
|
||||
cnclReduceOp_t getRedOp() const override { return cnclMax; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AllReduceSum, DataType::Float32,
|
||||
AllReduceSumCNCL, "AllReduce_Sum_CNCL_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AllReduceProd, DataType::Float32,
|
||||
AllReduceProdCNCL, "AllReduce_Prod_CNCL_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMin, DataType::Float32,
|
||||
AllReduceMinCNCL, "AllReduce_Min_CNCL_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AllReduceMax, DataType::Float32,
|
||||
AllReduceMaxCNCL, "AllReduce_Max_CNCL_BANG_Float32");
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -7,6 +7,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -17,55 +18,91 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
auto outDims = op->getOutput()->getDims();
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1];
|
||||
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
||||
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
||||
int permute[4] = {0, 2, 3, 1};
|
||||
int permuteOut[4] = {0, 3, 1, 2};
|
||||
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimArray[i] = dims[i];
|
||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
||||
}
|
||||
int w = dimArray[3];
|
||||
dimArray[3] = dimArray[1];
|
||||
int h = dimArray[2];
|
||||
dimArray[1] = h;
|
||||
dimArray[2] = w;
|
||||
|
||||
dimPArray[0] = op->getInputs(1)->getDims()[0];
|
||||
stridePArray[0] = op->getInputs(1)->getDims()[0];
|
||||
// get inputs
|
||||
cnnlTensorDescriptor_t inDesc;
|
||||
cnnlTensorDescriptor_t inDesc, intransDesc, outDesc, outtransDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dimArray, strideArray));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&intransDesc));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&outDesc));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&outtransDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dims.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(intransDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, dims.size(),
|
||||
dimsTrans));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(outDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, outDims.size(),
|
||||
outDims.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(outtransDesc, CNNL_LAYOUT_NHWC,
|
||||
CNNL_DTYPE_FLOAT, outDims.size(),
|
||||
dimsOutTrans));
|
||||
cnnlTransposeDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute));
|
||||
size_t wsSize;
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), inDesc, opDesc,
|
||||
&wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
BangPtr inputTrans = context->getWorkspace(
|
||||
cnnlGetTensorElementNum(inDesc) * sizeof(float));
|
||||
BangPtr outputTrans = context->getWorkspace(
|
||||
cnnlGetTensorElementNum(inDesc) * sizeof(float));
|
||||
cnnlStatus_t stat =
|
||||
cnnlTranspose_v2(context->cnnlHandle(), opDesc, inDesc, input,
|
||||
intransDesc, inputTrans, wsData, wsSize);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// get bnScaleBiasMeanVarDesc
|
||||
auto dimsScaleBiasMeanVar = op->getInputs(1)->getDims();
|
||||
cnnlTensorDescriptor_t paraDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(¶Desc));
|
||||
checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 1, dimPArray,
|
||||
stridePArray));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
paraDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT,
|
||||
dimsScaleBiasMeanVar.size(), dimsScaleBiasMeanVar.data()));
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
// This mode is intended for use after convolutional layers
|
||||
cnnlStatus_t stat = cnnlBatchNormForwardInference(
|
||||
context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc,
|
||||
scale, bias, mean, var, op->getEps(), inDesc, output);
|
||||
stat = cnnlBatchNormForwardInference(
|
||||
context->cnnlHandle(), &alpha, &beta, intransDesc, inputTrans,
|
||||
paraDesc, scale, bias, mean, var, op->getEps(), outtransDesc,
|
||||
outputTrans);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
cnnlTransposeDescriptor_t op2Desc;
|
||||
checkCnnlError(cnnlCreateTransposeDescriptor(&op2Desc));
|
||||
checkCnnlError(cnnlSetTransposeDescriptor(op2Desc, 4, permuteOut));
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), intransDesc,
|
||||
op2Desc, &wsSize);
|
||||
BangPtr ws2Data = context->getWorkspace(wsSize);
|
||||
stat = cnnlTranspose_v2(context->cnnlHandle(), op2Desc, outtransDesc,
|
||||
outputTrans, outDesc, output, ws2Data, wsSize);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in BANG does not require sync. But cnnl does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(inDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(outDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(intransDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(outtransDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc));
|
||||
checkCnnlError(cnnlDestroyTransposeDescriptor(opDesc));
|
||||
checkCnnlError(cnnlDestroyTransposeDescriptor(op2Desc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32,
|
||||
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, BatchNormCnnl,
|
||||
"BatchNorm_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#ifdef INFINI_USE_CNCL
|
||||
#include "operators/broadcast.h"
|
||||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "bang/cncl_communicator.h"
|
||||
#include <thread>
|
||||
namespace infini {
|
||||
class BroadcastCNCL : public BangKernelWithoutConfig {
|
||||
public:
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BroadcastObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||
|
||||
cnclComm_t comm =
|
||||
dynamic_cast<CnclCommunicatorObj &>(context->getCommunicator())
|
||||
.getCnclComm();
|
||||
cnrtQueue_t queue = context->getBangQueue();
|
||||
// TODO: Using default stream 0 for now.
|
||||
CNCL_CHECK(cnclBroadcast(input, output, count, cnclFloat32,
|
||||
op->getRoot(), comm, queue));
|
||||
checkBangError(cnrtQueueSync(queue));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Broadcast, DataType::Float32,
|
||||
BroadcastCNCL, "Broadcast_CNCL_BANG_Float32");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -212,7 +212,6 @@ class CastCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
|
||||
"Cast_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cast, CastCnnl, "Cast_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class CeilCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -35,7 +36,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Ceil, DataType::Float32, CeilCnnl,
|
||||
"Ceil_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Ceil, CeilCnnl, "Ceil_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ClipCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ClipObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -30,7 +31,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Clip, DataType::Float32, ClipCnnl,
|
||||
"Clip_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Clip, ClipCnnl, "Clip_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
int num = op->numInputs();
|
||||
int axis = op->getDim();
|
||||
|
@ -50,6 +51,5 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Concat, DataType::Float32, ConcatCnnl,
|
||||
"Concat_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Concat, ConcatCnnl, "Concat_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -118,8 +119,8 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
cnnlGetTensorElementNum(cInDesc) * sizeof(float));
|
||||
|
||||
stat = cnnlConvolutionForward(
|
||||
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aData, bDesc,
|
||||
bData, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
||||
context->cnnlHandle(), convDesc, algo, NULL, aDesc, aDataOut, bDesc,
|
||||
bDataOut, NULL, NULL, wsData, wsSize, NULL, cInDesc, cDataIn);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
@ -130,10 +131,10 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
|
||||
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), cInDesc, opOutDesc,
|
||||
&wsSize);
|
||||
wsData = context->getWorkspace(wsSize);
|
||||
BangPtr wsData2 = context->getWorkspace(wsSize);
|
||||
|
||||
stat = cnnlTranspose_v2(context->cnnlHandle(), opOutDesc, cInDesc,
|
||||
cDataIn, cDesc, cData, wsData, wsSize);
|
||||
cDataIn, cDesc, cData, wsData2, wsSize);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
|
@ -151,6 +152,5 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Conv, DataType::Float32, ConvCnnl,
|
||||
"Conv_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Conv, ConvCnnl, "Conv_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -39,24 +40,17 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
if (dimOutput.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int inputs0[4] = {dimInputs0[0], dimInputs0[1], dimInputs0[2],
|
||||
dimInputs0[3]};
|
||||
int inputs1[4] = {dimInputs1[0], dimInputs1[1], dimInputs1[2],
|
||||
dimInputs1[3]};
|
||||
int output[4] = {dimOutput[0], dimOutput[1], dimOutput[2],
|
||||
dimOutput[3]};
|
||||
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, inputs0));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs0.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, inputs1));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimInputs1.data()));
|
||||
// get outputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, output));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dimOutput.data()));
|
||||
|
||||
cnnlConvolutionBwdDataAlgo_t algo;
|
||||
cnnlGetConvolutionBackwardDataAlgorithm(
|
||||
|
@ -64,12 +58,12 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
CNNL_CONVOLUTION_BWD_DATA_FASTEST, &algo);
|
||||
size_t wsSize;
|
||||
cnnlGetConvolutionBackwardDataWorkspaceSize(context->cnnlHandle(),
|
||||
aDesc, bDesc, convDesc,
|
||||
bDesc, aDesc, convDesc,
|
||||
cDesc, algo, &wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
|
||||
cnnlStatus_t stat = cnnlConvolutionBackwardData(
|
||||
context->cnnlHandle(), NULL, aDesc, aData, bDesc, bData, convDesc,
|
||||
context->cnnlHandle(), NULL, bDesc, bData, aDesc, aData, convDesc,
|
||||
algo, wsData, wsSize, NULL, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
@ -83,6 +77,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32,
|
||||
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, ConvTransCnnl,
|
||||
"ConvTrans_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvBackwardFilterObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -154,6 +155,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32,
|
||||
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter,
|
||||
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class DetCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DetObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -42,6 +43,5 @@ class DetCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl,
|
||||
"Det_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Det, DetCnnl, "Det_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue