fix tensor parallel for llama (#159)

* fix Slice

* change default rounds of timeit to 10 to reduce time

* fix slice with large ends

* Reshape support Int64

* support position_ids as input

* skip last MatMul in Llama

* skip infer_shapes to parse large model

* update launch.py

* fix split_concat_kernel

* print more message in launch.py

* Reshape supports both Int32 and Int64

* try infer_shapes and warn about failure

* fix format

---------

Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
constroy Li 2023-10-30 15:04:16 +08:00 committed by GitHub
parent 7f5188bedd
commit feccd4f318
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 40 deletions

View File

@ -5,6 +5,7 @@ import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
@ -44,16 +45,18 @@ def parse_args():
)
def run_model(model, runtime, inputs: np.array, n=20):
def run_model(model, runtime, inputs, n=10):
stub = OnnxStub(model, runtime)
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
stub.tune()
for tensor, input in zip(stub.inputs.values(), inputs):
tensor.copyin_numpy(input)
# stub.tune()
stub.run()
# get outputs
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
for tensor, input in zip(stub.inputs.values(), inputs):
tensor.copyin_numpy(input)
begin = time.time()
for _ in range(n):
stub.run()
@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20):
def run_and_compare(name, model, runtime):
data = np.load(f"{name}_inputs.npy")
input_ids = np.load(f"{name}_inputs.npy")
position_ids = np.arange(input_ids.shape[-1])
results = np.load(f"{name}_results.npy")
outputs = run_model(model, runtime, data)
print("outputs sum:", outputs.sum())
print("max abs diff:", abs(outputs - results).max())
print("max rel diff:", abs((outputs - results) / results).max())
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
outputs = run_model(model, runtime, (input_ids, position_ids))
print("outputs abs mean:", abs(outputs).mean())
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
def start_worker(
@ -81,14 +83,13 @@ def start_worker(
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
onnx.save_model(
model,
all_tensors_to_one_file=True,
f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank)
# print("init comm")
runtime.init_comm(
@ -106,10 +107,12 @@ def start_single(name, model):
def gen_standard(name, model, voc_size, bs, len):
# generate standard results
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
np.save(f"{name}_inputs", data)
input_ids = np.random.randint(0, voc_size, (bs, len))
position_ids = np.arange(len)
np.save(f"{name}_inputs", input_ids)
runtime = backend.CudaRuntime(0)
outputs = run_model(model, runtime, data, 1)
outputs = run_model(model, runtime, (input_ids, position_ids), 1)
print("outputs abs mean:", abs(outputs).mean())
np.save(f"{name}_results", outputs)
@ -128,12 +131,14 @@ def main():
# run single process.
# use standalone process to isolate cuda.
print("run model by single GPU.")
p = mp.Process(target=start_single, args=(name, model))
p.start()
p.join()
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"run model by {world_size} GPU in parallel.")
workers = [
mp.Process(
target=start_worker,

View File

@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
vinfo = {info.name: info for info in model.graph.value_info}
vinfo.update({info.name: info for info in model.graph.input})
vinfo.update({info.name: info for info in model.graph.output})
output = {info.name: info for info in model.graph.output}
place: Dict[str, Placement] = {}
nodes: List[NodeProto] = []
@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
ndim = len(vinfo[output].type.tensor_type.shape.dim)
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
place[node.output[0]] = out_plc
def shard_concat(node: NodeProto):
# hack for kvcache
in_plc = place[node.input[1]]
@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
), f"{place[node.input[0]]} != {place[node.input[1]]}"
place[node.output[0]] = place[node.input[0]]
elif node.op_type == "Concat":
shard_concat(node)
shard_concat(node)
def find_successor(op_type: str, idx: int, search_limit: int = 1):
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
input in data for input in node.input
):
# FIXME(constroy): the last MatMul should not be sharded as TP.
if node.output[0] in output:
continue
groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
split_node = find_successor("Split", index, search_limit=2)
@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
new_input = []
for info in model.graph.input:
new_input.append(vinfo[info.name])
graph = helper.make_graph(
nodes,
model.graph.name + f"_{tp_rank}",

View File

@ -75,7 +75,7 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
double timeit(
const std::function<void()> &func,
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
int timingRounds = 200);
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
int timingRounds = 10);
} // namespace infini

View File

@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
import warnings
class OnnxStub:
@ -48,7 +49,10 @@ class OnnxStub:
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}
model = infer_shapes(model)
try:
model = infer_shapes(model)
except:
warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict()
@ -603,15 +607,20 @@ class OnnxStub:
!= 0,
)
elif node.op_type == "Slice":
def clamp(nums):
MAX_INT = 0x7FFFFFFF
return [min(x, MAX_INT) for x in nums]
tensors[node.output[0]] = self.handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[2]]),
_parse_data(data[node.input[3]])
clamp(_parse_data(data[node.input[1]])),
clamp(_parse_data(data[node.input[2]])),
clamp(_parse_data(data[node.input[3]]))
if len(node.input) > 3
else None,
_parse_data(data[node.input[4]])
clamp(_parse_data(data[node.input[4]]))
if len(node.input) > 4
else None,
)

View File

@ -58,6 +58,7 @@ class matmulCublas : public Kernel {
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize = 1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)

View File

@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig {
// reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
"Reshape_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
"Reshape_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
"Reshape_CUDA_Int32");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
"Flatten_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,

View File

@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
namespace infini {
// TODO: when dim=0, the operation can be executed in-place
// TODO: when dim=0, the operation can be executed in-place
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16);
// gridsize =n_elements / blockSize
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
// gridsize = max_n_elements / blockSize
int max_n_elements =
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
// each y is a split among the batch
dim3 gridSize(gridDimX, batchSize);

View File

@ -2,10 +2,10 @@
#include "utils/operator_utils.h"
namespace infini {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
: OperatorObj(OpType::Concat, inputs, {output}) {
int rank = inputs[0]->getRank();
dim = get_real_axis(dim, rank);
dim = get_real_axis(_dim, rank);
IT_ASSERT(checkValid(graph));
}

View File

@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
auto size = shape.size();
this->axes.reserve(size);
for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < size; ++i) {
auto len = shape[i];
if (auto _i = axes.find(i); _i != axes.end()) {
auto __i = _i->second;
auto start = starts[__i];
auto end = ends[__i];
this->axes.push_back({start >= 0 ? start : start + shape[__i],
end >= 0 ? end : end + shape[__i],
steps[__i]});
if (start > len)
start = len;
if (end > len)
end = len;
this->axes.push_back({start >= 0 ? start : start + len,
end >= 0 ? end : end + len, steps[__i]});
} else {
this->axes.push_back({0, shape[i], 1});
this->axes.push_back({0, len, 1});
}
}
IT_ASSERT(checkValid(graph));
}