]> git.karo-electronics.de Git - karo-tx-linux.git/blob - arch/arm64/net/bpf_jit_comp.c
Merge tag 'armsoc-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/arm/arm-soc
[karo-tx-linux.git] / arch / arm64 / net / bpf_jit_comp.c
1 /*
2  * BPF JIT compiler for ARM64
3  *
4  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #define pr_fmt(fmt) "bpf_jit: " fmt
20
21 #include <linux/bpf.h>
22 #include <linux/filter.h>
23 #include <linux/printk.h>
24 #include <linux/skbuff.h>
25 #include <linux/slab.h>
26
27 #include <asm/byteorder.h>
28 #include <asm/cacheflush.h>
29 #include <asm/debug-monitors.h>
30 #include <asm/set_memory.h>
31
32 #include "bpf_jit.h"
33
34 int bpf_jit_enable __read_mostly;
35
36 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
37 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
38 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
39 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
40
41 /* Map BPF registers to A64 registers */
42 static const int bpf2a64[] = {
43         /* return value from in-kernel function, and exit value from eBPF */
44         [BPF_REG_0] = A64_R(7),
45         /* arguments from eBPF program to in-kernel function */
46         [BPF_REG_1] = A64_R(0),
47         [BPF_REG_2] = A64_R(1),
48         [BPF_REG_3] = A64_R(2),
49         [BPF_REG_4] = A64_R(3),
50         [BPF_REG_5] = A64_R(4),
51         /* callee saved registers that in-kernel function will preserve */
52         [BPF_REG_6] = A64_R(19),
53         [BPF_REG_7] = A64_R(20),
54         [BPF_REG_8] = A64_R(21),
55         [BPF_REG_9] = A64_R(22),
56         /* read-only frame pointer to access stack */
57         [BPF_REG_FP] = A64_R(25),
58         /* temporary registers for internal BPF JIT */
59         [TMP_REG_1] = A64_R(10),
60         [TMP_REG_2] = A64_R(11),
61         [TMP_REG_3] = A64_R(12),
62         /* tail_call_cnt */
63         [TCALL_CNT] = A64_R(26),
64         /* temporary register for blinding constants */
65         [BPF_REG_AX] = A64_R(9),
66 };
67
68 struct jit_ctx {
69         const struct bpf_prog *prog;
70         int idx;
71         int epilogue_offset;
72         int *offset;
73         u32 *image;
74 };
75
76 static inline void emit(const u32 insn, struct jit_ctx *ctx)
77 {
78         if (ctx->image != NULL)
79                 ctx->image[ctx->idx] = cpu_to_le32(insn);
80
81         ctx->idx++;
82 }
83
84 static inline void emit_a64_mov_i64(const int reg, const u64 val,
85                                     struct jit_ctx *ctx)
86 {
87         u64 tmp = val;
88         int shift = 0;
89
90         emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
91         tmp >>= 16;
92         shift += 16;
93         while (tmp) {
94                 if (tmp & 0xffff)
95                         emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
96                 tmp >>= 16;
97                 shift += 16;
98         }
99 }
100
101 static inline void emit_a64_mov_i(const int is64, const int reg,
102                                   const s32 val, struct jit_ctx *ctx)
103 {
104         u16 hi = val >> 16;
105         u16 lo = val & 0xffff;
106
107         if (hi & 0x8000) {
108                 if (hi == 0xffff) {
109                         emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
110                 } else {
111                         emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
112                         emit(A64_MOVK(is64, reg, lo, 0), ctx);
113                 }
114         } else {
115                 emit(A64_MOVZ(is64, reg, lo, 0), ctx);
116                 if (hi)
117                         emit(A64_MOVK(is64, reg, hi, 16), ctx);
118         }
119 }
120
121 static inline int bpf2a64_offset(int bpf_to, int bpf_from,
122                                  const struct jit_ctx *ctx)
123 {
124         int to = ctx->offset[bpf_to];
125         /* -1 to account for the Branch instruction */
126         int from = ctx->offset[bpf_from] - 1;
127
128         return to - from;
129 }
130
131 static void jit_fill_hole(void *area, unsigned int size)
132 {
133         u32 *ptr;
134         /* We are guaranteed to have aligned memory. */
135         for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
136                 *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
137 }
138
139 static inline int epilogue_offset(const struct jit_ctx *ctx)
140 {
141         int to = ctx->epilogue_offset;
142         int from = ctx->idx;
143
144         return to - from;
145 }
146
147 /* Stack must be multiples of 16B */
148 #define STACK_ALIGN(sz) (((sz) + 15) & ~15)
149
150 #define _STACK_SIZE \
151         (MAX_BPF_STACK \
152          + 4 /* extra for skb_copy_bits buffer */)
153
154 #define STACK_SIZE STACK_ALIGN(_STACK_SIZE)
155
156 #define PROLOGUE_OFFSET 8
157
158 static int build_prologue(struct jit_ctx *ctx)
159 {
160         const u8 r6 = bpf2a64[BPF_REG_6];
161         const u8 r7 = bpf2a64[BPF_REG_7];
162         const u8 r8 = bpf2a64[BPF_REG_8];
163         const u8 r9 = bpf2a64[BPF_REG_9];
164         const u8 fp = bpf2a64[BPF_REG_FP];
165         const u8 tcc = bpf2a64[TCALL_CNT];
166         const int idx0 = ctx->idx;
167         int cur_offset;
168
169         /*
170          * BPF prog stack layout
171          *
172          *                         high
173          * original A64_SP =>   0:+-----+ BPF prologue
174          *                        |FP/LR|
175          * current A64_FP =>  -16:+-----+
176          *                        | ... | callee saved registers
177          * BPF fp register => -64:+-----+ <= (BPF_FP)
178          *                        |     |
179          *                        | ... | BPF prog stack
180          *                        |     |
181          *                        +-----+ <= (BPF_FP - MAX_BPF_STACK)
182          *                        |RSVD | JIT scratchpad
183          * current A64_SP =>      +-----+ <= (BPF_FP - STACK_SIZE)
184          *                        |     |
185          *                        | ... | Function call stack
186          *                        |     |
187          *                        +-----+
188          *                          low
189          *
190          */
191
192         /* Save FP and LR registers to stay align with ARM64 AAPCS */
193         emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
194         emit(A64_MOV(1, A64_FP, A64_SP), ctx);
195
196         /* Save callee-saved registers */
197         emit(A64_PUSH(r6, r7, A64_SP), ctx);
198         emit(A64_PUSH(r8, r9, A64_SP), ctx);
199         emit(A64_PUSH(fp, tcc, A64_SP), ctx);
200
201         /* Set up BPF prog stack base register */
202         emit(A64_MOV(1, fp, A64_SP), ctx);
203
204         /* Initialize tail_call_cnt */
205         emit(A64_MOVZ(1, tcc, 0, 0), ctx);
206
207         /* Set up function call stack */
208         emit(A64_SUB_I(1, A64_SP, A64_SP, STACK_SIZE), ctx);
209
210         cur_offset = ctx->idx - idx0;
211         if (cur_offset != PROLOGUE_OFFSET) {
212                 pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
213                             cur_offset, PROLOGUE_OFFSET);
214                 return -1;
215         }
216         return 0;
217 }
218
219 static int out_offset = -1; /* initialized on the first pass of build_body() */
220 static int emit_bpf_tail_call(struct jit_ctx *ctx)
221 {
222         /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
223         const u8 r2 = bpf2a64[BPF_REG_2];
224         const u8 r3 = bpf2a64[BPF_REG_3];
225
226         const u8 tmp = bpf2a64[TMP_REG_1];
227         const u8 prg = bpf2a64[TMP_REG_2];
228         const u8 tcc = bpf2a64[TCALL_CNT];
229         const int idx0 = ctx->idx;
230 #define cur_offset (ctx->idx - idx0)
231 #define jmp_offset (out_offset - (cur_offset))
232         size_t off;
233
234         /* if (index >= array->map.max_entries)
235          *     goto out;
236          */
237         off = offsetof(struct bpf_array, map.max_entries);
238         emit_a64_mov_i64(tmp, off, ctx);
239         emit(A64_LDR32(tmp, r2, tmp), ctx);
240         emit(A64_CMP(0, r3, tmp), ctx);
241         emit(A64_B_(A64_COND_GE, jmp_offset), ctx);
242
243         /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
244          *     goto out;
245          * tail_call_cnt++;
246          */
247         emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
248         emit(A64_CMP(1, tcc, tmp), ctx);
249         emit(A64_B_(A64_COND_GT, jmp_offset), ctx);
250         emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
251
252         /* prog = array->ptrs[index];
253          * if (prog == NULL)
254          *     goto out;
255          */
256         off = offsetof(struct bpf_array, ptrs);
257         emit_a64_mov_i64(tmp, off, ctx);
258         emit(A64_ADD(1, tmp, r2, tmp), ctx);
259         emit(A64_LSL(1, prg, r3, 3), ctx);
260         emit(A64_LDR64(prg, tmp, prg), ctx);
261         emit(A64_CBZ(1, prg, jmp_offset), ctx);
262
263         /* goto *(prog->bpf_func + prologue_size); */
264         off = offsetof(struct bpf_prog, bpf_func);
265         emit_a64_mov_i64(tmp, off, ctx);
266         emit(A64_LDR64(tmp, prg, tmp), ctx);
267         emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
268         emit(A64_BR(tmp), ctx);
269
270         /* out: */
271         if (out_offset == -1)
272                 out_offset = cur_offset;
273         if (cur_offset != out_offset) {
274                 pr_err_once("tail_call out_offset = %d, expected %d!\n",
275                             cur_offset, out_offset);
276                 return -1;
277         }
278         return 0;
279 #undef cur_offset
280 #undef jmp_offset
281 }
282
283 static void build_epilogue(struct jit_ctx *ctx)
284 {
285         const u8 r0 = bpf2a64[BPF_REG_0];
286         const u8 r6 = bpf2a64[BPF_REG_6];
287         const u8 r7 = bpf2a64[BPF_REG_7];
288         const u8 r8 = bpf2a64[BPF_REG_8];
289         const u8 r9 = bpf2a64[BPF_REG_9];
290         const u8 fp = bpf2a64[BPF_REG_FP];
291
292         /* We're done with BPF stack */
293         emit(A64_ADD_I(1, A64_SP, A64_SP, STACK_SIZE), ctx);
294
295         /* Restore fs (x25) and x26 */
296         emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
297
298         /* Restore callee-saved register */
299         emit(A64_POP(r8, r9, A64_SP), ctx);
300         emit(A64_POP(r6, r7, A64_SP), ctx);
301
302         /* Restore FP/LR registers */
303         emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
304
305         /* Set return value */
306         emit(A64_MOV(1, A64_R(0), r0), ctx);
307
308         emit(A64_RET(A64_LR), ctx);
309 }
310
311 /* JITs an eBPF instruction.
312  * Returns:
313  * 0  - successfully JITed an 8-byte eBPF instruction.
314  * >0 - successfully JITed a 16-byte eBPF instruction.
315  * <0 - failed to JIT.
316  */
317 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
318 {
319         const u8 code = insn->code;
320         const u8 dst = bpf2a64[insn->dst_reg];
321         const u8 src = bpf2a64[insn->src_reg];
322         const u8 tmp = bpf2a64[TMP_REG_1];
323         const u8 tmp2 = bpf2a64[TMP_REG_2];
324         const u8 tmp3 = bpf2a64[TMP_REG_3];
325         const s16 off = insn->off;
326         const s32 imm = insn->imm;
327         const int i = insn - ctx->prog->insnsi;
328         const bool is64 = BPF_CLASS(code) == BPF_ALU64;
329         const bool isdw = BPF_SIZE(code) == BPF_DW;
330         u8 jmp_cond;
331         s32 jmp_offset;
332
333 #define check_imm(bits, imm) do {                               \
334         if ((((imm) > 0) && ((imm) >> (bits))) ||               \
335             (((imm) < 0) && (~(imm) >> (bits)))) {              \
336                 pr_info("[%2d] imm=%d(0x%x) out of range\n",    \
337                         i, imm, imm);                           \
338                 return -EINVAL;                                 \
339         }                                                       \
340 } while (0)
341 #define check_imm19(imm) check_imm(19, imm)
342 #define check_imm26(imm) check_imm(26, imm)
343
344         switch (code) {
345         /* dst = src */
346         case BPF_ALU | BPF_MOV | BPF_X:
347         case BPF_ALU64 | BPF_MOV | BPF_X:
348                 emit(A64_MOV(is64, dst, src), ctx);
349                 break;
350         /* dst = dst OP src */
351         case BPF_ALU | BPF_ADD | BPF_X:
352         case BPF_ALU64 | BPF_ADD | BPF_X:
353                 emit(A64_ADD(is64, dst, dst, src), ctx);
354                 break;
355         case BPF_ALU | BPF_SUB | BPF_X:
356         case BPF_ALU64 | BPF_SUB | BPF_X:
357                 emit(A64_SUB(is64, dst, dst, src), ctx);
358                 break;
359         case BPF_ALU | BPF_AND | BPF_X:
360         case BPF_ALU64 | BPF_AND | BPF_X:
361                 emit(A64_AND(is64, dst, dst, src), ctx);
362                 break;
363         case BPF_ALU | BPF_OR | BPF_X:
364         case BPF_ALU64 | BPF_OR | BPF_X:
365                 emit(A64_ORR(is64, dst, dst, src), ctx);
366                 break;
367         case BPF_ALU | BPF_XOR | BPF_X:
368         case BPF_ALU64 | BPF_XOR | BPF_X:
369                 emit(A64_EOR(is64, dst, dst, src), ctx);
370                 break;
371         case BPF_ALU | BPF_MUL | BPF_X:
372         case BPF_ALU64 | BPF_MUL | BPF_X:
373                 emit(A64_MUL(is64, dst, dst, src), ctx);
374                 break;
375         case BPF_ALU | BPF_DIV | BPF_X:
376         case BPF_ALU64 | BPF_DIV | BPF_X:
377         case BPF_ALU | BPF_MOD | BPF_X:
378         case BPF_ALU64 | BPF_MOD | BPF_X:
379         {
380                 const u8 r0 = bpf2a64[BPF_REG_0];
381
382                 /* if (src == 0) return 0 */
383                 jmp_offset = 3; /* skip ahead to else path */
384                 check_imm19(jmp_offset);
385                 emit(A64_CBNZ(is64, src, jmp_offset), ctx);
386                 emit(A64_MOVZ(1, r0, 0, 0), ctx);
387                 jmp_offset = epilogue_offset(ctx);
388                 check_imm26(jmp_offset);
389                 emit(A64_B(jmp_offset), ctx);
390                 /* else */
391                 switch (BPF_OP(code)) {
392                 case BPF_DIV:
393                         emit(A64_UDIV(is64, dst, dst, src), ctx);
394                         break;
395                 case BPF_MOD:
396                         emit(A64_UDIV(is64, tmp, dst, src), ctx);
397                         emit(A64_MUL(is64, tmp, tmp, src), ctx);
398                         emit(A64_SUB(is64, dst, dst, tmp), ctx);
399                         break;
400                 }
401                 break;
402         }
403         case BPF_ALU | BPF_LSH | BPF_X:
404         case BPF_ALU64 | BPF_LSH | BPF_X:
405                 emit(A64_LSLV(is64, dst, dst, src), ctx);
406                 break;
407         case BPF_ALU | BPF_RSH | BPF_X:
408         case BPF_ALU64 | BPF_RSH | BPF_X:
409                 emit(A64_LSRV(is64, dst, dst, src), ctx);
410                 break;
411         case BPF_ALU | BPF_ARSH | BPF_X:
412         case BPF_ALU64 | BPF_ARSH | BPF_X:
413                 emit(A64_ASRV(is64, dst, dst, src), ctx);
414                 break;
415         /* dst = -dst */
416         case BPF_ALU | BPF_NEG:
417         case BPF_ALU64 | BPF_NEG:
418                 emit(A64_NEG(is64, dst, dst), ctx);
419                 break;
420         /* dst = BSWAP##imm(dst) */
421         case BPF_ALU | BPF_END | BPF_FROM_LE:
422         case BPF_ALU | BPF_END | BPF_FROM_BE:
423 #ifdef CONFIG_CPU_BIG_ENDIAN
424                 if (BPF_SRC(code) == BPF_FROM_BE)
425                         goto emit_bswap_uxt;
426 #else /* !CONFIG_CPU_BIG_ENDIAN */
427                 if (BPF_SRC(code) == BPF_FROM_LE)
428                         goto emit_bswap_uxt;
429 #endif
430                 switch (imm) {
431                 case 16:
432                         emit(A64_REV16(is64, dst, dst), ctx);
433                         /* zero-extend 16 bits into 64 bits */
434                         emit(A64_UXTH(is64, dst, dst), ctx);
435                         break;
436                 case 32:
437                         emit(A64_REV32(is64, dst, dst), ctx);
438                         /* upper 32 bits already cleared */
439                         break;
440                 case 64:
441                         emit(A64_REV64(dst, dst), ctx);
442                         break;
443                 }
444                 break;
445 emit_bswap_uxt:
446                 switch (imm) {
447                 case 16:
448                         /* zero-extend 16 bits into 64 bits */
449                         emit(A64_UXTH(is64, dst, dst), ctx);
450                         break;
451                 case 32:
452                         /* zero-extend 32 bits into 64 bits */
453                         emit(A64_UXTW(is64, dst, dst), ctx);
454                         break;
455                 case 64:
456                         /* nop */
457                         break;
458                 }
459                 break;
460         /* dst = imm */
461         case BPF_ALU | BPF_MOV | BPF_K:
462         case BPF_ALU64 | BPF_MOV | BPF_K:
463                 emit_a64_mov_i(is64, dst, imm, ctx);
464                 break;
465         /* dst = dst OP imm */
466         case BPF_ALU | BPF_ADD | BPF_K:
467         case BPF_ALU64 | BPF_ADD | BPF_K:
468                 emit_a64_mov_i(is64, tmp, imm, ctx);
469                 emit(A64_ADD(is64, dst, dst, tmp), ctx);
470                 break;
471         case BPF_ALU | BPF_SUB | BPF_K:
472         case BPF_ALU64 | BPF_SUB | BPF_K:
473                 emit_a64_mov_i(is64, tmp, imm, ctx);
474                 emit(A64_SUB(is64, dst, dst, tmp), ctx);
475                 break;
476         case BPF_ALU | BPF_AND | BPF_K:
477         case BPF_ALU64 | BPF_AND | BPF_K:
478                 emit_a64_mov_i(is64, tmp, imm, ctx);
479                 emit(A64_AND(is64, dst, dst, tmp), ctx);
480                 break;
481         case BPF_ALU | BPF_OR | BPF_K:
482         case BPF_ALU64 | BPF_OR | BPF_K:
483                 emit_a64_mov_i(is64, tmp, imm, ctx);
484                 emit(A64_ORR(is64, dst, dst, tmp), ctx);
485                 break;
486         case BPF_ALU | BPF_XOR | BPF_K:
487         case BPF_ALU64 | BPF_XOR | BPF_K:
488                 emit_a64_mov_i(is64, tmp, imm, ctx);
489                 emit(A64_EOR(is64, dst, dst, tmp), ctx);
490                 break;
491         case BPF_ALU | BPF_MUL | BPF_K:
492         case BPF_ALU64 | BPF_MUL | BPF_K:
493                 emit_a64_mov_i(is64, tmp, imm, ctx);
494                 emit(A64_MUL(is64, dst, dst, tmp), ctx);
495                 break;
496         case BPF_ALU | BPF_DIV | BPF_K:
497         case BPF_ALU64 | BPF_DIV | BPF_K:
498                 emit_a64_mov_i(is64, tmp, imm, ctx);
499                 emit(A64_UDIV(is64, dst, dst, tmp), ctx);
500                 break;
501         case BPF_ALU | BPF_MOD | BPF_K:
502         case BPF_ALU64 | BPF_MOD | BPF_K:
503                 emit_a64_mov_i(is64, tmp2, imm, ctx);
504                 emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
505                 emit(A64_MUL(is64, tmp, tmp, tmp2), ctx);
506                 emit(A64_SUB(is64, dst, dst, tmp), ctx);
507                 break;
508         case BPF_ALU | BPF_LSH | BPF_K:
509         case BPF_ALU64 | BPF_LSH | BPF_K:
510                 emit(A64_LSL(is64, dst, dst, imm), ctx);
511                 break;
512         case BPF_ALU | BPF_RSH | BPF_K:
513         case BPF_ALU64 | BPF_RSH | BPF_K:
514                 emit(A64_LSR(is64, dst, dst, imm), ctx);
515                 break;
516         case BPF_ALU | BPF_ARSH | BPF_K:
517         case BPF_ALU64 | BPF_ARSH | BPF_K:
518                 emit(A64_ASR(is64, dst, dst, imm), ctx);
519                 break;
520
521         /* JUMP off */
522         case BPF_JMP | BPF_JA:
523                 jmp_offset = bpf2a64_offset(i + off, i, ctx);
524                 check_imm26(jmp_offset);
525                 emit(A64_B(jmp_offset), ctx);
526                 break;
527         /* IF (dst COND src) JUMP off */
528         case BPF_JMP | BPF_JEQ | BPF_X:
529         case BPF_JMP | BPF_JGT | BPF_X:
530         case BPF_JMP | BPF_JGE | BPF_X:
531         case BPF_JMP | BPF_JNE | BPF_X:
532         case BPF_JMP | BPF_JSGT | BPF_X:
533         case BPF_JMP | BPF_JSGE | BPF_X:
534                 emit(A64_CMP(1, dst, src), ctx);
535 emit_cond_jmp:
536                 jmp_offset = bpf2a64_offset(i + off, i, ctx);
537                 check_imm19(jmp_offset);
538                 switch (BPF_OP(code)) {
539                 case BPF_JEQ:
540                         jmp_cond = A64_COND_EQ;
541                         break;
542                 case BPF_JGT:
543                         jmp_cond = A64_COND_HI;
544                         break;
545                 case BPF_JGE:
546                         jmp_cond = A64_COND_CS;
547                         break;
548                 case BPF_JSET:
549                 case BPF_JNE:
550                         jmp_cond = A64_COND_NE;
551                         break;
552                 case BPF_JSGT:
553                         jmp_cond = A64_COND_GT;
554                         break;
555                 case BPF_JSGE:
556                         jmp_cond = A64_COND_GE;
557                         break;
558                 default:
559                         return -EFAULT;
560                 }
561                 emit(A64_B_(jmp_cond, jmp_offset), ctx);
562                 break;
563         case BPF_JMP | BPF_JSET | BPF_X:
564                 emit(A64_TST(1, dst, src), ctx);
565                 goto emit_cond_jmp;
566         /* IF (dst COND imm) JUMP off */
567         case BPF_JMP | BPF_JEQ | BPF_K:
568         case BPF_JMP | BPF_JGT | BPF_K:
569         case BPF_JMP | BPF_JGE | BPF_K:
570         case BPF_JMP | BPF_JNE | BPF_K:
571         case BPF_JMP | BPF_JSGT | BPF_K:
572         case BPF_JMP | BPF_JSGE | BPF_K:
573                 emit_a64_mov_i(1, tmp, imm, ctx);
574                 emit(A64_CMP(1, dst, tmp), ctx);
575                 goto emit_cond_jmp;
576         case BPF_JMP | BPF_JSET | BPF_K:
577                 emit_a64_mov_i(1, tmp, imm, ctx);
578                 emit(A64_TST(1, dst, tmp), ctx);
579                 goto emit_cond_jmp;
580         /* function call */
581         case BPF_JMP | BPF_CALL:
582         {
583                 const u8 r0 = bpf2a64[BPF_REG_0];
584                 const u64 func = (u64)__bpf_call_base + imm;
585
586                 emit_a64_mov_i64(tmp, func, ctx);
587                 emit(A64_BLR(tmp), ctx);
588                 emit(A64_MOV(1, r0, A64_R(0)), ctx);
589                 break;
590         }
591         /* tail call */
592         case BPF_JMP | BPF_CALL | BPF_X:
593                 if (emit_bpf_tail_call(ctx))
594                         return -EFAULT;
595                 break;
596         /* function return */
597         case BPF_JMP | BPF_EXIT:
598                 /* Optimization: when last instruction is EXIT,
599                    simply fallthrough to epilogue. */
600                 if (i == ctx->prog->len - 1)
601                         break;
602                 jmp_offset = epilogue_offset(ctx);
603                 check_imm26(jmp_offset);
604                 emit(A64_B(jmp_offset), ctx);
605                 break;
606
607         /* dst = imm64 */
608         case BPF_LD | BPF_IMM | BPF_DW:
609         {
610                 const struct bpf_insn insn1 = insn[1];
611                 u64 imm64;
612
613                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
614                 emit_a64_mov_i64(dst, imm64, ctx);
615
616                 return 1;
617         }
618
619         /* LDX: dst = *(size *)(src + off) */
620         case BPF_LDX | BPF_MEM | BPF_W:
621         case BPF_LDX | BPF_MEM | BPF_H:
622         case BPF_LDX | BPF_MEM | BPF_B:
623         case BPF_LDX | BPF_MEM | BPF_DW:
624                 emit_a64_mov_i(1, tmp, off, ctx);
625                 switch (BPF_SIZE(code)) {
626                 case BPF_W:
627                         emit(A64_LDR32(dst, src, tmp), ctx);
628                         break;
629                 case BPF_H:
630                         emit(A64_LDRH(dst, src, tmp), ctx);
631                         break;
632                 case BPF_B:
633                         emit(A64_LDRB(dst, src, tmp), ctx);
634                         break;
635                 case BPF_DW:
636                         emit(A64_LDR64(dst, src, tmp), ctx);
637                         break;
638                 }
639                 break;
640
641         /* ST: *(size *)(dst + off) = imm */
642         case BPF_ST | BPF_MEM | BPF_W:
643         case BPF_ST | BPF_MEM | BPF_H:
644         case BPF_ST | BPF_MEM | BPF_B:
645         case BPF_ST | BPF_MEM | BPF_DW:
646                 /* Load imm to a register then store it */
647                 emit_a64_mov_i(1, tmp2, off, ctx);
648                 emit_a64_mov_i(1, tmp, imm, ctx);
649                 switch (BPF_SIZE(code)) {
650                 case BPF_W:
651                         emit(A64_STR32(tmp, dst, tmp2), ctx);
652                         break;
653                 case BPF_H:
654                         emit(A64_STRH(tmp, dst, tmp2), ctx);
655                         break;
656                 case BPF_B:
657                         emit(A64_STRB(tmp, dst, tmp2), ctx);
658                         break;
659                 case BPF_DW:
660                         emit(A64_STR64(tmp, dst, tmp2), ctx);
661                         break;
662                 }
663                 break;
664
665         /* STX: *(size *)(dst + off) = src */
666         case BPF_STX | BPF_MEM | BPF_W:
667         case BPF_STX | BPF_MEM | BPF_H:
668         case BPF_STX | BPF_MEM | BPF_B:
669         case BPF_STX | BPF_MEM | BPF_DW:
670                 emit_a64_mov_i(1, tmp, off, ctx);
671                 switch (BPF_SIZE(code)) {
672                 case BPF_W:
673                         emit(A64_STR32(src, dst, tmp), ctx);
674                         break;
675                 case BPF_H:
676                         emit(A64_STRH(src, dst, tmp), ctx);
677                         break;
678                 case BPF_B:
679                         emit(A64_STRB(src, dst, tmp), ctx);
680                         break;
681                 case BPF_DW:
682                         emit(A64_STR64(src, dst, tmp), ctx);
683                         break;
684                 }
685                 break;
686         /* STX XADD: lock *(u32 *)(dst + off) += src */
687         case BPF_STX | BPF_XADD | BPF_W:
688         /* STX XADD: lock *(u64 *)(dst + off) += src */
689         case BPF_STX | BPF_XADD | BPF_DW:
690                 emit_a64_mov_i(1, tmp, off, ctx);
691                 emit(A64_ADD(1, tmp, tmp, dst), ctx);
692                 emit(A64_PRFM(tmp, PST, L1, STRM), ctx);
693                 emit(A64_LDXR(isdw, tmp2, tmp), ctx);
694                 emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
695                 emit(A64_STXR(isdw, tmp2, tmp, tmp3), ctx);
696                 jmp_offset = -3;
697                 check_imm19(jmp_offset);
698                 emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
699                 break;
700
701         /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + imm)) */
702         case BPF_LD | BPF_ABS | BPF_W:
703         case BPF_LD | BPF_ABS | BPF_H:
704         case BPF_LD | BPF_ABS | BPF_B:
705         /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + src + imm)) */
706         case BPF_LD | BPF_IND | BPF_W:
707         case BPF_LD | BPF_IND | BPF_H:
708         case BPF_LD | BPF_IND | BPF_B:
709         {
710                 const u8 r0 = bpf2a64[BPF_REG_0]; /* r0 = return value */
711                 const u8 r6 = bpf2a64[BPF_REG_6]; /* r6 = pointer to sk_buff */
712                 const u8 fp = bpf2a64[BPF_REG_FP];
713                 const u8 r1 = bpf2a64[BPF_REG_1]; /* r1: struct sk_buff *skb */
714                 const u8 r2 = bpf2a64[BPF_REG_2]; /* r2: int k */
715                 const u8 r3 = bpf2a64[BPF_REG_3]; /* r3: unsigned int size */
716                 const u8 r4 = bpf2a64[BPF_REG_4]; /* r4: void *buffer */
717                 const u8 r5 = bpf2a64[BPF_REG_5]; /* r5: void *(*func)(...) */
718                 int size;
719
720                 emit(A64_MOV(1, r1, r6), ctx);
721                 emit_a64_mov_i(0, r2, imm, ctx);
722                 if (BPF_MODE(code) == BPF_IND)
723                         emit(A64_ADD(0, r2, r2, src), ctx);
724                 switch (BPF_SIZE(code)) {
725                 case BPF_W:
726                         size = 4;
727                         break;
728                 case BPF_H:
729                         size = 2;
730                         break;
731                 case BPF_B:
732                         size = 1;
733                         break;
734                 default:
735                         return -EINVAL;
736                 }
737                 emit_a64_mov_i64(r3, size, ctx);
738                 emit(A64_SUB_I(1, r4, fp, STACK_SIZE), ctx);
739                 emit_a64_mov_i64(r5, (unsigned long)bpf_load_pointer, ctx);
740                 emit(A64_BLR(r5), ctx);
741                 emit(A64_MOV(1, r0, A64_R(0)), ctx);
742
743                 jmp_offset = epilogue_offset(ctx);
744                 check_imm19(jmp_offset);
745                 emit(A64_CBZ(1, r0, jmp_offset), ctx);
746                 emit(A64_MOV(1, r5, r0), ctx);
747                 switch (BPF_SIZE(code)) {
748                 case BPF_W:
749                         emit(A64_LDR32(r0, r5, A64_ZR), ctx);
750 #ifndef CONFIG_CPU_BIG_ENDIAN
751                         emit(A64_REV32(0, r0, r0), ctx);
752 #endif
753                         break;
754                 case BPF_H:
755                         emit(A64_LDRH(r0, r5, A64_ZR), ctx);
756 #ifndef CONFIG_CPU_BIG_ENDIAN
757                         emit(A64_REV16(0, r0, r0), ctx);
758 #endif
759                         break;
760                 case BPF_B:
761                         emit(A64_LDRB(r0, r5, A64_ZR), ctx);
762                         break;
763                 }
764                 break;
765         }
766         default:
767                 pr_err_once("unknown opcode %02x\n", code);
768                 return -EINVAL;
769         }
770
771         return 0;
772 }
773
774 static int build_body(struct jit_ctx *ctx)
775 {
776         const struct bpf_prog *prog = ctx->prog;
777         int i;
778
779         for (i = 0; i < prog->len; i++) {
780                 const struct bpf_insn *insn = &prog->insnsi[i];
781                 int ret;
782
783                 ret = build_insn(insn, ctx);
784                 if (ret > 0) {
785                         i++;
786                         if (ctx->image == NULL)
787                                 ctx->offset[i] = ctx->idx;
788                         continue;
789                 }
790                 if (ctx->image == NULL)
791                         ctx->offset[i] = ctx->idx;
792                 if (ret)
793                         return ret;
794         }
795
796         return 0;
797 }
798
799 static int validate_code(struct jit_ctx *ctx)
800 {
801         int i;
802
803         for (i = 0; i < ctx->idx; i++) {
804                 u32 a64_insn = le32_to_cpu(ctx->image[i]);
805
806                 if (a64_insn == AARCH64_BREAK_FAULT)
807                         return -1;
808         }
809
810         return 0;
811 }
812
813 static inline void bpf_flush_icache(void *start, void *end)
814 {
815         flush_icache_range((unsigned long)start, (unsigned long)end);
816 }
817
818 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
819 {
820         struct bpf_prog *tmp, *orig_prog = prog;
821         struct bpf_binary_header *header;
822         bool tmp_blinded = false;
823         struct jit_ctx ctx;
824         int image_size;
825         u8 *image_ptr;
826
827         if (!bpf_jit_enable)
828                 return orig_prog;
829
830         tmp = bpf_jit_blind_constants(prog);
831         /* If blinding was requested and we failed during blinding,
832          * we must fall back to the interpreter.
833          */
834         if (IS_ERR(tmp))
835                 return orig_prog;
836         if (tmp != prog) {
837                 tmp_blinded = true;
838                 prog = tmp;
839         }
840
841         memset(&ctx, 0, sizeof(ctx));
842         ctx.prog = prog;
843
844         ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
845         if (ctx.offset == NULL) {
846                 prog = orig_prog;
847                 goto out;
848         }
849
850         /* 1. Initial fake pass to compute ctx->idx. */
851
852         /* Fake pass to fill in ctx->offset. */
853         if (build_body(&ctx)) {
854                 prog = orig_prog;
855                 goto out_off;
856         }
857
858         if (build_prologue(&ctx)) {
859                 prog = orig_prog;
860                 goto out_off;
861         }
862
863         ctx.epilogue_offset = ctx.idx;
864         build_epilogue(&ctx);
865
866         /* Now we know the actual image size. */
867         image_size = sizeof(u32) * ctx.idx;
868         header = bpf_jit_binary_alloc(image_size, &image_ptr,
869                                       sizeof(u32), jit_fill_hole);
870         if (header == NULL) {
871                 prog = orig_prog;
872                 goto out_off;
873         }
874
875         /* 2. Now, the actual pass. */
876
877         ctx.image = (u32 *)image_ptr;
878         ctx.idx = 0;
879
880         build_prologue(&ctx);
881
882         if (build_body(&ctx)) {
883                 bpf_jit_binary_free(header);
884                 prog = orig_prog;
885                 goto out_off;
886         }
887
888         build_epilogue(&ctx);
889
890         /* 3. Extra pass to validate JITed code. */
891         if (validate_code(&ctx)) {
892                 bpf_jit_binary_free(header);
893                 prog = orig_prog;
894                 goto out_off;
895         }
896
897         /* And we're done. */
898         if (bpf_jit_enable > 1)
899                 bpf_jit_dump(prog->len, image_size, 2, ctx.image);
900
901         bpf_flush_icache(header, ctx.image + ctx.idx);
902
903         bpf_jit_binary_lock_ro(header);
904         prog->bpf_func = (void *)ctx.image;
905         prog->jited = 1;
906
907 out_off:
908         kfree(ctx.offset);
909 out:
910         if (tmp_blinded)
911                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
912                                            tmp : orig_prog);
913         return prog;
914 }