forked from jiuyuan/InfiniTensor
Add send and recv operators based on NCCL (#182)
* baseline sendrecv, bug * success sendrecv * get rank from comm * set output shape * successful:set output shape equal to input shape * shape as attribute * success:shape as attribute * success send recv, output 0 * add onnx test * split send and recv * success split send and recv * test-onnx bug * success test-onnx * modified onnx.py * solve review
This commit is contained in:
parent
c143eebdf7
commit
a3929c25f8
|
@ -94,6 +94,9 @@ class GraphHandlerObj {
|
||||||
Tensor allReduceAvg(Tensor input, Tensor output);
|
Tensor allReduceAvg(Tensor input, Tensor output);
|
||||||
TensorVec allGather(Tensor input, std::optional<TensorVec> outputs, int n);
|
TensorVec allGather(Tensor input, std::optional<TensorVec> outputs, int n);
|
||||||
Tensor broadcast(Tensor input, Tensor output, int root);
|
Tensor broadcast(Tensor input, Tensor output, int root);
|
||||||
|
Tensor send(Tensor input, int source, int destination, Tensor output);
|
||||||
|
Tensor recv(Tensor output, int source, int destination, Shape dims,
|
||||||
|
int outputType, Tensor input);
|
||||||
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
|
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
|
||||||
std::string mode);
|
std::string mode);
|
||||||
|
|
||||||
|
|
|
@ -232,6 +232,8 @@ struct OpType {
|
||||||
AllReduceAvg,
|
AllReduceAvg,
|
||||||
AllGather,
|
AllGather,
|
||||||
Broadcast,
|
Broadcast,
|
||||||
|
Send,
|
||||||
|
Recv,
|
||||||
} type;
|
} type;
|
||||||
|
|
||||||
constexpr OpType(decltype(type) t) : type(t) {}
|
constexpr OpType(decltype(type) t) : type(t) {}
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
|
||||||
|
*/
|
||||||
|
class RecvObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new SendRecv object
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input default nullptr, because recv does not have input.
|
||||||
|
* @param output recv output
|
||||||
|
* @param source the send rank
|
||||||
|
* @param destination the recv rank
|
||||||
|
* @param dims The shape of the output tensor.
|
||||||
|
*/
|
||||||
|
RecvObj(GraphObj *graph, Tensor output, int source, int destination,
|
||||||
|
Shape dims, int outputType, Tensor input = nullptr);
|
||||||
|
OP_CLONE(RecvObj);
|
||||||
|
|
||||||
|
int numInputs() const override { return inputs.size(); }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
std::string toString() const override;
|
||||||
|
DataType getDType() const;
|
||||||
|
int getSourceRank() const { return source; }
|
||||||
|
int getDestinationRank() const { return destination; }
|
||||||
|
inline Shape getShape() const { return dims; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int source;
|
||||||
|
int destination;
|
||||||
|
Shape dims;
|
||||||
|
int outputType;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,42 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
|
||||||
|
*/
|
||||||
|
class SendObj : public OperatorObj {
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new SendRecv object
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input send input
|
||||||
|
* @param output recv output
|
||||||
|
* @param source the send rank
|
||||||
|
* @param destination the recv rank
|
||||||
|
*/
|
||||||
|
SendObj(GraphObj *graph, Tensor input, int source, int destination,
|
||||||
|
Tensor output = nullptr);
|
||||||
|
OP_CLONE(SendObj);
|
||||||
|
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return outputs.size(); }
|
||||||
|
std::string toString() const override;
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
int getSourceRank() const { return source; }
|
||||||
|
int getDestinationRank() const { return destination; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int source;
|
||||||
|
int destination;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -716,14 +716,17 @@ class OnnxStub:
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
keepdims = next(
|
keepdims = (
|
||||||
|
next(
|
||||||
(
|
(
|
||||||
attr.i
|
attr.i
|
||||||
for attr in node.attribute
|
for attr in node.attribute
|
||||||
if attr.name == "keepdims"
|
if attr.name == "keepdims"
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
) != 0
|
)
|
||||||
|
!= 0
|
||||||
|
)
|
||||||
|
|
||||||
tensors[node.output[0]] = self.handler.reduceSum(
|
tensors[node.output[0]] = self.handler.reduceSum(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -775,6 +778,58 @@ class OnnxStub:
|
||||||
0,
|
0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Send":
|
||||||
|
source = next(
|
||||||
|
(attr.i for attr in node.attribute if attr.name == "source"),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
destination = next(
|
||||||
|
(
|
||||||
|
attr.i
|
||||||
|
for attr in node.attribute
|
||||||
|
if attr.name == "destination"
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.handler.send(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
source,
|
||||||
|
destination,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
elif node.op_type == "Recv":
|
||||||
|
source = next(
|
||||||
|
(attr.i for attr in node.attribute if attr.name == "source"),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
destination = next(
|
||||||
|
(
|
||||||
|
attr.i
|
||||||
|
for attr in node.attribute
|
||||||
|
if attr.name == "destination"
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name == "shape":
|
||||||
|
shapeBasic = attr.ints
|
||||||
|
shape = []
|
||||||
|
for item in shapeBasic:
|
||||||
|
shape.append(item)
|
||||||
|
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name == "dataType":
|
||||||
|
outputType = attr.i
|
||||||
|
tensors[node.output[0]] = self.handler.recv(
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
source,
|
||||||
|
destination,
|
||||||
|
shape,
|
||||||
|
outputType,
|
||||||
|
None,
|
||||||
|
)
|
||||||
elif node.op_type == "Expand":
|
elif node.op_type == "Expand":
|
||||||
shape = _parse_data(data[node.input[1]])
|
shape = _parse_data(data[node.input[1]])
|
||||||
tensors[node.output[0]] = self.handler.expand(
|
tensors[node.output[0]] = self.handler.expand(
|
||||||
|
@ -1091,10 +1146,7 @@ class OnnxStub:
|
||||||
elif ty == backend.OpTypeId.Gather:
|
elif ty == backend.OpTypeId.Gather:
|
||||||
axis = backend.gather_axis_of(op)
|
axis = backend.gather_axis_of(op)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
elif ty in [
|
elif ty in [backend.OpTypeId.ReduceMean, backend.OpTypeId.ReduceSum]:
|
||||||
backend.OpTypeId.ReduceMean,
|
|
||||||
backend.OpTypeId.ReduceSum
|
|
||||||
]:
|
|
||||||
axes, keepdims = backend.reduce_attrs_of(op)
|
axes, keepdims = backend.reduce_attrs_of(op)
|
||||||
inputs.append(
|
inputs.append(
|
||||||
ctx.push_data_input(
|
ctx.push_data_input(
|
||||||
|
|
|
@ -508,6 +508,29 @@ class TestStringMethods(unittest.TestCase):
|
||||||
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
||||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||||
|
|
||||||
|
def test_send(self):
|
||||||
|
sendInput = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
send = make_node("Send", ["input"], [], name="send", source=0, destination=1)
|
||||||
|
graph = make_graph([send], "send", [sendInput], [])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
def test_recv(self):
|
||||||
|
recvOutput = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
recv = make_node(
|
||||||
|
"Recv",
|
||||||
|
[],
|
||||||
|
["output"],
|
||||||
|
name="recv",
|
||||||
|
source=0,
|
||||||
|
destination=1,
|
||||||
|
shape=[1, 3, 5, 7],
|
||||||
|
dataType=1,
|
||||||
|
)
|
||||||
|
graph = make_graph([recv], "recv", [], [recvOutput])
|
||||||
|
model = make_model(graph)
|
||||||
|
from_onnx(model, backend.cpu_runtime())
|
||||||
|
|
||||||
|
|
||||||
class TestDynamicTensor(unittest.TestCase):
|
class TestDynamicTensor(unittest.TestCase):
|
||||||
def test_dynamic_tensor(self):
|
def test_dynamic_tensor(self):
|
||||||
|
@ -517,6 +540,7 @@ class TestDynamicTensor(unittest.TestCase):
|
||||||
for root, dirs, files in os.walk(current_path):
|
for root, dirs, files in os.walk(current_path):
|
||||||
if filename in files:
|
if filename in files:
|
||||||
model_file = os.path.join(root, filename)
|
model_file = os.path.join(root, filename)
|
||||||
|
|
||||||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
|
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
|
||||||
output_key = list(model.outputs.keys())[0]
|
output_key = list(model.outputs.keys())[0]
|
||||||
old_output_shape = model.getShape(output_key)
|
old_output_shape = model.getShape(output_key)
|
||||||
|
|
|
@ -11,20 +11,33 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||||
map<UidBaseType, Tensor> tensorPool;
|
map<UidBaseType, Tensor> tensorPool;
|
||||||
// Clone tensors
|
// Clone tensors
|
||||||
for (const auto &op : ops_in) {
|
for (const auto &op : ops_in) {
|
||||||
for (const auto &t : op->getInputs())
|
for (const auto &t : op->getInputs()) {
|
||||||
|
if (t) {
|
||||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||||
for (const auto &t : op->getOutputs())
|
}
|
||||||
|
}
|
||||||
|
for (const auto &t : op->getOutputs()) {
|
||||||
|
if (t) {
|
||||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||||
tensorPool[t->getFuid()] = cloneTensor(t);
|
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// Clone operators and add connections
|
// Clone operators and add connections
|
||||||
for (const auto &op : ops_in) {
|
for (const auto &op : ops_in) {
|
||||||
TensorVec inputs, outputs;
|
TensorVec inputs, outputs;
|
||||||
for (const auto &t : op->getInputs())
|
for (const auto &t : op->getInputs()) {
|
||||||
|
if (t) {
|
||||||
inputs.emplace_back(tensorPool.at(t->getFuid()));
|
inputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||||
for (const auto &t : op->getOutputs())
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &t : op->getOutputs()) {
|
||||||
|
if (t) {
|
||||||
outputs.emplace_back(tensorPool.at(t->getFuid()));
|
outputs.emplace_back(tensorPool.at(t->getFuid()));
|
||||||
|
}
|
||||||
|
}
|
||||||
addOperatorAndConnect(op->clone(inputs, outputs));
|
addOperatorAndConnect(op->clone(inputs, outputs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,13 +46,16 @@ void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||||
sorted = false;
|
sorted = false;
|
||||||
ops.push_back(op);
|
ops.push_back(op);
|
||||||
for (auto &input : op->getInputs()) {
|
for (auto &input : op->getInputs()) {
|
||||||
|
if (input) {
|
||||||
input->addTarget(op);
|
input->addTarget(op);
|
||||||
if (auto pred = input->getSource()) {
|
if (auto pred = input->getSource()) {
|
||||||
pred->addSuccessors(op);
|
pred->addSuccessors(op);
|
||||||
op->addPredecessors(pred);
|
op->addPredecessors(pred);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (auto &output : op->getOutputs()) {
|
for (auto &output : op->getOutputs()) {
|
||||||
|
if (output) {
|
||||||
output->setSource(op);
|
output->setSource(op);
|
||||||
for (auto &succ : output->getTargets()) {
|
for (auto &succ : output->getTargets()) {
|
||||||
succ->addPredecessors(op);
|
succ->addPredecessors(op);
|
||||||
|
@ -47,6 +63,7 @@ void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
string GraphObj::toString() const {
|
string GraphObj::toString() const {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@ -88,8 +105,9 @@ bool GraphObj::topo_sort() {
|
||||||
const auto is_head = std::all_of(
|
const auto is_head = std::all_of(
|
||||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||||
auto src = input->getSource();
|
auto src = input->getSource();
|
||||||
return src // If the source node is in the waiting list,
|
return src // If the source node is in the waiting
|
||||||
// means that this node is not the head node.
|
// list, means that this node is not the
|
||||||
|
// head node.
|
||||||
? waiting.find(src) == waiting.end()
|
? waiting.find(src) == waiting.end()
|
||||||
// This tensor has no source node,
|
// This tensor has no source node,
|
||||||
// it must be a input tensor.
|
// it must be a input tensor.
|
||||||
|
@ -110,7 +128,6 @@ bool GraphObj::topo_sort() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Done.
|
// Done.
|
||||||
this->ops = std::move(sorted);
|
this->ops = std::move(sorted);
|
||||||
return this->sorted = true;
|
return this->sorted = true;
|
||||||
|
@ -155,6 +172,7 @@ void GraphObj::shape_infer() {
|
||||||
|
|
||||||
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
||||||
// topological sorting first
|
// topological sorting first
|
||||||
|
|
||||||
IT_ASSERT(topo_sort() == true);
|
IT_ASSERT(topo_sort() == true);
|
||||||
if (useNaiveAllocator) {
|
if (useNaiveAllocator) {
|
||||||
// can not set memory pool when use naive allocator
|
// can not set memory pool when use naive allocator
|
||||||
|
@ -222,13 +240,16 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
||||||
// memory should be allocated for the op's output first
|
// memory should be allocated for the op's output first
|
||||||
auto outputs = op->getOutputs();
|
auto outputs = op->getOutputs();
|
||||||
for (auto &tensor : outputs) {
|
for (auto &tensor : outputs) {
|
||||||
|
if (tensor) {
|
||||||
if (tensor->isOthers()) {
|
if (tensor->isOthers()) {
|
||||||
tensorToOffset[tensor.get()] =
|
tensorToOffset[tensor.get()] =
|
||||||
allocator.alloc(tensor->getBytes());
|
allocator.alloc(tensor->getBytes());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
auto inputs = op->getInputs();
|
auto inputs = op->getInputs();
|
||||||
for (auto &tensor : inputs) {
|
for (auto &tensor : inputs) {
|
||||||
|
if (tensor) {
|
||||||
if (tensor->isOthers()) {
|
if (tensor->isOthers()) {
|
||||||
auto tensorIter = tensorToRefCount.find(tensor.get());
|
auto tensorIter = tensorToRefCount.find(tensor.get());
|
||||||
IT_ASSERT(tensorIter != tensorToRefCount.end());
|
IT_ASSERT(tensorIter != tensorToRefCount.end());
|
||||||
|
@ -244,6 +265,7 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// perform actual memory allocation for non-weight tensors
|
// perform actual memory allocation for non-weight tensors
|
||||||
for (auto &tensor : tensors) {
|
for (auto &tensor : tensors) {
|
||||||
|
|
|
@ -13,8 +13,10 @@
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/pad.h"
|
#include "operators/pad.h"
|
||||||
#include "operators/pooling.h"
|
#include "operators/pooling.h"
|
||||||
|
#include "operators/recv.h"
|
||||||
#include "operators/reduce.h"
|
#include "operators/reduce.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
#include "operators/send.h"
|
||||||
#include "operators/slice.h"
|
#include "operators/slice.h"
|
||||||
#include "operators/softmax.h"
|
#include "operators/softmax.h"
|
||||||
#include "operators/split.h"
|
#include "operators/split.h"
|
||||||
|
@ -434,6 +436,39 @@ Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::send(Tensor input, int source, int destination,
|
||||||
|
Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
|
||||||
|
g->addOpWithOutputs<SendObj>(std::move(input), source, destination,
|
||||||
|
output);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<SendObj>(std::move(input), source, destination, output)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::recv(Tensor output, int source, int destination,
|
||||||
|
Shape dims, int outputType, Tensor input) {
|
||||||
|
|
||||||
|
if (output) {
|
||||||
|
|
||||||
|
g->addOpWithOutputs<RecvObj>(output, source, destination,
|
||||||
|
std::move(dims), outputType,
|
||||||
|
std::move(input));
|
||||||
|
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
|
||||||
|
return g
|
||||||
|
->addOp<RecvObj>(output, source, destination, std::move(dims),
|
||||||
|
outputType, std::move(input))
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<CastObj>(std::move(input), output,
|
g->addOpWithOutputs<CastObj>(std::move(input), output,
|
||||||
|
|
|
@ -6,9 +6,11 @@ namespace infini {
|
||||||
|
|
||||||
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||||
: type(opType), inputs(inputs), outputs(outputs) {
|
: type(opType), inputs(inputs), outputs(outputs) {
|
||||||
|
if (opType != OpType::Recv) {
|
||||||
for (const auto &t : inputs)
|
for (const auto &t : inputs)
|
||||||
IT_ASSERT(t);
|
IT_ASSERT(t);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void OperatorObj::removePredecessors(const Operator &op) {
|
void OperatorObj::removePredecessors(const Operator &op) {
|
||||||
for (auto it = predecessors.begin(); it != predecessors.end();) {
|
for (auto it = predecessors.begin(); it != predecessors.end();) {
|
||||||
|
|
|
@ -511,6 +511,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("allReduceAvg", &Handler::allReduceAvg, policy::move)
|
.def("allReduceAvg", &Handler::allReduceAvg, policy::move)
|
||||||
.def("allGather", &Handler::allGather, policy::move)
|
.def("allGather", &Handler::allGather, policy::move)
|
||||||
.def("broadcast", &Handler::broadcast, policy::move)
|
.def("broadcast", &Handler::broadcast, policy::move)
|
||||||
|
.def("send", &Handler::send, policy::move)
|
||||||
|
.def("recv", &Handler::recv, policy::move)
|
||||||
.def("cast", &Handler::cast, policy::move)
|
.def("cast", &Handler::cast, policy::move)
|
||||||
.def("expand", &Handler::expand, policy::move)
|
.def("expand", &Handler::expand, policy::move)
|
||||||
.def("erf", &Handler::erf, policy::move)
|
.def("erf", &Handler::erf, policy::move)
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "operators/recv.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class RecvNCCL : public CudaKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<RecvObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *output = op->getOutput(0)->getRawDataPtr<void *>();
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
const auto shape = op->getShape();
|
||||||
|
int nDims = shape.size();
|
||||||
|
int outputCount = 1;
|
||||||
|
for (int i = 0; i < nDims; i++) {
|
||||||
|
outputCount *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
int rank;
|
||||||
|
|
||||||
|
checkNcclError(ncclCommUserRank(comm, &rank));
|
||||||
|
|
||||||
|
int source = op->getSourceRank();
|
||||||
|
int destination = op->getDestinationRank();
|
||||||
|
|
||||||
|
if (rank == destination) {
|
||||||
|
|
||||||
|
checkNcclError(
|
||||||
|
ncclRecv(output, outputCount, ncclFloat, source, comm, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL,
|
||||||
|
"Recv_NCCL_CUDA_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,43 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "operators/send.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/nccl_communicator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class SendNCCL : public CudaKernelWithoutConfig {
|
||||||
|
public:
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<SendObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||||
|
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
size_t inputCount =
|
||||||
|
op->getInputs(0)->getBytes() / op->getDType().getSize();
|
||||||
|
|
||||||
|
ncclComm_t comm =
|
||||||
|
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||||
|
.getNcclComm();
|
||||||
|
// TODO: Using default stream 0 for now.
|
||||||
|
int rank;
|
||||||
|
|
||||||
|
checkNcclError(ncclCommUserRank(comm, &rank));
|
||||||
|
|
||||||
|
int source = op->getSourceRank();
|
||||||
|
int destination = op->getDestinationRank();
|
||||||
|
|
||||||
|
if (rank == source) {
|
||||||
|
|
||||||
|
checkNcclError(
|
||||||
|
ncclSend(input, inputCount, ncclFloat, destination, comm, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL,
|
||||||
|
"Send_NCCL_CUDA_Float32");
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,49 @@
|
||||||
|
#include "operators/recv.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
RecvObj::RecvObj(GraphObj *graph, Tensor output, int source, int destination,
|
||||||
|
Shape dims, int outputType, [[maybe_unused]] Tensor input)
|
||||||
|
: OperatorObj(OpType::Recv, input ? TensorVec{input} : TensorVec{},
|
||||||
|
TensorVec{output}),
|
||||||
|
source(source), destination(destination), dims(std::move(dims)),
|
||||||
|
outputType(outputType) {
|
||||||
|
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
optional<vector<Shape>> RecvObj::inferShape(const TensorVec &inputs) {
|
||||||
|
return {{dims}};
|
||||||
|
}
|
||||||
|
vector<DataType> RecvObj::inferDataType(const TensorVec &inputs) const {
|
||||||
|
return {{DataType(outputType)}};
|
||||||
|
}
|
||||||
|
DataType RecvObj::getDType() const { return getOutput(0)->getDType(); }
|
||||||
|
std::string RecvObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Recv"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(dims) << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ",";
|
||||||
|
os << "dims=" << vecToString(dims) << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> RecvObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = dims;
|
||||||
|
ret.insert(ret.end(), dims.begin(), dims.end());
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
|
||||||
|
ret.emplace_back(source);
|
||||||
|
ret.emplace_back(destination);
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> RecvObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = dims;
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
ret.emplace_back(source);
|
||||||
|
ret.emplace_back(destination);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,46 @@
|
||||||
|
#include "operators/send.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
SendObj::SendObj(GraphObj *graph, Tensor input, int source, int destination,
|
||||||
|
[[maybe_unused]] Tensor output)
|
||||||
|
: OperatorObj(OpType::Send, TensorVec{input},
|
||||||
|
TensorVec{output ? output : nullptr}),
|
||||||
|
source(source), destination(destination) {
|
||||||
|
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
optional<vector<Shape>> SendObj::inferShape(const TensorVec &inputs) {
|
||||||
|
return {{inputs[0]->getDims()}};
|
||||||
|
}
|
||||||
|
vector<DataType> SendObj::inferDataType(const TensorVec &inputs) const {
|
||||||
|
return {{inputs[0]->getDType()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SendObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Send"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> SendObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
ret.emplace_back(source);
|
||||||
|
ret.emplace_back(destination);
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> SendObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
ret.emplace_back(source);
|
||||||
|
ret.emplace_back(destination);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,90 @@
|
||||||
|
#ifdef INFINI_USE_NCCL
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/recv.h"
|
||||||
|
#include "operators/send.h"
|
||||||
|
#include "test.h"
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void sendrecv(const string taskName, int deviceID, vector<float> data,
|
||||||
|
const Shape &dataShape, int WORLD_SIZE, int source,
|
||||||
|
int destination) {
|
||||||
|
// Create Runtimes and initiate communication
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Runtime cudaRuntime = make_ref<CudaRuntimeObj>(deviceID);
|
||||||
|
cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID);
|
||||||
|
|
||||||
|
if (deviceID == source) {
|
||||||
|
Graph gSend = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto input = gSend->addTensor(Shape{static_cast<int>(data.size())},
|
||||||
|
DataType::Float32);
|
||||||
|
auto opSend =
|
||||||
|
gSend->addOp<SendObj>(input, source, destination, nullptr);
|
||||||
|
|
||||||
|
// Copy data from CPU to GPU
|
||||||
|
gSend->dataMalloc();
|
||||||
|
input->copyin(data);
|
||||||
|
cudaRuntime->run(gSend);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------
|
||||||
|
|
||||||
|
if (deviceID == destination) {
|
||||||
|
Graph gRecv = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
int outputType = 1;
|
||||||
|
// auto input =
|
||||||
|
// gRecv->addTensor(Shape{static_cast<int>(data.size())},DataType::Float32);
|
||||||
|
auto opRecv = gRecv->addOp<RecvObj>(nullptr, source, destination,
|
||||||
|
dataShape, outputType, nullptr);
|
||||||
|
gRecv->dataMalloc();
|
||||||
|
cudaRuntime->run(gRecv);
|
||||||
|
|
||||||
|
auto result = opRecv->getOutput()->clone(cpuRuntime);
|
||||||
|
EXPECT_TRUE(result->equalData(data));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_SendRecv1, run) {
|
||||||
|
// Only 1 device gets data. Every rank should have the same data after
|
||||||
|
// sendrecv.
|
||||||
|
vector<float> data = {2., 3., 5., 6.};
|
||||||
|
|
||||||
|
int WORLD_SIZE = 4;
|
||||||
|
int source = 0;
|
||||||
|
int destination = 2;
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2},
|
||||||
|
WORLD_SIZE, source, destination);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_SendRecv2, run) {
|
||||||
|
// Only 1 device gets data. Every rank should have the same data after
|
||||||
|
// sendrecv.
|
||||||
|
vector<float> data = {2., 3., 5., 6.};
|
||||||
|
|
||||||
|
int WORLD_SIZE = 3;
|
||||||
|
int source = 0;
|
||||||
|
int destination = 2;
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
|
||||||
|
threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2},
|
||||||
|
WORLD_SIZE, source, destination);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/recv.h"
|
||||||
|
#include "operators/send.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(Send, ShapeTypeInfer) {
|
||||||
|
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
int source = 0;
|
||||||
|
int destination = 1;
|
||||||
|
Shape dims = {1, 3, 2, 4};
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor(dims, DataType::Float32);
|
||||||
|
auto op = g->addOp<SendObj>(input, source, destination, nullptr);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::Send);
|
||||||
|
EXPECT_EQ(op->getInputs(0)->getDims(), (dims));
|
||||||
|
EXPECT_EQ(op->getInputs(0)->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TEST(Recv, ShapeTypeInfer) {
|
||||||
|
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
int source = 0;
|
||||||
|
int destination = 1;
|
||||||
|
Shape dims = {1, 3, 2, 4};
|
||||||
|
int outputType = 1;
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor input = g->addTensor(dims, DataType::Float32);
|
||||||
|
auto op = g->addOp<RecvObj>(nullptr, source, destination, dims,
|
||||||
|
outputType, input);
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::Recv);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (dims));
|
||||||
|
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue