s390/bpf: Fix multiple tail calls

In order to branch around tail calls (due to out-of-bounds index,
exceeding tail call count or missing tail call target), JIT uses
label[0] field, which contains the address of the instruction following
the tail call. When there are multiple tail calls, label[0] value comes
from handling of a previous tail call, which is incorrect.

Fix by getting rid of label array and resolving the label address
locally: for all 3 branches that jump to it, emit 0 offsets at the
beginning, and then backpatch them with the correct value.

Also, do not use the long jump infrastructure: the tail call sequence
is known to be short, so make all 3 jumps short.

Fixes: 6651ee070b ("s390/bpf: implement bpf_tail_call() helper")
Signed-off-by: Ilya Leoshkevich <iii@linux.ibm.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20200909232141.3099367-1-iii@linux.ibm.com
This commit is contained in:
Ilya Leoshkevich 2020-09-10 01:21:41 +02:00 committed by Alexei Starovoitov
parent 2bab48c5be
commit d72714c1da
1 changed files with 27 additions and 34 deletions

View File

@ -50,7 +50,6 @@ struct bpf_jit {
int r14_thunk_ip; /* Address of expoline thunk for 'br %r14' */ int r14_thunk_ip; /* Address of expoline thunk for 'br %r14' */
int tail_call_start; /* Tail call start offset */ int tail_call_start; /* Tail call start offset */
int excnt; /* Number of exception table entries */ int excnt; /* Number of exception table entries */
int labels[1]; /* Labels for local jumps */
}; };
#define SEEN_MEM BIT(0) /* use mem[] for temporary storage */ #define SEEN_MEM BIT(0) /* use mem[] for temporary storage */
@ -229,18 +228,18 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
REG_SET_SEEN(b3); \ REG_SET_SEEN(b3); \
}) })
#define EMIT6_PCREL_LABEL(op1, op2, b1, b2, label, mask) \ #define EMIT6_PCREL_RIEB(op1, op2, b1, b2, mask, target) \
({ \ ({ \
int rel = (jit->labels[label] - jit->prg) >> 1; \ unsigned int rel = (int)((target) - jit->prg) / 2; \
_EMIT6((op1) | reg(b1, b2) << 16 | (rel & 0xffff), \ _EMIT6((op1) | reg(b1, b2) << 16 | (rel & 0xffff), \
(op2) | (mask) << 12); \ (op2) | (mask) << 12); \
REG_SET_SEEN(b1); \ REG_SET_SEEN(b1); \
REG_SET_SEEN(b2); \ REG_SET_SEEN(b2); \
}) })
#define EMIT6_PCREL_IMM_LABEL(op1, op2, b1, imm, label, mask) \ #define EMIT6_PCREL_RIEC(op1, op2, b1, imm, mask, target) \
({ \ ({ \
int rel = (jit->labels[label] - jit->prg) >> 1; \ unsigned int rel = (int)((target) - jit->prg) / 2; \
_EMIT6((op1) | (reg_high(b1) | (mask)) << 16 | \ _EMIT6((op1) | (reg_high(b1) | (mask)) << 16 | \
(rel & 0xffff), (op2) | ((imm) & 0xff) << 8); \ (rel & 0xffff), (op2) | ((imm) & 0xff) << 8); \
REG_SET_SEEN(b1); \ REG_SET_SEEN(b1); \
@ -1282,7 +1281,9 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
EMIT4(0xb9040000, BPF_REG_0, REG_2); EMIT4(0xb9040000, BPF_REG_0, REG_2);
break; break;
} }
case BPF_JMP | BPF_TAIL_CALL: case BPF_JMP | BPF_TAIL_CALL: {
int patch_1_clrj, patch_2_clij, patch_3_brc;
/* /*
* Implicit input: * Implicit input:
* B1: pointer to ctx * B1: pointer to ctx
@ -1300,16 +1301,10 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
EMIT6_DISP_LH(0xe3000000, 0x0016, REG_W1, REG_0, BPF_REG_2, EMIT6_DISP_LH(0xe3000000, 0x0016, REG_W1, REG_0, BPF_REG_2,
offsetof(struct bpf_array, map.max_entries)); offsetof(struct bpf_array, map.max_entries));
/* if ((u32)%b3 >= (u32)%w1) goto out; */ /* if ((u32)%b3 >= (u32)%w1) goto out; */
if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { /* clrj %b3,%w1,0xa,out */
/* clrj %b3,%w1,0xa,label0 */ patch_1_clrj = jit->prg;
EMIT6_PCREL_LABEL(0xec000000, 0x0077, BPF_REG_3, EMIT6_PCREL_RIEB(0xec000000, 0x0077, BPF_REG_3, REG_W1, 0xa,
REG_W1, 0, 0xa); jit->prg);
} else {
/* clr %b3,%w1 */
EMIT2(0x1500, BPF_REG_3, REG_W1);
/* brcl 0xa,label0 */
EMIT6_PCREL_RILC(0xc0040000, 0xa, jit->labels[0]);
}
/* /*
* if (tail_call_cnt++ > MAX_TAIL_CALL_CNT) * if (tail_call_cnt++ > MAX_TAIL_CALL_CNT)
@ -1324,16 +1319,10 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
EMIT4_IMM(0xa7080000, REG_W0, 1); EMIT4_IMM(0xa7080000, REG_W0, 1);
/* laal %w1,%w0,off(%r15) */ /* laal %w1,%w0,off(%r15) */
EMIT6_DISP_LH(0xeb000000, 0x00fa, REG_W1, REG_W0, REG_15, off); EMIT6_DISP_LH(0xeb000000, 0x00fa, REG_W1, REG_W0, REG_15, off);
if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { /* clij %w1,MAX_TAIL_CALL_CNT,0x2,out */
/* clij %w1,MAX_TAIL_CALL_CNT,0x2,label0 */ patch_2_clij = jit->prg;
EMIT6_PCREL_IMM_LABEL(0xec000000, 0x007f, REG_W1, EMIT6_PCREL_RIEC(0xec000000, 0x007f, REG_W1, MAX_TAIL_CALL_CNT,
MAX_TAIL_CALL_CNT, 0, 0x2); 2, jit->prg);
} else {
/* clfi %w1,MAX_TAIL_CALL_CNT */
EMIT6_IMM(0xc20f0000, REG_W1, MAX_TAIL_CALL_CNT);
/* brcl 0x2,label0 */
EMIT6_PCREL_RILC(0xc0040000, 0x2, jit->labels[0]);
}
/* /*
* prog = array->ptrs[index]; * prog = array->ptrs[index];
@ -1348,13 +1337,9 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
/* ltg %r1,prog(%b2,%r1) */ /* ltg %r1,prog(%b2,%r1) */
EMIT6_DISP_LH(0xe3000000, 0x0002, REG_1, BPF_REG_2, EMIT6_DISP_LH(0xe3000000, 0x0002, REG_1, BPF_REG_2,
REG_1, offsetof(struct bpf_array, ptrs)); REG_1, offsetof(struct bpf_array, ptrs));
if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { /* brc 0x8,out */
/* brc 0x8,label0 */ patch_3_brc = jit->prg;
EMIT4_PCREL_RIC(0xa7040000, 0x8, jit->labels[0]); EMIT4_PCREL_RIC(0xa7040000, 8, jit->prg);
} else {
/* brcl 0x8,label0 */
EMIT6_PCREL_RILC(0xc0040000, 0x8, jit->labels[0]);
}
/* /*
* Restore registers before calling function * Restore registers before calling function
@ -1371,8 +1356,16 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
/* bc 0xf,tail_call_start(%r1) */ /* bc 0xf,tail_call_start(%r1) */
_EMIT4(0x47f01000 + jit->tail_call_start); _EMIT4(0x47f01000 + jit->tail_call_start);
/* out: */ /* out: */
jit->labels[0] = jit->prg; if (jit->prg_buf) {
*(u16 *)(jit->prg_buf + patch_1_clrj + 2) =
(jit->prg - patch_1_clrj) >> 1;
*(u16 *)(jit->prg_buf + patch_2_clij + 2) =
(jit->prg - patch_2_clij) >> 1;
*(u16 *)(jit->prg_buf + patch_3_brc + 2) =
(jit->prg - patch_3_brc) >> 1;
}
break; break;
}
case BPF_JMP | BPF_EXIT: /* return b0 */ case BPF_JMP | BPF_EXIT: /* return b0 */
last = (i == fp->len - 1) ? 1 : 0; last = (i == fp->len - 1) ? 1 : 0;
if (last) if (last)