diff --git a/drivers/net/virtio_net.c b/drivers/net/virtio_net.c
index 669b65c232a2..cea52e47dc65 100644
--- a/drivers/net/virtio_net.c
+++ b/drivers/net/virtio_net.c
@@ -2410,6 +2410,10 @@ static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog,
 		return -ENOMEM;
 	}
 
+	old_prog = rtnl_dereference(vi->rq[0].xdp_prog);
+	if (!prog && !old_prog)
+		return 0;
+
 	if (prog) {
 		prog = bpf_prog_add(prog, vi->max_queue_pairs - 1);
 		if (IS_ERR(prog))
@@ -2424,21 +2428,30 @@ static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog,
 		}
 	}
 
+	if (!prog) {
+		for (i = 0; i < vi->max_queue_pairs; i++) {
+			rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
+			if (i == 0)
+				virtnet_restore_guest_offloads(vi);
+		}
+		synchronize_net();
+	}
+
 	err = _virtnet_set_queues(vi, curr_qp + xdp_qp);
 	if (err)
 		goto err;
 	netif_set_real_num_rx_queues(dev, curr_qp + xdp_qp);
 	vi->xdp_queue_pairs = xdp_qp;
 
-	for (i = 0; i < vi->max_queue_pairs; i++) {
-		old_prog = rtnl_dereference(vi->rq[i].xdp_prog);
-		rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
-		if (i == 0) {
-			if (!old_prog)
+	if (prog) {
+		for (i = 0; i < vi->max_queue_pairs; i++) {
+			rcu_assign_pointer(vi->rq[i].xdp_prog, prog);
+			if (i == 0 && !old_prog)
 				virtnet_clear_guest_offloads(vi);
-			if (!prog)
-				virtnet_restore_guest_offloads(vi);
 		}
+	}
+
+	for (i = 0; i < vi->max_queue_pairs; i++) {
 		if (old_prog)
 			bpf_prog_put(old_prog);
 		if (netif_running(dev)) {
@@ -2451,6 +2464,12 @@ static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog,
 	return 0;
 
 err:
+	if (!prog) {
+		virtnet_clear_guest_offloads(vi);
+		for (i = 0; i < vi->max_queue_pairs; i++)
+			rcu_assign_pointer(vi->rq[i].xdp_prog, old_prog);
+	}
+
 	if (netif_running(dev)) {
 		for (i = 0; i < vi->max_queue_pairs; i++) {
 			virtnet_napi_enable(vi->rq[i].vq, &vi->rq[i].napi);