diff --git a/Makefile b/Makefile index 84d01b9e..19f1b353 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,10 @@ test-onnx: @echo python3 pyinfinitensor/tests/test_onnx.py +test-api: + @echo + python3 pyinfinitensor/tests/test_api.py + docker-build: docker build -f scripts/dockerfile/$(DOCKER_FILE) -t $(DOCKER_NAME) . diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py index 362bde1c..64930e6e 100644 --- a/examples/distributed/launch.py +++ b/examples/distributed/launch.py @@ -4,8 +4,12 @@ import time import multiprocessing as mp from pyinfinitensor.onnx import OnnxStub, backend import onnx +from onnx.external_data_helper import convert_model_to_external_data import numpy as np -from parallel import parallel_model +from parallel_opt import parallel_model + + +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" def parse_args(): @@ -14,77 +18,126 @@ def parse_args(): parser.add_argument( "--nproc_per_node", type=int, default=1, help="number of processes per node" ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) parser.add_argument( "--model", type=str, required=True, help="path to the ONNX model file." ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + action="store_true", + help="whether to generate the standard results.", + ) args = parser.parse_args() print("arg setting: ", args) - return args.num_nodes, args.nproc_per_node, args.model + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) -def run_stub(stub: OnnxStub, inputs: np.array, n=100): - # warm up - next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) +def run_model(model, runtime, inputs: np.array, n=20): + stub = OnnxStub(model, runtime) + next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) stub.tune() - for _ in range(20): - stub.run() + stub.run() + # get outputs outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) # bench - next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist()) + next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) begin = time.time() for _ in range(n): stub.run() end = time.time() - outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) - print("outputs sum:", outputs.sum()) - # np.save("results", outputs) - results = np.load("results.npy") - print("max diff:", abs(outputs - results).max()) - assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6) avg_time = (end - begin) / n - return avg_time + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, model, runtime): + data = np.load(f"{name}_inputs.npy") + results = np.load(f"{name}_results.npy") + outputs = run_model(model, runtime, data) + print("outputs sum:", outputs.sum()) + print("max abs diff:", abs(outputs - results).max()) + print("max rel diff:", abs((outputs - results) / results).max()) + # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) def start_worker( - dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto ): - print("start worker") + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model, f"./{dist_name}_rank{rank}.onnx") runtime = backend.CudaRuntime(local_rank) - print("init comm") + # print("init comm") runtime.init_comm( dist_name, world_size, rank, ) - model = parallel_model(model, world_size, rank) - onnx.save(model, f"dist_model_rank{rank}.onnx") - print("load model") - stub = OnnxStub(model, runtime) - data = np.load("inputs.npy") - print("run model") - avg_time = run_stub(stub, data) - print(f"average time: {avg_time}") + run_and_compare(name, model, runtime) + + +def start_single(name, model): + runtime = backend.CudaRuntime(0) + run_and_compare(name, model, runtime) + + +def gen_standard(name, model, voc_size, bs, len): + # generate standard results + data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + np.save(f"{name}_inputs", data) + runtime = backend.CudaRuntime(0) + outputs = run_model(model, runtime, data, 1) + np.save(f"{name}_results", outputs) def main(): - nnodes, nproc_per_node, model_path = parse_args() - world_size = nnodes * nproc_per_node + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() model = onnx.load(model_path) - # generate standard results - # runtime = backend.CudaRuntime(0) - # stub = OnnxStub(model, runtime) - # data = np.random.randn(1, 3, 224, 224) - # np.save("inputs", data) - # run_stub(stub, data) - # del stub - dist_name = f"dist_{os.getpid()}" + # generate standart output + if gen_std: + print(f"generate standard data for {name}.") + # a small vocabulary size to fit all LLM. + voc_size = 1000 + gen_standard(name, model, voc_size, bs, length) + return + + # run single process. + # use standalone process to isolate cuda. + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node workers = [ mp.Process( target=start_worker, - args=(dist_name, world_size, rank, rank % nproc_per_node, model), + args=(name, world_size, rank, rank % nproc_per_node, model), ) for rank in range(world_size) ] diff --git a/examples/distributed/launch_kvcache.py b/examples/distributed/launch_kvcache.py new file mode 100644 index 00000000..13f908af --- /dev/null +++ b/examples/distributed/launch_kvcache.py @@ -0,0 +1,245 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +import numpy as np +from parallel_opt import parallel_model + + +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=1, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model1", type=str, required=True, help="path to the ONNX model file." + ) + parser.add_argument( + "--model2", type=str, required=True, help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model1, + args.model2, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(model1, model2, runtime1, runtime2, inputs1: np.array, inputs2: np.array, n=20): + #################################### + # run the first graph without kvcache + #################################### + stub1 = OnnxStub(model1, runtime1) + stub1.inputs['onnx::Reshape_0'].copyin_int32(inputs1.reshape(-1).tolist()) + stub1.tune() + stub1.run() + kvcache_it1 = [] + count = 0 + for output in stub1.outputs.items().__iter__(): + if count == 0: + logits_it1 = np.array(output[1].copyout_float(), dtype=np.float32) + else: + kvcache_it1.append(np.array(output[1].copyout_float(), dtype=np.float32)) + count = count + 1 + + # bench for stub1 + next(stub1.inputs.items().__iter__())[1].copyin_int32(inputs1.reshape(-1).tolist()) + begin = time.time() + for _ in range(n): + stub1.run() + end = time.time() + avg_time = (end - begin) / n + print(f"stub1 average time: {avg_time}") + + #################################### + # run the second graph with kvcache + #################################### + i = 0 + batchsize = 1 + stub2 = OnnxStub(model2, runtime2) + past_kvcache_length = (i+2)*np.ones((batchsize, 1), dtype=np.int32) + # copyin input + stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist()) + stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist()) + count = -1 + for input in stub2.inputs.items().__iter__(): + if count in range(24): + # print(count, input[0]) + # print(np.dtype(kvcache_it1[count][0]), kvcache_it1[count].shape) + input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist()) + count = count + 1 + stub2.tune() + stub2.run() + + # copyout output + count = 0 + kvcache_it2 = [] + for output in stub2.outputs.items().__iter__(): + if count == 0: + logits_it2 = np.array(output[1].copyout_float(), dtype=np.float32) + else: + kvcache_it2.append(np.array(output[1].copyout_float(), dtype=np.float32)) + count = count + 1 + + # bench for stub2 + # copyin input + stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist()) + stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist()) + count = -1 + for input in stub2.inputs.items().__iter__(): + if count in range(24): + input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist()) + count = count + 1 + begin = time.time() + for _ in range(n): + stub2.run() + end = time.time() + avg_time = (end - begin) / n + print(f"stub2 average time: {avg_time}") + return logits_it2 + + +def run_and_compare(name, model1, model2, runtime1, runtime2): + data1 = np.load(f"{name}_inputs1.npy") + data2 = np.load(f"{name}_inputs2.npy") + results = np.load(f"{name}_results.npy") + outputs = run_model(model1, model2, runtime1, runtime2, data1, data2) + print("outputs sum:", outputs.sum()) + print("max abs diff:", abs(outputs - results).max()) + print("max rel diff:", abs((outputs - results) / results).max()) + # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model1: onnx.ModelProto, model2: onnx.ModelProto +): + dist_name = name + "_dist" + #################################### + # shard the first graph + #################################### + model1 = parallel_model(model1, world_size, rank) + extern_path = f"./{dist_name}_stub1_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model1, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model1, f"./{dist_name}_stub1_rank{rank}.onnx") + runtime1 = backend.CudaRuntime(local_rank) + runtime1.init_comm( + dist_name, + world_size, + rank, + ) + + #################################### + # shard the second graph + #################################### + model2 = parallel_model(model2, world_size, rank) + extern_path = f"./{dist_name}_stub2_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model2, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model2, f"./{dist_name}_stub2_rank{rank}.onnx") + runtime2 = backend.CudaRuntime(local_rank) + # print("init comm") + runtime2.init_comm( + dist_name, + world_size, + rank, + ) + + # run the two graphs + run_and_compare(name, model1, model2, runtime1, runtime2) + + +def start_single(name, model1, model2): + runtime1 = backend.CudaRuntime(0) + runtime2 = backend.CudaRuntime(0) + run_and_compare(name, model1, model2, runtime1, runtime2) + + +def gen_standard(name, model1, model2, voc_size, bs, len): + # generate standard results + data1 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + data2 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + np.save(f"{name}_inputs1", data1) + np.save(f"{name}_inputs2", data2) + runtime1 = backend.CudaRuntime(0) + runtime2 = backend.CudaRuntime(0) + outputs = run_model(model1, model2, runtime1, runtime2, data1, data2, 1) + np.save(f"{name}_results", outputs) + + +def main(): + nnodes, nproc_per_node, name, model1_path, model2_path, bs, length, gen_std = parse_args() + + model1 = onnx.load(model1_path) + model2 = onnx.load(model2_path) + + # generate standart output + if gen_std: + print(f"generate standard data for {name}.") + # a small vocabulary size to fit all LLM. + voc_size = 1000 + gen_standard(name, model1, model2, voc_size, bs, length) + return + + # run single process. + # use standalone process to isolate cuda. + p = mp.Process(target=start_single, args=(name, model1, model2)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model1, model2), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py new file mode 100644 index 00000000..42465a69 --- /dev/null +++ b/examples/distributed/parallel_opt.py @@ -0,0 +1,237 @@ +import onnx +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto +from onnx import helper, numpy_helper +from typing import Dict, List +from placement import Placement, Replicate, Shard, _Partial +import numpy as np + + +def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): + data = {init.name: init for init in model.graph.initializer} + vinfo = {info.name: info for info in model.graph.value_info} + vinfo.update({info.name: info for info in model.graph.input}) + vinfo.update({info.name: info for info in model.graph.output}) + place: Dict[str, Placement] = {} + nodes: List[NodeProto] = [] + + def is_sharded(name: str): + return place[name].is_shard() + + def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1): + # print(f"shard {tensor.name} at dim {dim}") + assert plc.is_shard(), plc + ndim = len(tensor.dims) + if plc.dim < 0: + plc.dim += ndim + if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard. + return tensor + array = numpy_helper.to_array(tensor) + assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim] + dims = list(tensor.dims) + dims.insert(plc.dim, groups) + dims[plc.dim + 1] //= groups + array = array.reshape(dims) + seg = array.shape[plc.dim + 1] // tp_world_size + array = array.take( + indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1 + ) + dims = list(tensor.dims) + dims[plc.dim] //= tp_world_size + array = array.reshape(dims) + tensor = numpy_helper.from_array(array, name=tensor.name) + place[tensor.name] = plc + return tensor + + def shard_gemm(node: NodeProto, groups: int = 1): + # print("gemm", node.name) + in_plc = place[node.input[0]] + w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0) + transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0) + if transB: + w_plc.dim = ~w_plc.dim + input = node.input[1] + data[input] = shard_tensor(data[input], w_plc, groups) + + output = node.output[0] + ndim = len(vinfo[output].type.tensor_type.shape.dim) + out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() + place[node.output[0]] = out_plc + + def shard_concat(node: NodeProto): + # hack for kvcache + in_plc = place[node.input[1]] + if in_plc.is_shard(): + seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1) + seq_len_dim.dim_value //= tp_world_size + vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim) + place[node.input[0]] = in_plc + place[node.output[0]] = in_plc + + def shard_binary(node: NodeProto, groups: int = 1): + # print("binary", node.name, node.input[0], place[node.input[0]]) + a = node.input[0] + b = node.input[1] + if a in data: + a, b = b, a + place[node.output[0]] = place[a] + if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast + data[b] = shard_tensor(data[b], Shard(0), groups) + + def shard_reshape(node: NodeProto): + # print("reshape", node.name, node.input[0], place[node.input[0]]) + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + s_dim = -1 + in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim] + tensor = data[node.input[1]] + out_dims = numpy_helper.to_array(tensor).copy() + if len(in_dims) == 3 and len(out_dims) == 4: + if in_plc.dim == 0: + s_dim = 1 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 2 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + if len(in_dims) == 3 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + + assert s_dim != -1 + assert out_dims[s_dim] % tp_world_size == 0, out_dims + out_dims[s_dim] //= tp_world_size + # if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others. + # node.input[1] = node.output[0] + "_shape" + data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1]) + place[node.output[0]] = Shard(s_dim) + + def shard_split(node: NodeProto): + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + split_tensor = data[node.input[1]] + split = numpy_helper.to_array(split_tensor).copy() + split //= tp_world_size + data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1]) + for output in node.output: + place[output] = in_plc + + def shard_transpose(node: NodeProto): + plc = place[node.input[0]] + if plc.is_shard(): + perm = next(attr.ints for attr in node.attribute if attr.name == "perm") + place[node.output[0]] = Shard(list(perm).index(plc.dim)) + + def shard_node(node: NodeProto): + if node.op_type in ["Relu", "Tanh", "Softmax"]: + place[node.output[0]] = place[node.input[0]] + elif node.op_type in ["Where"]: + place[node.output[0]] = place[node.input[1]] + if node.op_type in {"Add", "Mul", "Div", "Max"}: + shard_binary(node) + elif node.op_type == "Reshape": + shard_reshape(node) + elif node.op_type == "Transpose": + shard_transpose(node) + elif node.op_type == "Split": + shard_split(node) + elif node.op_type == "MatMul": + assert ( + place[node.input[0]] == place[node.input[1]] + ), f"{place[node.input[0]]} != {place[node.input[1]]}" + place[node.output[0]] = place[node.input[0]] + elif node.op_type == "Concat": + shard_concat(node) + + def find_successor(op_type: str, idx: int, search_limit: int = 1): + for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: + if node.op_type == op_type: + return node + return None + + # all tensors are initially replicated. + for v in vinfo: + place[v] = Replicate() + + for t in data: + place[t] = Replicate() + + for index, node in enumerate(model.graph.node): + nodes.append(node) + # linear + if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( + input in data for input in node.input + ): + groups = 1 + # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups + split_node = find_successor("Split", index, search_limit=2) + if split_node is not None: + groups = len(split_node.output) + shard_gemm(node, groups) + plc = place[node.output[0]] + if plc.is_partial(): + new_name = node.output[0] + f":{plc}" + place[new_name] = place[node.output[0]] + # insert all_reduce + nodes.append( + helper.make_node( + op_type="ReduceSum", + inputs=[new_name], + outputs=[node.output[0]], + name=node.name + "/all_reduce", + noop_with_empty_axes=1, + communicator=0, # hack to treat ReduceSum as AllReduceSum + ) + ) + place[node.output[0]] = Replicate() + node.output[0] = new_name + if len(node.input) > 2: # split bias to add + prev = nodes[-1] + new_name = prev.output[0] + "_no_bias" + place[new_name] = place[node.output[0]] + bias = helper.make_node( + op_type="Add", + inputs=[new_name, node.input[2]], + outputs=[prev.output[0]], + name=node.name + "/bias", + ) + node.input.pop() + prev.output[0] = new_name + shard_binary(bias, groups) + nodes.append(bias) + continue + shard_node(node) + + new_input = [] + for info in model.graph.input: + new_input.append(vinfo[info.name]) + + graph = helper.make_graph( + nodes, + model.graph.name + f"_{tp_rank}", + new_input, + model.graph.output, + data.values(), + doc_string=model.graph.doc_string, + # value_info=vinfo.values(), + ) + for output in graph.output: + tt = output.type.tensor_type + if tt.HasField("shape"): + tt.ClearField("shape") + model = helper.make_model(graph) + model = onnx.shape_inference.infer_shapes(model) + return model diff --git a/examples/distributed/placement.py b/examples/distributed/placement.py new file mode 100644 index 00000000..634d4fe5 --- /dev/null +++ b/examples/distributed/placement.py @@ -0,0 +1,64 @@ +from typing import Optional + + +class Placement: + # base class Placement type + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, Shard): + return self.dim == dim + else: + return isinstance(self, Shard) + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, _Partial) + + +class Replicate(Placement): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Replicate): + return False + return True + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + +class Shard(Placement): + # shard placement, shard on a dim + def __init__(self, dim): + self.dim = dim + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + +class _Partial(Placement): + def __init__(self, reduce_op: str = "sum"): + self.reduce_op: str = reduce_op + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Partial): + return False + return self.reduce_op == other.reduce_op + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"_Partial(reduce_op={self.reduce_op})" diff --git a/examples/python/onnx_inference.py b/examples/python/onnx_inference.py new file mode 100644 index 00000000..152fe37a --- /dev/null +++ b/examples/python/onnx_inference.py @@ -0,0 +1,29 @@ +import sys +import onnx +import torch +import numpy as np +from pyinfinitensor.onnx import OnnxStub, backend + +if __name__ == '__main__': + args = sys.argv + if len(sys.argv) != 2: + print("Usage: python onnx_inference.py model_name.onnx") + exit() + model_path = sys.argv[1] + # print(model_path) + + onnx_model = onnx.load(model_path) + onnx_input = onnx_model.graph.input[0] + input_shape = [[d.dim_value for d in _input.type.tensor_type.shape.dim] + for _input in onnx_model.graph.input] + # Assume that there is only one input tensor + input_shape = input_shape[0] + # print(input_shape) + input_data = np.random.random(input_shape).astype(np.float32) + + model = OnnxStub(onnx_model, backend.cuda_runtime()) + next(iter(model.inputs.values())).copyin_numpy(input_data) + model.run() + outputs = next(iter(model.outputs.values())).copyout_numpy() + outputs = torch.tensor(outputs) + print(outputs.shape) diff --git a/examples/python/resnet_inference.py b/examples/python/resnet_inference.py new file mode 100644 index 00000000..4c58c7a6 --- /dev/null +++ b/examples/python/resnet_inference.py @@ -0,0 +1,24 @@ +import sys +import onnx +import torch +import numpy as np +from pyinfinitensor.onnx import OnnxStub, backend +import torchvision.models as models + +if __name__ == '__main__': + model_path = './resnet18.onnx' + tv_model = models.resnet50(weights=None) + input_shape = (1, 3, 224, 224) + param = torch.rand(input_shape) + torch.onnx.export(tv_model, param, model_path, verbose=False) + + onnx_model = onnx.load(model_path) + model = OnnxStub(onnx_model, backend.cuda_runtime()) + images = np.random.random(input_shape).astype(np.float32) + next(iter(model.inputs.values())).copyin_numpy(images) + model.run() + outputs = next(iter(model.outputs.values())).copyout_numpy() + outputs = torch.tensor(outputs) + outputs = torch.reshape(outputs, (1, 1000)) + _, predicted = torch.max(outputs, 1) + print(predicted) diff --git a/include/bang/bang_runtime.h b/include/bang/bang_runtime.h index 6a40ae37..684e238f 100644 --- a/include/bang/bang_runtime.h +++ b/include/bang/bang_runtime.h @@ -67,6 +67,10 @@ class BangRuntimeObj : public RuntimeObj { CNRT_MEM_TRANS_DIR_PEER2PEER)); } + void initComm(const string &, int, int) override { IT_TODO_HALT(); } + + CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); } + private: void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; }; diff --git a/include/core/common.h b/include/core/common.h index caa61084..749caff2 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -40,12 +40,12 @@ using HashType = uint64_t; // compatible with std::hash // Assert: conditions should have no side effect #define _IT_ASSERT_2(condition, info) \ - (static_cast(condition) \ - ? void(0) \ - : throw ::infini::Exception( \ - std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ - "] Assertion failed (" + #condition + "): " + info)) -#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, ""); + static_cast(condition) \ + ? void(0) \ + : throw ::infini::Exception( \ + std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ + "] Assertion failed (" + #condition + "): " + info) +#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "") #define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) #define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented") diff --git a/include/core/graph.h b/include/core/graph.h index 3efd893f..8415b15a 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -120,6 +120,11 @@ class GraphObj : public Object { * @brief If the nodes is sorted in topological order. */ bool sorted; + + /** + * @brief If the weight tensors are allocated. + */ + bool weightAllocated = false; }; } // namespace infini diff --git a/include/core/lazy_allocator.h b/include/core/lazy_allocator.h index 228639a3..5f073845 100644 --- a/include/core/lazy_allocator.h +++ b/include/core/lazy_allocator.h @@ -20,14 +20,23 @@ class LazyAllocator { Runtime runtime; - size_t used; + size_t used = 0; - size_t peak; + size_t peak = 0; + + size_t weightPeak = 0; size_t alignment; // pointer to the memory actually allocated - void *ptr; + void *ptr = nullptr; + + // pointer to the weight memory space + void *weightPtr = nullptr; + + // // a cache designed for a batch size that has already occurred + // std::unordered_map> + // batchsizeToTensorOffset; struct freeBlockInfo { size_t addr; @@ -57,12 +66,16 @@ class LazyAllocator { virtual ~LazyAllocator(); + void init(); + // function: simulate memory allocation // arguments: // size: size of memory block to be allocated // return: head address offset of the allocated memory block size_t alloc(size_t size); + size_t allocWeight(size_t size); + // function: simulate memory free // arguments: // addr: head address offset of memory block to be free @@ -73,6 +86,12 @@ class LazyAllocator { // return: pointer to the head address of the allocated memory void *getPtr(); + // void addCache(size_t batchsize, std::unordered_map); + + // std::unordered_map getCache(size_t batchsize); + + void *getWeightPtr(); + void info(); private: diff --git a/include/core/tensor.h b/include/core/tensor.h index 1f4ea57b..48590fd6 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,5 +1,6 @@ #pragma once #include "core/tensor_base.h" +#include "core/tensor_type.h" #include "utils/data_convert.h" #include #include @@ -19,6 +20,8 @@ class TensorObj : public TensorBaseObj { size_t _size; // Cache of Π(shape). Fuid fuid; // Cloned tensors share the same id. Tensors constructed from // scratch have a new id. + TensorType tensorType = TensorType::others; + public: TensorObj(Shape shape, DataType dtype, Runtime runtime); virtual ~TensorObj() {} @@ -33,6 +36,33 @@ class TensorObj : public TensorBaseObj { size_t getOffset(const vector &ds) const; void dataMalloc(); UidBaseType getFuid() const { return fuid; } + bool isWeight() const { return tensorType == TensorType::weight; } + bool isInput() const { return tensorType == TensorType::input; } + bool isOutput() const { return tensorType == TensorType::output; } + bool isOthers() const { return tensorType == TensorType::others; } + void setWeight() { tensorType = TensorType::weight; } + void setInput() { tensorType = TensorType::input; } + void setOutput() { tensorType = TensorType::output; } + string tensorTypeToString() const { + switch (tensorType) { + case TensorType::weight: + return "weight"; + break; + case TensorType::input: + return "input"; + break; + case TensorType::output: + return "output"; + break; + case TensorType::others: + return "others"; + break; + + default: + return "unknown tensor type"; + break; + } + } void load(std::string file_path); void save(std::string file_path); diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index 54d65ffd..05a8a727 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -44,6 +44,7 @@ class TensorBaseObj : public Object { } DataType getDType() const { return dtype; } + int getDTypeIndex() const { return dtype.getIndex(); } Runtime getRuntime() const { return runtime; } // std::pair getOutputOfWithIndex(); diff --git a/include/core/tensor_type.h b/include/core/tensor_type.h new file mode 100644 index 00000000..46df6073 --- /dev/null +++ b/include/core/tensor_type.h @@ -0,0 +1,7 @@ +#pragma once + +namespace infini { + +enum class TensorType { weight, input, output, others }; + +} // namespace infini diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index dec9a40b..4eb75f27 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -6,16 +6,11 @@ #include #include -// TODO: replace with Exception (IT_ASSERT) #define checkCudaError(call) \ - { \ - auto err = call; \ - if (cudaSuccess != err) { \ - fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \ - cudaGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } + if (auto err = call; err != cudaSuccess) \ + throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \ + std::to_string(__LINE__) + "] CUDA error (" + \ + #call + "): " + cudaGetErrorString(err)) #define checkCUresult(call) \ { \ @@ -39,14 +34,10 @@ } #define checkCudnnError(call) \ - { \ - auto err = call; \ - if (CUDNN_STATUS_SUCCESS != err) { \ - fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \ - __LINE__, cudnnGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } + if (auto err = call; err != CUDNN_STATUS_SUCCESS) \ + throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \ + std::to_string(__LINE__) + "] cuDNN error (" + \ + #call + "): " + cudnnGetErrorString(err)) #define checkCurandError(call) \ { \ diff --git a/include/cuda/cuda_expand.h b/include/cuda/cuda_expand.h index b53c4ce4..8d4701fd 100644 --- a/include/cuda/cuda_expand.h +++ b/include/cuda/cuda_expand.h @@ -3,7 +3,7 @@ #include "operators/unary.h" #include "utils/small_array.h" namespace infini { -void expand_kernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape); +void expandKernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape); }; // namespace infini diff --git a/include/cuda/cuda_where.h b/include/cuda/cuda_where.h index 14d9bc73..15ad29ec 100644 --- a/include/cuda/cuda_where.h +++ b/include/cuda/cuda_where.h @@ -3,11 +3,9 @@ #include "utils/small_array.h" namespace infini { -void where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape); +void whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape); }; // namespace infini diff --git a/include/utils/broadcast_shape.h b/include/utils/broadcast_shape.h new file mode 100644 index 00000000..1f45ddcc --- /dev/null +++ b/include/utils/broadcast_shape.h @@ -0,0 +1,14 @@ +#pragma once + +namespace infini { +void broadcastShape(const Shape &originShape, SmallArray &modifyShape, + int nDims, int size) { + for (int i = nDims - 1; i >= 0; --i) { + modifyShape.data[i] = 1; + } + for (int i = size - 1; i >= 0; --i) { + modifyShape.data[i + nDims - size] = originShape[i]; + } +} + +} // namespace infini diff --git a/include/utils/exception.h b/include/utils/exception.h index f2a7dc94..d7bb4331 100644 --- a/include/utils/exception.h +++ b/include/utils/exception.h @@ -5,8 +5,18 @@ namespace infini { class Exception : public std::runtime_error { + protected: + std::string info; + public: Exception(const std::string &msg); + + Exception &operator<<(const std::string &str) { + info += str; + return *this; + } + + const char *what() const noexcept override { return info.c_str(); } }; } // namespace infini diff --git a/include/utils/small_array.h b/include/utils/small_array.h index d0e29a09..3ea93279 100644 --- a/include/utils/small_array.h +++ b/include/utils/small_array.h @@ -1,3 +1,4 @@ +#pragma once namespace infini { #define SMALL_ARRAY_SIZE 8 diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index b0a433ba..f8e53b1c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,35 +32,37 @@ class OnnxStub: The Onnx model imported into infinitensor. It can be generated from an Onnx model object. """ - - inputs: Dict[str, backend.Tensor] = {} - outputs: Dict[str, backend.Tensor] = {} - initializer: Dict[int, TensorProto] = {} - handler: backend.GraphHandler - def __init__(self, model: ModelProto, runtime): + self.inputs: Dict[str, backend.Tensor] = {} + self.outputs: Dict[str, backend.Tensor] = {} + self.initializer: Dict[int, TensorProto] = {} model = infer_shapes(model) self.handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() data: Dict[str, TensorProto] = dict() + for initializer in model.graph.initializer: + dims = [d for d in initializer.dims] + tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) + data[initializer.name] = initializer + tensors[initializer.name].set_weight() + for input in model.graph.input: dims = _take_shape_dim(input.type.tensor_type.shape) - tensors[input.name] = self.handler.tensor( - dims, input.type.tensor_type.elem_type - ) + if input.name not in tensors.keys(): + tensors[input.name] = self.handler.tensor( + dims, input.type.tensor_type.elem_type + ) + tensors[input.name].set_input() for output in model.graph.output: dims = _take_shape_dim(output.type.tensor_type.shape) tensors[output.name] = self.handler.tensor( dims, output.type.tensor_type.elem_type ) + tensors[output.name].set_output() - for initializer in model.graph.initializer: - dims = [d for d in initializer.dims] - tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) - data[initializer.name] = initializer node_name = [] new_node_name = [] @@ -591,6 +593,13 @@ class OnnxStub: tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "to")), ) + elif node.op_type == "ReduceSum": + # ReduceSum is only implemented as allReduceSum. + assert any(attr.name == "communicator" for attr in node.attribute) + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "AllReduceSum": tensors[node.output[0]] = self.handler.allReduceSum( tensors[node.input[0]], @@ -631,13 +640,9 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), next( - ( - attr.i - for attr in node.attribute - if attr.name == "root" - ), - 0, - ), + (attr.i for attr in node.attribute if attr.name == "root"), + 0, + ), ) elif node.op_type == "Expand": shape = _parse_data(data[node.input[1]]) @@ -658,6 +663,15 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Constant": + output_name = node.output[0] + attributes = _parse_attribute(node) + tensor = attributes['value'] + dims = [d for d in tensor.dims] + tensors[output_name] = self.handler.tensor( + dims, tensor.data_type) + data[output_name] = tensor + tensors[output_name].set_weight() else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) new_node_name.append(node.name) @@ -1062,19 +1076,18 @@ def _search_shape(model: ModelProto, name: str) -> List[int]: def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: - if attr.name in attrs: - if attr.type == AttributeProto.INT: - attrs[attr.name] = attr.i - elif attr.type == AttributeProto.INTS: - attrs[attr.name] = attr.ints - elif attr.type == AttributeProto.FLOAT: - attrs[attr.name] = attr.f - elif attr.type == AttributeProto.STRING: - attrs[attr.name] = attr.s - elif attr.type == AttributeProto.TENSOR: - attrs[attr.name] = attr.t - else: - assert False, "Unsupported Attribute Type: {}".format(attr.type) + if attr.type == AttributeProto.INT: + attrs[attr.name] = attr.i + elif attr.type == AttributeProto.INTS: + attrs[attr.name] = attr.ints + elif attr.type == AttributeProto.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == AttributeProto.STRING: + attrs[attr.name] = attr.s + elif attr.type == AttributeProto.TENSOR: + attrs[attr.name] = attr.t + else: + assert False, "Unsupported Attribute Type: {}".format(attr.type) return attrs diff --git a/pyinfinitensor/tests/test_api.py b/pyinfinitensor/tests/test_api.py new file mode 100644 index 00000000..d0d77b88 --- /dev/null +++ b/pyinfinitensor/tests/test_api.py @@ -0,0 +1,65 @@ +import os, unittest +from onnx import TensorProto +from pyinfinitensor import backend +import numpy as np + + +class TestPythonAPI(unittest.TestCase): + def test_copyin_numpy(self): + dims = [2, 3, 5, 4] + np_array = np.random.random(dims).astype(np.float32) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT) + tensor2 = handler.tensor(dims, TensorProto.FLOAT) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + tensor2.copyin_float(np_array.flatten().tolist()) + array1 = tensor1.copyout_float() + array2 = tensor2.copyout_float() + self.assertEqual(array1, array2) + self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) + + np_array = np.random.random(dims).astype(np.int64) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.INT64) + tensor2 = handler.tensor(dims, TensorProto.INT64) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + tensor2.copyin_int64(np_array.flatten().tolist()) + array1 = tensor1.copyout_int64() + array2 = tensor2.copyout_int64() + self.assertEqual(array1, array2) + self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) + + def test_copyout_numpy(self): + dims = [2, 3, 5, 4] + np_array = np.random.random(dims).astype(np.float32) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT) + tensor2 = handler.tensor(dims, TensorProto.FLOAT) + handler.data_malloc() + tensor1.copyin_float(np_array.flatten().tolist()) + tensor2.copyin_float(np_array.flatten().tolist()) + array1 = np.array(tensor1.copyout_float()).reshape(dims) + array2 = tensor2.copyout_numpy() + self.assertTrue(np.array_equal(array2, np_array)) + self.assertTrue(np.array_equal(array1, array2)) + + np_array = np.random.random(dims).astype(np.float16) + np_array[0, 0, 0, 0] = .1 + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT16) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + array1 = tensor1.copyout_numpy() + # Copy should be the same as original array + self.assertTrue(np.array_equal(array1, np_array)) + # Modify the value so that tensorObj value changes + np_array[0, 0, 0, 0] = 0. + tensor1.copyin_numpy(np_array) + # The copied-out array should not change + self.assertFalse(np.array_equal(array1, np_array)) + + +if __name__ == "__main__": + unittest.main() diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 884bd874..6d041ed2 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase): [pads_data], ) ) - + def test_allReduceSum(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([allReduceProd], "allReduceProd", [input], [output]) model = make_model(graph) from_onnx(model, backend.cpu_runtime()) - + def test_allReduceMin(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output]) model = make_model(graph) from_onnx(model, backend.cpu_runtime()) - + def test_split(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) - split = make_node( - "Split", ["input"], ["output"], name="split", axis=0 - ) + split = make_node("Split", ["input"], ["output"], name="split", axis=0) make_and_import_model(make_graph([split], "split", [input], [])) - + def test_allBroadcast(self): input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4]) output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4]) @@ -460,53 +458,6 @@ class TestStringMethods(unittest.TestCase): where = make_node("Where", ["x", "y", "con"], ["output"], name="where") make_and_import_model(make_graph([where], "where", [x, y, con], [output])) - def test_copyin(self): - dims = [2,3,5,4] - np_array = np.random.random(dims).astype(np.float32) - handler = backend.GraphHandler(backend.cpu_runtime()) - tensor1 = handler.tensor(dims, TensorProto.FLOAT) - tensor2 = handler.tensor(dims, TensorProto.FLOAT) - handler.data_malloc() - tensor1.copyin_numpy(np_array) - tensor2.copyin_float(np_array.flatten().tolist()) - array1 = tensor1.copyout_float() - array2 = tensor2.copyout_float() - self.assertEqual(array1, array2) - self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) - - np_array = np.random.random(dims).astype(np.int64) - handler = backend.GraphHandler(backend.cpu_runtime()) - tensor1 = handler.tensor(dims, TensorProto.INT64) - tensor2 = handler.tensor(dims, TensorProto.INT64) - handler.data_malloc() - tensor1.copyin_numpy(np_array) - tensor2.copyin_int64(np_array.flatten().tolist()) - array1 = tensor1.copyout_int64() - array2 = tensor2.copyout_int64() - self.assertEqual(array1, array2) - self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) - - def test_to_numpy(self): - dims = [2,3,5,4] - np_array = np.random.random(dims).astype(np.float32) - handler = backend.GraphHandler(backend.cpu_runtime()) - tensor1 = handler.tensor(dims, TensorProto.FLOAT) - tensor2 = handler.tensor(dims, TensorProto.FLOAT) - handler.data_malloc() - tensor1.copyin_float(np_array.flatten().tolist()) - tensor2.copyin_float(np_array.flatten().tolist()) - array1 = np.array(tensor1.copyout_float()).reshape(dims) - array2 = np.array(tensor2) - self.assertTrue(np.array_equal(array2, np_array)) - self.assertTrue(np.array_equal(array1, array2)) - - np_array = np.random.random(dims).astype(np.float16) - handler = backend.GraphHandler(backend.cpu_runtime()) - tensor1 = handler.tensor(dims, TensorProto.FLOAT16) - handler.data_malloc() - tensor1.copyin_numpy(np_array) - array1 = np.array(tensor1, copy=False) - self.assertTrue(np.array_equal(array1, np_array)) if __name__ == "__main__": unittest.main() diff --git a/src/core/graph.cc b/src/core/graph.cc index 05f45fae..0f844c34 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -131,30 +131,63 @@ void GraphObj::dataMalloc() { // record the memory address offsets of all tensors to be allocated std::unordered_map tensorToOffset; - // record all constant tensors, including weight tensors and input tensors - std::unordered_set constTensor; + // reinit allocator + allocator.init(); + + // record all weight tensors, including weight tensors and kvcache + // tensors + std::unordered_set weightTensors; for (auto &tensor : tensors) { - if (tensor.get()->getSource() == nullptr) { - // allocate memory for all constant tensors first, and this memory + if (tensor->isWeight()) { + // allocate memory for all weight tensors first, and this memory + // will not be freed until the graph is destroyed + weightTensors.insert(tensor.get()); + if (!this->weightAllocated) { + tensorToOffset[tensor.get()] = + allocator.allocWeight(tensor->getBytes()); + } + } else if (tensor->isInput() || tensor->isOutput()) { + // allocate memory for all input and output tensors, and this memory // will not be reused later - constTensor.insert(tensor.get()); tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); } else { tensorToRefCount[tensor.get()] = tensor->getTargets().size(); + // allocate memory for all user-created tensors + if (tensor.get()->getSource() == nullptr) { + tensorToOffset[tensor.get()] = + allocator.alloc(tensor->getBytes()); + } + } + } + // if memory has not yet been allocated for weight tensors, + // allocate memory now and do not allocate again in the future. + if (!this->weightAllocated) { + this->weightAllocated = true; + // only allocate once for weight tensors + for (auto &tensor : weightTensors) { + IT_ASSERT(tensorToOffset.find(tensor) != tensorToOffset.end()); + tensor->setDataBlob(make_ref( + tensor->runtime, + static_cast(allocator.getWeightPtr()) + + tensorToOffset[tensor])); } } // traverse in topological order and simulate memory allocation for (auto &op : ops) { - // memory should be allocated for the output first + // memory should be allocated for the op's output first auto outputs = op->getOutputs(); for (auto &tensor : outputs) { - tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); + if (tensor->isOthers()) { + tensorToOffset[tensor.get()] = + allocator.alloc(tensor->getBytes()); + } } auto inputs = op->getInputs(); for (auto &tensor : inputs) { - if (constTensor.find(tensor.get()) == constTensor.end()) { + if (tensor->isOthers()) { auto tensorIter = tensorToRefCount.find(tensor.get()); IT_ASSERT(tensorIter != tensorToRefCount.end()); + IT_ASSERT(tensorToRefCount[tensor.get()] > 0); tensorToRefCount[tensor.get()] -= 1; if (tensorToRefCount[tensor.get()] == 0) { // indicate that this tensor will no longer be used and @@ -167,15 +200,20 @@ void GraphObj::dataMalloc() { } } - // perform actual memory allocation + // perform actual memory allocation for non-weight tensors for (auto &tensor : tensors) { - IT_ASSERT(tensorToOffset.find(tensor.get()) != tensorToOffset.end()); - tensor->setDataBlob(make_ref( - tensor->runtime, static_cast(allocator.getPtr()) + - tensorToOffset[tensor.get()])); + if (!tensor->isWeight()) { + IT_ASSERT(tensorToOffset.find(tensor.get()) != + tensorToOffset.end()); + tensor->setDataBlob(make_ref( + tensor->runtime, static_cast(allocator.getPtr()) + + tensorToOffset[tensor.get()])); + } } +#ifdef DEBUG_MODE allocator.info(); +#endif } Tensor GraphObj::addTensor(Shape dim, DataType dtype) { diff --git a/src/core/lazy_allocator.cc b/src/core/lazy_allocator.cc index bb7f766f..a5014e5c 100644 --- a/src/core/lazy_allocator.cc +++ b/src/core/lazy_allocator.cc @@ -11,9 +11,6 @@ namespace infini { constexpr size_t alignmentInBytesForCUDA = 256; LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) { - used = 0; - peak = 0; - ptr = nullptr; if (runtime->isCuda()) { // TODO: the alignment on cuda might need further discussion alignment = alignmentInBytesForCUDA; @@ -30,10 +27,24 @@ LazyAllocator::~LazyAllocator() { if (this->ptr != nullptr) { runtime->dealloc(this->ptr); } + if (this->weightPtr != nullptr) { + runtime->dealloc(this->weightPtr); + } +} + +void LazyAllocator::init() { + used = 0; + peak = 0; + freeBlocks.clear(); + headAddrToBlockSize.clear(); + tailAddrToBlockSize.clear(); + if (this->ptr != nullptr) { + runtime->dealloc(this->ptr); + } + this->ptr = nullptr; } size_t LazyAllocator::alloc(size_t size) { - IT_ASSERT(this->ptr == nullptr); // pad the size to the multiple of alignment size = this->getAlignedSize(size); auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size}); @@ -83,6 +94,14 @@ size_t LazyAllocator::alloc(size_t size) { return retAddr; } +size_t LazyAllocator::allocWeight(size_t size) { + IT_ASSERT(this->weightPtr == nullptr); + size = this->getAlignedSize(size); + size_t retAddr = this->weightPeak; + this->weightPeak += size; + return retAddr; +} + void LazyAllocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); @@ -126,18 +145,33 @@ void LazyAllocator::free(size_t addr, size_t size) { void *LazyAllocator::getPtr() { if (this->ptr == nullptr) { this->ptr = runtime->alloc(this->peak); - printf("LazyAllocator really alloc: %p %lu bytes\n", this->ptr, peak); +#ifdef DEBUG_MODE + printf("LazyAllocator really alloc non-weight: %p %lu bytes\n", + this->ptr, peak); +#endif } return this->ptr; } +void *LazyAllocator::getWeightPtr() { + if (this->weightPtr == nullptr) { + this->weightPtr = runtime->alloc(this->weightPeak); +#ifdef DEBUG_MODE + printf("LazyAllocator really alloc weight: %p %lu bytes\n", + this->weightPtr, weightPeak); +#endif + } + return this->weightPtr; +} + size_t LazyAllocator::getAlignedSize(size_t size) { return ((size - 1) / this->alignment + 1) * this->alignment; } void LazyAllocator::info() { - std::cout << "Used memory: " << this->used - << ", peak memory: " << this->peak << std::endl; + std::cout << "Used memory: " << this->used + this->weightPeak + << ", peak memory: " << this->peak + this->weightPeak + << std::endl; } } // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index d867adfd..e34fb8bc 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -23,7 +23,7 @@ string TensorObj::toString() const { string ret = "Tensor " + std::to_string(guid) + ", Fuid " + std::to_string(fuid) + ", shape " + vecToString(shape) + ", dtype " + dtype.toString() + ", " + runtime->toString() + - ", " + ss.str() + "\n"; + ", " + ss.str() + ", " + tensorTypeToString() + "\n"; vector targetGuids; for (const auto &op : targets) targetGuids.emplace_back(op.lock()->getGuid()); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 927b1f0d..0676646a 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -8,7 +8,6 @@ #include "operators/conv.h" #include "operators/matmul.h" -#ifdef DEBUG_MODE void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) { cudaError_t kernelError = cudaGetLastError(); if (kernelError != cudaSuccess) { @@ -18,7 +17,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) { exit(EXIT_FAILURE); } } -#endif namespace infini { @@ -38,10 +36,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { } else { kernel->compute(op, this); } - -#ifdef DEBUG_MODE - CHECK_CUDA_KERNEL_ERROR(op); -#endif + checkCudaError(cudaGetLastError()) << op->toString(); } } @@ -78,9 +73,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { opCnt[op->getOpType()]++; } -#ifdef DEBUG_MODE - CHECK_CUDA_KERNEL_ERROR(op); -#endif + checkCudaError(cudaGetLastError()) << op->toString(); } } @@ -103,6 +96,7 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) { IT_ASSERT(worldSize > 0); IT_ASSERT(rank >= 0); IT_ASSERT(rank < worldSize); + IT_ASSERT(!comm) << "communicator is already initialized."; #ifdef INFINI_USE_NCCL comm = std::make_unique(name, worldSize, rank); #else diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 4df6df2a..71bfcbae 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -317,6 +317,44 @@ void export_functions(py::module &m) { #undef FUNCTION } +// A helper function that converts DataType to python format string +static std::string getFormat(DataType type) { + std::string format; + if (type == DataType::Float32) { + format = py::format_descriptor::format(); + } else if (type == DataType::Double) { + format = py::format_descriptor::format(); + } else if (type == DataType::Int32) { + format = py::format_descriptor::format(); + } else if (type == DataType::UInt32) { + format = py::format_descriptor::format(); + } else if (type == DataType::Int64) { + format = py::format_descriptor::format(); + } else if (type == DataType::UInt64) { + format = py::format_descriptor::format(); + } else if (type == DataType::Int16) { + format = py::format_descriptor::format(); + } else if (type == DataType::UInt16) { + format = py::format_descriptor::format(); + } else if (type == DataType::Int8) { + format = py::format_descriptor::format(); + } else if (type == DataType::UInt8) { + format = py::format_descriptor::format(); + } else if (type == DataType::Bool) { + format = py::format_descriptor::format(); + } else if (type == DataType::Float16 || type == DataType::BFloat16) { + // Python uses "e" for half precision float type code. + // Check the following link for more information. + // https://docs.python.org/3/library/struct.html#format-characters + format = "e"; + } else { + throw std::runtime_error("Error converting TensorObj to " + "Numpy: unsupported datatype.\n"); + } + + return format; +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; @@ -341,6 +379,10 @@ void init_graph_builder(py::module &m) { py::buffer_protocol()) .def("fuid", &TensorObj::getFuid, policy::automatic) .def("shape", &TensorObj::getDims, policy::move) + .def("set_weight", &TensorObj::setWeight, policy::move) + .def("set_input", &TensorObj::setInput, policy::move) + .def("set_output", &TensorObj::setOutput, policy::move) + .def("dtype", &TensorObj::getDTypeIndex, policy::automatic) .def("copyin_float", &TensorObj::copyin, policy::move) .def("copyin_int32", &TensorObj::copyin, policy::move) .def("copyin_int64", &TensorObj::copyin, policy::move) @@ -367,51 +409,24 @@ void init_graph_builder(py::module &m) { } self.copyin(data_np, self.getBytes()); }) - // A buffer can be used to convert a TensorObj directly to Numpy array - // without copy - .def_buffer([](TensorObj &self) -> py::buffer_info { - vector stride_byte; - for (int s : self.getStride()) { - stride_byte.push_back(s * self.getDType().getSize()); - } + // Return a Numpy array which copies the values of this tensor + .def("copyout_numpy", + [](TensorObj &self) -> py::array { + vector stride_byte; + for (int s : self.getStride()) { + stride_byte.push_back(s * self.getDType().getSize()); + } + std::string format = getFormat(self.getDType()); - std::string format; - if (self.getDType() == DataType::Float32) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Double) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Int32) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::UInt32) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Int64) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::UInt64) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Int16) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::UInt16) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Int8) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::UInt8) { - format = py::format_descriptor::format(); - } else if (self.getDType() == DataType::Float16 || - self.getDType() == DataType::BFloat16) { - // Python uses "e" for half precision float type code. - // Check the following link for more information. - // https://docs.python.org/3/library/struct.html#format-characters - format = "e"; - } else { - throw std::runtime_error("Error converting TensorObj to " - "Numpy: unsupported datatype.\n"); - } + py::array numpy_array(py::dtype(format), self.getDims(), + nullptr); - return py::buffer_info(self.getRawDataPtr(), - self.getDType().getSize(), format, - self.getRank(), self.getDims(), stride_byte, - true); // Read-only = true - }) + // Copy data to the numpy array + auto ptr = numpy_array.mutable_data(); + self.copyout(ptr, self.getBytes()); + + return numpy_array; + }) .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getSource, policy::move) .def("printData", &TensorObj::printData, policy::automatic); @@ -436,6 +451,8 @@ void init_graph_builder(py::module &m) { .def("mul", &Handler::mul, policy::move) .def("div", &Handler::div, policy::move) .def("pow", &Handler::pow, policy::move) + .def("min", &Handler::min, policy::move) + .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index 2728b5e2..ef60b991 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -14,7 +14,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { void *input = op->getInputs(0)->getRawDataPtr(); void *output = op->getOutput()->getRawDataPtr(); IT_ASSERT(op->getDType() == DataType::Float32); - size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize(); + size_t count = op->getInputs(0)->size(); ncclComm_t comm = dynamic_cast(context->getCommunicator()) diff --git a/src/kernels/cuda/expand.cc b/src/kernels/cuda/expand.cc index b8154d49..acbf5cd2 100644 --- a/src/kernels/cuda/expand.cc +++ b/src/kernels/cuda/expand.cc @@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig { inputShape.data[i] = in_Shape[i]; outputsize *= out_Shape[i]; } - expand_kernel((float *)inputData, (float *)outputData, nDims, - outputsize, inputShape, outputShape); + expandKernel((float *)inputData, (float *)outputData, nDims, outputsize, + inputShape, outputShape); } }; diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index e1649b81..09405d09 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -6,9 +6,9 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _expand_kernel(float *input, float *output, int nDims, - int outputsize, infini::SmallArray inputShape, - infini::SmallArray outputShape) { +__global__ void _expandKernel(float *input, float *output, int nDims, + int outputsize, infini::SmallArray inputShape, + infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s @@ -38,12 +38,12 @@ __global__ void _expand_kernel(float *input, float *output, int nDims, } namespace infini { -void expand_kernel(float *input, float *output, int nDims, int outputsize, - SmallArray inputShape, SmallArray outputShape) { +void expandKernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape) { int blocksize = block_work_size(); int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); - _expand_kernel<<>>(input, output, nDims, outputsize, - inputShape, outputShape); + _expandKernel<<>>(input, output, nDims, outputsize, + inputShape, outputShape); } } // namespace infini diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index a2b55e04..9cd4b0b3 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -1,6 +1,8 @@ #include "operators/matmul.h" #include "core/kernel.h" +#include "cuda/cuda_expand.h" #include "cuda/cuda_runtime.h" +#include "utils/small_array.h" namespace infini { @@ -46,7 +48,30 @@ class matmulCublas : public Kernel { auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N; const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n, ldc = n; - const float alpha = 1.f, beta = 0.f; + float alpha = 1.f, beta = 0.f; + if (op->numInputs() == 2) { // no bias + beta = 0.f; + } else { // broadcast bias to output + beta = 1.f; + auto inC = op->getInputs(2); + auto out = op->getOutput(); + SmallArray inputShape, outputShape; + int nDims = out->getRank(); + IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + int outputsize = 1; // the length of the output vector after flatten + int offset = nDims - inC->getRank(); + for (int i = 0; i < offset; ++i) + inputShape.data[i] = 1; + for (int i = 0; i < nDims; ++i) { + outputShape.data[i] = out->getDims()[i]; + outputsize *= outputShape.data[i]; + if (i >= offset) + inputShape.data[i] = inC->getDims()[i - offset]; + } + expandKernel(inC->getRawDataPtr(), + out->getRawDataPtr(), nDims, outputsize, + inputShape, outputShape); + } // TODO:use compute type cublasStatus_t stat; if (b > 1) { diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index 86f5524b..73f29482 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -1,30 +1,29 @@ #include "cuda/cuda_common.h" #include "cuda/cuda_split_concat.h" -int getMultiProcessorCount() { - int cur_device; - checkCudaError(cudaGetDevice(&cur_device)); - - struct cudaDeviceProp prop; - checkCudaError(cudaGetDeviceProperties(&prop, cur_device)); - return prop.multiProcessorCount; -} - __host__ __device__ int elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim, int nDim, ComposedTensorMetadata wholeMeta) { int offset = 0; + // COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1) + // where k=dim, n=ndim, d=dimSize is the splited length of + // dimension dim #pragma unroll + // Interate through n-1 to 1 for (int i = nDim - 1; i >= 1; --i) { int size = (i == dim) ? dimSize : wholeMeta.dimSize[i]; int p = elementIndex % size; + // dimBgNo move the pointer to correct location in composed data + // corresponding to current element, with repect to the splitted + // dimension dim int oP = (i == dim) ? (p + dimBgNo) : p; elementIndex = (elementIndex - p) / size; offset += oP * wholeMeta.stride[i]; } - - return offset + elementIndex * wholeMeta.stride[0]; + // Deal with i = 0 + int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex; + return offset + oP * wholeMeta.stride[0]; } __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, @@ -38,31 +37,29 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, auto dimBgNo = elemMeta.dimBgNo[blockIdx.y]; auto dimSize = elemMeta.dimSize[blockIdx.y]; float *elemData = elemMeta.data[blockIdx.y]; - int stride = gridDim.x * blockDim.x; - while (tid < nElements) { - int Offset = - elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); - // copy data from input to output - // for split:input is composed tensor;for concat:input is element - // tensors. - if (isSplit) - elemData[tid] = compMeta.data[Offset]; - else - compMeta.data[Offset] = elemData[tid]; - tid += stride; - } + int Offset = + elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta); + // copy data from input to output + // for split:input is composed tensor;for concat:input is element + // tensors. + if (isSplit) + elemData[tid] = compMeta.data[Offset]; + else + compMeta.data[Offset] = elemData[tid]; } namespace infini { +// TODO: when dim=0, the operation can be executed in-place void split_concat_kernel(const ElementTensorMetadata &eleMeta, const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit) { dim3 blockSize = dim3(32 * 16); - - // y dim is number of tensors. - dim3 gridSize(getMultiProcessorCount(), batchSize); + // gridsize =n_elements / blockSize + int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1; + // each y is a split among the batch + dim3 gridSize(gridDimX, batchSize); _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, isSplit); diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc index 4769fea0..9898ab7d 100644 --- a/src/kernels/cuda/where.cc +++ b/src/kernels/cuda/where.cc @@ -2,6 +2,7 @@ #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_where.h" +#include "utils/broadcast_shape.h" namespace infini { @@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig { const RuntimeObj *_context) const override { auto op = as(_op); - void *const inputxData = (op->getInputs(0)->getRawDataPtr()); - void *const inputyData = (op->getInputs(1)->getRawDataPtr()); + void *const inputXData = (op->getInputs(0)->getRawDataPtr()); + void *const inputYData = (op->getInputs(1)->getRawDataPtr()); void *const conditionData = (op->getInputs(2)->getRawDataPtr()); void *const outputData = (op->getOutput()->getRawDataPtr()); - const auto &inputx_Shape = op->getInputs(0)->getDims(); - const auto &inputy_Shape = op->getInputs(1)->getDims(); - const auto &condition_Shape = op->getInputs(2)->getDims(); - const auto &output_Shape = op->getOutput()->getDims(); + const auto &opInputXShape = op->getInputs(0)->getDims(); + const auto &opInputYShape = op->getInputs(1)->getDims(); + const auto &opConditionShape = op->getInputs(2)->getDims(); + const auto &opOutputShape = op->getOutput()->getDims(); - int nDims = op->getInputs(0)->getDims().size(); + const int xSize = op->getInputs(0)->getRank(); + const int ySize = op->getInputs(1)->getRank(); + const int cSize = op->getInputs(2)->getRank(); + int nDims = op->getOutput()->getDims().size(); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); - SmallArray inputxShape, inputyShape, conditionShape, outputShape; - for (int i = 0; i < nDims; ++i) { - inputxShape.data[i] = inputx_Shape[i]; - inputyShape.data[i] = inputy_Shape[i]; - conditionShape.data[i] = condition_Shape[i]; - outputShape.data[i] = output_Shape[i]; + SmallArray inputXShape, inputYShape, conditionShape, outputShape; + for (int i = nDims - 1; i >= 0; --i) { + outputShape.data[i] = opOutputShape[i]; } - where_kernel((float *)inputxData, (float *)inputyData, - (float *)conditionData, (float *)outputData, nDims, - inputxShape, inputyShape, conditionShape, outputShape); + + broadcastShape(opInputXShape, inputXShape, nDims, xSize); + broadcastShape(opInputYShape, inputYShape, nDims, ySize); + broadcastShape(opConditionShape, conditionShape, nDims, cSize); + + whereKernel((float *)inputXData, (float *)inputYData, + (uint8_t *)conditionData, (float *)outputData, nDims, + inputXShape, inputYShape, conditionShape, outputShape); } }; diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index 7d34098c..ce6579f8 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -1,20 +1,20 @@ #include "cuda/cuda_common.h" #include "utils/small_array.h" -__global__ void _where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - int outputsize, infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape) { +__global__ void _whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + int outputsize, infini::SmallArray inputXShape, + infini::SmallArray inputYShape, + infini::SmallArray conditionShape, + infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; if (outputIdx < outputsize) { - int inputxIdx = 0; - int temp_inputx = 1; + int inputXIdx = 0; + int temp_inputX = 1; - int inputyIdx = 0; - int temp_inputy = 1; + int inputYIdx = 0; + int temp_inputY = 1; int conditionIdx = 0; int temp_condition = 1; @@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy, } else { tmp = v % outputShape.data[i]; // store s,k,j in order } - if (inputxShape.data[i] == 1) { - inputxIdx += 0; + if (inputXShape.data[i] == 1) { + inputXIdx += 0; } else { - inputxIdx += + inputXIdx += tmp * - temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s + temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s } - temp_inputx *= inputxShape.data[i]; + temp_inputX *= inputXShape.data[i]; //---------------------------- - if (inputyShape.data[i] == 1) { - inputyIdx += 0; + if (inputYShape.data[i] == 1) { + inputYIdx += 0; } else { - inputyIdx += + inputYIdx += tmp * - temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s + temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s } - temp_inputy *= inputyShape.data[i]; + temp_inputY *= inputYShape.data[i]; //-------------------------- if (conditionShape.data[i] == 1) { conditionIdx += 0; @@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy, v = v / outputShape.data[i]; } output[outputIdx] = - condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx]; + condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx]; } } namespace infini { -void where_kernel(const float *inputx, const float *inputy, - const float *condition, float *output, int nDims, - infini::SmallArray inputxShape, - infini::SmallArray inputyShape, - infini::SmallArray conditionShape, - infini::SmallArray outputShape) { +void whereKernel(const float *inputX, const float *inputY, + const uint8_t *condition, float *output, int nDims, + SmallArray inputXShape, SmallArray inputYShape, + SmallArray conditionShape, SmallArray outputShape) { int outputsize = 1; for (int i = 0; i < nDims; i++) { @@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy, } int blocksize = 32 * 16; int gridsize = (outputsize + blocksize - 1) / blocksize; - _where_kernel<<>>( - inputx, inputy, condition, output, nDims, outputsize, inputxShape, - inputyShape, conditionShape, outputShape); + _whereKernel<<>>( + inputX, inputY, condition, output, nDims, outputsize, inputXShape, + inputYShape, conditionShape, outputShape); } } // namespace infini diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 78e30dad..de836d58 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) } optional> ConcatObj::inferShape(const TensorVec &inputs) const { - IT_ASSERT(inputs.size() > 1); Shape dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); ShapeElem n = dims.at(dim); diff --git a/src/operators/gather.cc b/src/operators/gather.cc index f615faf7..96493323 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -6,7 +6,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output, int axis) : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { int rank = input->getRank(); - axis = get_real_axis(axis, rank); + this->axis = get_real_axis(axis, rank); IT_ASSERT(checkValid(graph)); } @@ -25,7 +25,7 @@ optional> GatherObj::inferShape(const TensorVec &inputs) const { vector GatherObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2); auto index_dtype = inputs[1]->getDType(); - IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64) + IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64); return {inputs[0]->getDType()}; } diff --git a/src/utils/exception.cc b/src/utils/exception.cc index da69751f..37013836 100644 --- a/src/utils/exception.cc +++ b/src/utils/exception.cc @@ -9,8 +9,9 @@ namespace host_backtrace = backward; host_backtrace::SignalHandling sh; namespace infini { -Exception::Exception(const std::string &msg) : std::runtime_error(msg) { - host_backtrace::StackTrace st; +Exception::Exception(const std::string &msg) + : std::runtime_error(msg), info(msg) { + backward_trace::StackTrace st; st.load_here(32); host_backtrace::Printer p; p.print(st); diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 4bc7e950..013d25b5 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -8,6 +8,7 @@ namespace infini { /* +// Test cuda splitted idx to complosed idx in cpu. Uncomment to run this test. int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize, int concatDim, int outputDimSize[4], int outputStride[4], int nDim) { @@ -22,7 +23,8 @@ int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize, offset += oP * outputStride[i]; } - return offset + linearIndex * outputStride[0]; + int oP = (concatDim == 0) ? (linearIndex + dimBgNo) : linearIndex; + return offset + oP * outputStride[0]; } TEST(Concat, OffsetTrans) { @@ -41,8 +43,22 @@ TEST(Concat, OffsetTrans) { 4); EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim), 5); + catDim = 0; + EXPECT_EQ(inputOffset2CatOffset(0, 0, 3, catDim, dimSize, strides, nDim), + 0); + EXPECT_EQ(inputOffset2CatOffset(1, 0, 3, catDim, dimSize, strides, nDim), + 1); + EXPECT_EQ(inputOffset2CatOffset(2, 0, 3, catDim, dimSize, strides, nDim), + 2); + EXPECT_EQ(inputOffset2CatOffset(0, 1, 3, catDim, dimSize, strides, nDim), + 3); + EXPECT_EQ(inputOffset2CatOffset(1, 1, 3, catDim, dimSize, strides, nDim), + 4); + EXPECT_EQ(inputOffset2CatOffset(2, 1, 3, catDim, dimSize, strides, nDim), + 5); } */ + TEST(Concat, Cuda) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); @@ -78,4 +94,32 @@ TEST(Concat, Cuda) { 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); } +TEST(Concat, Cuda_dim0) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({1, 3}, DataType::Float32); + auto t2 = gCpu->addTensor({1, 3}, DataType::Float32); + auto t3 = gCpu->addTensor({1, 3}, DataType::Float32); + gCpu->dataMalloc(); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + auto t2Gpu = gCuda->cloneTensor(t2); + auto t3Gpu = gCuda->cloneTensor(t3); + + auto op = + gCuda->addOp(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 0); + gCuda->dataMalloc(); + t1Gpu->setData(IncrementalGenerator()); // 0 1 2 + t2Gpu->setData(OneGenerator()); // 1 1 1 + t3Gpu->setData(IncrementalGenerator()); // 0 1 2 + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{0, 1, 2, 1, 1, 1, 0, 1, 2})); +} + } // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 163bba5c..5a32f27f 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -39,4 +39,30 @@ TEST(Split, Cuda) { 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); } +TEST(Split, Cuda_dim0) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 3}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto op = gCuda->addOp(inputGpu, std::nullopt, 0, 2); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)2); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + EXPECT_TRUE(o0Cpu->equalData(vector{0, 1, 2})); + EXPECT_TRUE(o1Cpu->equalData(vector{3, 4, 5})); +} + } // namespace infini diff --git a/test/kernels/cuda/test_cuda_where.cc b/test/kernels/cuda/test_cuda_where.cc index 61515445..74f114d4 100644 --- a/test/kernels/cuda/test_cuda_where.cc +++ b/test/kernels/cuda/test_cuda_where.cc @@ -10,11 +10,12 @@ namespace infini { void test_where(const Shape &inputxshape, const vector &inputxdata, const Shape &inputyshape, const vector &inputydata, - const Shape &conditionshape, const vector &conditiondata, + const Shape &conditionshape, + const vector &conditiondata, const vector &ExpectData) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); - auto condition = gCpu->addTensor(conditionshape, DataType::Int32); + auto condition = gCpu->addTensor(conditionshape, DataType::UInt8); auto inputx = gCpu->addTensor(inputxshape, DataType::Float32); auto inputy = gCpu->addTensor(inputyshape, DataType::Float32); @@ -47,16 +48,37 @@ TEST(CUDA_Where, run) { test_where( Shape{2, 2, 3, 1}, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, Shape{2, 2, 3, 1}, vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, + Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, vector{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); test_where(Shape{2, 1, 1, 3}, // inputx vector{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy vector{1, 1}, Shape{2, 1, 3, 1}, // condition - vector{0, 1, 1, 0, 0, 0}, + vector{0, 1, 1, 0, 0, 0}, vector{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + test_where( + Shape{ + 3, + }, + vector{0, 1, 2}, // inputX + Shape{2, 3, 1}, vector{0, 1, 2, 3, 4, 5}, // inputY + Shape{2, 1, 3, 1}, vector{0, 1, 1, 0, 0, 0}, // condition + vector{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3., + 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., + 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); + test_where( + Shape{ + 3, + }, + vector{0, 1, 2}, // inputX + Shape{2, 3, 1}, vector{0, 1, 2, 3, 4, 5}, // inputY + Shape{2, 1, 3, 1}, + vector{false, true, true, false, false, false}, // condition + vector{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3., + 0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1., + 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.}); } // python output