forked from jiuyuan/InfiniTensor
tensor parallel for transformer (#125)
* add cmake bits about NCCL * move example to examples/NNmodel * impl NCCL communicator * add comm related function to Runtime * export runtime interface * add launch.py * use unique name to distingush the the NCCL ID file * add timeout to communicator init * expose communicator obj from runtime obj, add unit test for nccl communicator * reformat files * Add allReduce operator and cuda nccl allReduce kernel * impl model parallel for resnet * add allGather nccl kernel and operator * Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output * fix format of onnx.py * use concat following AllGather * get tensor parallel for resnet * fix format of graph_handler.cc * change BUILD_DIST default to OFF * polish code of communicator * update .gitignore * export min/max to python * fix MatMul * modify launch.py to run opt * hack to treat ReduceSum as AllReduceSum * throw exception in cuda error * fix parallel_opt.py * improve the error prompt and cuda error check * fix GatherObj::GatherObj member init * fix size calculation for scalar (rank = 0) tensor * MatMul supports bias * fix add bias for row parallel gemm * add --gen_std to launch.py * fix AllReduceNCCL * update launch.py * less log * update parallel_opt * update launch.py * add __eq__ for Placement sub-classes * less benchmark run * fix placement infer for matmul * fix vacabuary size * fix Exception * Add shard tensor with group to support gpt2 * Add find successor function to find split op at different depth * recover CommunicatorObj * improve error mesasge * optimize parallel_opt.py * optimize launch.py * recover docs for all_reduce and all_gather * Fix API * fix format --------- Co-authored-by: panzezhong <panzezhong@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
dda668fd16
commit
4c321c8a91
|
@ -4,8 +4,12 @@ import time
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from pyinfinitensor.onnx import OnnxStub, backend
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
import onnx
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parallel import parallel_model
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -14,77 +18,126 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, required=True, help="path to the ONNX model file."
|
"--model", type=str, required=True, help="path to the ONNX model file."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print("arg setting: ", args)
|
print("arg setting: ", args)
|
||||||
return args.num_nodes, args.nproc_per_node, args.model
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.gen_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_stub(stub: OnnxStub, inputs: np.array, n=100):
|
def run_model(model, runtime, inputs: np.array, n=20):
|
||||||
# warm up
|
stub = OnnxStub(model, runtime)
|
||||||
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||||
stub.tune()
|
stub.tune()
|
||||||
for _ in range(20):
|
|
||||||
stub.run()
|
stub.run()
|
||||||
|
# get outputs
|
||||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||||
begin = time.time()
|
begin = time.time()
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
stub.run()
|
stub.run()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
|
||||||
print("outputs sum:", outputs.sum())
|
|
||||||
# np.save("results", outputs)
|
|
||||||
results = np.load("results.npy")
|
|
||||||
print("max diff:", abs(outputs - results).max())
|
|
||||||
assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6)
|
|
||||||
avg_time = (end - begin) / n
|
avg_time = (end - begin) / n
|
||||||
return avg_time
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime):
|
||||||
|
data = np.load(f"{name}_inputs.npy")
|
||||||
|
results = np.load(f"{name}_results.npy")
|
||||||
|
outputs = run_model(model, runtime, data)
|
||||||
|
print("outputs sum:", outputs.sum())
|
||||||
|
print("max abs diff:", abs(outputs - results).max())
|
||||||
|
print("max rel diff:", abs((outputs - results) / results).max())
|
||||||
|
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
def start_worker(
|
def start_worker(
|
||||||
dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
):
|
):
|
||||||
print("start worker")
|
dist_name = name + "_dist"
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
convert_model_to_external_data(
|
||||||
|
model,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=extern_path,
|
||||||
|
size_threshold=1024,
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
|
||||||
runtime = backend.CudaRuntime(local_rank)
|
runtime = backend.CudaRuntime(local_rank)
|
||||||
print("init comm")
|
# print("init comm")
|
||||||
runtime.init_comm(
|
runtime.init_comm(
|
||||||
dist_name,
|
dist_name,
|
||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
)
|
)
|
||||||
model = parallel_model(model, world_size, rank)
|
run_and_compare(name, model, runtime)
|
||||||
onnx.save(model, f"dist_model_rank{rank}.onnx")
|
|
||||||
print("load model")
|
|
||||||
stub = OnnxStub(model, runtime)
|
def start_single(name, model):
|
||||||
data = np.load("inputs.npy")
|
runtime = backend.CudaRuntime(0)
|
||||||
print("run model")
|
run_and_compare(name, model, runtime)
|
||||||
avg_time = run_stub(stub, data)
|
|
||||||
print(f"average time: {avg_time}")
|
|
||||||
|
def gen_standard(name, model, voc_size, bs, len):
|
||||||
|
# generate standard results
|
||||||
|
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||||
|
np.save(f"{name}_inputs", data)
|
||||||
|
runtime = backend.CudaRuntime(0)
|
||||||
|
outputs = run_model(model, runtime, data, 1)
|
||||||
|
np.save(f"{name}_results", outputs)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
nnodes, nproc_per_node, model_path = parse_args()
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||||
world_size = nnodes * nproc_per_node
|
|
||||||
|
|
||||||
model = onnx.load(model_path)
|
model = onnx.load(model_path)
|
||||||
# generate standard results
|
|
||||||
# runtime = backend.CudaRuntime(0)
|
|
||||||
# stub = OnnxStub(model, runtime)
|
|
||||||
# data = np.random.randn(1, 3, 224, 224)
|
|
||||||
# np.save("inputs", data)
|
|
||||||
# run_stub(stub, data)
|
|
||||||
# del stub
|
|
||||||
|
|
||||||
dist_name = f"dist_{os.getpid()}"
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print(f"generate standard data for {name}.")
|
||||||
|
# a small vocabulary size to fit all LLM.
|
||||||
|
voc_size = 1000
|
||||||
|
gen_standard(name, model, voc_size, bs, length)
|
||||||
|
return
|
||||||
|
|
||||||
|
# run single process.
|
||||||
|
# use standalone process to isolate cuda.
|
||||||
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
workers = [
|
workers = [
|
||||||
mp.Process(
|
mp.Process(
|
||||||
target=start_worker,
|
target=start_worker,
|
||||||
args=(dist_name, world_size, rank, rank % nproc_per_node, model),
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
)
|
)
|
||||||
for rank in range(world_size)
|
for rank in range(world_size)
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,221 @@
|
||||||
|
import onnx
|
||||||
|
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto
|
||||||
|
from onnx import helper, numpy_helper
|
||||||
|
from typing import Dict, List
|
||||||
|
from placement import Placement, Replicate, Shard, _Partial
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
|
data = {init.name: init for init in model.graph.initializer}
|
||||||
|
vinfo = {info.name: info for info in model.graph.value_info}
|
||||||
|
vinfo.update({info.name: info for info in model.graph.input})
|
||||||
|
vinfo.update({info.name: info for info in model.graph.output})
|
||||||
|
place: Dict[str, Placement] = {}
|
||||||
|
nodes: List[NodeProto] = []
|
||||||
|
|
||||||
|
def is_sharded(name: str):
|
||||||
|
return place[name].is_shard()
|
||||||
|
|
||||||
|
def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1):
|
||||||
|
# print(f"shard {tensor.name} at dim {dim}")
|
||||||
|
assert plc.is_shard(), plc
|
||||||
|
ndim = len(tensor.dims)
|
||||||
|
if plc.dim < 0:
|
||||||
|
plc.dim += ndim
|
||||||
|
if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard.
|
||||||
|
return tensor
|
||||||
|
array = numpy_helper.to_array(tensor)
|
||||||
|
assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim]
|
||||||
|
dims = list(tensor.dims)
|
||||||
|
dims.insert(plc.dim, groups)
|
||||||
|
dims[plc.dim + 1] //= groups
|
||||||
|
array = array.reshape(dims)
|
||||||
|
seg = array.shape[plc.dim + 1] // tp_world_size
|
||||||
|
array = array.take(
|
||||||
|
indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1
|
||||||
|
)
|
||||||
|
dims = list(tensor.dims)
|
||||||
|
dims[plc.dim] //= tp_world_size
|
||||||
|
array = array.reshape(dims)
|
||||||
|
tensor = numpy_helper.from_array(array, name=tensor.name)
|
||||||
|
place[tensor.name] = plc
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def shard_gemm(node: NodeProto, groups: int = 1):
|
||||||
|
# print("gemm", node.name)
|
||||||
|
in_plc = place[node.input[0]]
|
||||||
|
w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0)
|
||||||
|
transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0)
|
||||||
|
if transB:
|
||||||
|
w_plc.dim = ~w_plc.dim
|
||||||
|
input = node.input[1]
|
||||||
|
data[input] = shard_tensor(data[input], w_plc, groups)
|
||||||
|
|
||||||
|
output = node.output[0]
|
||||||
|
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
||||||
|
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
||||||
|
place[node.output[0]] = out_plc
|
||||||
|
|
||||||
|
def shard_binary(node: NodeProto, groups: int = 1):
|
||||||
|
# print("binary", node.name, node.input[0], place[node.input[0]])
|
||||||
|
a = node.input[0]
|
||||||
|
b = node.input[1]
|
||||||
|
if a in data:
|
||||||
|
a, b = b, a
|
||||||
|
place[node.output[0]] = place[a]
|
||||||
|
if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast
|
||||||
|
data[b] = shard_tensor(data[b], Shard(0), groups)
|
||||||
|
|
||||||
|
def shard_reshape(node: NodeProto):
|
||||||
|
# print("reshape", node.name, node.input[0], place[node.input[0]])
|
||||||
|
if not is_sharded(node.input[0]):
|
||||||
|
return
|
||||||
|
in_plc = place[node.input[0]]
|
||||||
|
s_dim = -1
|
||||||
|
in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim]
|
||||||
|
tensor = data[node.input[1]]
|
||||||
|
out_dims = numpy_helper.to_array(tensor).copy()
|
||||||
|
if len(in_dims) == 3 and len(out_dims) == 4:
|
||||||
|
if in_plc.dim == 0:
|
||||||
|
s_dim = 1
|
||||||
|
elif in_plc.dim == 2:
|
||||||
|
s_dim = 2
|
||||||
|
if len(in_dims) == 4 and len(out_dims) == 3:
|
||||||
|
if in_plc.dim == 1:
|
||||||
|
s_dim = 0
|
||||||
|
elif in_plc.dim == 2:
|
||||||
|
s_dim = 2
|
||||||
|
if len(in_dims) == 2 and len(out_dims) == 3:
|
||||||
|
if in_plc.dim == 1:
|
||||||
|
s_dim = 2
|
||||||
|
if len(in_dims) == 4 and len(out_dims) == 2:
|
||||||
|
if in_plc.dim == 1:
|
||||||
|
s_dim = 0
|
||||||
|
elif in_plc.dim == 2:
|
||||||
|
s_dim = 1
|
||||||
|
if len(in_dims) == 3 and len(out_dims) == 2:
|
||||||
|
if in_plc.dim == 1:
|
||||||
|
s_dim = 0
|
||||||
|
elif in_plc.dim == 2:
|
||||||
|
s_dim = 1
|
||||||
|
|
||||||
|
assert s_dim != -1
|
||||||
|
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||||
|
out_dims[s_dim] //= tp_world_size
|
||||||
|
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
|
||||||
|
# node.input[1] = node.output[0] + "_shape"
|
||||||
|
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
|
||||||
|
place[node.output[0]] = Shard(s_dim)
|
||||||
|
|
||||||
|
def shard_split(node: NodeProto):
|
||||||
|
if not is_sharded(node.input[0]):
|
||||||
|
return
|
||||||
|
in_plc = place[node.input[0]]
|
||||||
|
split_tensor = data[node.input[1]]
|
||||||
|
split = numpy_helper.to_array(split_tensor).copy()
|
||||||
|
split //= tp_world_size
|
||||||
|
data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1])
|
||||||
|
for output in node.output:
|
||||||
|
place[output] = in_plc
|
||||||
|
|
||||||
|
def shard_transpose(node: NodeProto):
|
||||||
|
plc = place[node.input[0]]
|
||||||
|
if plc.is_shard():
|
||||||
|
perm = next(attr.ints for attr in node.attribute if attr.name == "perm")
|
||||||
|
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||||
|
|
||||||
|
def shard_node(node: NodeProto):
|
||||||
|
if node.op_type in ["Relu", "Tanh", "Softmax"]:
|
||||||
|
place[node.output[0]] = place[node.input[0]]
|
||||||
|
elif node.op_type in ["Where"]:
|
||||||
|
place[node.output[0]] = place[node.input[1]]
|
||||||
|
if node.op_type in {"Add", "Mul", "Div", "Max"}:
|
||||||
|
shard_binary(node)
|
||||||
|
elif node.op_type == "Reshape":
|
||||||
|
shard_reshape(node)
|
||||||
|
elif node.op_type == "Transpose":
|
||||||
|
shard_transpose(node)
|
||||||
|
elif node.op_type == "Split":
|
||||||
|
shard_split(node)
|
||||||
|
elif node.op_type == "MatMul":
|
||||||
|
assert (
|
||||||
|
place[node.input[0]] == place[node.input[1]]
|
||||||
|
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||||
|
place[node.output[0]] = place[node.input[0]]
|
||||||
|
|
||||||
|
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
||||||
|
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
||||||
|
if node.op_type == op_type:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
# all tensors are initially replicated.
|
||||||
|
for v in vinfo:
|
||||||
|
place[v] = Replicate()
|
||||||
|
|
||||||
|
for t in data:
|
||||||
|
place[t] = Replicate()
|
||||||
|
|
||||||
|
for index, node in enumerate(model.graph.node):
|
||||||
|
nodes.append(node)
|
||||||
|
# linear
|
||||||
|
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||||
|
input in data for input in node.input
|
||||||
|
):
|
||||||
|
groups = 1
|
||||||
|
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||||
|
split_node = find_successor("Split", index, search_limit=2)
|
||||||
|
if split_node is not None:
|
||||||
|
groups = len(split_node.output)
|
||||||
|
shard_gemm(node, groups)
|
||||||
|
plc = place[node.output[0]]
|
||||||
|
if plc.is_partial():
|
||||||
|
new_name = node.output[0] + f":{plc}"
|
||||||
|
place[new_name] = place[node.output[0]]
|
||||||
|
# insert all_reduce
|
||||||
|
nodes.append(
|
||||||
|
helper.make_node(
|
||||||
|
op_type="ReduceSum",
|
||||||
|
inputs=[new_name],
|
||||||
|
outputs=[node.output[0]],
|
||||||
|
name=node.name + "/all_reduce",
|
||||||
|
noop_with_empty_axes=1,
|
||||||
|
communicator=0, # hack to treat ReduceSum as AllReduceSum
|
||||||
|
)
|
||||||
|
)
|
||||||
|
place[node.output[0]] = Replicate()
|
||||||
|
node.output[0] = new_name
|
||||||
|
if len(node.input) > 2: # split bias to add
|
||||||
|
prev = nodes[-1]
|
||||||
|
new_name = prev.output[0] + "_no_bias"
|
||||||
|
place[new_name] = place[node.output[0]]
|
||||||
|
bias = helper.make_node(
|
||||||
|
op_type="Add",
|
||||||
|
inputs=[new_name, node.input[2]],
|
||||||
|
outputs=[prev.output[0]],
|
||||||
|
name=node.name + "/bias",
|
||||||
|
)
|
||||||
|
node.input.pop()
|
||||||
|
prev.output[0] = new_name
|
||||||
|
shard_binary(bias, groups)
|
||||||
|
nodes.append(bias)
|
||||||
|
continue
|
||||||
|
shard_node(node)
|
||||||
|
|
||||||
|
graph = helper.make_graph(
|
||||||
|
nodes,
|
||||||
|
model.graph.name + f"_{tp_rank}",
|
||||||
|
model.graph.input,
|
||||||
|
model.graph.output,
|
||||||
|
data.values(),
|
||||||
|
doc_string=model.graph.doc_string,
|
||||||
|
# value_info=vinfo.values(),
|
||||||
|
)
|
||||||
|
for output in graph.output:
|
||||||
|
tt = output.type.tensor_type
|
||||||
|
if tt.HasField("shape"):
|
||||||
|
tt.ClearField("shape")
|
||||||
|
model = helper.make_model(graph)
|
||||||
|
model = onnx.shape_inference.infer_shapes(model)
|
||||||
|
return model
|
|
@ -0,0 +1,64 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Placement:
|
||||||
|
# base class Placement type
|
||||||
|
|
||||||
|
# convenient utils to check for placement types
|
||||||
|
def is_shard(self, dim: Optional[int] = None) -> bool:
|
||||||
|
if dim is not None and isinstance(self, Shard):
|
||||||
|
return self.dim == dim
|
||||||
|
else:
|
||||||
|
return isinstance(self, Shard)
|
||||||
|
|
||||||
|
def is_replicate(self) -> bool:
|
||||||
|
return isinstance(self, Replicate)
|
||||||
|
|
||||||
|
def is_partial(self) -> bool:
|
||||||
|
return isinstance(self, _Partial)
|
||||||
|
|
||||||
|
|
||||||
|
class Replicate(Placement):
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, Replicate):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
machine readable representation of the Replicate placement
|
||||||
|
"""
|
||||||
|
return "Replicate()"
|
||||||
|
|
||||||
|
|
||||||
|
class Shard(Placement):
|
||||||
|
# shard placement, shard on a dim
|
||||||
|
def __init__(self, dim):
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, Shard):
|
||||||
|
return False
|
||||||
|
return self.dim == other.dim
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
machine readable representation of the Shard placement
|
||||||
|
"""
|
||||||
|
return f"Shard(dim={self.dim})"
|
||||||
|
|
||||||
|
|
||||||
|
class _Partial(Placement):
|
||||||
|
def __init__(self, reduce_op: str = "sum"):
|
||||||
|
self.reduce_op: str = reduce_op
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, _Partial):
|
||||||
|
return False
|
||||||
|
return self.reduce_op == other.reduce_op
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
machine readable representation of the Partial placement
|
||||||
|
"""
|
||||||
|
return f"_Partial(reduce_op={self.reduce_op})"
|
|
@ -40,12 +40,12 @@ using HashType = uint64_t; // compatible with std::hash
|
||||||
|
|
||||||
// Assert: conditions should have no side effect
|
// Assert: conditions should have no side effect
|
||||||
#define _IT_ASSERT_2(condition, info) \
|
#define _IT_ASSERT_2(condition, info) \
|
||||||
(static_cast<bool>(condition) \
|
static_cast<bool>(condition) \
|
||||||
? void(0) \
|
? void(0) \
|
||||||
: throw ::infini::Exception( \
|
: throw ::infini::Exception( \
|
||||||
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||||
"] Assertion failed (" + #condition + "): " + info))
|
"] Assertion failed (" + #condition + "): " + info)
|
||||||
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "");
|
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "")
|
||||||
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
||||||
|
|
||||||
#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented")
|
#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented")
|
||||||
|
|
|
@ -6,16 +6,11 @@
|
||||||
#include <cudnn.h>
|
#include <cudnn.h>
|
||||||
#include <curand.h>
|
#include <curand.h>
|
||||||
|
|
||||||
// TODO: replace with Exception (IT_ASSERT)
|
|
||||||
#define checkCudaError(call) \
|
#define checkCudaError(call) \
|
||||||
{ \
|
if (auto err = call; err != cudaSuccess) \
|
||||||
auto err = call; \
|
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
|
||||||
if (cudaSuccess != err) { \
|
std::to_string(__LINE__) + "] CUDA error (" + \
|
||||||
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
|
#call + "): " + cudaGetErrorString(err))
|
||||||
cudaGetErrorString(err)); \
|
|
||||||
exit(EXIT_FAILURE); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define checkCUresult(call) \
|
#define checkCUresult(call) \
|
||||||
{ \
|
{ \
|
||||||
|
@ -39,14 +34,10 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
#define checkCudnnError(call) \
|
#define checkCudnnError(call) \
|
||||||
{ \
|
if (auto err = call; err != CUDNN_STATUS_SUCCESS) \
|
||||||
auto err = call; \
|
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
|
||||||
if (CUDNN_STATUS_SUCCESS != err) { \
|
std::to_string(__LINE__) + "] cuDNN error (" + \
|
||||||
fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \
|
#call + "): " + cudnnGetErrorString(err))
|
||||||
__LINE__, cudnnGetErrorString(err)); \
|
|
||||||
exit(EXIT_FAILURE); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define checkCurandError(call) \
|
#define checkCurandError(call) \
|
||||||
{ \
|
{ \
|
||||||
|
|
|
@ -5,8 +5,18 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
class Exception : public std::runtime_error {
|
class Exception : public std::runtime_error {
|
||||||
|
protected:
|
||||||
|
std::string info;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Exception(const std::string &msg);
|
Exception(const std::string &msg);
|
||||||
|
|
||||||
|
Exception &operator<<(const std::string &str) {
|
||||||
|
info += str;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *what() const noexcept override { return info.c_str(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
#pragma once
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
#define SMALL_ARRAY_SIZE 8
|
#define SMALL_ARRAY_SIZE 8
|
||||||
|
|
|
@ -591,6 +591,13 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next((attr.i for attr in node.attribute if attr.name == "to")),
|
next((attr.i for attr in node.attribute if attr.name == "to")),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "ReduceSum":
|
||||||
|
# ReduceSum is only implemented as allReduceSum.
|
||||||
|
assert any(attr.name == "communicator" for attr in node.attribute)
|
||||||
|
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "AllReduceSum":
|
elif node.op_type == "AllReduceSum":
|
||||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -631,11 +638,7 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next(
|
next(
|
||||||
(
|
(attr.i for attr in node.attribute if attr.name == "root"),
|
||||||
attr.i
|
|
||||||
for attr in node.attribute
|
|
||||||
if attr.name == "root"
|
|
||||||
),
|
|
||||||
0,
|
0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -382,9 +382,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
split = make_node(
|
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||||
"Split", ["input"], ["output"], name="split", axis=0
|
|
||||||
)
|
|
||||||
make_and_import_model(make_graph([split], "split", [input], []))
|
make_and_import_model(make_graph([split], "split", [input], []))
|
||||||
|
|
||||||
def test_allBroadcast(self):
|
def test_allBroadcast(self):
|
||||||
|
@ -461,7 +459,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||||
|
|
||||||
def test_copyin(self):
|
def test_copyin(self):
|
||||||
dims = [2,3,5,4]
|
dims = [2, 3, 5, 4]
|
||||||
np_array = np.random.random(dims).astype(np.float32)
|
np_array = np.random.random(dims).astype(np.float32)
|
||||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
@ -487,7 +485,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||||
|
|
||||||
def test_to_numpy(self):
|
def test_to_numpy(self):
|
||||||
dims = [2,3,5,4]
|
dims = [2, 3, 5, 4]
|
||||||
np_array = np.random.random(dims).astype(np.float32)
|
np_array = np.random.random(dims).astype(np.float32)
|
||||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
@ -508,5 +506,6 @@ class TestStringMethods(unittest.TestCase):
|
||||||
array1 = np.array(tensor1, copy=False)
|
array1 = np.array(tensor1, copy=False)
|
||||||
self.assertTrue(np.array_equal(array1, np_array))
|
self.assertTrue(np.array_equal(array1, np_array))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
|
||||||
#ifdef DEBUG_MODE
|
|
||||||
void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
||||||
cudaError_t kernelError = cudaGetLastError();
|
cudaError_t kernelError = cudaGetLastError();
|
||||||
if (kernelError != cudaSuccess) {
|
if (kernelError != cudaSuccess) {
|
||||||
|
@ -18,7 +17,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -38,10 +36,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||||
} else {
|
} else {
|
||||||
kernel->compute(op, this);
|
kernel->compute(op, this);
|
||||||
}
|
}
|
||||||
|
checkCudaError(cudaGetLastError()) << op->toString();
|
||||||
#ifdef DEBUG_MODE
|
|
||||||
CHECK_CUDA_KERNEL_ERROR(op);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,9 +73,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
||||||
opCnt[op->getOpType()]++;
|
opCnt[op->getOpType()]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef DEBUG_MODE
|
checkCudaError(cudaGetLastError()) << op->toString();
|
||||||
CHECK_CUDA_KERNEL_ERROR(op);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +96,7 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
||||||
IT_ASSERT(worldSize > 0);
|
IT_ASSERT(worldSize > 0);
|
||||||
IT_ASSERT(rank >= 0);
|
IT_ASSERT(rank >= 0);
|
||||||
IT_ASSERT(rank < worldSize);
|
IT_ASSERT(rank < worldSize);
|
||||||
|
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||||
#ifdef INFINI_USE_NCCL
|
#ifdef INFINI_USE_NCCL
|
||||||
comm = std::make_unique<NcclCommunicatorObj>(name, worldSize, rank);
|
comm = std::make_unique<NcclCommunicatorObj>(name, worldSize, rank);
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -421,6 +421,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("mul", &Handler::mul, policy::move)
|
.def("mul", &Handler::mul, policy::move)
|
||||||
.def("div", &Handler::div, policy::move)
|
.def("div", &Handler::div, policy::move)
|
||||||
.def("pow", &Handler::pow, policy::move)
|
.def("pow", &Handler::pow, policy::move)
|
||||||
|
.def("min", &Handler::min, policy::move)
|
||||||
|
.def("max", &Handler::max, policy::move)
|
||||||
.def("relu", &Handler::relu, policy::move)
|
.def("relu", &Handler::relu, policy::move)
|
||||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||||
.def("tanh", &Handler::tanh, policy::move)
|
.def("tanh", &Handler::tanh, policy::move)
|
||||||
|
|
|
@ -14,7 +14,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
||||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
size_t count = op->getInputs(0)->size();
|
||||||
|
|
||||||
ncclComm_t comm =
|
ncclComm_t comm =
|
||||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_expand.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
@ -46,7 +48,30 @@ class matmulCublas : public Kernel {
|
||||||
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||||
ldc = n;
|
ldc = n;
|
||||||
const float alpha = 1.f, beta = 0.f;
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
if (op->numInputs() == 2) { // no bias
|
||||||
|
beta = 0.f;
|
||||||
|
} else { // broadcast bias to output
|
||||||
|
beta = 1.f;
|
||||||
|
auto inC = op->getInputs(2);
|
||||||
|
auto out = op->getOutput();
|
||||||
|
SmallArray inputShape, outputShape;
|
||||||
|
int nDims = out->getRank();
|
||||||
|
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||||
|
int outputsize = 1; // the length of the output vector after flatten
|
||||||
|
int offset = nDims - inC->getRank();
|
||||||
|
for (int i = 0; i < offset; ++i)
|
||||||
|
inputShape.data[i] = 1;
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
outputShape.data[i] = out->getDims()[i];
|
||||||
|
outputsize *= outputShape.data[i];
|
||||||
|
if (i >= offset)
|
||||||
|
inputShape.data[i] = inC->getDims()[i - offset];
|
||||||
|
}
|
||||||
|
expandKernel(inC->getRawDataPtr<float *>(),
|
||||||
|
out->getRawDataPtr<float *>(), nDims, outputsize,
|
||||||
|
inputShape, outputShape);
|
||||||
|
}
|
||||||
// TODO:use compute type
|
// TODO:use compute type
|
||||||
cublasStatus_t stat;
|
cublasStatus_t stat;
|
||||||
if (b > 1) {
|
if (b > 1) {
|
||||||
|
|
|
@ -6,7 +6,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
|
||||||
Tensor output, int axis)
|
Tensor output, int axis)
|
||||||
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
|
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
|
||||||
int rank = input->getRank();
|
int rank = input->getRank();
|
||||||
axis = get_real_axis(axis, rank);
|
this->axis = get_real_axis(axis, rank);
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
||||||
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
||||||
IT_ASSERT(inputs.size() == 2);
|
IT_ASSERT(inputs.size() == 2);
|
||||||
auto index_dtype = inputs[1]->getDType();
|
auto index_dtype = inputs[1]->getDType();
|
||||||
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64)
|
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64);
|
||||||
return {inputs[0]->getDType()};
|
return {inputs[0]->getDType()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,8 @@ namespace backward_trace = backward;
|
||||||
backward_trace::SignalHandling sh;
|
backward_trace::SignalHandling sh;
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
Exception::Exception(const std::string &msg) : std::runtime_error(msg) {
|
Exception::Exception(const std::string &msg)
|
||||||
|
: std::runtime_error(msg), info(msg) {
|
||||||
backward_trace::StackTrace st;
|
backward_trace::StackTrace st;
|
||||||
st.load_here(32);
|
st.load_here(32);
|
||||||
backward_trace::Printer p;
|
backward_trace::Printer p;
|
||||||
|
|
Loading…
Reference in New Issue