forked from jiuyuan/InfiniTensor
Merge branch 'master' into xpu
This commit is contained in:
commit
33cae7fc41
4
Makefile
4
Makefile
|
@ -56,6 +56,10 @@ test-onnx:
|
|||
@echo
|
||||
python3 pyinfinitensor/tests/test_onnx.py
|
||||
|
||||
test-api:
|
||||
@echo
|
||||
python3 pyinfinitensor/tests/test_api.py
|
||||
|
||||
docker-build:
|
||||
docker build -f scripts/dockerfile/$(DOCKER_FILE) -t $(DOCKER_NAME) .
|
||||
|
||||
|
|
|
@ -4,8 +4,12 @@ import time
|
|||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
import numpy as np
|
||||
from parallel import parallel_model
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -14,77 +18,126 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return args.num_nodes, args.nproc_per_node, args.model
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
)
|
||||
|
||||
|
||||
def run_stub(stub: OnnxStub, inputs: np.array, n=100):
|
||||
# warm up
|
||||
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
||||
def run_model(model, runtime, inputs: np.array, n=20):
|
||||
stub = OnnxStub(model, runtime)
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
stub.tune()
|
||||
for _ in range(20):
|
||||
stub.run()
|
||||
stub.run()
|
||||
# get outputs
|
||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||
|
||||
# bench
|
||||
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
|
||||
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
|
||||
print("outputs sum:", outputs.sum())
|
||||
# np.save("results", outputs)
|
||||
results = np.load("results.npy")
|
||||
print("max diff:", abs(outputs - results).max())
|
||||
assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6)
|
||||
avg_time = (end - begin) / n
|
||||
return avg_time
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime):
|
||||
data = np.load(f"{name}_inputs.npy")
|
||||
results = np.load(f"{name}_results.npy")
|
||||
outputs = run_model(model, runtime, data)
|
||||
print("outputs sum:", outputs.sum())
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
print("max rel diff:", abs((outputs - results) / results).max())
|
||||
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
||||
|
||||
|
||||
def start_worker(
|
||||
dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
):
|
||||
print("start worker")
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
print("init comm")
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
model = parallel_model(model, world_size, rank)
|
||||
onnx.save(model, f"dist_model_rank{rank}.onnx")
|
||||
print("load model")
|
||||
stub = OnnxStub(model, runtime)
|
||||
data = np.load("inputs.npy")
|
||||
print("run model")
|
||||
avg_time = run_stub(stub, data)
|
||||
print(f"average time: {avg_time}")
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
runtime = backend.CudaRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def gen_standard(name, model, voc_size, bs, len):
|
||||
# generate standard results
|
||||
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||
np.save(f"{name}_inputs", data)
|
||||
runtime = backend.CudaRuntime(0)
|
||||
outputs = run_model(model, runtime, data, 1)
|
||||
np.save(f"{name}_results", outputs)
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, model_path = parse_args()
|
||||
world_size = nnodes * nproc_per_node
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
# generate standard results
|
||||
# runtime = backend.CudaRuntime(0)
|
||||
# stub = OnnxStub(model, runtime)
|
||||
# data = np.random.randn(1, 3, 224, 224)
|
||||
# np.save("inputs", data)
|
||||
# run_stub(stub, data)
|
||||
# del stub
|
||||
|
||||
dist_name = f"dist_{os.getpid()}"
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print(f"generate standard data for {name}.")
|
||||
# a small vocabulary size to fit all LLM.
|
||||
voc_size = 1000
|
||||
gen_standard(name, model, voc_size, bs, length)
|
||||
return
|
||||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(dist_name, world_size, rank, rank % nproc_per_node, model),
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model1", type=str, required=True, help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model2", type=str, required=True, help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model1,
|
||||
args.model2,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
)
|
||||
|
||||
|
||||
def run_model(model1, model2, runtime1, runtime2, inputs1: np.array, inputs2: np.array, n=20):
|
||||
####################################
|
||||
# run the first graph without kvcache
|
||||
####################################
|
||||
stub1 = OnnxStub(model1, runtime1)
|
||||
stub1.inputs['onnx::Reshape_0'].copyin_int32(inputs1.reshape(-1).tolist())
|
||||
stub1.tune()
|
||||
stub1.run()
|
||||
kvcache_it1 = []
|
||||
count = 0
|
||||
for output in stub1.outputs.items().__iter__():
|
||||
if count == 0:
|
||||
logits_it1 = np.array(output[1].copyout_float(), dtype=np.float32)
|
||||
else:
|
||||
kvcache_it1.append(np.array(output[1].copyout_float(), dtype=np.float32))
|
||||
count = count + 1
|
||||
|
||||
# bench for stub1
|
||||
next(stub1.inputs.items().__iter__())[1].copyin_int32(inputs1.reshape(-1).tolist())
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub1.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"stub1 average time: {avg_time}")
|
||||
|
||||
####################################
|
||||
# run the second graph with kvcache
|
||||
####################################
|
||||
i = 0
|
||||
batchsize = 1
|
||||
stub2 = OnnxStub(model2, runtime2)
|
||||
past_kvcache_length = (i+2)*np.ones((batchsize, 1), dtype=np.int32)
|
||||
# copyin input
|
||||
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
|
||||
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
|
||||
count = -1
|
||||
for input in stub2.inputs.items().__iter__():
|
||||
if count in range(24):
|
||||
# print(count, input[0])
|
||||
# print(np.dtype(kvcache_it1[count][0]), kvcache_it1[count].shape)
|
||||
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
|
||||
count = count + 1
|
||||
stub2.tune()
|
||||
stub2.run()
|
||||
|
||||
# copyout output
|
||||
count = 0
|
||||
kvcache_it2 = []
|
||||
for output in stub2.outputs.items().__iter__():
|
||||
if count == 0:
|
||||
logits_it2 = np.array(output[1].copyout_float(), dtype=np.float32)
|
||||
else:
|
||||
kvcache_it2.append(np.array(output[1].copyout_float(), dtype=np.float32))
|
||||
count = count + 1
|
||||
|
||||
# bench for stub2
|
||||
# copyin input
|
||||
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
|
||||
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
|
||||
count = -1
|
||||
for input in stub2.inputs.items().__iter__():
|
||||
if count in range(24):
|
||||
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
|
||||
count = count + 1
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub2.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"stub2 average time: {avg_time}")
|
||||
return logits_it2
|
||||
|
||||
|
||||
def run_and_compare(name, model1, model2, runtime1, runtime2):
|
||||
data1 = np.load(f"{name}_inputs1.npy")
|
||||
data2 = np.load(f"{name}_inputs2.npy")
|
||||
results = np.load(f"{name}_results.npy")
|
||||
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2)
|
||||
print("outputs sum:", outputs.sum())
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
print("max rel diff:", abs((outputs - results) / results).max())
|
||||
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model1: onnx.ModelProto, model2: onnx.ModelProto
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
####################################
|
||||
# shard the first graph
|
||||
####################################
|
||||
model1 = parallel_model(model1, world_size, rank)
|
||||
extern_path = f"./{dist_name}_stub1_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
model1,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(model1, f"./{dist_name}_stub1_rank{rank}.onnx")
|
||||
runtime1 = backend.CudaRuntime(local_rank)
|
||||
runtime1.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
|
||||
####################################
|
||||
# shard the second graph
|
||||
####################################
|
||||
model2 = parallel_model(model2, world_size, rank)
|
||||
extern_path = f"./{dist_name}_stub2_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
model2,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(model2, f"./{dist_name}_stub2_rank{rank}.onnx")
|
||||
runtime2 = backend.CudaRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime2.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
|
||||
# run the two graphs
|
||||
run_and_compare(name, model1, model2, runtime1, runtime2)
|
||||
|
||||
|
||||
def start_single(name, model1, model2):
|
||||
runtime1 = backend.CudaRuntime(0)
|
||||
runtime2 = backend.CudaRuntime(0)
|
||||
run_and_compare(name, model1, model2, runtime1, runtime2)
|
||||
|
||||
|
||||
def gen_standard(name, model1, model2, voc_size, bs, len):
|
||||
# generate standard results
|
||||
data1 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||
data2 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
|
||||
np.save(f"{name}_inputs1", data1)
|
||||
np.save(f"{name}_inputs2", data2)
|
||||
runtime1 = backend.CudaRuntime(0)
|
||||
runtime2 = backend.CudaRuntime(0)
|
||||
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2, 1)
|
||||
np.save(f"{name}_results", outputs)
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model1_path, model2_path, bs, length, gen_std = parse_args()
|
||||
|
||||
model1 = onnx.load(model1_path)
|
||||
model2 = onnx.load(model2_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print(f"generate standard data for {name}.")
|
||||
# a small vocabulary size to fit all LLM.
|
||||
voc_size = 1000
|
||||
gen_standard(name, model1, model2, voc_size, bs, length)
|
||||
return
|
||||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
p = mp.Process(target=start_single, args=(name, model1, model2))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model1, model2),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,237 @@
|
|||
import onnx
|
||||
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto
|
||||
from onnx import helper, numpy_helper
|
||||
from typing import Dict, List
|
||||
from placement import Placement, Replicate, Shard, _Partial
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||
data = {init.name: init for init in model.graph.initializer}
|
||||
vinfo = {info.name: info for info in model.graph.value_info}
|
||||
vinfo.update({info.name: info for info in model.graph.input})
|
||||
vinfo.update({info.name: info for info in model.graph.output})
|
||||
place: Dict[str, Placement] = {}
|
||||
nodes: List[NodeProto] = []
|
||||
|
||||
def is_sharded(name: str):
|
||||
return place[name].is_shard()
|
||||
|
||||
def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1):
|
||||
# print(f"shard {tensor.name} at dim {dim}")
|
||||
assert plc.is_shard(), plc
|
||||
ndim = len(tensor.dims)
|
||||
if plc.dim < 0:
|
||||
plc.dim += ndim
|
||||
if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard.
|
||||
return tensor
|
||||
array = numpy_helper.to_array(tensor)
|
||||
assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim]
|
||||
dims = list(tensor.dims)
|
||||
dims.insert(plc.dim, groups)
|
||||
dims[plc.dim + 1] //= groups
|
||||
array = array.reshape(dims)
|
||||
seg = array.shape[plc.dim + 1] // tp_world_size
|
||||
array = array.take(
|
||||
indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1
|
||||
)
|
||||
dims = list(tensor.dims)
|
||||
dims[plc.dim] //= tp_world_size
|
||||
array = array.reshape(dims)
|
||||
tensor = numpy_helper.from_array(array, name=tensor.name)
|
||||
place[tensor.name] = plc
|
||||
return tensor
|
||||
|
||||
def shard_gemm(node: NodeProto, groups: int = 1):
|
||||
# print("gemm", node.name)
|
||||
in_plc = place[node.input[0]]
|
||||
w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0)
|
||||
transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0)
|
||||
if transB:
|
||||
w_plc.dim = ~w_plc.dim
|
||||
input = node.input[1]
|
||||
data[input] = shard_tensor(data[input], w_plc, groups)
|
||||
|
||||
output = node.output[0]
|
||||
ndim = len(vinfo[output].type.tensor_type.shape.dim)
|
||||
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
|
||||
place[node.output[0]] = out_plc
|
||||
|
||||
def shard_concat(node: NodeProto):
|
||||
# hack for kvcache
|
||||
in_plc = place[node.input[1]]
|
||||
if in_plc.is_shard():
|
||||
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
|
||||
seq_len_dim.dim_value //= tp_world_size
|
||||
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
|
||||
place[node.input[0]] = in_plc
|
||||
place[node.output[0]] = in_plc
|
||||
|
||||
def shard_binary(node: NodeProto, groups: int = 1):
|
||||
# print("binary", node.name, node.input[0], place[node.input[0]])
|
||||
a = node.input[0]
|
||||
b = node.input[1]
|
||||
if a in data:
|
||||
a, b = b, a
|
||||
place[node.output[0]] = place[a]
|
||||
if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast
|
||||
data[b] = shard_tensor(data[b], Shard(0), groups)
|
||||
|
||||
def shard_reshape(node: NodeProto):
|
||||
# print("reshape", node.name, node.input[0], place[node.input[0]])
|
||||
if not is_sharded(node.input[0]):
|
||||
return
|
||||
in_plc = place[node.input[0]]
|
||||
s_dim = -1
|
||||
in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim]
|
||||
tensor = data[node.input[1]]
|
||||
out_dims = numpy_helper.to_array(tensor).copy()
|
||||
if len(in_dims) == 3 and len(out_dims) == 4:
|
||||
if in_plc.dim == 0:
|
||||
s_dim = 1
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 2
|
||||
if len(in_dims) == 4 and len(out_dims) == 3:
|
||||
if in_plc.dim == 1:
|
||||
s_dim = 0
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 2
|
||||
if len(in_dims) == 2 and len(out_dims) == 3:
|
||||
if in_plc.dim == 1:
|
||||
s_dim = 2
|
||||
if len(in_dims) == 4 and len(out_dims) == 2:
|
||||
if in_plc.dim == 1:
|
||||
s_dim = 0
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 1
|
||||
if len(in_dims) == 3 and len(out_dims) == 2:
|
||||
if in_plc.dim == 1:
|
||||
s_dim = 0
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 1
|
||||
|
||||
assert s_dim != -1
|
||||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||
out_dims[s_dim] //= tp_world_size
|
||||
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
|
||||
# node.input[1] = node.output[0] + "_shape"
|
||||
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
|
||||
place[node.output[0]] = Shard(s_dim)
|
||||
|
||||
def shard_split(node: NodeProto):
|
||||
if not is_sharded(node.input[0]):
|
||||
return
|
||||
in_plc = place[node.input[0]]
|
||||
split_tensor = data[node.input[1]]
|
||||
split = numpy_helper.to_array(split_tensor).copy()
|
||||
split //= tp_world_size
|
||||
data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1])
|
||||
for output in node.output:
|
||||
place[output] = in_plc
|
||||
|
||||
def shard_transpose(node: NodeProto):
|
||||
plc = place[node.input[0]]
|
||||
if plc.is_shard():
|
||||
perm = next(attr.ints for attr in node.attribute if attr.name == "perm")
|
||||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||
|
||||
def shard_node(node: NodeProto):
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax"]:
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type in ["Where"]:
|
||||
place[node.output[0]] = place[node.input[1]]
|
||||
if node.op_type in {"Add", "Mul", "Div", "Max"}:
|
||||
shard_binary(node)
|
||||
elif node.op_type == "Reshape":
|
||||
shard_reshape(node)
|
||||
elif node.op_type == "Transpose":
|
||||
shard_transpose(node)
|
||||
elif node.op_type == "Split":
|
||||
shard_split(node)
|
||||
elif node.op_type == "MatMul":
|
||||
assert (
|
||||
place[node.input[0]] == place[node.input[1]]
|
||||
), f"{place[node.input[0]]} != {place[node.input[1]]}"
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type == "Concat":
|
||||
shard_concat(node)
|
||||
|
||||
def find_successor(op_type: str, idx: int, search_limit: int = 1):
|
||||
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
|
||||
if node.op_type == op_type:
|
||||
return node
|
||||
return None
|
||||
|
||||
# all tensors are initially replicated.
|
||||
for v in vinfo:
|
||||
place[v] = Replicate()
|
||||
|
||||
for t in data:
|
||||
place[t] = Replicate()
|
||||
|
||||
for index, node in enumerate(model.graph.node):
|
||||
nodes.append(node)
|
||||
# linear
|
||||
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
|
||||
input in data for input in node.input
|
||||
):
|
||||
groups = 1
|
||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||
split_node = find_successor("Split", index, search_limit=2)
|
||||
if split_node is not None:
|
||||
groups = len(split_node.output)
|
||||
shard_gemm(node, groups)
|
||||
plc = place[node.output[0]]
|
||||
if plc.is_partial():
|
||||
new_name = node.output[0] + f":{plc}"
|
||||
place[new_name] = place[node.output[0]]
|
||||
# insert all_reduce
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
op_type="ReduceSum",
|
||||
inputs=[new_name],
|
||||
outputs=[node.output[0]],
|
||||
name=node.name + "/all_reduce",
|
||||
noop_with_empty_axes=1,
|
||||
communicator=0, # hack to treat ReduceSum as AllReduceSum
|
||||
)
|
||||
)
|
||||
place[node.output[0]] = Replicate()
|
||||
node.output[0] = new_name
|
||||
if len(node.input) > 2: # split bias to add
|
||||
prev = nodes[-1]
|
||||
new_name = prev.output[0] + "_no_bias"
|
||||
place[new_name] = place[node.output[0]]
|
||||
bias = helper.make_node(
|
||||
op_type="Add",
|
||||
inputs=[new_name, node.input[2]],
|
||||
outputs=[prev.output[0]],
|
||||
name=node.name + "/bias",
|
||||
)
|
||||
node.input.pop()
|
||||
prev.output[0] = new_name
|
||||
shard_binary(bias, groups)
|
||||
nodes.append(bias)
|
||||
continue
|
||||
shard_node(node)
|
||||
|
||||
new_input = []
|
||||
for info in model.graph.input:
|
||||
new_input.append(vinfo[info.name])
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
model.graph.name + f"_{tp_rank}",
|
||||
new_input,
|
||||
model.graph.output,
|
||||
data.values(),
|
||||
doc_string=model.graph.doc_string,
|
||||
# value_info=vinfo.values(),
|
||||
)
|
||||
for output in graph.output:
|
||||
tt = output.type.tensor_type
|
||||
if tt.HasField("shape"):
|
||||
tt.ClearField("shape")
|
||||
model = helper.make_model(graph)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
return model
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class Placement:
|
||||
# base class Placement type
|
||||
|
||||
# convenient utils to check for placement types
|
||||
def is_shard(self, dim: Optional[int] = None) -> bool:
|
||||
if dim is not None and isinstance(self, Shard):
|
||||
return self.dim == dim
|
||||
else:
|
||||
return isinstance(self, Shard)
|
||||
|
||||
def is_replicate(self) -> bool:
|
||||
return isinstance(self, Replicate)
|
||||
|
||||
def is_partial(self) -> bool:
|
||||
return isinstance(self, _Partial)
|
||||
|
||||
|
||||
class Replicate(Placement):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Replicate):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the Replicate placement
|
||||
"""
|
||||
return "Replicate()"
|
||||
|
||||
|
||||
class Shard(Placement):
|
||||
# shard placement, shard on a dim
|
||||
def __init__(self, dim):
|
||||
self.dim = dim
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Shard):
|
||||
return False
|
||||
return self.dim == other.dim
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the Shard placement
|
||||
"""
|
||||
return f"Shard(dim={self.dim})"
|
||||
|
||||
|
||||
class _Partial(Placement):
|
||||
def __init__(self, reduce_op: str = "sum"):
|
||||
self.reduce_op: str = reduce_op
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _Partial):
|
||||
return False
|
||||
return self.reduce_op == other.reduce_op
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the Partial placement
|
||||
"""
|
||||
return f"_Partial(reduce_op={self.reduce_op})"
|
|
@ -0,0 +1,29 @@
|
|||
import sys
|
||||
import onnx
|
||||
import torch
|
||||
import numpy as np
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = sys.argv
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python onnx_inference.py model_name.onnx")
|
||||
exit()
|
||||
model_path = sys.argv[1]
|
||||
# print(model_path)
|
||||
|
||||
onnx_model = onnx.load(model_path)
|
||||
onnx_input = onnx_model.graph.input[0]
|
||||
input_shape = [[d.dim_value for d in _input.type.tensor_type.shape.dim]
|
||||
for _input in onnx_model.graph.input]
|
||||
# Assume that there is only one input tensor
|
||||
input_shape = input_shape[0]
|
||||
# print(input_shape)
|
||||
input_data = np.random.random(input_shape).astype(np.float32)
|
||||
|
||||
model = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||
next(iter(model.inputs.values())).copyin_numpy(input_data)
|
||||
model.run()
|
||||
outputs = next(iter(model.outputs.values())).copyout_numpy()
|
||||
outputs = torch.tensor(outputs)
|
||||
print(outputs.shape)
|
|
@ -0,0 +1,24 @@
|
|||
import sys
|
||||
import onnx
|
||||
import torch
|
||||
import numpy as np
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import torchvision.models as models
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_path = './resnet18.onnx'
|
||||
tv_model = models.resnet50(weights=None)
|
||||
input_shape = (1, 3, 224, 224)
|
||||
param = torch.rand(input_shape)
|
||||
torch.onnx.export(tv_model, param, model_path, verbose=False)
|
||||
|
||||
onnx_model = onnx.load(model_path)
|
||||
model = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||
images = np.random.random(input_shape).astype(np.float32)
|
||||
next(iter(model.inputs.values())).copyin_numpy(images)
|
||||
model.run()
|
||||
outputs = next(iter(model.outputs.values())).copyout_numpy()
|
||||
outputs = torch.tensor(outputs)
|
||||
outputs = torch.reshape(outputs, (1, 1000))
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
print(predicted)
|
|
@ -67,6 +67,10 @@ class BangRuntimeObj : public RuntimeObj {
|
|||
CNRT_MEM_TRANS_DIR_PEER2PEER));
|
||||
}
|
||||
|
||||
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
||||
|
||||
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
||||
|
||||
private:
|
||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||
};
|
||||
|
|
|
@ -40,12 +40,12 @@ using HashType = uint64_t; // compatible with std::hash
|
|||
|
||||
// Assert: conditions should have no side effect
|
||||
#define _IT_ASSERT_2(condition, info) \
|
||||
(static_cast<bool>(condition) \
|
||||
? void(0) \
|
||||
: throw ::infini::Exception( \
|
||||
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||
"] Assertion failed (" + #condition + "): " + info))
|
||||
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "");
|
||||
static_cast<bool>(condition) \
|
||||
? void(0) \
|
||||
: throw ::infini::Exception( \
|
||||
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||
"] Assertion failed (" + #condition + "): " + info)
|
||||
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "")
|
||||
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
||||
|
||||
#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented")
|
||||
|
|
|
@ -120,6 +120,11 @@ class GraphObj : public Object {
|
|||
* @brief If the nodes is sorted in topological order.
|
||||
*/
|
||||
bool sorted;
|
||||
|
||||
/**
|
||||
* @brief If the weight tensors are allocated.
|
||||
*/
|
||||
bool weightAllocated = false;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -20,14 +20,23 @@ class LazyAllocator {
|
|||
|
||||
Runtime runtime;
|
||||
|
||||
size_t used;
|
||||
size_t used = 0;
|
||||
|
||||
size_t peak;
|
||||
size_t peak = 0;
|
||||
|
||||
size_t weightPeak = 0;
|
||||
|
||||
size_t alignment;
|
||||
|
||||
// pointer to the memory actually allocated
|
||||
void *ptr;
|
||||
void *ptr = nullptr;
|
||||
|
||||
// pointer to the weight memory space
|
||||
void *weightPtr = nullptr;
|
||||
|
||||
// // a cache designed for a batch size that has already occurred
|
||||
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
|
||||
// batchsizeToTensorOffset;
|
||||
|
||||
struct freeBlockInfo {
|
||||
size_t addr;
|
||||
|
@ -57,12 +66,16 @@ class LazyAllocator {
|
|||
|
||||
virtual ~LazyAllocator();
|
||||
|
||||
void init();
|
||||
|
||||
// function: simulate memory allocation
|
||||
// arguments:
|
||||
// size: size of memory block to be allocated
|
||||
// return: head address offset of the allocated memory block
|
||||
size_t alloc(size_t size);
|
||||
|
||||
size_t allocWeight(size_t size);
|
||||
|
||||
// function: simulate memory free
|
||||
// arguments:
|
||||
// addr: head address offset of memory block to be free
|
||||
|
@ -73,6 +86,12 @@ class LazyAllocator {
|
|||
// return: pointer to the head address of the allocated memory
|
||||
void *getPtr();
|
||||
|
||||
// void addCache(size_t batchsize, std::unordered_map<TensorObj *, size_t>);
|
||||
|
||||
// std::unordered_map<TensorObj *, size_t> getCache(size_t batchsize);
|
||||
|
||||
void *getWeightPtr();
|
||||
|
||||
void info();
|
||||
|
||||
private:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include "core/tensor_base.h"
|
||||
#include "core/tensor_type.h"
|
||||
#include "utils/data_convert.h"
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
@ -19,6 +20,8 @@ class TensorObj : public TensorBaseObj {
|
|||
size_t _size; // Cache of Π(shape).
|
||||
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
||||
// scratch have a new id.
|
||||
TensorType tensorType = TensorType::others;
|
||||
|
||||
public:
|
||||
TensorObj(Shape shape, DataType dtype, Runtime runtime);
|
||||
virtual ~TensorObj() {}
|
||||
|
@ -33,6 +36,33 @@ class TensorObj : public TensorBaseObj {
|
|||
size_t getOffset(const vector<int> &ds) const;
|
||||
void dataMalloc();
|
||||
UidBaseType getFuid() const { return fuid; }
|
||||
bool isWeight() const { return tensorType == TensorType::weight; }
|
||||
bool isInput() const { return tensorType == TensorType::input; }
|
||||
bool isOutput() const { return tensorType == TensorType::output; }
|
||||
bool isOthers() const { return tensorType == TensorType::others; }
|
||||
void setWeight() { tensorType = TensorType::weight; }
|
||||
void setInput() { tensorType = TensorType::input; }
|
||||
void setOutput() { tensorType = TensorType::output; }
|
||||
string tensorTypeToString() const {
|
||||
switch (tensorType) {
|
||||
case TensorType::weight:
|
||||
return "weight";
|
||||
break;
|
||||
case TensorType::input:
|
||||
return "input";
|
||||
break;
|
||||
case TensorType::output:
|
||||
return "output";
|
||||
break;
|
||||
case TensorType::others:
|
||||
return "others";
|
||||
break;
|
||||
|
||||
default:
|
||||
return "unknown tensor type";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void load(std::string file_path);
|
||||
void save(std::string file_path);
|
||||
|
|
|
@ -44,6 +44,7 @@ class TensorBaseObj : public Object {
|
|||
}
|
||||
|
||||
DataType getDType() const { return dtype; }
|
||||
int getDTypeIndex() const { return dtype.getIndex(); }
|
||||
Runtime getRuntime() const { return runtime; }
|
||||
|
||||
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
|
||||
enum class TensorType { weight, input, output, others };
|
||||
|
||||
} // namespace infini
|
|
@ -6,16 +6,11 @@
|
|||
#include <cudnn.h>
|
||||
#include <curand.h>
|
||||
|
||||
// TODO: replace with Exception (IT_ASSERT)
|
||||
#define checkCudaError(call) \
|
||||
{ \
|
||||
auto err = call; \
|
||||
if (cudaSuccess != err) { \
|
||||
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
if (auto err = call; err != cudaSuccess) \
|
||||
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
|
||||
std::to_string(__LINE__) + "] CUDA error (" + \
|
||||
#call + "): " + cudaGetErrorString(err))
|
||||
|
||||
#define checkCUresult(call) \
|
||||
{ \
|
||||
|
@ -39,14 +34,10 @@
|
|||
}
|
||||
|
||||
#define checkCudnnError(call) \
|
||||
{ \
|
||||
auto err = call; \
|
||||
if (CUDNN_STATUS_SUCCESS != err) { \
|
||||
fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \
|
||||
__LINE__, cudnnGetErrorString(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
if (auto err = call; err != CUDNN_STATUS_SUCCESS) \
|
||||
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
|
||||
std::to_string(__LINE__) + "] cuDNN error (" + \
|
||||
#call + "): " + cudnnGetErrorString(err))
|
||||
|
||||
#define checkCurandError(call) \
|
||||
{ \
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,11 +3,9 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape);
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void broadcastShape(const Shape &originShape, SmallArray &modifyShape,
|
||||
int nDims, int size) {
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
modifyShape.data[i] = 1;
|
||||
}
|
||||
for (int i = size - 1; i >= 0; --i) {
|
||||
modifyShape.data[i + nDims - size] = originShape[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -5,8 +5,18 @@
|
|||
namespace infini {
|
||||
|
||||
class Exception : public std::runtime_error {
|
||||
protected:
|
||||
std::string info;
|
||||
|
||||
public:
|
||||
Exception(const std::string &msg);
|
||||
|
||||
Exception &operator<<(const std::string &str) {
|
||||
info += str;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const char *what() const noexcept override { return info.c_str(); }
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
#pragma once
|
||||
namespace infini {
|
||||
|
||||
#define SMALL_ARRAY_SIZE 8
|
||||
|
|
|
@ -32,35 +32,37 @@ class OnnxStub:
|
|||
The Onnx model imported into infinitensor.
|
||||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
initializer: Dict[int, TensorProto] = {}
|
||||
handler: backend.GraphHandler
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
model = infer_shapes(model)
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
dims = [d for d in initializer.dims]
|
||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||
data[initializer.name] = initializer
|
||||
tensors[initializer.name].set_weight()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
if input.name not in tensors.keys():
|
||||
tensors[input.name] = self.handler.tensor(
|
||||
dims, input.type.tensor_type.elem_type
|
||||
)
|
||||
tensors[input.name].set_input()
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
dims = [d for d in initializer.dims]
|
||||
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
|
||||
data[initializer.name] = initializer
|
||||
|
||||
node_name = []
|
||||
new_node_name = []
|
||||
|
@ -591,6 +593,13 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "to")),
|
||||
)
|
||||
elif node.op_type == "ReduceSum":
|
||||
# ReduceSum is only implemented as allReduceSum.
|
||||
assert any(attr.name == "communicator" for attr in node.attribute)
|
||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "AllReduceSum":
|
||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||
tensors[node.input[0]],
|
||||
|
@ -631,13 +640,9 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "root"
|
||||
),
|
||||
0,
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "root"),
|
||||
0,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Expand":
|
||||
shape = _parse_data(data[node.input[1]])
|
||||
|
@ -658,6 +663,15 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Constant":
|
||||
output_name = node.output[0]
|
||||
attributes = _parse_attribute(node)
|
||||
tensor = attributes['value']
|
||||
dims = [d for d in tensor.dims]
|
||||
tensors[output_name] = self.handler.tensor(
|
||||
dims, tensor.data_type)
|
||||
data[output_name] = tensor
|
||||
tensors[output_name].set_weight()
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
new_node_name.append(node.name)
|
||||
|
@ -1062,19 +1076,18 @@ def _search_shape(model: ModelProto, name: str) -> List[int]:
|
|||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
if attr.type == AttributeProto.INT:
|
||||
attrs[attr.name] = attr.i
|
||||
elif attr.type == AttributeProto.INTS:
|
||||
attrs[attr.name] = attr.ints
|
||||
elif attr.type == AttributeProto.FLOAT:
|
||||
attrs[attr.name] = attr.f
|
||||
elif attr.type == AttributeProto.STRING:
|
||||
attrs[attr.name] = attr.s
|
||||
elif attr.type == AttributeProto.TENSOR:
|
||||
attrs[attr.name] = attr.t
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
if attr.type == AttributeProto.INT:
|
||||
attrs[attr.name] = attr.i
|
||||
elif attr.type == AttributeProto.INTS:
|
||||
attrs[attr.name] = attr.ints
|
||||
elif attr.type == AttributeProto.FLOAT:
|
||||
attrs[attr.name] = attr.f
|
||||
elif attr.type == AttributeProto.STRING:
|
||||
attrs[attr.name] = attr.s
|
||||
elif attr.type == AttributeProto.TENSOR:
|
||||
attrs[attr.name] = attr.t
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
return attrs
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import os, unittest
|
||||
from onnx import TensorProto
|
||||
from pyinfinitensor import backend
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestPythonAPI(unittest.TestCase):
|
||||
def test_copyin_numpy(self):
|
||||
dims = [2, 3, 5, 4]
|
||||
np_array = np.random.random(dims).astype(np.float32)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
tensor2.copyin_float(np_array.flatten().tolist())
|
||||
array1 = tensor1.copyout_float()
|
||||
array2 = tensor2.copyout_float()
|
||||
self.assertEqual(array1, array2)
|
||||
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||
|
||||
np_array = np.random.random(dims).astype(np.int64)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.INT64)
|
||||
tensor2 = handler.tensor(dims, TensorProto.INT64)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
tensor2.copyin_int64(np_array.flatten().tolist())
|
||||
array1 = tensor1.copyout_int64()
|
||||
array2 = tensor2.copyout_int64()
|
||||
self.assertEqual(array1, array2)
|
||||
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||
|
||||
def test_copyout_numpy(self):
|
||||
dims = [2, 3, 5, 4]
|
||||
np_array = np.random.random(dims).astype(np.float32)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_float(np_array.flatten().tolist())
|
||||
tensor2.copyin_float(np_array.flatten().tolist())
|
||||
array1 = np.array(tensor1.copyout_float()).reshape(dims)
|
||||
array2 = tensor2.copyout_numpy()
|
||||
self.assertTrue(np.array_equal(array2, np_array))
|
||||
self.assertTrue(np.array_equal(array1, array2))
|
||||
|
||||
np_array = np.random.random(dims).astype(np.float16)
|
||||
np_array[0, 0, 0, 0] = .1
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT16)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
array1 = tensor1.copyout_numpy()
|
||||
# Copy should be the same as original array
|
||||
self.assertTrue(np.array_equal(array1, np_array))
|
||||
# Modify the value so that tensorObj value changes
|
||||
np_array[0, 0, 0, 0] = 0.
|
||||
tensor1.copyin_numpy(np_array)
|
||||
# The copied-out array should not change
|
||||
self.assertFalse(np.array_equal(array1, np_array))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
[pads_data],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_allReduceSum(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([allReduceProd], "allReduceProd", [input], [output])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
|
||||
def test_allReduceMin(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
|
||||
def test_split(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
split = make_node(
|
||||
"Split", ["input"], ["output"], name="split", axis=0
|
||||
)
|
||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -460,53 +458,6 @@ class TestStringMethods(unittest.TestCase):
|
|||
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||
|
||||
def test_copyin(self):
|
||||
dims = [2,3,5,4]
|
||||
np_array = np.random.random(dims).astype(np.float32)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
tensor2.copyin_float(np_array.flatten().tolist())
|
||||
array1 = tensor1.copyout_float()
|
||||
array2 = tensor2.copyout_float()
|
||||
self.assertEqual(array1, array2)
|
||||
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||
|
||||
np_array = np.random.random(dims).astype(np.int64)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.INT64)
|
||||
tensor2 = handler.tensor(dims, TensorProto.INT64)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
tensor2.copyin_int64(np_array.flatten().tolist())
|
||||
array1 = tensor1.copyout_int64()
|
||||
array2 = tensor2.copyout_int64()
|
||||
self.assertEqual(array1, array2)
|
||||
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||
|
||||
def test_to_numpy(self):
|
||||
dims = [2,3,5,4]
|
||||
np_array = np.random.random(dims).astype(np.float32)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_float(np_array.flatten().tolist())
|
||||
tensor2.copyin_float(np_array.flatten().tolist())
|
||||
array1 = np.array(tensor1.copyout_float()).reshape(dims)
|
||||
array2 = np.array(tensor2)
|
||||
self.assertTrue(np.array_equal(array2, np_array))
|
||||
self.assertTrue(np.array_equal(array1, array2))
|
||||
|
||||
np_array = np.random.random(dims).astype(np.float16)
|
||||
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||
tensor1 = handler.tensor(dims, TensorProto.FLOAT16)
|
||||
handler.data_malloc()
|
||||
tensor1.copyin_numpy(np_array)
|
||||
array1 = np.array(tensor1, copy=False)
|
||||
self.assertTrue(np.array_equal(array1, np_array))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -131,30 +131,63 @@ void GraphObj::dataMalloc() {
|
|||
// record the memory address offsets of all tensors to be allocated
|
||||
std::unordered_map<TensorObj *, size_t> tensorToOffset;
|
||||
|
||||
// record all constant tensors, including weight tensors and input tensors
|
||||
std::unordered_set<TensorObj *> constTensor;
|
||||
// reinit allocator
|
||||
allocator.init();
|
||||
|
||||
// record all weight tensors, including weight tensors and kvcache
|
||||
// tensors
|
||||
std::unordered_set<TensorObj *> weightTensors;
|
||||
for (auto &tensor : tensors) {
|
||||
if (tensor.get()->getSource() == nullptr) {
|
||||
// allocate memory for all constant tensors first, and this memory
|
||||
if (tensor->isWeight()) {
|
||||
// allocate memory for all weight tensors first, and this memory
|
||||
// will not be freed until the graph is destroyed
|
||||
weightTensors.insert(tensor.get());
|
||||
if (!this->weightAllocated) {
|
||||
tensorToOffset[tensor.get()] =
|
||||
allocator.allocWeight(tensor->getBytes());
|
||||
}
|
||||
} else if (tensor->isInput() || tensor->isOutput()) {
|
||||
// allocate memory for all input and output tensors, and this memory
|
||||
// will not be reused later
|
||||
constTensor.insert(tensor.get());
|
||||
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes());
|
||||
} else {
|
||||
tensorToRefCount[tensor.get()] = tensor->getTargets().size();
|
||||
// allocate memory for all user-created tensors
|
||||
if (tensor.get()->getSource() == nullptr) {
|
||||
tensorToOffset[tensor.get()] =
|
||||
allocator.alloc(tensor->getBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
// if memory has not yet been allocated for weight tensors,
|
||||
// allocate memory now and do not allocate again in the future.
|
||||
if (!this->weightAllocated) {
|
||||
this->weightAllocated = true;
|
||||
// only allocate once for weight tensors
|
||||
for (auto &tensor : weightTensors) {
|
||||
IT_ASSERT(tensorToOffset.find(tensor) != tensorToOffset.end());
|
||||
tensor->setDataBlob(make_ref<BlobObj>(
|
||||
tensor->runtime,
|
||||
static_cast<uint8_t *>(allocator.getWeightPtr()) +
|
||||
tensorToOffset[tensor]));
|
||||
}
|
||||
}
|
||||
// traverse in topological order and simulate memory allocation
|
||||
for (auto &op : ops) {
|
||||
// memory should be allocated for the output first
|
||||
// memory should be allocated for the op's output first
|
||||
auto outputs = op->getOutputs();
|
||||
for (auto &tensor : outputs) {
|
||||
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes());
|
||||
if (tensor->isOthers()) {
|
||||
tensorToOffset[tensor.get()] =
|
||||
allocator.alloc(tensor->getBytes());
|
||||
}
|
||||
}
|
||||
auto inputs = op->getInputs();
|
||||
for (auto &tensor : inputs) {
|
||||
if (constTensor.find(tensor.get()) == constTensor.end()) {
|
||||
if (tensor->isOthers()) {
|
||||
auto tensorIter = tensorToRefCount.find(tensor.get());
|
||||
IT_ASSERT(tensorIter != tensorToRefCount.end());
|
||||
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
|
||||
tensorToRefCount[tensor.get()] -= 1;
|
||||
if (tensorToRefCount[tensor.get()] == 0) {
|
||||
// indicate that this tensor will no longer be used and
|
||||
|
@ -167,15 +200,20 @@ void GraphObj::dataMalloc() {
|
|||
}
|
||||
}
|
||||
|
||||
// perform actual memory allocation
|
||||
// perform actual memory allocation for non-weight tensors
|
||||
for (auto &tensor : tensors) {
|
||||
IT_ASSERT(tensorToOffset.find(tensor.get()) != tensorToOffset.end());
|
||||
tensor->setDataBlob(make_ref<BlobObj>(
|
||||
tensor->runtime, static_cast<uint8_t *>(allocator.getPtr()) +
|
||||
tensorToOffset[tensor.get()]));
|
||||
if (!tensor->isWeight()) {
|
||||
IT_ASSERT(tensorToOffset.find(tensor.get()) !=
|
||||
tensorToOffset.end());
|
||||
tensor->setDataBlob(make_ref<BlobObj>(
|
||||
tensor->runtime, static_cast<uint8_t *>(allocator.getPtr()) +
|
||||
tensorToOffset[tensor.get()]));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_MODE
|
||||
allocator.info();
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||
|
|
|
@ -11,9 +11,6 @@ namespace infini {
|
|||
constexpr size_t alignmentInBytesForCUDA = 256;
|
||||
|
||||
LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) {
|
||||
used = 0;
|
||||
peak = 0;
|
||||
ptr = nullptr;
|
||||
if (runtime->isCuda()) {
|
||||
// TODO: the alignment on cuda might need further discussion
|
||||
alignment = alignmentInBytesForCUDA;
|
||||
|
@ -30,10 +27,24 @@ LazyAllocator::~LazyAllocator() {
|
|||
if (this->ptr != nullptr) {
|
||||
runtime->dealloc(this->ptr);
|
||||
}
|
||||
if (this->weightPtr != nullptr) {
|
||||
runtime->dealloc(this->weightPtr);
|
||||
}
|
||||
}
|
||||
|
||||
void LazyAllocator::init() {
|
||||
used = 0;
|
||||
peak = 0;
|
||||
freeBlocks.clear();
|
||||
headAddrToBlockSize.clear();
|
||||
tailAddrToBlockSize.clear();
|
||||
if (this->ptr != nullptr) {
|
||||
runtime->dealloc(this->ptr);
|
||||
}
|
||||
this->ptr = nullptr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::alloc(size_t size) {
|
||||
IT_ASSERT(this->ptr == nullptr);
|
||||
// pad the size to the multiple of alignment
|
||||
size = this->getAlignedSize(size);
|
||||
auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size});
|
||||
|
@ -83,6 +94,14 @@ size_t LazyAllocator::alloc(size_t size) {
|
|||
return retAddr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::allocWeight(size_t size) {
|
||||
IT_ASSERT(this->weightPtr == nullptr);
|
||||
size = this->getAlignedSize(size);
|
||||
size_t retAddr = this->weightPeak;
|
||||
this->weightPeak += size;
|
||||
return retAddr;
|
||||
}
|
||||
|
||||
void LazyAllocator::free(size_t addr, size_t size) {
|
||||
IT_ASSERT(this->ptr == nullptr);
|
||||
size = getAlignedSize(size);
|
||||
|
@ -126,18 +145,33 @@ void LazyAllocator::free(size_t addr, size_t size) {
|
|||
void *LazyAllocator::getPtr() {
|
||||
if (this->ptr == nullptr) {
|
||||
this->ptr = runtime->alloc(this->peak);
|
||||
printf("LazyAllocator really alloc: %p %lu bytes\n", this->ptr, peak);
|
||||
#ifdef DEBUG_MODE
|
||||
printf("LazyAllocator really alloc non-weight: %p %lu bytes\n",
|
||||
this->ptr, peak);
|
||||
#endif
|
||||
}
|
||||
return this->ptr;
|
||||
}
|
||||
|
||||
void *LazyAllocator::getWeightPtr() {
|
||||
if (this->weightPtr == nullptr) {
|
||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||
#ifdef DEBUG_MODE
|
||||
printf("LazyAllocator really alloc weight: %p %lu bytes\n",
|
||||
this->weightPtr, weightPeak);
|
||||
#endif
|
||||
}
|
||||
return this->weightPtr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::getAlignedSize(size_t size) {
|
||||
return ((size - 1) / this->alignment + 1) * this->alignment;
|
||||
}
|
||||
|
||||
void LazyAllocator::info() {
|
||||
std::cout << "Used memory: " << this->used
|
||||
<< ", peak memory: " << this->peak << std::endl;
|
||||
std::cout << "Used memory: " << this->used + this->weightPeak
|
||||
<< ", peak memory: " << this->peak + this->weightPeak
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -23,7 +23,7 @@ string TensorObj::toString() const {
|
|||
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
|
||||
std::to_string(fuid) + ", shape " + vecToString(shape) +
|
||||
", dtype " + dtype.toString() + ", " + runtime->toString() +
|
||||
", " + ss.str() + "\n";
|
||||
", " + ss.str() + ", " + tensorTypeToString() + "\n";
|
||||
vector<UidBaseType> targetGuids;
|
||||
for (const auto &op : targets)
|
||||
targetGuids.emplace_back(op.lock()->getGuid());
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#ifdef DEBUG_MODE
|
||||
void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
||||
cudaError_t kernelError = cudaGetLastError();
|
||||
if (kernelError != cudaSuccess) {
|
||||
|
@ -18,7 +17,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -38,10 +36,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
} else {
|
||||
kernel->compute(op, this);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_MODE
|
||||
CHECK_CUDA_KERNEL_ERROR(op);
|
||||
#endif
|
||||
checkCudaError(cudaGetLastError()) << op->toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,9 +73,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
|||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_MODE
|
||||
CHECK_CUDA_KERNEL_ERROR(op);
|
||||
#endif
|
||||
checkCudaError(cudaGetLastError()) << op->toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,6 +96,7 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
|||
IT_ASSERT(worldSize > 0);
|
||||
IT_ASSERT(rank >= 0);
|
||||
IT_ASSERT(rank < worldSize);
|
||||
IT_ASSERT(!comm) << "communicator is already initialized.";
|
||||
#ifdef INFINI_USE_NCCL
|
||||
comm = std::make_unique<NcclCommunicatorObj>(name, worldSize, rank);
|
||||
#else
|
||||
|
|
|
@ -317,6 +317,44 @@ void export_functions(py::module &m) {
|
|||
#undef FUNCTION
|
||||
}
|
||||
|
||||
// A helper function that converts DataType to python format string
|
||||
static std::string getFormat(DataType type) {
|
||||
std::string format;
|
||||
if (type == DataType::Float32) {
|
||||
format = py::format_descriptor<float>::format();
|
||||
} else if (type == DataType::Double) {
|
||||
format = py::format_descriptor<double>::format();
|
||||
} else if (type == DataType::Int32) {
|
||||
format = py::format_descriptor<int>::format();
|
||||
} else if (type == DataType::UInt32) {
|
||||
format = py::format_descriptor<uint32_t>::format();
|
||||
} else if (type == DataType::Int64) {
|
||||
format = py::format_descriptor<int64_t>::format();
|
||||
} else if (type == DataType::UInt64) {
|
||||
format = py::format_descriptor<uint64_t>::format();
|
||||
} else if (type == DataType::Int16) {
|
||||
format = py::format_descriptor<int16_t>::format();
|
||||
} else if (type == DataType::UInt16) {
|
||||
format = py::format_descriptor<uint16_t>::format();
|
||||
} else if (type == DataType::Int8) {
|
||||
format = py::format_descriptor<int8_t>::format();
|
||||
} else if (type == DataType::UInt8) {
|
||||
format = py::format_descriptor<uint8_t>::format();
|
||||
} else if (type == DataType::Bool) {
|
||||
format = py::format_descriptor<bool>::format();
|
||||
} else if (type == DataType::Float16 || type == DataType::BFloat16) {
|
||||
// Python uses "e" for half precision float type code.
|
||||
// Check the following link for more information.
|
||||
// https://docs.python.org/3/library/struct.html#format-characters
|
||||
format = "e";
|
||||
} else {
|
||||
throw std::runtime_error("Error converting TensorObj to "
|
||||
"Numpy: unsupported datatype.\n");
|
||||
}
|
||||
|
||||
return format;
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
|
@ -341,6 +379,10 @@ void init_graph_builder(py::module &m) {
|
|||
py::buffer_protocol())
|
||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("set_weight", &TensorObj::setWeight, policy::move)
|
||||
.def("set_input", &TensorObj::setInput, policy::move)
|
||||
.def("set_output", &TensorObj::setOutput, policy::move)
|
||||
.def("dtype", &TensorObj::getDTypeIndex, policy::automatic)
|
||||
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
||||
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
|
||||
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
|
||||
|
@ -367,51 +409,24 @@ void init_graph_builder(py::module &m) {
|
|||
}
|
||||
self.copyin(data_np, self.getBytes());
|
||||
})
|
||||
// A buffer can be used to convert a TensorObj directly to Numpy array
|
||||
// without copy
|
||||
.def_buffer([](TensorObj &self) -> py::buffer_info {
|
||||
vector<size_t> stride_byte;
|
||||
for (int s : self.getStride()) {
|
||||
stride_byte.push_back(s * self.getDType().getSize());
|
||||
}
|
||||
// Return a Numpy array which copies the values of this tensor
|
||||
.def("copyout_numpy",
|
||||
[](TensorObj &self) -> py::array {
|
||||
vector<size_t> stride_byte;
|
||||
for (int s : self.getStride()) {
|
||||
stride_byte.push_back(s * self.getDType().getSize());
|
||||
}
|
||||
std::string format = getFormat(self.getDType());
|
||||
|
||||
std::string format;
|
||||
if (self.getDType() == DataType::Float32) {
|
||||
format = py::format_descriptor<float>::format();
|
||||
} else if (self.getDType() == DataType::Double) {
|
||||
format = py::format_descriptor<double>::format();
|
||||
} else if (self.getDType() == DataType::Int32) {
|
||||
format = py::format_descriptor<int>::format();
|
||||
} else if (self.getDType() == DataType::UInt32) {
|
||||
format = py::format_descriptor<uint32_t>::format();
|
||||
} else if (self.getDType() == DataType::Int64) {
|
||||
format = py::format_descriptor<int64_t>::format();
|
||||
} else if (self.getDType() == DataType::UInt64) {
|
||||
format = py::format_descriptor<uint64_t>::format();
|
||||
} else if (self.getDType() == DataType::Int16) {
|
||||
format = py::format_descriptor<int16_t>::format();
|
||||
} else if (self.getDType() == DataType::UInt16) {
|
||||
format = py::format_descriptor<uint16_t>::format();
|
||||
} else if (self.getDType() == DataType::Int8) {
|
||||
format = py::format_descriptor<int8_t>::format();
|
||||
} else if (self.getDType() == DataType::UInt8) {
|
||||
format = py::format_descriptor<uint8_t>::format();
|
||||
} else if (self.getDType() == DataType::Float16 ||
|
||||
self.getDType() == DataType::BFloat16) {
|
||||
// Python uses "e" for half precision float type code.
|
||||
// Check the following link for more information.
|
||||
// https://docs.python.org/3/library/struct.html#format-characters
|
||||
format = "e";
|
||||
} else {
|
||||
throw std::runtime_error("Error converting TensorObj to "
|
||||
"Numpy: unsupported datatype.\n");
|
||||
}
|
||||
py::array numpy_array(py::dtype(format), self.getDims(),
|
||||
nullptr);
|
||||
|
||||
return py::buffer_info(self.getRawDataPtr<void *>(),
|
||||
self.getDType().getSize(), format,
|
||||
self.getRank(), self.getDims(), stride_byte,
|
||||
true); // Read-only = true
|
||||
})
|
||||
// Copy data to the numpy array
|
||||
auto ptr = numpy_array.mutable_data();
|
||||
self.copyout(ptr, self.getBytes());
|
||||
|
||||
return numpy_array;
|
||||
})
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
|
@ -436,6 +451,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("mul", &Handler::mul, policy::move)
|
||||
.def("div", &Handler::div, policy::move)
|
||||
.def("pow", &Handler::pow, policy::move)
|
||||
.def("min", &Handler::min, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
.def("tanh", &Handler::tanh, policy::move)
|
||||
|
|
|
@ -14,7 +14,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||
size_t count = op->getInputs(0)->size();
|
||||
|
||||
ncclComm_t comm =
|
||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||
|
|
|
@ -25,8 +25,8 @@ class ExpandCuda : public CudaKernelWithoutConfig {
|
|||
inputShape.data[i] = in_Shape[i];
|
||||
outputsize *= out_Shape[i];
|
||||
}
|
||||
expand_kernel((float *)inputData, (float *)outputData, nDims,
|
||||
outputsize, inputShape, outputShape);
|
||||
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -6,9 +6,9 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
|||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _expand_kernel(float *input, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape) {
|
||||
__global__ void _expandKernel(float *input, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
int outputIdx =
|
||||
blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s
|
||||
|
@ -38,12 +38,12 @@ __global__ void _expand_kernel(float *input, float *output, int nDims,
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape) {
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_expand.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -46,7 +48,30 @@ class matmulCublas : public Kernel {
|
|||
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||
ldc = n;
|
||||
const float alpha = 1.f, beta = 0.f;
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
if (op->numInputs() == 2) { // no bias
|
||||
beta = 0.f;
|
||||
} else { // broadcast bias to output
|
||||
beta = 1.f;
|
||||
auto inC = op->getInputs(2);
|
||||
auto out = op->getOutput();
|
||||
SmallArray inputShape, outputShape;
|
||||
int nDims = out->getRank();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
int outputsize = 1; // the length of the output vector after flatten
|
||||
int offset = nDims - inC->getRank();
|
||||
for (int i = 0; i < offset; ++i)
|
||||
inputShape.data[i] = 1;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
outputShape.data[i] = out->getDims()[i];
|
||||
outputsize *= outputShape.data[i];
|
||||
if (i >= offset)
|
||||
inputShape.data[i] = inC->getDims()[i - offset];
|
||||
}
|
||||
expandKernel(inC->getRawDataPtr<float *>(),
|
||||
out->getRawDataPtr<float *>(), nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
// TODO:use compute type
|
||||
cublasStatus_t stat;
|
||||
if (b > 1) {
|
||||
|
|
|
@ -1,30 +1,29 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_split_concat.h"
|
||||
|
||||
int getMultiProcessorCount() {
|
||||
int cur_device;
|
||||
checkCudaError(cudaGetDevice(&cur_device));
|
||||
|
||||
struct cudaDeviceProp prop;
|
||||
checkCudaError(cudaGetDeviceProperties(&prop, cur_device));
|
||||
return prop.multiProcessorCount;
|
||||
}
|
||||
|
||||
__host__ __device__ int
|
||||
elementIdx2ComposedIdx(int elementIndex, int dimBgNo, int dimSize, int dim,
|
||||
int nDim, ComposedTensorMetadata wholeMeta) {
|
||||
int offset = 0;
|
||||
|
||||
// COMP(x0,...,xk,...,xn-1) = ELMT[xk / d](x0,...,xk % d,...xn-1)
|
||||
// where k=dim, n=ndim, d=dimSize is the splited length of
|
||||
// dimension dim
|
||||
#pragma unroll
|
||||
// Interate through n-1 to 1
|
||||
for (int i = nDim - 1; i >= 1; --i) {
|
||||
int size = (i == dim) ? dimSize : wholeMeta.dimSize[i];
|
||||
int p = elementIndex % size;
|
||||
// dimBgNo move the pointer to correct location in composed data
|
||||
// corresponding to current element, with repect to the splitted
|
||||
// dimension dim
|
||||
int oP = (i == dim) ? (p + dimBgNo) : p;
|
||||
elementIndex = (elementIndex - p) / size;
|
||||
offset += oP * wholeMeta.stride[i];
|
||||
}
|
||||
|
||||
return offset + elementIndex * wholeMeta.stride[0];
|
||||
// Deal with i = 0
|
||||
int oP = (dim == 0) ? (elementIndex + dimBgNo) : elementIndex;
|
||||
return offset + oP * wholeMeta.stride[0];
|
||||
}
|
||||
|
||||
__global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
||||
|
@ -38,31 +37,29 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta,
|
|||
auto dimBgNo = elemMeta.dimBgNo[blockIdx.y];
|
||||
auto dimSize = elemMeta.dimSize[blockIdx.y];
|
||||
float *elemData = elemMeta.data[blockIdx.y];
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
|
||||
while (tid < nElements) {
|
||||
int Offset =
|
||||
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
||||
// copy data from input to output
|
||||
// for split:input is composed tensor;for concat:input is element
|
||||
// tensors.
|
||||
if (isSplit)
|
||||
elemData[tid] = compMeta.data[Offset];
|
||||
else
|
||||
compMeta.data[Offset] = elemData[tid];
|
||||
tid += stride;
|
||||
}
|
||||
int Offset =
|
||||
elementIdx2ComposedIdx(tid, dimBgNo, dimSize, dim, nDims, compMeta);
|
||||
// copy data from input to output
|
||||
// for split:input is composed tensor;for concat:input is element
|
||||
// tensors.
|
||||
if (isSplit)
|
||||
elemData[tid] = compMeta.data[Offset];
|
||||
else
|
||||
compMeta.data[Offset] = elemData[tid];
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
// TODO: when dim=0, the operation can be executed in-place
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit) {
|
||||
dim3 blockSize = dim3(32 * 16);
|
||||
|
||||
// y dim is number of tensors.
|
||||
dim3 gridSize(getMultiProcessorCount(), batchSize);
|
||||
// gridsize =n_elements / blockSize
|
||||
int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1;
|
||||
// each y is a split among the batch
|
||||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims,
|
||||
isSplit);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_where.h"
|
||||
#include "utils/broadcast_shape.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -10,28 +11,33 @@ class WhereCuda : public CudaKernelWithoutConfig {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
|
||||
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const inputXData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputYData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &inputx_Shape = op->getInputs(0)->getDims();
|
||||
const auto &inputy_Shape = op->getInputs(1)->getDims();
|
||||
const auto &condition_Shape = op->getInputs(2)->getDims();
|
||||
const auto &output_Shape = op->getOutput()->getDims();
|
||||
const auto &opInputXShape = op->getInputs(0)->getDims();
|
||||
const auto &opInputYShape = op->getInputs(1)->getDims();
|
||||
const auto &opConditionShape = op->getInputs(2)->getDims();
|
||||
const auto &opOutputShape = op->getOutput()->getDims();
|
||||
|
||||
int nDims = op->getInputs(0)->getDims().size();
|
||||
const int xSize = op->getInputs(0)->getRank();
|
||||
const int ySize = op->getInputs(1)->getRank();
|
||||
const int cSize = op->getInputs(2)->getRank();
|
||||
int nDims = op->getOutput()->getDims().size();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
|
||||
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
inputxShape.data[i] = inputx_Shape[i];
|
||||
inputyShape.data[i] = inputy_Shape[i];
|
||||
conditionShape.data[i] = condition_Shape[i];
|
||||
outputShape.data[i] = output_Shape[i];
|
||||
SmallArray inputXShape, inputYShape, conditionShape, outputShape;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
outputShape.data[i] = opOutputShape[i];
|
||||
}
|
||||
where_kernel((float *)inputxData, (float *)inputyData,
|
||||
(float *)conditionData, (float *)outputData, nDims,
|
||||
inputxShape, inputyShape, conditionShape, outputShape);
|
||||
|
||||
broadcastShape(opInputXShape, inputXShape, nDims, xSize);
|
||||
broadcastShape(opInputYShape, inputYShape, nDims, ySize);
|
||||
broadcastShape(opConditionShape, conditionShape, nDims, cSize);
|
||||
|
||||
whereKernel((float *)inputXData, (float *)inputYData,
|
||||
(uint8_t *)conditionData, (float *)outputData, nDims,
|
||||
inputXShape, inputYShape, conditionShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
__global__ void _where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
__global__ void _whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputXShape,
|
||||
infini::SmallArray inputYShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < outputsize) {
|
||||
int inputxIdx = 0;
|
||||
int temp_inputx = 1;
|
||||
int inputXIdx = 0;
|
||||
int temp_inputX = 1;
|
||||
|
||||
int inputyIdx = 0;
|
||||
int temp_inputy = 1;
|
||||
int inputYIdx = 0;
|
||||
int temp_inputY = 1;
|
||||
|
||||
int conditionIdx = 0;
|
||||
int temp_condition = 1;
|
||||
|
@ -27,23 +27,23 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
|||
} else {
|
||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||
}
|
||||
if (inputxShape.data[i] == 1) {
|
||||
inputxIdx += 0;
|
||||
if (inputXShape.data[i] == 1) {
|
||||
inputXIdx += 0;
|
||||
} else {
|
||||
inputxIdx +=
|
||||
inputXIdx +=
|
||||
tmp *
|
||||
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
temp_inputX; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputx *= inputxShape.data[i];
|
||||
temp_inputX *= inputXShape.data[i];
|
||||
//----------------------------
|
||||
if (inputyShape.data[i] == 1) {
|
||||
inputyIdx += 0;
|
||||
if (inputYShape.data[i] == 1) {
|
||||
inputYIdx += 0;
|
||||
} else {
|
||||
inputyIdx +=
|
||||
inputYIdx +=
|
||||
tmp *
|
||||
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
temp_inputY; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputy *= inputyShape.data[i];
|
||||
temp_inputY *= inputYShape.data[i];
|
||||
//--------------------------
|
||||
if (conditionShape.data[i] == 1) {
|
||||
conditionIdx += 0;
|
||||
|
@ -57,17 +57,15 @@ __global__ void _where_kernel(const float *inputx, const float *inputy,
|
|||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] =
|
||||
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
|
||||
condition[conditionIdx] ? inputX[inputXIdx] : inputY[inputYIdx];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape) {
|
||||
int outputsize = 1;
|
||||
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
|
@ -75,8 +73,8 @@ void where_kernel(const float *inputx, const float *inputy,
|
|||
}
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_where_kernel<<<gridsize, blocksize>>>(
|
||||
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
|
||||
inputyShape, conditionShape, outputShape);
|
||||
_whereKernel<<<gridsize, blocksize>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -10,7 +10,6 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
|
|||
}
|
||||
|
||||
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() > 1);
|
||||
Shape dims = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank();
|
||||
ShapeElem n = dims.at(dim);
|
||||
|
|
|
@ -6,7 +6,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
|
|||
Tensor output, int axis)
|
||||
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
|
||||
int rank = input->getRank();
|
||||
axis = get_real_axis(axis, rank);
|
||||
this->axis = get_real_axis(axis, rank);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
|||
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 2);
|
||||
auto index_dtype = inputs[1]->getDType();
|
||||
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64)
|
||||
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64);
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ namespace host_backtrace = backward;
|
|||
host_backtrace::SignalHandling sh;
|
||||
|
||||
namespace infini {
|
||||
Exception::Exception(const std::string &msg) : std::runtime_error(msg) {
|
||||
host_backtrace::StackTrace st;
|
||||
Exception::Exception(const std::string &msg)
|
||||
: std::runtime_error(msg), info(msg) {
|
||||
backward_trace::StackTrace st;
|
||||
st.load_here(32);
|
||||
host_backtrace::Printer p;
|
||||
p.print(st);
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
namespace infini {
|
||||
/*
|
||||
// Test cuda splitted idx to complosed idx in cpu. Uncomment to run this test.
|
||||
int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
|
||||
int concatDim, int outputDimSize[4],
|
||||
int outputStride[4], int nDim) {
|
||||
|
@ -22,7 +23,8 @@ int inputOffset2CatOffset(int linearIndex, int dimBgNo, int dimSize,
|
|||
offset += oP * outputStride[i];
|
||||
}
|
||||
|
||||
return offset + linearIndex * outputStride[0];
|
||||
int oP = (concatDim == 0) ? (linearIndex + dimBgNo) : linearIndex;
|
||||
return offset + oP * outputStride[0];
|
||||
}
|
||||
|
||||
TEST(Concat, OffsetTrans) {
|
||||
|
@ -41,8 +43,22 @@ TEST(Concat, OffsetTrans) {
|
|||
4);
|
||||
EXPECT_EQ(inputOffset2CatOffset(3, 1, 2, catDim, dimSize, strides, nDim),
|
||||
5);
|
||||
catDim = 0;
|
||||
EXPECT_EQ(inputOffset2CatOffset(0, 0, 3, catDim, dimSize, strides, nDim),
|
||||
0);
|
||||
EXPECT_EQ(inputOffset2CatOffset(1, 0, 3, catDim, dimSize, strides, nDim),
|
||||
1);
|
||||
EXPECT_EQ(inputOffset2CatOffset(2, 0, 3, catDim, dimSize, strides, nDim),
|
||||
2);
|
||||
EXPECT_EQ(inputOffset2CatOffset(0, 1, 3, catDim, dimSize, strides, nDim),
|
||||
3);
|
||||
EXPECT_EQ(inputOffset2CatOffset(1, 1, 3, catDim, dimSize, strides, nDim),
|
||||
4);
|
||||
EXPECT_EQ(inputOffset2CatOffset(2, 1, 3, catDim, dimSize, strides, nDim),
|
||||
5);
|
||||
}
|
||||
*/
|
||||
|
||||
TEST(Concat, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
@ -78,4 +94,32 @@ TEST(Concat, Cuda) {
|
|||
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(Concat, Cuda_dim0) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||
auto t2 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||
auto t3 = gCpu->addTensor({1, 3}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto t1Gpu = gCuda->cloneTensor(t1);
|
||||
auto t2Gpu = gCuda->cloneTensor(t2);
|
||||
auto t3Gpu = gCuda->cloneTensor(t3);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 0);
|
||||
gCuda->dataMalloc();
|
||||
t1Gpu->setData(IncrementalGenerator()); // 0 1 2
|
||||
t2Gpu->setData(OneGenerator()); // 1 1 1
|
||||
t3Gpu->setData(IncrementalGenerator()); // 0 1 2
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 1, 2, 1, 1, 1, 0, 1, 2}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -39,4 +39,30 @@ TEST(Split, Cuda) {
|
|||
12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39}));
|
||||
}
|
||||
|
||||
TEST(Split, Cuda_dim0) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({2, 3}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
auto op = gCuda->addOp<SplitObj>(inputGpu, std::nullopt, 0, 2);
|
||||
gCuda->dataMalloc();
|
||||
inputGpu->setData(IncrementalGenerator());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
EXPECT_EQ(op->getOutputs().size(), (size_t)2);
|
||||
auto o0Cpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
auto o1Cpu = gCpu->cloneTensor(op->getOutput(1));
|
||||
EXPECT_TRUE(o0Cpu->equalData(vector<float>{0, 1, 2}));
|
||||
EXPECT_TRUE(o1Cpu->equalData(vector<float>{3, 4, 5}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -10,11 +10,12 @@ namespace infini {
|
|||
|
||||
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
||||
const Shape &inputyshape, const vector<float> &inputydata,
|
||||
const Shape &conditionshape, const vector<int> &conditiondata,
|
||||
const Shape &conditionshape,
|
||||
const vector<uint8_t> &conditiondata,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
|
||||
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
|
||||
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
||||
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
||||
|
||||
|
@ -47,16 +48,37 @@ TEST(CUDA_Where, run) {
|
|||
test_where(
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
||||
|
||||
test_where(Shape{2, 1, 1, 3}, // inputx
|
||||
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
||||
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
||||
vector<int>{0, 1, 1, 0, 0, 0},
|
||||
vector<uint8_t>{0, 1, 1, 0, 0, 0},
|
||||
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
||||
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
||||
test_where(
|
||||
Shape{
|
||||
3,
|
||||
},
|
||||
vector<float>{0, 1, 2}, // inputX
|
||||
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
|
||||
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||
test_where(
|
||||
Shape{
|
||||
3,
|
||||
},
|
||||
vector<float>{0, 1, 2}, // inputX
|
||||
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
||||
Shape{2, 1, 3, 1},
|
||||
vector<uint8_t>{false, true, true, false, false, false}, // condition
|
||||
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
||||
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
||||
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
||||
|
||||
} // python output
|
||||
|
||||
|
|
Loading…
Reference in New Issue