tensor parallel for transformer (#125)

* add cmake bits about NCCL

* move example to examples/NNmodel

* impl NCCL communicator

* add comm related function to Runtime

* export runtime interface

* add launch.py

* use unique name to distingush the the NCCL ID file

* add timeout to communicator init

* expose communicator obj from runtime obj, add unit test for nccl communicator

* reformat files

* Add allReduce operator and cuda nccl allReduce kernel

* impl model parallel for resnet

* add allGather nccl kernel and operator

* Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output

* fix format of onnx.py

* use concat following AllGather

* get tensor parallel for resnet

* fix format of graph_handler.cc

* change BUILD_DIST default to OFF

* polish code of communicator

* update .gitignore

* export min/max to python

* fix MatMul

* modify launch.py to run opt

* hack to treat ReduceSum as AllReduceSum

* throw exception in cuda error

* fix parallel_opt.py

* improve the error prompt and cuda error check

* fix GatherObj::GatherObj member init

* fix size calculation for scalar (rank = 0) tensor

* MatMul supports bias

* fix add bias for row parallel gemm

* add --gen_std to launch.py

* fix AllReduceNCCL

* update launch.py

* less log

* update parallel_opt

* update launch.py

* add __eq__ for Placement sub-classes

* less benchmark run

* fix placement infer for matmul

* fix vacabuary size

* fix Exception

* Add shard tensor with group to support gpt2

* Add find successor function to find split op at different depth

* recover CommunicatorObj

* improve error mesasge

* optimize parallel_opt.py

* optimize launch.py

* recover docs for all_reduce and all_gather

* Fix API

* fix format

---------

Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
constroy Li 2023-09-14 14:19:45 +08:00 committed by GitHub
parent dda668fd16
commit 4c321c8a91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 454 additions and 90 deletions

View File

@ -4,8 +4,12 @@ import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
import numpy as np
from parallel import parallel_model
from parallel_opt import parallel_model
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
def parse_args():
@ -14,77 +18,126 @@ def parse_args():
parser.add_argument(
"--nproc_per_node", type=int, default=1, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model", type=str, required=True, help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
action="store_true",
help="whether to generate the standard results.",
)
args = parser.parse_args()
print("arg setting: ", args)
return args.num_nodes, args.nproc_per_node, args.model
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.batch_size,
args.length,
args.gen_std,
)
def run_stub(stub: OnnxStub, inputs: np.array, n=100):
# warm up
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
def run_model(model, runtime, inputs: np.array, n=20):
stub = OnnxStub(model, runtime)
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
stub.tune()
for _ in range(20):
stub.run()
stub.run()
# get outputs
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
# bench
next(stub.inputs.items().__iter__())[1].copyin_float(inputs.reshape(-1).tolist())
next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs)
begin = time.time()
for _ in range(n):
stub.run()
end = time.time()
outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float())
print("outputs sum:", outputs.sum())
# np.save("results", outputs)
results = np.load("results.npy")
print("max diff:", abs(outputs - results).max())
assert np.allclose(outputs, results, rtol=1e-6, atol=1e-6)
avg_time = (end - begin) / n
return avg_time
print(f"average time: {avg_time}")
return outputs
def run_and_compare(name, model, runtime):
data = np.load(f"{name}_inputs.npy")
results = np.load(f"{name}_results.npy")
outputs = run_model(model, runtime, data)
print("outputs sum:", outputs.sum())
print("max abs diff:", abs(outputs - results).max())
print("max rel diff:", abs((outputs - results) / results).max())
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
def start_worker(
dist_name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
):
print("start worker")
dist_name = name + "_dist"
model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
model,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model, f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank)
print("init comm")
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
model = parallel_model(model, world_size, rank)
onnx.save(model, f"dist_model_rank{rank}.onnx")
print("load model")
stub = OnnxStub(model, runtime)
data = np.load("inputs.npy")
print("run model")
avg_time = run_stub(stub, data)
print(f"average time: {avg_time}")
run_and_compare(name, model, runtime)
def start_single(name, model):
runtime = backend.CudaRuntime(0)
run_and_compare(name, model, runtime)
def gen_standard(name, model, voc_size, bs, len):
# generate standard results
data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
np.save(f"{name}_inputs", data)
runtime = backend.CudaRuntime(0)
outputs = run_model(model, runtime, data, 1)
np.save(f"{name}_results", outputs)
def main():
nnodes, nproc_per_node, model_path = parse_args()
world_size = nnodes * nproc_per_node
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
model = onnx.load(model_path)
# generate standard results
# runtime = backend.CudaRuntime(0)
# stub = OnnxStub(model, runtime)
# data = np.random.randn(1, 3, 224, 224)
# np.save("inputs", data)
# run_stub(stub, data)
# del stub
dist_name = f"dist_{os.getpid()}"
# generate standart output
if gen_std:
print(f"generate standard data for {name}.")
# a small vocabulary size to fit all LLM.
voc_size = 1000
gen_standard(name, model, voc_size, bs, length)
return
# run single process.
# use standalone process to isolate cuda.
p = mp.Process(target=start_single, args=(name, model))
p.start()
p.join()
# run distributed parallel.
world_size = nnodes * nproc_per_node
workers = [
mp.Process(
target=start_worker,
args=(dist_name, world_size, rank, rank % nproc_per_node, model),
args=(name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]

View File

@ -0,0 +1,221 @@
import onnx
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto
from onnx import helper, numpy_helper
from typing import Dict, List
from placement import Placement, Replicate, Shard, _Partial
import numpy as np
def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
data = {init.name: init for init in model.graph.initializer}
vinfo = {info.name: info for info in model.graph.value_info}
vinfo.update({info.name: info for info in model.graph.input})
vinfo.update({info.name: info for info in model.graph.output})
place: Dict[str, Placement] = {}
nodes: List[NodeProto] = []
def is_sharded(name: str):
return place[name].is_shard()
def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1):
# print(f"shard {tensor.name} at dim {dim}")
assert plc.is_shard(), plc
ndim = len(tensor.dims)
if plc.dim < 0:
plc.dim += ndim
if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard.
return tensor
array = numpy_helper.to_array(tensor)
assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim]
dims = list(tensor.dims)
dims.insert(plc.dim, groups)
dims[plc.dim + 1] //= groups
array = array.reshape(dims)
seg = array.shape[plc.dim + 1] // tp_world_size
array = array.take(
indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1
)
dims = list(tensor.dims)
dims[plc.dim] //= tp_world_size
array = array.reshape(dims)
tensor = numpy_helper.from_array(array, name=tensor.name)
place[tensor.name] = plc
return tensor
def shard_gemm(node: NodeProto, groups: int = 1):
# print("gemm", node.name)
in_plc = place[node.input[0]]
w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0)
transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0)
if transB:
w_plc.dim = ~w_plc.dim
input = node.input[1]
data[input] = shard_tensor(data[input], w_plc, groups)
output = node.output[0]
ndim = len(vinfo[output].type.tensor_type.shape.dim)
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
place[node.output[0]] = out_plc
def shard_binary(node: NodeProto, groups: int = 1):
# print("binary", node.name, node.input[0], place[node.input[0]])
a = node.input[0]
b = node.input[1]
if a in data:
a, b = b, a
place[node.output[0]] = place[a]
if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast
data[b] = shard_tensor(data[b], Shard(0), groups)
def shard_reshape(node: NodeProto):
# print("reshape", node.name, node.input[0], place[node.input[0]])
if not is_sharded(node.input[0]):
return
in_plc = place[node.input[0]]
s_dim = -1
in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim]
tensor = data[node.input[1]]
out_dims = numpy_helper.to_array(tensor).copy()
if len(in_dims) == 3 and len(out_dims) == 4:
if in_plc.dim == 0:
s_dim = 1
elif in_plc.dim == 2:
s_dim = 2
if len(in_dims) == 4 and len(out_dims) == 3:
if in_plc.dim == 1:
s_dim = 0
elif in_plc.dim == 2:
s_dim = 2
if len(in_dims) == 2 and len(out_dims) == 3:
if in_plc.dim == 1:
s_dim = 2
if len(in_dims) == 4 and len(out_dims) == 2:
if in_plc.dim == 1:
s_dim = 0
elif in_plc.dim == 2:
s_dim = 1
if len(in_dims) == 3 and len(out_dims) == 2:
if in_plc.dim == 1:
s_dim = 0
elif in_plc.dim == 2:
s_dim = 1
assert s_dim != -1
assert out_dims[s_dim] % tp_world_size == 0, out_dims
out_dims[s_dim] //= tp_world_size
# if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others.
# node.input[1] = node.output[0] + "_shape"
data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1])
place[node.output[0]] = Shard(s_dim)
def shard_split(node: NodeProto):
if not is_sharded(node.input[0]):
return
in_plc = place[node.input[0]]
split_tensor = data[node.input[1]]
split = numpy_helper.to_array(split_tensor).copy()
split //= tp_world_size
data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1])
for output in node.output:
place[output] = in_plc
def shard_transpose(node: NodeProto):
plc = place[node.input[0]]
if plc.is_shard():
perm = next(attr.ints for attr in node.attribute if attr.name == "perm")
place[node.output[0]] = Shard(list(perm).index(plc.dim))
def shard_node(node: NodeProto):
if node.op_type in ["Relu", "Tanh", "Softmax"]:
place[node.output[0]] = place[node.input[0]]
elif node.op_type in ["Where"]:
place[node.output[0]] = place[node.input[1]]
if node.op_type in {"Add", "Mul", "Div", "Max"}:
shard_binary(node)
elif node.op_type == "Reshape":
shard_reshape(node)
elif node.op_type == "Transpose":
shard_transpose(node)
elif node.op_type == "Split":
shard_split(node)
elif node.op_type == "MatMul":
assert (
place[node.input[0]] == place[node.input[1]]
), f"{place[node.input[0]]} != {place[node.input[1]]}"
place[node.output[0]] = place[node.input[0]]
def find_successor(op_type: str, idx: int, search_limit: int = 1):
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
if node.op_type == op_type:
return node
return None
# all tensors are initially replicated.
for v in vinfo:
place[v] = Replicate()
for t in data:
place[t] = Replicate()
for index, node in enumerate(model.graph.node):
nodes.append(node)
# linear
if (node.op_type == "MatMul" or node.op_type == "Gemm") and any(
input in data for input in node.input
):
groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
split_node = find_successor("Split", index, search_limit=2)
if split_node is not None:
groups = len(split_node.output)
shard_gemm(node, groups)
plc = place[node.output[0]]
if plc.is_partial():
new_name = node.output[0] + f":{plc}"
place[new_name] = place[node.output[0]]
# insert all_reduce
nodes.append(
helper.make_node(
op_type="ReduceSum",
inputs=[new_name],
outputs=[node.output[0]],
name=node.name + "/all_reduce",
noop_with_empty_axes=1,
communicator=0, # hack to treat ReduceSum as AllReduceSum
)
)
place[node.output[0]] = Replicate()
node.output[0] = new_name
if len(node.input) > 2: # split bias to add
prev = nodes[-1]
new_name = prev.output[0] + "_no_bias"
place[new_name] = place[node.output[0]]
bias = helper.make_node(
op_type="Add",
inputs=[new_name, node.input[2]],
outputs=[prev.output[0]],
name=node.name + "/bias",
)
node.input.pop()
prev.output[0] = new_name
shard_binary(bias, groups)
nodes.append(bias)
continue
shard_node(node)
graph = helper.make_graph(
nodes,
model.graph.name + f"_{tp_rank}",
model.graph.input,
model.graph.output,
data.values(),
doc_string=model.graph.doc_string,
# value_info=vinfo.values(),
)
for output in graph.output:
tt = output.type.tensor_type
if tt.HasField("shape"):
tt.ClearField("shape")
model = helper.make_model(graph)
model = onnx.shape_inference.infer_shapes(model)
return model

View File

@ -0,0 +1,64 @@
from typing import Optional
class Placement:
# base class Placement type
# convenient utils to check for placement types
def is_shard(self, dim: Optional[int] = None) -> bool:
if dim is not None and isinstance(self, Shard):
return self.dim == dim
else:
return isinstance(self, Shard)
def is_replicate(self) -> bool:
return isinstance(self, Replicate)
def is_partial(self) -> bool:
return isinstance(self, _Partial)
class Replicate(Placement):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Replicate):
return False
return True
def __repr__(self) -> str:
"""
machine readable representation of the Replicate placement
"""
return "Replicate()"
class Shard(Placement):
# shard placement, shard on a dim
def __init__(self, dim):
self.dim = dim
def __eq__(self, other: object) -> bool:
if not isinstance(other, Shard):
return False
return self.dim == other.dim
def __repr__(self) -> str:
"""
machine readable representation of the Shard placement
"""
return f"Shard(dim={self.dim})"
class _Partial(Placement):
def __init__(self, reduce_op: str = "sum"):
self.reduce_op: str = reduce_op
def __eq__(self, other: object) -> bool:
if not isinstance(other, _Partial):
return False
return self.reduce_op == other.reduce_op
def __repr__(self) -> str:
"""
machine readable representation of the Partial placement
"""
return f"_Partial(reduce_op={self.reduce_op})"

View File

@ -40,12 +40,12 @@ using HashType = uint64_t; // compatible with std::hash
// Assert: conditions should have no side effect
#define _IT_ASSERT_2(condition, info) \
(static_cast<bool>(condition) \
? void(0) \
: throw ::infini::Exception( \
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
"] Assertion failed (" + #condition + "): " + info))
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "");
static_cast<bool>(condition) \
? void(0) \
: throw ::infini::Exception( \
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
"] Assertion failed (" + #condition + "): " + info)
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "")
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented")

View File

@ -6,16 +6,11 @@
#include <cudnn.h>
#include <curand.h>
// TODO: replace with Exception (IT_ASSERT)
#define checkCudaError(call) \
{ \
auto err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
if (auto err = call; err != cudaSuccess) \
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
std::to_string(__LINE__) + "] CUDA error (" + \
#call + "): " + cudaGetErrorString(err))
#define checkCUresult(call) \
{ \
@ -39,14 +34,10 @@
}
#define checkCudnnError(call) \
{ \
auto err = call; \
if (CUDNN_STATUS_SUCCESS != err) { \
fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \
__LINE__, cudnnGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
if (auto err = call; err != CUDNN_STATUS_SUCCESS) \
throw ::infini::Exception(std::string("[") + __FILE__ + ":" + \
std::to_string(__LINE__) + "] cuDNN error (" + \
#call + "): " + cudnnGetErrorString(err))
#define checkCurandError(call) \
{ \

View File

@ -5,8 +5,18 @@
namespace infini {
class Exception : public std::runtime_error {
protected:
std::string info;
public:
Exception(const std::string &msg);
Exception &operator<<(const std::string &str) {
info += str;
return *this;
}
const char *what() const noexcept override { return info.c_str(); }
};
} // namespace infini

View File

@ -1,3 +1,4 @@
#pragma once
namespace infini {
#define SMALL_ARRAY_SIZE 8

View File

@ -591,6 +591,13 @@ class OnnxStub:
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "to")),
)
elif node.op_type == "ReduceSum":
# ReduceSum is only implemented as allReduceSum.
assert any(attr.name == "communicator" for attr in node.attribute)
tensors[node.output[0]] = self.handler.allReduceSum(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "AllReduceSum":
tensors[node.output[0]] = self.handler.allReduceSum(
tensors[node.input[0]],
@ -631,13 +638,9 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
next(
(
attr.i
for attr in node.attribute
if attr.name == "root"
),
0,
),
(attr.i for attr in node.attribute if attr.name == "root"),
0,
),
)
elif node.op_type == "Expand":
shape = _parse_data(data[node.input[1]])

View File

@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase):
[pads_data],
)
)
def test_allReduceSum(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase):
graph = make_graph([allReduceProd], "allReduceProd", [input], [output])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
def test_allReduceMin(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase):
graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
def test_split(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
split = make_node(
"Split", ["input"], ["output"], name="split", axis=0
)
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
make_and_import_model(make_graph([split], "split", [input], []))
def test_allBroadcast(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -461,7 +459,7 @@ class TestStringMethods(unittest.TestCase):
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
def test_copyin(self):
dims = [2,3,5,4]
dims = [2, 3, 5, 4]
np_array = np.random.random(dims).astype(np.float32)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
@ -487,7 +485,7 @@ class TestStringMethods(unittest.TestCase):
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
def test_to_numpy(self):
dims = [2,3,5,4]
dims = [2, 3, 5, 4]
np_array = np.random.random(dims).astype(np.float32)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
@ -508,5 +506,6 @@ class TestStringMethods(unittest.TestCase):
array1 = np.array(tensor1, copy=False)
self.assertTrue(np.array_equal(array1, np_array))
if __name__ == "__main__":
unittest.main()

View File

@ -8,7 +8,6 @@
#include "operators/conv.h"
#include "operators/matmul.h"
#ifdef DEBUG_MODE
void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
cudaError_t kernelError = cudaGetLastError();
if (kernelError != cudaSuccess) {
@ -18,7 +17,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
exit(EXIT_FAILURE);
}
}
#endif
namespace infini {
@ -38,10 +36,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
} else {
kernel->compute(op, this);
}
#ifdef DEBUG_MODE
CHECK_CUDA_KERNEL_ERROR(op);
#endif
checkCudaError(cudaGetLastError()) << op->toString();
}
}
@ -78,9 +73,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
opCnt[op->getOpType()]++;
}
#ifdef DEBUG_MODE
CHECK_CUDA_KERNEL_ERROR(op);
#endif
checkCudaError(cudaGetLastError()) << op->toString();
}
}
@ -103,6 +96,7 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
IT_ASSERT(worldSize > 0);
IT_ASSERT(rank >= 0);
IT_ASSERT(rank < worldSize);
IT_ASSERT(!comm) << "communicator is already initialized.";
#ifdef INFINI_USE_NCCL
comm = std::make_unique<NcclCommunicatorObj>(name, worldSize, rank);
#else

View File

@ -421,6 +421,8 @@ void init_graph_builder(py::module &m) {
.def("mul", &Handler::mul, policy::move)
.def("div", &Handler::div, policy::move)
.def("pow", &Handler::pow, policy::move)
.def("min", &Handler::min, policy::move)
.def("max", &Handler::max, policy::move)
.def("relu", &Handler::relu, policy::move)
.def("sigmoid", &Handler::sigmoid, policy::move)
.def("tanh", &Handler::tanh, policy::move)

View File

@ -14,7 +14,7 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t count = op->getInputs(0)->getBytes() / op->getDType().getSize();
size_t count = op->getInputs(0)->size();
ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())

View File

@ -1,6 +1,8 @@
#include "operators/matmul.h"
#include "core/kernel.h"
#include "cuda/cuda_expand.h"
#include "cuda/cuda_runtime.h"
#include "utils/small_array.h"
namespace infini {
@ -46,7 +48,30 @@ class matmulCublas : public Kernel {
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
ldc = n;
const float alpha = 1.f, beta = 0.f;
float alpha = 1.f, beta = 0.f;
if (op->numInputs() == 2) { // no bias
beta = 0.f;
} else { // broadcast bias to output
beta = 1.f;
auto inC = op->getInputs(2);
auto out = op->getOutput();
SmallArray inputShape, outputShape;
int nDims = out->getRank();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
int offset = nDims - inC->getRank();
for (int i = 0; i < offset; ++i)
inputShape.data[i] = 1;
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out->getDims()[i];
outputsize *= outputShape.data[i];
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
expandKernel(inC->getRawDataPtr<float *>(),
out->getRawDataPtr<float *>(), nDims, outputsize,
inputShape, outputShape);
}
// TODO:use compute type
cublasStatus_t stat;
if (b > 1) {

View File

@ -6,7 +6,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
Tensor output, int axis)
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
int rank = input->getRank();
axis = get_real_axis(axis, rank);
this->axis = get_real_axis(axis, rank);
IT_ASSERT(checkValid(graph));
}
@ -25,7 +25,7 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2);
auto index_dtype = inputs[1]->getDType();
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64)
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64);
return {inputs[0]->getDType()};
}

View File

@ -9,7 +9,8 @@ namespace backward_trace = backward;
backward_trace::SignalHandling sh;
namespace infini {
Exception::Exception(const std::string &msg) : std::runtime_error(msg) {
Exception::Exception(const std::string &msg)
: std::runtime_error(msg), info(msg) {
backward_trace::StackTrace st;
st.load_here(32);
backward_trace::Printer p;