tls: rx: use async as an in-out argument

Propagating EINPROGRESS thru multiple layers of functions is
error prone. Use darg->async as an in/out argument, like we
use darg->zc today. On input it tells the code if async is
allowed, on output if it took place.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Jakub Kicinski 2022-04-11 12:19:15 -07:00 committed by David S. Miller
parent f314bfee81
commit 3547a1f9d9
1 changed files with 16 additions and 15 deletions

View File

@ -227,7 +227,7 @@ static int tls_do_decryption(struct sock *sk,
char *iv_recv, char *iv_recv,
size_t data_len, size_t data_len,
struct aead_request *aead_req, struct aead_request *aead_req,
bool async) struct tls_decrypt_arg *darg)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
@ -240,7 +240,7 @@ static int tls_do_decryption(struct sock *sk,
data_len + prot->tag_size, data_len + prot->tag_size,
(u8 *)iv_recv); (u8 *)iv_recv);
if (async) { if (darg->async) {
/* Using skb->sk to push sk through to crypto async callback /* Using skb->sk to push sk through to crypto async callback
* handler. This allows propagating errors up to the socket * handler. This allows propagating errors up to the socket
* if needed. It _must_ be cleared in the async handler * if needed. It _must_ be cleared in the async handler
@ -260,11 +260,13 @@ static int tls_do_decryption(struct sock *sk,
ret = crypto_aead_decrypt(aead_req); ret = crypto_aead_decrypt(aead_req);
if (ret == -EINPROGRESS) { if (ret == -EINPROGRESS) {
if (async) if (darg->async)
return ret; return 0;
ret = crypto_wait_req(ret, &ctx->async_wait); ret = crypto_wait_req(ret, &ctx->async_wait);
} }
darg->async = false;
if (ret == -EBADMSG) if (ret == -EBADMSG)
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
@ -1536,9 +1538,9 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv, err = tls_do_decryption(sk, skb, sgin, sgout, iv,
data_len, aead_req, darg->async); data_len, aead_req, darg);
if (err == -EINPROGRESS) if (darg->async)
return err; return 0;
/* Release the pages in case iov was mapped to pages */ /* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--) for (; pages > 0; pages--)
@ -1575,11 +1577,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
} }
err = decrypt_internal(sk, skb, dest, NULL, darg); err = decrypt_internal(sk, skb, dest, NULL, darg);
if (err < 0) { if (err < 0)
if (err == -EINPROGRESS)
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
return err; return err;
} if (darg->async)
goto decrypt_next;
decrypt_done: decrypt_done:
pad = padding_length(prot, skb); pad = padding_length(prot, skb);
@ -1589,8 +1590,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
rxm->full_len -= pad; rxm->full_len -= pad;
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
tlm->decrypted = 1; tlm->decrypted = 1;
decrypt_next:
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
return 0; return 0;
} }
@ -1796,13 +1798,12 @@ int tls_sw_recvmsg(struct sock *sk,
darg.async = false; darg.async = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
if (err < 0 && err != -EINPROGRESS) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto recv_end; goto recv_end;
} }
if (err == -EINPROGRESS) async |= darg.async;
async = true;
/* If the type of records being processed is not known yet, /* If the type of records being processed is not known yet,
* set it to record type just dequeued. If it is already known, * set it to record type just dequeued. If it is already known,