diff --git a/adb/client/usb_dispatch.cpp b/adb/client/usb_dispatch.cpp index bfc8e164b..710a3ce85 100644 --- a/adb/client/usb_dispatch.cpp +++ b/adb/client/usb_dispatch.cpp @@ -48,3 +48,9 @@ void usb_kick(usb_handle* h) { should_use_libusb() ? libusb::usb_kick(reinterpret_cast(h)) : native::usb_kick(reinterpret_cast(h)); } + +size_t usb_get_max_packet_size(usb_handle* h) { + return should_use_libusb() + ? libusb::usb_get_max_packet_size(reinterpret_cast(h)) + : native::usb_get_max_packet_size(reinterpret_cast(h)); +} diff --git a/adb/client/usb_libusb.cpp b/adb/client/usb_libusb.cpp index fec4742b2..d39884ac7 100644 --- a/adb/client/usb_libusb.cpp +++ b/adb/client/usb_libusb.cpp @@ -91,7 +91,7 @@ namespace libusb { struct usb_handle : public ::usb_handle { usb_handle(const std::string& device_address, const std::string& serial, unique_device_handle&& device_handle, uint8_t interface, uint8_t bulk_in, - uint8_t bulk_out, size_t zero_mask) + uint8_t bulk_out, size_t zero_mask, size_t max_packet_size) : device_address(device_address), serial(serial), closing(false), @@ -100,7 +100,8 @@ struct usb_handle : public ::usb_handle { write("write", zero_mask, true), interface(interface), bulk_in(bulk_in), - bulk_out(bulk_out) {} + bulk_out(bulk_out), + max_packet_size(max_packet_size) {} ~usb_handle() { Close(); @@ -143,6 +144,8 @@ struct usb_handle : public ::usb_handle { uint8_t interface; uint8_t bulk_in; uint8_t bulk_out; + + size_t max_packet_size; }; static auto& usb_handles = *new std::unordered_map>(); @@ -206,6 +209,7 @@ static void poll_for_devices() { size_t interface_num; uint16_t zero_mask; uint8_t bulk_in = 0, bulk_out = 0; + size_t packet_size = 0; bool found_adb = false; for (interface_num = 0; interface_num < config->bNumInterfaces; ++interface_num) { @@ -252,6 +256,14 @@ static void poll_for_devices() { found_in = true; bulk_in = endpoint_addr; } + + size_t endpoint_packet_size = endpoint_desc.wMaxPacketSize; + CHECK(endpoint_packet_size != 0); + if (packet_size == 0) { + packet_size = endpoint_packet_size; + } else { + CHECK(packet_size == endpoint_packet_size); + } } if (found_in && found_out) { @@ -280,7 +292,7 @@ static void poll_for_devices() { } libusb_device_handle* handle_raw; - rc = libusb_open(list[i], &handle_raw); + rc = libusb_open(device, &handle_raw); if (rc != 0) { LOG(WARNING) << "failed to open usb device at " << device_address << ": " << libusb_error_name(rc); @@ -324,9 +336,9 @@ static void poll_for_devices() { } } - auto result = - std::make_unique(device_address, device_serial, std::move(handle), - interface_num, bulk_in, bulk_out, zero_mask); + auto result = std::make_unique(device_address, device_serial, + std::move(handle), interface_num, bulk_in, + bulk_out, zero_mask, packet_size); usb_handle* usb_handle_raw = result.get(); { @@ -507,4 +519,10 @@ int usb_close(usb_handle* h) { void usb_kick(usb_handle* h) { h->Close(); } + +size_t usb_get_max_packet_size(usb_handle* h) { + CHECK(h->max_packet_size != 0); + return h->max_packet_size; +} + } // namespace libusb diff --git a/adb/client/usb_linux.cpp b/adb/client/usb_linux.cpp index 6efed274b..f9ba7cbc2 100644 --- a/adb/client/usb_linux.cpp +++ b/adb/client/usb_linux.cpp @@ -65,6 +65,7 @@ struct usb_handle : public ::usb_handle { unsigned char ep_in; unsigned char ep_out; + size_t max_packet_size; unsigned zero_mask; unsigned writeable = 1; @@ -120,9 +121,9 @@ static inline bool contains_non_digit(const char* name) { } static void find_usb_device(const std::string& base, - void (*register_device_callback) - (const char*, const char*, unsigned char, unsigned char, int, int, unsigned)) -{ + void (*register_device_callback)(const char*, const char*, + unsigned char, unsigned char, int, int, + unsigned, size_t)) { std::unique_ptr bus_dir(opendir(base.c_str()), closedir); if (!bus_dir) return; @@ -144,6 +145,7 @@ static void find_usb_device(const std::string& base, struct usb_interface_descriptor* interface; struct usb_endpoint_descriptor *ep1, *ep2; unsigned zero_mask = 0; + size_t max_packet_size = 0; unsigned vid, pid; if (contains_non_digit(de->d_name)) continue; @@ -251,7 +253,8 @@ static void find_usb_device(const std::string& base, continue; } /* aproto 01 needs 0 termination */ - if(interface->bInterfaceProtocol == 0x01) { + if (interface->bInterfaceProtocol == 0x01) { + max_packet_size = ep1->wMaxPacketSize; zero_mask = ep1->wMaxPacketSize - 1; } @@ -281,9 +284,9 @@ static void find_usb_device(const std::string& base, } } - register_device_callback(dev_name.c_str(), devpath, - local_ep_in, local_ep_out, - interface->bInterfaceNumber, device->iSerialNumber, zero_mask); + register_device_callback(dev_name.c_str(), devpath, local_ep_in, + local_ep_out, interface->bInterfaceNumber, + device->iSerialNumber, zero_mask, max_packet_size); break; } } else { @@ -497,10 +500,13 @@ int usb_close(usb_handle* h) { return 0; } -static void register_device(const char* dev_name, const char* dev_path, - unsigned char ep_in, unsigned char ep_out, - int interface, int serial_index, - unsigned zero_mask) { +size_t usb_get_max_packet_size(usb_handle* h) { + return h->max_packet_size; +} + +static void register_device(const char* dev_name, const char* dev_path, unsigned char ep_in, + unsigned char ep_out, int interface, int serial_index, + unsigned zero_mask, size_t max_packet_size) { // Since Linux will not reassign the device ID (and dev_name) as long as the // device is open, we can add to the list here once we open it and remove // from the list when we're finally closed and everything will work out @@ -523,6 +529,7 @@ static void register_device(const char* dev_name, const char* dev_path, usb->ep_in = ep_in; usb->ep_out = ep_out; usb->zero_mask = zero_mask; + usb->max_packet_size = max_packet_size; // Initialize mark so we don't get garbage collected after the device scan. usb->mark = true; diff --git a/adb/client/usb_osx.cpp b/adb/client/usb_osx.cpp index fcd0bc044..e4a543bba 100644 --- a/adb/client/usb_osx.cpp +++ b/adb/client/usb_osx.cpp @@ -51,15 +51,21 @@ struct usb_handle UInt8 bulkOut; IOUSBInterfaceInterface190** interface; unsigned int zero_mask; + size_t max_packet_size; // For garbage collecting disconnected devices. bool mark; std::string devpath; std::atomic dead; - usb_handle() : bulkIn(0), bulkOut(0), interface(nullptr), - zero_mask(0), mark(false), dead(false) { - } + usb_handle() + : bulkIn(0), + bulkOut(0), + interface(nullptr), + zero_mask(0), + max_packet_size(0), + mark(false), + dead(false) {} }; static std::atomic usb_inited_flag; @@ -390,6 +396,7 @@ CheckInterface(IOUSBInterfaceInterface190 **interface, UInt16 vendor, UInt16 pro } handle->zero_mask = maxPacketSize - 1; + handle->max_packet_size = maxPacketSize; } handle->interface = interface; @@ -558,4 +565,9 @@ void usb_kick(usb_handle *handle) { std::lock_guard lock_guard(g_usb_handles_mutex); usb_kick_locked(handle); } + +size_t usb_get_max_packet_size(usb_handle* handle) { + return handle->max_packet_size; +} + } // namespace native diff --git a/adb/client/usb_windows.cpp b/adb/client/usb_windows.cpp index ee7f8024f..ec55b0e2a 100644 --- a/adb/client/usb_windows.cpp +++ b/adb/client/usb_windows.cpp @@ -65,6 +65,9 @@ struct usb_handle { /// Interface name wchar_t* interface_name; + /// Maximum packet size. + unsigned max_packet_size; + /// Mask for determining when to use zero length packets unsigned zero_mask; }; @@ -522,6 +525,10 @@ int usb_close(usb_handle* handle) { return 0; } +size_t usb_get_max_packet_size(usb_handle* handle) { + return handle->max_packet_size; +} + int recognized_device(usb_handle* handle) { if (NULL == handle) return 0; @@ -557,6 +564,7 @@ int recognized_device(usb_handle* handle) { AdbEndpointInformation endpoint_info; // assuming zero is a valid bulk endpoint ID if (AdbGetEndpointInformation(handle->adb_interface, 0, &endpoint_info)) { + handle->max_packet_size = endpoint_info.max_packet_size; handle->zero_mask = endpoint_info.max_packet_size - 1; D("device zero_mask: 0x%x", handle->zero_mask); } else { diff --git a/adb/test_device.py b/adb/test_device.py index a30972e54..e44cc83f0 100644 --- a/adb/test_device.py +++ b/adb/test_device.py @@ -1259,6 +1259,26 @@ class DeviceOfflineTest(DeviceTest): self.assertEqual(self._get_device_state(serialno), 'device') + def test_packet_size_regression(self): + """Test for http://b/37783561 + + Receiving packets of a length divisible by 512 but not 1024 resulted in + the adb client waiting indefinitely for more input. + """ + # The values that trigger things are 507 (512 - 5 bytes from shell protocol) + 1024*n + # Probe some surrounding values as well, for the hell of it. + for length in [506, 507, 508, 1018, 1019, 1020, 1530, 1531, 1532]: + cmd = ['dd', 'if=/dev/zero', 'bs={}'.format(length), 'count=1', '2>/dev/null;' + 'echo', 'foo'] + rc, stdout, _ = self.device.shell_nocheck(cmd) + + self.assertEqual(0, rc) + + # Output should be '\0' * length, followed by "foo\n" + self.assertEqual(length, len(stdout) - 4) + self.assertEqual(stdout, "\0" * length + "foo\n") + + def main(): random.seed(0) if len(adb.get_devices()) > 0: diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index ce419b88d..885d7230e 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -27,57 +27,43 @@ #if ADB_HOST -static constexpr size_t MAX_USB_BULK_PACKET_SIZE = 1024u; - -// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. static int UsbReadMessage(usb_handle* h, amessage* msg) { D("UsbReadMessage"); - char buffer[MAX_USB_BULK_PACKET_SIZE]; - int n = usb_read(h, buffer, sizeof(buffer)); - if (n == sizeof(*msg)) { - memcpy(msg, buffer, sizeof(*msg)); + + size_t usb_packet_size = usb_get_max_packet_size(h); + CHECK(usb_packet_size >= sizeof(*msg)); + CHECK(usb_packet_size < 4096); + + char buffer[4096]; + int n = usb_read(h, buffer, usb_packet_size); + if (n != sizeof(*msg)) { + D("usb_read returned unexpected length %d (expected %zu)", n, sizeof(*msg)); + return -1; } + memcpy(msg, buffer, sizeof(*msg)); return n; } -// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. static int UsbReadPayload(usb_handle* h, apacket* p) { - D("UsbReadPayload"); - size_t need_size = p->msg.data_length; - size_t data_pos = 0u; - while (need_size > 0u) { - int n = 0; - if (data_pos + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { - // Read directly to p->data. - size_t rem_size = need_size % MAX_USB_BULK_PACKET_SIZE; - size_t direct_read_size = need_size - rem_size; - if (rem_size && - data_pos + direct_read_size + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { - direct_read_size += MAX_USB_BULK_PACKET_SIZE; - } - n = usb_read(h, &p->data[data_pos], direct_read_size); - if (n < 0) { - D("usb_read(size %zu) failed", direct_read_size); - return n; - } - } else { - // Read indirectly using a buffer. - char buffer[MAX_USB_BULK_PACKET_SIZE]; - n = usb_read(h, buffer, sizeof(buffer)); - if (n < 0) { - D("usb_read(size %zu) failed", sizeof(buffer)); - return -1; - } - size_t copy_size = std::min(static_cast(n), need_size); - D("usb read %d bytes, need %zu bytes, copy %zu bytes", n, need_size, copy_size); - memcpy(&p->data[data_pos], buffer, copy_size); - } - data_pos += n; - need_size -= std::min(static_cast(n), need_size); + D("UsbReadPayload(%d)", p->msg.data_length); + + size_t usb_packet_size = usb_get_max_packet_size(h); + CHECK(sizeof(p->data) % usb_packet_size == 0); + + // Round the data length up to the nearest packet size boundary. + // The device won't send a zero packet for packet size aligned payloads, + // so don't read any more packets than needed. + size_t len = p->msg.data_length; + size_t rem_size = len % usb_packet_size; + if (rem_size) { + len += usb_packet_size - rem_size; } - return static_cast(data_pos); + CHECK(len <= sizeof(p->data)); + return usb_read(h, &p->data, len); } static int remote_read(apacket* p, atransport* t) { diff --git a/adb/usb.h b/adb/usb.h index ba70de43e..e867ec8a3 100644 --- a/adb/usb.h +++ b/adb/usb.h @@ -16,6 +16,8 @@ #pragma once +#include + // USB host/client interface. #define ADB_USB_INTERFACE(handle_ref_type) \ @@ -23,7 +25,8 @@ int usb_write(handle_ref_type h, const void* data, int len); \ int usb_read(handle_ref_type h, void* data, int len); \ int usb_close(handle_ref_type h); \ - void usb_kick(handle_ref_type h) + void usb_kick(handle_ref_type h); \ + size_t usb_get_max_packet_size(handle_ref_type) #if defined(_WIN32) || !ADB_HOST // Windows and the daemon have a single implementation.