diff --git a/include/net/tls.h b/include/net/tls.h index bf9eb4823933..18cd4f418464 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -135,6 +135,8 @@ struct tls_sw_context_tx { struct tls_rec *open_rec; struct list_head tx_list; atomic_t encrypt_pending; + /* protect crypto_wait with encrypt_pending */ + spinlock_t encrypt_compl_lock; int async_notify; u8 async_capable:1; @@ -155,6 +157,8 @@ struct tls_sw_context_rx { u8 async_capable:1; u8 decrypted:1; atomic_t decrypt_pending; + /* protect crypto_wait with decrypt_pending*/ + spinlock_t decrypt_compl_lock; bool async_notify; }; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 2d399b6c4075..8c2763eb6aae 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -206,10 +206,12 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) kfree(aead_req); + spin_lock_bh(&ctx->decrypt_compl_lock); pending = atomic_dec_return(&ctx->decrypt_pending); - if (!pending && READ_ONCE(ctx->async_notify)) + if (!pending && ctx->async_notify) complete(&ctx->async_wait.completion); + spin_unlock_bh(&ctx->decrypt_compl_lock); } static int tls_do_decryption(struct sock *sk, @@ -467,10 +469,12 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) ready = true; } + spin_lock_bh(&ctx->encrypt_compl_lock); pending = atomic_dec_return(&ctx->encrypt_pending); - if (!pending && READ_ONCE(ctx->async_notify)) + if (!pending && ctx->async_notify) complete(&ctx->async_wait.completion); + spin_unlock_bh(&ctx->encrypt_compl_lock); if (!ready) return; @@ -929,6 +933,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int num_zc = 0; int orig_size; int ret = 0; + int pending; if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -EOPNOTSUPP; @@ -1095,13 +1100,19 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto send_end; } else if (num_zc) { /* Wait for pending encryptions to get completed */ - smp_store_mb(ctx->async_notify, true); + spin_lock_bh(&ctx->encrypt_compl_lock); + ctx->async_notify = true; - if (atomic_read(&ctx->encrypt_pending)) + pending = atomic_read(&ctx->encrypt_pending); + spin_unlock_bh(&ctx->encrypt_compl_lock); + if (pending) crypto_wait_req(-EINPROGRESS, &ctx->async_wait); else reinit_completion(&ctx->async_wait.completion); + /* There can be no concurrent accesses, since we have no + * pending encrypt operations + */ WRITE_ONCE(ctx->async_notify, false); if (ctx->async_wait.err) { @@ -1732,6 +1743,7 @@ int tls_sw_recvmsg(struct sock *sk, bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; int num_async = 0; + int pending; flags |= nonblock; @@ -1894,8 +1906,11 @@ int tls_sw_recvmsg(struct sock *sk, recv_end: if (num_async) { /* Wait for all previously submitted records to be decrypted */ - smp_store_mb(ctx->async_notify, true); - if (atomic_read(&ctx->decrypt_pending)) { + spin_lock_bh(&ctx->decrypt_compl_lock); + ctx->async_notify = true; + pending = atomic_read(&ctx->decrypt_pending); + spin_unlock_bh(&ctx->decrypt_compl_lock); + if (pending) { err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); if (err) { /* one of async decrypt failed */ @@ -1907,6 +1922,10 @@ int tls_sw_recvmsg(struct sock *sk, } else { reinit_completion(&ctx->async_wait.completion); } + + /* There can be no concurrent accesses, since we have no + * pending decrypt operations + */ WRITE_ONCE(ctx->async_notify, false); /* Drain records from the rx_list & copy if required */ @@ -2293,6 +2312,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) if (tx) { crypto_init_wait(&sw_ctx_tx->async_wait); + spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); crypto_info = &ctx->crypto_send.info; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; @@ -2301,6 +2321,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) sw_ctx_tx->tx_work.sk = sk; } else { crypto_init_wait(&sw_ctx_rx->async_wait); + spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; skb_queue_head_init(&sw_ctx_rx->rx_list);