diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c index dd428807cb30..d0d51903c7e0 100644 --- a/arch/arm64/net/bpf_jit_comp.c +++ b/arch/arm64/net/bpf_jit_comp.c @@ -31,8 +31,8 @@ int bpf_jit_enable __read_mostly; -#define TMP_REG_1 (MAX_BPF_REG + 0) -#define TMP_REG_2 (MAX_BPF_REG + 1) +#define TMP_REG_1 (MAX_BPF_JIT_REG + 0) +#define TMP_REG_2 (MAX_BPF_JIT_REG + 1) /* Map BPF registers to A64 registers */ static const int bpf2a64[] = { @@ -54,6 +54,8 @@ static const int bpf2a64[] = { /* temporary register for internal BPF JIT */ [TMP_REG_1] = A64_R(23), [TMP_REG_2] = A64_R(24), + /* temporary register for blinding constants */ + [BPF_REG_AX] = A64_R(9), }; struct jit_ctx { @@ -764,26 +766,43 @@ void bpf_jit_compile(struct bpf_prog *prog) struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) { + struct bpf_prog *tmp, *orig_prog = prog; struct bpf_binary_header *header; + bool tmp_blinded = false; struct jit_ctx ctx; int image_size; u8 *image_ptr; if (!bpf_jit_enable) - return prog; + return orig_prog; + + tmp = bpf_jit_blind_constants(prog); + /* If blinding was requested and we failed during blinding, + * we must fall back to the interpreter. + */ + if (IS_ERR(tmp)) + return orig_prog; + if (tmp != prog) { + tmp_blinded = true; + prog = tmp; + } memset(&ctx, 0, sizeof(ctx)); ctx.prog = prog; ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL); - if (ctx.offset == NULL) - return prog; + if (ctx.offset == NULL) { + prog = orig_prog; + goto out; + } /* 1. Initial fake pass to compute ctx->idx. */ /* Fake pass to fill in ctx->offset and ctx->tmp_used. */ - if (build_body(&ctx)) - goto out; + if (build_body(&ctx)) { + prog = orig_prog; + goto out_off; + } build_prologue(&ctx); @@ -794,8 +813,10 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) image_size = sizeof(u32) * ctx.idx; header = bpf_jit_binary_alloc(image_size, &image_ptr, sizeof(u32), jit_fill_hole); - if (header == NULL) - goto out; + if (header == NULL) { + prog = orig_prog; + goto out_off; + } /* 2. Now, the actual pass. */ @@ -806,7 +827,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) if (build_body(&ctx)) { bpf_jit_binary_free(header); - goto out; + prog = orig_prog; + goto out_off; } build_epilogue(&ctx); @@ -814,7 +836,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) /* 3. Extra pass to validate JITed code. */ if (validate_code(&ctx)) { bpf_jit_binary_free(header); - goto out; + prog = orig_prog; + goto out_off; } /* And we're done. */ @@ -826,8 +849,13 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) set_memory_ro((unsigned long)header, header->pages); prog->bpf_func = (void *)ctx.image; prog->jited = 1; -out: + +out_off: kfree(ctx.offset); +out: + if (tmp_blinded) + bpf_jit_prog_release_other(prog, prog == orig_prog ? + tmp : orig_prog); return prog; }