diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index b6afe01f8592..14d61bba0b79 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -359,17 +359,21 @@ static inline void sk_psock_restore_proto(struct sock *sk, struct sk_psock *psock) { sk->sk_prot->unhash = psock->saved_unhash; - sk->sk_write_space = psock->saved_write_space; if (psock->sk_proto) { struct inet_connection_sock *icsk = inet_csk(sk); bool has_ulp = !!icsk->icsk_ulp_data; - if (has_ulp) - tcp_update_ulp(sk, psock->sk_proto); - else + if (has_ulp) { + tcp_update_ulp(sk, psock->sk_proto, + psock->saved_write_space); + } else { sk->sk_prot = psock->sk_proto; + sk->sk_write_space = psock->saved_write_space; + } psock->sk_proto = NULL; + } else { + sk->sk_write_space = psock->saved_write_space; } } diff --git a/include/net/tcp.h b/include/net/tcp.h index e460ea7f767b..e6f48384dc71 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2147,7 +2147,8 @@ struct tcp_ulp_ops { /* initialize ulp */ int (*init)(struct sock *sk); /* update ulp */ - void (*update)(struct sock *sk, struct proto *p); + void (*update)(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)); /* cleanup ulp */ void (*release)(struct sock *sk); /* diagnostic */ @@ -2162,7 +2163,8 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type); int tcp_set_ulp(struct sock *sk, const char *name); void tcp_get_available_ulp(char *buf, size_t len); void tcp_cleanup_ulp(struct sock *sk); -void tcp_update_ulp(struct sock *sk, struct proto *p); +void tcp_update_ulp(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)); #define MODULE_ALIAS_TCP_ULP(name) \ __MODULE_INFO(alias, alias_userspace, name); \ diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index 12ab5db2b71c..38d3ad141161 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -99,17 +99,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen) rcu_read_unlock(); } -void tcp_update_ulp(struct sock *sk, struct proto *proto) +void tcp_update_ulp(struct sock *sk, struct proto *proto, + void (*write_space)(struct sock *sk)) { struct inet_connection_sock *icsk = inet_csk(sk); if (!icsk->icsk_ulp_ops) { + sk->sk_write_space = write_space; sk->sk_prot = proto; return; } if (icsk->icsk_ulp_ops->update) - icsk->icsk_ulp_ops->update(sk, proto); + icsk->icsk_ulp_ops->update(sk, proto, write_space); } void tcp_cleanup_ulp(struct sock *sk) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index dac24c7aa7d4..94774c0e5ff3 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -732,15 +732,19 @@ static int tls_init(struct sock *sk) return rc; } -static void tls_update(struct sock *sk, struct proto *p) +static void tls_update(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)) { struct tls_context *ctx; ctx = tls_get_ctx(sk); - if (likely(ctx)) + if (likely(ctx)) { + ctx->sk_write_space = write_space; ctx->sk_proto = p; - else + } else { sk->sk_prot = p; + sk->sk_write_space = write_space; + } } static int tls_get_info(const struct sock *sk, struct sk_buff *skb)