diff --git a/tools/testing/selftests/net/msg_zerocopy.c b/tools/testing/selftests/net/msg_zerocopy.c index 7a5b35362f84..5cc2a53bb71c 100644 --- a/tools/testing/selftests/net/msg_zerocopy.c +++ b/tools/testing/selftests/net/msg_zerocopy.c @@ -168,17 +168,39 @@ static int do_accept(int fd) return fd; } -static bool do_sendmsg(int fd, struct msghdr *msg, bool do_zerocopy) +static void add_zcopy_cookie(struct msghdr *msg, uint32_t cookie) +{ + struct cmsghdr *cm; + + if (!msg->msg_control) + error(1, errno, "NULL cookie"); + cm = (void *)msg->msg_control; + cm->cmsg_len = CMSG_LEN(sizeof(cookie)); + cm->cmsg_level = SOL_RDS; + cm->cmsg_type = RDS_CMSG_ZCOPY_COOKIE; + memcpy(CMSG_DATA(cm), &cookie, sizeof(cookie)); +} + +static bool do_sendmsg(int fd, struct msghdr *msg, bool do_zerocopy, int domain) { int ret, len, i, flags; + static uint32_t cookie; + char ckbuf[CMSG_SPACE(sizeof(cookie))]; len = 0; for (i = 0; i < msg->msg_iovlen; i++) len += msg->msg_iov[i].iov_len; flags = MSG_DONTWAIT; - if (do_zerocopy) + if (do_zerocopy) { flags |= MSG_ZEROCOPY; + if (domain == PF_RDS) { + memset(&msg->msg_control, 0, sizeof(msg->msg_control)); + msg->msg_controllen = CMSG_SPACE(sizeof(cookie)); + msg->msg_control = (struct cmsghdr *)ckbuf; + add_zcopy_cookie(msg, ++cookie); + } + } ret = sendmsg(fd, msg, flags); if (ret == -1 && errno == EAGAIN) @@ -194,6 +216,10 @@ static bool do_sendmsg(int fd, struct msghdr *msg, bool do_zerocopy) if (do_zerocopy && ret) expected_completions++; } + if (do_zerocopy && domain == PF_RDS) { + msg->msg_control = NULL; + msg->msg_controllen = 0; + } return true; } @@ -220,7 +246,9 @@ static void do_sendmsg_corked(int fd, struct msghdr *msg) msg->msg_iov[0].iov_len = payload_len + extra_len; extra_len = 0; - do_sendmsg(fd, msg, do_zerocopy); + do_sendmsg(fd, msg, do_zerocopy, + (cfg_dst_addr.ss_family == AF_INET ? + PF_INET : PF_INET6)); } do_setsockopt(fd, IPPROTO_UDP, UDP_CORK, 0); @@ -316,6 +344,26 @@ static int do_setup_tx(int domain, int type, int protocol) return fd; } +static int do_process_zerocopy_cookies(struct sock_extended_err *serr, + uint32_t *ckbuf, size_t nbytes) +{ + int ncookies, i; + + if (serr->ee_errno != 0) + error(1, 0, "serr: wrong error code: %u", serr->ee_errno); + ncookies = serr->ee_data; + if (ncookies > SO_EE_ORIGIN_MAX_ZCOOKIES) + error(1, 0, "Returned %d cookies, max expected %d\n", + ncookies, SO_EE_ORIGIN_MAX_ZCOOKIES); + if (nbytes != ncookies * sizeof(uint32_t)) + error(1, 0, "Expected %d cookies, got %ld\n", + ncookies, nbytes/sizeof(uint32_t)); + for (i = 0; i < ncookies; i++) + if (cfg_verbose >= 2) + fprintf(stderr, "%d\n", ckbuf[i]); + return ncookies; +} + static bool do_recv_completion(int fd) { struct sock_extended_err *serr; @@ -324,10 +372,17 @@ static bool do_recv_completion(int fd) uint32_t hi, lo, range; int ret, zerocopy; char control[100]; + uint32_t ckbuf[SO_EE_ORIGIN_MAX_ZCOOKIES]; + struct iovec iov; msg.msg_control = control; msg.msg_controllen = sizeof(control); + iov.iov_base = ckbuf; + iov.iov_len = (SO_EE_ORIGIN_MAX_ZCOOKIES * sizeof(ckbuf[0])); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + ret = recvmsg(fd, &msg, MSG_ERRQUEUE); if (ret == -1 && errno == EAGAIN) return false; @@ -346,6 +401,11 @@ static bool do_recv_completion(int fd) cm->cmsg_level, cm->cmsg_type); serr = (void *) CMSG_DATA(cm); + + if (serr->ee_origin == SO_EE_ORIGIN_ZCOOKIE) { + completions += do_process_zerocopy_cookies(serr, ckbuf, ret); + return true; + } if (serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY) error(1, 0, "serr: wrong origin: %u", serr->ee_origin); if (serr->ee_errno != 0) @@ -470,7 +530,7 @@ static void do_tx(int domain, int type, int protocol) if (cfg_cork) do_sendmsg_corked(fd, &msg); else - do_sendmsg(fd, &msg, cfg_zerocopy); + do_sendmsg(fd, &msg, cfg_zerocopy, domain); while (!do_poll(fd, POLLOUT)) { if (cfg_zerocopy)