diff --git a/adb/adb.cpp b/adb/adb.cpp index 49cf123ce..fa935f61b 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -244,11 +244,11 @@ void handle_offline(atransport *t) //Close the associated usb t->online = 0; - // This is necessary to avoid a race condition that occured when a transport closes + // This is necessary to avoid a race condition that occurred when a transport closes // while a client socket is still active. close_all_sockets(t); - run_transport_disconnects(t); + t->RunDisconnects(); } #if DEBUG_PACKETS diff --git a/adb/adb.h b/adb/adb.h index 6855f3b0c..0fb2008f2 100644 --- a/adb/adb.h +++ b/adb/adb.h @@ -157,8 +157,6 @@ struct adisconnect { void (*func)(void* opaque, atransport* t); void* opaque; - adisconnect* next; - adisconnect* prev; }; diff --git a/adb/adb_auth_client.cpp b/adb/adb_auth_client.cpp index be28202a4..c3af02408 100644 --- a/adb/adb_auth_client.cpp +++ b/adb/adb_auth_client.cpp @@ -47,7 +47,7 @@ static fdevent listener_fde; static int framework_fd = -1; static void usb_disconnected(void* unused, atransport* t); -static struct adisconnect usb_disconnect = { usb_disconnected, 0, 0, 0 }; +static struct adisconnect usb_disconnect = { usb_disconnected, nullptr}; static atransport* usb_transport; static bool needs_retry = false; @@ -164,7 +164,6 @@ int adb_auth_verify(uint8_t* token, uint8_t* sig, int siglen) static void usb_disconnected(void* unused, atransport* t) { D("USB disconnect\n"); - remove_transport_disconnect(usb_transport, &usb_disconnect); usb_transport = NULL; needs_retry = false; } @@ -196,7 +195,7 @@ void adb_auth_confirm_key(unsigned char *key, size_t len, atransport *t) if (!usb_transport) { usb_transport = t; - add_transport_disconnect(t, &usb_disconnect); + t->AddDisconnect(&usb_disconnect); } if (framework_fd < 0) { diff --git a/adb/adb_listeners.cpp b/adb/adb_listeners.cpp index 8fb2d19da..d5b1fd558 100644 --- a/adb/adb_listeners.cpp +++ b/adb/adb_listeners.cpp @@ -101,13 +101,15 @@ static void free_listener(alistener* l) free((char*)l->connect_to); if (l->transport) { - remove_transport_disconnect(l->transport, &l->disconnect); + l->transport->RemoveDisconnect(&l->disconnect); } free(l); } -static void listener_disconnect(void* listener, atransport* t) { - free_listener(reinterpret_cast(listener)); +static void listener_disconnect(void* arg, atransport*) { + alistener* listener = reinterpret_cast(arg); + listener->transport = nullptr; + free_listener(listener); } static int local_name_to_fd(const char* name, std::string* error) { @@ -159,7 +161,7 @@ InstallStatus remove_listener(const char *local_name, atransport* transport) { for (l = listener_list.next; l != &listener_list; l = l->next) { if (!strcmp(local_name, l->local_name)) { - listener_disconnect(l, l->transport); + free_listener(l); return INSTALL_STATUS_OK; } } @@ -174,7 +176,7 @@ void remove_all_listeners(void) // Never remove smart sockets. if (l->connect_to[0] == '*') continue; - listener_disconnect(l, l->transport); + free_listener(l); } } @@ -209,9 +211,9 @@ InstallStatus install_listener(const std::string& local_name, free((void*) l->connect_to); l->connect_to = cto; if (l->transport != transport) { - remove_transport_disconnect(l->transport, &l->disconnect); + l->transport->RemoveDisconnect(&l->disconnect); l->transport = transport; - add_transport_disconnect(l->transport, &l->disconnect); + l->transport->AddDisconnect(&l->disconnect); } return INSTALL_STATUS_OK; } @@ -260,7 +262,7 @@ InstallStatus install_listener(const std::string& local_name, if (transport) { listener->disconnect.opaque = listener; listener->disconnect.func = listener_disconnect; - add_transport_disconnect(transport, &listener->disconnect); + transport->AddDisconnect(&listener->disconnect); } return INSTALL_STATUS_OK; diff --git a/adb/transport.cpp b/adb/transport.cpp index 4dc5e4a1e..6ce5d7f51 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -42,36 +42,6 @@ static std::list pending_list; ADB_MUTEX_DEFINE( transport_lock ); -// Each atransport contains a list of adisconnects (t->disconnects). -// An adisconnect contains a link to the next/prev adisconnect, a function -// pointer to a disconnect callback which takes a void* piece of user data and -// the atransport, and some user data for the callback (helpfully named -// "opaque"). -// -// The list is circular. New items are added to the entry member of the list -// (t->disconnects) by add_transport_disconnect. -// -// run_transport_disconnects invokes each function in the list. -// -// Gotchas: -// * run_transport_disconnects assumes that t->disconnects is non-null, so -// this can't be run on a zeroed atransport. -// * The callbacks in this list are not removed when called, and this function -// is not guarded against running more than once. As such, ensure that this -// function is not called multiple times on the same atransport. -// TODO(danalbert): Just fix this so that it is guarded once you have tests. -void run_transport_disconnects(atransport* t) -{ - adisconnect* dis = t->disconnects.next; - - D("%s: run_transport_disconnects\n", t->serial); - while (dis != &t->disconnects) { - adisconnect* next = dis->next; - dis->func( dis->opaque, t ); - dis = next; - } -} - static void dump_packet(const char* name, const char* func, apacket* p) { unsigned command = p->msg.command; int len = p->msg.data_length; @@ -588,8 +558,6 @@ static void transport_registration_func(int _fd, unsigned ev, void *data) transport_list.push_front(t); adb_mutex_unlock(&transport_lock); - t->disconnects.next = t->disconnects.prev = &t->disconnects; - update_transports(); } @@ -653,23 +621,6 @@ static void transport_unref(atransport* t) { adb_mutex_unlock(&transport_lock); } -void add_transport_disconnect(atransport* t, adisconnect* dis) -{ - adb_mutex_lock(&transport_lock); - dis->next = &t->disconnects; - dis->prev = dis->next->prev; - dis->prev->next = dis; - dis->next->prev = dis; - adb_mutex_unlock(&transport_lock); -} - -void remove_transport_disconnect(atransport* t, adisconnect* dis) -{ - dis->prev->next = dis->next; - dis->next->prev = dis->prev; - dis->next = dis->prev = dis; -} - static int qual_match(const char *to_test, const char *prefix, const char *qual, bool sanitize_qual) { @@ -844,6 +795,21 @@ bool atransport::CanUseFeature(const std::string& feature) const { return has_feature(feature) && supported_features().count(feature) > 0; } +void atransport::AddDisconnect(adisconnect* disconnect) { + disconnects_.push_back(disconnect); +} + +void atransport::RemoveDisconnect(adisconnect* disconnect) { + disconnects_.remove(disconnect); +} + +void atransport::RunDisconnects() { + for (auto& disconnect : disconnects_) { + disconnect->func(disconnect->opaque, this); + } + disconnects_.clear(); +} + #if ADB_HOST static void append_transport_info(std::string* result, const char* key, diff --git a/adb/transport.h b/adb/transport.h index abb26a7d9..3b56c55e0 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -19,6 +19,7 @@ #include +#include #include #include @@ -71,9 +72,6 @@ public: int adb_port = -1; // Use for emulators (local transport) bool kicked = false; - // A list of adisconnect callbacks called when the transport is kicked. - adisconnect disconnects = {}; - void* key = nullptr; unsigned char token[TOKEN_SIZE] = {}; fdevent auth_fde; @@ -96,6 +94,10 @@ public: // feature. bool CanUseFeature(const std::string& feature) const; + void AddDisconnect(adisconnect* disconnect); + void RemoveDisconnect(adisconnect* disconnect); + void RunDisconnects(); + private: // A set of features transmitted in the banner with the initial connection. // This is stored in the banner as 'features=feature0,feature1,etc'. @@ -103,6 +105,9 @@ private: int protocol_version; size_t max_payload; + // A list of adisconnect callbacks called when the transport is kicked. + std::list disconnects_; + DISALLOW_COPY_AND_ASSIGN(atransport); }; @@ -114,10 +119,7 @@ private: */ atransport* acquire_one_transport(ConnectionState state, TransportType type, const char* serial, std::string* error_out); -void add_transport_disconnect(atransport* t, adisconnect* dis); -void remove_transport_disconnect(atransport* t, adisconnect* dis); void kick_transport(atransport* t); -void run_transport_disconnects(atransport* t); void update_transports(void); void init_transport_registration(void); diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp index 743d97d32..10872ac3e 100644 --- a/adb/transport_test.cpp +++ b/adb/transport_test.cpp @@ -51,9 +51,6 @@ public: EXPECT_EQ(adb_port, rhs.adb_port); EXPECT_EQ(kicked, rhs.kicked); - EXPECT_EQ( - 0, memcmp(&disconnects, &rhs.disconnects, sizeof(adisconnect))); - EXPECT_EQ(key, rhs.key); EXPECT_EQ(0, memcmp(token, rhs.token, TOKEN_SIZE)); EXPECT_EQ(0, memcmp(&auth_fde, &rhs.auth_fde, sizeof(fdevent))); @@ -118,12 +115,33 @@ TEST(transport, kick_transport_already_kicked) { ASSERT_EQ(expected, t); } -// Disabled because the function currently segfaults for a zeroed atransport. I -// want to make sure I understand how this is working at all before I try fixing -// that. -TEST(transport, DISABLED_run_transport_disconnects_zeroed_atransport) { +static void DisconnectFunc(void* arg, atransport*) { + int* count = reinterpret_cast(arg); + ++*count; +} + +TEST(transport, RunDisconnects) { atransport t; - run_transport_disconnects(&t); + // RunDisconnects() can be called with an empty atransport. + t.RunDisconnects(); + + int count = 0; + adisconnect disconnect; + disconnect.func = DisconnectFunc; + disconnect.opaque = &count; + t.AddDisconnect(&disconnect); + t.RunDisconnects(); + ASSERT_EQ(1, count); + + // disconnect should have been removed automatically. + t.RunDisconnects(); + ASSERT_EQ(1, count); + + count = 0; + t.AddDisconnect(&disconnect); + t.RemoveDisconnect(&disconnect); + t.RunDisconnects(); + ASSERT_EQ(0, count); } TEST(transport, add_feature) {