forked from jiuyuan/InfiniTensor
XCCL support (#171)
* add reduce_mean and gather * fix format * add kunlun allreduce and cmakefile * add kunlun allreduce and cmakefile * deltete cmake opt * fix format * fix makefile * add DIST option in Makefile * add xpu allgather * delete xpu_wait() * add xpu allgather * delete specific compiler * fix format * fix gather * add broadcast * fix format * fix * fix xpu, add where operation, fix element-wise operation * fix softmax * fix softmax * log internal input and output * fix kunlun gather bugs * update CMakeList.txt and Makefile * fix some kunlun kernels * fix Makefile * fix Makefile * set cmake version 3.12 * format * fix where, gather and support gpt2 * "fix format" * fix format * copy onnx.py from master * use KUNLUN_HOME instead of absolute path * fix torchvision models * support torchvison model-zoo * fix format * format fix, CMakeList fix * fix review * fix vecToString return value * fix format * delete empty file --------- Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
b51ccae3b2
commit
00e6cc2587
|
@ -53,11 +53,13 @@ endif()
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
|
add_compile_options(-Wno-error=unused-variable)
|
||||||
|
|
||||||
find_package(
|
find_package(
|
||||||
Python
|
Python
|
||||||
COMPONENTS Interpreter Development
|
COMPONENTS Interpreter Development
|
||||||
REQUIRED)
|
REQUIRED)
|
||||||
|
|
||||||
# OpenMP
|
# OpenMP
|
||||||
find_package(OpenMP)
|
find_package(OpenMP)
|
||||||
if(OpenMP_C_FOUND)
|
if(OpenMP_C_FOUND)
|
||||||
|
@ -282,9 +284,9 @@ if(USE_KUNLUN)
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
||||||
|
|
||||||
include_directories("${KUNLUN_HOME}/XTDK/include/")
|
include_directories("${KUNLUN_HOME}/include/")
|
||||||
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64")
|
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/")
|
||||||
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/XTDK/shlib")
|
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||||
|
|
||||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||||
|
@ -297,6 +299,13 @@ if(USE_KUNLUN)
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
||||||
|
|
||||||
|
if (BUILD_DIST)
|
||||||
|
message(STATUS "Add BUILD_DIST, use XCCL with KUNLUN XPU")
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
|
||||||
|
find_package(XCCL REQUIRED)
|
||||||
|
add_compile_definitions(INFINI_USE_XCCL=1)
|
||||||
|
target_link_libraries(InfiniTensor ${XCCL_LIBRARIES})
|
||||||
|
endif()
|
||||||
target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++)
|
target_link_libraries(InfiniTensor ${KUNLUN_RT} ${KUNLUN_DNN} stdc++)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -335,6 +344,7 @@ if(BUILD_TEST)
|
||||||
endif()
|
endif()
|
||||||
if (USE_KUNLUN)
|
if (USE_KUNLUN)
|
||||||
build_test(test/kernels/kunlun/*.cc)
|
build_test(test/kernels/kunlun/*.cc)
|
||||||
|
build_test(test/kunlun/*.cc)
|
||||||
endif()
|
endif()
|
||||||
if (USE_INTELCPU)
|
if (USE_INTELCPU)
|
||||||
build_test(test/kernels/intelcpu/*.cc)
|
build_test(test/kernels/intelcpu/*.cc)
|
||||||
|
|
1
Makefile
1
Makefile
|
@ -9,6 +9,7 @@ BACKTRACE ?= ON
|
||||||
TEST ?= ON
|
TEST ?= ON
|
||||||
DIST ?= OFF
|
DIST ?= OFF
|
||||||
NNET ?= OFF
|
NNET ?= OFF
|
||||||
|
DIST ?= OFF
|
||||||
FORMAT_ORIGIN ?=
|
FORMAT_ORIGIN ?=
|
||||||
# Docker build options
|
# Docker build options
|
||||||
DOCKER_NAME ?= infinitensor
|
DOCKER_NAME ?= infinitensor
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Find the xccl libraries
|
||||||
|
set(XCCL_INCLUDE_DIR $ENV{KUNLUN_HOME}/include CACHE PATH "Folder contains KUNLUN XCCL headers")
|
||||||
|
set(XCCL_LIB_DIR $ENV{KUNLUN_HOME} CACHE PATH "Folder contains KUNLUN XCCL libraries")
|
||||||
|
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH $ENV{KUNLUN_HOME})
|
||||||
|
|
||||||
|
find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR}
|
||||||
|
NAMES xpu/bkcl.h
|
||||||
|
HINTS XCCL_INCLUDE_DIR)
|
||||||
|
|
||||||
|
find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR}
|
||||||
|
NAMES so/libbkcl.so
|
||||||
|
HINTS XCCL_LIB_DIR)
|
||||||
|
|
||||||
|
message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}")
|
||||||
|
message(STATUS "XCCL_LIBRARIES: ${XCCL_LIBRARIES}")
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(XCCL DEFAULT_MSG XCCL_INCLUDE_DIRS XCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if (XCCL_FOUND)
|
||||||
|
set (XCCL_HEADER_FILE "${XCCL_INCLUDE_DIRS}/xpu/bkcl.h")
|
||||||
|
message (STATUS "Determing XCCL version from ${XCCL_HEADER_FILE}...")
|
||||||
|
list (APPEND CMAKE_REQUIRED_INCLUDES ${XCCL_INCLUDE_DIRS})
|
||||||
|
message(STATUS "Found XCCL (include: ${XCCL_INCLUDE_DIRS}, library: ${XCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(XCCL_INCLUDE_DIRS XCCL_LIBRARIES)
|
||||||
|
endif()
|
|
@ -1 +1 @@
|
||||||
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
|
@ -0,0 +1,213 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
import numpy as np
|
||||||
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
st_input_dir = "standard/inputs/"
|
||||||
|
st_output_dir = "standard/outputs/"
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_single",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether run model with single process with standard inputs"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.gen_std,
|
||||||
|
args.run_single
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
load_inputs(stub, world_size, rank)
|
||||||
|
# stub.tune()
|
||||||
|
stub.run()
|
||||||
|
# get outputs
|
||||||
|
time.sleep(0.01)
|
||||||
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
|
# bench
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n):
|
||||||
|
stub.run()
|
||||||
|
end = time.time()
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||||
|
results = np.load(os.path.join(st_output_dir,f"output.npy"))
|
||||||
|
outputs = run_model(model, runtime, world_size, rank)
|
||||||
|
print(outputs[:100])
|
||||||
|
if np.isnan(outputs).any():
|
||||||
|
print("Nan in output")
|
||||||
|
print("answer argmax:", np.argmax(results))
|
||||||
|
print("output argmax:", np.argmax(outputs))
|
||||||
|
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||||
|
getDiff(results, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
|
):
|
||||||
|
dist_name = name + "_dist"
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
onnx.save_model(
|
||||||
|
model,
|
||||||
|
f"./{dist_name}_rank{rank}.onnx",
|
||||||
|
save_as_external_data=True,
|
||||||
|
location=extern_path,
|
||||||
|
)
|
||||||
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
|
runtime = backend.KUNLUNRuntime(local_rank)
|
||||||
|
# print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
run_and_compare(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def start_single(name, model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
run_and_compare(name, model, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_output(model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
position_id = 0
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = tensor.copyout_numpy()
|
||||||
|
if np.issubdtype(input.dtype, np.integer):
|
||||||
|
if input.size == 1:
|
||||||
|
# input = np.array([position_id])
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
else:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
elif input.dtype == np.bool_:
|
||||||
|
input = np.random.randint(0,2,size=input.shape) > 0
|
||||||
|
else:
|
||||||
|
if i == 0:
|
||||||
|
input = np.ones(input.shape).astype(input.dtype)
|
||||||
|
position_id = input.shape[-1] - 1
|
||||||
|
else:
|
||||||
|
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
|
||||||
|
stub.run()
|
||||||
|
# print(stub.outputs)
|
||||||
|
time.sleep(0.01)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
print(output[:100])
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
np.save(os.path.join(st_output_dir, f"output"), output)
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(os.path.join(st_input_dir, f"input_{i}.npy"))
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
|
||||||
|
def getDiff(base, test):
|
||||||
|
absolute_diff = np.abs(np.subtract(base, test))
|
||||||
|
max_absolute_diff = np.max(absolute_diff)
|
||||||
|
|
||||||
|
baseCopy = base.astype(np.float64).ravel()
|
||||||
|
testCopy = test.astype(np.float64).ravel()
|
||||||
|
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||||
|
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||||
|
max_relative_diff = upValue / downValue
|
||||||
|
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
|
||||||
|
|
||||||
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print("Generate inputs and outputs.")
|
||||||
|
p = mp.Process(target=generate_input_output, args=[model])
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# # run single process.
|
||||||
|
# # use standalone process to isolate cuda.
|
||||||
|
if run_single:
|
||||||
|
print("run model by single GPU.")
|
||||||
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"run model by {world_size} GPU in parallel.")
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -61,16 +61,30 @@ template <typename T> auto enum_to_underlying(T e) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
||||||
std::string ret;
|
std::stringstream ss;
|
||||||
ret.append("[");
|
ss << "[";
|
||||||
for (auto d : vec) {
|
for (size_t i = 0; i < vec.size(); ++i) {
|
||||||
ret.append(std::to_string(d));
|
ss << vec.at(i);
|
||||||
ret.append(",");
|
if (i < vec.size() - 1) {
|
||||||
|
ss << ",";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!vec.empty())
|
ss << "]";
|
||||||
ret.pop_back();
|
return ss.str();
|
||||||
ret.append("]");
|
}
|
||||||
return ret;
|
|
||||||
|
template <typename T> std::string vecToString(const T *st, size_t length) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[";
|
||||||
|
size_t i = 0;
|
||||||
|
for (i = 0; i < length; i++) {
|
||||||
|
ss << *(st + i);
|
||||||
|
if (i < length - 1) {
|
||||||
|
ss << ",";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
double timeit(
|
double timeit(
|
||||||
|
|
|
@ -15,6 +15,7 @@ class GraphObj;
|
||||||
class GraphHandlerObj;
|
class GraphHandlerObj;
|
||||||
class RuntimeObj;
|
class RuntimeObj;
|
||||||
class BlobObj;
|
class BlobObj;
|
||||||
|
template <typename T> class WorkspaceObj;
|
||||||
|
|
||||||
using TensorBase = Ref<TensorBaseObj>;
|
using TensorBase = Ref<TensorBaseObj>;
|
||||||
using Tensor = Ref<TensorObj>;
|
using Tensor = Ref<TensorObj>;
|
||||||
|
@ -23,6 +24,7 @@ using Graph = Ref<GraphObj>;
|
||||||
using GraphHandler = Ref<GraphHandlerObj>;
|
using GraphHandler = Ref<GraphHandlerObj>;
|
||||||
using Runtime = Ref<RuntimeObj>;
|
using Runtime = Ref<RuntimeObj>;
|
||||||
using Blob = Ref<BlobObj>;
|
using Blob = Ref<BlobObj>;
|
||||||
|
template <typename T> using Workspace = Ref<WorkspaceObj<T>>;
|
||||||
|
|
||||||
using TensorVec = vector<Tensor>;
|
using TensorVec = vector<Tensor>;
|
||||||
using OpVec = vector<Operator>;
|
using OpVec = vector<Operator>;
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "utils/data_convert.h"
|
#include "utils/data_convert.h"
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#if USE_CUDA
|
#if USE_CUDA
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
|
@ -143,6 +144,7 @@ class TensorObj : public TensorBaseObj {
|
||||||
}
|
}
|
||||||
|
|
||||||
void printData() const;
|
void printData() const;
|
||||||
|
void dumpData(std::ofstream &ofs) const;
|
||||||
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
||||||
|
|
||||||
template <typename T> bool equalData(const vector<T> &dataVector) {
|
template <typename T> bool equalData(const vector<T> &dataVector) {
|
||||||
|
@ -198,13 +200,20 @@ class TensorObj : public TensorBaseObj {
|
||||||
if (a[i] != b[i])
|
if (a[i] != b[i])
|
||||||
return false;
|
return false;
|
||||||
} else if constexpr (std::is_floating_point_v<T>) {
|
} else if constexpr (std::is_floating_point_v<T>) {
|
||||||
if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) >
|
if (std::min(fabs(a[i]), fabs(b[i])) == 0. &&
|
||||||
relativeError) {
|
fabs(a[i] - b[i]) > relativeError) {
|
||||||
|
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
|
||||||
|
return false;
|
||||||
|
} else if (std::min(fabs(a[i]), fabs(b[i])) != 0. &&
|
||||||
|
fabs(a[i] - b[i]) /
|
||||||
|
std::max(fabs(a[i]), fabs(b[i])) >
|
||||||
|
relativeError) {
|
||||||
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
|
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else
|
} else {
|
||||||
static_assert(!sizeof(T), "Unsupported data type");
|
static_assert(!sizeof(T), "Unsupported data type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -239,8 +248,8 @@ class TensorObj : public TensorBaseObj {
|
||||||
// // std::cerr << "Init beginned " << std::endl;
|
// // std::cerr << "Init beginned " << std::endl;
|
||||||
// #pragma omp parallel for
|
// #pragma omp parallel for
|
||||||
// for (size_t i = 0; i < iEnd; ++i)
|
// for (size_t i = 0; i < iEnd; ++i)
|
||||||
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
// data[i] = fastrand(random_seed[omp_get_thread_num() *
|
||||||
// 10000;
|
// 16]) % 10000;
|
||||||
// // std::cerr << "Init finished" << std::endl;
|
// // std::cerr << "Init finished" << std::endl;
|
||||||
// computed = ComputedFull;
|
// computed = ComputedFull;
|
||||||
// return true;
|
// return true;
|
||||||
|
@ -285,8 +294,8 @@ class TensorObj : public TensorBaseObj {
|
||||||
// auto nDim = dims.size();
|
// auto nDim = dims.size();
|
||||||
// auto nBroadcastDim = ds.size() - nDim;
|
// auto nBroadcastDim = ds.size() - nDim;
|
||||||
// for (size_t i = 0; i < nDim; ++i)
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim +
|
||||||
// dims[i])
|
// i] >= dims[i])
|
||||||
// return (size_t)-1;
|
// return (size_t)-1;
|
||||||
// size_t idx = 0;
|
// size_t idx = 0;
|
||||||
// for (size_t i = 0; i < nDim; ++i)
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
@ -345,12 +354,14 @@ class TensorObj : public TensorBaseObj {
|
||||||
// return (g_seed >> 16) & 0x7FFF;
|
// return (g_seed >> 16) & 0x7FFF;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
// std::vector<std::vector<int>> const *getSplittingPoints()
|
||||||
|
// const {
|
||||||
// assert(!splittingPoints.empty());
|
// assert(!splittingPoints.empty());
|
||||||
// return &splittingPoints;
|
// return &splittingPoints;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
// bool setSplittingPoints(std::vector<std::vector<int>> value)
|
||||||
|
// {
|
||||||
// assert(!value.empty());
|
// assert(!value.empty());
|
||||||
// splittingPoints = value;
|
// splittingPoints = value;
|
||||||
// return true;
|
// return true;
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <class T> class WorkspaceObj {
|
||||||
|
private:
|
||||||
|
T workspace; // workspace pointer
|
||||||
|
size_t workspaceSize; // Size of workspace
|
||||||
|
size_t workspaceAlloc; // currently use workspace size
|
||||||
|
|
||||||
|
public:
|
||||||
|
WorkspaceObj(T workspace_, size_t workspaceSize_)
|
||||||
|
: workspace(workspace_), workspaceSize(workspaceSize_) {
|
||||||
|
workspaceAlloc = 0;
|
||||||
|
}
|
||||||
|
virtual ~WorkspaceObj() {
|
||||||
|
// Dealloc workspace in RuntimeObj
|
||||||
|
// Set workspace = nullptr here
|
||||||
|
workspace = nullptr;
|
||||||
|
}
|
||||||
|
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||||
|
|
||||||
|
T getWorkspace(size_t size) {
|
||||||
|
// Get unused workspace
|
||||||
|
IT_ASSERT(size + workspaceAlloc <= workspaceSize);
|
||||||
|
auto ret = (T)(static_cast<uint8_t *>(workspace) + workspaceAlloc);
|
||||||
|
workspaceAlloc += size;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
T getWorkspace() {
|
||||||
|
// Override getWorkspace in order to dealloc in runtime
|
||||||
|
return workspace;
|
||||||
|
}
|
||||||
|
void resetWorkspace() {
|
||||||
|
// Reset workspaceAlloc every time end kernel
|
||||||
|
workspaceAlloc = 0;
|
||||||
|
}
|
||||||
|
size_t getWorkspaceAlloc() const { return workspaceAlloc; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,23 @@
|
||||||
|
#include "core/op_type.h"
|
||||||
|
#include "kunlun/kunlun_common.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using KunlunActType = xdnn::Activation_t;
|
||||||
|
KunlunActType parseActType(ActType act) {
|
||||||
|
switch (act) {
|
||||||
|
case ActType::None:
|
||||||
|
return KunlunActType::LINEAR;
|
||||||
|
case ActType::Tanh:
|
||||||
|
return KunlunActType::TANH;
|
||||||
|
case ActType::Sigmoid:
|
||||||
|
return KunlunActType::SIGMOID;
|
||||||
|
case ActType::Relu:
|
||||||
|
return KunlunActType::RELU6;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "Activation Type not support yet!\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return KunlunActType::LINEAR;
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -3,6 +3,8 @@
|
||||||
#include "xpu/runtime_ex.h"
|
#include "xpu/runtime_ex.h"
|
||||||
#include "xpu/xdnn.h"
|
#include "xpu/xdnn.h"
|
||||||
|
|
||||||
|
namespace xdnn = baidu::xpu::api;
|
||||||
|
|
||||||
#define checkKUNLUNError(call) \
|
#define checkKUNLUNError(call) \
|
||||||
{ \
|
{ \
|
||||||
auto err = call; \
|
auto err = call; \
|
||||||
|
|
|
@ -1,28 +1,35 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
#include "core/workspace.h"
|
||||||
#include "kunlun/kunlun_common.h"
|
#include "kunlun/kunlun_common.h"
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "kunlun/xccl_communicator.h"
|
||||||
|
#endif
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
class KUNLUNRuntimeObj : public RuntimeObj {
|
class KUNLUNRuntimeObj : public RuntimeObj {
|
||||||
private:
|
private:
|
||||||
baidu::xpu::api::Context *xdnn;
|
xdnn::Context *ctx;
|
||||||
KUNLUNPtr workspace;
|
std::unique_ptr<CommunicatorObj> comm;
|
||||||
size_t workspaceSize;
|
// KUNLUNPtr workspace;
|
||||||
|
// size_t workspaceSize;
|
||||||
|
Workspace<KUNLUNPtr> workspace;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
KUNLUNRuntimeObj() : RuntimeObj(Device::KUNLUN) {
|
KUNLUNRuntimeObj(int deviceId = 0) : RuntimeObj(Device::KUNLUN) {
|
||||||
xdnn = baidu::xpu::api::create_context();
|
xpu_set_device(deviceId);
|
||||||
|
ctx = xdnn::create_context();
|
||||||
// 10GB for Longformer
|
// 10GB for Longformer
|
||||||
// size_t longformerNum = 3lu * (1 << 30);
|
// size_t longformerNum = 3lu * (1 << 30);
|
||||||
workspaceSize = 3ll << 30; // 3 GB
|
size_t workspaceSize = 3llu << 30; // 3 GB
|
||||||
// std::cout<<workspaceSize/1024/1024/1024<< std::endl;
|
KUNLUNPtr wkspacePtr = alloc(workspaceSize);
|
||||||
// std::cout<<std::bitset<64>(workspaceSize)<< std::endl;
|
workspace =
|
||||||
workspace = alloc(workspaceSize);
|
make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize);
|
||||||
}
|
}
|
||||||
virtual ~KUNLUNRuntimeObj() {
|
virtual ~KUNLUNRuntimeObj() {
|
||||||
dealloc(workspace);
|
KUNLUNPtr wkspacePtr = workspace->getWorkspace();
|
||||||
baidu::xpu::api::destroy_context(xdnn);
|
dealloc(wkspacePtr);
|
||||||
|
xdnn::destroy_context(ctx);
|
||||||
}
|
}
|
||||||
string toString() const override;
|
string toString() const override;
|
||||||
|
|
||||||
|
@ -31,6 +38,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
||||||
// double runEvaluation(const Graph &graph, int nWarmups,
|
// double runEvaluation(const Graph &graph, int nWarmups,
|
||||||
// int nEvaluations) const;
|
// int nEvaluations) const;
|
||||||
void sync() const;
|
void sync() const;
|
||||||
|
|
||||||
KUNLUNPtr alloc(size_t size) override {
|
KUNLUNPtr alloc(size_t size) override {
|
||||||
void *ptr;
|
void *ptr;
|
||||||
checkKUNLUNError(
|
checkKUNLUNError(
|
||||||
|
@ -38,33 +46,33 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
void dealloc(void *ptr) override { xpu_free(ptr); }
|
void dealloc(void *ptr) override { xpu_free(ptr); }
|
||||||
baidu::xpu::api::Context *KUNLUNHandle() const { return xdnn; }
|
|
||||||
|
xdnn::Context *KUNLUNHandle() const { return ctx; }
|
||||||
|
// Get $size workspace by bytes
|
||||||
KUNLUNPtr getWorkspace(size_t size) const {
|
KUNLUNPtr getWorkspace(size_t size) const {
|
||||||
IT_ASSERT(size <= workspaceSize);
|
auto ret = workspace->getWorkspace(size);
|
||||||
return workspace;
|
return ret;
|
||||||
}
|
}
|
||||||
|
Workspace<KUNLUNPtr> getWorkspaceObj() const { return workspace; }
|
||||||
|
|
||||||
void copyBlobFromCPU(void *dst, const void *src,
|
void copyBlobFromCPU(void *dst, const void *src,
|
||||||
size_t bytes) const override {
|
size_t bytes) const override {
|
||||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||||
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
|
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copyBlobToCPU(void *dst, const void *src,
|
void copyBlobToCPU(void *dst, const void *src,
|
||||||
size_t bytes) const override {
|
size_t bytes) const override {
|
||||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||||
size_t bytes) const override {
|
size_t bytes) const override {
|
||||||
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
xpu_memcpy(dst, const_cast<void *>(src), bytes,
|
||||||
XPUMemcpyKind::XPU_DEVICE_TO_DEVICE);
|
XPUMemcpyKind::XPU_DEVICE_TO_DEVICE);
|
||||||
}
|
}
|
||||||
|
void initComm(const string &name, int worldSize, int rank) final;
|
||||||
|
|
||||||
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
CommunicatorObj &getCommunicator() const final { return *comm; }
|
||||||
|
|
||||||
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/communicator.h"
|
||||||
|
#include "xpu/bkcl.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#define checkXcclError(call) \
|
||||||
|
{ \
|
||||||
|
auto err = call; \
|
||||||
|
if (BKCL_SUCCESS != err) { \
|
||||||
|
fprintf(stderr, "XCCL error in %s:%i.\n", __FILE__, __LINE__); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class XcclCommunicatorObj final : public CommunicatorObj {
|
||||||
|
private:
|
||||||
|
BKCLContext_t comm;
|
||||||
|
|
||||||
|
public:
|
||||||
|
XcclCommunicatorObj(const string &name, int worldSize, int rank)
|
||||||
|
: CommunicatorObj(worldSize, rank) {
|
||||||
|
const std::string filePath("./" + name + "_xccl_id.bin");
|
||||||
|
BKCLUniqueId commId;
|
||||||
|
if (rank == 0) {
|
||||||
|
checkXcclError(bkcl_get_unique_id(&commId));
|
||||||
|
std::ofstream ofs(filePath, std::ios::binary);
|
||||||
|
ofs.write((char *)&commId, sizeof(BKCLUniqueId));
|
||||||
|
} else {
|
||||||
|
auto begin = std::chrono::steady_clock::now();
|
||||||
|
while (!std::filesystem::exists(filePath)) {
|
||||||
|
auto now = std::chrono::steady_clock::now();
|
||||||
|
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||||
|
"time limit (10s) exceeded.");
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
std::ifstream ifs(filePath, std::ios::binary);
|
||||||
|
ifs.read((char *)&commId, sizeof(BKCLUniqueId));
|
||||||
|
}
|
||||||
|
checkXcclError(bkcl_init_rank(&comm, rank, worldSize, &commId));
|
||||||
|
if (rank == 0) {
|
||||||
|
std::filesystem::remove(filePath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BKCLContext_t getXcclComm() { return comm; }
|
||||||
|
|
||||||
|
~XcclCommunicatorObj() final { checkXcclError(bkcl_destroy_context(comm)); }
|
||||||
|
virtual string toString() const final {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "XCCL communicator";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -159,6 +159,7 @@ enum class CastType {
|
||||||
Uint322Int64,
|
Uint322Int64,
|
||||||
Float162Float,
|
Float162Float,
|
||||||
BFloat162Float,
|
BFloat162Float,
|
||||||
|
Float2Float,
|
||||||
};
|
};
|
||||||
|
|
||||||
class CastObj : public OperatorObj {
|
class CastObj : public OperatorObj {
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace infini {
|
|
||||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
|
||||||
int nDims, int size) {
|
|
||||||
for (int i = nDims - size - 1; i >= 0; --i) {
|
|
||||||
modifyShape.data[i] = 1;
|
|
||||||
}
|
|
||||||
for (int i = nDims - 1; i >= nDims - size; --i) {
|
|
||||||
modifyShape.data[i] = originShape[i - nDims + size];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace infini
|
|
|
@ -5,6 +5,9 @@
|
||||||
#include "core/operator.h"
|
#include "core/operator.h"
|
||||||
#include "core/tensor.h"
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
#include "utils/small_array.h"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
// Launch a broadcast shape based on the shape of input A and B
|
// Launch a broadcast shape based on the shape of input A and B
|
||||||
|
@ -20,6 +23,12 @@ size_t delocate_index(const Shape &shapeIndex, const Shape &shape,
|
||||||
const Shape &stride);
|
const Shape &stride);
|
||||||
// Convert KernelAttrs to a string representation
|
// Convert KernelAttrs to a string representation
|
||||||
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
|
std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs);
|
||||||
|
// VectorProd
|
||||||
|
int shapeProd(std::vector<int>::iterator start, std::vector<int>::iterator end);
|
||||||
|
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||||
|
int nDims, int size);
|
||||||
|
void broadcastShape(const Shape &tempShape, Shape &modifyShape);
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -4,6 +4,14 @@ namespace infini {
|
||||||
#define SMALL_ARRAY_SIZE 8
|
#define SMALL_ARRAY_SIZE 8
|
||||||
struct SmallArray {
|
struct SmallArray {
|
||||||
int data[SMALL_ARRAY_SIZE];
|
int data[SMALL_ARRAY_SIZE];
|
||||||
|
|
||||||
|
int prod(int start, int end) {
|
||||||
|
int result = 1;
|
||||||
|
for (int i = start; i < end; ++i) {
|
||||||
|
result *= data[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import backend
|
import backend
|
||||||
from onnx import (
|
from onnx import (
|
||||||
ModelProto,
|
ModelProto,
|
||||||
TensorProto,
|
TensorProto,
|
||||||
|
@ -208,8 +208,8 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
elif node.op_type == "MatMul":
|
elif node.op_type == "MatMul":
|
||||||
tensors[node.output[0]] = self.handler.matmul(
|
tensors[node.output[0]] = self.handler.matmul(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]], # input
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]], # weight
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
|
|
|
@ -695,6 +695,8 @@ static CastType inferCastType(Tensor input, int to) {
|
||||||
return CastType::Float162Float;
|
return CastType::Float162Float;
|
||||||
} else if (iType == DataType::BFloat16 && oType == DataType::Float32) {
|
} else if (iType == DataType::BFloat16 && oType == DataType::Float32) {
|
||||||
return CastType::BFloat162Float;
|
return CastType::BFloat162Float;
|
||||||
|
} else if (iType == DataType::Float32 && oType == DataType::Float32) {
|
||||||
|
return CastType::Float2Float;
|
||||||
} else {
|
} else {
|
||||||
IT_TODO_HALT_MSG("Unsupported CastType : input_type is " +
|
IT_TODO_HALT_MSG("Unsupported CastType : input_type is " +
|
||||||
iType.toString() + " output_type is " +
|
iType.toString() + " output_type is " +
|
||||||
|
|
|
@ -66,6 +66,36 @@ void TensorObj::setShape(Shape shape_) {
|
||||||
_size = size;
|
_size = size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TensorObj::dumpData(std::ofstream &ofs) const {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
if (!runtime->isCpu())
|
||||||
|
IT_TODO_HALT();
|
||||||
|
|
||||||
|
#define TRY_DUMP(N) \
|
||||||
|
if (dtype == DataType(N)) \
|
||||||
|
ofs << dataToString<DT<N>::t>() << std::endl;
|
||||||
|
|
||||||
|
TRY_DUMP(0) // fmt: new line
|
||||||
|
else TRY_DUMP(1) //
|
||||||
|
else TRY_DUMP(2) //
|
||||||
|
else TRY_DUMP(3) //
|
||||||
|
else TRY_DUMP(4) //
|
||||||
|
else TRY_DUMP(5) //
|
||||||
|
else TRY_DUMP(6) //
|
||||||
|
else TRY_DUMP(7) //
|
||||||
|
else TRY_DUMP(8) //
|
||||||
|
else TRY_DUMP(9) //
|
||||||
|
else TRY_DUMP(10) //
|
||||||
|
else TRY_DUMP(11) //
|
||||||
|
else TRY_DUMP(12) //
|
||||||
|
else TRY_DUMP(13) //
|
||||||
|
else TRY_DUMP(16) //
|
||||||
|
else IT_TODO_HALT();
|
||||||
|
ofs.flush();
|
||||||
|
|
||||||
|
#undef TRY_DUMP
|
||||||
|
}
|
||||||
|
|
||||||
void TensorObj::printData() const {
|
void TensorObj::printData() const {
|
||||||
IT_ASSERT(data != nullptr);
|
IT_ASSERT(data != nullptr);
|
||||||
if (!runtime->isCpu())
|
if (!runtime->isCpu())
|
||||||
|
|
|
@ -429,7 +429,9 @@ void init_graph_builder(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_KUNLUN
|
#ifdef USE_KUNLUN
|
||||||
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
py::class_<KUNLUNRuntimeObj, std::shared_ptr<KUNLUNRuntimeObj>, RuntimeObj>(
|
||||||
m, "KUNLUNRuntime");
|
m, "KUNLUNRuntime")
|
||||||
|
.def(py::init<int>(), py::arg("device") = 0)
|
||||||
|
.def("init_comm", &KUNLUNRuntimeObj::initComm);
|
||||||
#endif
|
#endif
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
|
||||||
py::buffer_protocol())
|
py::buffer_protocol())
|
||||||
|
|
|
@ -145,6 +145,9 @@ class NativeUnary : public CpuKernelWithoutConfig {
|
||||||
case OpType::Atanh:
|
case OpType::Atanh:
|
||||||
_doCompute = aTanhCompute<T>;
|
_doCompute = aTanhCompute<T>;
|
||||||
break;
|
break;
|
||||||
|
case OpType::Acosh:
|
||||||
|
_doCompute = aCoshCompute<T>;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "cuda/cuda_where.h"
|
#include "cuda/cuda_where.h"
|
||||||
#include "utils/broadcast_shape.h"
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "kunlun/xccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class AllGatherXCCL : public KUNLUNKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllGatherObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
int world_size = op->getWorldSize();
|
||||||
|
IT_ASSERT(world_size == context->getCommunicator().getWorldSize());
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
KUNLUNPtr output_temp =
|
||||||
|
context->getWorkspace(op->getInputs(0)->getBytes() * world_size);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t bytes = op->getInputs(0)->getBytes();
|
||||||
|
size_t count = bytes / op->getDType().getSize();
|
||||||
|
|
||||||
|
BKCLContext_t comm =
|
||||||
|
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getXcclComm();
|
||||||
|
// TODO: Using the default stream 0
|
||||||
|
checkXcclError(
|
||||||
|
bkcl_all_gather(comm, input, count, output_temp, BKCL_FLOAT, 0));
|
||||||
|
|
||||||
|
for (int i = 0; i < world_size; ++i) {
|
||||||
|
Tensor output = op->getOutput(i);
|
||||||
|
context->copyBlobInsideRuntime(
|
||||||
|
output->getRawDataPtr<float *>(),
|
||||||
|
static_cast<float *>(output_temp) + i * count, bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::AllGather, AllGatherXCCL,
|
||||||
|
"AllGatcher_XCCL_KUNLUN");
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,49 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "kunlun/xccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class AllReduceXCCL : public KUNLUNKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<AllReduceBaseObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput(0)->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->size();
|
||||||
|
|
||||||
|
BKCLContext_t comm =
|
||||||
|
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getXcclComm();
|
||||||
|
checkXcclError(bkcl_all_reduce(comm, input, output, count,
|
||||||
|
BKCLDataType::BKCL_FLOAT, getRedOp(),
|
||||||
|
0));
|
||||||
|
}
|
||||||
|
virtual BKCLOp getRedOp() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceSumXCCL : public AllReduceXCCL {
|
||||||
|
BKCLOp getRedOp() const override { return BKCLOp::BKCL_ADD; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceMinXCCL : public AllReduceXCCL {
|
||||||
|
BKCLOp getRedOp() const override { return BKCLOp::BKCL_MIN; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllReduceMaxXCCL : public AllReduceXCCL {
|
||||||
|
BKCLOp getRedOp() const override { return BKCLOp::BKCL_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceSum, AllReduceSumXCCL,
|
||||||
|
"AllReduce_Sum_XCCL_KUNLUN");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceMax, AllReduceMaxXCCL,
|
||||||
|
"AllReduce_Max_XCCL_KUNLUN");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::AllReduceMin, AllReduceMinXCCL,
|
||||||
|
"AllReduce_Min_XCCL_KUNLUN");
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -26,7 +26,7 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
int h = dims[2];
|
int h = dims[2];
|
||||||
int c = dims[1];
|
int c = dims[1];
|
||||||
int n = dims[0];
|
int n = dims[0];
|
||||||
auto ret = baidu::xpu::api::batch_norm_infer<float>(
|
auto ret = xdnn::batch_norm_infer<float>(
|
||||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||||
(float *)var, true);
|
(float *)var, true);
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "kunlun/xccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class BroadcastXCCL : public KUNLUNKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<BroadcastObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||||
|
|
||||||
|
BKCLContext_t comm =
|
||||||
|
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getXcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
checkXcclError(bkcl_broadcast(comm, input, output, count, BKCL_FLOAT,
|
||||||
|
op->getRoot(), 0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Broadcast, BroadcastXCCL,
|
||||||
|
"Broadcast_XCCL_KUNLUN");
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -17,74 +17,78 @@ class CastXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case CastType::Float2Float16:
|
case CastType::Float2Float16:
|
||||||
ret = baidu::xpu::api::cast<float, float16>(
|
ret = xdnn::cast<float, float16>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float16 *)cData, len);
|
context->KUNLUNHandle(), (float *)aData, (float16 *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Float2Int64:
|
case CastType::Float2Int64:
|
||||||
ret = baidu::xpu::api::cast<float, int64_t>(
|
ret = xdnn::cast<float, int64_t>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (int64_t *)cData, len);
|
context->KUNLUNHandle(), (float *)aData, (int64_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Float2Int32:
|
case CastType::Float2Int32:
|
||||||
ret = baidu::xpu::api::cast<float, int>(
|
ret = xdnn::cast<float, int>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)aData, (int *)cData, len);
|
(float *)aData, (int *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Float2Int16:
|
case CastType::Float2Int16:
|
||||||
ret = baidu::xpu::api::cast<float, int16_t>(
|
ret = xdnn::cast<float, int16_t>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (int16_t *)cData, len);
|
context->KUNLUNHandle(), (float *)aData, (int16_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Float2Int8:
|
case CastType::Float2Int8:
|
||||||
ret = baidu::xpu::api::cast<float, int8_t>(
|
ret = xdnn::cast<float, int8_t>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (int8_t *)cData, len);
|
context->KUNLUNHandle(), (float *)aData, (int8_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int322Float:
|
case CastType::Int322Float:
|
||||||
ret = baidu::xpu::api::cast<int, float>(
|
ret = xdnn::cast<int, float>(context->KUNLUNHandle(), (int *)aData,
|
||||||
context->KUNLUNHandle(), (int *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int322Int8:
|
case CastType::Int322Int8:
|
||||||
ret = baidu::xpu::api::cast<int, int8_t>(
|
ret = xdnn::cast<int, int8_t>(context->KUNLUNHandle(), (int *)aData,
|
||||||
context->KUNLUNHandle(), (int *)aData, (int8_t *)cData, len);
|
(int8_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int322Int16:
|
case CastType::Int322Int16:
|
||||||
ret = baidu::xpu::api::cast<int, int16_t>(
|
ret = xdnn::cast<int, int16_t>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int *)aData, (int16_t *)cData, len);
|
(int *)aData, (int16_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int162Float:
|
case CastType::Int162Float:
|
||||||
ret = baidu::xpu::api::cast<int16_t, float>(
|
ret = xdnn::cast<int16_t, float>(
|
||||||
context->KUNLUNHandle(), (int16_t *)aData, (float *)cData, len);
|
context->KUNLUNHandle(), (int16_t *)aData, (float *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int162Int32:
|
case CastType::Int162Int32:
|
||||||
ret = baidu::xpu::api::cast<int16_t, int>(
|
ret = xdnn::cast<int16_t, int>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int16_t *)aData, (int *)cData, len);
|
(int16_t *)aData, (int *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int82Float:
|
case CastType::Int82Float:
|
||||||
ret = baidu::xpu::api::cast<int8_t, float>(
|
ret = xdnn::cast<int8_t, float>(
|
||||||
context->KUNLUNHandle(), (int8_t *)aData, (float *)cData, len);
|
context->KUNLUNHandle(), (int8_t *)aData, (float *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int82Int16:
|
case CastType::Int82Int16:
|
||||||
ret = baidu::xpu::api::cast<int8_t, int16_t>(
|
ret = xdnn::cast<int8_t, int16_t>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int8_t *)aData, (int16_t *)cData,
|
(int8_t *)aData, (int16_t *)cData,
|
||||||
len);
|
len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int82Int32:
|
case CastType::Int82Int32:
|
||||||
ret = baidu::xpu::api::cast<int8_t, int>(
|
ret = xdnn::cast<int8_t, int>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int8_t *)aData, (int *)cData, len);
|
(int8_t *)aData, (int *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int322Int64:
|
case CastType::Int322Int64:
|
||||||
ret = baidu::xpu::api::cast<int, int64_t>(
|
ret = xdnn::cast<int, int64_t>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int *)aData, (int64_t *)cData, len);
|
(int *)aData, (int64_t *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int642Int32:
|
case CastType::Int642Int32:
|
||||||
ret = baidu::xpu::api::cast<int64_t, int>(
|
ret = xdnn::cast<int64_t, int>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (int64_t *)aData, (int *)cData, len);
|
(int64_t *)aData, (int *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Int642Float:
|
case CastType::Int642Float:
|
||||||
ret = baidu::xpu::api::cast<int64_t, float>(
|
ret = xdnn::cast<int64_t, float>(
|
||||||
context->KUNLUNHandle(), (int64_t *)aData, (float *)cData, len);
|
context->KUNLUNHandle(), (int64_t *)aData, (float *)cData, len);
|
||||||
break;
|
break;
|
||||||
case CastType::Float162Float:
|
case CastType::Float162Float:
|
||||||
ret = baidu::xpu::api::cast<float16, float>(
|
ret = xdnn::cast<float16, float>(
|
||||||
context->KUNLUNHandle(), (float16 *)aData, (float *)cData, len);
|
context->KUNLUNHandle(), (float16 *)aData, (float *)cData, len);
|
||||||
break;
|
break;
|
||||||
|
case CastType::Float2Float:
|
||||||
|
ret = xdnn::copy<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
|
(float *)cData, len);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,8 @@ class ConcatXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
dims.push_back(dim);
|
dims.push_back(dim);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::concat<float>(
|
auto ret = xdnn::concat<float>(context->KUNLUNHandle(), inputsData,
|
||||||
context->KUNLUNHandle(), inputsData, (float *)cData, dims, axis);
|
(float *)cData, dims, axis);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,11 +24,17 @@ class ConvXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
std::vector<int> stride = {sh, sw};
|
std::vector<int> stride = {sh, sw};
|
||||||
std::vector<int> dilation = {dh, dw};
|
std::vector<int> dilation = {dh, dw};
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::conv2d<float, float, float, float>(
|
// TODO: Convolution operators still have some accuracy problems
|
||||||
|
checkKUNLUNError((xdnn::conv2d<float, float, float, float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
||||||
nullptr, nullptr, nullptr, true);
|
nullptr, nullptr, nullptr, true)));
|
||||||
assert(ret == 0);
|
|
||||||
|
// checkKUNLUNError((xdnn::conv2d_fusion<float, float, float, float>(
|
||||||
|
// context->KUNLUNHandle(), (float *const)aData, (float
|
||||||
|
// *const)bData, (float *)cData, n, c, h, w, f, ksize, stride, pads,
|
||||||
|
// dilation, g, nullptr, nullptr, nullptr, true, nullptr, nullptr,
|
||||||
|
// xdnn::Activation_t::LINEAR)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -37,11 +37,10 @@ class ConvTransXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (dimOutput.size() != 4)
|
if (dimOutput.size() != 4)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
|
||||||
auto ret =
|
auto ret = xdnn::conv2d_transpose<float, float, float, float>(
|
||||||
baidu::xpu::api::conv2d_transpose<float, float, float, float>(
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
||||||
(float *)cData, n, c, h, w, f, ksize, stride, pads, dilation, g,
|
nullptr, nullptr, nullptr, isNCHW);
|
||||||
nullptr, nullptr, nullptr, isNCHW);
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "kunlun/kunlun_kernel_without_config.h"
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
#include "kunlun/kunlun_runtime.h"
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class AddXdnn : public KUNLUNKernelWithoutConfig {
|
class AddXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
@ -22,10 +23,9 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_add<float>(
|
checkKUNLUNError(xdnn::broadcast_add<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -49,10 +49,9 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
checkKUNLUNError(xdnn::broadcast_sub<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -76,10 +75,9 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
checkKUNLUNError(xdnn::broadcast_mul<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -95,18 +93,40 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
auto aSize = op->getInputs(0)->size();
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
|
auto bSize = op->getInputs(1)->size();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() == 0) {
|
auto dtype = op->getDType();
|
||||||
aDim.push_back(1);
|
|
||||||
}
|
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_div<float>(
|
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
if (aSize == bSize) {
|
||||||
(float *)cData, aDim, bDim);
|
// Do ElementWise Sub with no broadcast
|
||||||
assert(ret == 0);
|
checkKUNLUNError(xdnn::div<float>(context->KUNLUNHandle(),
|
||||||
|
(float *)aData, (float *)bData,
|
||||||
|
(float *)cData, aSize));
|
||||||
|
} else {
|
||||||
|
// Do broadcast div
|
||||||
|
Shape aligned = infer_broadcast(aDim, bDim);
|
||||||
|
if (aligned == aDim) {
|
||||||
|
// BData need to be broadcasted
|
||||||
|
checkKUNLUNError(xdnn::broadcast_div<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
|
(float *)cData, aDim, bDim));
|
||||||
|
} else {
|
||||||
|
// Use workspace to broadcast aData
|
||||||
|
KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize());
|
||||||
|
checkKUNLUNError(xdnn::broadcast<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim,
|
||||||
|
bDim));
|
||||||
|
checkKUNLUNError(xdnn::div<float>(context->KUNLUNHandle(),
|
||||||
|
(float *)wks, (float *)bData,
|
||||||
|
(float *)cData, bSize));
|
||||||
|
}
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -131,10 +151,9 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
checkKUNLUNError(xdnn::broadcast_pow<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -158,10 +177,9 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_max<float>(
|
checkKUNLUNError(xdnn::broadcast_max<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -185,10 +203,9 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_min<float>(
|
checkKUNLUNError(xdnn::broadcast_min<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -204,7 +221,9 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -214,12 +233,11 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
checkKUNLUNError(xdnn::broadcast_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -235,7 +253,8 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -245,12 +264,11 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
checkKUNLUNError(xdnn::broadcast_greater_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -266,7 +284,8 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
KUNLUNPtr wsData =
|
||||||
|
context->getWorkspace(len * (op->getDType()).getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -276,12 +295,11 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
checkKUNLUNError(xdnn::broadcast_greater_than<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -297,7 +315,8 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -307,12 +326,11 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
checkKUNLUNError(xdnn::broadcast_less_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -328,7 +346,8 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -338,12 +357,11 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
checkKUNLUNError(xdnn::broadcast_less_than<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -367,10 +385,9 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
checkKUNLUNError(xdnn::broadcast_floordiv<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -388,10 +405,9 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
|
|
||||||
auto dim = op->getInputs(0)->getDims();
|
auto dim = op->getInputs(0)->getDims();
|
||||||
auto ret = baidu::xpu::api::mse_loss<float>(
|
checkKUNLUNError(xdnn::mse_loss<float>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
(float *)aData, (float *)bData,
|
||||||
(float *)cData, len);
|
(float *)cData, len));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -407,7 +423,8 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -417,12 +434,11 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::logical_and<bool>(
|
checkKUNLUNError(xdnn::logical_and<bool>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
(bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -438,7 +454,8 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -448,12 +465,11 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::logical_or<bool>(
|
checkKUNLUNError(xdnn::logical_or<bool>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
(bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -469,7 +485,8 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
@ -479,12 +496,11 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::logical_xor<bool>(
|
checkKUNLUNError(xdnn::logical_xor<bool>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
(bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -499,14 +515,14 @@ class NotXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
auto dtype = op->getDType();
|
||||||
|
KUNLUNPtr wsData = context->getWorkspace(len * dtype.getSize());
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto ret = baidu::xpu::api::logical_not<bool>(
|
checkKUNLUNError(xdnn::logical_not<bool>(
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
|
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len));
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
checkKUNLUNError((xdnn::cast<bool, float>(
|
||||||
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len);
|
context->KUNLUNHandle(), (bool *)wsData, (float *)cData, len)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
|
#include "core/common.h"
|
||||||
#include "kunlun/kunlun_kernel_without_config.h"
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
#include "kunlun/kunlun_runtime.h"
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
|
||||||
|
@ -10,17 +11,18 @@ class GatherXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); // data
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData =
|
||||||
|
(op->getInputs(1)->getRawDataPtr<void *>()); // indice
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto shape = op->getInputs(0)->getDims();
|
Shape aShape = op->getInputs(0)->getDims();
|
||||||
auto index = op->getInputs(1)->getDims();
|
Tensor bTensor = op->getInputs(1);
|
||||||
auto axis = op->getAxis();
|
int axis = op->getAxis();
|
||||||
auto ret = baidu::xpu::api::gather<float, int>(
|
checkKUNLUNError((baidu::xpu::api::gather<float, int>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (int *)bData,
|
context->KUNLUNHandle(), (float *)aData, (int *)bData,
|
||||||
(float *)cData, shape, index.size(), axis);
|
(float *)cData, aShape, bTensor->size(), axis)));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,30 +1,123 @@
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "kunlun/kunlun_act_type.h"
|
||||||
|
#include "kunlun/kunlun_common.h"
|
||||||
#include "kunlun/kunlun_kernel_without_config.h"
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
#include "kunlun/kunlun_runtime.h"
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class MatmulXdnn : public KUNLUNKernelWithoutConfig {
|
class MatmulXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
|
// This kernel do C = act(alpha * x * w + beta * bias)
|
||||||
auto op = as<MatmulObj>(_op);
|
auto op = as<MatmulObj>(_op);
|
||||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
Shape aDims = op->getInputs(0)->getDims();
|
||||||
|
Shape bDims = op->getInputs(1)->getDims();
|
||||||
|
Shape cDims = op->getOutput()->getDims();
|
||||||
|
|
||||||
|
const auto [b, m, n, k] = op->getBMNK();
|
||||||
bool transA = op->getTransA();
|
bool transA = op->getTransA();
|
||||||
bool transB = op->getTransB();
|
bool transB = op->getTransB();
|
||||||
|
int rankA = op->getInputs(0)->getRank();
|
||||||
|
int rankB = op->getInputs(1)->getRank();
|
||||||
|
int rankAligned = std::max(rankA, rankB);
|
||||||
|
IT_ASSERT(rankAligned <= SMALL_ARRAY_SIZE);
|
||||||
|
|
||||||
auto b = op->getB();
|
float alpha = 1.f, beta = 0.f;
|
||||||
auto m = op->getM();
|
Tensor biasTensor = op->getBias();
|
||||||
auto n = op->getN();
|
DataType dtype = op->getDType();
|
||||||
auto k = op->getK();
|
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::fc_batched<float, float, float, float>(
|
if (b > 1) {
|
||||||
context->KUNLUNHandle(), b, transA, transB, m, n, k, 1.0,
|
SmallArray alignedAShape;
|
||||||
(float *)aData, m * k, (float *)bData, n * k, 0.0, (float *)cData,
|
SmallArray alignedBShape;
|
||||||
m * n, nullptr, nullptr);
|
// Padding 1 in aShape and bShape in order to align rank
|
||||||
assert(ret == 0);
|
broadcastShape(aDims, alignedAShape, rankAligned, rankA);
|
||||||
|
broadcastShape(bDims, alignedBShape, rankAligned, rankB);
|
||||||
|
// Calculate batch dim
|
||||||
|
int batchA = alignedAShape.prod(0, rankAligned - 2);
|
||||||
|
int batchB = alignedBShape.prod(0, rankAligned - 2);
|
||||||
|
// View aShape bShape to 3 dim
|
||||||
|
Shape aDimsMatmul = {batchA, aDims[rankA - 2], aDims[rankA - 1]};
|
||||||
|
Shape bDimsMatmul = {batchB, bDims[rankB - 2], bDims[rankB - 1]};
|
||||||
|
auto numOutput = op->getOutput()->size();
|
||||||
|
KUNLUNPtr wkspace = nullptr;
|
||||||
|
void *AData = nullptr;
|
||||||
|
void *BData = nullptr;
|
||||||
|
void *CData = nullptr;
|
||||||
|
if (batchA != batchB) {
|
||||||
|
// If bs not equal, then broadcast
|
||||||
|
IT_ASSERT(batchA == 1 || batchB == 1);
|
||||||
|
if (batchA == 1) {
|
||||||
|
// Broadcast aShapeMatmul in batch dimension
|
||||||
|
Shape aDimsTarget = {b, aDimsMatmul[1], aDimsMatmul[2]};
|
||||||
|
auto numInput =
|
||||||
|
shapeProd(aDimsTarget.begin(), aDimsTarget.end());
|
||||||
|
wkspace = context->getWorkspace(numInput * dtype.getSize());
|
||||||
|
checkKUNLUNError(xdnn::broadcast<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData,
|
||||||
|
(float *)wkspace, aDimsMatmul, aDimsTarget));
|
||||||
|
AData = wkspace;
|
||||||
|
BData = bData;
|
||||||
|
CData =
|
||||||
|
biasTensor
|
||||||
|
? context->getWorkspace(numOutput * dtype.getSize())
|
||||||
|
: outData;
|
||||||
|
} else {
|
||||||
|
// Broadcast bShapeMatmul in batch dimension
|
||||||
|
Shape bDimsTarget = {b, bDimsMatmul[1], bDimsMatmul[2]};
|
||||||
|
auto numInput =
|
||||||
|
shapeProd(bDimsTarget.begin(), bDimsTarget.end());
|
||||||
|
wkspace = context->getWorkspace(numInput * dtype.getSize());
|
||||||
|
checkKUNLUNError(xdnn::broadcast<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)bData,
|
||||||
|
(float *)wkspace, bDimsMatmul, bDimsTarget));
|
||||||
|
AData = aData;
|
||||||
|
BData = wkspace;
|
||||||
|
CData =
|
||||||
|
biasTensor
|
||||||
|
? context->getWorkspace(numOutput * dtype.getSize())
|
||||||
|
: outData;
|
||||||
|
} // endif batchA == 1
|
||||||
|
} else { // batchA == batchB, no need to broadcast
|
||||||
|
AData = aData;
|
||||||
|
BData = bData;
|
||||||
|
CData = biasTensor
|
||||||
|
? context->getWorkspace(numOutput * dtype.getSize())
|
||||||
|
: outData;
|
||||||
|
}
|
||||||
|
checkKUNLUNError((xdnn::fc_batched<float, float, float, float>(
|
||||||
|
context->KUNLUNHandle(), b, transA, transB, m, n, k, alpha,
|
||||||
|
(float *)AData, m * k, (float *)BData, n * k, beta,
|
||||||
|
(float *)CData, m * n, nullptr, nullptr)));
|
||||||
|
// Broadcast_add xw and bias if bias exists
|
||||||
|
if (biasTensor) {
|
||||||
|
auto biasShape = biasTensor->getDims();
|
||||||
|
broadcastShape(cDims, biasShape);
|
||||||
|
checkKUNLUNError(baidu::xpu::api::broadcast_add<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)CData,
|
||||||
|
biasTensor->getRawDataPtr<float *>(), (float *)outData,
|
||||||
|
cDims, biasShape));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Matmul with no batch, call fc_fusion
|
||||||
|
const int lda = transA ? m : k, ldb = transB ? k : n, ldc = n;
|
||||||
|
auto kunlunAct = parseActType(std::move(op->getAct()));
|
||||||
|
checkKUNLUNError(
|
||||||
|
(baidu::xpu::api::fc_fusion<float, float, float, float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
|
(float *)outData, m, n, k, transA, transB, nullptr, nullptr,
|
||||||
|
nullptr, lda, ldb, ldc, alpha, 0.f,
|
||||||
|
biasTensor ? biasTensor->getRawDataPtr<float *>() : nullptr,
|
||||||
|
kunlunAct, nullptr)));
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -14,11 +14,23 @@ class AvgPooling : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||||
|
auto outShape = op->getOutput()->getDims();
|
||||||
|
|
||||||
std::vector<int> ksize = {kh, kw};
|
std::vector<int> ksize = {kh, kw};
|
||||||
std::vector<int> stride = {sh, sw};
|
std::vector<int> stride = {sh, sw};
|
||||||
std::vector<int> pad = {ph, pw};
|
std::vector<int> pad = {ph, pw};
|
||||||
|
|
||||||
|
int yh = outShape[op->getOutput()->getRank() - 2];
|
||||||
|
int yw = outShape[op->getOutput()->getRank() - 1];
|
||||||
|
|
||||||
|
// If Maxpool with ceilMode true
|
||||||
|
// We need to change padding in order to call xdnn api
|
||||||
|
if (op->getCeilMode() && yh > (h + 2 * ph - kh) / sh + 1) {
|
||||||
|
auto padh = yh - ((h + 2 * ph - kh) / sh + 1);
|
||||||
|
auto padw = yw - ((w + 2 * pw - kw) / sw + 1);
|
||||||
|
pad = {0, padh, 0, padw};
|
||||||
|
}
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::avg_pool2d<float>(
|
auto ret = baidu::xpu::api::avg_pool2d<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, n, c, h, w,
|
context->KUNLUNHandle(), (float *)aData, (float *)cData, n, c, h, w,
|
||||||
ksize, stride, pad, true, true, nullptr, nullptr);
|
ksize, stride, pad, true, true, nullptr, nullptr);
|
||||||
|
@ -38,21 +50,30 @@ class MaxPooling : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||||
|
auto outShape = op->getOutput()->getDims();
|
||||||
|
|
||||||
std::vector<int> ksize = {kh, kw};
|
std::vector<int> ksize = {kh, kw};
|
||||||
std::vector<int> stride = {sh, sw};
|
std::vector<int> stride = {sh, sw};
|
||||||
std::vector<int> pad = {ph, pw};
|
std::vector<int> pad = {ph, pw};
|
||||||
|
|
||||||
int yh = (h + ph * 2 - kh) / sh + 1;
|
int yh = outShape[op->getOutput()->getRank() - 2];
|
||||||
int yw = (w + pw * 2 - kw) / sw + 1;
|
int yw = outShape[op->getOutput()->getRank() - 1];
|
||||||
|
|
||||||
KUNLUNPtr indices = context->getWorkspace(yh * yw * 4);
|
// If Maxpool with ceilMode true
|
||||||
|
// We need to change padding in order to call xdnn api
|
||||||
|
if (op->getCeilMode() && yh > (h + 2 * ph - kh) / sh + 1) {
|
||||||
|
auto padh = yh - ((h + 2 * ph - kh) / sh + 1);
|
||||||
|
auto padw = yw - ((w + 2 * pw - kw) / sw + 1);
|
||||||
|
pad = {0, padh, 0, padw};
|
||||||
|
}
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::max_pool2d<float>(
|
KUNLUNPtr indices = context->getWorkspace(yh * yw * sizeof(int));
|
||||||
|
|
||||||
|
checkKUNLUNError(baidu::xpu::api::max_pool2d<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData,
|
context->KUNLUNHandle(), (float *)aData, (float *)cData,
|
||||||
(int *)indices, n, c, h, w, ksize, stride, pad, true, nullptr,
|
(int *)indices, n, c, h, w, ksize, stride, pad, true, nullptr,
|
||||||
nullptr, false);
|
nullptr, false));
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
#include "operators/reduce.h"
|
||||||
#include "kunlun/kunlun_kernel_without_config.h"
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
#include "kunlun/kunlun_runtime.h"
|
#include "kunlun/kunlun_runtime.h"
|
||||||
#include "operators/reduce.h"
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig {
|
class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
|
@ -26,6 +27,31 @@ class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ReduceSumXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<ReduceSumObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
auto axes_set = op->getAxes();
|
||||||
|
std::vector<int> axes;
|
||||||
|
axes.assign(axes_set.begin(), axes_set.end());
|
||||||
|
auto shape = op->getInputs(0)->getDims();
|
||||||
|
|
||||||
|
auto ret = baidu::xpu::api::reduce_sum<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)cData, shape,
|
||||||
|
axes);
|
||||||
|
assert(ret == 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, ReduceMeanXdnn,
|
REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, ReduceMeanXdnn,
|
||||||
"ReduceMean_xdnn_KUNLUN");
|
"ReduceMean_xdnn_KUNLUN");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceSum, ReduceSumXdnn,
|
||||||
|
"ReduceSum_xdnn_KUNLUN");
|
||||||
}; // namespace infini
|
}; // namespace infini
|
|
@ -1,32 +0,0 @@
|
||||||
#include "kunlun/kunlun_kernel_without_config.h"
|
|
||||||
#include "kunlun/kunlun_runtime.h"
|
|
||||||
#include "operators/where.h"
|
|
||||||
|
|
||||||
namespace infini {
|
|
||||||
class WhereXdnn : public KUNLUNKernelWithoutConfig {
|
|
||||||
void compute(const Operator &_op,
|
|
||||||
const RuntimeObj *_context) const override {
|
|
||||||
auto op = as<WhereObj>(_op);
|
|
||||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
|
||||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
|
||||||
|
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
|
||||||
void *const cData = (op->getInputs(2)->getRawDataPtr<void *>());
|
|
||||||
void *const dData = (op->getOutput()->getRawDataPtr<void *>());
|
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
|
||||||
auto cDim = op->getInputs(2)->getDims();
|
|
||||||
auto dDim = op->getOutput()->getDims();
|
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::select<float>(
|
|
||||||
context->KUNLUNHandle(), (bool *)cData, (float *)aData,
|
|
||||||
(float *)bData, (float *)dData, cDim, aDim);
|
|
||||||
assert(ret == 0);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Where, WhereXdnn, "Where_xdnn_KUNLUN");
|
|
||||||
}; // namespace infini
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class SliceXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<SliceObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *inData = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
void *outData = op->getOutput()->getRawDataPtr<void *>();
|
||||||
|
|
||||||
|
// Get attributes of Slice OP
|
||||||
|
Shape starts = op->getStarts(), ends = op->getEnds(),
|
||||||
|
steps = op->getSteps();
|
||||||
|
Shape inShape = op->getInputs(0)->getDims();
|
||||||
|
// If all steps are 1, set continuous True
|
||||||
|
bool continuous =
|
||||||
|
(size_t)std::count(steps.begin(), steps.end(), 1) == steps.size();
|
||||||
|
if (continuous) {
|
||||||
|
// if continuous, call xdnn::slice
|
||||||
|
checkKUNLUNError(
|
||||||
|
xdnn::slice<float>(context->KUNLUNHandle(), (float *)inData,
|
||||||
|
(float *)outData, inShape, starts, ends));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// else call xdnn::strided_slice
|
||||||
|
checkKUNLUNError(xdnn::strided_slice<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)inData, (float *)outData,
|
||||||
|
inShape, starts, ends, steps));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Slice, SliceXdnn, "Slice_xdnn_KUNLUN")
|
||||||
|
}; // namespace infini
|
|
@ -15,9 +15,9 @@ class SoftmaxXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::softmax<float>(
|
checkKUNLUNError(xdnn::softmax<float>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, dim, axis);
|
(float *)aData, (float *)cData,
|
||||||
assert(ret == 0);
|
dim, axis));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -16,13 +16,9 @@ class TransposeXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
auto dimin = op->getInputs(0)->getDims();
|
auto dimin = op->getInputs(0)->getDims();
|
||||||
auto permute = op->getPermute();
|
auto permute = op->getPermute();
|
||||||
|
|
||||||
if (dimin.size() != 4) {
|
auto ret =
|
||||||
IT_TODO_HALT();
|
xdnn::transpose<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
}
|
(float *)cData, dimin, permute);
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::transpose<float>(
|
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, dimin,
|
|
||||||
permute);
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -46,9 +42,9 @@ class DepthToSpaceXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
} else {
|
} else {
|
||||||
permute = {0, 1, 4, 2, 5, 3};
|
permute = {0, 1, 4, 2, 5, 3};
|
||||||
}
|
}
|
||||||
auto ret = baidu::xpu::api::transpose<float>(
|
auto ret =
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, reshape,
|
xdnn::transpose<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
permute);
|
(float *)cData, reshape, permute);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,8 +14,8 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::relu<float>(
|
auto ret = xdnn::relu<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -32,8 +32,8 @@ class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::sigmoid<float>(
|
auto ret = xdnn::sigmoid<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -50,8 +50,45 @@ class TanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::tanh<float>(
|
auto ret = xdnn::tanh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
assert(ret == 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HardSwishXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<UnaryObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
|
auto ret = xdnn::hard_swish<float>(context->KUNLUNHandle(),
|
||||||
|
(float *)aData, (float *)cData, len);
|
||||||
|
assert(ret == 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HardSigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<UnaryObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
|
// Slop set to 0.2 as default
|
||||||
|
auto ret = xdnn::hard_sigmoid<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)cData, len, 0.2);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -68,8 +105,8 @@ class SquareXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::square<float>(
|
auto ret = xdnn::square<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -86,8 +123,8 @@ class SqrtXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::sqrt<float>(
|
auto ret = xdnn::sqrt<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -104,8 +141,8 @@ class RsqrtXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::rsqrt<float>(
|
auto ret = xdnn::rsqrt<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -122,8 +159,8 @@ class ExpXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::exp<float>(
|
auto ret = xdnn::exp<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -140,8 +177,8 @@ class CeilXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::ceil<float>(
|
auto ret = xdnn::ceil<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -160,9 +197,8 @@ class ClipXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
float min = op->getMin().value();
|
float min = op->getMin().value();
|
||||||
float max = op->getMax().value();
|
float max = op->getMax().value();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::clip<float>(context->KUNLUNHandle(),
|
auto ret = xdnn::clip<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
(float *)aData, (float *)cData,
|
(float *)cData, len, min, max);
|
||||||
len, min, max);
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -179,8 +215,8 @@ class FloorXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::floor<float>(
|
auto ret = xdnn::floor<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -197,8 +233,8 @@ class NegXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::neg<float>(
|
auto ret = xdnn::neg<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -214,8 +250,8 @@ class CopyXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::copy<float>(
|
auto ret = xdnn::copy<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -232,8 +268,8 @@ class ReciprocalXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::reciprocal<float>(
|
auto ret = xdnn::reciprocal<float>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)aData, (float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -250,8 +286,8 @@ class AbsXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::abs<float>(
|
auto ret = xdnn::abs<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -268,8 +304,8 @@ class ATanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::arctan<float>(
|
auto ret = xdnn::arctan<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -288,36 +324,36 @@ class LogXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
1,
|
1,
|
||||||
};
|
};
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
|
auto dtype = op->getDType();
|
||||||
// get ptr of tempspace
|
// get ptr of tempspace
|
||||||
KUNLUNPtr temp = context->getWorkspace(len * sizeof(float));
|
KUNLUNPtr temp = context->getWorkspace(len * dtype.getSize());
|
||||||
LogObj::LogType type = op->getType();
|
LogObj::LogType type = op->getType();
|
||||||
// get output of xpu::api::loge(x)
|
// get output of xpu::api::loge(x)
|
||||||
auto ret = baidu::xpu::api::log<float>(
|
auto ret = xdnn::log<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)temp, len);
|
(float *)temp, len);
|
||||||
// get ptr of divider
|
// get ptr of divider
|
||||||
KUNLUNPtr dd =
|
KUNLUNPtr dd = context->getWorkspace(1 * dtype.getSize());
|
||||||
(float *)(context->getWorkspace((1 + len) * sizeof(float))) + len;
|
|
||||||
// choose from logE, log2, log10
|
// choose from logE, log2, log10
|
||||||
switch (type) {
|
switch (type) {
|
||||||
float constant;
|
float constant;
|
||||||
case LogObj::LogE:
|
case LogObj::LogE:
|
||||||
// if use loge, copy from temp to cData
|
// if use loge, copy from temp to cData
|
||||||
ret = baidu::xpu::api::copy<float>(
|
ret = xdnn::copy<float>(context->KUNLUNHandle(), (float *)temp,
|
||||||
context->KUNLUNHandle(), (float *)temp, (float *)cData, len);
|
(float *)cData, len);
|
||||||
break;
|
break;
|
||||||
case LogObj::Log2:
|
case LogObj::Log2:
|
||||||
constant = std::log(2);
|
constant = std::log(2);
|
||||||
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
||||||
ret = baidu::xpu::api::broadcast_div<float>(
|
ret = xdnn::broadcast_div<float>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)temp, (float *)dd,
|
(float *)temp, (float *)dd,
|
||||||
(float *)cData, aDim, divDim);
|
(float *)cData, aDim, divDim);
|
||||||
break;
|
break;
|
||||||
case LogObj::Log10:
|
case LogObj::Log10:
|
||||||
constant = std::log(10);
|
constant = std::log(10);
|
||||||
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
context->copyBlobFromCPU(dd, &constant, sizeof(float));
|
||||||
ret = baidu::xpu::api::broadcast_div<float>(
|
ret = xdnn::broadcast_div<float>(context->KUNLUNHandle(),
|
||||||
context->KUNLUNHandle(), (float *)temp, (float *)dd,
|
(float *)temp, (float *)dd,
|
||||||
(float *)cData, aDim, divDim);
|
(float *)cData, aDim, divDim);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
printf("LogType not support!");
|
printf("LogType not support!");
|
||||||
|
@ -337,8 +373,8 @@ class CosXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::cos<float>(
|
auto ret = xdnn::cos<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -354,8 +390,8 @@ class SinXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::sin<float>(
|
auto ret = xdnn::sin<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -371,8 +407,8 @@ class TanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::tan<float>(
|
auto ret = xdnn::tan<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -388,8 +424,8 @@ class SinhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::sinh<float>(
|
auto ret = xdnn::sinh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -405,8 +441,8 @@ class CoshXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::cosh<float>(
|
auto ret = xdnn::cosh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -422,8 +458,8 @@ class ErfXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::erf<float>(
|
auto ret = xdnn::erf<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -439,8 +475,8 @@ class ACosXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::arccos<float>(
|
auto ret = xdnn::arccos<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -456,8 +492,8 @@ class ACoshXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::acosh<float>(
|
auto ret = xdnn::acosh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -473,8 +509,8 @@ class ASinXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::arcsin<float>(
|
auto ret = xdnn::arcsin<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -490,8 +526,8 @@ class ASinhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::asinh<float>(
|
auto ret = xdnn::asinh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -507,8 +543,8 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
auto len = op->getInputs(0)->size();
|
auto len = op->getInputs(0)->size();
|
||||||
auto ret = baidu::xpu::api::atanh<float>(
|
auto ret = xdnn::atanh<float>(context->KUNLUNHandle(), (float *)aData,
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, len);
|
(float *)cData, len);
|
||||||
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
|
@ -546,7 +582,10 @@ REGISTER_KERNEL(Device::KUNLUN, OpType::Erf, ErfXdnn, "Erf_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, ACosXdnn, "ACos_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Acos, ACosXdnn, "ACos_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, ACoshXdnn, "ACosh_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Acosh, ACoshXdnn, "ACosh_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, ASinXdnn, "ASin_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Asin, ASinXdnn, "ASin_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, ASinhXdnn,
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Asinh, ASinhXdnn, "ASinh_xdnn");
|
||||||
"ASinh_xdnn_Float3 2");
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, ATanhXdnn, "ATanh_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Atanh, ATanhXdnn, "ATanh_xdnn");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::HardSwish, HardSwishXdnn,
|
||||||
|
"HardSwish_xdnn");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::HardSigmoid, HardSigmoidXdnn,
|
||||||
|
"HardSigmoid_xdnn");
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||||
|
#include "operators/where.h"
|
||||||
|
#include "kunlun/kunlun_kernel_without_config.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class WhereXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<WhereObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData =
|
||||||
|
(op->getInputs(0)->getRawDataPtr<void *>()); // inputX
|
||||||
|
void *const bData =
|
||||||
|
(op->getInputs(1)->getRawDataPtr<void *>()); // inputY
|
||||||
|
void *const cData =
|
||||||
|
(op->getInputs(2)->getRawDataPtr<void *>()); // condition
|
||||||
|
void *const dData =
|
||||||
|
(op->getOutput()->getRawDataPtr<void *>()); // output
|
||||||
|
|
||||||
|
auto aDim = op->getInputs(0)->getDims(); // dimX
|
||||||
|
auto bDim = op->getInputs(1)->getDims(); // dimY
|
||||||
|
auto cDim = op->getInputs(2)->getDims(); // dimCondition
|
||||||
|
auto dDim = op->getOutput()->getDims(); // dimOutput
|
||||||
|
|
||||||
|
auto dtype = op->getDType();
|
||||||
|
|
||||||
|
if (aDim != bDim) {
|
||||||
|
// Infer broadcast for X and Y
|
||||||
|
Shape XYDim = infer_broadcast(aDim, bDim);
|
||||||
|
int XYSize = std::accumulate(XYDim.begin(), XYDim.end(), 1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
// Align rank for XYDim and aDim or bDim
|
||||||
|
broadcastShape(XYDim, aDim);
|
||||||
|
broadcastShape(XYDim, bDim);
|
||||||
|
// Get workspace
|
||||||
|
void *wkspace = context->getWorkspace(XYSize * dtype.getSize());
|
||||||
|
// Broadcast X Y
|
||||||
|
checkKUNLUNError(xdnn::broadcast<float>(
|
||||||
|
context->KUNLUNHandle(),
|
||||||
|
(float *)(XYDim == aDim ? bData : aData), (float *)wkspace,
|
||||||
|
(XYDim == aDim ? bDim : aDim), XYDim));
|
||||||
|
// Align Rank
|
||||||
|
broadcastShape(dDim, XYDim);
|
||||||
|
broadcastShape(dDim, XYDim);
|
||||||
|
// Where
|
||||||
|
void *XData = XYDim == aDim ? aData : wkspace;
|
||||||
|
void *YData = XYDim == bDim ? bData : wkspace;
|
||||||
|
checkKUNLUNError(xdnn::select<float>(
|
||||||
|
context->KUNLUNHandle(), (bool *)cData, (float *)XData,
|
||||||
|
(float *)YData, (float *)dData, cDim, XYDim));
|
||||||
|
} else {
|
||||||
|
checkKUNLUNError(xdnn::select<float>(
|
||||||
|
context->KUNLUNHandle(), (bool *)cData, (float *)aData,
|
||||||
|
(float *)bData, (float *)dData, cDim, aDim));
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Where, WhereXdnn, "Where_xdnn_KUNLUN");
|
||||||
|
}; // namespace infini
|
|
@ -19,6 +19,7 @@ void KUNLUNRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||||
auto perfData = perfEngine.getPerfData(perfKey);
|
auto perfData = perfEngine.getPerfData(perfKey);
|
||||||
if (!perfData && !tune) {
|
if (!perfData && !tune) {
|
||||||
kernel->compute(op, this);
|
kernel->compute(op, this);
|
||||||
|
workspace->resetWorkspace();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,8 +53,20 @@ void KUNLUNRuntimeObj::run(const Graph &graph, bool tune,
|
||||||
sync();
|
sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
void KUNLUNRuntimeObj::sync() const { ; }
|
void KUNLUNRuntimeObj::sync() const { xpu_wait(); }
|
||||||
|
|
||||||
string KUNLUNRuntimeObj::toString() const { return "KUNLUN Runtime"; }
|
string KUNLUNRuntimeObj::toString() const { return "KUNLUN Runtime"; }
|
||||||
|
|
||||||
|
void KUNLUNRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||||
|
IT_ASSERT(worldSize > 0);
|
||||||
|
IT_ASSERT(rank >= 0);
|
||||||
|
IT_ASSERT(rank < worldSize);
|
||||||
|
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
comm = std::make_unique<XcclCommunicatorObj>(name, worldSize, rank);
|
||||||
|
#else
|
||||||
|
IT_TODO_HALT_MSG("Not compiled with XCCL");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -25,7 +25,7 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) {
|
||||||
auto A = inputs[0], B = inputs[1];
|
auto A = inputs[0], B = inputs[1];
|
||||||
auto shapeA = A->getDims();
|
auto shapeA = A->getDims();
|
||||||
auto shapeB = B->getDims();
|
auto shapeB = B->getDims();
|
||||||
int rankA = A->getRank();
|
int rankA = A->getRank(); // Rank is the Shape of TensorDims
|
||||||
int rankB = B->getRank();
|
int rankB = B->getRank();
|
||||||
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
|
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
|
||||||
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
|
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
|
||||||
|
|
|
@ -231,6 +231,8 @@ DataType CastObj::getOutputDataType() const {
|
||||||
return DataType::Float32;
|
return DataType::Float32;
|
||||||
case CastType::Float2BFloat16:
|
case CastType::Float2BFloat16:
|
||||||
return DataType::BFloat16;
|
return DataType::BFloat16;
|
||||||
|
case CastType::Float2Float:
|
||||||
|
return DataType::Float32;
|
||||||
default:
|
default:
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,4 +114,28 @@ std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs) {
|
||||||
std::string opStr = OpType(std::get<1>(kernelAttrs)).toString();
|
std::string opStr = OpType(std::get<1>(kernelAttrs)).toString();
|
||||||
return deviceStr + ", " + opStr;
|
return deviceStr + ", " + opStr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int shapeProd(std::vector<int>::iterator start,
|
||||||
|
std::vector<int>::iterator end) {
|
||||||
|
return std::accumulate(start, end, 1, std::multiplies<int>());
|
||||||
|
}
|
||||||
|
|
||||||
|
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||||
|
int nDims, int size) {
|
||||||
|
for (int i = nDims - size - 1; i >= 0; --i) {
|
||||||
|
modifyShape.data[i] = 1;
|
||||||
|
}
|
||||||
|
for (int i = nDims - 1; i >= nDims - size; --i) {
|
||||||
|
modifyShape.data[i] = originShape[i - nDims + size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void broadcastShape(const Shape &tempShape, Shape &modifyShape) {
|
||||||
|
// Align Rank, Add 1 in the start of smallShape
|
||||||
|
IT_ASSERT(tempShape.size() >= modifyShape.size());
|
||||||
|
modifyShape.insert(modifyShape.begin(),
|
||||||
|
tempShape.size() - modifyShape.size(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/all_gather.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include "xpu/bkcl.h"
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void allGather(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<vector<float>> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime kunlunRuntime = make_ref<KUNLUNRuntimeObj>(deviceID);
|
||||||
|
kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AllGatherObj>(input, std::nullopt, WORLD_SIZE);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
// Run operation
|
||||||
|
kunlunRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
for (int i = 0; i < WORLD_SIZE; ++i) {
|
||||||
|
auto result = op->getOutputs()[i]->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData(ans[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_AllGather, run) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<vector<float>> ans = {{2., 3.}, {5., 6.}};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allGather, "test_all_gather", gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,72 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/all_reduce.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include "xpu/bkcl.h"
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
|
||||||
|
using namespace infini;
|
||||||
|
|
||||||
|
template <typename OperatorObj>
|
||||||
|
void allReduce(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime kunlunRuntime = make_ref<KUNLUNRuntimeObj>(deviceID);
|
||||||
|
kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
Graph g = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<OperatorObj>(input, nullptr);
|
||||||
|
g->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
kunlunRuntime->run(g);
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_AllReduce, sum) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {7., 9.};
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int rank = 0; rank < WORLD_SIZE; ++rank) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceSumObj>, "test_allreduce_sum",
|
||||||
|
rank, data[rank], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_AllReduce, max) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {5., 6.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceMaxObj>, "test_allreduce_max",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_AllReduce, min) {
|
||||||
|
vector<float> data[2] = {{2., 3.}, {5., 6.}};
|
||||||
|
vector<float> ans = {2., 3.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(allReduce<AllReduceMinObj>, "test_allreduce_min",
|
||||||
|
gpu, data[gpu], ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -0,0 +1,56 @@
|
||||||
|
#ifdef INFINI_USE_XCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/broadcast.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <thread>
|
||||||
|
#include <xpu/bkcl.h>
|
||||||
|
|
||||||
|
static int WORLD_SIZE = 2;
|
||||||
|
static int root = 0;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void broadcast(const string taskName, int deviceID, vector<float> data,
|
||||||
|
vector<float> ans) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime kunlunRuntime = make_ref<KUNLUNRuntimeObj>(deviceID);
|
||||||
|
kunlunRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
// Create Graph and insert allReduce operation
|
||||||
|
Graph g = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
auto input =
|
||||||
|
g->addTensor(Shape{static_cast<int>(data.size())}, DataType::Float32);
|
||||||
|
auto op = g->addOp<BroadcastObj>(input, nullptr, root);
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
g->dataMalloc();
|
||||||
|
// Only rank 0 has the data
|
||||||
|
if (deviceID == root) {
|
||||||
|
input->copyin(data);
|
||||||
|
}
|
||||||
|
// Run broadcast operation
|
||||||
|
kunlunRuntime->run(g);
|
||||||
|
// Copy output from GPU to CPU
|
||||||
|
auto result = op->getOutput()->clone(cpuRuntime);
|
||||||
|
|
||||||
|
EXPECT_TRUE(result->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_Broadcast, run) {
|
||||||
|
// Only 1 device gets data. Every rank should have the same data after
|
||||||
|
// broadcast.
|
||||||
|
vector<float> data = {2., 3., 5., 6.};
|
||||||
|
vector<float> ans = {2., 3., 5., 6.};
|
||||||
|
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(broadcast, "test_broadcast", gpu, data, ans);
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,144 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/*
|
||||||
|
test1:
|
||||||
|
input = [
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
]
|
||||||
|
indices = [
|
||||||
|
[0, 1],
|
||||||
|
[1, 2],
|
||||||
|
]
|
||||||
|
output = [
|
||||||
|
[
|
||||||
|
[1, 2],
|
||||||
|
[3, 4],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[3, 4],
|
||||||
|
[5, 6],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
axis=0
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
test2
|
||||||
|
input = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
indices = [
|
||||||
|
[0, 2],
|
||||||
|
]
|
||||||
|
axis = 1,
|
||||||
|
output = [
|
||||||
|
[[0, 2]],
|
||||||
|
[[3, 5]],
|
||||||
|
[[6, 8]],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
test3
|
||||||
|
input=[[[ 0, 1],
|
||||||
|
[ 2, 3],
|
||||||
|
[ 4, 5],
|
||||||
|
[ 6, 7]],
|
||||||
|
|
||||||
|
[[ 8, 9],
|
||||||
|
[10, 11],
|
||||||
|
[12, 13],
|
||||||
|
[14, 15]]] //(2,4,2)
|
||||||
|
indices=[[0],[3],[1]] //(3,1)
|
||||||
|
axis=1
|
||||||
|
output=
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST(Gather, KUNLUN) {
|
||||||
|
{
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({2, 2}, DataType::Int32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||||
|
index->copyin(vector<int>{0, 1, 1, 2});
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
|
||||||
|
auto inputCuda = gCuda->cloneTensor(input);
|
||||||
|
auto indexCuda = gCuda->cloneTensor(index);
|
||||||
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputCuda->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||||
|
indexCuda->copyin(vector<int>{0, 1, 1, 2});
|
||||||
|
kunlunRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({1, 2}, DataType::Int32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->setData(IncrementalGenerator());
|
||||||
|
index->copyin(vector<int>{0, 2});
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
|
||||||
|
auto inputCuda = gCuda->cloneTensor(input);
|
||||||
|
auto indexCuda = gCuda->cloneTensor(index);
|
||||||
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputCuda->setData(IncrementalGenerator());
|
||||||
|
indexCuda->copyin(vector<int>{0, 2});
|
||||||
|
kunlunRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||||
|
auto index = gCpu->addTensor({2, 2}, DataType::Int32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
input->copyin(std::vector<float>{1.0, 1.2, 2.3, 3.4, 4.5, 5.7});
|
||||||
|
index->copyin(std::vector<int>{0, 1, 1, 2});
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
|
||||||
|
auto inputCuda = gCuda->cloneTensor(input);
|
||||||
|
auto indexCuda = gCuda->cloneTensor(index);
|
||||||
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
inputCuda->copyin(std::vector<float>{1.0, 1.2, 2.3, 3.4, 4.5, 5.7});
|
||||||
|
indexCuda->copyin(std::vector<int>{0, 1, 1, 2});
|
||||||
|
kunlunRuntime->run(gCuda);
|
||||||
|
|
||||||
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
// copy output from CUDA to CPU
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(
|
||||||
|
std::vector<float>{1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -7,52 +7,112 @@
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
using ExpectOutput = vector<float>;
|
||||||
|
|
||||||
template <class T>
|
void testMatmulKUNLUNWithBias(
|
||||||
void testMatmul(const std::function<void(void *, size_t, DataType)> &generatorA,
|
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||||
bool transA, bool transB, const Shape &shapeA,
|
const std::function<void(void *, size_t, DataType)> &generatorBias,
|
||||||
const Shape &shapeB) {
|
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||||
// Runtime
|
const Shape &shapeBias, const ExpectOutput &ansVec) {
|
||||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
auto xpuRuntime = make_ref<KUNLUNRuntimeObj>();
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||||
|
auto BiasCpu = gCpu->addTensor(shapeBias, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(generatorA);
|
||||||
|
BCpu->setData(generatorB);
|
||||||
|
BiasCpu->setData(generatorBias);
|
||||||
|
|
||||||
// Build input data on CPU
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
Tensor inputCpu1 =
|
auto gKunlun = make_ref<GraphObj>(kunlunRuntime);
|
||||||
make_ref<TensorObj>(shapeA, DataType::Float32, cpuRuntime);
|
auto AKunlun = gKunlun->cloneTensor(ACpu);
|
||||||
Tensor inputCpu2 =
|
auto BKunlun = gKunlun->cloneTensor(BCpu);
|
||||||
make_ref<TensorObj>(shapeB, DataType::Float32, cpuRuntime);
|
auto BiasKunlun = gKunlun->cloneTensor(BiasCpu);
|
||||||
|
auto matmul = gKunlun->addOp<MatmulObj>(AKunlun, BKunlun, nullptr, transA,
|
||||||
|
transB, BiasKunlun);
|
||||||
|
|
||||||
// MLU
|
// allocate Kunlun memory
|
||||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
gKunlun->dataMalloc();
|
||||||
auto inputMlu1 = xpuGraph->cloneTensor(inputCpu1);
|
AKunlun->setData(generatorA);
|
||||||
auto inputMlu2 = xpuGraph->cloneTensor(inputCpu2);
|
BKunlun->setData(generatorB);
|
||||||
auto mluOp = xpuGraph->addOp<T>(inputMlu1, inputMlu2, nullptr);
|
BiasKunlun->setData(generatorBias);
|
||||||
xpuGraph->dataMalloc();
|
kunlunRuntime->run(gKunlun);
|
||||||
inputMlu1->setData(generatorA);
|
|
||||||
inputMlu2->setData(generatorB);
|
auto CCpu = gCpu->cloneTensor(matmul->getOutput());
|
||||||
xpuRuntime->run(xpuGraph);
|
// CCpu->printData();
|
||||||
auto outputMlu = mluOp->getOutput();
|
// check results on CPU
|
||||||
auto outputMlu2Cpu = outputMlu->clone(cpuRuntime);
|
EXPECT_TRUE(CCpu->equalData(ansVec));
|
||||||
// CPU
|
// print a tensor/operator/graph by print()
|
||||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
// gKunlun->print();
|
||||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
|
|
||||||
cpuGraph->addTensor(inputCpu1);
|
|
||||||
cpuGraph->addTensor(inputCpu2);
|
|
||||||
cpuGraph->dataMalloc();
|
|
||||||
inputCpu1->setData(generatorA);
|
|
||||||
inputCpu2->setData(generatorB);
|
|
||||||
cpuRuntime->run(cpuGraph);
|
|
||||||
auto outputCpu = cpuOp->getOutput();
|
|
||||||
outputCpu->print();
|
|
||||||
outputMlu2Cpu->print();
|
|
||||||
// Check
|
|
||||||
EXPECT_TRUE(outputCpu->equalData(outputMlu2Cpu));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(xpu_Matmul, run) {
|
void testMatmulKUNLUN(
|
||||||
testMatmul<MatmulObj>(IncrementalGenerator(), IncrementalGenerator(), false,
|
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||||
false, Shape{2, 3}, Shape{3, 4});
|
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||||
|
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||||
|
const ExpectOutput &ansVec) {
|
||||||
|
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(generatorA);
|
||||||
|
BCpu->setData(generatorB);
|
||||||
|
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
auto gKunlun = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
auto AKunlun = gKunlun->cloneTensor(ACpu);
|
||||||
|
auto BKunlun = gKunlun->cloneTensor(BCpu);
|
||||||
|
auto matmul = gKunlun->addOp<MatmulObj>(AKunlun, BKunlun, nullptr, transA,
|
||||||
|
transB, nullptr);
|
||||||
|
|
||||||
|
// allocate Kunlun memory
|
||||||
|
gKunlun->dataMalloc();
|
||||||
|
AKunlun->setData(generatorA);
|
||||||
|
BKunlun->setData(generatorB);
|
||||||
|
kunlunRuntime->run(gKunlun);
|
||||||
|
|
||||||
|
auto CCpu = gCpu->cloneTensor(matmul->getOutput());
|
||||||
|
// CCpu->printData();
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(CCpu->equalData(ansVec));
|
||||||
|
// print a tensor/operator/graph by print()
|
||||||
|
// gKunlun->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XDNN_Matmul, run) {
|
||||||
|
testMatmulKUNLUN(IncrementalGenerator(), OneGenerator(), false, false,
|
||||||
|
Shape{1, 3, 5}, Shape{1, 5, 2},
|
||||||
|
ExpectOutput{10, 10, 35, 35, 60, 60});
|
||||||
|
testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), true,
|
||||||
|
false, Shape{2, 3, 4}, Shape{2, 3, 2},
|
||||||
|
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
|
||||||
|
475, 448, 502, 472, 529});
|
||||||
|
testMatmulKUNLUN(
|
||||||
|
IncrementalGenerator(), IncrementalGenerator(), false, false,
|
||||||
|
Shape{2, 3, 5}, Shape{5, 2},
|
||||||
|
ExpectOutput{60, 70, 160, 195, 260, 320, 360, 445, 460, 570, 560, 695});
|
||||||
|
testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), true,
|
||||||
|
false, Shape{2, 5, 3}, Shape{5, 2},
|
||||||
|
ExpectOutput{180, 210, 200, 235, 220, 260, 480, 585, 500,
|
||||||
|
610, 520, 635});
|
||||||
|
testMatmulKUNLUN(IncrementalGenerator(), IncrementalGenerator(), false,
|
||||||
|
false, Shape{3, 5}, Shape{5, 2},
|
||||||
|
ExpectOutput{60, 70, 160, 195, 260, 320});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XDNN_Matmul_With_Bias, run) {
|
||||||
|
testMatmulKUNLUNWithBias(IncrementalGenerator(), OneGenerator(),
|
||||||
|
OneGenerator(), false, false, Shape{1, 3, 5},
|
||||||
|
Shape{1, 5, 2}, Shape{2},
|
||||||
|
ExpectOutput{11, 11, 36, 36, 61, 61});
|
||||||
|
testMatmulKUNLUNWithBias(IncrementalGenerator(), IncrementalGenerator(),
|
||||||
|
OneGenerator(), true, false, Shape{2, 3, 4},
|
||||||
|
Shape{2, 3, 2}, Shape{4, 2},
|
||||||
|
ExpectOutput{41, 53, 47, 62, 53, 71, 59, 80, 401,
|
||||||
|
449, 425, 476, 449, 503, 473, 530});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(KUNLUN_Slice, run) {
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu =
|
||||||
|
make_ref<TensorObj>(Shape{3, 2, 1, 5}, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Build CUDA graph;
|
||||||
|
Graph g = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op =
|
||||||
|
g->addOp<SliceObj>(i, nullptr, vector<int>{1, 1}, vector<int>{2, 5},
|
||||||
|
vector<int>{0, 3}, std::nullopt);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
i->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
kunlunRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto cpuo = o->clone(cpuRuntime);
|
||||||
|
// cudaPrintTensor(o);
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(cpuo->equalData(vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,77 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
#include "operators/where.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void test_where(const Shape &inputXShape, const vector<float> &inputXData,
|
||||||
|
const Shape &inputYShape, const vector<float> &inputYData,
|
||||||
|
const Shape &conditionShape,
|
||||||
|
const vector<int8_t> &conditionData,
|
||||||
|
const vector<float> &ExpectData) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
auto condition = gCpu->addTensor(conditionShape, DataType::Bool);
|
||||||
|
auto inputX = gCpu->addTensor(inputXShape, DataType::Float32);
|
||||||
|
auto inputY = gCpu->addTensor(inputYShape, DataType::Float32);
|
||||||
|
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
condition->copyin(conditionData); //
|
||||||
|
inputX->copyin(inputXData);
|
||||||
|
inputY->copyin(inputYData); //
|
||||||
|
|
||||||
|
auto kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(kunlunRuntime);
|
||||||
|
|
||||||
|
auto conditionGpu = gCuda->cloneTensor(condition);
|
||||||
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
|
auto inputYGpu = gCuda->cloneTensor(inputY);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<WhereObj>(inputXGpu, inputYGpu, conditionGpu,
|
||||||
|
nullptr); // WhereObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
conditionGpu->copyin(conditionData);
|
||||||
|
inputXGpu->copyin(inputXData);
|
||||||
|
inputYGpu->copyin(inputYData);
|
||||||
|
kunlunRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KUNLUN_Where, run) {
|
||||||
|
test_where(
|
||||||
|
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
|
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
Shape{2, 2, 3, 1}, vector<int8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||||
|
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
||||||
|
|
||||||
|
test_where(Shape{2, 2, 1, 3}, // inputx
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5},
|
||||||
|
Shape{2, 2, 1, 3}, // inputy
|
||||||
|
vector<float>{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7},
|
||||||
|
Shape{2, 2, 1, 3}, // condition
|
||||||
|
vector<int8_t>{0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0},
|
||||||
|
vector<float>{1, 1, 2, 2, 5, 1, 0, 2, 2, 3, 6, 7});
|
||||||
|
|
||||||
|
test_where(Shape{2, 2, 1, 3},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}, // inputX
|
||||||
|
Shape{2, 2, 1, 3},
|
||||||
|
vector<float>{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7}, // inputY
|
||||||
|
Shape{2, 1, 1, 3}, vector<int8_t>{1, 1, 0, 1, 1, 1}, // condition
|
||||||
|
vector<float>{0, 1, 3, 3, 4, 1, 0, 1, 2, 3, 4, 5}); // result
|
||||||
|
|
||||||
|
test_where(Shape{2, 2, 1, 3},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}, // inputX
|
||||||
|
Shape{2, 2, 1, 3},
|
||||||
|
vector<float>{1, 1, 3, 2, 5, 1, 5, 2, 3, 5, 6, 7}, // inputY
|
||||||
|
Shape{2, 1, 1, 3},
|
||||||
|
vector<int8_t>{1, 1, 0, 1, 1,
|
||||||
|
1}, // condition } // python output
|
||||||
|
vector<float>{0, 1, 3, 3, 4, 1, 0, 1, 2, 3, 4, 5}); // result
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,20 @@
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "core/workspace.h"
|
||||||
|
#include "kunlun/kunlun_runtime.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(KunlunWorkspace, test) {
|
||||||
|
Ref<KUNLUNRuntimeObj> kunlunRuntime = make_ref<KUNLUNRuntimeObj>();
|
||||||
|
auto wkspace = kunlunRuntime->getWorkspaceObj();
|
||||||
|
KUNLUNPtr space1 = kunlunRuntime->getWorkspace(1024 * 1024 * sizeof(float));
|
||||||
|
IT_ASSERT(wkspace->getWorkspaceAlloc() == 1024 * 1024 * sizeof(float));
|
||||||
|
KUNLUNPtr space2 = kunlunRuntime->getWorkspace(1024 * 1024 * sizeof(float));
|
||||||
|
IT_ASSERT(wkspace->getWorkspaceAlloc() == 1024 * 1024 * sizeof(float) * 2);
|
||||||
|
IT_ASSERT((void *)(static_cast<uint8_t *>(space1) +
|
||||||
|
1024 * 1024 * sizeof(float)) == (void *)space2);
|
||||||
|
wkspace->resetWorkspace();
|
||||||
|
IT_ASSERT(wkspace->getWorkspaceAlloc() == 0);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue