forked from jiuyuan/InfiniTensor
fix tensor parallel for llama (#159)
* fix Slice * change default rounds of timeit to 10 to reduce time * fix slice with large ends * Reshape support Int64 * support position_ids as input * skip last MatMul in Llama * skip infer_shapes to parse large model * update launch.py * fix split_concat_kernel * print more message in launch.py * Reshape supports both Int32 and Int64 * try infer_shapes and warn about failure * fix format --------- Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
parent
7f5188bedd
commit
feccd4f318
|
@ -5,6 +5,7 @@ import multiprocessing as mp
|
||||||
from pyinfinitensor.onnx import OnnxStub, backend
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
import onnx
|
import onnx
|
||||||
from onnx.external_data_helper import convert_model_to_external_data
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parallel_opt import parallel_model
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
@ -44,16 +45,18 @@ def parse_args():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, runtime, inputs: np.array, n=20):
|
def run_model(model, runtime, inputs, n=10):
|
||||||
stub = OnnxStub(model, runtime)
|
stub = OnnxStub(model, runtime)
|
||||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||||
stub.tune()
|
tensor.copyin_numpy(input)
|
||||||
|
# stub.tune()
|
||||||
stub.run()
|
stub.run()
|
||||||
# get outputs
|
# get outputs
|
||||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
begin = time.time()
|
begin = time.time()
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
stub.run()
|
stub.run()
|
||||||
|
@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20):
|
||||||
|
|
||||||
|
|
||||||
def run_and_compare(name, model, runtime):
|
def run_and_compare(name, model, runtime):
|
||||||
data = np.load(f"{name}_inputs.npy")
|
input_ids = np.load(f"{name}_inputs.npy")
|
||||||
|
position_ids = np.arange(input_ids.shape[-1])
|
||||||
results = np.load(f"{name}_results.npy")
|
results = np.load(f"{name}_results.npy")
|
||||||
outputs = run_model(model, runtime, data)
|
outputs = run_model(model, runtime, (input_ids, position_ids))
|
||||||
print("outputs sum:", outputs.sum())
|
print("outputs abs mean:", abs(outputs).mean())
|
||||||
print("max abs diff:", abs(outputs - results).max())
|
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
|
||||||
print("max rel diff:", abs((outputs - results) / results).max())
|
|
||||||
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def start_worker(
|
def start_worker(
|
||||||
|
@ -81,14 +83,13 @@ def start_worker(
|
||||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
if os.path.exists(extern_path):
|
if os.path.exists(extern_path):
|
||||||
os.remove(extern_path)
|
os.remove(extern_path)
|
||||||
convert_model_to_external_data(
|
onnx.save_model(
|
||||||
model,
|
model,
|
||||||
all_tensors_to_one_file=True,
|
f"./{dist_name}_rank{rank}.onnx",
|
||||||
|
save_as_external_data=True,
|
||||||
location=extern_path,
|
location=extern_path,
|
||||||
size_threshold=1024,
|
|
||||||
convert_attribute=False,
|
|
||||||
)
|
)
|
||||||
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
runtime = backend.CudaRuntime(local_rank)
|
runtime = backend.CudaRuntime(local_rank)
|
||||||
# print("init comm")
|
# print("init comm")
|
||||||
runtime.init_comm(
|
runtime.init_comm(
|
||||||
|
@ -106,10 +107,12 @@ def start_single(name, model):
|
||||||
|
|
||||||
def gen_standard(name, model, voc_size, bs, len):
|
def gen_standard(name, model, voc_size, bs, len):
|
||||||
# generate standard results
|
# generate standard results
|
||||||
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
input_ids = np.random.randint(0, voc_size, (bs, len))
|
||||||
np.save(f"{name}_inputs", data)
|
position_ids = np.arange(len)
|
||||||
|
np.save(f"{name}_inputs", input_ids)
|
||||||
runtime = backend.CudaRuntime(0)
|
runtime = backend.CudaRuntime(0)
|
||||||
outputs = run_model(model, runtime, data, 1)
|
outputs = run_model(model, runtime, (input_ids, position_ids), 1)
|
||||||
|
print("outputs abs mean:", abs(outputs).mean())
|
||||||
np.save(f"{name}_results", outputs)
|
np.save(f"{name}_results", outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,12 +131,14 @@ def main():
|
||||||
|
|
||||||
# run single process.
|
# run single process.
|
||||||
# use standalone process to isolate cuda.
|
# use standalone process to isolate cuda.
|
||||||
|
print("run model by single GPU.")
|
||||||
p = mp.Process(target=start_single, args=(name, model))
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
p.start()
|
p.start()
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
# run distributed parallel.
|
# run distributed parallel.
|
||||||
world_size = nnodes * nproc_per_node
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"run model by {world_size} GPU in parallel.")
|
||||||
workers = [
|
workers = [
|
||||||
mp.Process(
|
mp.Process(
|
||||||
target=start_worker,
|
target=start_worker,
|
||||||
|
|
|
@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
vinfo = {info.name: info for info in model.graph.value_info}
|
vinfo = {info.name: info for info in model.graph.value_info}
|
||||||
vinfo.update({info.name: info for info in model.graph.input})
|
vinfo.update({info.name: info for info in model.graph.input})
|
||||||
vinfo.update({info.name: info for info in model.graph.output})
|
vinfo.update({info.name: info for info in model.graph.output})
|
||||||
|
output = {info.name: info for info in model.graph.output}
|
||||||
place: Dict[str, Placement] = {}
|
place: Dict[str, Placement] = {}
|
||||||
nodes: List[NodeProto] = []
|
nodes: List[NodeProto] = []
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
||||||
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
||||||
place[node.output[0]] = out_plc
|
place[node.output[0]] = out_plc
|
||||||
|
|
||||||
def shard_concat(node: NodeProto):
|
def shard_concat(node: NodeProto):
|
||||||
# hack for kvcache
|
# hack for kvcache
|
||||||
in_plc = place[node.input[1]]
|
in_plc = place[node.input[1]]
|
||||||
|
@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||||
place[node.output[0]] = place[node.input[0]]
|
place[node.output[0]] = place[node.input[0]]
|
||||||
elif node.op_type == "Concat":
|
elif node.op_type == "Concat":
|
||||||
shard_concat(node)
|
shard_concat(node)
|
||||||
|
|
||||||
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
||||||
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
||||||
|
@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||||
input in data for input in node.input
|
input in data for input in node.input
|
||||||
):
|
):
|
||||||
|
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||||
|
if node.output[0] in output:
|
||||||
|
continue
|
||||||
groups = 1
|
groups = 1
|
||||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||||
split_node = find_successor("Split", index, search_limit=2)
|
split_node = find_successor("Split", index, search_limit=2)
|
||||||
|
@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
new_input = []
|
new_input = []
|
||||||
for info in model.graph.input:
|
for info in model.graph.input:
|
||||||
new_input.append(vinfo[info.name])
|
new_input.append(vinfo[info.name])
|
||||||
|
|
||||||
graph = helper.make_graph(
|
graph = helper.make_graph(
|
||||||
nodes,
|
nodes,
|
||||||
model.graph.name + f"_{tp_rank}",
|
model.graph.name + f"_{tp_rank}",
|
||||||
|
|
|
@ -75,7 +75,7 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
||||||
|
|
||||||
double timeit(
|
double timeit(
|
||||||
const std::function<void()> &func,
|
const std::function<void()> &func,
|
||||||
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
|
const std::function<void(void)> &sync = []() {}, int warmupRounds = 10,
|
||||||
int timingRounds = 200);
|
int timingRounds = 10);
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from onnxsim import simplify
|
from onnxsim import simplify
|
||||||
import copy
|
import copy
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class OnnxStub:
|
class OnnxStub:
|
||||||
|
@ -48,7 +49,10 @@ class OnnxStub:
|
||||||
self.inputs: Dict[str, backend.Tensor] = {}
|
self.inputs: Dict[str, backend.Tensor] = {}
|
||||||
self.outputs: Dict[str, backend.Tensor] = {}
|
self.outputs: Dict[str, backend.Tensor] = {}
|
||||||
self.initializer: Dict[int, TensorProto] = {}
|
self.initializer: Dict[int, TensorProto] = {}
|
||||||
model = infer_shapes(model)
|
try:
|
||||||
|
model = infer_shapes(model)
|
||||||
|
except:
|
||||||
|
warnings.warn("infer_shapes failed.")
|
||||||
self.handler = backend.GraphHandler(runtime)
|
self.handler = backend.GraphHandler(runtime)
|
||||||
|
|
||||||
tensors: Dict[str, backend.Tensor] = dict()
|
tensors: Dict[str, backend.Tensor] = dict()
|
||||||
|
@ -603,15 +607,20 @@ class OnnxStub:
|
||||||
!= 0,
|
!= 0,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Slice":
|
elif node.op_type == "Slice":
|
||||||
|
|
||||||
|
def clamp(nums):
|
||||||
|
MAX_INT = 0x7FFFFFFF
|
||||||
|
return [min(x, MAX_INT) for x in nums]
|
||||||
|
|
||||||
tensors[node.output[0]] = self.handler.slice(
|
tensors[node.output[0]] = self.handler.slice(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
_parse_data(data[node.input[1]]),
|
clamp(_parse_data(data[node.input[1]])),
|
||||||
_parse_data(data[node.input[2]]),
|
clamp(_parse_data(data[node.input[2]])),
|
||||||
_parse_data(data[node.input[3]])
|
clamp(_parse_data(data[node.input[3]]))
|
||||||
if len(node.input) > 3
|
if len(node.input) > 3
|
||||||
else None,
|
else None,
|
||||||
_parse_data(data[node.input[4]])
|
clamp(_parse_data(data[node.input[4]]))
|
||||||
if len(node.input) > 4
|
if len(node.input) > 4
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -58,6 +58,7 @@ class matmulCublas : public Kernel {
|
||||||
SmallArray inputShape, outputShape;
|
SmallArray inputShape, outputShape;
|
||||||
int nDims = out->getRank();
|
int nDims = out->getRank();
|
||||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
// FIXME(constroy): use size_t for outputsize.
|
||||||
int outputsize = 1; // the length of the output vector after flatten
|
int outputsize = 1; // the length of the output vector after flatten
|
||||||
int offset = nDims - inC->getRank();
|
int offset = nDims - inC->getRank();
|
||||||
for (int i = 0; i < offset; ++i)
|
for (int i = 0; i < offset; ++i)
|
||||||
|
|
|
@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
||||||
// reshape/flatten/identity all act as copying from input to output.
|
// reshape/flatten/identity all act as copying from input to output.
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
|
||||||
"Reshape_CUDA_Float32");
|
"Reshape_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
|
||||||
|
"Reshape_CUDA_Int64");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
||||||
|
"Reshape_CUDA_Int32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||||
"Flatten_CUDA_Float32");
|
"Flatten_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||||
|
|
|
@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
// TODO: when dim=0, the operation can be executed in-place
|
// TODO: when dim=0, the operation can be executed in-place
|
||||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||||
const ComposedTensorMetadata &compMeta, int dim,
|
const ComposedTensorMetadata &compMeta, int dim,
|
||||||
int batchSize, int nDims, bool isSplit) {
|
int batchSize, int nDims, bool isSplit) {
|
||||||
dim3 blockSize = dim3(32 * 16);
|
dim3 blockSize = dim3(32 * 16);
|
||||||
// gridsize =n_elements / blockSize
|
// gridsize = max_n_elements / blockSize
|
||||||
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
|
int max_n_elements =
|
||||||
|
*std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize);
|
||||||
|
int gridDimX = (max_n_elements - 1) / (32 * 16) + 1;
|
||||||
// each y is a split among the batch
|
// each y is a split among the batch
|
||||||
dim3 gridSize(gridDimX, batchSize);
|
dim3 gridSize(gridDimX, batchSize);
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
#include "utils/operator_utils.h"
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
|
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
|
||||||
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) {
|
: OperatorObj(OpType::Concat, inputs, {output}) {
|
||||||
int rank = inputs[0]->getRank();
|
int rank = inputs[0]->getRank();
|
||||||
dim = get_real_axis(dim, rank);
|
dim = get_real_axis(_dim, rank);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
|
||||||
auto size = shape.size();
|
auto size = shape.size();
|
||||||
this->axes.reserve(size);
|
this->axes.reserve(size);
|
||||||
for (size_t i = 0; i < size; ++i)
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
auto len = shape[i];
|
||||||
if (auto _i = axes.find(i); _i != axes.end()) {
|
if (auto _i = axes.find(i); _i != axes.end()) {
|
||||||
auto __i = _i->second;
|
auto __i = _i->second;
|
||||||
auto start = starts[__i];
|
auto start = starts[__i];
|
||||||
auto end = ends[__i];
|
auto end = ends[__i];
|
||||||
this->axes.push_back({start >= 0 ? start : start + shape[__i],
|
if (start > len)
|
||||||
end >= 0 ? end : end + shape[__i],
|
start = len;
|
||||||
steps[__i]});
|
if (end > len)
|
||||||
|
end = len;
|
||||||
|
this->axes.push_back({start >= 0 ? start : start + len,
|
||||||
|
end >= 0 ? end : end + len, steps[__i]});
|
||||||
} else {
|
} else {
|
||||||
this->axes.push_back({0, shape[i], 1});
|
this->axes.push_back({0, len, 1});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue