crypto: arm64/aes-ce-gcm - implement 2-way aggregation

Implement a faster version of the GHASH transform which amortizes
the reduction modulo the characteristic polynomial across two
input blocks at a time.

On a Cortex-A53, the gcm(aes) performance increases 24%, from
3.0 cycles per byte to 2.4 cpb for large input sizes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
This commit is contained in:
Ard Biesheuvel 2018-07-30 23:06:41 +02:00 committed by Herbert Xu
parent 71e52c278c
commit e0bd888dc4
2 changed files with 58 additions and 74 deletions

View File

@ -290,6 +290,10 @@ ENDPROC(pmull_ghash_update_p8)
KS1 .req v9 KS1 .req v9
INP0 .req v10 INP0 .req v10
INP1 .req v11 INP1 .req v11
HH .req v12
XL2 .req v13
XM2 .req v14
XH2 .req v15
.macro load_round_keys, rounds, rk .macro load_round_keys, rounds, rk
cmp \rounds, #12 cmp \rounds, #12
@ -323,6 +327,7 @@ ENDPROC(pmull_ghash_update_p8)
.endm .endm
.macro pmull_gcm_do_crypt, enc .macro pmull_gcm_do_crypt, enc
ld1 {HH.2d}, [x4], #16
ld1 {SHASH.2d}, [x4] ld1 {SHASH.2d}, [x4]
ld1 {XL.2d}, [x1] ld1 {XL.2d}, [x1]
ldr x8, [x5, #8] // load lower counter ldr x8, [x5, #8] // load lower counter
@ -330,10 +335,11 @@ ENDPROC(pmull_ghash_update_p8)
load_round_keys w7, x6 load_round_keys w7, x6
movi MASK.16b, #0xe1 movi MASK.16b, #0xe1
ext SHASH2.16b, SHASH.16b, SHASH.16b, #8 trn1 SHASH2.2d, SHASH.2d, HH.2d
trn2 T1.2d, SHASH.2d, HH.2d
CPU_LE( rev x8, x8 ) CPU_LE( rev x8, x8 )
shl MASK.2d, MASK.2d, #57 shl MASK.2d, MASK.2d, #57
eor SHASH2.16b, SHASH2.16b, SHASH.16b eor SHASH2.16b, SHASH2.16b, T1.16b
.if \enc == 1 .if \enc == 1
ldr x10, [sp] ldr x10, [sp]
@ -358,116 +364,82 @@ CPU_LE( rev x8, x8 )
ins KS0.d[1], x9 // set lower counter ins KS0.d[1], x9 // set lower counter
ins KS1.d[1], x11 ins KS1.d[1], x11
rev64 T1.16b, INP0.16b rev64 T1.16b, INP1.16b
cmp w7, #12 cmp w7, #12
b.ge 2f // AES-192/256? b.ge 2f // AES-192/256?
1: enc_round KS0, v21 1: enc_round KS0, v21
ext T2.16b, XL.16b, XL.16b, #8
ext IN1.16b, T1.16b, T1.16b, #8 ext IN1.16b, T1.16b, T1.16b, #8
enc_round KS1, v21 enc_round KS1, v21
pmull2 XH2.1q, SHASH.2d, IN1.2d // a1 * b1
eor T1.16b, T1.16b, T2.16b
eor XL.16b, XL.16b, IN1.16b
enc_round KS0, v22 enc_round KS0, v22
eor T1.16b, T1.16b, IN1.16b
pmull2 XH.1q, SHASH.2d, XL.2d // a1 * b1
eor T1.16b, T1.16b, XL.16b
enc_round KS1, v22 enc_round KS1, v22
pmull XL2.1q, SHASH.1d, IN1.1d // a0 * b0
pmull XL.1q, SHASH.1d, XL.1d // a0 * b0
pmull XM.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0)
enc_round KS0, v23 enc_round KS0, v23
pmull XM2.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0)
ext T1.16b, XL.16b, XH.16b, #8
eor T2.16b, XL.16b, XH.16b
eor XM.16b, XM.16b, T1.16b
enc_round KS1, v23 enc_round KS1, v23
rev64 T1.16b, INP0.16b
eor XM.16b, XM.16b, T2.16b ext T2.16b, XL.16b, XL.16b, #8
pmull T2.1q, XL.1d, MASK.1d
enc_round KS0, v24 enc_round KS0, v24
ext IN1.16b, T1.16b, T1.16b, #8
mov XH.d[0], XM.d[1] eor T1.16b, T1.16b, T2.16b
mov XM.d[1], XL.d[0]
enc_round KS1, v24 enc_round KS1, v24
eor XL.16b, XM.16b, T2.16b
enc_round KS0, v25
ext T2.16b, XL.16b, XL.16b, #8
enc_round KS1, v25
pmull XL.1q, XL.1d, MASK.1d
eor T2.16b, T2.16b, XH.16b
enc_round KS0, v26
eor XL.16b, XL.16b, T2.16b
rev64 T1.16b, INP1.16b
enc_round KS1, v26
ext T2.16b, XL.16b, XL.16b, #8
ext IN1.16b, T1.16b, T1.16b, #8
enc_round KS0, v27
eor T1.16b, T1.16b, T2.16b
eor XL.16b, XL.16b, IN1.16b eor XL.16b, XL.16b, IN1.16b
enc_round KS1, v27 enc_round KS0, v25
pmull2 XH.1q, SHASH.2d, XL.2d // a1 * b1
eor T1.16b, T1.16b, XL.16b eor T1.16b, T1.16b, XL.16b
enc_round KS0, v28 enc_round KS1, v25
pmull2 XH.1q, HH.2d, XL.2d // a1 * b1
pmull XL.1q, SHASH.1d, XL.1d // a0 * b0 enc_round KS0, v26
pmull XM.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0) pmull XL.1q, HH.1d, XL.1d // a0 * b0
enc_round KS1, v28 enc_round KS1, v26
pmull2 XM.1q, SHASH2.2d, T1.2d // (a1 + a0)(b1 + b0)
enc_round KS0, v27
eor XL.16b, XL.16b, XL2.16b
eor XH.16b, XH.16b, XH2.16b
enc_round KS1, v27
eor XM.16b, XM.16b, XM2.16b
ext T1.16b, XL.16b, XH.16b, #8 ext T1.16b, XL.16b, XH.16b, #8
enc_round KS0, v28
eor T2.16b, XL.16b, XH.16b eor T2.16b, XL.16b, XH.16b
eor XM.16b, XM.16b, T1.16b eor XM.16b, XM.16b, T1.16b
enc_round KS0, v29 enc_round KS1, v28
eor XM.16b, XM.16b, T2.16b eor XM.16b, XM.16b, T2.16b
enc_round KS0, v29
pmull T2.1q, XL.1d, MASK.1d pmull T2.1q, XL.1d, MASK.1d
enc_round KS1, v29 enc_round KS1, v29
mov XH.d[0], XM.d[1] mov XH.d[0], XM.d[1]
mov XM.d[1], XL.d[0] mov XM.d[1], XL.d[0]
aese KS0.16b, v30.16b aese KS0.16b, v30.16b
eor XL.16b, XM.16b, T2.16b eor XL.16b, XM.16b, T2.16b
aese KS1.16b, v30.16b aese KS1.16b, v30.16b
ext T2.16b, XL.16b, XL.16b, #8 ext T2.16b, XL.16b, XL.16b, #8
eor KS0.16b, KS0.16b, v31.16b eor KS0.16b, KS0.16b, v31.16b
pmull XL.1q, XL.1d, MASK.1d pmull XL.1q, XL.1d, MASK.1d
eor T2.16b, T2.16b, XH.16b eor T2.16b, T2.16b, XH.16b
eor KS1.16b, KS1.16b, v31.16b eor KS1.16b, KS1.16b, v31.16b
eor XL.16b, XL.16b, T2.16b eor XL.16b, XL.16b, T2.16b
.if \enc == 0 .if \enc == 0

View File

@ -46,6 +46,7 @@ struct ghash_desc_ctx {
struct gcm_aes_ctx { struct gcm_aes_ctx {
struct crypto_aes_ctx aes_key; struct crypto_aes_ctx aes_key;
u64 h2[2];
struct ghash_key ghash_key; struct ghash_key ghash_key;
}; };
@ -62,12 +63,11 @@ static void (*pmull_ghash_update)(int blocks, u64 dg[], const char *src,
const char *head); const char *head);
asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[],
const u8 src[], struct ghash_key const *k, const u8 src[], u64 const *k, u8 ctr[],
u8 ctr[], u32 const rk[], int rounds, u32 const rk[], int rounds, u8 ks[]);
u8 ks[]);
asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[],
const u8 src[], struct ghash_key const *k, const u8 src[], u64 const *k,
u8 ctr[], u32 const rk[], int rounds); u8 ctr[], u32 const rk[], int rounds);
asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[], asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[],
@ -232,7 +232,8 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *inkey,
unsigned int keylen) unsigned int keylen)
{ {
struct gcm_aes_ctx *ctx = crypto_aead_ctx(tfm); struct gcm_aes_ctx *ctx = crypto_aead_ctx(tfm);
u8 key[GHASH_BLOCK_SIZE]; be128 h1, h2;
u8 *key = (u8 *)&h1;
int ret; int ret;
ret = crypto_aes_expand_key(&ctx->aes_key, inkey, keylen); ret = crypto_aes_expand_key(&ctx->aes_key, inkey, keylen);
@ -244,7 +245,19 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *inkey,
__aes_arm64_encrypt(ctx->aes_key.key_enc, key, (u8[AES_BLOCK_SIZE]){}, __aes_arm64_encrypt(ctx->aes_key.key_enc, key, (u8[AES_BLOCK_SIZE]){},
num_rounds(&ctx->aes_key)); num_rounds(&ctx->aes_key));
return __ghash_setkey(&ctx->ghash_key, key, sizeof(key)); __ghash_setkey(&ctx->ghash_key, key, sizeof(be128));
/* calculate H^2 (used for 2-way aggregation) */
h2 = h1;
gf128mul_lle(&h2, &h1);
ctx->h2[0] = (be64_to_cpu(h2.b) << 1) | (be64_to_cpu(h2.a) >> 63);
ctx->h2[1] = (be64_to_cpu(h2.a) << 1) | (be64_to_cpu(h2.b) >> 63);
if (be64_to_cpu(h2.a) >> 63)
ctx->h2[1] ^= 0xc200000000000000UL;
return 0;
} }
static int gcm_setauthsize(struct crypto_aead *tfm, unsigned int authsize) static int gcm_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
@ -378,9 +391,8 @@ static int gcm_encrypt(struct aead_request *req)
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr, pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
walk.src.virt.addr, &ctx->ghash_key, walk.src.virt.addr, ctx->h2, iv,
iv, ctx->aes_key.key_enc, nrounds, ctx->aes_key.key_enc, nrounds, ks);
ks);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,
@ -486,8 +498,8 @@ static int gcm_decrypt(struct aead_request *req)
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr, pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
walk.src.virt.addr, &ctx->ghash_key, walk.src.virt.addr, ctx->h2, iv,
iv, ctx->aes_key.key_enc, nrounds); ctx->aes_key.key_enc, nrounds);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,