Add send and recv operators based on NCCL (#182)

* baseline sendrecv, bug

* success sendrecv

* get rank from comm

* set output shape

* successful:set output shape equal to input shape

* shape as attribute

* success:shape as attribute

* success send recv, output 0

* add onnx test

* split send and recv

* success split send and recv

* test-onnx bug

* success test-onnx

* modified onnx.py

* solve review
This commit is contained in:
xgqdut2016 2023-12-14 16:38:03 +08:00 committed by GitHub
parent c143eebdf7
commit a3929c25f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 595 additions and 52 deletions

View File

@ -94,6 +94,9 @@ class GraphHandlerObj {
Tensor allReduceAvg(Tensor input, Tensor output);
TensorVec allGather(Tensor input, std::optional<TensorVec> outputs, int n);
Tensor broadcast(Tensor input, Tensor output, int root);
Tensor send(Tensor input, int source, int destination, Tensor output);
Tensor recv(Tensor output, int source, int destination, Shape dims,
int outputType, Tensor input);
Tensor depthToSpace(Tensor input, Tensor output, int blocksize,
std::string mode);

View File

@ -232,6 +232,8 @@ struct OpType {
AllReduceAvg,
AllGather,
Broadcast,
Send,
Recv,
} type;
constexpr OpType(decltype(type) t) : type(t) {}

46
include/operators/recv.h Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
*
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
*/
class RecvObj : public OperatorObj {
public:
/**
* @brief Construct a new SendRecv object
*
* @param graph The computation graph that this operator belongs to.
* @param input default nullptr, because recv does not have input.
* @param output recv output
* @param source the send rank
* @param destination the recv rank
* @param dims The shape of the output tensor.
*/
RecvObj(GraphObj *graph, Tensor output, int source, int destination,
Shape dims, int outputType, Tensor input = nullptr);
OP_CLONE(RecvObj);
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
DataType getDType() const;
int getSourceRank() const { return source; }
int getDestinationRank() const { return destination; }
inline Shape getShape() const { return dims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;
protected:
int source;
int destination;
Shape dims;
int outputType;
};
} // namespace infini

42
include/operators/send.h Normal file
View File

@ -0,0 +1,42 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
*
* https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2193/user-guide/docs/index.html
*/
class SendObj : public OperatorObj {
public:
/**
* @brief Construct a new SendRecv object
*
* @param graph The computation graph that this operator belongs to.
* @param input send input
* @param output recv output
* @param source the send rank
* @param destination the recv rank
*/
SendObj(GraphObj *graph, Tensor input, int source, int destination,
Tensor output = nullptr);
OP_CLONE(SendObj);
int numInputs() const override { return 1; }
int numOutputs() const override { return outputs.size(); }
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int getSourceRank() const { return source; }
int getDestinationRank() const { return destination; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;
protected:
int source;
int destination;
};
} // namespace infini

View File

@ -703,12 +703,12 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
)
else:
else:
# NOTE: `axes` is an attribute until opset version 13.
if len(node.input) > 1:
axis = _parse_data(data[node.input[1]])
else:
axis = next(
axis = next(
(
attr.ints
for attr in node.attribute
@ -716,14 +716,17 @@ class OnnxStub:
),
None,
)
keepdims = next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
),
1,
) != 0
keepdims = (
next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
),
1,
)
!= 0
)
tensors[node.output[0]] = self.handler.reduceSum(
tensors[node.input[0]],
@ -775,6 +778,58 @@ class OnnxStub:
0,
),
)
elif node.op_type == "Send":
source = next(
(attr.i for attr in node.attribute if attr.name == "source"),
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
0,
)
self.handler.send(
tensors[node.input[0]],
source,
destination,
None,
)
elif node.op_type == "Recv":
source = next(
(attr.i for attr in node.attribute if attr.name == "source"),
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
0,
)
for attr in node.attribute:
if attr.name == "shape":
shapeBasic = attr.ints
shape = []
for item in shapeBasic:
shape.append(item)
for attr in node.attribute:
if attr.name == "dataType":
outputType = attr.i
tensors[node.output[0]] = self.handler.recv(
tensors.get(node.output[0]),
source,
destination,
shape,
outputType,
None,
)
elif node.op_type == "Expand":
shape = _parse_data(data[node.input[1]])
tensors[node.output[0]] = self.handler.expand(
@ -1091,10 +1146,7 @@ class OnnxStub:
elif ty == backend.OpTypeId.Gather:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty in [
backend.OpTypeId.ReduceMean,
backend.OpTypeId.ReduceSum
]:
elif ty in [backend.OpTypeId.ReduceMean, backend.OpTypeId.ReduceSum]:
axes, keepdims = backend.reduce_attrs_of(op)
inputs.append(
ctx.push_data_input(

View File

@ -337,7 +337,7 @@ class TestStringMethods(unittest.TestCase):
"ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean"
)
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
def test_reduce_sum(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
@ -508,6 +508,29 @@ class TestStringMethods(unittest.TestCase):
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
def test_send(self):
sendInput = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 5, 7])
send = make_node("Send", ["input"], [], name="send", source=0, destination=1)
graph = make_graph([send], "send", [sendInput], [])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
def test_recv(self):
recvOutput = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7])
recv = make_node(
"Recv",
[],
["output"],
name="recv",
source=0,
destination=1,
shape=[1, 3, 5, 7],
dataType=1,
)
graph = make_graph([recv], "recv", [], [recvOutput])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
class TestDynamicTensor(unittest.TestCase):
def test_dynamic_tensor(self):
@ -517,6 +540,7 @@ class TestDynamicTensor(unittest.TestCase):
for root, dirs, files in os.walk(current_path):
if filename in files:
model_file = os.path.join(root, filename)
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
output_key = list(model.outputs.keys())[0]
old_output_shape = model.getShape(output_key)

View File

@ -11,20 +11,33 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
map<UidBaseType, Tensor> tensorPool;
// Clone tensors
for (const auto &op : ops_in) {
for (const auto &t : op->getInputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = cloneTensor(t);
for (const auto &t : op->getOutputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = cloneTensor(t);
for (const auto &t : op->getInputs()) {
if (t) {
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = cloneTensor(t);
}
}
for (const auto &t : op->getOutputs()) {
if (t) {
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = cloneTensor(t);
}
}
}
// Clone operators and add connections
for (const auto &op : ops_in) {
TensorVec inputs, outputs;
for (const auto &t : op->getInputs())
inputs.emplace_back(tensorPool.at(t->getFuid()));
for (const auto &t : op->getOutputs())
outputs.emplace_back(tensorPool.at(t->getFuid()));
for (const auto &t : op->getInputs()) {
if (t) {
inputs.emplace_back(tensorPool.at(t->getFuid()));
}
}
for (const auto &t : op->getOutputs()) {
if (t) {
outputs.emplace_back(tensorPool.at(t->getFuid()));
}
}
addOperatorAndConnect(op->clone(inputs, outputs));
}
}
@ -33,17 +46,21 @@ void GraphObj::addOperatorAndConnect(const Operator &op) {
sorted = false;
ops.push_back(op);
for (auto &input : op->getInputs()) {
input->addTarget(op);
if (auto pred = input->getSource()) {
pred->addSuccessors(op);
op->addPredecessors(pred);
if (input) {
input->addTarget(op);
if (auto pred = input->getSource()) {
pred->addSuccessors(op);
op->addPredecessors(pred);
}
}
}
for (auto &output : op->getOutputs()) {
output->setSource(op);
for (auto &succ : output->getTargets()) {
succ->addPredecessors(op);
op->addSuccessors(succ);
if (output) {
output->setSource(op);
for (auto &succ : output->getTargets()) {
succ->addPredecessors(op);
op->addSuccessors(succ);
}
}
}
}
@ -88,8 +105,9 @@ bool GraphObj::topo_sort() {
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getSource();
return src // If the source node is in the waiting list,
// means that this node is not the head node.
return src // If the source node is in the waiting
// list, means that this node is not the
// head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
@ -110,7 +128,6 @@ bool GraphObj::topo_sort() {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
@ -155,6 +172,7 @@ void GraphObj::shape_infer() {
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
// topological sorting first
IT_ASSERT(topo_sort() == true);
if (useNaiveAllocator) {
// can not set memory pool when use naive allocator
@ -222,24 +240,28 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
// memory should be allocated for the op's output first
auto outputs = op->getOutputs();
for (auto &tensor : outputs) {
if (tensor->isOthers()) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
if (tensor) {
if (tensor->isOthers()) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
}
}
}
auto inputs = op->getInputs();
for (auto &tensor : inputs) {
if (tensor->isOthers()) {
auto tensorIter = tensorToRefCount.find(tensor.get());
IT_ASSERT(tensorIter != tensorToRefCount.end());
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
tensorToRefCount[tensor.get()] -= 1;
if (tensorToRefCount[tensor.get()] == 0) {
// indicate that this tensor will no longer be used and
// perform memory free
tensorToRefCount.erase(tensor.get());
allocator.free(tensorToOffset[tensor.get()],
tensor->getBytes());
if (tensor) {
if (tensor->isOthers()) {
auto tensorIter = tensorToRefCount.find(tensor.get());
IT_ASSERT(tensorIter != tensorToRefCount.end());
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
tensorToRefCount[tensor.get()] -= 1;
if (tensorToRefCount[tensor.get()] == 0) {
// indicate that this tensor will no longer be used and
// perform memory free
tensorToRefCount.erase(tensor.get());
allocator.free(tensorToOffset[tensor.get()],
tensor->getBytes());
}
}
}
}

View File

@ -13,8 +13,10 @@
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/recv.h"
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/send.h"
#include "operators/slice.h"
#include "operators/softmax.h"
#include "operators/split.h"
@ -434,6 +436,39 @@ Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) {
}
}
Tensor GraphHandlerObj::send(Tensor input, int source, int destination,
Tensor output) {
if (output) {
g->addOpWithOutputs<SendObj>(std::move(input), source, destination,
output);
return output;
} else {
return g->addOp<SendObj>(std::move(input), source, destination, output)
->getOutput();
}
}
Tensor GraphHandlerObj::recv(Tensor output, int source, int destination,
Shape dims, int outputType, Tensor input) {
if (output) {
g->addOpWithOutputs<RecvObj>(output, source, destination,
std::move(dims), outputType,
std::move(input));
return output;
} else {
return g
->addOp<RecvObj>(output, source, destination, std::move(dims),
outputType, std::move(input))
->getOutput();
}
}
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
if (output) {
g->addOpWithOutputs<CastObj>(std::move(input), output,

View File

@ -6,8 +6,10 @@ namespace infini {
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
: type(opType), inputs(inputs), outputs(outputs) {
for (const auto &t : inputs)
IT_ASSERT(t);
if (opType != OpType::Recv) {
for (const auto &t : inputs)
IT_ASSERT(t);
}
}
void OperatorObj::removePredecessors(const Operator &op) {

View File

@ -511,6 +511,8 @@ void init_graph_builder(py::module &m) {
.def("allReduceAvg", &Handler::allReduceAvg, policy::move)
.def("allGather", &Handler::allGather, policy::move)
.def("broadcast", &Handler::broadcast, policy::move)
.def("send", &Handler::send, policy::move)
.def("recv", &Handler::recv, policy::move)
.def("cast", &Handler::cast, policy::move)
.def("expand", &Handler::expand, policy::move)
.def("erf", &Handler::erf, policy::move)

47
src/kernels/cuda/recv.cc Normal file
View File

@ -0,0 +1,47 @@
#ifdef INFINI_USE_NCCL
#include "operators/recv.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/nccl_communicator.h"
namespace infini {
class RecvNCCL : public CudaKernelWithoutConfig {
public:
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<RecvObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *output = op->getOutput(0)->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
const auto shape = op->getShape();
int nDims = shape.size();
int outputCount = 1;
for (int i = 0; i < nDims; i++) {
outputCount *= shape[i];
}
ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm();
// TODO: Using default stream 0 for now.
int rank;
checkNcclError(ncclCommUserRank(comm, &rank));
int source = op->getSourceRank();
int destination = op->getDestinationRank();
if (rank == destination) {
checkNcclError(
ncclRecv(output, outputCount, ncclFloat, source, comm, 0));
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL,
"Recv_NCCL_CUDA_Float32");
} // namespace infini
#endif

43
src/kernels/cuda/send.cc Normal file
View File

@ -0,0 +1,43 @@
#ifdef INFINI_USE_NCCL
#include "operators/send.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/nccl_communicator.h"
namespace infini {
class SendNCCL : public CudaKernelWithoutConfig {
public:
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SendObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
size_t inputCount =
op->getInputs(0)->getBytes() / op->getDType().getSize();
ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm();
// TODO: Using default stream 0 for now.
int rank;
checkNcclError(ncclCommUserRank(comm, &rank));
int source = op->getSourceRank();
int destination = op->getDestinationRank();
if (rank == source) {
checkNcclError(
ncclSend(input, inputCount, ncclFloat, destination, comm, 0));
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL,
"Send_NCCL_CUDA_Float32");
} // namespace infini
#endif

49
src/operators/recv.cc Normal file
View File

@ -0,0 +1,49 @@
#include "operators/recv.h"
namespace infini {
RecvObj::RecvObj(GraphObj *graph, Tensor output, int source, int destination,
Shape dims, int outputType, [[maybe_unused]] Tensor input)
: OperatorObj(OpType::Recv, input ? TensorVec{input} : TensorVec{},
TensorVec{output}),
source(source), destination(destination), dims(std::move(dims)),
outputType(outputType) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> RecvObj::inferShape(const TensorVec &inputs) {
return {{dims}};
}
vector<DataType> RecvObj::inferDataType(const TensorVec &inputs) const {
return {{DataType(outputType)}};
}
DataType RecvObj::getDType() const { return getOutput(0)->getDType(); }
std::string RecvObj::toString() const {
std::ostringstream os;
os << "Recv"
<< "[" << getGuid() << "]";
os << "(";
os << vecToString(dims) << ",";
os << "output=" << outputs[0]->getGuid() << ",";
os << "dims=" << vecToString(dims) << ")";
return os.str();
}
vector<int> RecvObj::getWorkloadVector() const {
vector<int> ret = dims;
ret.insert(ret.end(), dims.begin(), dims.end());
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(source);
ret.emplace_back(destination);
return ret;
}
vector<int> RecvObj::getOpAttrVector() const {
vector<int> ret = dims;
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(source);
ret.emplace_back(destination);
return ret;
}
} // namespace infini

46
src/operators/send.cc Normal file
View File

@ -0,0 +1,46 @@
#include "operators/send.h"
namespace infini {
SendObj::SendObj(GraphObj *graph, Tensor input, int source, int destination,
[[maybe_unused]] Tensor output)
: OperatorObj(OpType::Send, TensorVec{input},
TensorVec{output ? output : nullptr}),
source(source), destination(destination) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> SendObj::inferShape(const TensorVec &inputs) {
return {{inputs[0]->getDims()}};
}
vector<DataType> SendObj::inferDataType(const TensorVec &inputs) const {
return {{inputs[0]->getDType()}};
}
std::string SendObj::toString() const {
std::ostringstream os;
os << "Send"
<< "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ")";
return os.str();
}
vector<int> SendObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(source);
ret.emplace_back(destination);
return ret;
}
vector<int> SendObj::getOpAttrVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back(source);
ret.emplace_back(destination);
return ret;
}
} // namespace infini

View File

@ -0,0 +1,90 @@
#ifdef INFINI_USE_NCCL
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/recv.h"
#include "operators/send.h"
#include "test.h"
#include <nccl.h>
#include <thread>
namespace infini {
void sendrecv(const string taskName, int deviceID, vector<float> data,
const Shape &dataShape, int WORLD_SIZE, int source,
int destination) {
// Create Runtimes and initiate communication
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
Runtime cudaRuntime = make_ref<CudaRuntimeObj>(deviceID);
cudaRuntime->initComm(taskName, WORLD_SIZE, deviceID);
if (deviceID == source) {
Graph gSend = make_ref<GraphObj>(cudaRuntime);
auto input = gSend->addTensor(Shape{static_cast<int>(data.size())},
DataType::Float32);
auto opSend =
gSend->addOp<SendObj>(input, source, destination, nullptr);
// Copy data from CPU to GPU
gSend->dataMalloc();
input->copyin(data);
cudaRuntime->run(gSend);
}
// ----------------
if (deviceID == destination) {
Graph gRecv = make_ref<GraphObj>(cudaRuntime);
int outputType = 1;
// auto input =
// gRecv->addTensor(Shape{static_cast<int>(data.size())},DataType::Float32);
auto opRecv = gRecv->addOp<RecvObj>(nullptr, source, destination,
dataShape, outputType, nullptr);
gRecv->dataMalloc();
cudaRuntime->run(gRecv);
auto result = opRecv->getOutput()->clone(cpuRuntime);
EXPECT_TRUE(result->equalData(data));
}
}
TEST(CUDA_SendRecv1, run) {
// Only 1 device gets data. Every rank should have the same data after
// sendrecv.
vector<float> data = {2., 3., 5., 6.};
int WORLD_SIZE = 4;
int source = 0;
int destination = 2;
std::vector<std::thread> threads;
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2},
WORLD_SIZE, source, destination);
}
for (auto &thread : threads) {
thread.join();
}
}
TEST(CUDA_SendRecv2, run) {
// Only 1 device gets data. Every rank should have the same data after
// sendrecv.
vector<float> data = {2., 3., 5., 6.};
int WORLD_SIZE = 3;
int source = 0;
int destination = 2;
std::vector<std::thread> threads;
for (int gpu = 0; gpu < WORLD_SIZE; ++gpu) {
threads.emplace_back(sendrecv, "test_sendrecv", gpu, data, Shape{2, 2},
WORLD_SIZE, source, destination);
}
for (auto &thread : threads) {
thread.join();
}
}
} // namespace infini
#endif

View File

@ -0,0 +1,38 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/recv.h"
#include "operators/send.h"
#include "test.h"
namespace infini {
TEST(Send, ShapeTypeInfer) {
auto runtime = NativeCpuRuntimeObj::getInstance();
int source = 0;
int destination = 1;
Shape dims = {1, 3, 2, 4};
{
Graph g = make_ref<GraphObj>(runtime);
Tensor input = g->addTensor(dims, DataType::Float32);
auto op = g->addOp<SendObj>(input, source, destination, nullptr);
EXPECT_EQ(op->getOpType(), OpType::Send);
EXPECT_EQ(op->getInputs(0)->getDims(), (dims));
EXPECT_EQ(op->getInputs(0)->getDType(), DataType::Float32);
}
}
TEST(Recv, ShapeTypeInfer) {
auto runtime = NativeCpuRuntimeObj::getInstance();
int source = 0;
int destination = 1;
Shape dims = {1, 3, 2, 4};
int outputType = 1;
{
Graph g = make_ref<GraphObj>(runtime);
Tensor input = g->addTensor(dims, DataType::Float32);
auto op = g->addOp<RecvObj>(nullptr, source, destination, dims,
outputType, input);
EXPECT_EQ(op->getOpType(), OpType::Recv);
EXPECT_EQ(op->getOutput()->getDims(), (dims));
EXPECT_EQ(op->getOutput()->getDType(), DataType::Float32);
}
}
} // namespace infini