diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 76f6e0c2..f095db81 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -94,6 +94,9 @@ class GraphHandlerObj { Tensor allReduceAvg(Tensor input, Tensor output); TensorVec allGather(Tensor input, std::optional outputs, int n); Tensor broadcast(Tensor input, Tensor output, int root); + Tensor send(Tensor input, int source, int destination, Tensor output); + Tensor recv(Tensor output, int source, int destination, Shape dims, + int outputType, Tensor input); Tensor depthToSpace(Tensor input, Tensor output, int blocksize, std::string mode); diff --git a/include/core/op_type.h b/include/core/op_type.h index 91a0b99a..1652a677 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -232,6 +232,8 @@ struct OpType { AllReduceAvg, AllGather, Broadcast, + Send, + Recv, } type; constexpr OpType(decltype(type) t) : type(t) {} diff --git a/include/operators/recv.h b/include/operators/recv.h new file mode 100644 index 00000000..faed3407 --- /dev/null +++ b/include/operators/recv.h @@ -0,0 +1,46 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * + * https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html + */ +class RecvObj : public OperatorObj { + + public: + /** + * @brief Construct a new SendRecv object + * + * @param graph The computation graph that this operator belongs to. + * @param input default nullptr, because recv does not have input. + * @param output recv output + * @param source the send rank + * @param destination the recv rank + * @param dims The shape of the output tensor. + */ + RecvObj(GraphObj *graph, Tensor output, int source, int destination, + Shape dims, int outputType, Tensor input = nullptr); + OP_CLONE(RecvObj); + + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + optional> inferShape(const TensorVec &inputs) override; + std::string toString() const override; + DataType getDType() const; + int getSourceRank() const { return source; } + int getDestinationRank() const { return destination; } + inline Shape getShape() const { return dims; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + vector inferDataType(const TensorVec &inputs) const override; + + protected: + int source; + int destination; + Shape dims; + int outputType; +}; +} // namespace infini diff --git a/include/operators/send.h b/include/operators/send.h new file mode 100644 index 00000000..07f5e78b --- /dev/null +++ b/include/operators/send.h @@ -0,0 +1,42 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * + * https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html + */ +class SendObj : public OperatorObj { + + public: + /** + * @brief Construct a new SendRecv object + * + * @param graph The computation graph that this operator belongs to. + * @param input send input + * @param output recv output + * @param source the send rank + * @param destination the recv rank + */ + SendObj(GraphObj *graph, Tensor input, int source, int destination, + Tensor output = nullptr); + OP_CLONE(SendObj); + + int numInputs() const override { return 1; } + int numOutputs() const override { return outputs.size(); } + std::string toString() const override; + optional> inferShape(const TensorVec &inputs) override; + + int getSourceRank() const { return source; } + int getDestinationRank() const { return destination; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + vector inferDataType(const TensorVec &inputs) const override; + + protected: + int source; + int destination; +}; +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f0326d88..90a3d3ab 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -703,12 +703,12 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) - else: + else: # NOTE: `axes` is an attribute until opset version 13. if len(node.input) > 1: axis = _parse_data(data[node.input[1]]) else: - axis = next( + axis = next( ( attr.ints for attr in node.attribute @@ -716,14 +716,17 @@ class OnnxStub: ), None, ) - keepdims = next( - ( - attr.i - for attr in node.attribute - if attr.name == "keepdims" - ), - 1, - ) != 0 + keepdims = ( + next( + ( + attr.i + for attr in node.attribute + if attr.name == "keepdims" + ), + 1, + ) + != 0 + ) tensors[node.output[0]] = self.handler.reduceSum( tensors[node.input[0]], @@ -775,6 +778,58 @@ class OnnxStub: 0, ), ) + elif node.op_type == "Send": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + ( + attr.i + for attr in node.attribute + if attr.name == "destination" + ), + 0, + ) + + self.handler.send( + tensors[node.input[0]], + source, + destination, + None, + ) + elif node.op_type == "Recv": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + ( + attr.i + for attr in node.attribute + if attr.name == "destination" + ), + 0, + ) + + for attr in node.attribute: + if attr.name == "shape": + shapeBasic = attr.ints + shape = [] + for item in shapeBasic: + shape.append(item) + + for attr in node.attribute: + if attr.name == "dataType": + outputType = attr.i + tensors[node.output[0]] = self.handler.recv( + tensors.get(node.output[0]), + source, + destination, + shape, + outputType, + None, + ) elif node.op_type == "Expand": shape = _parse_data(data[node.input[1]]) tensors[node.output[0]] = self.handler.expand( @@ -1091,10 +1146,7 @@ class OnnxStub: elif ty == backend.OpTypeId.Gather: axis = backend.gather_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) - elif ty in [ - backend.OpTypeId.ReduceMean, - backend.OpTypeId.ReduceSum - ]: + elif ty in [backend.OpTypeId.ReduceMean, backend.OpTypeId.ReduceSum]: axes, keepdims = backend.reduce_attrs_of(op) inputs.append( ctx.push_data_input( diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 8e1587b9..ca290d76 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -337,7 +337,7 @@ class TestStringMethods(unittest.TestCase): "ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean" ) make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced])) - + def test_reduce_sum(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1]) @@ -508,6 +508,29 @@ class TestStringMethods(unittest.TestCase): where = make_node("Where", ["x", "y", "con"], ["output"], name="where") make_and_import_model(make_graph([where], "where", [x, y, con], [output])) + def test_send(self): + sendInput = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 5, 7]) + send = make_node("Send", ["input"], [], name="send", source=0, destination=1) + graph = make_graph([send], "send", [sendInput], []) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + + def test_recv(self): + recvOutput = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7]) + recv = make_node( + "Recv", + [], + ["output"], + name="recv", + source=0, + destination=1, + shape=[1, 3, 5, 7], + dataType=1, + ) + graph = make_graph([recv], "recv", [], [recvOutput]) + model = make_model(graph) + from_onnx(model, backend.cpu_runtime()) + class TestDynamicTensor(unittest.TestCase): def test_dynamic_tensor(self): @@ -517,6 +540,7 @@ class TestDynamicTensor(unittest.TestCase): for root, dirs, files in os.walk(current_path): if filename in files: model_file = os.path.join(root, filename) + model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()) output_key = list(model.outputs.keys())[0] old_output_shape = model.getShape(output_key) diff --git a/src/core/graph.cc b/src/core/graph.cc index dd474d11..5eb67402 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -11,20 +11,33 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) map tensorPool; // Clone tensors for (const auto &op : ops_in) { - for (const auto &t : op->getInputs()) - if (tensorPool.find(t->getFuid()) == tensorPool.end()) - tensorPool[t->getFuid()] = cloneTensor(t); - for (const auto &t : op->getOutputs()) - if (tensorPool.find(t->getFuid()) == tensorPool.end()) - tensorPool[t->getFuid()] = cloneTensor(t); + for (const auto &t : op->getInputs()) { + if (t) { + if (tensorPool.find(t->getFuid()) == tensorPool.end()) + tensorPool[t->getFuid()] = cloneTensor(t); + } + } + for (const auto &t : op->getOutputs()) { + if (t) { + if (tensorPool.find(t->getFuid()) == tensorPool.end()) + tensorPool[t->getFuid()] = cloneTensor(t); + } + } } // Clone operators and add connections for (const auto &op : ops_in) { TensorVec inputs, outputs; - for (const auto &t : op->getInputs()) - inputs.emplace_back(tensorPool.at(t->getFuid())); - for (const auto &t : op->getOutputs()) - outputs.emplace_back(tensorPool.at(t->getFuid())); + for (const auto &t : op->getInputs()) { + if (t) { + inputs.emplace_back(tensorPool.at(t->getFuid())); + } + } + + for (const auto &t : op->getOutputs()) { + if (t) { + outputs.emplace_back(tensorPool.at(t->getFuid())); + } + } addOperatorAndConnect(op->clone(inputs, outputs)); } } @@ -33,17 +46,21 @@ void GraphObj::addOperatorAndConnect(const Operator &op) { sorted = false; ops.push_back(op); for (auto &input : op->getInputs()) { - input->addTarget(op); - if (auto pred = input->getSource()) { - pred->addSuccessors(op); - op->addPredecessors(pred); + if (input) { + input->addTarget(op); + if (auto pred = input->getSource()) { + pred->addSuccessors(op); + op->addPredecessors(pred); + } } } for (auto &output : op->getOutputs()) { - output->setSource(op); - for (auto &succ : output->getTargets()) { - succ->addPredecessors(op); - op->addSuccessors(succ); + if (output) { + output->setSource(op); + for (auto &succ : output->getTargets()) { + succ->addPredecessors(op); + op->addSuccessors(succ); + } } } } @@ -88,8 +105,9 @@ bool GraphObj::topo_sort() { const auto is_head = std::all_of( this_inputs.begin(), this_inputs.end(), [&](const auto &input) { auto src = input->getSource(); - return src // If the source node is in the waiting list, - // means that this node is not the head node. + return src // If the source node is in the waiting + // list, means that this node is not the + // head node. ? waiting.find(src) == waiting.end() // This tensor has no source node, // it must be a input tensor. @@ -110,7 +128,6 @@ bool GraphObj::topo_sort() { return false; } } - // Done. this->ops = std::move(sorted); return this->sorted = true; @@ -155,6 +172,7 @@ void GraphObj::shape_infer() { void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) { // topological sorting first + IT_ASSERT(topo_sort() == true); if (useNaiveAllocator) { // can not set memory pool when use naive allocator @@ -222,24 +240,28 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) { // memory should be allocated for the op's output first auto outputs = op->getOutputs(); for (auto &tensor : outputs) { - if (tensor->isOthers()) { - tensorToOffset[tensor.get()] = - allocator.alloc(tensor->getBytes()); + if (tensor) { + if (tensor->isOthers()) { + tensorToOffset[tensor.get()] = + allocator.alloc(tensor->getBytes()); + } } } auto inputs = op->getInputs(); for (auto &tensor : inputs) { - if (tensor->isOthers()) { - auto tensorIter = tensorToRefCount.find(tensor.get()); - IT_ASSERT(tensorIter != tensorToRefCount.end()); - IT_ASSERT(tensorToRefCount[tensor.get()] > 0); - tensorToRefCount[tensor.get()] -= 1; - if (tensorToRefCount[tensor.get()] == 0) { - // indicate that this tensor will no longer be used and - // perform memory free - tensorToRefCount.erase(tensor.get()); - allocator.free(tensorToOffset[tensor.get()], - tensor->getBytes()); + if (tensor) { + if (tensor->isOthers()) { + auto tensorIter = tensorToRefCount.find(tensor.get()); + IT_ASSERT(tensorIter != tensorToRefCount.end()); + IT_ASSERT(tensorToRefCount[tensor.get()] > 0); + tensorToRefCount[tensor.get()] -= 1; + if (tensorToRefCount[tensor.get()] == 0) { + // indicate that this tensor will no longer be used and + // perform memory free + tensorToRefCount.erase(tensor.get()); + allocator.free(tensorToOffset[tensor.get()], + tensor->getBytes()); + } } } } diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index de156c43..1eb73499 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -13,8 +13,10 @@ #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" +#include "operators/recv.h" #include "operators/reduce.h" #include "operators/reshape.h" +#include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" #include "operators/split.h" @@ -434,6 +436,39 @@ Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) { } } +Tensor GraphHandlerObj::send(Tensor input, int source, int destination, + Tensor output) { + if (output) { + + g->addOpWithOutputs(std::move(input), source, destination, + output); + + return output; + } else { + return g->addOp(std::move(input), source, destination, output) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::recv(Tensor output, int source, int destination, + Shape dims, int outputType, Tensor input) { + + if (output) { + + g->addOpWithOutputs(output, source, destination, + std::move(dims), outputType, + std::move(input)); + + return output; + } else { + + return g + ->addOp(output, source, destination, std::move(dims), + outputType, std::move(input)) + ->getOutput(); + } +} + Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) { if (output) { g->addOpWithOutputs(std::move(input), output, diff --git a/src/core/operator.cc b/src/core/operator.cc index 6a9ea1b8..4fd4e6de 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -6,8 +6,10 @@ namespace infini { OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs) : type(opType), inputs(inputs), outputs(outputs) { - for (const auto &t : inputs) - IT_ASSERT(t); + if (opType != OpType::Recv) { + for (const auto &t : inputs) + IT_ASSERT(t); + } } void OperatorObj::removePredecessors(const Operator &op) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 408d3514..ca99a4c3 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -511,6 +511,8 @@ void init_graph_builder(py::module &m) { .def("allReduceAvg", &Handler::allReduceAvg, policy::move) .def("allGather", &Handler::allGather, policy::move) .def("broadcast", &Handler::broadcast, policy::move) + .def("send", &Handler::send, policy::move) + .def("recv", &Handler::recv, policy::move) .def("cast", &Handler::cast, policy::move) .def("expand", &Handler::expand, policy::move) .def("erf", &Handler::erf, policy::move) diff --git a/src/kernels/cuda/recv.cc b/src/kernels/cuda/recv.cc new file mode 100644 index 00000000..7fd7ee49 --- /dev/null +++ b/src/kernels/cuda/recv.cc @@ -0,0 +1,47 @@ +#ifdef INFINI_USE_NCCL +#include "operators/recv.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" + +namespace infini { +class RecvNCCL : public CudaKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *output = op->getOutput(0)->getRawDataPtr(); + IT_ASSERT(op->getDType() == DataType::Float32); + const auto shape = op->getShape(); + int nDims = shape.size(); + int outputCount = 1; + for (int i = 0; i < nDims; i++) { + outputCount *= shape[i]; + } + + ncclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getNcclComm(); + // TODO: Using default stream 0 for now. + int rank; + + checkNcclError(ncclCommUserRank(comm, &rank)); + + int source = op->getSourceRank(); + int destination = op->getDestinationRank(); + + if (rank == destination) { + + checkNcclError( + ncclRecv(output, outputCount, ncclFloat, source, comm, 0)); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL, + "Recv_NCCL_CUDA_Float32"); +} // namespace infini + +#endif diff --git a/src/kernels/cuda/send.cc b/src/kernels/cuda/send.cc new file mode 100644 index 00000000..38684062 --- /dev/null +++ b/src/kernels/cuda/send.cc @@ -0,0 +1,43 @@ +#ifdef INFINI_USE_NCCL +#include "operators/send.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/nccl_communicator.h" + +namespace infini { +class SendNCCL : public CudaKernelWithoutConfig { + public: + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + void *input = op->getInputs(0)->getRawDataPtr(); + + IT_ASSERT(op->getDType() == DataType::Float32); + size_t inputCount = + op->getInputs(0)->getBytes() / op->getDType().getSize(); + + ncclComm_t comm = + dynamic_cast(context->getCommunicator()) + .getNcclComm(); + // TODO: Using default stream 0 for now. + int rank; + + checkNcclError(ncclCommUserRank(comm, &rank)); + + int source = op->getSourceRank(); + int destination = op->getDestinationRank(); + + if (rank == source) { + + checkNcclError( + ncclSend(input, inputCount, ncclFloat, destination, comm, 0)); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL, + "Send_NCCL_CUDA_Float32"); +} // namespace infini + +#endif diff --git a/src/operators/recv.cc b/src/operators/recv.cc new file mode 100644 index 00000000..6883b636 --- /dev/null +++ b/src/operators/recv.cc @@ -0,0 +1,49 @@ +#include "operators/recv.h" + +namespace infini { +RecvObj::RecvObj(GraphObj *graph, Tensor output, int source, int destination, + Shape dims, int outputType, [[maybe_unused]] Tensor input) + : OperatorObj(OpType::Recv, input ? TensorVec{input} : TensorVec{}, + TensorVec{output}), + source(source), destination(destination), dims(std::move(dims)), + outputType(outputType) { + + IT_ASSERT(checkValid(graph)); +} +optional> RecvObj::inferShape(const TensorVec &inputs) { + return {{dims}}; +} +vector RecvObj::inferDataType(const TensorVec &inputs) const { + return {{DataType(outputType)}}; +} +DataType RecvObj::getDType() const { return getOutput(0)->getDType(); } +std::string RecvObj::toString() const { + std::ostringstream os; + os << "Recv" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(dims) << ","; + os << "output=" << outputs[0]->getGuid() << ","; + os << "dims=" << vecToString(dims) << ")"; + return os.str(); +} + +vector RecvObj::getWorkloadVector() const { + vector ret = dims; + ret.insert(ret.end(), dims.begin(), dims.end()); + ret.emplace(ret.begin(), type.underlying()); + + ret.emplace_back(source); + ret.emplace_back(destination); + + return ret; +} + +vector RecvObj::getOpAttrVector() const { + vector ret = dims; + ret.emplace(ret.begin(), type.underlying()); + ret.emplace_back(source); + ret.emplace_back(destination); + return ret; +} +} // namespace infini diff --git a/src/operators/send.cc b/src/operators/send.cc new file mode 100644 index 00000000..bc349ceb --- /dev/null +++ b/src/operators/send.cc @@ -0,0 +1,46 @@ +#include "operators/send.h" + +namespace infini { +SendObj::SendObj(GraphObj *graph, Tensor input, int source, int destination, + [[maybe_unused]] Tensor output) + : OperatorObj(OpType::Send, TensorVec{input}, + TensorVec{output ? output : nullptr}), + source(source), destination(destination) { + + IT_ASSERT(checkValid(graph)); +} +optional> SendObj::inferShape(const TensorVec &inputs) { + return {{inputs[0]->getDims()}}; +} +vector SendObj::inferDataType(const TensorVec &inputs) const { + return {{inputs[0]->getDType()}}; +} + +std::string SendObj::toString() const { + std::ostringstream os; + os << "Send" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ")"; + return os.str(); +} + +vector SendObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + + ret.emplace(ret.begin(), type.underlying()); + ret.emplace_back(source); + ret.emplace_back(destination); + + return ret; +} + +vector SendObj::getOpAttrVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), type.underlying()); + ret.emplace_back(source); + ret.emplace_back(destination); + return ret; +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_sendrecv.cc b/test/kernels/cuda/test_cuda_sendrecv.cc new file mode 100644 index 00000000..4be24b52 --- /dev/null +++ b/test/kernels/cuda/test_cuda_sendrecv.cc @@ -0,0 +1,90 @@ +#ifdef INFINI_USE_NCCL +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/recv.h" +#include "operators/send.h" +#include "test.h" +#include +#include + +namespace infini { + +void sendrecv(const string taskName, int deviceID, vector data, + const Shape &dataShape, int WORLD_SIZE, int source, + int destination) { + // Create Runtimes and initiate communication + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + Runtime cudaRuntime = make_ref(deviceID); + cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID); + + if (deviceID == source) { + Graph gSend = make_ref(cudaRuntime); + auto input = gSend->addTensor(Shape{static_cast(data.size())}, + DataType::Float32); + auto opSend = + gSend->addOp(input, source, destination, nullptr); + + // Copy data from CPU to GPU + gSend->dataMalloc(); + input->copyin(data); + cudaRuntime->run(gSend); + } + + // ---------------- + + if (deviceID == destination) { + Graph gRecv = make_ref(cudaRuntime); + int outputType = 1; + // auto input = + // gRecv->addTensor(Shape{static_cast(data.size())},DataType::Float32); + auto opRecv = gRecv->addOp(nullptr, source, destination, + dataShape, outputType, nullptr); + gRecv->dataMalloc(); + cudaRuntime->run(gRecv); + + auto result = opRecv->getOutput()->clone(cpuRuntime); + EXPECT_TRUE(result->equalData(data)); + } +} + +TEST(CUDA_SendRecv1, run) { + // Only 1 device gets data. Every rank should have the same data after + // sendrecv. + vector data = {2., 3., 5., 6.}; + + int WORLD_SIZE = 4; + int source = 0; + int destination = 2; + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2}, + WORLD_SIZE, source, destination); + } + + for (auto &thread : threads) { + thread.join(); + } +} + +TEST(CUDA_SendRecv2, run) { + // Only 1 device gets data. Every rank should have the same data after + // sendrecv. + vector data = {2., 3., 5., 6.}; + + int WORLD_SIZE = 3; + int source = 0; + int destination = 2; + std::vector threads; + for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) { + threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2}, + WORLD_SIZE, source, destination); + } + + for (auto &thread : threads) { + thread.join(); + } +} +} // namespace infini +#endif diff --git a/test/operators/test_sendrecv.cc b/test/operators/test_sendrecv.cc new file mode 100644 index 00000000..44cbb141 --- /dev/null +++ b/test/operators/test_sendrecv.cc @@ -0,0 +1,38 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/recv.h" +#include "operators/send.h" +#include "test.h" + +namespace infini { +TEST(Send, ShapeTypeInfer) { + auto runtime = NativeCpuRuntimeObj::getInstance(); + int source = 0; + int destination = 1; + Shape dims = {1, 3, 2, 4}; + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor(dims, DataType::Float32); + auto op = g->addOp(input, source, destination, nullptr); + EXPECT_EQ(op->getOpType(), OpType::Send); + EXPECT_EQ(op->getInputs(0)->getDims(), (dims)); + EXPECT_EQ(op->getInputs(0)->getDType(), DataType::Float32); + } +} +TEST(Recv, ShapeTypeInfer) { + auto runtime = NativeCpuRuntimeObj::getInstance(); + int source = 0; + int destination = 1; + Shape dims = {1, 3, 2, 4}; + int outputType = 1; + { + Graph g = make_ref(runtime); + Tensor input = g->addTensor(dims, DataType::Float32); + auto op = g->addOp(nullptr, source, destination, dims, + outputType, input); + EXPECT_EQ(op->getOpType(), OpType::Recv); + EXPECT_EQ(op->getOutput()->getDims(), (dims)); + EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32); + } +} +} // namespace infini