kunlun distributed

This commit is contained in:
weijie01 2024-04-02 17:15:08 +08:00
parent a71cd14963
commit 32a13b7760
9 changed files with 732 additions and 13 deletions

View File

@ -285,8 +285,8 @@ if(USE_KUNLUN)
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
include_directories("${KUNLUN_HOME}/include/")
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/")
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/")
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64/")
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/lib64/")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))

View File

@ -9,7 +9,7 @@ find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR}
HINTS XCCL_INCLUDE_DIR)
find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR}
NAMES so/libbkcl.so
NAMES lib64/libbkcl.so
HINTS XCCL_LIB_DIR)
message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}")

View File

@ -0,0 +1,278 @@
import sys
sys.path.append('../')
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
from functools import wraps
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model", type=str, default="", help="path to the ONNX model file."
)
parser.add_argument(
"--gen_std",
default=False,
action="store_true",
help="whether to generate the standard results.",
)
parser.add_argument(
"--run_single",
default=False,
action="store_true",
help="whether run model with single process with standard inputs"
)
parser.add_argument(
"--input_dir",
default="./",
help="path to save model input data"
)
parser.add_argument(
"--result_dir",
default="./",
help="path to save model standard output"
)
parser.add_argument(
"--internal_model_dir",
default="./",
help="path to save internal onnx model for parallel run"
)
args = parser.parse_args()
# check path, mkdir if not exist
check_exists(args.input_dir)
check_exists(args.result_dir)
check_exists(args.internal_model_dir)
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.gen_std,
args.run_single,
args.input_dir,
args.result_dir,
args.internal_model_dir
)
"""
utils function for this scripts
"""
def check_exists(path: str):
if not os.path.exists(path):
os.makedirs(path)
def np_assert(base, test, rtol=1e-2, atol=1e-1):
# np.testing.assert_allclose(test, base, rtol, atol)
print("max abs diff:", abs(base - test).max())
"""
Perf wrapper, run function n times
then average
"""
def perf_it(n):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# warmup
for _ in range(n):
func(*args, **kwargs)
t_total = 0
for _ in range(n):
t0 = time.time()
func(*args, **kwargs)
t1 = time.time()
t_total += t1 - t0
avg_time = (t_total) / n
print(f"Avg runtime of {n} time is {avg_time:.6f} seconds")
return avg_time
return wrapper
return decorator
"""
Run InfiniTensor model with Standard input
check=True: check with standard output gen by pytorch
perf=True: run n times to get avg time
"""
def run_model(task_name,
model,
runtime,
world_size=1,
rank=0,
n=10,
check=True,
perf=True):
stub = OnnxStub(model, runtime)
# load in Onnx model inputs
def load_inputs(stub: OnnxStub):
# check exists
inputs = []
for i, (name, tensor) in enumerate(stub.inputs.items()):
input_path = os.path.join(input_dir, \
f"{task_name}_input_{i}.npy")
print(input_path)
if os.path.exists(input_path):
input = np.load(input_path)
else :
raise KeyError(f"{i} th input of model not exists")
# check shape
if all(x == y for x,y in zip(input.shape, tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
load_inputs(stub)
# stub.tune()
stub.run()
time.sleep(0.01)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
# check output results with standard output
if check:
st_output_path = os.path.join(result_dir, \
f"{task_name}_output.npy")
assert os.path.exists(st_output_path) , \
"standard output not exists"
st_output = np.load(st_output_path)
if np.isnan(output).any():
print("Nan in output")
exit()
np_assert(st_output, output)
# perf
if perf:
@perf_it(n)
def perf_infinitensor(stub: OnnxStub):
stub.run()
perf_infinitensor(stub)
return output
"""
Start a worker in Parallel
"""
def start_worker(name: str,
world_size: int,
rank: int,
local_rank: int,
model: onnx.ModelProto):
dist_name = name + "_dist"
# partial a onnx model to world_size part
model = parallel_model(model, world_size, rank)
onnx.save(model, os.path.join(internal_model_dir, \
f"{dist_name}_rank{rank}.onnx"))
runtime = backend.KUNLUNRuntime(local_rank)
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
run_model(name, model, runtime, world_size, rank)
"""
generate standard input/output with
sigle card run
"""
def gen_stardard(task_name: str, model: onnx.ModelProto):
runtime = backend.KUNLUNRuntime(0)
stub = OnnxStub(model, runtime)
position_id = 0
# generate random input for model
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = tensor.copyout_numpy()
if np.issubdtype(input.dtype, np.integer):
if input.size == 1:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
else:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
elif input.dtype == np.bool_:
input = np.random.randint(0,2,size=input.shape) > 0
else:
if i == 0:
input = np.ones(input.shape).astype(input.dtype)
position_id = input.shape[-1] - 1
else:
input = np.random.rand(*input.shape).astype(input.dtype)
tensor.copyin_numpy(input)
np.save(os.path.join(input_dir, \
f"{task_name}_input_{i}.npy"), input)
stub.run()
# print(stub.outputs)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
if np.isnan(output).any():
print("Nan in output")
exit()
np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output)
def main():
global input_dir, result_dir, internal_model_dir
nnodes, nproc_per_node, task_name, \
model_path, gen_std, run_single, \
input_dir, result_dir, internal_model_dir = parse_args()
# load input onnx model
model = onnx.load(model_path)
# generate standart output
if gen_std:
print("Generate inputs and outputs.")
gen_stardard(task_name, model)
return
if run_single:
print("Run model by one GPU card.")
runtime = backend.KUNLUNRuntime(0)
run_model(task_name, model, runtime)
return
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"Run model by {world_size} GPU in parallel.")
workers = [
mp.Process(
target=start_worker,
args=(task_name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,215 @@
import sys
sys.path.append("../")
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
st_input_dir = ".cache/input/"
st_output_dir = ".cache/output/"
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
default=False,
action="store_true",
help="whether to generate the standard results.",
)
parser.add_argument(
"--run_single",
default=False,
action="store_true",
help="whether run model with single process with standard inputs"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.batch_size,
args.length,
args.gen_std,
args.run_single
)
def run_model(model, runtime, world_size=1, rank=0, n=10):
stub = OnnxStub(model, runtime)
load_inputs(stub, world_size, rank)
# stub.tune()
stub.run()
# get outputs
time.sleep(0.01)
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench
begin = time.time()
for _ in range(n):
stub.run()
end = time.time()
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
return outputs
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
results = np.load(os.path.join(st_output_dir, "test_output.npy"))
outputs = run_model(model, runtime, world_size, rank)
print(outputs[:100])
if np.isnan(outputs).any():
print("Nan in output")
print("answer argmax:", np.argmax(results))
print("output argmax:", np.argmax(outputs))
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
getDiff(results, outputs)
def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
):
dist_name = name + "_dist"
model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
onnx.save_model(
model,
f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path,
)
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.KUNLUNRuntime(local_rank)
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
run_and_compare(name, model, runtime, world_size, rank)
def start_single(name, model):
runtime = backend.KUNLUNRuntime(0)
run_and_compare(name, model, runtime)
def generate_input_output(model):
runtime = backend.KUNLUNRuntime(0)
stub = OnnxStub(model, runtime)
position_id = 0
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = tensor.copyout_numpy()
if np.issubdtype(input.dtype, np.integer):
if input.size == 1:
# input = np.array([position_id])
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
else:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
elif input.dtype == np.bool_:
input = np.random.randint(0,2,size=input.shape) > 0
else:
if i == 0:
input = np.ones(input.shape).astype(input.dtype)
position_id = input.shape[-1] - 1
else:
input = np.random.rand(*input.shape).astype(input.dtype)
tensor.copyin_numpy(input)
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
stub.run()
# print(stub.outputs)
time.sleep(0.01)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
print(output[:100])
if np.isnan(output).any():
print("Nan in output")
np.save(os.path.join(st_output_dir, f"output"), output)
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(os.path.join(st_input_dir, f"test_input_{name}.npy"))
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def getDiff(base, test):
absolute_diff = np.abs(np.subtract(base, test))
max_absolute_diff = np.max(absolute_diff)
baseCopy = base.astype(np.float64).ravel()
testCopy = test.astype(np.float64).ravel()
upValue = np.sum(np.abs(baseCopy - testCopy))
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
max_relative_diff = upValue / downValue
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff
def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
model = onnx.load(model_path)
# generate standart output
if gen_std:
print("Generate inputs and outputs.")
p = mp.Process(target=generate_input_output, args=[model])
p.start()
p.join()
return
# # run single process.
# # use standalone process to isolate cuda.
if run_single:
print("run model by single GPU.")
p = mp.Process(target=start_single, args=(name, model))
p.start()
p.join()
return
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"run model by {world_size} GPU in parallel.")
workers = [
mp.Process(
target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,208 @@
import argparse
import torch
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
from transformers import LlamaModel, LlamaConfig
import time
import numpy as np
import onnx
import os
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
parser.add_argument(
"--input_dir",
type=str,
default="./",
help="path to save pytorch model input data"
)
parser.add_argument(
"--result_dir",
type=str,
default="./",
help="path to save pytorch model output data"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx,
args.input_dir,
args.result_dir
)
def get_model(modelname):
if modelname == "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
elif modelname == "gpt2":
model = GPT2Model.from_pretrained("gpt2")
voc_size = GPT2Config().vocab_size
elif modelname == "opt":
model = OPTModel.from_pretrained("./opt-125m")
voc_size = OPTConfig().vocab_size
elif modelname == "llama":
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
voc_size = LlamaConfig().vocab_size
else :
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len, model_name):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data)
inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda")
n_iter = 20
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.cuda.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.cuda.synchronize()
outputs = torch_model(inputs)
#
torch.cuda.synchronize()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
torch.cuda.memory.empty_cache()
np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \
np.array(outputs))
print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy")
def export_onnx(model_name, model, data, path, extern=False):
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
onnx_model = onnx.load(path)
# onnx_model, check = simplify(onnx_model,
# skip_shape_inference=True,
# skipped_optimizers=['eliminate_duplicate_initializer'])
if model_name == "gpt2":
onnx_model, check = simplify(onnx_model,
skip_shape_inference=True,
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
else :
onnx_model, check = simplify(onnx_model,
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path.split("/")[-1],
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
global input_dir, result_dir
modelname, batchsize, seqlen, \
export_path, input_dir, result_dir = parse_args()
model, voc_size = get_model(modelname) # pytorch model
if export_path is not None:
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model
run_pytorch(model, voc_size, batchsize, seqlen, modelname)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,17 @@
export HF_ENDPOINT=https://hf-mirror.com
models=("bert" "gpt2")
batch_size=(1 32)
seq_len=(100 500)
nproc=(1 2 4)
for model in "${models[@]}"; do
for bs in "${batch_size[@]}"; do
for len in "${seq_len[@]}"; do
python -m xacc run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" > results/"$model"_"$bs"_"$len"
for n in "${nproc[@]}"; do
python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n >> results/"$model"_"$bs"_"$len"
done
done
done
done

View File

@ -5,8 +5,8 @@
#include "utils/operator_utils.h"
#include <functional>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
namespace infini {
using json = nlohmann::json;
class RuntimeObj; // Forward declaration for Kernel::compute

View File

@ -2,8 +2,8 @@
#include "core/graph.h"
#include "core/kernel.h"
#include <nlohmann/json_fwd.hpp>
using json = nlohmann::json;
namespace infini {
using json = nlohmann::json;
class PerfEngine {
public:

View File

@ -20,14 +20,15 @@ class AllReduceXCCL : public KUNLUNKernelWithoutConfig {
BKCLContext_t comm =
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
.getXcclComm();
double t = timeit(
[&]() {
checkXcclError(bkcl_all_reduce(comm, input, output, count,
BKCLDataType::BKCL_FLOAT,
getRedOp(), 0));
},
[&]() { context->sync(); });
std::cout << "Time consuming for " << op->getInputs(0)->size() << " size is " << t << std::endl;
// double t = timeit(
// [&]() {
checkXcclError(bkcl_all_reduce(comm, input, output, count,
BKCLDataType::BKCL_FLOAT, getRedOp(),
0));
// },
// [&]() { context->sync(); });
// std::cout << "Time consuming for " << op->getInputs(0)->size() << "
// size is " << t << std::endl;
}
virtual BKCLOp getRedOp() const = 0;
};