diff --git a/drivers/net/tap.c b/drivers/net/tap.c index e9489b88407c..0a886fda0129 100644 --- a/drivers/net/tap.c +++ b/drivers/net/tap.c @@ -829,8 +829,11 @@ static ssize_t tap_do_read(struct tap_queue *q, DEFINE_WAIT(wait); ssize_t ret = 0; - if (!iov_iter_count(to)) + if (!iov_iter_count(to)) { + if (skb) + kfree_skb(skb); return 0; + } if (skb) goto put; @@ -1154,11 +1157,14 @@ static int tap_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len, int flags) { struct tap_queue *q = container_of(sock, struct tap_queue, sock); + struct sk_buff *skb = m->msg_control; int ret; - if (flags & ~(MSG_DONTWAIT|MSG_TRUNC)) + if (flags & ~(MSG_DONTWAIT|MSG_TRUNC)) { + if (skb) + kfree_skb(skb); return -EINVAL; - ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT, - m->msg_control); + } + ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT, skb); if (ret > total_len) { m->msg_flags |= MSG_TRUNC; ret = flags & MSG_TRUNC ? ret : total_len; diff --git a/drivers/net/tun.c b/drivers/net/tun.c index 95749006d687..4f4a842a1c9c 100644 --- a/drivers/net/tun.c +++ b/drivers/net/tun.c @@ -1952,8 +1952,11 @@ static ssize_t tun_do_read(struct tun_struct *tun, struct tun_file *tfile, tun_debug(KERN_INFO, tun, "tun_do_read\n"); - if (!iov_iter_count(to)) + if (!iov_iter_count(to)) { + if (skb) + kfree_skb(skb); return 0; + } if (!skb) { /* Read frames from ring */ @@ -2069,22 +2072,24 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len, { struct tun_file *tfile = container_of(sock, struct tun_file, socket); struct tun_struct *tun = tun_get(tfile); + struct sk_buff *skb = m->msg_control; int ret; - if (!tun) - return -EBADFD; + if (!tun) { + ret = -EBADFD; + goto out_free_skb; + } if (flags & ~(MSG_DONTWAIT|MSG_TRUNC|MSG_ERRQUEUE)) { ret = -EINVAL; - goto out; + goto out_put_tun; } if (flags & MSG_ERRQUEUE) { ret = sock_recv_errqueue(sock->sk, m, total_len, SOL_PACKET, TUN_TX_TIMESTAMP); goto out; } - ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT, - m->msg_control); + ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT, skb); if (ret > (ssize_t)total_len) { m->msg_flags |= MSG_TRUNC; ret = flags & MSG_TRUNC ? ret : total_len; @@ -2092,6 +2097,13 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len, out: tun_put(tun); return ret; + +out_put_tun: + tun_put(tun); +out_free_skb: + if (skb) + kfree_skb(skb); + return ret; } static int tun_peek_len(struct socket *sock) diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 8d626d7c2e7e..c7bdeb655646 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -778,16 +778,6 @@ static void handle_rx(struct vhost_net *net) /* On error, stop handling until the next kick. */ if (unlikely(headcount < 0)) goto out; - if (nvq->rx_array) - msg.msg_control = vhost_net_buf_consume(&nvq->rxq); - /* On overrun, truncate and discard */ - if (unlikely(headcount > UIO_MAXIOV)) { - iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1); - err = sock->ops->recvmsg(sock, &msg, - 1, MSG_DONTWAIT | MSG_TRUNC); - pr_debug("Discarded rx packet: len %zd\n", sock_len); - continue; - } /* OK, now we need to know about added descriptors. */ if (!headcount) { if (unlikely(vhost_enable_notify(&net->dev, vq))) { @@ -800,6 +790,16 @@ static void handle_rx(struct vhost_net *net) * they refilled. */ goto out; } + if (nvq->rx_array) + msg.msg_control = vhost_net_buf_consume(&nvq->rxq); + /* On overrun, truncate and discard */ + if (unlikely(headcount > UIO_MAXIOV)) { + iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1); + err = sock->ops->recvmsg(sock, &msg, + 1, MSG_DONTWAIT | MSG_TRUNC); + pr_debug("Discarded rx packet: len %zd\n", sock_len); + continue; + } /* We don't need to be notified again. */ iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len); fixup = msg.msg_iter;