1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.apf;
18 
19 import static android.net.apf.BaseApfGenerator.Rbit.Rbit0;
20 import static android.net.apf.BaseApfGenerator.Rbit.Rbit1;
21 import static android.net.apf.BaseApfGenerator.Register.R0;
22 
23 import android.annotation.NonNull;
24 
25 import com.android.net.module.util.ByteUtils;
26 import com.android.net.module.util.CollectionUtils;
27 import com.android.net.module.util.HexDump;
28 
29 import java.util.ArrayList;
30 import java.util.Arrays;
31 import java.util.HashMap;
32 import java.util.List;
33 import java.util.Objects;
34 
35 /**
36  * The base class for APF assembler/generator.
37  *
38  * @hide
39  */
40 public abstract class BaseApfGenerator {
41 
BaseApfGenerator(int mVersion, boolean mDisableCounterRangeCheck)42     public BaseApfGenerator(int mVersion, boolean mDisableCounterRangeCheck) {
43         this.mVersion = mVersion;
44         this.mDisableCounterRangeCheck = mDisableCounterRangeCheck;
45     }
46 
47     /**
48      * This exception is thrown when an attempt is made to generate an illegal instruction.
49      */
50     public static class IllegalInstructionException extends Exception {
IllegalInstructionException(String msg)51         IllegalInstructionException(String msg) {
52             super(msg);
53         }
54     }
55     enum Opcodes {
56         LABEL(-1),
57         // Unconditionally pass (if R=0) or drop (if R=1) packet.
58         // An optional unsigned immediate value can be provided to encode the counter number.
59         // If the value is non-zero, the instruction increments the counter.
60         // The counter is located (-4 * counter number) bytes from the end of the data region.
61         // It is a U32 native-endian value and is always incremented by 1.
62         // This is more or less equivalent to: lddw R0, -N4; add R0,1; stdw R0, -N4; {pass,drop}
63         // e.g. "pass", "pass 1", "drop", "drop 1"
64         PASSDROP(0),
65         LDB(1),    // Load 1 byte from immediate offset, e.g. "ldb R0, [5]"
66         LDH(2),    // Load 2 bytes from immediate offset, e.g. "ldh R0, [5]"
67         LDW(3),    // Load 4 bytes from immediate offset, e.g. "ldw R0, [5]"
68         LDBX(4),   // Load 1 byte from immediate offset plus register, e.g. "ldbx R0, [5]R0"
69         LDHX(5),   // Load 2 byte from immediate offset plus register, e.g. "ldhx R0, [5]R0"
70         LDWX(6),   // Load 4 byte from immediate offset plus register, e.g. "ldwx R0, [5]R0"
71         ADD(7),    // Add, e.g. "add R0,5"
72         MUL(8),    // Multiply, e.g. "mul R0,5"
73         DIV(9),    // Divide, e.g. "div R0,5"
74         AND(10),   // And, e.g. "and R0,5"
75         OR(11),    // Or, e.g. "or R0,5"
76         SH(12),    // Left shift, e.g. "sh R0, 5" or "sh R0, -5" (shifts right)
77         LI(13),    // Load immediate, e.g. "li R0,5" (immediate encoded as signed value)
78         // Jump, e.g. "jmp label"
79         // In APFv6, we use JMP(R=1) to encode the DATA instruction. DATA is executed as a jump.
80         // It tells how many bytes of the program regions are used to store the data and followed
81         // by the actual data bytes.
82         // "e.g. data 5, abcde"
83         JMP(14),
84         JEQ(15),   // Compare equal and branch, e.g. "jeq R0,5,label"
85         JNE(16),   // Compare not equal and branch, e.g. "jne R0,5,label"
86         JGT(17),   // Compare greater than and branch, e.g. "jgt R0,5,label"
87         JLT(18),   // Compare less than and branch, e.g. "jlt R0,5,label"
88         JSET(19),  // Compare any bits set and branch, e.g. "jset R0,5,label"
89         // Compare not equal byte sequence, e.g. "jnebs R0,5,label,0x1122334455"
90         // NOTE: Only APFv6+ implements R=1 'jbseq' version and multi match
91         // imm1 is jmp target, imm2 is (cnt - 1) * 2048 + compare_len,
92         // which is followed by cnt * compare_len bytes to compare against.
93         // Warning: do not specify the same byte sequence multiple times.
94         JBSMATCH(20),
95         EXT(21),   // Followed by immediate indicating ExtendedOpcodes.
96         LDDW(22),  // Load 4 bytes from data memory address (register + immediate): "lddw R0, [5]R1"
97         STDW(23),  // Store 4 bytes to data memory address (register + immediate): "stdw R0, [5]R1"
98         // Write 1, 2 or 4 bytes immediate to the output buffer and auto-increment the pointer to
99         // write. e.g. "write 5"
100         WRITE(24),
101         // Copy bytes from input packet/APF program/data region to output buffer and
102         // auto-increment the output buffer pointer.
103         // Register bit is used to specify the source of data copy.
104         // R=0 means copy from packet.
105         // R=1 means copy from APF program/data region.
106         // The copy length is stored in (u8)imm2.
107         // e.g. "pktcopy 5, 5" "datacopy 5, 5"
108         PKTDATACOPY(25);
109 
110         final int value;
111 
Opcodes(int value)112         Opcodes(int value) {
113             this.value = value;
114         }
115     }
116     // Extended opcodes. Primary opcode is Opcodes.EXT. ExtendedOpcodes are encoded in the immediate
117     // field.
118     enum ExtendedOpcodes {
119         LDM(0),   // Load from memory, e.g. "ldm R0,5"
120         STM(16),  // Store to memory, e.g. "stm R0,5"
121         NOT(32),  // Not, e.g. "not R0"
122         NEG(33),  // Negate, e.g. "neg R0"
123         SWAP(34), // Swap, e.g. "swap R0,R1"
124         MOVE(35),  // Move, e.g. "move R0,R1"
125         // Allocate writable output buffer.
126         // R=0, use register R0 to store the length. R=1, encode the length in the u16 int imm2.
127         // "e.g. allocate R0"
128         // "e.g. allocate 123"
129         ALLOCATE(36),
130         // Transmit and deallocate the buffer (transmission can be delayed until the program
131         // terminates).  Length of buffer is the output buffer pointer (0 means discard).
132         // R=1 iff udp style L4 checksum
133         // u8 imm2 - ip header offset from start of buffer (255 for non-ip packets)
134         // u8 imm3 - offset from start of buffer to store L4 checksum (255 for no L4 checksum)
135         // u8 imm4 - offset from start of buffer to begin L4 checksum calc (present iff imm3 != 255)
136         // u16 imm5 - partial checksum value to include in L4 checksum (present iff imm3 != 255)
137         // "e.g. transmit"
138         TRANSMIT(37),
139         // Write 1, 2 or 4 byte value from register to the output buffer and auto-increment the
140         // output buffer pointer.
141         // e.g. "ewrite1 r0"
142         EWRITE1(38),
143         EWRITE2(39),
144         EWRITE4(40),
145         // Copy bytes from input packet/APF program/data region to output buffer and
146         // auto-increment the output buffer pointer.
147         // Register bit is used to specify the source of data copy.
148         // R=0 means copy from packet.
149         // R=1 means copy from APF program/data region.
150         // The source offset is stored in R0, copy length is stored in u8 imm2 or R1.
151         // e.g. "epktcopy r0, 16", "edatacopy r0, 16", "epktcopy r0, r1", "edatacopy r0, r1"
152         EPKTDATACOPYIMM(41),
153         EPKTDATACOPYR1(42),
154         // Jumps if the UDP payload content (starting at R0) does [not] match one
155         // of the specified QNAMEs in question records, applying case insensitivity.
156         // SAFE version PASSES corrupt packets, while the other one DROPS.
157         // R=0/1 meaning 'does not match'/'matches'
158         // R0: Offset to UDP payload content
159         // imm1: Extended opcode
160         // imm2: Jump label offset
161         // imm3(u8): Question type (PTR/SRV/TXT/A/AAAA)
162         // imm4(bytes): null terminated list of null terminated LV-encoded QNAMEs
163         // e.g.: "jdnsqeq R0,label,0xc,\002aa\005local\0\0",
164         //       "jdnsqne R0,label,0xc,\002aa\005local\0\0"
165         JDNSQMATCH(43),
166         JDNSQMATCHSAFE(45),
167         // Jumps if the UDP payload content (starting at R0) does [not] match one
168         // of the specified NAMEs in answers/authority/additional records, applying
169         // case insensitivity.
170         // SAFE version PASSES corrupt packets, while the other one DROPS.
171         // R=0/1 meaning 'does not match'/'matches'
172         // R0: Offset to UDP payload content
173         // imm1: Extended opcode
174         // imm2: Jump label offset
175         // imm3(bytes): null terminated list of null terminated LV-encoded NAMEs
176         // e.g.: "jdnsaeq R0,label,0xc,\002aa\005local\0\0",
177         //       "jdnsane R0,label,0xc,\002aa\005local\0\0"
178 
179         JDNSAMATCH(44),
180         JDNSAMATCHSAFE(46),
181         // Jump if register is [not] one of the list of values
182         // R bit - specifies the register (R0/R1) to test
183         // imm1: Extended opcode
184         // imm2: Jump label offset
185         // imm3(u8): top 5 bits - number of following u8/be16/be32 values - 1
186         //        middle 2 bits - 1..4 length of immediates - 1
187         //        bottom 1 bit  - =0 jmp if in set, =1 if not in set
188         // imm4(imm3 * 1/2/3/4 bytes): the *UNIQUE* values to compare against
189         JONEOF(47),
190         /* Specify length of exception buffer, which is populated on abnormal program termination.
191          * imm1: Extended opcode
192          * imm2(u16): Length of exception buffer (located *immediately* after the program itself)
193          */
194         EXCEPTIONBUFFER(48);
195 
196         final int value;
197 
ExtendedOpcodes(int value)198         ExtendedOpcodes(int value) {
199             this.value = value;
200         }
201     }
202     public enum Register {
203         R0,
204         R1;
205 
other()206         Register other() {
207             return (this == R0) ? R1 : R0;
208         }
209     }
210 
211     public enum Rbit {
212         Rbit0(0),
213         Rbit1(1);
214 
215         final int value;
216 
Rbit(int value)217         Rbit(int value) {
218             this.value = value;
219         }
220     }
221 
222     private enum IntImmediateType {
223         INDETERMINATE_SIZE_SIGNED,
224         INDETERMINATE_SIZE_UNSIGNED,
225         SIGNED_8,
226         UNSIGNED_8,
227         SIGNED_BE16,
228         UNSIGNED_BE16,
229         SIGNED_BE32,
230         UNSIGNED_BE32;
231     }
232 
233     private static class IntImmediate {
234         public final IntImmediateType mImmediateType;
235         public final int mValue;
236 
IntImmediate(int value, IntImmediateType type)237         IntImmediate(int value, IntImmediateType type) {
238             mImmediateType = type;
239             mValue = value;
240         }
241 
calculateIndeterminateSize()242         private int calculateIndeterminateSize() {
243             switch (mImmediateType) {
244                 case INDETERMINATE_SIZE_SIGNED:
245                     return calculateImmSize(mValue, true /* signed */);
246                 case INDETERMINATE_SIZE_UNSIGNED:
247                     return calculateImmSize(mValue, false /* signed */);
248                 default:
249                     // For IMM with determinate size, return 0 to allow Math.max() calculation in
250                     // caller function.
251                     return 0;
252             }
253         }
254 
getEncodingSize(int immFieldSize)255         private int getEncodingSize(int immFieldSize) {
256             switch (mImmediateType) {
257                 case SIGNED_8:
258                 case UNSIGNED_8:
259                     return 1;
260                 case SIGNED_BE16:
261                 case UNSIGNED_BE16:
262                     return 2;
263                 case SIGNED_BE32:
264                 case UNSIGNED_BE32:
265                     return 4;
266                 case INDETERMINATE_SIZE_SIGNED:
267                 case INDETERMINATE_SIZE_UNSIGNED: {
268                     int minSizeRequired = calculateIndeterminateSize();
269                     if (minSizeRequired > immFieldSize) {
270                         throw new IllegalStateException(
271                                 String.format("immFieldSize: %d is too small to encode value %d",
272                                         immFieldSize, mValue));
273                     }
274                     return immFieldSize;
275                 }
276             }
277             throw new IllegalStateException("UnhandledInvalid IntImmediateType: " + mImmediateType);
278         }
279 
writeValue(byte[] bytecode, Integer writingOffset, int immFieldSize)280         private int writeValue(byte[] bytecode, Integer writingOffset, int immFieldSize) {
281             return Instruction.writeValue(mValue, bytecode, writingOffset,
282                     getEncodingSize(immFieldSize));
283         }
284 
newSigned(int imm)285         public static IntImmediate newSigned(int imm) {
286             return new IntImmediate(imm, IntImmediateType.INDETERMINATE_SIZE_SIGNED);
287         }
288 
newUnsigned(long imm)289         public static IntImmediate newUnsigned(long imm) {
290             // upperBound is 2^32 - 1
291             checkRange("Unsigned IMM", imm, 0 /* lowerBound */,
292                     4294967295L /* upperBound */);
293             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_UNSIGNED);
294         }
295 
newTwosComplementUnsigned(long imm)296         public static IntImmediate newTwosComplementUnsigned(long imm) {
297             checkRange("Unsigned TwosComplement IMM", imm, Integer.MIN_VALUE,
298                     4294967295L /* upperBound */);
299             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_UNSIGNED);
300         }
301 
newTwosComplementSigned(long imm)302         public static IntImmediate newTwosComplementSigned(long imm) {
303             checkRange("Signed TwosComplement IMM", imm, Integer.MIN_VALUE,
304                     4294967295L /* upperBound */);
305             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_SIGNED);
306         }
307 
newS8(byte imm)308         public static IntImmediate newS8(byte imm) {
309             checkRange("S8 IMM", imm, Byte.MIN_VALUE, Byte.MAX_VALUE);
310             return new IntImmediate(imm, IntImmediateType.SIGNED_8);
311         }
312 
newU8(int imm)313         public static IntImmediate newU8(int imm) {
314             checkRange("U8 IMM", imm, 0, 255);
315             return new IntImmediate(imm, IntImmediateType.UNSIGNED_8);
316         }
317 
newS16(short imm)318         public static IntImmediate newS16(short imm) {
319             return new IntImmediate(imm, IntImmediateType.SIGNED_BE16);
320         }
321 
newU16(int imm)322         public static IntImmediate newU16(int imm) {
323             checkRange("U16 IMM", imm, 0, 65535);
324             return new IntImmediate(imm, IntImmediateType.UNSIGNED_BE16);
325         }
326 
newS32(int imm)327         public static IntImmediate newS32(int imm) {
328             return new IntImmediate(imm, IntImmediateType.SIGNED_BE32);
329         }
330 
newU32(long imm)331         public static IntImmediate newU32(long imm) {
332             // upperBound is 2^32 - 1
333             checkRange("U32 IMM", imm, 0 /* lowerBound */,
334                     4294967295L /* upperBound */);
335             return new IntImmediate((int) imm, IntImmediateType.UNSIGNED_BE32);
336         }
337 
338         @Override
toString()339         public String toString() {
340             return "IntImmediate{" + "mImmediateType=" + mImmediateType + ", mValue=" + mValue
341                     + '}';
342         }
343     }
344 
345     class Instruction {
346         public final Opcodes mOpcode;
347         private final Rbit mRbit;
348         public final List<IntImmediate> mIntImms = new ArrayList<>();
349         // When mOpcode is a jump:
350         private int mTargetLabelSize;
351         private int mImmSizeOverride = -1;
352         private String mTargetLabel;
353         // When mOpcode == Opcodes.LABEL:
354         private String mLabel;
355         public byte[] mBytesImm;
356         // Offset in bytes from the beginning of this program.
357         // Set by {@link BaseApfGenerator#generate}.
358         int offset;
359 
Instruction(Opcodes opcode, Rbit rbit)360         Instruction(Opcodes opcode, Rbit rbit) {
361             mOpcode = opcode;
362             mRbit = rbit;
363         }
364 
Instruction(Opcodes opcode, Register register)365         Instruction(Opcodes opcode, Register register) {
366             this(opcode, register == R0 ? Rbit0 : Rbit1);
367         }
368 
Instruction(ExtendedOpcodes extendedOpcodes, Rbit rbit)369         Instruction(ExtendedOpcodes extendedOpcodes, Rbit rbit) {
370             this(Opcodes.EXT, rbit);
371             addUnsigned(extendedOpcodes.value);
372         }
373 
Instruction(ExtendedOpcodes extendedOpcodes, Register register)374         Instruction(ExtendedOpcodes extendedOpcodes, Register register) {
375             this(Opcodes.EXT, register);
376             addUnsigned(extendedOpcodes.value);
377         }
378 
Instruction(ExtendedOpcodes extendedOpcodes, int slot, Register register)379         Instruction(ExtendedOpcodes extendedOpcodes, int slot, Register register)
380                 throws IllegalInstructionException {
381             this(Opcodes.EXT, register);
382             if (slot < 0 || slot >= MEMORY_SLOTS) {
383                 throw new IllegalInstructionException("illegal memory slot number: " + slot);
384             }
385             addUnsigned(extendedOpcodes.value + slot);
386         }
387 
Instruction(Opcodes opcode)388         Instruction(Opcodes opcode) {
389             this(opcode, R0);
390         }
391 
Instruction(ExtendedOpcodes extendedOpcodes)392         Instruction(ExtendedOpcodes extendedOpcodes) {
393             this(extendedOpcodes, R0);
394         }
395 
addSigned(int imm)396         Instruction addSigned(int imm) {
397             mIntImms.add(IntImmediate.newSigned(imm));
398             return this;
399         }
400 
addUnsigned(long imm)401         Instruction addUnsigned(long imm) {
402             mIntImms.add(IntImmediate.newUnsigned(imm));
403             return this;
404         }
405 
406         // in practice, 'int' always enough for packet offset
addPacketOffset(int imm)407         Instruction addPacketOffset(int imm) {
408             return addUnsigned(imm);
409         }
410 
411         // in practice, 'int' always enough for data offset
addDataOffset(int imm)412         Instruction addDataOffset(int imm) {
413             return addUnsigned(imm);
414         }
415 
addTwosCompSigned(long imm)416         Instruction addTwosCompSigned(long imm) {
417             mIntImms.add(IntImmediate.newTwosComplementSigned(imm));
418             return this;
419         }
420 
addTwosCompUnsigned(long imm)421         Instruction addTwosCompUnsigned(long imm) {
422             mIntImms.add(IntImmediate.newTwosComplementUnsigned(imm));
423             return this;
424         }
425 
addS8(byte imm)426         Instruction addS8(byte imm) {
427             mIntImms.add(IntImmediate.newS8(imm));
428             return this;
429         }
430 
addU8(int imm)431         Instruction addU8(int imm) {
432             mIntImms.add(IntImmediate.newU8(imm));
433             return this;
434         }
435 
addS16(short imm)436         Instruction addS16(short imm) {
437             mIntImms.add(IntImmediate.newS16(imm));
438             return this;
439         }
440 
addU16(int imm)441         Instruction addU16(int imm) {
442             mIntImms.add(IntImmediate.newU16(imm));
443             return this;
444         }
445 
addS32(int imm)446         Instruction addS32(int imm) {
447             mIntImms.add(IntImmediate.newS32(imm));
448             return this;
449         }
450 
addU32(long imm)451         Instruction addU32(long imm) {
452             mIntImms.add(IntImmediate.newU32(imm));
453             return this;
454         }
455 
setLabel(String label)456         Instruction setLabel(String label) throws IllegalInstructionException {
457             if (mLabels.containsKey(label)) {
458                 throw new IllegalInstructionException("duplicate label " + label);
459             }
460             if (mOpcode != Opcodes.LABEL) {
461                 throw new IllegalStateException("adding label to non-label instruction");
462             }
463             mLabel = label;
464             mLabels.put(label, this);
465             return this;
466         }
467 
setTargetLabel(String label)468         Instruction setTargetLabel(String label) {
469             mTargetLabel = label;
470             mTargetLabelSize = 4; // May shrink later on in generate().
471             return this;
472         }
473 
overrideImmSize(int size)474         Instruction overrideImmSize(int size) {
475             mImmSizeOverride = size;
476             return this;
477         }
478 
setBytesImm(byte[] bytes)479         Instruction setBytesImm(byte[] bytes) {
480             mBytesImm = bytes;
481             return this;
482         }
483 
484         /**
485          * Attempts to match {@code content} with existing data bytes. If not exist, then
486          * append the {@code content} to the data bytes.
487          * Returns the start offset of the content from the beginning of the program.
488          */
maybeUpdateBytesImm(byte[] content)489         int maybeUpdateBytesImm(byte[] content) throws IllegalInstructionException {
490             if (mOpcode != Opcodes.JMP || mBytesImm == null) {
491                 throw new IllegalInstructionException(String.format(
492                         "maybeUpdateBytesImm() is only valid for jump data instruction, mOpcode "
493                                 + ":%s, mBytesImm: %s", Opcodes.JMP,
494                         mBytesImm == null ? "(empty)" : HexDump.toHexString(mBytesImm)));
495             }
496             if (mImmSizeOverride != 2) {
497                 throw new IllegalInstructionException(
498                         "mImmSizeOverride must be 2, mImmSizeOverride: " + mImmSizeOverride);
499             }
500             int offsetInDataBytes = CollectionUtils.indexOfSubArray(mBytesImm, content);
501             if (offsetInDataBytes == -1) {
502                 offsetInDataBytes = mBytesImm.length;
503                 mBytesImm = ByteUtils.concat(mBytesImm, content);
504                 // Update the length immediate (first imm) value. Due to mValue within
505                 // IntImmediate being final, we must remove and re-add the value to apply changes.
506                 mIntImms.remove(0);
507                 addDataOffset(mBytesImm.length);
508             }
509             // Note that the data instruction encoding consumes 1 byte and the data length
510             // encoding consumes 2 bytes.
511             return 1 + mImmSizeOverride + offsetInDataBytes;
512         }
513 
514         /**
515          * Updates exception buffer size.
516          * @param bufSize the new exception buffer size
517          */
updateExceptionBufferSize(int bufSize)518         void updateExceptionBufferSize(int bufSize) throws IllegalInstructionException {
519             if (mOpcode != Opcodes.EXT || mIntImms.get(0).mValue
520                     != ExtendedOpcodes.EXCEPTIONBUFFER.value) {
521                 throw new IllegalInstructionException(
522                         "updateExceptionBuffer() is only valid for EXCEPTIONBUFFER opcode");
523             }
524             // Update the buffer size immediate (second imm) value. Due to mValue within
525             // IntImmediate being final, we must remove and re-add the value to apply changes.
526             mIntImms.remove(1);
527             addU16(bufSize);
528         }
529 
530         /**
531          * @return size of instruction in bytes.
532          */
size()533         int size() {
534             if (mOpcode == Opcodes.LABEL) {
535                 return 0;
536             }
537             int size = 1;
538             int indeterminateSize = calculateRequiredIndeterminateSize();
539             for (IntImmediate imm : mIntImms) {
540                 size += imm.getEncodingSize(indeterminateSize);
541             }
542             if (mTargetLabel != null) {
543                 size += indeterminateSize;
544             }
545             if (mBytesImm != null) {
546                 size += mBytesImm.length;
547             }
548             return size;
549         }
550 
551         /**
552          * Resize immediate value field so that it's only as big as required to
553          * contain the offset of the jump destination.
554          * @return {@code true} if shrunk.
555          */
shrink()556         boolean shrink() throws IllegalInstructionException {
557             if (mTargetLabel == null) {
558                 return false;
559             }
560             int oldTargetLabelSize = mTargetLabelSize;
561             mTargetLabelSize = calculateImmSize(calculateTargetLabelOffset(), false);
562             if (mTargetLabelSize > oldTargetLabelSize) {
563                 throw new IllegalStateException("instruction grew");
564             }
565             return mTargetLabelSize < oldTargetLabelSize;
566         }
567 
568         /**
569          * Assemble value for instruction size field.
570          */
generateImmSizeField()571         private int generateImmSizeField() {
572             int immSize = calculateRequiredIndeterminateSize();
573             // Encode size field to fit in 2 bits: 0->0, 1->1, 2->2, 3->4.
574             return immSize == 4 ? 3 : immSize;
575         }
576 
577         /**
578          * Assemble first byte of generated instruction.
579          */
generateInstructionByte()580         private byte generateInstructionByte() {
581             int sizeField = generateImmSizeField();
582             return (byte) ((mOpcode.value << 3) | (sizeField << 1) | (byte) mRbit.value);
583         }
584 
585         /**
586          * Write {@code value} at offset {@code writingOffset} into {@code bytecode}.
587          * {@code immSize} bytes are written. {@code value} is truncated to
588          * {@code immSize} bytes. {@code value} is treated simply as a
589          * 32-bit value, so unsigned values should be zero extended and the truncation
590          * should simply throw away their zero-ed upper bits, and signed values should
591          * be sign extended and the truncation should simply throw away their signed
592          * upper bits.
593          */
writeValue(int value, byte[] bytecode, int writingOffset, int immSize)594         private static int writeValue(int value, byte[] bytecode, int writingOffset, int immSize) {
595             for (int i = immSize - 1; i >= 0; i--) {
596                 bytecode[writingOffset++] = (byte) ((value >> (i * 8)) & 255);
597             }
598             return writingOffset;
599         }
600 
601         /**
602          * Generate bytecode for this instruction at offset {@link Instruction#offset}.
603          */
generate(byte[] bytecode)604         void generate(byte[] bytecode) throws IllegalInstructionException {
605             if (mOpcode == Opcodes.LABEL) {
606                 return;
607             }
608             int writingOffset = offset;
609             bytecode[writingOffset++] = generateInstructionByte();
610             int indeterminateSize = calculateRequiredIndeterminateSize();
611             int startOffset = 0;
612             if (mOpcode == Opcodes.EXT) {
613                 // For extend opcode, always write the actual opcode first.
614                 writingOffset = mIntImms.get(startOffset++).writeValue(bytecode, writingOffset,
615                         indeterminateSize);
616             }
617             if (mTargetLabel != null) {
618                 writingOffset = writeValue(calculateTargetLabelOffset(), bytecode, writingOffset,
619                         indeterminateSize);
620             }
621             for (int i = startOffset; i < mIntImms.size(); ++i) {
622                 writingOffset = mIntImms.get(i).writeValue(bytecode, writingOffset,
623                         indeterminateSize);
624             }
625             if (mBytesImm != null) {
626                 System.arraycopy(mBytesImm, 0, bytecode, writingOffset, mBytesImm.length);
627                 writingOffset += mBytesImm.length;
628             }
629             if ((writingOffset - offset) != size()) {
630                 throw new IllegalStateException("wrote " + (writingOffset - offset)
631                         + " but should have written " + size());
632             }
633         }
634 
635         /**
636          * Calculates the maximum indeterminate size of all IMMs in this instruction.
637          * <p>
638          * This method finds the largest size needed to encode any indeterminate-sized IMMs in
639          * the instruction. This size will be stored in the immLen field.
640          */
calculateRequiredIndeterminateSize()641         private int calculateRequiredIndeterminateSize() {
642             int maxSize = mTargetLabelSize;
643             for (IntImmediate imm : mIntImms) {
644                 maxSize = Math.max(maxSize, imm.calculateIndeterminateSize());
645             }
646             if (mImmSizeOverride != -1 && maxSize > mImmSizeOverride) {
647                 throw new IllegalStateException(String.format(
648                         "maxSize: %d should not be greater than mImmSizeOverride: %d", maxSize,
649                         mImmSizeOverride));
650             }
651             // If we already know the size the length field, just use it
652             switch (mImmSizeOverride) {
653                 case -1:
654                     return maxSize;
655                 case 1:
656                 case 2:
657                 case 4:
658                     return mImmSizeOverride;
659                 default:
660                     throw new IllegalStateException(
661                             "mImmSizeOverride has invalid value: " + mImmSizeOverride);
662             }
663         }
664 
calculateTargetLabelOffset()665         private int calculateTargetLabelOffset() throws IllegalInstructionException {
666             Instruction targetLabelInstruction;
667             if (mTargetLabel == DROP_LABEL) {
668                 targetLabelInstruction = mDropLabel;
669             } else if (mTargetLabel == PASS_LABEL) {
670                 targetLabelInstruction = mPassLabel;
671             } else {
672                 targetLabelInstruction = mLabels.get(mTargetLabel);
673             }
674             if (targetLabelInstruction == null) {
675                 throw new IllegalInstructionException("label not found: " + mTargetLabel);
676             }
677             // Calculate distance from end of this instruction to instruction.offset.
678             final int targetLabelOffset = targetLabelInstruction.offset - (offset + size());
679             return targetLabelOffset;
680         }
681     }
682 
683     /**
684      * Updates instruction offset fields using latest instruction sizes.
685      * @return current program length in bytes.
686      */
updateInstructionOffsets()687     private int updateInstructionOffsets() {
688         int offset = 0;
689         for (Instruction instruction : mInstructions) {
690             instruction.offset = offset;
691             offset += instruction.size();
692         }
693         return offset;
694     }
695 
696     /**
697      * Calculate the size of the imm.
698      */
calculateImmSize(int imm, boolean signed)699     static int calculateImmSize(int imm, boolean signed) {
700         if (imm == 0) {
701             return 0;
702         }
703         if (signed && (imm >= -128 && imm <= 127) || !signed && (imm >= 0 && imm <= 255)) {
704             return 1;
705         }
706         if (signed && (imm >= -32768 && imm <= 32767) || !signed && (imm >= 0 && imm <= 65535)) {
707             return 2;
708         }
709         return 4;
710     }
711 
checkRange(@onNull String variableName, long value, long lowerBound, long upperBound)712     static void checkRange(@NonNull String variableName, long value, long lowerBound,
713                            long upperBound) {
714         if (value >= lowerBound && value <= upperBound) {
715             return;
716         }
717         throw new IllegalArgumentException(
718                 String.format("%s: %d, must be in range [%d, %d]", variableName, value, lowerBound,
719                         upperBound));
720     }
721 
checkPassCounterRange(ApfCounterTracker.Counter cnt)722     void checkPassCounterRange(ApfCounterTracker.Counter cnt) {
723         if (mDisableCounterRangeCheck) return;
724         if (cnt.value() < ApfCounterTracker.MIN_PASS_COUNTER.value()
725                 || cnt.value() > ApfCounterTracker.MAX_PASS_COUNTER.value()) {
726             throw new IllegalArgumentException(
727                     String.format("Counter %s, is not in range [%s, %s]", cnt,
728                             ApfCounterTracker.MIN_PASS_COUNTER,
729                             ApfCounterTracker.MAX_PASS_COUNTER));
730         }
731     }
732 
checkDropCounterRange(ApfCounterTracker.Counter cnt)733     void checkDropCounterRange(ApfCounterTracker.Counter cnt) {
734         if (mDisableCounterRangeCheck) return;
735         if (cnt.value() < ApfCounterTracker.MIN_DROP_COUNTER.value()
736                 || cnt.value() > ApfCounterTracker.MAX_DROP_COUNTER.value()) {
737             throw new IllegalArgumentException(
738                     String.format("Counter %s, is not in range [%s, %s]", cnt,
739                             ApfCounterTracker.MIN_DROP_COUNTER,
740                             ApfCounterTracker.MAX_DROP_COUNTER));
741         }
742     }
743 
744     /**
745      * Returns an overestimate of the size of the generated program. {@link #generate} may return
746      * a program that is smaller.
747      */
programLengthOverEstimate()748     public int programLengthOverEstimate() {
749         return updateInstructionOffsets();
750     }
751 
752     /**
753      * Updates the exception buffer size.
754      */
updateExceptionBufferSize(int programSize)755     abstract void updateExceptionBufferSize(int programSize) throws IllegalInstructionException;
756 
757     /**
758      * Generate the bytecode for the APF program.
759      * @return the bytecode.
760      * @throws IllegalStateException if a label is referenced but not defined.
761      */
generate()762     public byte[] generate() throws IllegalInstructionException {
763         // Enforce that we can only generate once because we cannot unshrink instructions and
764         // PASS/DROP labels may move further away requiring unshrinking if we add further
765         // instructions.
766         if (mGenerated) {
767             throw new IllegalStateException("Can only generate() once!");
768         }
769         mGenerated = true;
770         int total_size;
771         boolean shrunk;
772         // Shrink the immediate value fields of instructions.
773         // As we shrink the instructions some branch offset
774         // fields may shrink also, thereby shrinking the
775         // instructions further. Loop until we've reached the
776         // minimum size. Rarely will this loop more than a few times.
777         // Limit iterations to avoid O(n^2) behavior.
778         int iterations_remaining = 10;
779         do {
780             total_size = updateInstructionOffsets();
781             // Update drop and pass label offsets.
782             mDropLabel.offset = total_size + 1;
783             mPassLabel.offset = total_size;
784             // Limit run-time in aberant circumstances.
785             if (iterations_remaining-- == 0) break;
786             // Attempt to shrink instructions.
787             shrunk = false;
788             for (Instruction instruction : mInstructions) {
789                 if (instruction.shrink()) {
790                     shrunk = true;
791                 }
792             }
793         } while (shrunk);
794         // Generate bytecode for instructions.
795         byte[] bytecode = new byte[total_size];
796         updateExceptionBufferSize(total_size);
797         for (Instruction instruction : mInstructions) {
798             instruction.generate(bytecode);
799         }
800         return bytecode;
801     }
802 
validateBytes(byte[] bytes)803     void validateBytes(byte[] bytes) {
804         Objects.requireNonNull(bytes);
805         if (bytes.length > 2047) {
806             throw new IllegalArgumentException(
807                     "bytes array size must be in less than 2048, current size: " + bytes.length);
808         }
809     }
810 
validateDeduplicateBytesList(List<byte[]> bytesList)811     List<byte[]> validateDeduplicateBytesList(List<byte[]> bytesList) {
812         if (bytesList == null || bytesList.size() == 0) {
813             throw new IllegalArgumentException(
814                     "bytesList size must > 0, current size: "
815                             + (bytesList == null ? "null" : bytesList.size()));
816         }
817         for (byte[] bytes : bytesList) {
818             validateBytes(bytes);
819         }
820         final int elementSize = bytesList.get(0).length;
821         if (elementSize > 2097151) { // 2 ^ 21 - 1
822             throw new IllegalArgumentException("too many elements");
823         }
824         List<byte[]> deduplicatedList = new ArrayList<>();
825         deduplicatedList.add(bytesList.get(0));
826         for (int i = 1; i < bytesList.size(); ++i) {
827             if (elementSize != bytesList.get(i).length) {
828                 throw new IllegalArgumentException("byte arrays in the set have different size");
829             }
830             int j = 0;
831             for (; j < deduplicatedList.size(); ++j) {
832                 if (Arrays.equals(bytesList.get(i), deduplicatedList.get(j))) {
833                     break;
834                 }
835             }
836             if (j == deduplicatedList.size()) {
837                 deduplicatedList.add(bytesList.get(i));
838             }
839         }
840         return deduplicatedList;
841     }
842 
requireApfVersion(int minimumVersion)843     void requireApfVersion(int minimumVersion) throws IllegalInstructionException {
844         if (mVersion < minimumVersion) {
845             throw new IllegalInstructionException("Requires APF >= " + minimumVersion);
846         }
847     }
848 
849     private int mLabelCount = 0;
850 
851     /**
852      * Return a unique label string.
853      */
getUniqueLabel()854     protected String getUniqueLabel() {
855         return "LABEL_" + mLabelCount++;
856     }
857 
858     /**
859      * Jump to this label to terminate the program and indicate the packet
860      * should be dropped.
861      */
862     public static final String DROP_LABEL = "__DROP__";
863 
864     /**
865      * Jump to this label to terminate the program and indicate the packet
866      * should be passed to the AP.
867      */
868     public static final String PASS_LABEL = "__PASS__";
869 
870     /**
871      * Number of memory slots available for access via APF stores to memory and loads from memory.
872      * The memory slots are numbered 0 to {@code MEMORY_SLOTS} - 1. This must be kept in sync with
873      * the APF interpreter.
874      */
875     public static final int MEMORY_SLOTS = 16;
876 
877     public enum MemorySlot {
878         /**
879          * These slots start with value 0 and are unused.
880          */
881         SLOT_0(0),
882         SLOT_1(1),
883         SLOT_2(2),
884         SLOT_3(3),
885         SLOT_4(4),
886         SLOT_5(5),
887         SLOT_6(6),
888         SLOT_7(7),
889 
890         /**
891          * First memory slot containing prefilled (ie. non-zero) values.
892          * Can be used in range comparisons to determine if memory slot index
893          * is within prefilled slots.
894          */
895         FIRST_PREFILLED(8),
896 
897         /**
898          * Slot #8 is used for the APFv6+ version.
899          */
900         APF_VERSION(8),
901 
902         /**
903          * Slot #9 is used for the filter age in 16384ths of a second (APFv6+).
904          */
905         FILTER_AGE_16384THS(9),
906 
907         /**
908          * Slot #10 starts at zero, implicitly used as tx buffer output pointer.
909          */
910         TX_BUFFER_OUTPUT_POINTER(10),
911 
912         /**
913          * Slot #11 is used for the program byte code size (APFv2+).
914          */
915         PROGRAM_SIZE(11),
916 
917         /**
918          * Slot #12 is used for the total RAM length.
919          */
920         RAM_LEN(12),
921 
922         /**
923          * Slot #13 is the IPv4 header length (in bytes).
924          */
925         IPV4_HEADER_SIZE(13),
926 
927         /**
928          * Slot #14 is the size of the packet being filtered in bytes.
929          */
930         PACKET_SIZE(14),
931 
932         /**
933          * Slot #15 is the age of the filter (time since filter was installed
934          * till now) in seconds.
935          */
936         FILTER_AGE_SECONDS(15);
937 
938         public final int value;
939 
MemorySlot(int value)940         MemorySlot(int value) {
941             this.value = value;
942         }
943     }
944 
945     // This version number syncs up with APF_VERSION in hardware/google/apf/apf_interpreter.h
946     public static final int APF_VERSION_2 = 2;
947     public static final int APF_VERSION_3 = 3;
948     public static final int APF_VERSION_4 = 4;
949     public static final int APF_VERSION_6 = 6000;
950 
951 
952     final ArrayList<Instruction> mInstructions = new ArrayList<Instruction>();
953     private final HashMap<String, Instruction> mLabels = new HashMap<String, Instruction>();
954     private final Instruction mDropLabel = new Instruction(Opcodes.LABEL);
955     private final Instruction mPassLabel = new Instruction(Opcodes.LABEL);
956     public final int mVersion;
957     public boolean mGenerated;
958     private final boolean mDisableCounterRangeCheck;
959 }
960