diff --git a/fs/io_uring.c b/fs/io_uring.c index 7813bc7d5b61..bda27b52fd5b 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -1806,6 +1806,7 @@ static void io_poll_complete_work(struct io_wq_work **workptr) struct io_poll_iocb *poll = &req->poll; struct poll_table_struct pt = { ._key = poll->events }; struct io_ring_ctx *ctx = req->ctx; + struct io_kiocb *nxt = NULL; __poll_t mask = 0; if (work->flags & IO_WQ_WORK_CANCEL) @@ -1832,7 +1833,10 @@ static void io_poll_complete_work(struct io_wq_work **workptr) spin_unlock_irq(&ctx->completion_lock); io_cqring_ev_posted(ctx); - io_put_req(req, NULL); + + io_put_req(req, &nxt); + if (nxt) + *workptr = &nxt->work; } static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, @@ -1886,7 +1890,8 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head, add_wait_queue(head, &pt->req->poll.wait); } -static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe) +static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, + struct io_kiocb **nxt) { struct io_poll_iocb *poll = &req->poll; struct io_ring_ctx *ctx = req->ctx; @@ -1949,7 +1954,7 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe) if (mask) { io_cqring_ev_posted(ctx); - io_put_req(req, NULL); + io_put_req(req, nxt); } return ipt.error; } @@ -2238,7 +2243,7 @@ static int __io_submit_sqe(struct io_ring_ctx *ctx, struct io_kiocb *req, ret = io_fsync(req, s->sqe, nxt, force_nonblock); break; case IORING_OP_POLL_ADD: - ret = io_poll_add(req, s->sqe); + ret = io_poll_add(req, s->sqe, nxt); break; case IORING_OP_POLL_REMOVE: ret = io_poll_remove(req, s->sqe);