diff --git a/include/uapi/linux/tcp.h b/include/uapi/linux/tcp.h index cfcb10b75483..62db78b9c1a0 100644 --- a/include/uapi/linux/tcp.h +++ b/include/uapi/linux/tcp.h @@ -349,5 +349,7 @@ struct tcp_zerocopy_receive { __u32 recv_skip_hint; /* out: amount of bytes to skip */ __u32 inq; /* out: amount of bytes in read queue */ __s32 err; /* out: socket error */ + __u64 copybuf_address; /* in: copybuf address (small reads) */ + __s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */ }; #endif /* _UAPI_LINUX_TCP_H */ diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 75a28b8f4470..0ad70097da59 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -1758,6 +1758,52 @@ int tcp_mmap(struct file *file, struct socket *sock, } EXPORT_SYMBOL(tcp_mmap); +static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, + struct sk_buff *skb, u32 copylen, + u32 *offset, u32 *seq) +{ + unsigned long copy_address = (unsigned long)zc->copybuf_address; + struct msghdr msg = {}; + struct iovec iov; + int err; + + if (copy_address != zc->copybuf_address) + return -EINVAL; + + err = import_single_range(READ, (void __user *)copy_address, + copylen, &iov, &msg.msg_iter); + if (err) + return err; + err = skb_copy_datagram_msg(skb, *offset, &msg, copylen); + if (err) + return err; + zc->recv_skip_hint -= copylen; + *offset += copylen; + *seq += copylen; + return (__s32)copylen; +} + +static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc, + struct sock *sk, + struct sk_buff *skb, + u32 *seq, + s32 copybuf_len) +{ + u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint); + + if (!copylen) + return 0; + /* skb is null if inq < PAGE_SIZE. */ + if (skb) + offset = *seq - TCP_SKB_CB(skb)->seq; + else + skb = tcp_recv_skb(sk, *seq, &offset); + + zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset, + seq); + return zc->copybuf_len < 0 ? 0 : copylen; +} + static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma, struct page **pages, unsigned long pages_to_map, @@ -1791,8 +1837,10 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma, static int tcp_zerocopy_receive(struct sock *sk, struct tcp_zerocopy_receive *zc) { + u32 length = 0, offset, vma_len, avail_len, aligned_len, copylen = 0; unsigned long address = (unsigned long)zc->address; - u32 length = 0, seq, offset, zap_len; + s32 copybuf_len = zc->copybuf_len; + struct tcp_sock *tp = tcp_sk(sk); #define PAGE_BATCH_SIZE 8 struct page *pages[PAGE_BATCH_SIZE]; const skb_frag_t *frags = NULL; @@ -1800,10 +1848,12 @@ static int tcp_zerocopy_receive(struct sock *sk, struct sk_buff *skb = NULL; unsigned long pg_idx = 0; unsigned long curr_addr; - struct tcp_sock *tp; - int inq; + u32 seq = tp->copied_seq; + int inq = tcp_inq(sk); int ret; + zc->copybuf_len = 0; + if (address & (PAGE_SIZE - 1) || address != zc->address) return -EINVAL; @@ -1812,8 +1862,6 @@ static int tcp_zerocopy_receive(struct sock *sk, sock_rps_record_flow(sk); - tp = tcp_sk(sk); - mmap_read_lock(current->mm); vma = find_vma(current->mm, address); @@ -1821,17 +1869,16 @@ static int tcp_zerocopy_receive(struct sock *sk, mmap_read_unlock(current->mm); return -EINVAL; } - zc->length = min_t(unsigned long, zc->length, vma->vm_end - address); - - seq = tp->copied_seq; - inq = tcp_inq(sk); - zc->length = min_t(u32, zc->length, inq); - zap_len = zc->length & ~(PAGE_SIZE - 1); - if (zap_len) { - zap_page_range(vma, address, zap_len); + vma_len = min_t(unsigned long, zc->length, vma->vm_end - address); + avail_len = min_t(u32, vma_len, inq); + aligned_len = avail_len & ~(PAGE_SIZE - 1); + if (aligned_len) { + zap_page_range(vma, address, aligned_len); + zc->length = aligned_len; zc->recv_skip_hint = 0; } else { - zc->recv_skip_hint = zc->length; + zc->length = avail_len; + zc->recv_skip_hint = avail_len; } ret = 0; curr_addr = address; @@ -1900,13 +1947,18 @@ static int tcp_zerocopy_receive(struct sock *sk, } out: mmap_read_unlock(current->mm); - if (length) { + /* Try to copy straggler data. */ + if (!ret) + copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq, + copybuf_len); + + if (length + copylen) { WRITE_ONCE(tp->copied_seq, seq); tcp_rcv_space_adjust(sk); /* Clean up data we have read: This will do ACK frames. */ tcp_recv_skb(sk, seq, &offset); - tcp_cleanup_rbuf(sk, length); + tcp_cleanup_rbuf(sk, length + copylen); ret = 0; if (length == zc->length) zc->recv_skip_hint = 0;