diff --git a/adb/transport.cpp b/adb/transport.cpp index aaab21daa..1cf4c49a9 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -295,20 +295,12 @@ static void write_transport_thread(void* _t) { transport_unref(t); } -static void kick_transport_locked(atransport* t) { - CHECK(t != nullptr); - if (!t->kicked) { - t->kicked = true; - t->kick(t); - } -} - void kick_transport(atransport* t) { adb_mutex_lock(&transport_lock); // As kick_transport() can be called from threads without guarantee that t is valid, // check if the transport is in transport_list first. if (std::find(transport_list.begin(), transport_list.end(), t) != transport_list.end()) { - kick_transport_locked(t); + t->Kick(); } adb_mutex_unlock(&transport_lock); } @@ -621,7 +613,7 @@ static void transport_unref(atransport* t) { t->ref_count--; if (t->ref_count == 0) { D("transport: %s unref (kicking and closing)", t->serial); - kick_transport_locked(t); + t->Kick(); t->close(t); remove_transport(t); } else { @@ -748,6 +740,14 @@ atransport* acquire_one_transport(TransportType type, const char* serial, return result; } +void atransport::Kick() { + if (!kicked_) { + kicked_ = true; + CHECK(kick_func_ != nullptr); + kick_func_(this); + } +} + const std::string atransport::connection_state_name() const { switch (connection_state) { case kCsOffline: return "offline"; @@ -928,10 +928,7 @@ std::string list_transports(bool long_listing) { void close_usb_devices() { adb_mutex_lock(&transport_lock); for (const auto& t : transport_list) { - if (!t->kicked) { - t->kicked = 1; - t->kick(t); - } + t->Kick(); } adb_mutex_unlock(&transport_lock); } @@ -1002,7 +999,7 @@ void kick_all_tcp_devices() { // the read_transport thread will notify the main thread to make this transport // offline. Then the main thread will notify the write_transport thread to exit. // Finally, this transport will be closed and freed in the main thread. - kick_transport_locked(t); + t->Kick(); } } adb_mutex_unlock(&transport_lock); diff --git a/adb/transport.h b/adb/transport.h index 5857249db..35d7b505d 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -60,7 +60,13 @@ public: int (*read_from_remote)(apacket* p, atransport* t) = nullptr; int (*write_to_remote)(apacket* p, atransport* t) = nullptr; void (*close)(atransport* t) = nullptr; - void (*kick)(atransport* t) = nullptr; + void SetKickFunction(void (*kick_func)(atransport*)) { + kick_func_ = kick_func; + } + bool IsKicked() { + return kicked_; + } + void Kick(); int fd = -1; int transport_socket = -1; @@ -82,7 +88,6 @@ public: char* device = nullptr; char* devpath = nullptr; int adb_port = -1; // Use for emulators (local transport) - bool kicked = false; void* key = nullptr; unsigned char token[TOKEN_SIZE] = {}; @@ -123,6 +128,9 @@ public: bool MatchesTarget(const std::string& target) const; private: + bool kicked_ = false; + void (*kick_func_)(atransport*) = nullptr; + // A set of features transmitted in the banner with the initial connection. // This is stored in the banner as 'features=feature0,feature1,etc'. FeatureSet features_; diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp index f6c9df4e8..4121f472c 100644 --- a/adb/transport_local.cpp +++ b/adb/transport_local.cpp @@ -388,7 +388,7 @@ int init_socket_transport(atransport *t, int s, int adb_port, int local) { int fail = 0; - t->kick = remote_kick; + t->SetKickFunction(remote_kick); t->close = remote_close; t->read_from_remote = remote_read; t->write_to_remote = remote_write; diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp index 2028eccbb..a6db07acf 100644 --- a/adb/transport_test.cpp +++ b/adb/transport_test.cpp @@ -20,47 +20,6 @@ #include "adb.h" -class TestTransport : public atransport { -public: - bool operator==(const atransport& rhs) const { - EXPECT_EQ(read_from_remote, rhs.read_from_remote); - EXPECT_EQ(write_to_remote, rhs.write_to_remote); - EXPECT_EQ(close, rhs.close); - EXPECT_EQ(kick, rhs.kick); - - EXPECT_EQ(fd, rhs.fd); - EXPECT_EQ(transport_socket, rhs.transport_socket); - - EXPECT_EQ( - 0, memcmp(&transport_fde, &rhs.transport_fde, sizeof(fdevent))); - - EXPECT_EQ(ref_count, rhs.ref_count); - EXPECT_EQ(sync_token, rhs.sync_token); - EXPECT_EQ(connection_state, rhs.connection_state); - EXPECT_EQ(online, rhs.online); - EXPECT_EQ(type, rhs.type); - - EXPECT_EQ(usb, rhs.usb); - EXPECT_EQ(sfd, rhs.sfd); - - EXPECT_EQ(serial, rhs.serial); - EXPECT_EQ(product, rhs.product); - EXPECT_EQ(model, rhs.model); - EXPECT_EQ(device, rhs.device); - EXPECT_EQ(devpath, rhs.devpath); - EXPECT_EQ(adb_port, rhs.adb_port); - EXPECT_EQ(kicked, rhs.kicked); - - EXPECT_EQ(key, rhs.key); - EXPECT_EQ(0, memcmp(token, rhs.token, TOKEN_SIZE)); - EXPECT_EQ(failed_auth_attempts, rhs.failed_auth_attempts); - - EXPECT_EQ(features(), rhs.features()); - - return true; - } -}; - class TransportSetup { public: TransportSetup() { @@ -83,35 +42,19 @@ public: static TransportSetup g_TransportSetup; TEST(transport, kick_transport) { - TestTransport t; - + atransport t; + static size_t kick_count; + kick_count = 0; // Mutate some member so we can test that the function is run. - t.kick = [](atransport* trans) { trans->fd = 42; }; - - TestTransport expected; - expected.kick = t.kick; - expected.fd = 42; - expected.kicked = 1; - - kick_transport(&t); - ASSERT_EQ(42, t.fd); - ASSERT_EQ(1, t.kicked); - ASSERT_EQ(expected, t); -} - -TEST(transport, kick_transport_already_kicked) { - // Ensure that the transport is not modified if the transport has already been - // kicked. - TestTransport t; - t.kicked = 1; - t.kick = [](atransport*) { FAIL() << "Kick should not have been called"; }; - - TestTransport expected; - expected.kicked = 1; - expected.kick = t.kick; - - kick_transport(&t); - ASSERT_EQ(expected, t); + t.SetKickFunction([](atransport* trans) { kick_count++; }); + ASSERT_FALSE(t.IsKicked()); + t.Kick(); + ASSERT_TRUE(t.IsKicked()); + ASSERT_EQ(1u, kick_count); + // A transport can only be kicked once. + t.Kick(); + ASSERT_TRUE(t.IsKicked()); + ASSERT_EQ(1u, kick_count); } static void DisconnectFunc(void* arg, atransport*) { diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index 263f9e70b..d05d9285c 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -84,7 +84,7 @@ void init_usb_transport(atransport *t, usb_handle *h, ConnectionState state) { D("transport: usb"); t->close = remote_close; - t->kick = remote_kick; + t->SetKickFunction(remote_kick); t->read_from_remote = remote_read; t->write_to_remote = remote_write; t->sync_token = 1;