forked from jiuyuan/InfiniTensor
fix tensor parallel for llama (#159)
* fix Slice * change default rounds of timeit to 10 to reduce time * fix slice with large ends * Reshape support Int64 * support position_ids as input * skip last MatMul in Llama * skip infer_shapes to parse large model * update launch.py * fix split_concat_kernel * print more message in launch.py * Reshape supports both Int32 and Int64 * try infer_shapes and warn about failure * fix format --------- Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
parent
7f5188bedd
commit
feccd4f318
|
@ -5,6 +5,7 @@ import multiprocessing as mp
|
|||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
@ -44,16 +45,18 @@ def parse_args():
|
|||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, inputs: np.array, n=20):
|
||||
def run_model(model, runtime, inputs, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
stub.tune()
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
tensor.copyin_numpy(input)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
tensor.copyin_numpy(input)
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
|
@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20):
|
|||
|
||||
|
||||
def run_and_compare(name, model, runtime):
|
||||
data = np.load(f"{name}_inputs.npy")
|
||||
input_ids = np.load(f"{name}_inputs.npy")
|
||||
position_ids = np.arange(input_ids.shape[-1])
|
||||
results = np.load(f"{name}_results.npy")
|
||||
outputs = run_model(model, runtime, data)
|
||||
print("outputs sum:", outputs.sum())
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
print("max rel diff:", abs((outputs - results) / results).max())
|
||||
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids))
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
def start_worker(
|
||||
|
@ -81,14 +83,13 @@ def start_worker(
|
|||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
onnx.save_model(
|
||||
model,
|
||||
all_tensors_to_one_file=True,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
|
@ -106,10 +107,12 @@ def start_single(name, model):
|
|||
|
||||
def gen_standard(name, model, voc_size, bs, len):
|
||||
# generate standard results
|
||||
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||
np.save(f"{name}_inputs", data)
|
||||
input_ids = np.random.randint(0, voc_size, (bs, len))
|
||||
position_ids = np.arange(len)
|
||||
np.save(f"{name}_inputs", input_ids)
|
||||
runtime = backend.CudaRuntime(0)
|
||||
outputs = run_model(model, runtime, data, 1)
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids), 1)
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
np.save(f"{name}_results", outputs)
|
||||
|
||||
|
||||
|
@ -128,12 +131,14 @@ def main():
|
|||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single GPU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} GPU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
|
|
|
@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
vinfo = {info.name: info for info in model.graph.value_info}
|
||||
vinfo.update({info.name: info for info in model.graph.input})
|
||||
vinfo.update({info.name: info for info in model.graph.output})
|
||||
output = {info.name: info for info in model.graph.output}
|
||||
place: Dict[str, Placement] = {}
|
||||
nodes: List[NodeProto] = []
|
||||
|
||||
|
@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
||||
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
||||
place[node.output[0]] = out_plc
|
||||
|
||||
|
||||
def shard_concat(node: NodeProto):
|
||||
# hack for kvcache
|
||||
in_plc = place[node.input[1]]
|
||||
|
@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type == "Concat":
|
||||
shard_concat(node)
|
||||
shard_concat(node)
|
||||
|
||||
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
||||
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
||||
|
@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||
input in data for input in node.input
|
||||
):
|
||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||
if node.output[0] in output:
|
||||
continue
|
||||
groups = 1
|
||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||
split_node = find_successor("Split", index, search_limit=2)
|
||||
|
@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
new_input = []
|
||||
for info in model.graph.input:
|
||||
new_input.append(vinfo[info.name])
|
||||
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
model.graph.name + f"_{tp_rank}",
|
||||
|
|
|
@ -75,7 +75,7 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
|||
|
||||
double timeit(
|
||||
const std::function<void()> &func,
|
||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
|
||||
int timingRounds = 200);
|
||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
|
||||
int timingRounds = 10);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
|||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -48,7 +49,10 @@ class OnnxStub:
|
|||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
model = infer_shapes(model)
|
||||
try:
|
||||
model = infer_shapes(model)
|
||||
except:
|
||||
warnings.warn("infer_shapes failed.")
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
|
@ -603,15 +607,20 @@ class OnnxStub:
|
|||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
|
||||
def clamp(nums):
|
||||
MAX_INT = 0x7FFFFFFF
|
||||
return [min(x, MAX_INT) for x in nums]
|
||||
|
||||
tensors[node.output[0]] = self.handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[2]]),
|
||||
_parse_data(data[node.input[3]])
|
||||
clamp(_parse_data(data[node.input[1]])),
|
||||
clamp(_parse_data(data[node.input[2]])),
|
||||
clamp(_parse_data(data[node.input[3]]))
|
||||
if len(node.input) > 3
|
||||
else None,
|
||||
_parse_data(data[node.input[4]])
|
||||
clamp(_parse_data(data[node.input[4]]))
|
||||
if len(node.input) > 4
|
||||
else None,
|
||||
)
|
||||
|
|
|
@ -58,6 +58,7 @@ class matmulCublas : public Kernel {
|
|||
SmallArray inputShape, outputShape;
|
||||
int nDims = out->getRank();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
// FIXME(constroy): use size_t for outputsize.
|
||||
int outputsize = 1; // the length of the output vector after flatten
|
||||
int offset = nDims - inC->getRank();
|
||||
for (int i = 0; i < offset; ++i)
|
||||
|
|
|
@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
// reshape/flatten/identity all act as copying from input to output.
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
|
||||
"Reshape_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
|
||||
"Reshape_CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
||||
"Reshape_CUDA_Int32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||
"Flatten_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||
|
|
|
@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
|||
|
||||
namespace infini {
|
||||
|
||||
// TODO: when dim=0, the operation can be executed in-place
|
||||
// TODO: when dim=0, the operation can be executed in-place
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit) {
|
||||
dim3 blockSize = dim3(32 * 16);
|
||||
// gridsize =n_elements / blockSize
|
||||
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
|
||||
// gridsize = max_n_elements / blockSize
|
||||
int max_n_elements =
|
||||
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
|
||||
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
|
||||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
|
||||
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) {
|
||||
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
|
||||
: OperatorObj(OpType::Concat, inputs, {output}) {
|
||||
int rank = inputs[0]->getRank();
|
||||
dim = get_real_axis(dim, rank);
|
||||
dim = get_real_axis(_dim, rank);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
|
@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
|
||||
auto size = shape.size();
|
||||
this->axes.reserve(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto len = shape[i];
|
||||
if (auto _i = axes.find(i); _i != axes.end()) {
|
||||
auto __i = _i->second;
|
||||
auto start = starts[__i];
|
||||
auto end = ends[__i];
|
||||
this->axes.push_back({start >= 0 ? start : start + shape[__i],
|
||||
end >= 0 ? end : end + shape[__i],
|
||||
steps[__i]});
|
||||
if (start > len)
|
||||
start = len;
|
||||
if (end > len)
|
||||
end = len;
|
||||
this->axes.push_back({start >= 0 ? start : start + len,
|
||||
end >= 0 ? end : end + len, steps[__i]});
|
||||
} else {
|
||||
this->axes.push_back({0, shape[i], 1});
|
||||
this->axes.push_back({0, len, 1});
|
||||
}
|
||||
}
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue