diff --git a/adb/socket.h b/adb/socket.h index 4083036fa..9eb1b1955 100644 --- a/adb/socket.h +++ b/adb/socket.h @@ -114,4 +114,13 @@ asocket *create_remote_socket(unsigned id, atransport *t); void connect_to_remote(asocket *s, const char *destination); void connect_to_smartsocket(asocket *s); +// Internal functions that are only made available here for testing purposes. +namespace internal { + +#if ADB_HOST +char* skip_host_serial(const char* service); +#endif + +} // namespace internal + #endif // __ADB_SOCKET_H diff --git a/adb/socket_test.cpp b/adb/socket_test.cpp index 471ca09e4..5cbef6dcf 100644 --- a/adb/socket_test.cpp +++ b/adb/socket_test.cpp @@ -270,3 +270,49 @@ TEST_F(LocalSocketTest, close_socket_in_CLOSE_WAIT_state) { } #endif // defined(__linux__) + +#if ADB_HOST + +// Checks that skip_host_serial(serial) returns a pointer to the part of |serial| which matches +// |expected|, otherwise logs the failure to gtest. +void VerifySkipHostSerial(const std::string& serial, const char* expected) { + const char* result = internal::skip_host_serial(serial.c_str()); + if (expected == nullptr) { + EXPECT_EQ(nullptr, result); + } else { + EXPECT_STREQ(expected, result); + } +} + +// Check [tcp:|udp:][:]: format. +TEST(socket_test, test_skip_host_serial) { + for (const std::string& protocol : {"", "tcp:", "udp:"}) { + VerifySkipHostSerial(protocol, nullptr); + VerifySkipHostSerial(protocol + "foo", nullptr); + + VerifySkipHostSerial(protocol + "foo:bar", ":bar"); + VerifySkipHostSerial(protocol + "foo:bar:baz", ":bar:baz"); + + VerifySkipHostSerial(protocol + "foo:123:bar", ":bar"); + VerifySkipHostSerial(protocol + "foo:123:456", ":456"); + VerifySkipHostSerial(protocol + "foo:123:bar:baz", ":bar:baz"); + + // Don't register a port unless it's all numbers and ends with ':'. + VerifySkipHostSerial(protocol + "foo:123", ":123"); + VerifySkipHostSerial(protocol + "foo:123bar:baz", ":123bar:baz"); + } +} + +// Check :: format. +TEST(socket_test, test_skip_host_serial_prefix) { + for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) { + VerifySkipHostSerial(prefix, nullptr); + VerifySkipHostSerial(prefix + "foo", nullptr); + + VerifySkipHostSerial(prefix + "foo:bar", ":bar"); + VerifySkipHostSerial(prefix + "foo:bar:baz", ":bar:baz"); + VerifySkipHostSerial(prefix + "foo:123:bar", ":123:bar"); + } +} + +#endif // ADB_HOST diff --git a/adb/sockets.cpp b/adb/sockets.cpp index d8e4e9368..c083ee144 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -26,6 +26,8 @@ #include #include +#include +#include #if !ADB_HOST #include "cutils/properties.h" @@ -623,43 +625,43 @@ static unsigned unhex(unsigned char *s, int len) #if ADB_HOST -#define PREFIX(str) { str, sizeof(str) - 1 } -static const struct prefix_struct { - const char *str; - const size_t len; -} prefixes[] = { - PREFIX("usb:"), - PREFIX("product:"), - PREFIX("model:"), - PREFIX("device:"), -}; -static const int num_prefixes = (sizeof(prefixes) / sizeof(prefixes[0])); +namespace internal { -/* skip_host_serial return the position in a string - skipping over the 'serial' parameter in the ADB protocol, - where parameter string may be a host:port string containing - the protocol delimiter (colon). */ -static char *skip_host_serial(char *service) { - char *first_colon, *serial_end; - int i; +// Returns the position in |service| following the target serial parameter. Serial format can be +// any of: +// * [tcp:|udp:][:]: +// * :: +// Where must be a base-10 number and may be any of {usb,product,model,device}. +// +// The returned pointer will point to the ':' just before , or nullptr if not found. +char* skip_host_serial(const char* service) { + static const std::vector& prefixes = + *(new std::vector{"usb:", "product:", "model:", "device:"}); - for (i = 0; i < num_prefixes; i++) { - if (!strncmp(service, prefixes[i].str, prefixes[i].len)) - return strchr(service + prefixes[i].len, ':'); + for (const std::string& prefix : prefixes) { + if (!strncmp(service, prefix.c_str(), prefix.length())) { + return strchr(service + prefix.length(), ':'); + } } - first_colon = strchr(service, ':'); + // For fastboot compatibility, ignore protocol prefixes. + if (!strncmp(service, "tcp:", 4) || !strncmp(service, "udp:", 4)) { + service += 4; + } + + char* first_colon = strchr(service, ':'); if (!first_colon) { - /* No colon in service string. */ - return NULL; + // No colon in service string. + return nullptr; } - serial_end = first_colon; + + char* serial_end = first_colon; if (isdigit(serial_end[1])) { serial_end++; - while ((*serial_end) && isdigit(*serial_end)) { + while (*serial_end && isdigit(*serial_end)) { serial_end++; } - if ((*serial_end) != ':') { + if (*serial_end != ':') { // Something other than numbers was found, reset the end. serial_end = first_colon; } @@ -667,6 +669,8 @@ static char *skip_host_serial(char *service) { return serial_end; } +} // namespace internal + #endif // ADB_HOST static int smart_socket_enqueue(asocket *s, apacket *p) @@ -725,7 +729,7 @@ static int smart_socket_enqueue(asocket *s, apacket *p) service += strlen("host-serial:"); // serial number should follow "host:" and could be a host:port string. - serial_end = skip_host_serial(service); + serial_end = internal::skip_host_serial(service); if (serial_end) { *serial_end = 0; // terminate string serial = service; diff --git a/adb/transport.cpp b/adb/transport.cpp index d9180bc45..e3340afb3 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -30,6 +30,7 @@ #include #include +#include #include #include @@ -679,11 +680,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, // Check for matching serial number. if (serial) { - if ((t->serial && !strcmp(serial, t->serial)) || - (t->devpath && !strcmp(serial, t->devpath)) || - qual_match(serial, "product:", t->product, false) || - qual_match(serial, "model:", t->model, true) || - qual_match(serial, "device:", t->device, false)) { + if (t->MatchesTarget(serial)) { if (result) { *error_out = "more than one device"; if (is_ambiguous) *is_ambiguous = true; @@ -837,6 +834,43 @@ void atransport::RunDisconnects() { disconnects_.clear(); } +bool atransport::MatchesTarget(const std::string& target) const { + if (serial) { + if (target == serial) { + return true; + } else if (type == kTransportLocal) { + // Local transports can match [tcp:|udp:][:port]. + const char* local_target_ptr = target.c_str(); + + // For fastboot compatibility, ignore protocol prefixes. + if (android::base::StartsWith(target, "tcp:") || + android::base::StartsWith(target, "udp:")) { + local_target_ptr += 4; + } + + // Parse our |serial| and the given |target| to check if the hostnames and ports match. + std::string serial_host, error; + int serial_port = -1; + if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr, + &error)) { + // |target| may omit the port to default to ours. + std::string target_host; + int target_port = serial_port; + if (android::base::ParseNetAddress(local_target_ptr, &target_host, &target_port, + nullptr, &error) && + serial_host == target_host && serial_port == target_port) { + return true; + } + } + } + } + + return (devpath && target == devpath) || + qual_match(target.c_str(), "product:", product, false) || + qual_match(target.c_str(), "model:", model, true) || + qual_match(target.c_str(), "device:", device, false); +} + #if ADB_HOST static void append_transport_info(std::string* result, const char* key, diff --git a/adb/transport.h b/adb/transport.h index 4c0c00887..5857249db 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -107,6 +107,21 @@ public: void RemoveDisconnect(adisconnect* disconnect); void RunDisconnects(); + // Returns true if |target| matches this transport. A matching |target| can be any of: + // * + // * + // * product: + // * model: + // * device: + // + // If this is a local transport, serial will also match [tcp:|udp:][:port] targets. + // For example, serial "100.100.100.100:5555" would match any of: + // * 100.100.100.100 + // * tcp:100.100.100.100 + // * udp:100.100.100.100:5555 + // This is to make it easier to use the same network target for both fastboot and adb. + bool MatchesTarget(const std::string& target) const; + private: // A set of features transmitted in the banner with the initial connection. // This is stored in the banner as 'features=feature0,feature1,etc'. diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp index 1bdea2a57..2028eccbb 100644 --- a/adb/transport_test.cpp +++ b/adb/transport_test.cpp @@ -218,3 +218,60 @@ TEST(transport, parse_banner_features) { ASSERT_EQ(std::string("bar"), t.model); ASSERT_EQ(std::string("baz"), t.device); } + +TEST(transport, test_matches_target) { + std::string serial = "foo"; + std::string devpath = "/path/to/bar"; + std::string product = "test_product"; + std::string model = "test_model"; + std::string device = "test_device"; + + atransport t; + t.serial = &serial[0]; + t.devpath = &devpath[0]; + t.product = &product[0]; + t.model = &model[0]; + t.device = &device[0]; + + // These tests should not be affected by the transport type. + for (TransportType type : {kTransportAny, kTransportLocal}) { + t.type = type; + + EXPECT_TRUE(t.MatchesTarget(serial)); + EXPECT_TRUE(t.MatchesTarget(devpath)); + EXPECT_TRUE(t.MatchesTarget("product:" + product)); + EXPECT_TRUE(t.MatchesTarget("model:" + model)); + EXPECT_TRUE(t.MatchesTarget("device:" + device)); + + // Product, model, and device don't match without the prefix. + EXPECT_FALSE(t.MatchesTarget(product)); + EXPECT_FALSE(t.MatchesTarget(model)); + EXPECT_FALSE(t.MatchesTarget(device)); + } +} + +TEST(transport, test_matches_target_local) { + std::string serial = "100.100.100.100:5555"; + + atransport t; + t.serial = &serial[0]; + + // Network address matching should only be used for local transports. + for (TransportType type : {kTransportAny, kTransportLocal}) { + t.type = type; + bool should_match = (type == kTransportLocal); + + EXPECT_EQ(should_match, t.MatchesTarget("100.100.100.100")); + EXPECT_EQ(should_match, t.MatchesTarget("tcp:100.100.100.100")); + EXPECT_EQ(should_match, t.MatchesTarget("tcp:100.100.100.100:5555")); + EXPECT_EQ(should_match, t.MatchesTarget("udp:100.100.100.100")); + EXPECT_EQ(should_match, t.MatchesTarget("udp:100.100.100.100:5555")); + + // Wrong protocol, hostname, or port should never match. + EXPECT_FALSE(t.MatchesTarget("100.100.100")); + EXPECT_FALSE(t.MatchesTarget("100.100.100.100:")); + EXPECT_FALSE(t.MatchesTarget("100.100.100.100:-1")); + EXPECT_FALSE(t.MatchesTarget("100.100.100.100:5554")); + EXPECT_FALSE(t.MatchesTarget("abc:100.100.100.100")); + } +}