adb: extract atransport's connection interface.

As step one of refactoring atransport to separate out protocol handling
from its underlying connection, extract atransport's existing
hand-rolled connection vtable out to its own abstract interface.

This should not change behavior except in one case: emulators are
now treated as TCP devices for the purposes of `adb disconnect`.

Test: python test_device.py, with walleye over USB + TCP
Test: manually connecting and disconnecting devices/emulators
Change-Id: I877b8027e567cc6a7461749432b49f6cb2c2f0d7
This commit is contained in:
Josh Gao 2018-01-28 20:32:46 -08:00
parent fb413a2304
commit b800d88b34
8 changed files with 173 additions and 233 deletions

View File

@ -11,6 +11,7 @@ adb_host_sanitize :=
adb_target_sanitize :=
ADB_COMMON_CFLAGS := \
-frtti \
-Wall -Wextra -Werror \
-Wno-unused-parameter \
-Wno-missing-field-initializers \

View File

@ -136,9 +136,6 @@ int launch_server(const std::string& socket_spec);
int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply_fd);
/* initialize a transport object's func pointers and state */
#if ADB_HOST
int get_available_local_transport_index();
#endif
int init_socket_transport(atransport* t, int s, int port, int local);
void init_usb_transport(atransport* t, usb_handle* usb);

View File

@ -407,14 +407,6 @@ void connect_emulator(const std::string& port_spec, std::string* response) {
return;
}
// Check if more emulators can be registered. Similar unproblematic
// race condition as above.
int candidate_slot = get_available_local_transport_index();
if (candidate_slot < 0) {
*response = "Cannot accept more emulators";
return;
}
// Preconditions met, try to connect to the emulator.
std::string error;
if (!local_connect_arbitrary_ports(console_port, adb_port, &error)) {

View File

@ -41,6 +41,7 @@
#include "adb.h"
#include "adb_auth.h"
#include "adb_io.h"
#include "adb_trace.h"
#include "adb_utils.h"
#include "diagnose_usb.h"
@ -65,6 +66,36 @@ TransportId NextTransportId() {
return next++;
}
bool FdConnection::Read(apacket* packet) {
if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) {
D("remote local: read terminated (message)");
return false;
}
if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
D("remote local: terminated (data)");
return false;
}
return true;
}
bool FdConnection::Write(apacket* packet) {
uint32_t length = packet->msg.data_length;
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
D("remote local: write terminated");
return false;
}
return true;
}
void FdConnection::Close() {
adb_shutdown(fd_.get());
fd_.reset();
}
static std::string dump_packet(const char* name, const char* func, apacket* p) {
unsigned command = p->msg.command;
int len = p->msg.data_length;
@ -220,11 +251,18 @@ static void read_transport_thread(void* _t) {
{
ATRACE_NAME("read_transport read_remote");
if (t->read_from_remote(p, t) != 0) {
if (!t->connection->Read(p)) {
D("%s: remote read failed for transport", t->serial);
put_apacket(p);
break;
}
if (!check_header(p, t)) {
D("%s: remote read: bad header", t->serial);
put_apacket(p);
break;
}
#if ADB_HOST
if (p->msg.command == 0) {
put_apacket(p);
@ -626,7 +664,7 @@ static void transport_unref(atransport* t) {
t->ref_count--;
if (t->ref_count == 0) {
D("transport: %s unref (kicking and closing)", t->serial);
t->close(t);
t->connection->Close();
remove_transport(t);
} else {
D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@ -754,14 +792,14 @@ atransport* acquire_one_transport(TransportType type, const char* serial, Transp
}
int atransport::Write(apacket* p) {
return write_func_(p, this);
return this->connection->Write(p) ? 0 : -1;
}
void atransport::Kick() {
if (!kicked_) {
D("kicking transport %s", this->serial);
kicked_ = true;
CHECK(kick_func_ != nullptr);
kick_func_(this);
this->connection->Close();
}
}
@ -1083,8 +1121,12 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev
// This should only be used for transports with connection_state == kCsNoPerm.
void unregister_usb_transport(usb_handle* usb) {
std::lock_guard<std::recursive_mutex> lock(transport_lock);
transport_list.remove_if(
[usb](atransport* t) { return t->usb == usb && t->GetConnectionState() == kCsNoPerm; });
transport_list.remove_if([usb](atransport* t) {
if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
}
return false;
});
}
bool check_header(apacket* p, atransport* t) {

View File

@ -28,10 +28,11 @@
#include <string>
#include <unordered_set>
#include "adb.h"
#include <openssl/rsa.h>
#include "adb.h"
#include "adb_unique_fd.h"
typedef std::unordered_set<std::string> FeatureSet;
const FeatureSet& supported_features();
@ -56,6 +57,50 @@ extern const char* const kFeaturePushSync;
TransportId NextTransportId();
// Abstraction for a blocking packet transport.
struct Connection {
Connection() = default;
Connection(const Connection& copy) = delete;
Connection(Connection&& move) = delete;
// Destroy a Connection. Formerly known as 'Close' in atransport.
virtual ~Connection() = default;
// Read/Write a packet. These functions are concurrently called from a transport's reader/writer
// threads.
virtual bool Read(apacket* packet) = 0;
virtual bool Write(apacket* packet) = 0;
// Terminate a connection.
// This method must be thread-safe, and must cause concurrent Reads/Writes to terminate.
// Formerly known as 'Kick' in atransport.
virtual void Close() = 0;
};
struct FdConnection : public Connection {
explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
void Close() override;
private:
unique_fd fd_;
};
struct UsbConnection : public Connection {
explicit UsbConnection(usb_handle* handle) : handle_(handle) {}
~UsbConnection();
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
void Close() override final;
usb_handle* handle_;
};
class atransport {
public:
// TODO(danalbert): We expose waaaaaaay too much stuff because this was
@ -73,12 +118,6 @@ class atransport {
}
virtual ~atransport() {}
int (*read_from_remote)(apacket* p, atransport* t) = nullptr;
void (*close)(atransport* t) = nullptr;
void SetWriteFunction(int (*write_func)(apacket*, atransport*)) { write_func_ = write_func; }
void SetKickFunction(void (*kick_func)(atransport*)) { kick_func_ = kick_func; }
bool IsKicked() { return kicked_; }
int Write(apacket* p);
void Kick();
@ -95,9 +134,7 @@ class atransport {
bool online = false;
TransportType type = kTransportAny;
// USB handle or socket fd as needed.
usb_handle* usb = nullptr;
int sfd = -1;
std::unique_ptr<Connection> connection;
// Used to identify transports for clients.
char* serial = nullptr;
@ -105,22 +142,8 @@ class atransport {
char* model = nullptr;
char* device = nullptr;
char* devpath = nullptr;
void SetLocalPortForEmulator(int port) {
CHECK_EQ(local_port_for_emulator_, -1);
local_port_for_emulator_ = port;
}
bool GetLocalPortForEmulator(int* port) const {
if (type == kTransportLocal && local_port_for_emulator_ != -1) {
*port = local_port_for_emulator_;
return true;
}
return false;
}
bool IsTcpDevice() const {
return type == kTransportLocal && local_port_for_emulator_ == -1;
}
bool IsTcpDevice() const { return type == kTransportLocal; }
#if ADB_HOST
std::shared_ptr<RSA> NextKey();
@ -165,10 +188,7 @@ class atransport {
bool MatchesTarget(const std::string& target) const;
private:
int local_port_for_emulator_ = -1;
bool kicked_ = false;
void (*kick_func_)(atransport*) = nullptr;
int (*write_func_)(apacket*, atransport*) = nullptr;
// A set of features transmitted in the banner with the initial connection.
// This is stored in the banner as 'features=feature0,feature1,etc'.

View File

@ -28,10 +28,12 @@
#include <condition_variable>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include <android-base/parsenetaddress.h>
#include <android-base/stringprintf.h>
#include <android-base/thread_annotations.h>
#include <cutils/sockets.h>
#if !ADB_HOST
@ -40,6 +42,7 @@
#include "adb.h"
#include "adb_io.h"
#include "adb_unique_fd.h"
#include "adb_utils.h"
#include "sysdeps/chrono.h"
@ -53,48 +56,15 @@
static std::mutex& local_transports_lock = *new std::mutex();
/* we keep a list of opened transports. The atransport struct knows to which
* local transport it is connected. The list is used to detect when we're
* trying to connect twice to a given local transport.
*/
static atransport* local_transports[ ADB_LOCAL_TRANSPORT_MAX ];
// We keep a map from emulator port to transport.
// TODO: weak_ptr?
static auto& local_transports GUARDED_BY(local_transports_lock) =
*new std::unordered_map<int, atransport*>();
#endif /* ADB_HOST */
static int remote_read(apacket *p, atransport *t)
{
if (!ReadFdExactly(t->sfd, &p->msg, sizeof(amessage))) {
D("remote local: read terminated (message)");
return -1;
}
if (!check_header(p, t)) {
D("bad header: terminated (data)");
return -1;
}
if (!ReadFdExactly(t->sfd, p->data, p->msg.data_length)) {
D("remote local: terminated (data)");
return -1;
}
return 0;
}
static int remote_write(apacket *p, atransport *t)
{
int length = p->msg.data_length;
if(!WriteFdExactly(t->sfd, &p->msg, sizeof(amessage) + length)) {
D("remote local: write terminated");
return -1;
}
return 0;
}
bool local_connect(int port) {
std::string dummy;
return local_connect_arbitrary_ports(port-1, port, &dummy) == 0;
return local_connect_arbitrary_ports(port - 1, port, &dummy) == 0;
}
void connect_device(const std::string& address, std::string* response) {
@ -423,130 +393,83 @@ void local_init(int port)
std::thread(func, port).detach();
}
static void remote_kick(atransport *t)
{
int fd = t->sfd;
t->sfd = -1;
adb_shutdown(fd);
adb_close(fd);
#if ADB_HOST
int nn;
std::lock_guard<std::mutex> lock(local_transports_lock);
for (nn = 0; nn < ADB_LOCAL_TRANSPORT_MAX; nn++) {
if (local_transports[nn] == t) {
local_transports[nn] = NULL;
break;
}
}
#endif
}
struct EmulatorConnection : public FdConnection {
EmulatorConnection(unique_fd fd, int local_port)
: FdConnection(std::move(fd)), local_port_(local_port) {}
static void remote_close(atransport *t)
{
int fd = t->sfd;
if (fd != -1) {
t->sfd = -1;
adb_close(fd);
}
#if ADB_HOST
int local_port;
if (t->GetLocalPortForEmulator(&local_port)) {
VLOG(TRANSPORT) << "remote_close, local_port = " << local_port;
~EmulatorConnection() {
VLOG(TRANSPORT) << "remote_close, local_port = " << local_port_;
std::unique_lock<std::mutex> lock(retry_ports_lock);
RetryPort port;
port.port = local_port;
port.port = local_port_;
port.retry_count = LOCAL_PORT_RETRY_COUNT;
retry_ports.push_back(port);
retry_ports_cond.notify_one();
}
#endif
}
void Close() override {
std::lock_guard<std::mutex> lock(local_transports_lock);
local_transports.erase(local_port_);
FdConnection::Close();
}
int local_port_;
};
#if ADB_HOST
/* Only call this function if you already hold local_transports_lock. */
static atransport* find_emulator_transport_by_adb_port_locked(int adb_port)
{
int i;
for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) {
int local_port;
if (local_transports[i] && local_transports[i]->GetLocalPortForEmulator(&local_port)) {
if (local_port == adb_port) {
return local_transports[i];
}
}
REQUIRES(local_transports_lock) {
auto it = local_transports.find(adb_port);
if (it == local_transports.end()) {
return nullptr;
}
return NULL;
return it->second;
}
std::string getEmulatorSerialString(int console_port)
{
std::string getEmulatorSerialString(int console_port) {
return android::base::StringPrintf("emulator-%d", console_port);
}
atransport* find_emulator_transport_by_adb_port(int adb_port)
{
atransport* find_emulator_transport_by_adb_port(int adb_port) {
std::lock_guard<std::mutex> lock(local_transports_lock);
atransport* result = find_emulator_transport_by_adb_port_locked(adb_port);
return result;
return find_emulator_transport_by_adb_port_locked(adb_port);
}
atransport* find_emulator_transport_by_console_port(int console_port)
{
atransport* find_emulator_transport_by_console_port(int console_port) {
return find_transport(getEmulatorSerialString(console_port).c_str());
}
/* Only call this function if you already hold local_transports_lock. */
int get_available_local_transport_index_locked()
{
int i;
for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) {
if (local_transports[i] == NULL) {
return i;
}
}
return -1;
}
int get_available_local_transport_index()
{
std::lock_guard<std::mutex> lock(local_transports_lock);
int result = get_available_local_transport_index_locked();
return result;
}
#endif
int init_socket_transport(atransport *t, int s, int adb_port, int local)
{
int fail = 0;
int init_socket_transport(atransport* t, int s, int adb_port, int local) {
int fail = 0;
t->SetKickFunction(remote_kick);
t->SetWriteFunction(remote_write);
t->close = remote_close;
t->read_from_remote = remote_read;
t->sfd = s;
unique_fd fd(s);
t->sync_token = 1;
t->type = kTransportLocal;
#if ADB_HOST
// Emulator connection.
if (local) {
t->connection.reset(new EmulatorConnection(std::move(fd), adb_port));
std::lock_guard<std::mutex> lock(local_transports_lock);
t->SetLocalPortForEmulator(adb_port);
atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
int index = get_available_local_transport_index_locked();
if (existing_transport != NULL) {
D("local transport for port %d already registered (%p)?", adb_port, existing_transport);
fail = -1;
} else if (index < 0) {
} else if (local_transports.size() >= ADB_LOCAL_TRANSPORT_MAX) {
// Too many emulators.
D("cannot register more emulators. Maximum is %d", ADB_LOCAL_TRANSPORT_MAX);
fail = -1;
} else {
local_transports[index] = t;
local_transports[adb_port] = t;
}
return fail;
}
#endif
// Regular tcp connection.
t->connection.reset(new FdConnection(std::move(fd)));
return fail;
}

View File

@ -20,22 +20,6 @@
#include "adb.h"
TEST(transport, kick_transport) {
atransport t;
static size_t kick_count;
kick_count = 0;
// Mutate some member so we can test that the function is run.
t.SetKickFunction([](atransport* trans) { kick_count++; });
ASSERT_FALSE(t.IsKicked());
t.Kick();
ASSERT_TRUE(t.IsKicked());
ASSERT_EQ(1u, kick_count);
// A transport can only be kicked once.
t.Kick();
ASSERT_TRUE(t.IsKicked());
ASSERT_EQ(1u, kick_count);
}
static void DisconnectFunc(void* arg, atransport*) {
int* count = reinterpret_cast<int*>(arg);
++*count;

View File

@ -80,25 +80,18 @@ static int UsbReadPayload(usb_handle* h, apacket* p) {
#endif
}
static int remote_read(apacket* p, atransport* t) {
int n = UsbReadMessage(t->usb, &p->msg);
static int remote_read(apacket* p, usb_handle* usb) {
int n = UsbReadMessage(usb, &p->msg);
if (n < 0) {
D("remote usb: read terminated (message)");
return -1;
}
if (static_cast<size_t>(n) != sizeof(p->msg) || !check_header(p, t)) {
D("remote usb: check_header failed, skip it");
goto err_msg;
}
if (t->GetConnectionState() == kCsOffline) {
// If we read a wrong msg header declaring a large message payload, don't read its payload.
// Otherwise we may miss true messages from the device.
if (p->msg.command != A_CNXN && p->msg.command != A_AUTH) {
goto err_msg;
}
if (static_cast<size_t>(n) != sizeof(p->msg)) {
D("remote usb: read received unexpected header length %d", n);
return -1;
}
if (p->msg.data_length) {
n = UsbReadPayload(t->usb, p);
n = UsbReadPayload(usb, p);
if (n < 0) {
D("remote usb: terminated (data)");
return -1;
@ -106,34 +99,24 @@ static int remote_read(apacket* p, atransport* t) {
if (static_cast<uint32_t>(n) != p->msg.data_length) {
D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it",
p->msg.data_length, n);
goto err_msg;
return -1;
}
}
return 0;
err_msg:
p->msg.command = 0;
return 0;
}
#else
// On Android devices, we rely on the kernel to provide buffered read.
// So we can recover automatically from EOVERFLOW.
static int remote_read(apacket *p, atransport *t)
{
if (usb_read(t->usb, &p->msg, sizeof(amessage))) {
static int remote_read(apacket* p, usb_handle* usb) {
if (usb_read(usb, &p->msg, sizeof(amessage))) {
PLOG(ERROR) << "remote usb: read terminated (message)";
return -1;
}
if (!check_header(p, t)) {
LOG(ERROR) << "remote usb: check_header failed";
return -1;
}
if (p->msg.data_length) {
if (usb_read(t->usb, p->data, p->msg.data_length)) {
if (usb_read(usb, p->data, p->msg.data_length)) {
PLOG(ERROR) << "remote usb: terminated (data)";
return -1;
}
@ -143,45 +126,43 @@ static int remote_read(apacket *p, atransport *t)
}
#endif
static int remote_write(apacket *p, atransport *t)
{
unsigned size = p->msg.data_length;
UsbConnection::~UsbConnection() {
usb_close(handle_);
}
if (usb_write(t->usb, &p->msg, sizeof(amessage))) {
bool UsbConnection::Read(apacket* packet) {
int rc = remote_read(packet, handle_);
return rc == 0;
}
bool UsbConnection::Write(apacket* packet) {
unsigned size = packet->msg.data_length;
if (usb_write(handle_, &packet->msg, sizeof(packet->msg)) != 0) {
PLOG(ERROR) << "remote usb: 1 - write terminated";
return -1;
return false;
}
if (p->msg.data_length == 0) return 0;
if (usb_write(t->usb, &p->data, size)) {
if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
PLOG(ERROR) << "remote usb: 2 - write terminated";
return -1;
return false;
}
return 0;
return true;
}
static void remote_close(atransport* t) {
usb_close(t->usb);
t->usb = 0;
}
static void remote_kick(atransport* t) {
usb_kick(t->usb);
void UsbConnection::Close() {
usb_kick(handle_);
}
void init_usb_transport(atransport* t, usb_handle* h) {
D("transport: usb");
t->close = remote_close;
t->SetKickFunction(remote_kick);
t->SetWriteFunction(remote_write);
t->read_from_remote = remote_read;
t->connection.reset(new UsbConnection(h));
t->sync_token = 1;
t->type = kTransportUsb;
t->usb = h;
}
int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol)
{
int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol) {
return (usb_class == ADB_CLASS && usb_subclass == ADB_SUBCLASS && usb_protocol == ADB_PROTOCOL);
}