Merge "adb: switch apacket over to a std::string payload."
This commit is contained in:
commit
581a4ceb00
50
adb/adb.cpp
50
adb/adb.cpp
|
@ -105,31 +105,27 @@ void fatal_errno(const char* fmt, ...) {
|
|||
}
|
||||
|
||||
uint32_t calculate_apacket_checksum(const apacket* p) {
|
||||
const unsigned char* x = reinterpret_cast<const unsigned char*>(p->data);
|
||||
uint32_t sum = 0;
|
||||
size_t count = p->msg.data_length;
|
||||
|
||||
while (count-- > 0) {
|
||||
sum += *x++;
|
||||
for (size_t i = 0; i < p->msg.data_length; ++i) {
|
||||
sum += static_cast<uint8_t>(p->payload[i]);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
apacket* get_apacket(void)
|
||||
{
|
||||
apacket* p = reinterpret_cast<apacket*>(malloc(sizeof(apacket)));
|
||||
apacket* p = new apacket();
|
||||
if (p == nullptr) {
|
||||
fatal("failed to allocate an apacket");
|
||||
}
|
||||
|
||||
memset(p, 0, sizeof(apacket) - MAX_PAYLOAD);
|
||||
memset(&p->msg, 0, sizeof(p->msg));
|
||||
return p;
|
||||
}
|
||||
|
||||
void put_apacket(apacket *p)
|
||||
{
|
||||
free(p);
|
||||
delete p;
|
||||
}
|
||||
|
||||
void handle_online(atransport *t)
|
||||
|
@ -155,8 +151,7 @@ void handle_offline(atransport *t)
|
|||
#define DUMPMAX 32
|
||||
void print_packet(const char *label, apacket *p)
|
||||
{
|
||||
char *tag;
|
||||
char *x;
|
||||
const char* tag;
|
||||
unsigned count;
|
||||
|
||||
switch(p->msg.command){
|
||||
|
@ -173,15 +168,15 @@ void print_packet(const char *label, apacket *p)
|
|||
fprintf(stderr, "%s: %s %08x %08x %04x \"",
|
||||
label, tag, p->msg.arg0, p->msg.arg1, p->msg.data_length);
|
||||
count = p->msg.data_length;
|
||||
x = (char*) p->data;
|
||||
if(count > DUMPMAX) {
|
||||
const char* x = p->payload.data();
|
||||
if (count > DUMPMAX) {
|
||||
count = DUMPMAX;
|
||||
tag = "\n";
|
||||
} else {
|
||||
tag = "\"\n";
|
||||
}
|
||||
while(count-- > 0){
|
||||
if((*x >= ' ') && (*x < 127)) {
|
||||
while (count-- > 0) {
|
||||
if ((*x >= ' ') && (*x < 127)) {
|
||||
fputc(*x, stderr);
|
||||
} else {
|
||||
fputc('.', stderr);
|
||||
|
@ -254,8 +249,8 @@ void send_connect(atransport* t) {
|
|||
<< connection_str.length() << ")";
|
||||
}
|
||||
|
||||
memcpy(cp->data, connection_str.c_str(), connection_str.length());
|
||||
cp->msg.data_length = connection_str.length();
|
||||
cp->payload = std::move(connection_str);
|
||||
cp->msg.data_length = cp->payload.size();
|
||||
|
||||
send_packet(cp, t);
|
||||
}
|
||||
|
@ -329,9 +324,7 @@ static void handle_new_connection(atransport* t, apacket* p) {
|
|||
}
|
||||
|
||||
t->update_version(p->msg.arg0, p->msg.arg1);
|
||||
std::string banner(reinterpret_cast<const char*>(p->data),
|
||||
p->msg.data_length);
|
||||
parse_banner(banner, t);
|
||||
parse_banner(p->payload, t);
|
||||
|
||||
#if ADB_HOST
|
||||
handle_online(t);
|
||||
|
@ -354,6 +347,7 @@ void handle_packet(apacket *p, atransport *t)
|
|||
((char*) (&(p->msg.command)))[2],
|
||||
((char*) (&(p->msg.command)))[3]);
|
||||
print_packet("recv", p);
|
||||
CHECK_EQ(p->payload.size(), p->msg.data_length);
|
||||
|
||||
switch(p->msg.command){
|
||||
case A_SYNC:
|
||||
|
@ -380,11 +374,11 @@ void handle_packet(apacket *p, atransport *t)
|
|||
if (t->GetConnectionState() == kCsOffline) {
|
||||
t->SetConnectionState(kCsUnauthorized);
|
||||
}
|
||||
send_auth_response(p->data, p->msg.data_length, t);
|
||||
send_auth_response(p->payload.data(), p->msg.data_length, t);
|
||||
break;
|
||||
#else
|
||||
case ADB_AUTH_SIGNATURE:
|
||||
if (adbd_auth_verify(t->token, sizeof(t->token), p->data, p->msg.data_length)) {
|
||||
if (adbd_auth_verify(t->token, sizeof(t->token), p->payload)) {
|
||||
adbd_auth_verified(t);
|
||||
t->failed_auth_attempts = 0;
|
||||
} else {
|
||||
|
@ -394,7 +388,7 @@ void handle_packet(apacket *p, atransport *t)
|
|||
break;
|
||||
|
||||
case ADB_AUTH_RSAPUBLICKEY:
|
||||
adbd_auth_confirm_key(p->data, p->msg.data_length, t);
|
||||
adbd_auth_confirm_key(p->payload.data(), p->msg.data_length, t);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
|
@ -406,9 +400,7 @@ void handle_packet(apacket *p, atransport *t)
|
|||
|
||||
case A_OPEN: /* OPEN(local-id, 0, "destination") */
|
||||
if (t->online && p->msg.arg0 != 0 && p->msg.arg1 == 0) {
|
||||
char *name = (char*) p->data;
|
||||
name[p->msg.data_length > 0 ? p->msg.data_length - 1 : 0] = 0;
|
||||
asocket* s = create_local_service_socket(name, t);
|
||||
asocket* s = create_local_service_socket(p->payload.c_str(), t);
|
||||
if (s == nullptr) {
|
||||
send_close(0, p->msg.arg0, t);
|
||||
} else {
|
||||
|
@ -474,11 +466,7 @@ void handle_packet(apacket *p, atransport *t)
|
|||
asocket* s = find_local_socket(p->msg.arg1, p->msg.arg0);
|
||||
if (s) {
|
||||
unsigned rid = p->msg.arg0;
|
||||
|
||||
// TODO: Convert apacket::data to a type that we can move out of.
|
||||
std::string copy(p->data, p->data + p->msg.data_length);
|
||||
|
||||
if (s->enqueue(s, std::move(copy)) == 0) {
|
||||
if (s->enqueue(s, std::move(p->payload)) == 0) {
|
||||
D("Enqueue the socket");
|
||||
send_ready(s->id, rid, t);
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ struct amessage {
|
|||
|
||||
struct apacket {
|
||||
amessage msg;
|
||||
char data[MAX_PAYLOAD];
|
||||
std::string payload;
|
||||
};
|
||||
|
||||
uint32_t calculate_apacket_checksum(const apacket* packet);
|
||||
|
|
|
@ -49,7 +49,7 @@ void adbd_auth_init(void);
|
|||
void adbd_auth_verified(atransport *t);
|
||||
|
||||
void adbd_cloexec_auth_socket();
|
||||
bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len);
|
||||
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig);
|
||||
void adbd_auth_confirm_key(const char* data, size_t len, atransport* t);
|
||||
|
||||
void send_auth_request(atransport *t);
|
||||
|
|
|
@ -299,20 +299,25 @@ std::deque<std::shared_ptr<RSA>> adb_auth_get_private_keys() {
|
|||
return result;
|
||||
}
|
||||
|
||||
static int adb_auth_sign(RSA* key, const char* token, size_t token_size, char* sig) {
|
||||
static std::string adb_auth_sign(RSA* key, const char* token, size_t token_size) {
|
||||
if (token_size != TOKEN_SIZE) {
|
||||
D("Unexpected token size %zd", token_size);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result.resize(MAX_PAYLOAD);
|
||||
|
||||
unsigned int len;
|
||||
if (!RSA_sign(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
|
||||
reinterpret_cast<uint8_t*>(sig), &len, key)) {
|
||||
return 0;
|
||||
reinterpret_cast<uint8_t*>(&result[0]), &len, key)) {
|
||||
return std::string();
|
||||
}
|
||||
|
||||
result.resize(len);
|
||||
|
||||
D("adb_auth_sign len=%d", len);
|
||||
return (int)len;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string adb_auth_get_userkey() {
|
||||
|
@ -446,13 +451,14 @@ static void send_auth_publickey(atransport* t) {
|
|||
}
|
||||
|
||||
apacket* p = get_apacket();
|
||||
memcpy(p->data, key.c_str(), key.size() + 1);
|
||||
|
||||
p->msg.command = A_AUTH;
|
||||
p->msg.arg0 = ADB_AUTH_RSAPUBLICKEY;
|
||||
|
||||
p->payload = std::move(key);
|
||||
|
||||
// adbd expects a null-terminated string.
|
||||
p->msg.data_length = key.size() + 1;
|
||||
p->payload.push_back('\0');
|
||||
p->msg.data_length = p->payload.size();
|
||||
send_packet(p, t);
|
||||
}
|
||||
|
||||
|
@ -467,8 +473,8 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) {
|
|||
LOG(INFO) << "Calling send_auth_response";
|
||||
apacket* p = get_apacket();
|
||||
|
||||
int ret = adb_auth_sign(key.get(), token, token_size, p->data);
|
||||
if (!ret) {
|
||||
std::string result = adb_auth_sign(key.get(), token, token_size);
|
||||
if (result.empty()) {
|
||||
D("Error signing the token");
|
||||
put_apacket(p);
|
||||
return;
|
||||
|
@ -476,6 +482,7 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) {
|
|||
|
||||
p->msg.command = A_AUTH;
|
||||
p->msg.arg0 = ADB_AUTH_SIGNATURE;
|
||||
p->msg.data_length = ret;
|
||||
p->payload = std::move(result);
|
||||
p->msg.data_length = p->payload.size();
|
||||
send_packet(p, t);
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ static bool needs_retry = false;
|
|||
|
||||
bool auth_required = true;
|
||||
|
||||
bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int sig_len) {
|
||||
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig) {
|
||||
static constexpr const char* key_paths[] = { "/adb_keys", "/data/misc/adb/adb_keys", nullptr };
|
||||
|
||||
for (const auto& path : key_paths) {
|
||||
|
@ -80,7 +80,8 @@ bool adbd_auth_verify(const char* token, size_t token_size, const char* sig, int
|
|||
|
||||
bool verified =
|
||||
(RSA_verify(NID_sha1, reinterpret_cast<const uint8_t*>(token), token_size,
|
||||
reinterpret_cast<const uint8_t*>(sig), sig_len, key) == 1);
|
||||
reinterpret_cast<const uint8_t*>(sig.c_str()), sig.size(),
|
||||
key) == 1);
|
||||
RSA_free(key);
|
||||
if (verified) return true;
|
||||
}
|
||||
|
@ -210,10 +211,10 @@ void send_auth_request(atransport* t) {
|
|||
}
|
||||
|
||||
apacket* p = get_apacket();
|
||||
memcpy(p->data, t->token, sizeof(t->token));
|
||||
p->msg.command = A_AUTH;
|
||||
p->msg.arg0 = ADB_AUTH_TOKEN;
|
||||
p->msg.data_length = sizeof(t->token);
|
||||
p->payload.assign(t->token, t->token + sizeof(t->token));
|
||||
send_packet(p, t);
|
||||
}
|
||||
|
||||
|
|
|
@ -413,15 +413,15 @@ static int remote_socket_enqueue(asocket* s, std::string data) {
|
|||
p->msg.command = A_WRTE;
|
||||
p->msg.arg0 = s->peer->id;
|
||||
p->msg.arg1 = s->id;
|
||||
p->msg.data_length = data.size();
|
||||
|
||||
if (data.size() > sizeof(p->data)) {
|
||||
if (data.size() > MAX_PAYLOAD) {
|
||||
put_apacket(p);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO: Convert apacket::data to a type that we can move into.
|
||||
memcpy(p->data, data.data(), data.size());
|
||||
p->payload = std::move(data);
|
||||
p->msg.data_length = p->payload.size();
|
||||
|
||||
send_packet(p, s->transport);
|
||||
return 1;
|
||||
}
|
||||
|
@ -482,17 +482,20 @@ asocket* create_remote_socket(unsigned id, atransport* t) {
|
|||
void connect_to_remote(asocket* s, const char* destination) {
|
||||
D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
|
||||
apacket* p = get_apacket();
|
||||
size_t len = strlen(destination) + 1;
|
||||
|
||||
if (len > (s->get_max_payload() - 1)) {
|
||||
fatal("destination oversized");
|
||||
}
|
||||
|
||||
D("LS(%d): connect('%s')", s->id, destination);
|
||||
p->msg.command = A_OPEN;
|
||||
p->msg.arg0 = s->id;
|
||||
p->msg.data_length = len;
|
||||
strcpy((char*)p->data, destination);
|
||||
|
||||
// adbd expects a null-terminated string.
|
||||
p->payload = destination;
|
||||
p->payload.push_back('\0');
|
||||
p->msg.data_length = p->payload.size();
|
||||
|
||||
if (p->msg.data_length > s->get_max_payload()) {
|
||||
fatal("destination oversized");
|
||||
}
|
||||
|
||||
send_packet(p, s->transport);
|
||||
}
|
||||
|
||||
|
|
|
@ -72,12 +72,14 @@ bool FdConnection::Read(apacket* packet) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (packet->msg.data_length > sizeof(packet->data)) {
|
||||
if (packet->msg.data_length > MAX_PAYLOAD) {
|
||||
D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) {
|
||||
packet->payload.resize(packet->msg.data_length);
|
||||
|
||||
if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
|
||||
D("remote local: terminated (data)");
|
||||
return false;
|
||||
}
|
||||
|
@ -86,13 +88,18 @@ bool FdConnection::Read(apacket* packet) {
|
|||
}
|
||||
|
||||
bool FdConnection::Write(apacket* packet) {
|
||||
uint32_t length = packet->msg.data_length;
|
||||
|
||||
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) {
|
||||
if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
|
||||
D("remote local: write terminated");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (packet->msg.data_length) {
|
||||
if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
|
||||
D("remote local: write terminated");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -133,7 +140,7 @@ static std::string dump_packet(const char* name, const char* func, apacket* p) {
|
|||
|
||||
std::string result = android::base::StringPrintf("%s: %s: [%s] arg0=%s arg1=%s (len=%d) ", name,
|
||||
func, cmd, arg0, arg1, len);
|
||||
result += dump_hex(p->data, len);
|
||||
result += dump_hex(p->payload.data(), p->payload.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -191,9 +198,10 @@ static void transport_socket_events(int fd, unsigned events, void* _t) {
|
|||
apacket* p = 0;
|
||||
if (read_packet(fd, t->serial, &p)) {
|
||||
D("%s: failed to read packet from transport socket on fd %d", t->serial, fd);
|
||||
} else {
|
||||
handle_packet(p, (atransport*)_t);
|
||||
return;
|
||||
}
|
||||
|
||||
handle_packet(p, (atransport*)_t);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,6 +251,7 @@ static void read_transport_thread(void* _t) {
|
|||
p->msg.arg0 = 1;
|
||||
p->msg.arg1 = ++(t->sync_token);
|
||||
p->msg.magic = A_SYNC ^ 0xffffffff;
|
||||
D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size());
|
||||
if (write_packet(t->fd, t->serial, &p)) {
|
||||
put_apacket(p);
|
||||
D("%s: failed to write SYNC packet", t->serial);
|
||||
|
@ -336,6 +345,13 @@ static void write_transport_thread(void* _t) {
|
|||
if (active) {
|
||||
D("%s: transport got packet, sending to remote", t->serial);
|
||||
ATRACE_NAME("write_transport write_remote");
|
||||
|
||||
// Allow sending the payload's implicit null terminator.
|
||||
if (p->msg.data_length != p->payload.size()) {
|
||||
LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = "
|
||||
<< p->msg.data_length << ", payload.size() = " << p->payload.size();
|
||||
}
|
||||
|
||||
if (t->Write(p) != 0) {
|
||||
D("%s: remote write failed for transport", t->serial);
|
||||
put_apacket(p);
|
||||
|
|
|
@ -61,13 +61,12 @@ static int UsbReadMessage(usb_handle* h, amessage* msg) {
|
|||
static int UsbReadPayload(usb_handle* h, apacket* p) {
|
||||
D("UsbReadPayload(%d)", p->msg.data_length);
|
||||
|
||||
if (p->msg.data_length > sizeof(p->data)) {
|
||||
if (p->msg.data_length > MAX_PAYLOAD) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if CHECK_PACKET_OVERFLOW
|
||||
size_t usb_packet_size = usb_get_max_packet_size(h);
|
||||
CHECK_EQ(0ULL, sizeof(p->data) % usb_packet_size);
|
||||
|
||||
// Round the data length up to the nearest packet size boundary.
|
||||
// The device won't send a zero packet for packet size aligned payloads,
|
||||
|
@ -77,10 +76,18 @@ static int UsbReadPayload(usb_handle* h, apacket* p) {
|
|||
if (rem_size) {
|
||||
len += usb_packet_size - rem_size;
|
||||
}
|
||||
CHECK_LE(len, sizeof(p->data));
|
||||
return usb_read(h, &p->data, len);
|
||||
|
||||
p->payload.resize(len);
|
||||
int rc = usb_read(h, &p->payload[0], p->payload.size());
|
||||
if (rc != static_cast<int>(p->msg.data_length)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
p->payload.resize(rc);
|
||||
return rc;
|
||||
#else
|
||||
return usb_read(h, &p->data, p->msg.data_length);
|
||||
p->payload.resize(p->msg.data_length);
|
||||
return usb_read(h, &p->payload[0], p->payload.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -120,12 +127,13 @@ static int remote_read(apacket* p, usb_handle* usb) {
|
|||
}
|
||||
|
||||
if (p->msg.data_length) {
|
||||
if (p->msg.data_length > sizeof(p->data)) {
|
||||
if (p->msg.data_length > MAX_PAYLOAD) {
|
||||
PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (usb_read(usb, p->data, p->msg.data_length)) {
|
||||
p->payload.resize(p->msg.data_length);
|
||||
if (usb_read(usb, &p->payload[0], p->payload.size())) {
|
||||
PLOG(ERROR) << "remote usb: terminated (data)";
|
||||
return -1;
|
||||
}
|
||||
|
@ -152,7 +160,7 @@ bool UsbConnection::Write(apacket* packet) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) {
|
||||
if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != 0) {
|
||||
PLOG(ERROR) << "remote usb: 2 - write terminated";
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue