diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp index d49f47ff2..14ecd937a 100644 --- a/fastboot/socket.cpp +++ b/fastboot/socket.cpp @@ -48,18 +48,6 @@ int Socket::Close() { return ret; } -bool Socket::SetReceiveTimeout(int timeout_ms) { - if (timeout_ms != receive_timeout_ms_) { - if (socket_set_receive_timeout(sock_, timeout_ms) == 0) { - receive_timeout_ms_ = timeout_ms; - return true; - } - return false; - } - - return true; -} - ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) { size_t total = 0; @@ -82,6 +70,40 @@ int Socket::GetLocalPort() { return socket_get_local_port(sock_); } +// According to Windows setsockopt() documentation, if a Windows socket times out during send() or +// recv() the state is indeterminate and should not be used. Our UDP protocol relies on being able +// to re-send after a timeout, so we must use select() rather than SO_RCVTIMEO. +// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms740476(v=vs.85).aspx. +bool Socket::WaitForRecv(int timeout_ms) { + receive_timed_out_ = false; + + // In our usage |timeout_ms| <= 0 means block forever, so just return true immediately and let + // the subsequent recv() do the blocking. + if (timeout_ms <= 0) { + return true; + } + + // select() doesn't always check this case and will block for |timeout_ms| if we let it. + if (sock_ == INVALID_SOCKET) { + return false; + } + + fd_set read_set; + FD_ZERO(&read_set); + FD_SET(sock_, &read_set); + + timeval timeout; + timeout.tv_sec = timeout_ms / 1000; + timeout.tv_usec = (timeout_ms % 1000) * 1000; + + int result = TEMP_FAILURE_RETRY(select(sock_ + 1, &read_set, nullptr, nullptr, &timeout)); + + if (result == 0) { + receive_timed_out_ = true; + } + return result == 1; +} + // Implements the Socket interface for UDP. class UdpSocket : public Socket { public: @@ -127,7 +149,7 @@ bool UdpSocket::Send(std::vector buffers) { } ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { - if (!SetReceiveTimeout(timeout_ms)) { + if (!WaitForRecv(timeout_ms)) { return -1; } @@ -206,7 +228,7 @@ bool TcpSocket::Send(std::vector buffers) { } ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { - if (!SetReceiveTimeout(timeout_ms)) { + if (!WaitForRecv(timeout_ms)) { return -1; } diff --git a/fastboot/socket.h b/fastboot/socket.h index c0bd7c96c..de543dbab 100644 --- a/fastboot/socket.h +++ b/fastboot/socket.h @@ -81,13 +81,17 @@ class Socket { virtual bool Send(std::vector buffers) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will - // block forever. Returns the number of bytes received or -1 on error/timeout. On timeout - // errno will be set to EAGAIN or EWOULDBLOCK. + // block forever. Returns the number of bytes received or -1 on error/timeout; see + // ReceiveTimedOut() to distinguish between the two. virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0; // Calls Receive() until exactly |length| bytes have been received or an error occurs. virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms); + // Returns true if the last Receive() call timed out normally and can be retried; fatal errors + // or successful reads will return false. + bool ReceiveTimedOut() { return receive_timed_out_; } + // Closes the socket. Returns 0 on success, -1 on error. virtual int Close(); @@ -102,10 +106,13 @@ class Socket { // Protected constructor to force factory function use. Socket(cutils_socket_t sock); - // Update the socket receive timeout if necessary. - bool SetReceiveTimeout(int timeout_ms); + // Blocks up to |timeout_ms| until a read is possible on |sock_|, and sets |receive_timed_out_| + // as appropriate to help distinguish between normal timeouts and fatal errors. Returns true if + // a subsequent recv() on |sock_| will complete without blocking or if |timeout_ms| <= 0. + bool WaitForRecv(int timeout_ms); cutils_socket_t sock_ = INVALID_SOCKET; + bool receive_timed_out_ = false; // Non-class functions we want to override during tests to verify functionality. Implementation // should call this rather than using socket_send_buffers() directly. @@ -113,8 +120,6 @@ class Socket { socket_send_buffers_function_ = &socket_send_buffers; private: - int receive_timeout_ms_ = 0; - FRIEND_TEST(SocketTest, TestTcpSendBuffers); FRIEND_TEST(SocketTest, TestUdpSendBuffers); diff --git a/fastboot/socket_mock.cpp b/fastboot/socket_mock.cpp index c962f303d..2531b53ad 100644 --- a/fastboot/socket_mock.cpp +++ b/fastboot/socket_mock.cpp @@ -55,7 +55,7 @@ bool SocketMock::Send(const void* data, size_t length) { return false; } - bool return_value = events_.front().return_value; + bool return_value = events_.front().status; events_.pop(); return return_value; } @@ -76,21 +76,28 @@ ssize_t SocketMock::Receive(void* data, size_t length, int /*timeout_ms*/) { return -1; } - if (events_.front().type != EventType::kReceive) { + const Event& event = events_.front(); + if (event.type != EventType::kReceive) { ADD_FAILURE() << "Receive() was called out-of-order"; return -1; } - if (events_.front().return_value > static_cast(length)) { - ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for " - << events_.front().message; + const std::string& message = event.message; + if (message.length() > length) { + ADD_FAILURE() << "Receive(): not enough bytes (" << length << ") for " << message; return -1; } - ssize_t return_value = events_.front().return_value; - if (return_value > 0) { - memcpy(data, events_.front().message.data(), return_value); + receive_timed_out_ = event.status; + ssize_t return_value = message.length(); + + // Empty message indicates failure. + if (message.empty()) { + return_value = -1; + } else { + memcpy(data, message.data(), message.length()); } + events_.pop(); return return_value; } @@ -124,18 +131,21 @@ void SocketMock::ExpectSendFailure(std::string message) { } void SocketMock::AddReceive(std::string message) { - ssize_t return_value = message.length(); - events_.push(Event(EventType::kReceive, std::move(message), return_value, nullptr)); + events_.push(Event(EventType::kReceive, std::move(message), false, nullptr)); +} + +void SocketMock::AddReceiveTimeout() { + events_.push(Event(EventType::kReceive, "", true, nullptr)); } void SocketMock::AddReceiveFailure() { - events_.push(Event(EventType::kReceive, "", -1, nullptr)); + events_.push(Event(EventType::kReceive, "", false, nullptr)); } void SocketMock::AddAccept(std::unique_ptr sock) { - events_.push(Event(EventType::kAccept, "", 0, std::move(sock))); + events_.push(Event(EventType::kAccept, "", false, std::move(sock))); } -SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _return_value, +SocketMock::Event::Event(EventType _type, std::string _message, ssize_t _status, std::unique_ptr _sock) - : type(_type), message(_message), return_value(_return_value), sock(std::move(_sock)) {} + : type(_type), message(_message), status(_status), sock(std::move(_sock)) {} diff --git a/fastboot/socket_mock.h b/fastboot/socket_mock.h index 41fe06db0..eacd6bb6a 100644 --- a/fastboot/socket_mock.h +++ b/fastboot/socket_mock.h @@ -71,7 +71,10 @@ class SocketMock : public Socket { // Adds data to provide for Receive(). void AddReceive(std::string message); - // Adds a Receive() failure. + // Adds a Receive() timeout after which ReceiveTimedOut() will return true. + void AddReceiveTimeout(); + + // Adds a Receive() failure after which ReceiveTimedOut() will return false. void AddReceiveFailure(); // Adds a Socket to return from Accept(). @@ -81,12 +84,12 @@ class SocketMock : public Socket { enum class EventType { kSend, kReceive, kAccept }; struct Event { - Event(EventType _type, std::string _message, ssize_t _return_value, + Event(EventType _type, std::string _message, ssize_t _status, std::unique_ptr _sock); EventType type; std::string message; - ssize_t return_value; + bool status; // Return value for Send() or timeout status for Receive(). std::unique_ptr sock; }; diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp index cc7107529..affbdfd88 100644 --- a/fastboot/socket_test.cpp +++ b/fastboot/socket_test.cpp @@ -28,7 +28,8 @@ #include #include -enum { kTestTimeoutMs = 3000 }; +static constexpr int kShortTimeoutMs = 10; +static constexpr int kTestTimeoutMs = 3000; // Creates connected sockets |server| and |client|. Returns true on success. bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr* server, @@ -87,6 +88,50 @@ TEST(SocketTest, TestSendAndReceive) { } } +TEST(SocketTest, TestReceiveTimeout) { + std::unique_ptr server, client; + char buffer[16]; + + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); + + EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(server->ReceiveTimedOut()); + + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(client->ReceiveTimedOut()); + } + + // UDP will wait for timeout if the other side closes. + ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); + EXPECT_TRUE(client->ReceiveTimedOut()); +} + +TEST(SocketTest, TestReceiveFailure) { + std::unique_ptr server, client; + char buffer[16]; + + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); + + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(server->ReceiveTimedOut()); + + EXPECT_EQ(0, client->Close()); + EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(client->ReceiveTimedOut()); + } + + // TCP knows right away when the other side closes and returns 0 to indicate EOF. + ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kTcp, &server, &client)); + EXPECT_EQ(0, server->Close()); + EXPECT_EQ(0, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); + EXPECT_FALSE(client->ReceiveTimedOut()); +} + // Tests sending and receiving large packets. TEST(SocketTest, TestLargePackets) { std::string message(1024, '\0'); @@ -290,6 +335,11 @@ TEST(SocketMockTest, TestReceiveFailure) { mock->AddReceiveFailure(); EXPECT_FALSE(ReceiveString(mock, "foo")); + EXPECT_FALSE(mock->ReceiveTimedOut()); + + mock->AddReceiveTimeout(); + EXPECT_FALSE(ReceiveString(mock, "foo")); + EXPECT_TRUE(mock->ReceiveTimedOut()); mock->AddReceive("foo"); mock->AddReceiveFailure();