diff --git a/adb/client/transport_mdns.cpp b/adb/client/transport_mdns.cpp index 961121202..a0fc9caf4 100644 --- a/adb/client/transport_mdns.cpp +++ b/adb/client/transport_mdns.cpp @@ -249,6 +249,9 @@ class ResolvedService : public AsyncServiceRef { return false; } + // Remove any services with the same instance name, as it may be a stale registration. + removeDNSService(regType_.c_str(), serviceName_.c_str()); + // Add to the service registry before trying to auto-connect, since socket_spec_connect will // check these registries for the ip address when connecting via mdns instance name. int adbSecureServiceType = serviceIndex(); @@ -268,13 +271,6 @@ class ResolvedService : public AsyncServiceRef { return false; } - if (!services->empty()) { - // Remove the previous resolved service, if any. - services->erase(std::remove_if(services->begin(), services->end(), - [&](std::unique_ptr& service) { - return (serviceName_ == service->serviceName()); - })); - } services->push_back(std::unique_ptr(this)); if (adb_DNSServiceShouldAutoConnect(regType_.c_str(), serviceName_.c_str())) { @@ -327,6 +323,8 @@ class ResolvedService : public AsyncServiceRef { static bool connectByServiceName(const ServiceRegistry& services, const std::string& service_name); + static void removeDNSService(const char* regType, const char* serviceName); + private: int clientVersion_ = ADB_SECURE_CLIENT_VERSION; std::string addr_format_; @@ -396,6 +394,37 @@ bool ResolvedService::connectByServiceName(const ServiceRegistry& services, return false; } +// static +void ResolvedService::removeDNSService(const char* regType, const char* serviceName) { + D("%s: regType=[%s] serviceName=[%s]", __func__, regType, serviceName); + int index = adb_DNSServiceIndexByName(regType); + ServiceRegistry* services; + switch (index) { + case kADBTransportServiceRefIndex: + services = sAdbTransportServices; + break; + case kADBSecurePairingServiceRefIndex: + services = sAdbSecurePairingServices; + break; + case kADBSecureConnectServiceRefIndex: + services = sAdbSecureConnectServices; + break; + default: + return; + } + + if (services->empty()) { + return; + } + + std::string sName(serviceName); + services->erase(std::remove_if(services->begin(), services->end(), + [&sName](std::unique_ptr& service) { + return (sName == service->serviceName()); + }), + services->end()); +} + void adb_secure_foreach_pairing_service(const char* service_name, adb_secure_foreach_service_callback cb) { ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices, service_name, cb); @@ -481,35 +510,6 @@ class DiscoveredService : public AsyncServiceRef { std::string regType_; }; -static void adb_RemoveDNSService(const char* regType, const char* serviceName) { - D("%s: regType=[%s] serviceName=[%s]", __func__, regType, serviceName); - int index = adb_DNSServiceIndexByName(regType); - ResolvedService::ServiceRegistry* services; - switch (index) { - case kADBTransportServiceRefIndex: - services = ResolvedService::sAdbTransportServices; - break; - case kADBSecurePairingServiceRefIndex: - services = ResolvedService::sAdbSecurePairingServices; - break; - case kADBSecureConnectServiceRefIndex: - services = ResolvedService::sAdbSecureConnectServices; - break; - default: - return; - } - - if (services->empty()) { - return; - } - - std::string sName(serviceName); - services->erase(std::remove_if(services->begin(), services->end(), - [&sName](std::unique_ptr& service) { - return (sName == service->serviceName()); - })); -} - // Returns the version the device wanted to advertise, // or -1 if parsing fails. static int parse_version_from_txt_record(uint16_t txtLen, const unsigned char* txtRecord) { @@ -612,7 +612,7 @@ static void DNSSD_API on_service_browsed(DNSServiceRef sdRef, DNSServiceFlags fl } else { D("%s: Discover lost serviceName=[%s] regtype=[%s] domain=[%s]", __func__, serviceName, regtype, domain); - adb_RemoveDNSService(regtype, serviceName); + ResolvedService::removeDNSService(regtype, serviceName); } } diff --git a/adb/test_adb.py b/adb/test_adb.py index b9f0d5487..a32d8757d 100755 --- a/adb/test_adb.py +++ b/adb/test_adb.py @@ -618,21 +618,37 @@ def zeroconf_register_service(zeroconf_ctx, info): finally: zeroconf_ctx.unregister_service(info) +@contextlib.contextmanager +def zeroconf_register_services(zeroconf_ctx, infos): + """Context manager for multiple zeroconf services + + Registers all services given and unregisters all on cleanup. Returns the ServiceInfo + list supplied. + """ + + try: + for info in infos: + zeroconf_ctx.register_service(info) + yield infos + finally: + for info in infos: + zeroconf_ctx.unregister_service(info) + """Should match the service names listed in adb_mdns.h""" class MdnsTest: """Tests for adb mdns.""" + @staticmethod + def _mdns_services(port): + output = subprocess.check_output(["adb", "-P", str(port), "mdns", "services"]) + return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]] + + @staticmethod + def _devices(port): + output = subprocess.check_output(["adb", "-P", str(port), "devices"]) + return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]] + class Base(unittest.TestCase): - @staticmethod - def _mdns_services(port): - output = subprocess.check_output(["adb", "-P", str(port), "mdns", "services"]) - return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]] - - @staticmethod - def _devices(port): - output = subprocess.check_output(["adb", "-P", str(port), "devices"]) - return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]] - @contextlib.contextmanager def _adb_mdns_connect(self, server_port, mdns_instance, serial, should_connect): """Context manager for an ADB connection. @@ -690,6 +706,50 @@ class MdnsTest: self.assertFalse(any((serv_instance in line and serv_type in line) for line in MdnsTest._mdns_services(server_port))) + @unittest.skipIf(not is_zeroconf_installed(), "zeroconf library not installed") + def test_mdns_services_register_unregister_multiple(self): + """Ensure that `adb mdns services` correctly adds and removes multiple services + """ + from zeroconf import IPVersion, ServiceInfo + + with adb_server() as server_port: + output = subprocess.check_output(["adb", "-P", str(server_port), + "mdns", "services"]).strip() + self.assertTrue(output.startswith(b"List of discovered mdns services")) + + """TODO(joshuaduong): Add ipv6 tests once we have it working in adb""" + """Register/Unregister a service""" + with zeroconf_context(IPVersion.V4Only) as zc: + srvs = { + 'mdns_name': ["testservice0", "testservice1", "testservice2"], + 'mdns_type': "_" + self.service_name + "._tcp.", + 'ipaddr': [ + socket.inet_aton("192.168.0.1"), + socket.inet_aton("10.0.0.255"), + socket.inet_aton("172.16.1.100")], + 'port': [10000, 20000, 65535]} + srv_infos = [] + for i in range(len(srvs['mdns_name'])): + srv_infos.append(ServiceInfo( + srvs['mdns_type'] + "local.", + name=srvs['mdns_name'][i] + "." + srvs['mdns_type'] + "local.", + addresses=[srvs['ipaddr'][i]], + port=srvs['port'][i])) + + """ Register all devices, then unregister""" + with zeroconf_register_services(zc, srv_infos) as infos: + """Give adb some time to register the service""" + time.sleep(1) + for i in range(len(srvs['mdns_name'])): + self.assertTrue(any((srvs['mdns_name'][i] in line and srvs['mdns_type'] in line) + for line in MdnsTest._mdns_services(server_port))) + + """Give adb some time to unregister the service""" + time.sleep(1) + for i in range(len(srvs['mdns_name'])): + self.assertFalse(any((srvs['mdns_name'][i] in line and srvs['mdns_type'] in line) + for line in MdnsTest._mdns_services(server_port))) + @unittest.skipIf(not is_zeroconf_installed(), "zeroconf library not installed") def test_mdns_connect(self): """Ensure that `adb connect` by mdns instance name works (for non-pairing services)