diff --git a/adb/adb.cpp b/adb/adb.cpp index ee3503bb3..ae8020e81 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -105,31 +105,27 @@ void fatal_errno(const char* fmt, ...) { } uint32_t calculate_apacket_checksum(const apacket* p) { - const unsigned char* x = reinterpret_cast(p->data); uint32_t sum = 0; - size_t count = p->msg.data_length; - - while (count-- > 0) { - sum += *x++; + for (size_t i = 0; i < p->msg.data_length; ++i) { + sum += static_cast(p->payload[i]); } - return sum; } apacket* get_apacket(void) { - apacket* p = reinterpret_cast(malloc(sizeof(apacket))); + apacket* p = new apacket(); if (p == nullptr) { fatal("failed to allocate an apacket"); } - memset(p, 0, sizeof(apacket) - MAX_PAYLOAD); + memset(&p->msg, 0, sizeof(p->msg)); return p; } void put_apacket(apacket *p) { - free(p); + delete p; } void handle_online(atransport *t) @@ -155,8 +151,7 @@ void handle_offline(atransport *t) #define DUMPMAX 32 void print_packet(const char *label, apacket *p) { - char *tag; - char *x; + const char* tag; unsigned count; switch(p->msg.command){ @@ -173,15 +168,15 @@ void print_packet(const char *label, apacket *p) fprintf(stderr, "%s: %s %08x %08x %04x \"", label, tag, p->msg.arg0, p->msg.arg1, p->msg.data_length); count = p->msg.data_length; - x = (char*) p->data; - if(count > DUMPMAX) { + const char* x = p->payload.data(); + if (count > DUMPMAX) { count = DUMPMAX; tag = "\n"; } else { tag = "\"\n"; } - while(count-- > 0){ - if((*x >= ' ') && (*x < 127)) { + while (count-- > 0) { + if ((*x >= ' ') && (*x < 127)) { fputc(*x, stderr); } else { fputc('.', stderr); @@ -254,8 +249,8 @@ void send_connect(atransport* t) { << connection_str.length() << ")"; } - memcpy(cp->data, connection_str.c_str(), connection_str.length()); - cp->msg.data_length = connection_str.length(); + cp->payload = std::move(connection_str); + cp->msg.data_length = cp->payload.size(); send_packet(cp, t); } @@ -329,9 +324,7 @@ static void handle_new_connection(atransport* t, apacket* p) { } t->update_version(p->msg.arg0, p->msg.arg1); - std::string banner(reinterpret_cast(p->data), - p->msg.data_length); - parse_banner(banner, t); + parse_banner(p->payload, t); #if ADB_HOST handle_online(t); @@ -354,6 +347,7 @@ void handle_packet(apacket *p, atransport *t) ((char*) (&(p->msg.command)))[2], ((char*) (&(p->msg.command)))[3]); print_packet("recv", p); + CHECK_EQ(p->payload.size(), p->msg.data_length); switch(p->msg.command){ case A_SYNC: @@ -380,11 +374,11 @@ void handle_packet(apacket *p, atransport *t) if (t->GetConnectionState() == kCsOffline) { t->SetConnectionState(kCsUnauthorized); } - send_auth_response(p->data, p->msg.data_length, t); + send_auth_response(p->payload.data(), p->msg.data_length, t); break; #else case ADB_AUTH_SIGNATURE: - if (adbd_auth_verify(t->token, sizeof(t->token), p->data, p->msg.data_length)) { + if (adbd_auth_verify(t->token, sizeof(t->token), p->payload)) { adbd_auth_verified(t); t->failed_auth_attempts = 0; } else { @@ -394,7 +388,7 @@ void handle_packet(apacket *p, atransport *t) break; case ADB_AUTH_RSAPUBLICKEY: - adbd_auth_confirm_key(p->data, p->msg.data_length, t); + adbd_auth_confirm_key(p->payload.data(), p->msg.data_length, t); break; #endif default: @@ -406,9 +400,7 @@ void handle_packet(apacket *p, atransport *t) case A_OPEN: /* OPEN(local-id, 0, "destination") */ if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) { - char *name = (char*) p->data; - name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0; - asocket* s = create_local_service_socket(name, t); + asocket* s = create_local_service_socket(p->payload.c_str(), t); if (s == nullptr) { send_close(0, p->msg.arg0, t); } else { @@ -474,11 +466,7 @@ void handle_packet(apacket *p, atransport *t) asocket* s = find_local_socket(p->msg.arg1, p->msg.arg0); if (s) { unsigned rid = p->msg.arg0; - - // TODO: Convert apacket::data to a type that we can move out of. - std::string copy(p->data, p->data + p->msg.data_length); - - if (s->enqueue(s, std::move(copy)) == 0) { + if (s->enqueue(s, std::move(p->payload)) == 0) { D("Enqueue the socket"); send_ready(s->id, rid, t); } diff --git a/adb/adb.h b/adb/adb.h index c9c635a8d..a6d04631d 100644 --- a/adb/adb.h +++ b/adb/adb.h @@ -74,7 +74,7 @@ struct amessage { struct apacket { amessage msg; - char data[MAX_PAYLOAD]; + std::string payload; }; uint32_t calculate_apacket_checksum(const apacket* packet); diff --git a/adb/adb_auth.h b/adb/adb_auth.h index a6f224f00..715e04f2c 100644 --- a/adb/adb_auth.h +++ b/adb/adb_auth.h @@ -49,7 +49,7 @@ void adbd_auth_init(void); void adbd_auth_verified(atransport *t); void adbd_cloexec_auth_socket(); -bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len); +bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig); void adbd_auth_confirm_key(const char* data, size_t len, atransport* t); void send_auth_request(atransport *t); diff --git a/adb/adb_auth_host.cpp b/adb/adb_auth_host.cpp index 365bf77a9..c3aef16d4 100644 --- a/adb/adb_auth_host.cpp +++ b/adb/adb_auth_host.cpp @@ -299,20 +299,25 @@ std::deque> adb_auth_get_private_keys() { return result; } -static int adb_auth_sign(RSA* key, const char* token, size_t token_size, char* sig) { +static std::string adb_auth_sign(RSA* key, const char* token, size_t token_size) { if (token_size != TOKEN_SIZE) { D("Unexpected token size %zd", token_size); return 0; } + std::string result; + result.resize(MAX_PAYLOAD); + unsigned int len; if (!RSA_sign(NID_sha1, reinterpret_cast(token), token_size, - reinterpret_cast(sig), &len, key)) { - return 0; + reinterpret_cast(&result[0]), &len, key)) { + return std::string(); } + result.resize(len); + D("adb_auth_sign len=%d", len); - return (int)len; + return result; } std::string adb_auth_get_userkey() { @@ -446,13 +451,14 @@ static void send_auth_publickey(atransport* t) { } apacket* p = get_apacket(); - memcpy(p->data, key.c_str(), key.size() + 1); - p->msg.command = A_AUTH; p->msg.arg0 = ADB_AUTH_RSAPUBLICKEY; + p->payload = std::move(key); + // adbd expects a null-terminated string. - p->msg.data_length = key.size() + 1; + p->payload.push_back('\0'); + p->msg.data_length = p->payload.size(); send_packet(p, t); } @@ -467,8 +473,8 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) { LOG(INFO) << "Calling send_auth_response"; apacket* p = get_apacket(); - int ret = adb_auth_sign(key.get(), token, token_size, p->data); - if (!ret) { + std::string result = adb_auth_sign(key.get(), token, token_size); + if (result.empty()) { D("Error signing the token"); put_apacket(p); return; @@ -476,6 +482,7 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) { p->msg.command = A_AUTH; p->msg.arg0 = ADB_AUTH_SIGNATURE; - p->msg.data_length = ret; + p->payload = std::move(result); + p->msg.data_length = p->payload.size(); send_packet(p, t); } diff --git a/adb/adbd_auth.cpp b/adb/adbd_auth.cpp index 3488ad194..3fd2b3194 100644 --- a/adb/adbd_auth.cpp +++ b/adb/adbd_auth.cpp @@ -46,7 +46,7 @@ static bool needs_retry = false; bool auth_required = true; -bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len) { +bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig) { static constexpr const char* key_paths[] = { "/adb_keys", "/data/misc/adb/adb_keys", nullptr }; for (const auto& path : key_paths) { @@ -80,7 +80,8 @@ bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int bool verified = (RSA_verify(NID_sha1, reinterpret_cast(token), token_size, - reinterpret_cast(sig), sig_len, key) == 1); + reinterpret_cast(sig.c_str()), sig.size(), + key) == 1); RSA_free(key); if (verified) return true; } @@ -210,10 +211,10 @@ void send_auth_request(atransport* t) { } apacket* p = get_apacket(); - memcpy(p->data, t->token, sizeof(t->token)); p->msg.command = A_AUTH; p->msg.arg0 = ADB_AUTH_TOKEN; p->msg.data_length = sizeof(t->token); + p->payload.assign(t->token, t->token + sizeof(t->token)); send_packet(p, t); } diff --git a/adb/sockets.cpp b/adb/sockets.cpp index 307cbfe0e..0007fed7b 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -413,15 +413,15 @@ static int remote_socket_enqueue(asocket* s, std::string data) { p->msg.command = A_WRTE; p->msg.arg0 = s->peer->id; p->msg.arg1 = s->id; - p->msg.data_length = data.size(); - if (data.size() > sizeof(p->data)) { + if (data.size() > MAX_PAYLOAD) { put_apacket(p); return -1; } - // TODO: Convert apacket::data to a type that we can move into. - memcpy(p->data, data.data(), data.size()); + p->payload = std::move(data); + p->msg.data_length = p->payload.size(); + send_packet(p, s->transport); return 1; } @@ -482,17 +482,20 @@ asocket* create_remote_socket(unsigned id, atransport* t) { void connect_to_remote(asocket* s, const char* destination) { D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd); apacket* p = get_apacket(); - size_t len = strlen(destination) + 1; - - if (len > (s->get_max_payload() - 1)) { - fatal("destination oversized"); - } D("LS(%d): connect('%s')", s->id, destination); p->msg.command = A_OPEN; p->msg.arg0 = s->id; - p->msg.data_length = len; - strcpy((char*)p->data, destination); + + // adbd expects a null-terminated string. + p->payload = destination; + p->payload.push_back('\0'); + p->msg.data_length = p->payload.size(); + + if (p->msg.data_length > s->get_max_payload()) { + fatal("destination oversized"); + } + send_packet(p, s->transport); } diff --git a/adb/transport.cpp b/adb/transport.cpp index 9ae129751..14888ab70 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -72,12 +72,14 @@ bool FdConnection::Read(apacket* packet) { return false; } - if (packet->msg.data_length > sizeof(packet->data)) { + if (packet->msg.data_length > MAX_PAYLOAD) { D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length); return false; } - if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) { + packet->payload.resize(packet->msg.data_length); + + if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) { D("remote local: terminated (data)"); return false; } @@ -86,13 +88,18 @@ bool FdConnection::Read(apacket* packet) { } bool FdConnection::Write(apacket* packet) { - uint32_t length = packet->msg.data_length; - - if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) { + if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) { D("remote local: write terminated"); return false; } + if (packet->msg.data_length) { + if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) { + D("remote local: write terminated"); + return false; + } + } + return true; } @@ -133,7 +140,7 @@ static std::string dump_packet(const char* name, const char* func, apacket* p) { std::string result = android::base::StringPrintf("%s: %s: [%s] arg0=%s arg1=%s (len=%d) ", name, func, cmd, arg0, arg1, len); - result += dump_hex(p->data, len); + result += dump_hex(p->payload.data(), p->payload.size()); return result; } @@ -191,9 +198,10 @@ static void transport_socket_events(int fd, unsigned events, void* _t) { apacket* p = 0; if (read_packet(fd, t->serial, &p)) { D("%s: failed to read packet from transport socket on fd %d", t->serial, fd); - } else { - handle_packet(p, (atransport*)_t); + return; } + + handle_packet(p, (atransport*)_t); } } @@ -243,6 +251,7 @@ static void read_transport_thread(void* _t) { p->msg.arg0 = 1; p->msg.arg1 = ++(t->sync_token); p->msg.magic = A_SYNC ^ 0xffffffff; + D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size()); if (write_packet(t->fd, t->serial, &p)) { put_apacket(p); D("%s: failed to write SYNC packet", t->serial); @@ -336,6 +345,13 @@ static void write_transport_thread(void* _t) { if (active) { D("%s: transport got packet, sending to remote", t->serial); ATRACE_NAME("write_transport write_remote"); + + // Allow sending the payload's implicit null terminator. + if (p->msg.data_length != p->payload.size()) { + LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = " + << p->msg.data_length << ", payload.size() = " << p->payload.size(); + } + if (t->Write(p) != 0) { D("%s: remote write failed for transport", t->serial); put_apacket(p); diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index a1086999d..d7565f63d 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -61,13 +61,12 @@ static int UsbReadMessage(usb_handle* h, amessage* msg) { static int UsbReadPayload(usb_handle* h, apacket* p) { D("UsbReadPayload(%d)", p->msg.data_length); - if (p->msg.data_length > sizeof(p->data)) { + if (p->msg.data_length > MAX_PAYLOAD) { return -1; } #if CHECK_PACKET_OVERFLOW size_t usb_packet_size = usb_get_max_packet_size(h); - CHECK_EQ(0ULL, sizeof(p->data) % usb_packet_size); // Round the data length up to the nearest packet size boundary. // The device won't send a zero packet for packet size aligned payloads, @@ -77,10 +76,18 @@ static int UsbReadPayload(usb_handle* h, apacket* p) { if (rem_size) { len += usb_packet_size - rem_size; } - CHECK_LE(len, sizeof(p->data)); - return usb_read(h, &p->data, len); + + p->payload.resize(len); + int rc = usb_read(h, &p->payload[0], p->payload.size()); + if (rc != static_cast(p->msg.data_length)) { + return -1; + } + + p->payload.resize(rc); + return rc; #else - return usb_read(h, &p->data, p->msg.data_length); + p->payload.resize(p->msg.data_length); + return usb_read(h, &p->payload[0], p->payload.size()); #endif } @@ -120,12 +127,13 @@ static int remote_read(apacket* p, usb_handle* usb) { } if (p->msg.data_length) { - if (p->msg.data_length > sizeof(p->data)) { + if (p->msg.data_length > MAX_PAYLOAD) { PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")"; return -1; } - if (usb_read(usb, p->data, p->msg.data_length)) { + p->payload.resize(p->msg.data_length); + if (usb_read(usb, &p->payload[0], p->payload.size())) { PLOG(ERROR) << "remote usb: terminated (data)"; return -1; } @@ -152,7 +160,7 @@ bool UsbConnection::Write(apacket* packet) { return false; } - if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) { + if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != 0) { PLOG(ERROR) << "remote usb: 2 - write terminated"; return false; }