From 572bce29088521caf7f90c9fa66a8237a7674435 Mon Sep 17 00:00:00 2001 From: David Pursell Date: Fri, 15 Jan 2016 14:19:56 -0800 Subject: [PATCH] fastboot: use cutils socket functions. Now that cutils has cross-platform socket functionality, we can restructure fastboot to remove platform-dependent networking code. This CL adds socket_set_receive_timeout() to libcutils and combines the fastboot socket code into a single implementation. It also adds TCP functionality to fastboot sockets, but nothing uses it yet except for the unit tests. A future CL will add the TCP protocol which will use this TCP socket implementation. Bug: http://b/26558551 Change-Id: If613fb348f9332b31fa2c88d67fb1e839923768a --- fastboot/.clang-format | 5 +- fastboot/Android.mk | 22 ++-- fastboot/socket.cpp | 212 +++++++++++++++++++++++++++++++ fastboot/socket.h | 50 +++++--- fastboot/socket_test.cpp | 209 ++++++++++-------------------- fastboot/socket_unix.cpp | 131 ------------------- fastboot/socket_windows.cpp | 134 ------------------- include/cutils/sockets.h | 8 ++ libcutils/sockets_unix.c | 7 + libcutils/sockets_windows.c | 5 + libcutils/tests/sockets_test.cpp | 47 +++++++ 11 files changed, 400 insertions(+), 430 deletions(-) create mode 100644 fastboot/socket.cpp delete mode 100644 fastboot/socket_unix.cpp delete mode 100644 fastboot/socket_windows.cpp diff --git a/fastboot/.clang-format b/fastboot/.clang-format index 673753525..bcb8d8ac0 100644 --- a/fastboot/.clang-format +++ b/fastboot/.clang-format @@ -1,11 +1,14 @@ BasedOnStyle: Google AllowShortBlocksOnASingleLine: false -AllowShortFunctionsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline ColumnLimit: 100 CommentPragmas: NOLINT:.* DerivePointerAlignment: false IndentWidth: 4 +ContinuationIndentWidth: 8 +ConstructorInitializerIndentWidth: 8 +AccessModifierOffset: -2 PointerAlignment: Left TabWidth: 4 UseTab: Never diff --git a/fastboot/Android.mk b/fastboot/Android.mk index bb28afab5..fcec5b104 100644 --- a/fastboot/Android.mk +++ b/fastboot/Android.mk @@ -24,7 +24,15 @@ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../../extras/ext4_utils \ $(LOCAL_PATH)/../../extras/f2fs_utils \ -LOCAL_SRC_FILES := protocol.cpp engine.cpp bootimg_utils.cpp fastboot.cpp util.cpp fs.cpp +LOCAL_SRC_FILES := \ + bootimg_utils.cpp \ + engine.cpp \ + fastboot.cpp \ + fs.cpp\ + protocol.cpp \ + socket.cpp \ + util.cpp \ + LOCAL_MODULE := fastboot LOCAL_MODULE_TAGS := debug LOCAL_MODULE_HOST_OS := darwin linux windows @@ -33,15 +41,15 @@ LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code LOCAL_CFLAGS += -DFASTBOOT_REVISION='"$(fastboot_version)"' -LOCAL_SRC_FILES_linux := socket_unix.cpp usb_linux.cpp util_linux.cpp +LOCAL_SRC_FILES_linux := usb_linux.cpp util_linux.cpp LOCAL_STATIC_LIBRARIES_linux := libselinux -LOCAL_SRC_FILES_darwin := socket_unix.cpp usb_osx.cpp util_osx.cpp +LOCAL_SRC_FILES_darwin := usb_osx.cpp util_osx.cpp LOCAL_STATIC_LIBRARIES_darwin := libselinux LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter -LOCAL_SRC_FILES_windows := socket_windows.cpp usb_windows.cpp util_windows.cpp +LOCAL_SRC_FILES_windows := usb_windows.cpp util_windows.cpp LOCAL_STATIC_LIBRARIES_windows := AdbWinApi LOCAL_REQUIRED_MODULES_windows := AdbWinApi LOCAL_LDLIBS_windows := -lws2_32 @@ -98,18 +106,14 @@ include $(CLEAR_VARS) LOCAL_MODULE := fastboot_test LOCAL_MODULE_HOST_OS := darwin linux windows -LOCAL_SRC_FILES := socket_test.cpp +LOCAL_SRC_FILES := socket.cpp socket_test.cpp LOCAL_STATIC_LIBRARIES := libbase libcutils LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code -LOCAL_SRC_FILES_linux := socket_unix.cpp - -LOCAL_SRC_FILES_darwin := socket_unix.cpp LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter -LOCAL_SRC_FILES_windows := socket_windows.cpp LOCAL_LDLIBS_windows := -lws2_32 include $(BUILD_HOST_NATIVE_TEST) diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp new file mode 100644 index 000000000..d41f1fe6f --- /dev/null +++ b/fastboot/socket.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (C) 2015 The Android Open Source Project + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS + * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED + * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include "socket.h" + +#include + +Socket::Socket(cutils_socket_t sock) : sock_(sock) {} + +Socket::~Socket() { + Close(); +} + +int Socket::Close() { + int ret = 0; + + if (sock_ != INVALID_SOCKET) { + ret = socket_close(sock_); + sock_ = INVALID_SOCKET; + } + + return ret; +} + +bool Socket::SetReceiveTimeout(int timeout_ms) { + if (timeout_ms != receive_timeout_ms_) { + if (socket_set_receive_timeout(sock_, timeout_ms) == 0) { + receive_timeout_ms_ = timeout_ms; + return true; + } + return false; + } + + return true; +} + +ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) { + size_t total = 0; + + while (total < length) { + ssize_t bytes = Receive(reinterpret_cast(data) + total, length - total, timeout_ms); + + if (bytes == -1) { + if (total == 0) { + return -1; + } + break; + } + total += bytes; + } + + return total; +} + +// Implements the Socket interface for UDP. +class UdpSocket : public Socket { + public: + enum class Type { kClient, kServer }; + + UdpSocket(Type type, cutils_socket_t sock); + + ssize_t Send(const void* data, size_t length) override; + ssize_t Receive(void* data, size_t length, int timeout_ms) override; + + private: + std::unique_ptr addr_; + socklen_t addr_size_ = 0; + + DISALLOW_COPY_AND_ASSIGN(UdpSocket); +}; + +UdpSocket::UdpSocket(Type type, cutils_socket_t sock) : Socket(sock) { + // Only servers need to remember addresses; clients are connected to a server in NewClient() + // so will send to that server without needing to specify the address again. + if (type == Type::kServer) { + addr_.reset(new sockaddr_storage); + addr_size_ = sizeof(*addr_); + memset(addr_.get(), 0, addr_size_); + } +} + +ssize_t UdpSocket::Send(const void* data, size_t length) { + return TEMP_FAILURE_RETRY(sendto(sock_, reinterpret_cast(data), length, 0, + reinterpret_cast(addr_.get()), addr_size_)); +} + +ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { + if (!SetReceiveTimeout(timeout_ms)) { + return -1; + } + + socklen_t* addr_size_ptr = nullptr; + if (addr_ != nullptr) { + // Reset addr_size as it may have been modified by previous recvfrom() calls. + addr_size_ = sizeof(*addr_); + addr_size_ptr = &addr_size_; + } + + return TEMP_FAILURE_RETRY(recvfrom(sock_, reinterpret_cast(data), length, 0, + reinterpret_cast(addr_.get()), addr_size_ptr)); +} + +// Implements the Socket interface for TCP. +class TcpSocket : public Socket { + public: + TcpSocket(cutils_socket_t sock) : Socket(sock) {} + + ssize_t Send(const void* data, size_t length) override; + ssize_t Receive(void* data, size_t length, int timeout_ms) override; + + std::unique_ptr Accept() override; + + private: + DISALLOW_COPY_AND_ASSIGN(TcpSocket); +}; + +ssize_t TcpSocket::Send(const void* data, size_t length) { + size_t total = 0; + + while (total < length) { + ssize_t bytes = TEMP_FAILURE_RETRY( + send(sock_, reinterpret_cast(data) + total, length - total, 0)); + + if (bytes == -1) { + if (total == 0) { + return -1; + } + break; + } + total += bytes; + } + + return total; +} + +ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { + if (!SetReceiveTimeout(timeout_ms)) { + return -1; + } + + return TEMP_FAILURE_RETRY(recv(sock_, reinterpret_cast(data), length, 0)); +} + +std::unique_ptr TcpSocket::Accept() { + cutils_socket_t handler = accept(sock_, nullptr, nullptr); + if (handler == INVALID_SOCKET) { + return nullptr; + } + return std::unique_ptr(new TcpSocket(handler)); +} + +std::unique_ptr Socket::NewClient(Protocol protocol, const std::string& host, int port, + std::string* error) { + if (protocol == Protocol::kUdp) { + cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_DGRAM); + if (sock != INVALID_SOCKET) { + return std::unique_ptr(new UdpSocket(UdpSocket::Type::kClient, sock)); + } + } else { + cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_STREAM); + if (sock != INVALID_SOCKET) { + return std::unique_ptr(new TcpSocket(sock)); + } + } + + if (error) { + *error = android::base::StringPrintf("Failed to connect to %s:%d", host.c_str(), port); + } + return nullptr; +} + +// This functionality is currently only used by tests so we don't need any error messages. +std::unique_ptr Socket::NewServer(Protocol protocol, int port) { + if (protocol == Protocol::kUdp) { + cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_DGRAM); + if (sock != INVALID_SOCKET) { + return std::unique_ptr(new UdpSocket(UdpSocket::Type::kServer, sock)); + } + } else { + cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_STREAM); + if (sock != INVALID_SOCKET) { + return std::unique_ptr(new TcpSocket(sock)); + } + } + + return nullptr; +} diff --git a/fastboot/socket.h b/fastboot/socket.h index 888b53077..3e66c274f 100644 --- a/fastboot/socket.h +++ b/fastboot/socket.h @@ -26,36 +26,41 @@ * SUCH DAMAGE. */ -// This file provides a class interface for cross-platform UDP functionality. The main fastboot +// This file provides a class interface for cross-platform socket functionality. The main fastboot // engine should not be using this interface directly, but instead should use a higher-level -// interface that enforces the fastboot UDP protocol. +// interface that enforces the fastboot protocol. #ifndef SOCKET_H_ #define SOCKET_H_ -#include "android-base/macros.h" - #include #include -// UdpSocket interface to be implemented for each platform. -class UdpSocket { +#include +#include + +// Socket interface to be implemented for each platform. +class Socket { public: + enum class Protocol { kTcp, kUdp }; + // Creates a new client connection. Clients are connected to a specific hostname/port and can // only send to that destination. // On failure, |error| is filled (if non-null) and nullptr is returned. - static std::unique_ptr NewUdpClient(const std::string& hostname, int port, - std::string* error); + static std::unique_ptr NewClient(Protocol protocol, const std::string& hostname, + int port, std::string* error); // Creates a new server bound to local |port|. This is only meant for testing, during normal // fastboot operation the device acts as the server. - // The server saves sender addresses in Receive(), and uses the most recent address during + // A UDP server saves sender addresses in Receive(), and uses the most recent address during // calls to Send(). - static std::unique_ptr NewUdpServer(int port); + static std::unique_ptr NewServer(Protocol protocol, int port); - virtual ~UdpSocket() = default; + // Destructor closes the socket if it's open. + virtual ~Socket(); - // Sends |length| bytes of |data|. Returns the number of bytes actually sent or -1 on error. + // Sends |length| bytes of |data|. For TCP sockets this will continue trying to send until all + // bytes are transmitted. Returns the number of bytes actually sent or -1 on error. virtual ssize_t Send(const void* data, size_t length) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will @@ -63,14 +68,29 @@ class UdpSocket { // errno will be set to EAGAIN or EWOULDBLOCK. virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0; + // Calls Receive() until exactly |length| bytes have been received or an error occurs. + virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms); + // Closes the socket. Returns 0 on success, -1 on error. - virtual int Close() = 0; + virtual int Close(); + + // Accepts an incoming TCP connection. No effect for UDP sockets. Returns a new Socket + // connected to the client on success, nullptr on failure. + virtual std::unique_ptr Accept() { return nullptr; } protected: // Protected constructor to force factory function use. - UdpSocket() = default; + Socket(cutils_socket_t sock); - DISALLOW_COPY_AND_ASSIGN(UdpSocket); + // Update the socket receive timeout if necessary. + bool SetReceiveTimeout(int timeout_ms); + + cutils_socket_t sock_ = INVALID_SOCKET; + + private: + int receive_timeout_ms_ = 0; + + DISALLOW_COPY_AND_ASSIGN(Socket); }; #endif // SOCKET_H_ diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp index 6ada964ad..1fd9d7c22 100644 --- a/fastboot/socket_test.cpp +++ b/fastboot/socket_test.cpp @@ -14,184 +14,113 @@ * limitations under the License. */ -// Tests UDP functionality using loopback connections. Requires that kDefaultPort is available +// Tests UDP functionality using loopback connections. Requires that kTestPort is available // for loopback communication on the host. These tests also assume that no UDP packets are lost, // which should be the case for loopback communication, but is not guaranteed. #include "socket.h" -#include -#include - -#include -#include -#include - #include enum { // This port must be available for loopback communication. - kDefaultPort = 54321, + kTestPort = 54321, // Don't wait forever in a unit test. - kDefaultTimeoutMs = 3000, + kTestTimeoutMs = 3000, }; -static const char kReceiveStringError[] = "Error receiving string"; - -// Test fixture to provide some helper functions. Makes each test a little simpler since we can -// just check a bool for socket creation and don't have to pass hostname or port information. -class SocketTest : public ::testing::Test { - protected: - bool StartServer(int port = kDefaultPort) { - server_ = UdpSocket::NewUdpServer(port); - return server_ != nullptr; +// Creates connected sockets |server| and |client|. Returns true on success. +bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr* server, + std::unique_ptr* client, const std::string hostname = "localhost", + int port = kTestPort) { + *server = Socket::NewServer(protocol, port); + if (*server == nullptr) { + ADD_FAILURE() << "Failed to create server."; + return false; } - bool StartClient(const std::string hostname = "localhost", int port = kDefaultPort) { - client_ = UdpSocket::NewUdpClient(hostname, port, nullptr); - return client_ != nullptr; + *client = Socket::NewClient(protocol, hostname, port, nullptr); + if (*client == nullptr) { + ADD_FAILURE() << "Failed to create client."; + return false; } - bool StartClient2(const std::string hostname = "localhost", int port = kDefaultPort) { - client2_ = UdpSocket::NewUdpClient(hostname, port, nullptr); - return client2_ != nullptr; + // TCP passes the client off to a new socket. + if (protocol == Socket::Protocol::kTcp) { + *server = (*server)->Accept(); + if (*server == nullptr) { + ADD_FAILURE() << "Failed to accept client connection."; + return false; + } } - std::unique_ptr server_, client_, client2_; -}; + return true; +} -// Sends a string over a UdpSocket. Returns true if the full string (without terminating char) +// Sends a string over a Socket. Returns true if the full string (without terminating char) // was sent. -static bool SendString(UdpSocket* udp, const std::string& message) { - return udp->Send(message.c_str(), message.length()) == static_cast(message.length()); +static bool SendString(Socket* sock, const std::string& message) { + return sock->Send(message.c_str(), message.length()) == static_cast(message.length()); } -// Receives a string from a UdpSocket. Returns the string, or kReceiveStringError on failure. -static std::string ReceiveString(UdpSocket* udp, size_t receive_size = 128) { - std::vector buffer(receive_size); - - ssize_t result = udp->Receive(buffer.data(), buffer.size(), kDefaultTimeoutMs); - if (result >= 0) { - return std::string(buffer.data(), result); - } - return kReceiveStringError; -} - -// Calls Receive() on the UdpSocket with the given timeout. Returns true if the call timed out. -static bool ReceiveTimeout(UdpSocket* udp, int timeout_ms) { - char buffer[1]; - - errno = 0; - return udp->Receive(buffer, 1, timeout_ms) == -1 && (errno == EAGAIN || errno == EWOULDBLOCK); +// Receives a string from a Socket. Returns true if the full string (without terminating char) +// was received. +static bool ReceiveString(Socket* sock, const std::string& message) { + std::string received(message.length(), '\0'); + ssize_t bytes = sock->ReceiveAll(&received[0], received.length(), kTestTimeoutMs); + return static_cast(bytes) == received.length() && received == message; } // Tests sending packets client -> server, then server -> client. -TEST_F(SocketTest, SendAndReceive) { - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient()); +TEST(SocketTest, TestSendAndReceive) { + std::unique_ptr server, client; - EXPECT_TRUE(SendString(client_.get(), "foo")); - EXPECT_EQ("foo", ReceiveString(server_.get())); + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); - EXPECT_TRUE(SendString(server_.get(), "bar baz")); - EXPECT_EQ("bar baz", ReceiveString(client_.get())); + EXPECT_TRUE(SendString(client.get(), "foo")); + EXPECT_TRUE(ReceiveString(server.get(), "foo")); + + EXPECT_TRUE(SendString(server.get(), "bar baz")); + EXPECT_TRUE(ReceiveString(client.get(), "bar baz")); + } } // Tests sending and receiving large packets. -TEST_F(SocketTest, LargePackets) { - std::string message(512, '\0'); +TEST(SocketTest, TestLargePackets) { + std::string message(1024, '\0'); + std::unique_ptr server, client; - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient()); + for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { + ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); - // Run through the test a few times. - for (int i = 0; i < 10; ++i) { - // Use a different message each iteration to prevent false positives. - for (size_t j = 0; j < message.length(); ++j) { - message[j] = static_cast(i + j); + // Run through the test a few times. + for (int i = 0; i < 10; ++i) { + // Use a different message each iteration to prevent false positives. + for (size_t j = 0; j < message.length(); ++j) { + message[j] = static_cast(i + j); + } + + EXPECT_TRUE(SendString(client.get(), message)); + EXPECT_TRUE(ReceiveString(server.get(), message)); } - - EXPECT_TRUE(SendString(client_.get(), message)); - EXPECT_EQ(message, ReceiveString(server_.get(), message.length())); } } -// Tests IPv4 client/server. -TEST_F(SocketTest, IPv4) { - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient("127.0.0.1")); +// Tests UDP receive overflow when the UDP packet is larger than the receive buffer. +TEST(SocketTest, TestUdpReceiveOverflow) { + std::unique_ptr server, client; + ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); - EXPECT_TRUE(SendString(client_.get(), "foo")); - EXPECT_EQ("foo", ReceiveString(server_.get())); + EXPECT_TRUE(SendString(client.get(), "1234567890")); - EXPECT_TRUE(SendString(server_.get(), "bar")); - EXPECT_EQ("bar", ReceiveString(client_.get())); -} - -// Tests IPv6 client/server. -TEST_F(SocketTest, IPv6) { - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient("::1")); - - EXPECT_TRUE(SendString(client_.get(), "foo")); - EXPECT_EQ("foo", ReceiveString(server_.get())); - - EXPECT_TRUE(SendString(server_.get(), "bar")); - EXPECT_EQ("bar", ReceiveString(client_.get())); -} - -// Tests receive timeout. The timing verification logic must be very coarse to make sure different -// systems running different loads can all pass these tests. -TEST_F(SocketTest, ReceiveTimeout) { - time_t start_time; - - ASSERT_TRUE(StartServer()); - - // Make sure a 20ms timeout completes in 1 second or less. - start_time = time(nullptr); - EXPECT_TRUE(ReceiveTimeout(server_.get(), 20)); - EXPECT_LE(difftime(time(nullptr), start_time), 1.0); - - // Make sure a 1250ms timeout takes 1 second or more. - start_time = time(nullptr); - EXPECT_TRUE(ReceiveTimeout(server_.get(), 1250)); - EXPECT_LE(1.0, difftime(time(nullptr), start_time)); -} - -// Tests receive overflow (the UDP packet is larger than the receive buffer). -TEST_F(SocketTest, ReceiveOverflow) { - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient()); - - EXPECT_TRUE(SendString(client_.get(), "1234567890")); - - // This behaves differently on different systems; some give us a truncated UDP packet, others - // will error out and not return anything at all. - std::string rx_string = ReceiveString(server_.get(), 5); - - // If we didn't get an error then the packet should have been truncated. - if (rx_string != kReceiveStringError) { - EXPECT_EQ("12345", rx_string); + // This behaves differently on different systems, either truncating the packet or returning -1. + char buffer[5]; + ssize_t bytes = server->Receive(buffer, 5, kTestTimeoutMs); + if (bytes == 5) { + EXPECT_EQ(0, memcmp(buffer, "12345", 5)); + } else { + EXPECT_EQ(-1, bytes); } } - -// Tests multiple clients sending to the same server. -TEST_F(SocketTest, MultipleClients) { - ASSERT_TRUE(StartServer()); - ASSERT_TRUE(StartClient()); - ASSERT_TRUE(StartClient2()); - - EXPECT_TRUE(SendString(client_.get(), "client")); - EXPECT_TRUE(SendString(client2_.get(), "client2")); - - // Receive the packets and send a response for each (note that packets may be received - // out-of-order). - for (int i = 0; i < 2; ++i) { - std::string received = ReceiveString(server_.get()); - EXPECT_TRUE(SendString(server_.get(), received + " response")); - } - - EXPECT_EQ("client response", ReceiveString(client_.get())); - EXPECT_EQ("client2 response", ReceiveString(client2_.get())); -} diff --git a/fastboot/socket_unix.cpp b/fastboot/socket_unix.cpp deleted file mode 100644 index 462256a82..000000000 --- a/fastboot/socket_unix.cpp +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright (C) 2015 The Android Open Source Project - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS - * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE - * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, - * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS - * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED - * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT - * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF - * SUCH DAMAGE. - */ - -#include "socket.h" - -#include -#include - -#include -#include - -class UnixUdpSocket : public UdpSocket { - public: - enum class Type { kClient, kServer }; - - UnixUdpSocket(int fd, Type type); - ~UnixUdpSocket() override; - - ssize_t Send(const void* data, size_t length) override; - ssize_t Receive(void* data, size_t length, int timeout_ms) override; - int Close() override; - - private: - int fd_; - int receive_timeout_ms_ = 0; - std::unique_ptr addr_; - socklen_t addr_size_ = 0; - - DISALLOW_COPY_AND_ASSIGN(UnixUdpSocket); -}; - -UnixUdpSocket::UnixUdpSocket(int fd, Type type) : fd_(fd) { - // Only servers need to remember addresses; clients are connected to a server in NewUdpClient() - // so will send to that server without needing to specify the address again. - if (type == Type::kServer) { - addr_.reset(new sockaddr_storage); - addr_size_ = sizeof(*addr_); - memset(addr_.get(), 0, addr_size_); - } -} - -UnixUdpSocket::~UnixUdpSocket() { - Close(); -} - -ssize_t UnixUdpSocket::Send(const void* data, size_t length) { - return TEMP_FAILURE_RETRY( - sendto(fd_, data, length, 0, reinterpret_cast(addr_.get()), addr_size_)); -} - -ssize_t UnixUdpSocket::Receive(void* data, size_t length, int timeout_ms) { - // Only set socket timeout if it's changed. - if (receive_timeout_ms_ != timeout_ms) { - timeval tv; - tv.tv_sec = timeout_ms / 1000; - tv.tv_usec = (timeout_ms % 1000) * 1000; - if (setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { - return -1; - } - receive_timeout_ms_ = timeout_ms; - } - - socklen_t* addr_size_ptr = nullptr; - if (addr_ != nullptr) { - // Reset addr_size as it may have been modified by previous recvfrom() calls. - addr_size_ = sizeof(*addr_); - addr_size_ptr = &addr_size_; - } - return TEMP_FAILURE_RETRY(recvfrom(fd_, data, length, 0, - reinterpret_cast(addr_.get()), addr_size_ptr)); -} - -int UnixUdpSocket::Close() { - int result = 0; - if (fd_ != -1) { - result = close(fd_); - fd_ = -1; - } - return result; -} - -std::unique_ptr UdpSocket::NewUdpClient(const std::string& host, int port, - std::string* error) { - int getaddrinfo_error = 0; - int fd = socket_network_client_timeout(host.c_str(), port, SOCK_DGRAM, 0, &getaddrinfo_error); - if (fd == -1) { - if (error) { - *error = android::base::StringPrintf( - "Failed to connect to %s:%d: %s", host.c_str(), port, - getaddrinfo_error ? gai_strerror(getaddrinfo_error) : strerror(errno)); - } - return nullptr; - } - - return std::unique_ptr(new UnixUdpSocket(fd, UnixUdpSocket::Type::kClient)); -} - -std::unique_ptr UdpSocket::NewUdpServer(int port) { - int fd = socket_inaddr_any_server(port, SOCK_DGRAM); - if (fd == -1) { - // This is just used in testing, no need for an error message. - return nullptr; - } - - return std::unique_ptr(new UnixUdpSocket(fd, UnixUdpSocket::Type::kServer)); -} diff --git a/fastboot/socket_windows.cpp b/fastboot/socket_windows.cpp deleted file mode 100644 index f86bb69d9..000000000 --- a/fastboot/socket_windows.cpp +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (C) 2015 The Android Open Source Project - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS - * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE - * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, - * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS - * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED - * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT - * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF - * SUCH DAMAGE. - */ - -#include "socket.h" - -#include -#include - -#include - -#include -#include - -// Windows UDP socket functionality. -class WindowsUdpSocket : public UdpSocket { - public: - enum class Type { kClient, kServer }; - - WindowsUdpSocket(SOCKET sock, Type type); - ~WindowsUdpSocket() override; - - ssize_t Send(const void* data, size_t len) override; - ssize_t Receive(void* data, size_t len, int timeout_ms) override; - int Close() override; - - private: - SOCKET sock_; - int receive_timeout_ms_ = 0; - std::unique_ptr addr_; - int addr_size_ = 0; - - DISALLOW_COPY_AND_ASSIGN(WindowsUdpSocket); -}; - -WindowsUdpSocket::WindowsUdpSocket(SOCKET sock, Type type) : sock_(sock) { - // Only servers need to remember addresses; clients are connected to a server in NewUdpClient() - // so will send to that server without needing to specify the address again. - if (type == Type::kServer) { - addr_.reset(new sockaddr_storage); - addr_size_ = sizeof(*addr_); - memset(addr_.get(), 0, addr_size_); - } -} - -WindowsUdpSocket::~WindowsUdpSocket() { - Close(); -} - -ssize_t WindowsUdpSocket::Send(const void* data, size_t len) { - return sendto(sock_, reinterpret_cast(data), len, 0, - reinterpret_cast(addr_.get()), addr_size_); -} - -ssize_t WindowsUdpSocket::Receive(void* data, size_t len, int timeout_ms) { - // Only set socket timeout if it's changed. - if (receive_timeout_ms_ != timeout_ms) { - if (setsockopt(sock_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout_ms), - sizeof(timeout_ms)) < 0) { - return -1; - } - receive_timeout_ms_ = timeout_ms; - } - - int* addr_size_ptr = nullptr; - if (addr_ != nullptr) { - // Reset addr_size as it may have been modified by previous recvfrom() calls. - addr_size_ = sizeof(*addr_); - addr_size_ptr = &addr_size_; - } - int result = recvfrom(sock_, reinterpret_cast(data), len, 0, - reinterpret_cast(addr_.get()), addr_size_ptr); - if (result < 0 && WSAGetLastError() == WSAETIMEDOUT) { - errno = EAGAIN; - } - return result; -} - -int WindowsUdpSocket::Close() { - int result = 0; - if (sock_ != INVALID_SOCKET) { - result = closesocket(sock_); - sock_ = INVALID_SOCKET; - } - return result; -} - -std::unique_ptr UdpSocket::NewUdpClient(const std::string& host, int port, - std::string* error) { - SOCKET sock = socket_network_client(host.c_str(), port, SOCK_DGRAM); - if (sock == INVALID_SOCKET) { - if (error) { - *error = android::base::StringPrintf("Failed to connect to %s:%d (error %d)", - host.c_str(), port, WSAGetLastError()); - } - return nullptr; - } - - return std::unique_ptr(new WindowsUdpSocket(sock, WindowsUdpSocket::Type::kClient)); -} - -// This functionality is currently only used by tests so we don't need any error messages. -std::unique_ptr UdpSocket::NewUdpServer(int port) { - SOCKET sock = socket_inaddr_any_server(port, SOCK_DGRAM); - if (sock == INVALID_SOCKET) { - return nullptr; - } - - return std::unique_ptr(new WindowsUdpSocket(sock, WindowsUdpSocket::Type::kServer)); -} diff --git a/include/cutils/sockets.h b/include/cutils/sockets.h index c99d0aa2e..e25c555da 100644 --- a/include/cutils/sockets.h +++ b/include/cutils/sockets.h @@ -123,6 +123,14 @@ cutils_socket_t socket_inaddr_any_server(int port, int type); */ int socket_close(cutils_socket_t sock); +/* + * Sets socket receive timeout using SO_RCVTIMEO. Setting |timeout_ms| to 0 + * disables receive timeouts. + * + * Return 0 on success. + */ +int socket_set_receive_timeout(cutils_socket_t sock, int timeout_ms); + /* * socket_peer_is_trusted - Takes a socket which is presumed to be a * connected local socket (e.g. AF_LOCAL) and returns whether the peer diff --git a/libcutils/sockets_unix.c b/libcutils/sockets_unix.c index ca3f67ef2..5eddc4be1 100644 --- a/libcutils/sockets_unix.c +++ b/libcutils/sockets_unix.c @@ -49,3 +49,10 @@ bool socket_peer_is_trusted(int fd __android_unused) int socket_close(int sock) { return close(sock); } + +int socket_set_receive_timeout(cutils_socket_t sock, int timeout_ms) { + struct timeval tv; + tv.tv_sec = timeout_ms / 1000; + tv.tv_usec = (timeout_ms % 1000) * 1000; + return setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); +} diff --git a/libcutils/sockets_windows.c b/libcutils/sockets_windows.c index 92ccf88e0..1bf2933bd 100644 --- a/libcutils/sockets_windows.c +++ b/libcutils/sockets_windows.c @@ -53,3 +53,8 @@ bool initialize_windows_sockets() { int socket_close(cutils_socket_t sock) { return closesocket(sock); } + +int socket_set_receive_timeout(cutils_socket_t sock, int timeout_ms) { + return setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout_ms, + sizeof(timeout_ms)); +} diff --git a/libcutils/tests/sockets_test.cpp b/libcutils/tests/sockets_test.cpp index 975c3059c..966dfe77f 100644 --- a/libcutils/tests/sockets_test.cpp +++ b/libcutils/tests/sockets_test.cpp @@ -21,6 +21,8 @@ #include +#include + #include enum { @@ -66,6 +68,25 @@ static void TestConnectedSockets(cutils_socket_t server, cutils_socket_t client, EXPECT_EQ(0, socket_close(client)); } +// Tests receive timeout. The timing verification logic must be very coarse to +// make sure different systems can all pass these tests. +void TestReceiveTimeout(cutils_socket_t sock) { + time_t start_time; + char buffer[32]; + + // Make sure a 20ms timeout completes in 1 second or less. + EXPECT_EQ(0, socket_set_receive_timeout(sock, 20)); + start_time = time(nullptr); + EXPECT_EQ(-1, recv(sock, buffer, sizeof(buffer), 0)); + EXPECT_LE(difftime(time(nullptr), start_time), 1.0); + + // Make sure a 1250ms timeout takes 1 second or more. + EXPECT_EQ(0, socket_set_receive_timeout(sock, 1250)); + start_time = time(nullptr); + EXPECT_EQ(-1, recv(sock, buffer, sizeof(buffer), 0)); + EXPECT_LE(1.0, difftime(time(nullptr), start_time)); +} + // Tests socket_inaddr_any_server() and socket_network_client() for IPv4 UDP. TEST(SocketsTest, TestIpv4UdpLoopback) { cutils_socket_t server = socket_inaddr_any_server(kTestPort, SOCK_DGRAM); @@ -109,3 +130,29 @@ TEST(SocketsTest, TestIpv6TcpLoopback) { TestConnectedSockets(handler, client, SOCK_STREAM); } + +// Tests setting a receive timeout for UDP sockets. +TEST(SocketsTest, TestUdpReceiveTimeout) { + cutils_socket_t sock = socket_inaddr_any_server(kTestPort, SOCK_DGRAM); + ASSERT_NE(INVALID_SOCKET, sock); + + TestReceiveTimeout(sock); + + EXPECT_EQ(0, socket_close(sock)); +} + +// Tests setting a receive timeout for TCP sockets. +TEST(SocketsTest, TestTcpReceiveTimeout) { + cutils_socket_t server = socket_inaddr_any_server(kTestPort, SOCK_STREAM); + ASSERT_NE(INVALID_SOCKET, server); + + cutils_socket_t client = socket_network_client("localhost", kTestPort, + SOCK_STREAM); + cutils_socket_t handler = accept(server, nullptr, nullptr); + EXPECT_EQ(0, socket_close(server)); + + TestReceiveTimeout(handler); + + EXPECT_EQ(0, socket_close(client)); + EXPECT_EQ(0, socket_close(handler)); +}