fix tensor parallel for llama (#159)

* fix Slice

* change default rounds of timeit to 10 to reduce time

* fix slice with large ends

* Reshape support Int64

* support position_ids as input

* skip last MatMul in Llama

* skip infer_shapes to parse large model

* update launch.py

* fix split_concat_kernel

* print more message in launch.py

* Reshape supports both Int32 and Int64

* try infer_shapes and warn about failure

* fix format

---------

Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
constroy Li 2023-10-30 15:04:16 +08:00 committed by GitHub
parent 7f5188bedd
commit feccd4f318
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 40 deletions

View File

@ -5,6 +5,7 @@ import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend from pyinfinitensor.onnx import OnnxStub, backend
import onnx import onnx
from onnx.external_data_helper import convert_model_to_external_data from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np import numpy as np
from parallel_opt import parallel_model from parallel_opt import parallel_model
@ -44,16 +45,18 @@ def parse_args():
) )
def run_model(model, runtime, inputs: np.array, n=20): def run_model(model, runtime, inputs, n=10):
stub = OnnxStub(model, runtime) stub = OnnxStub(model, runtime)
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) for tensor, input in zip(stub.inputs.values(), inputs):
stub.tune() tensor.copyin_numpy(input)
# stub.tune()
stub.run() stub.run()
# get outputs # get outputs
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench # bench
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) for tensor, input in zip(stub.inputs.values(), inputs):
tensor.copyin_numpy(input)
begin = time.time() begin = time.time()
for _ in range(n): for _ in range(n):
stub.run() stub.run()
@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20):
def run_and_compare(name, model, runtime): def run_and_compare(name, model, runtime):
data = np.load(f"{name}_inputs.npy") input_ids = np.load(f"{name}_inputs.npy")
position_ids = np.arange(input_ids.shape[-1])
results = np.load(f"{name}_results.npy") results = np.load(f"{name}_results.npy")
outputs = run_model(model, runtime, data) outputs = run_model(model, runtime, (input_ids, position_ids))
print("outputs sum:", outputs.sum()) print("outputs abs mean:", abs(outputs).mean())
print("max abs diff:", abs(outputs - results).max()) np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
print("max rel diff:", abs((outputs - results) / results).max())
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
def start_worker( def start_worker(
@ -81,14 +83,13 @@ def start_worker(
extern_path = f"./{dist_name}_rank{rank}.pb" extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path): if os.path.exists(extern_path):
os.remove(extern_path) os.remove(extern_path)
convert_model_to_external_data( onnx.save_model(
model, model,
all_tensors_to_one_file=True, f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path, location=extern_path,
size_threshold=1024,
convert_attribute=False,
) )
onnx.save(model, f"./{dist_name}_rank{rank}.onnx") infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank) runtime = backend.CudaRuntime(local_rank)
# print("init comm") # print("init comm")
runtime.init_comm( runtime.init_comm(
@ -106,10 +107,12 @@ def start_single(name, model):
def gen_standard(name, model, voc_size, bs, len): def gen_standard(name, model, voc_size, bs, len):
# generate standard results # generate standard results
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) input_ids = np.random.randint(0, voc_size, (bs, len))
np.save(f"{name}_inputs", data) position_ids = np.arange(len)
np.save(f"{name}_inputs", input_ids)
runtime = backend.CudaRuntime(0) runtime = backend.CudaRuntime(0)
outputs = run_model(model, runtime, data, 1) outputs = run_model(model, runtime, (input_ids, position_ids), 1)
print("outputs abs mean:", abs(outputs).mean())
np.save(f"{name}_results", outputs) np.save(f"{name}_results", outputs)
@ -128,12 +131,14 @@ def main():
# run single process. # run single process.
# use standalone process to isolate cuda. # use standalone process to isolate cuda.
print("run model by single GPU.")
p = mp.Process(target=start_single, args=(name, model)) p = mp.Process(target=start_single, args=(name, model))
p.start() p.start()
p.join() p.join()
# run distributed parallel. # run distributed parallel.
world_size = nnodes * nproc_per_node world_size = nnodes * nproc_per_node
print(f"run model by {world_size} GPU in parallel.")
workers = [ workers = [
mp.Process( mp.Process(
target=start_worker, target=start_worker,

View File

@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
vinfo = {info.name: info for info in model.graph.value_info} vinfo = {info.name: info for info in model.graph.value_info}
vinfo.update({info.name: info for info in model.graph.input}) vinfo.update({info.name: info for info in model.graph.input})
vinfo.update({info.name: info for info in model.graph.output}) vinfo.update({info.name: info for info in model.graph.output})
output = {info.name: info for info in model.graph.output}
place: Dict[str, Placement] = {} place: Dict[str, Placement] = {}
nodes: List[NodeProto] = [] nodes: List[NodeProto] = []
@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
ndim = len(vinfo[output].type.tensor_type.shape.dim) ndim = len(vinfo[output].type.tensor_type.shape.dim)
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
place[node.output[0]] = out_plc place[node.output[0]] = out_plc
def shard_concat(node: NodeProto): def shard_concat(node: NodeProto):
# hack for kvcache # hack for kvcache
in_plc = place[node.input[1]] in_plc = place[node.input[1]]
@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
), f"{place[node.input[0]]} != {place[node.input[1]]}" ), f"{place[node.input[0]]} != {place[node.input[1]]}"
place[node.output[0]] = place[node.input[0]] place[node.output[0]] = place[node.input[0]]
elif node.op_type == "Concat": elif node.op_type == "Concat":
shard_concat(node) shard_concat(node)
def find_successor(op_type: str, idx: int, search_limit: int = 1): def find_successor(op_type: str, idx: int, search_limit: int = 1):
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
input in data for input in node.input input in data for input in node.input
): ):
# FIXME(constroy): the last MatMul should not be sharded as TP.
if node.output[0] in output:
continue
groups = 1 groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
split_node = find_successor("Split", index, search_limit=2) split_node = find_successor("Split", index, search_limit=2)
@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
new_input = [] new_input = []
for info in model.graph.input: for info in model.graph.input:
new_input.append(vinfo[info.name]) new_input.append(vinfo[info.name])
graph = helper.make_graph( graph = helper.make_graph(
nodes, nodes,
model.graph.name + f"_{tp_rank}", model.graph.name + f"_{tp_rank}",

View File

@ -75,7 +75,7 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
double timeit( double timeit(
const std::function<void()> &func, const std::function<void()> &func,
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200, const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
int timingRounds = 200); int timingRounds = 10);
} // namespace infini } // namespace infini

View File

@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce from functools import reduce
from onnxsim import simplify from onnxsim import simplify
import copy import copy
import warnings
class OnnxStub: class OnnxStub:
@ -48,7 +49,10 @@ class OnnxStub:
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {} self.initializer: Dict[int, TensorProto] = {}
model = infer_shapes(model) try:
model = infer_shapes(model)
except:
warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime) self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict() tensors: Dict[str, backend.Tensor] = dict()
@ -603,15 +607,20 @@ class OnnxStub:
!= 0, != 0,
) )
elif node.op_type == "Slice": elif node.op_type == "Slice":
def clamp(nums):
MAX_INT = 0x7FFFFFFF
return [min(x, MAX_INT) for x in nums]
tensors[node.output[0]] = self.handler.slice( tensors[node.output[0]] = self.handler.slice(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
_parse_data(data[node.input[1]]), clamp(_parse_data(data[node.input[1]])),
_parse_data(data[node.input[2]]), clamp(_parse_data(data[node.input[2]])),
_parse_data(data[node.input[3]]) clamp(_parse_data(data[node.input[3]]))
if len(node.input) > 3 if len(node.input) > 3
else None, else None,
_parse_data(data[node.input[4]]) clamp(_parse_data(data[node.input[4]]))
if len(node.input) > 4 if len(node.input) > 4
else None, else None,
) )

View File

@ -58,6 +58,7 @@ class matmulCublas : public Kernel {
SmallArray inputShape, outputShape; SmallArray inputShape, outputShape;
int nDims = out->getRank(); int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
// FIXME(constroy): use size_t for outputsize.
int outputsize = 1; // the length of the output vector after flatten int outputsize = 1; // the length of the output vector after flatten
int offset = nDims - inC->getRank(); int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i) for (int i = 0; i < offset; ++i)

View File

@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig {
// reshape/flatten/identity all act as copying from input to output. // reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
"Reshape_CUDA_Float32"); "Reshape_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
"Reshape_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
"Reshape_CUDA_Int32");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
"Flatten_CUDA_Float32"); "Flatten_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,

View File

@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
namespace infini { namespace infini {
// TODO: when dim=0, the operation can be executed in-place // TODO: when dim=0, the operation can be executed in-place
void split_concat_kernel(const ElementTensorMetadata &eleMeta, void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim, const ComposedTensorMetadata &compMeta, int dim,
int batchSize, int nDims, bool isSplit) { int batchSize, int nDims, bool isSplit) {
dim3 blockSize = dim3(32 * 16); dim3 blockSize = dim3(32 * 16);
// gridsize =n_elements / blockSize // gridsize = max_n_elements / blockSize
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1; int max_n_elements =
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
// each y is a split among the batch // each y is a split among the batch
dim3 gridSize(gridDimX, batchSize); dim3 gridSize(gridDimX, batchSize);

View File

@ -2,10 +2,10 @@
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
namespace infini { namespace infini {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { : OperatorObj(OpType::Concat, inputs, {output}) {
int rank = inputs[0]->getRank(); int rank = inputs[0]->getRank();
dim = get_real_axis(dim, rank); dim = get_real_axis(_dim, rank);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }

View File

@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
auto size = shape.size(); auto size = shape.size();
this->axes.reserve(size); this->axes.reserve(size);
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i) {
auto len = shape[i];
if (auto _i = axes.find(i); _i != axes.end()) { if (auto _i = axes.find(i); _i != axes.end()) {
auto __i = _i->second; auto __i = _i->second;
auto start = starts[__i]; auto start = starts[__i];
auto end = ends[__i]; auto end = ends[__i];
this->axes.push_back({start >= 0 ? start : start + shape[__i], if (start > len)
end >= 0 ? end : end + shape[__i], start = len;
steps[__i]}); if (end > len)
end = len;
this->axes.push_back({start >= 0 ? start : start + len,
end >= 0 ? end : end + len, steps[__i]});
} else { } else {
this->axes.push_back({0, shape[i], 1}); this->axes.push_back({0, len, 1});
} }
}
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }