forked from jiuyuan/InfiniTensor
kunlun distributed
This commit is contained in:
parent
a71cd14963
commit
32a13b7760
|
@ -285,8 +285,8 @@ if(USE_KUNLUN)
|
||||||
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
||||||
|
|
||||||
include_directories("${KUNLUN_HOME}/include/")
|
include_directories("${KUNLUN_HOME}/include/")
|
||||||
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/")
|
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64/")
|
||||||
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/")
|
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/lib64/")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||||
|
|
||||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||||
|
|
|
@ -9,7 +9,7 @@ find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR}
|
||||||
HINTS XCCL_INCLUDE_DIR)
|
HINTS XCCL_INCLUDE_DIR)
|
||||||
|
|
||||||
find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR}
|
find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR}
|
||||||
NAMES so/libbkcl.so
|
NAMES lib64/libbkcl.so
|
||||||
HINTS XCCL_LIB_DIR)
|
HINTS XCCL_LIB_DIR)
|
||||||
|
|
||||||
message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}")
|
message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}")
|
||||||
|
|
|
@ -0,0 +1,278 @@
|
||||||
|
import sys
|
||||||
|
sys.path.append('../')
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
import numpy as np
|
||||||
|
from parallel_opt import parallel_model
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="", help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_single",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether run model with single process with standard inputs"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_dir",
|
||||||
|
default="./",
|
||||||
|
help="path to save model input data"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--result_dir",
|
||||||
|
default="./",
|
||||||
|
help="path to save model standard output"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--internal_model_dir",
|
||||||
|
default="./",
|
||||||
|
help="path to save internal onnx model for parallel run"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# check path, mkdir if not exist
|
||||||
|
check_exists(args.input_dir)
|
||||||
|
check_exists(args.result_dir)
|
||||||
|
check_exists(args.internal_model_dir)
|
||||||
|
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.gen_std,
|
||||||
|
args.run_single,
|
||||||
|
args.input_dir,
|
||||||
|
args.result_dir,
|
||||||
|
args.internal_model_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
utils function for this scripts
|
||||||
|
"""
|
||||||
|
def check_exists(path: str):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
def np_assert(base, test, rtol=1e-2, atol=1e-1):
|
||||||
|
# np.testing.assert_allclose(test, base, rtol, atol)
|
||||||
|
print("max abs diff:", abs(base - test).max())
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Perf wrapper, run function n times
|
||||||
|
then average
|
||||||
|
"""
|
||||||
|
def perf_it(n):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# warmup
|
||||||
|
for _ in range(n):
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
t_total = 0
|
||||||
|
for _ in range(n):
|
||||||
|
t0 = time.time()
|
||||||
|
func(*args, **kwargs)
|
||||||
|
t1 = time.time()
|
||||||
|
t_total += t1 - t0
|
||||||
|
avg_time = (t_total) / n
|
||||||
|
print(f"Avg runtime of {n} time is {avg_time:.6f} seconds")
|
||||||
|
return avg_time
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Run InfiniTensor model with Standard input
|
||||||
|
check=True: check with standard output gen by pytorch
|
||||||
|
perf=True: run n times to get avg time
|
||||||
|
"""
|
||||||
|
def run_model(task_name,
|
||||||
|
model,
|
||||||
|
runtime,
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
n=10,
|
||||||
|
check=True,
|
||||||
|
perf=True):
|
||||||
|
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
|
||||||
|
# load in Onnx model inputs
|
||||||
|
def load_inputs(stub: OnnxStub):
|
||||||
|
# check exists
|
||||||
|
inputs = []
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input_path = os.path.join(input_dir, \
|
||||||
|
f"{task_name}_input_{i}.npy")
|
||||||
|
print(input_path)
|
||||||
|
if os.path.exists(input_path):
|
||||||
|
input = np.load(input_path)
|
||||||
|
else :
|
||||||
|
raise KeyError(f"{i} th input of model not exists")
|
||||||
|
# check shape
|
||||||
|
if all(x == y for x,y in zip(input.shape, tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
load_inputs(stub)
|
||||||
|
# stub.tune()
|
||||||
|
stub.run()
|
||||||
|
time.sleep(0.01)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
|
# check output results with standard output
|
||||||
|
if check:
|
||||||
|
st_output_path = os.path.join(result_dir, \
|
||||||
|
f"{task_name}_output.npy")
|
||||||
|
assert os.path.exists(st_output_path) , \
|
||||||
|
"standard output not exists"
|
||||||
|
st_output = np.load(st_output_path)
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
exit()
|
||||||
|
np_assert(st_output, output)
|
||||||
|
|
||||||
|
# perf
|
||||||
|
if perf:
|
||||||
|
@perf_it(n)
|
||||||
|
def perf_infinitensor(stub: OnnxStub):
|
||||||
|
stub.run()
|
||||||
|
perf_infinitensor(stub)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Start a worker in Parallel
|
||||||
|
"""
|
||||||
|
def start_worker(name: str,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
local_rank: int,
|
||||||
|
model: onnx.ModelProto):
|
||||||
|
|
||||||
|
dist_name = name + "_dist"
|
||||||
|
# partial a onnx model to world_size part
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
onnx.save(model, os.path.join(internal_model_dir, \
|
||||||
|
f"{dist_name}_rank{rank}.onnx"))
|
||||||
|
runtime = backend.KUNLUNRuntime(local_rank)
|
||||||
|
# print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
run_model(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
generate standard input/output with
|
||||||
|
sigle card run
|
||||||
|
"""
|
||||||
|
def gen_stardard(task_name: str, model: onnx.ModelProto):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
position_id = 0
|
||||||
|
# generate random input for model
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = tensor.copyout_numpy()
|
||||||
|
if np.issubdtype(input.dtype, np.integer):
|
||||||
|
if input.size == 1:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
else:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
elif input.dtype == np.bool_:
|
||||||
|
input = np.random.randint(0,2,size=input.shape) > 0
|
||||||
|
else:
|
||||||
|
if i == 0:
|
||||||
|
input = np.ones(input.shape).astype(input.dtype)
|
||||||
|
position_id = input.shape[-1] - 1
|
||||||
|
else:
|
||||||
|
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
np.save(os.path.join(input_dir, \
|
||||||
|
f"{task_name}_input_{i}.npy"), input)
|
||||||
|
stub.run()
|
||||||
|
# print(stub.outputs)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
exit()
|
||||||
|
np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
global input_dir, result_dir, internal_model_dir
|
||||||
|
|
||||||
|
nnodes, nproc_per_node, task_name, \
|
||||||
|
model_path, gen_std, run_single, \
|
||||||
|
input_dir, result_dir, internal_model_dir = parse_args()
|
||||||
|
|
||||||
|
# load input onnx model
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print("Generate inputs and outputs.")
|
||||||
|
gen_stardard(task_name, model)
|
||||||
|
return
|
||||||
|
|
||||||
|
if run_single:
|
||||||
|
print("Run model by one GPU card.")
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
run_model(task_name, model, runtime)
|
||||||
|
return
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"Run model by {world_size} GPU in parallel.")
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(task_name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,215 @@
|
||||||
|
import sys
|
||||||
|
sys.path.append("../")
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
import numpy as np
|
||||||
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
st_input_dir = ".cache/input/"
|
||||||
|
st_output_dir = ".cache/output/"
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_single",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether run model with single process with standard inputs"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.gen_std,
|
||||||
|
args.run_single
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
load_inputs(stub, world_size, rank)
|
||||||
|
# stub.tune()
|
||||||
|
stub.run()
|
||||||
|
# get outputs
|
||||||
|
time.sleep(0.01)
|
||||||
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
|
# bench
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n):
|
||||||
|
stub.run()
|
||||||
|
end = time.time()
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||||
|
results = np.load(os.path.join(st_output_dir, "test_output.npy"))
|
||||||
|
outputs = run_model(model, runtime, world_size, rank)
|
||||||
|
print(outputs[:100])
|
||||||
|
if np.isnan(outputs).any():
|
||||||
|
print("Nan in output")
|
||||||
|
print("answer argmax:", np.argmax(results))
|
||||||
|
print("output argmax:", np.argmax(outputs))
|
||||||
|
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||||
|
getDiff(results, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
|
):
|
||||||
|
dist_name = name + "_dist"
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
onnx.save_model(
|
||||||
|
model,
|
||||||
|
f"./{dist_name}_rank{rank}.onnx",
|
||||||
|
save_as_external_data=True,
|
||||||
|
location=extern_path,
|
||||||
|
)
|
||||||
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
|
runtime = backend.KUNLUNRuntime(local_rank)
|
||||||
|
# print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
run_and_compare(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def start_single(name, model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
run_and_compare(name, model, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_output(model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
position_id = 0
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = tensor.copyout_numpy()
|
||||||
|
if np.issubdtype(input.dtype, np.integer):
|
||||||
|
if input.size == 1:
|
||||||
|
# input = np.array([position_id])
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
else:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
elif input.dtype == np.bool_:
|
||||||
|
input = np.random.randint(0,2,size=input.shape) > 0
|
||||||
|
else:
|
||||||
|
if i == 0:
|
||||||
|
input = np.ones(input.shape).astype(input.dtype)
|
||||||
|
position_id = input.shape[-1] - 1
|
||||||
|
else:
|
||||||
|
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
|
||||||
|
stub.run()
|
||||||
|
# print(stub.outputs)
|
||||||
|
time.sleep(0.01)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
print(output[:100])
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
np.save(os.path.join(st_output_dir, f"output"), output)
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(os.path.join(st_input_dir, f"test_input_{name}.npy"))
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
|
||||||
|
def getDiff(base, test):
|
||||||
|
absolute_diff = np.abs(np.subtract(base, test))
|
||||||
|
max_absolute_diff = np.max(absolute_diff)
|
||||||
|
|
||||||
|
baseCopy = base.astype(np.float64).ravel()
|
||||||
|
testCopy = test.astype(np.float64).ravel()
|
||||||
|
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||||
|
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||||
|
max_relative_diff = upValue / downValue
|
||||||
|
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
|
||||||
|
|
||||||
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print("Generate inputs and outputs.")
|
||||||
|
p = mp.Process(target=generate_input_output, args=[model])
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# # run single process.
|
||||||
|
# # use standalone process to isolate cuda.
|
||||||
|
if run_single:
|
||||||
|
print("run model by single GPU.")
|
||||||
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"run model by {world_size} GPU in parallel.")
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,208 @@
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from transformers import BertModel, BertConfig
|
||||||
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
from transformers import OPTModel, OPTConfig
|
||||||
|
from transformers import LlamaModel, LlamaConfig
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import os
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnxsim import simplify
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type"
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--export_onnx",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default=None,
|
||||||
|
const="./",
|
||||||
|
help="whether and where to export onnx file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_dir",
|
||||||
|
type=str,
|
||||||
|
default="./",
|
||||||
|
help="path to save pytorch model input data"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--result_dir",
|
||||||
|
type=str,
|
||||||
|
default="./",
|
||||||
|
help="path to save pytorch model output data"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.export_onnx,
|
||||||
|
args.input_dir,
|
||||||
|
args.result_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(modelname):
|
||||||
|
if modelname == "bert":
|
||||||
|
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||||
|
voc_size = BertConfig().vocab_size
|
||||||
|
elif modelname == "gpt2":
|
||||||
|
model = GPT2Model.from_pretrained("gpt2")
|
||||||
|
voc_size = GPT2Config().vocab_size
|
||||||
|
elif modelname == "opt":
|
||||||
|
model = OPTModel.from_pretrained("./opt-125m")
|
||||||
|
voc_size = OPTConfig().vocab_size
|
||||||
|
elif modelname == "llama":
|
||||||
|
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
voc_size = LlamaConfig().vocab_size
|
||||||
|
else :
|
||||||
|
raise KeyError(modelname)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
return model, voc_size
|
||||||
|
|
||||||
|
def run_pytorch(torch_model, voc_size, batchsize, len, model_name):
|
||||||
|
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||||
|
np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data)
|
||||||
|
inputs = torch.from_numpy(data).to("cuda")
|
||||||
|
torch_model = torch_model.to("cuda")
|
||||||
|
|
||||||
|
n_iter = 20
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(10):
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
begin = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n_iter):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
#
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
avg_time = (end - begin) / n_iter
|
||||||
|
outputs = outputs.last_hidden_state.to("cpu")
|
||||||
|
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
torch.cuda.memory.empty_cache()
|
||||||
|
np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \
|
||||||
|
np.array(outputs))
|
||||||
|
print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy")
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(model_name, model, data, path, extern=False):
|
||||||
|
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||||
|
onnx_model = onnx.load(path)
|
||||||
|
# onnx_model, check = simplify(onnx_model,
|
||||||
|
# skip_shape_inference=True,
|
||||||
|
# skipped_optimizers=['eliminate_duplicate_initializer'])
|
||||||
|
if model_name == "gpt2":
|
||||||
|
onnx_model, check = simplify(onnx_model,
|
||||||
|
skip_shape_inference=True,
|
||||||
|
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||||
|
else :
|
||||||
|
onnx_model, check = simplify(onnx_model,
|
||||||
|
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||||
|
assert check
|
||||||
|
add_value_info_for_constants(onnx_model)
|
||||||
|
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||||
|
if extern:
|
||||||
|
extern_path = path.replace('.onnx', '.pb')
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
convert_model_to_external_data(
|
||||||
|
onnx_model,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=extern_path.split("/")[-1],
|
||||||
|
size_threshold=1024,
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
onnx.save(onnx_model, path)
|
||||||
|
|
||||||
|
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||||
|
"""
|
||||||
|
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||||
|
that info explicitly as ValueInfoProtos.
|
||||||
|
Mutates the model.
|
||||||
|
Args:
|
||||||
|
model: The ModelProto to update.
|
||||||
|
"""
|
||||||
|
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||||
|
if model.ir_version < 4:
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||||
|
inputs = {i.name for i in graph.input}
|
||||||
|
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||||
|
for init in graph.initializer:
|
||||||
|
# Check it really is a constant, not an input
|
||||||
|
if init.name in inputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The details we want to add
|
||||||
|
elem_type = init.data_type
|
||||||
|
shape = init.dims
|
||||||
|
|
||||||
|
# Get existing or create new value info for this constant
|
||||||
|
vi = existing_info.get(init.name)
|
||||||
|
if vi is None:
|
||||||
|
vi = graph.value_info.add()
|
||||||
|
vi.name = init.name
|
||||||
|
|
||||||
|
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||||
|
tt = vi.type.tensor_type
|
||||||
|
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||||
|
tt.elem_type = elem_type
|
||||||
|
if not tt.HasField("shape"):
|
||||||
|
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||||
|
tt.shape.dim.extend([])
|
||||||
|
for dim in shape:
|
||||||
|
tt.shape.dim.add().dim_value = dim
|
||||||
|
|
||||||
|
# Handle subgraphs
|
||||||
|
for node in graph.node:
|
||||||
|
for attr in node.attribute:
|
||||||
|
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||||
|
if attr.ref_attr_name != "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPH:
|
||||||
|
add_const_value_infos_to_graph(attr.g)
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||||
|
for g in attr.graphs:
|
||||||
|
add_const_value_infos_to_graph(g)
|
||||||
|
|
||||||
|
|
||||||
|
return add_const_value_infos_to_graph(model.graph)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global input_dir, result_dir
|
||||||
|
|
||||||
|
modelname, batchsize, seqlen, \
|
||||||
|
export_path, input_dir, result_dir = parse_args()
|
||||||
|
|
||||||
|
model, voc_size = get_model(modelname) # pytorch model
|
||||||
|
|
||||||
|
if export_path is not None:
|
||||||
|
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
||||||
|
path = os.path.join(export_path, filename)
|
||||||
|
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||||
|
export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model
|
||||||
|
|
||||||
|
run_pytorch(model, voc_size, batchsize, seqlen, modelname)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,17 @@
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
models=("bert" "gpt2")
|
||||||
|
batch_size=(1 32)
|
||||||
|
seq_len=(100 500)
|
||||||
|
nproc=(1 2 4)
|
||||||
|
|
||||||
|
for model in "${models[@]}"; do
|
||||||
|
for bs in "${batch_size[@]}"; do
|
||||||
|
for len in "${seq_len[@]}"; do
|
||||||
|
python -m xacc run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" > results/"$model"_"$bs"_"$len"
|
||||||
|
for n in "${nproc[@]}"; do
|
||||||
|
python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n >> results/"$model"_"$bs"_"$len"
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
|
@ -5,8 +5,8 @@
|
||||||
#include "utils/operator_utils.h"
|
#include "utils/operator_utils.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
using json = nlohmann::json;
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
class RuntimeObj; // Forward declaration for Kernel::compute
|
class RuntimeObj; // Forward declaration for Kernel::compute
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
#include <nlohmann/json_fwd.hpp>
|
#include <nlohmann/json_fwd.hpp>
|
||||||
using json = nlohmann::json;
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
class PerfEngine {
|
class PerfEngine {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -20,14 +20,15 @@ class AllReduceXCCL : public KUNLUNKernelWithoutConfig {
|
||||||
BKCLContext_t comm =
|
BKCLContext_t comm =
|
||||||
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
||||||
.getXcclComm();
|
.getXcclComm();
|
||||||
double t = timeit(
|
// double t = timeit(
|
||||||
[&]() {
|
// [&]() {
|
||||||
checkXcclError(bkcl_all_reduce(comm, input, output, count,
|
checkXcclError(bkcl_all_reduce(comm, input, output, count,
|
||||||
BKCLDataType::BKCL_FLOAT,
|
BKCLDataType::BKCL_FLOAT, getRedOp(),
|
||||||
getRedOp(), 0));
|
0));
|
||||||
},
|
// },
|
||||||
[&]() { context->sync(); });
|
// [&]() { context->sync(); });
|
||||||
std::cout << "Time consuming for " << op->getInputs(0)->size() << " size is " << t << std::endl;
|
// std::cout << "Time consuming for " << op->getInputs(0)->size() << "
|
||||||
|
// size is " << t << std::endl;
|
||||||
}
|
}
|
||||||
virtual BKCLOp getRedOp() const = 0;
|
virtual BKCLOp getRedOp() const = 0;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue