Merge "Revert "adb: don't close sockets before hitting EOF.""
This commit is contained in:
commit
58d5906be3
|
@ -24,7 +24,6 @@ cc_defaults {
|
|||
"-Wno-missing-field-initializers",
|
||||
"-Wvla",
|
||||
],
|
||||
cpp_std: "gnu++17",
|
||||
rtti: true,
|
||||
|
||||
use_version_lib: true,
|
||||
|
|
155
adb/sockets.cpp
155
adb/sockets.cpp
|
@ -26,14 +26,10 @@
|
|||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include <android-base/thread_annotations.h>
|
||||
|
||||
#if !ADB_HOST
|
||||
#include <android-base/properties.h>
|
||||
#include <log/log_properties.h>
|
||||
|
@ -41,150 +37,9 @@
|
|||
|
||||
#include "adb.h"
|
||||
#include "adb_io.h"
|
||||
#include "adb_utils.h"
|
||||
#include "sysdeps/chrono.h"
|
||||
#include "transport.h"
|
||||
#include "types.h"
|
||||
|
||||
// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
|
||||
// socket while we have pending data, a TCP RST should be sent to the
|
||||
// other end to notify it that we didn't read all of its data. However,
|
||||
// this can result in data that we've successfully written out to be dropped
|
||||
// on the other end. To avoid this, instead of immediately closing a
|
||||
// socket, call shutdown on it instead, and then read from the file
|
||||
// descriptor until we hit EOF or an error before closing.
|
||||
struct LingeringSocketCloser {
|
||||
LingeringSocketCloser() = default;
|
||||
~LingeringSocketCloser() = delete;
|
||||
|
||||
// Defer thread creation until it's needed, because we need for there to
|
||||
// only be one thread when dropping privileges in adbd.
|
||||
void Start() {
|
||||
CHECK(!thread_.joinable());
|
||||
|
||||
int fds[2];
|
||||
if (adb_socketpair(fds) != 0) {
|
||||
PLOG(FATAL) << "adb_socketpair failed";
|
||||
}
|
||||
|
||||
set_file_block_mode(fds[0], false);
|
||||
set_file_block_mode(fds[1], false);
|
||||
|
||||
notify_fd_read_.reset(fds[0]);
|
||||
notify_fd_write_.reset(fds[1]);
|
||||
|
||||
thread_ = std::thread([this]() { Run(); });
|
||||
}
|
||||
|
||||
void EnqueueSocket(unique_fd socket) {
|
||||
// Shutdown the socket in the outgoing direction only, so that
|
||||
// we don't have the same problem on the opposite end.
|
||||
adb_shutdown(socket.get(), SHUT_WR);
|
||||
set_file_block_mode(socket.get(), false);
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
int fd = socket.get();
|
||||
SocketInfo info = {
|
||||
.fd = std::move(socket),
|
||||
.deadline = std::chrono::steady_clock::now() + 1s,
|
||||
};
|
||||
|
||||
D("LingeringSocketCloser received fd %d", fd);
|
||||
|
||||
fds_.emplace(fd, std::move(info));
|
||||
if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
|
||||
PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<adb_pollfd> GeneratePollFds() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<adb_pollfd> result;
|
||||
result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
|
||||
for (auto& [fd, _] : fds_) {
|
||||
result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Run() {
|
||||
while (true) {
|
||||
std::vector<adb_pollfd> pfds = GeneratePollFds();
|
||||
int rc = adb_poll(pfds.data(), pfds.size(), 1000);
|
||||
if (rc == -1) {
|
||||
PLOG(FATAL) << "poll failed in LingeringSocketCloser";
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (rc == 0) {
|
||||
// Check deadlines.
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
for (auto it = fds_.begin(); it != fds_.end();) {
|
||||
if (now > it->second.deadline) {
|
||||
D("LingeringSocketCloser closing fd %d due to deadline", it->first);
|
||||
it = fds_.erase(it);
|
||||
} else {
|
||||
D("deadline still not expired for fd %d", it->first);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& pfd : pfds) {
|
||||
if ((pfd.revents & POLLIN) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Empty the fd.
|
||||
ssize_t rc;
|
||||
char buf[32768];
|
||||
while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pfd.fd == notify_fd_read_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = fds_.find(pfd.fd);
|
||||
if (it == fds_.end()) {
|
||||
LOG(FATAL) << "fd is missing";
|
||||
}
|
||||
|
||||
if (rc == -1 && errno == EAGAIN) {
|
||||
if (std::chrono::steady_clock::now() > it->second.deadline) {
|
||||
D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else if (rc == -1) {
|
||||
D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
|
||||
} else {
|
||||
D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
|
||||
}
|
||||
|
||||
fds_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::thread thread_;
|
||||
unique_fd notify_fd_read_;
|
||||
unique_fd notify_fd_write_;
|
||||
|
||||
struct SocketInfo {
|
||||
unique_fd fd;
|
||||
std::chrono::steady_clock::time_point deadline;
|
||||
};
|
||||
|
||||
std::mutex mutex_;
|
||||
std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
static auto& socket_closer = *new LingeringSocketCloser();
|
||||
|
||||
static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
|
||||
static unsigned local_socket_next_id = 1;
|
||||
|
||||
|
@ -388,12 +243,10 @@ static void local_socket_destroy(asocket* s) {
|
|||
|
||||
D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
|
||||
|
||||
// Defer thread creation until it's needed, because we need for there to
|
||||
// only be one thread when dropping privileges in adbd.
|
||||
static std::once_flag once;
|
||||
std::call_once(once, []() { socket_closer.Start(); });
|
||||
|
||||
socket_closer.EnqueueSocket(fdevent_release(s->fde));
|
||||
/* IMPORTANT: the remove closes the fd
|
||||
** that belongs to this socket
|
||||
*/
|
||||
fdevent_destroy(s->fde);
|
||||
|
||||
remove_socket(s);
|
||||
delete s;
|
||||
|
|
|
@ -35,8 +35,6 @@ import threading
|
|||
import time
|
||||
import unittest
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import adb
|
||||
|
||||
def requires_root(func):
|
||||
|
@ -1337,63 +1335,6 @@ class DeviceOfflineTest(DeviceTest):
|
|||
self.device.forward_remove("tcp:{}".format(local_port))
|
||||
|
||||
|
||||
class SocketTest(DeviceTest):
|
||||
def test_socket_flush(self):
|
||||
"""Test that we handle socket closure properly.
|
||||
|
||||
If we're done writing to a socket, closing before the other end has
|
||||
closed will send a TCP_RST if we have incoming data queued up, which
|
||||
may result in data that we've written being discarded.
|
||||
|
||||
Bug: http://b/74616284
|
||||
"""
|
||||
s = socket.create_connection(("localhost", 5037))
|
||||
|
||||
def adb_length_prefixed(string):
|
||||
encoded = string.encode("utf8")
|
||||
result = b"%04x%s" % (len(encoded), encoded)
|
||||
return result
|
||||
|
||||
if "ANDROID_SERIAL" in os.environ:
|
||||
transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
|
||||
else:
|
||||
transport_string = "host:transport-any"
|
||||
|
||||
s.sendall(adb_length_prefixed(transport_string))
|
||||
response = s.recv(4)
|
||||
self.assertEquals(b"OKAY", response)
|
||||
|
||||
shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
|
||||
s.sendall(adb_length_prefixed(shell_string))
|
||||
|
||||
response = s.recv(4)
|
||||
self.assertEquals(b"OKAY", response)
|
||||
|
||||
# Spawn a thread that dumps garbage into the socket until failure.
|
||||
def spam():
|
||||
buf = b"\0" * 16384
|
||||
try:
|
||||
while True:
|
||||
s.sendall(buf)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
thread = threading.Thread(target=spam)
|
||||
thread.start()
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
received = b""
|
||||
while True:
|
||||
read = s.recv(512)
|
||||
if len(read) == 0:
|
||||
break
|
||||
received += read
|
||||
|
||||
self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
|
||||
thread.join()
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
# From https://stackoverflow.com/a/38749458
|
||||
import os
|
||||
|
|
Loading…
Reference in New Issue