forked from jiuyuan/InfiniTensor
kunlun dist inference fix
This commit is contained in:
parent
54a35772fb
commit
a71cd14963
2
Makefile
2
Makefile
|
@ -40,7 +40,7 @@ endif
|
|||
|
||||
build:
|
||||
mkdir -p build/$(TYPE)
|
||||
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8
|
||||
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j64
|
||||
|
||||
clean:
|
||||
rm -rf build
|
||||
|
|
|
@ -1,213 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
st_input_dir = "standard/inputs/"
|
||||
st_output_dir = "standard/outputs/"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_single",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether run model with single process with standard inputs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
args.run_single
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
load_inputs(stub, world_size, rank)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
time.sleep(0.01)
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||
results = np.load(os.path.join(st_output_dir,f"output.npy"))
|
||||
outputs = run_model(model, runtime, world_size, rank)
|
||||
print(outputs[:100])
|
||||
if np.isnan(outputs).any():
|
||||
print("Nan in output")
|
||||
print("answer argmax:", np.argmax(results))
|
||||
print("output argmax:", np.argmax(outputs))
|
||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||
getDiff(results, outputs)
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
onnx.save_model(
|
||||
model,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.KUNLUNRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, world_size, rank)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def generate_input_output(model):
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
position_id = 0
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = tensor.copyout_numpy()
|
||||
if np.issubdtype(input.dtype, np.integer):
|
||||
if input.size == 1:
|
||||
# input = np.array([position_id])
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
else:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
elif input.dtype == np.bool_:
|
||||
input = np.random.randint(0,2,size=input.shape) > 0
|
||||
else:
|
||||
if i == 0:
|
||||
input = np.ones(input.shape).astype(input.dtype)
|
||||
position_id = input.shape[-1] - 1
|
||||
else:
|
||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||
tensor.copyin_numpy(input)
|
||||
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
|
||||
stub.run()
|
||||
# print(stub.outputs)
|
||||
time.sleep(0.01)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
print(output[:100])
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
np.save(os.path.join(st_output_dir, f"output"), output)
|
||||
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(os.path.join(st_input_dir, f"input_{i}.npy"))
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
|
||||
def getDiff(base, test):
|
||||
absolute_diff = np.abs(np.subtract(base, test))
|
||||
max_absolute_diff = np.max(absolute_diff)
|
||||
|
||||
baseCopy = base.astype(np.float64).ravel()
|
||||
testCopy = test.astype(np.float64).ravel()
|
||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||
max_relative_diff = upValue / downValue
|
||||
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
|
||||
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
p = mp.Process(target=generate_input_output, args=[model])
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# # run single process.
|
||||
# # use standalone process to isolate cuda.
|
||||
if run_single:
|
||||
print("run model by single GPU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} GPU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -80,6 +80,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
|
||||
def shard_reshape(node: NodeProto):
|
||||
# print("reshape", node.name, node.input[0], place[node.input[0]])
|
||||
# import pdb; pdb.set_trace()
|
||||
if not is_sharded(node.input[0]):
|
||||
return
|
||||
in_plc = place[node.input[0]]
|
||||
|
@ -110,7 +111,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
s_dim = 0
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 1
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
assert s_dim != -1
|
||||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||
out_dims[s_dim] //= tp_world_size
|
||||
|
@ -246,3 +247,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
model = helper.make_model(graph)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = onnx.load("./models/gpt2/gpt2_1_100.onnx")
|
||||
models = parallel_model(model, 2, 0)
|
|
@ -947,7 +947,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Constant":
|
||||
elif node.op_type in ["Constant", "ConstantOfShape"]:
|
||||
output_name = node.output[0]
|
||||
attributes = _parse_attribute(node)
|
||||
tensor = attributes["value"]
|
||||
|
|
|
@ -20,9 +20,14 @@ class AllReduceXCCL : public KUNLUNKernelWithoutConfig {
|
|||
BKCLContext_t comm =
|
||||
dynamic_cast<XcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getXcclComm();
|
||||
checkXcclError(bkcl_all_reduce(comm, input, output, count,
|
||||
BKCLDataType::BKCL_FLOAT, getRedOp(),
|
||||
0));
|
||||
double t = timeit(
|
||||
[&]() {
|
||||
checkXcclError(bkcl_all_reduce(comm, input, output, count,
|
||||
BKCLDataType::BKCL_FLOAT,
|
||||
getRedOp(), 0));
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
std::cout << "Time consuming for " << op->getInputs(0)->size() << " size is " << t << std::endl;
|
||||
}
|
||||
virtual BKCLOp getRedOp() const = 0;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue