Merge "adb: switch apacket over to a std::string payload."

This commit is contained in:
Treehugger Robot 2018-02-24 00:40:35 +00:00 committed by Gerrit Code Review
commit 581a4ceb00
8 changed files with 96 additions and 73 deletions

View File

@ -105,31 +105,27 @@ void fatal_errno(const char* fmt, ...) {
}
uint32_t calculate_apacket_checksum(const apacket* p) {
const unsigned char* x = reinterpret_cast<const unsigned char*>(p->data);
uint32_t sum = 0;
size_t count = p->msg.data_length;
while (count-- > 0) {
sum += *x++;
for (size_t i = 0; i < p->msg.data_length; ++i) {
sum += static_cast<uint8_t>(p->payload[i]);
}
return sum;
}
apacket* get_apacket(void)
{
apacket* p = reinterpret_cast<apacket*>(malloc(sizeof(apacket)));
apacket* p = new apacket();
if (p == nullptr) {
fatal("failed to allocate an apacket");
}
memset(p, 0, sizeof(apacket) - MAX_PAYLOAD);
memset(&p->msg, 0, sizeof(p->msg));
return p;
}
void put_apacket(apacket *p)
{
free(p);
delete p;
}
void handle_online(atransport *t)
@ -155,8 +151,7 @@ void handle_offline(atransport *t)
#define DUMPMAX 32
void print_packet(const char *label, apacket *p)
{
char *tag;
char *x;
const char* tag;
unsigned count;
switch(p->msg.command){
@ -173,15 +168,15 @@ void print_packet(const char *label, apacket *p)
fprintf(stderr, "%s: %s %08x %08x %04x \"",
label, tag, p->msg.arg0, p->msg.arg1, p->msg.data_length);
count = p->msg.data_length;
x = (char*) p->data;
if(count > DUMPMAX) {
const char* x = p->payload.data();
if (count > DUMPMAX) {
count = DUMPMAX;
tag = "\n";
} else {
tag = "\"\n";
}
while(count-- > 0){
if((*x >= ' ') && (*x < 127)) {
while (count-- > 0) {
if ((*x >= ' ') && (*x < 127)) {
fputc(*x, stderr);
} else {
fputc('.', stderr);
@ -254,8 +249,8 @@ void send_connect(atransport* t) {
<< connection_str.length() << ")";
}
memcpy(cp->data, connection_str.c_str(), connection_str.length());
cp->msg.data_length = connection_str.length();
cp->payload = std::move(connection_str);
cp->msg.data_length = cp->payload.size();
send_packet(cp, t);
}
@ -329,9 +324,7 @@ static void handle_new_connection(atransport* t, apacket* p) {
}
t->update_version(p->msg.arg0, p->msg.arg1);
std::string banner(reinterpret_cast<const char*>(p->data),
p->msg.data_length);
parse_banner(banner, t);
parse_banner(p->payload, t);
#if ADB_HOST
handle_online(t);
@ -354,6 +347,7 @@ void handle_packet(apacket *p, atransport *t)
((char*) (&(p->msg.command)))[2],
((char*) (&(p->msg.command)))[3]);
print_packet("recv", p);
CHECK_EQ(p->payload.size(), p->msg.data_length);
switch(p->msg.command){
case A_SYNC:
@ -380,11 +374,11 @@ void handle_packet(apacket *p, atransport *t)
if (t->GetConnectionState() == kCsOffline) {
t->SetConnectionState(kCsUnauthorized);
}
send_auth_response(p->data, p->msg.data_length, t);
send_auth_response(p->payload.data(), p->msg.data_length, t);
break;
#else
case ADB_AUTH_SIGNATURE:
if (adbd_auth_verify(t->token, sizeof(t->token), p->data, p->msg.data_length)) {
if (adbd_auth_verify(t->token, sizeof(t->token), p->payload)) {
adbd_auth_verified(t);
t->failed_auth_attempts = 0;
} else {
@ -394,7 +388,7 @@ void handle_packet(apacket *p, atransport *t)
break;
case ADB_AUTH_RSAPUBLICKEY:
adbd_auth_confirm_key(p->data, p->msg.data_length, t);
adbd_auth_confirm_key(p->payload.data(), p->msg.data_length, t);
break;
#endif
default:
@ -406,9 +400,7 @@ void handle_packet(apacket *p, atransport *t)
case A_OPEN: /* OPEN(local-id, 0, "destination") */
if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) {
char *name = (char*) p->data;
name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0;
asocket* s = create_local_service_socket(name, t);
asocket* s = create_local_service_socket(p->payload.c_str(), t);
if (s == nullptr) {
send_close(0, p->msg.arg0, t);
} else {
@ -474,11 +466,7 @@ void handle_packet(apacket *p, atransport *t)
asocket* s = find_local_socket(p->msg.arg1, p->msg.arg0);
if (s) {
unsigned rid = p->msg.arg0;
// TODO: Convert apacket::data to a type that we can move out of.
std::string copy(p->data, p->data + p->msg.data_length);
if (s->enqueue(s, std::move(copy)) == 0) {
if (s->enqueue(s, std::move(p->payload)) == 0) {
D("Enqueue the socket");
send_ready(s->id, rid, t);
}

View File

@ -74,7 +74,7 @@ struct amessage {
struct apacket {
amessage msg;
char data[MAX_PAYLOAD];
std::string payload;
};
uint32_t calculate_apacket_checksum(const apacket* packet);

View File

@ -49,7 +49,7 @@ void adbd_auth_init(void);
void adbd_auth_verified(atransport *t);
void adbd_cloexec_auth_socket();
bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len);
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig);
void adbd_auth_confirm_key(const char* data, size_t len, atransport* t);
void send_auth_request(atransport *t);

View File

@ -299,20 +299,25 @@ std::deque<std::shared_ptr<RSA>> adb_auth_get_private_keys() {
return result;
}
static int adb_auth_sign(RSA* key, const char* token, size_t token_size, char* sig) {
static std::string adb_auth_sign(RSA* key, const char* token, size_t token_size) {
if (token_size != TOKEN_SIZE) {
D("Unexpected token size %zd", token_size);
return 0;
}
std::string result;
result.resize(MAX_PAYLOAD);
unsigned int len;
if (!RSA_sign(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
reinterpret_cast<uint8_t*>(sig), &len, key)) {
return 0;
reinterpret_cast<uint8_t*>(&result[0]), &len, key)) {
return std::string();
}
result.resize(len);
D("adb_auth_sign len=%d", len);
return (int)len;
return result;
}
std::string adb_auth_get_userkey() {
@ -446,13 +451,14 @@ static void send_auth_publickey(atransport* t) {
}
apacket* p = get_apacket();
memcpy(p->data, key.c_str(), key.size() + 1);
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_RSAPUBLICKEY;
p->payload = std::move(key);
// adbd expects a null-terminated string.
p->msg.data_length = key.size() + 1;
p->payload.push_back('\0');
p->msg.data_length = p->payload.size();
send_packet(p, t);
}
@ -467,8 +473,8 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) {
LOG(INFO) << "Calling send_auth_response";
apacket* p = get_apacket();
int ret = adb_auth_sign(key.get(), token, token_size, p->data);
if (!ret) {
std::string result = adb_auth_sign(key.get(), token, token_size);
if (result.empty()) {
D("Error signing the token");
put_apacket(p);
return;
@ -476,6 +482,7 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) {
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_SIGNATURE;
p->msg.data_length = ret;
p->payload = std::move(result);
p->msg.data_length = p->payload.size();
send_packet(p, t);
}

View File

@ -46,7 +46,7 @@ static bool needs_retry = false;
bool auth_required = true;
bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len) {
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig) {
static constexpr const char* key_paths[] = { "/adb_keys", "/data/misc/adb/adb_keys", nullptr };
for (const auto& path : key_paths) {
@ -80,7 +80,8 @@ bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int
bool verified =
(RSA_verify(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
reinterpret_cast<const uint8_t*>(sig), sig_len, key) == 1);
reinterpret_cast<const uint8_t*>(sig.c_str()), sig.size(),
key) == 1);
RSA_free(key);
if (verified) return true;
}
@ -210,10 +211,10 @@ void send_auth_request(atransport* t) {
}
apacket* p = get_apacket();
memcpy(p->data, t->token, sizeof(t->token));
p->msg.command = A_AUTH;
p->msg.arg0 = ADB_AUTH_TOKEN;
p->msg.data_length = sizeof(t->token);
p->payload.assign(t->token, t->token + sizeof(t->token));
send_packet(p, t);
}

View File

@ -413,15 +413,15 @@ static int remote_socket_enqueue(asocket* s, std::string data) {
p->msg.command = A_WRTE;
p->msg.arg0 = s->peer->id;
p->msg.arg1 = s->id;
p->msg.data_length = data.size();
if (data.size() > sizeof(p->data)) {
if (data.size() > MAX_PAYLOAD) {
put_apacket(p);
return -1;
}
// TODO: Convert apacket::data to a type that we can move into.
memcpy(p->data, data.data(), data.size());
p->payload = std::move(data);
p->msg.data_length = p->payload.size();
send_packet(p, s->transport);
return 1;
}
@ -482,17 +482,20 @@ asocket* create_remote_socket(unsigned id, atransport* t) {
void connect_to_remote(asocket* s, const char* destination) {
D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
apacket* p = get_apacket();
size_t len = strlen(destination) + 1;
if (len > (s->get_max_payload() - 1)) {
fatal("destination oversized");
}
D("LS(%d): connect('%s')", s->id, destination);
p->msg.command = A_OPEN;
p->msg.arg0 = s->id;
p->msg.data_length = len;
strcpy((char*)p->data, destination);
// adbd expects a null-terminated string.
p->payload = destination;
p->payload.push_back('\0');
p->msg.data_length = p->payload.size();
if (p->msg.data_length > s->get_max_payload()) {
fatal("destination oversized");
}
send_packet(p, s->transport);
}

View File

@ -72,12 +72,14 @@ bool FdConnection::Read(apacket* packet) {
return false;
}
if (packet->msg.data_length > sizeof(packet->data)) {
if (packet->msg.data_length > MAX_PAYLOAD) {
D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length);
return false;
}
if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
packet->payload.resize(packet->msg.data_length);
if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
D("remote local: terminated (data)");
return false;
}
@ -86,13 +88,18 @@ bool FdConnection::Read(apacket* packet) {
}
bool FdConnection::Write(apacket* packet) {
uint32_t length = packet->msg.data_length;
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
D("remote local: write terminated");
return false;
}
if (packet->msg.data_length) {
if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
D("remote local: write terminated");
return false;
}
}
return true;
}
@ -133,7 +140,7 @@ static std::string dump_packet(const char* name, const char* func, apacket* p) {
std::string result = android::base::StringPrintf("%s: %s: [%s] arg0=%s arg1=%s (len=%d) ", name,
func, cmd, arg0, arg1, len);
result += dump_hex(p->data, len);
result += dump_hex(p->payload.data(), p->payload.size());
return result;
}
@ -191,9 +198,10 @@ static void transport_socket_events(int fd, unsigned events, void* _t) {
apacket* p = 0;
if (read_packet(fd, t->serial, &p)) {
D("%s: failed to read packet from transport socket on fd %d", t->serial, fd);
} else {
handle_packet(p, (atransport*)_t);
return;
}
handle_packet(p, (atransport*)_t);
}
}
@ -243,6 +251,7 @@ static void read_transport_thread(void* _t) {
p->msg.arg0 = 1;
p->msg.arg1 = ++(t->sync_token);
p->msg.magic = A_SYNC ^ 0xffffffff;
D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size());
if (write_packet(t->fd, t->serial, &p)) {
put_apacket(p);
D("%s: failed to write SYNC packet", t->serial);
@ -336,6 +345,13 @@ static void write_transport_thread(void* _t) {
if (active) {
D("%s: transport got packet, sending to remote", t->serial);
ATRACE_NAME("write_transport write_remote");
// Allow sending the payload's implicit null terminator.
if (p->msg.data_length != p->payload.size()) {
LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = "
<< p->msg.data_length << ", payload.size() = " << p->payload.size();
}
if (t->Write(p) != 0) {
D("%s: remote write failed for transport", t->serial);
put_apacket(p);

View File

@ -61,13 +61,12 @@ static int UsbReadMessage(usb_handle* h, amessage* msg) {
static int UsbReadPayload(usb_handle* h, apacket* p) {
D("UsbReadPayload(%d)", p->msg.data_length);
if (p->msg.data_length > sizeof(p->data)) {
if (p->msg.data_length > MAX_PAYLOAD) {
return -1;
}
#if CHECK_PACKET_OVERFLOW
size_t usb_packet_size = usb_get_max_packet_size(h);
CHECK_EQ(0ULL, sizeof(p->data) % usb_packet_size);
// Round the data length up to the nearest packet size boundary.
// The device won't send a zero packet for packet size aligned payloads,
@ -77,10 +76,18 @@ static int UsbReadPayload(usb_handle* h, apacket* p) {
if (rem_size) {
len += usb_packet_size - rem_size;
}
CHECK_LE(len, sizeof(p->data));
return usb_read(h, &p->data, len);
p->payload.resize(len);
int rc = usb_read(h, &p->payload[0], p->payload.size());
if (rc != static_cast<int>(p->msg.data_length)) {
return -1;
}
p->payload.resize(rc);
return rc;
#else
return usb_read(h, &p->data, p->msg.data_length);
p->payload.resize(p->msg.data_length);
return usb_read(h, &p->payload[0], p->payload.size());
#endif
}
@ -120,12 +127,13 @@ static int remote_read(apacket* p, usb_handle* usb) {
}
if (p->msg.data_length) {
if (p->msg.data_length > sizeof(p->data)) {
if (p->msg.data_length > MAX_PAYLOAD) {
PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")";
return -1;
}
if (usb_read(usb, p->data, p->msg.data_length)) {
p->payload.resize(p->msg.data_length);
if (usb_read(usb, &p->payload[0], p->payload.size())) {
PLOG(ERROR) << "remote usb: terminated (data)";
return -1;
}
@ -152,7 +160,7 @@ bool UsbConnection::Write(apacket* packet) {
return false;
}
if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != 0) {
PLOG(ERROR) << "remote usb: 2 - write terminated";
return false;
}