1# Copyright (c) Barefoot Networks, Inc.
2# Licensed under the Apache License, Version 2.0 (the "License")
3
4from p4_hlir.hlir import parse_call, p4_field, p4_parse_value_set, \
5    P4_DEFAULT, p4_parse_state, p4_table, \
6    p4_conditional_node, p4_parser_exception, \
7    p4_header_instance, P4_NEXT
8import ebpfProgram
9import ebpfStructType
10import ebpfInstance
11import programSerializer
12from compilationException import *
13
14
15class EbpfParser(object):
16    def __init__(self, hlirParser):  # hlirParser is a P4 parser
17        self.parser = hlirParser
18        self.name = hlirParser.name
19
20    def serialize(self, serializer, program):
21        assert isinstance(serializer, programSerializer.ProgramSerializer)
22        assert isinstance(program, ebpfProgram.EbpfProgram)
23
24        serializer.emitIndent()
25        serializer.appendFormat("{0}: ", self.name)
26        serializer.blockStart()
27        for op in self.parser.call_sequence:
28            self.serializeOperation(serializer, op, program)
29
30        self.serializeBranch(serializer, self.parser.branch_on,
31                             self.parser.branch_to, program)
32
33        serializer.blockEnd(True)
34
35    def serializeSelect(self, selectVarName, serializer, branch_on, program):
36        # selectVarName - name of temp variable to use for the select expression
37        assert isinstance(selectVarName, str)
38        assert isinstance(serializer, programSerializer.ProgramSerializer)
39        assert isinstance(program, ebpfProgram.EbpfProgram)
40
41        totalWidth = 0
42        switchValue = ""
43        for e in branch_on:
44            if isinstance(e, p4_field):
45                instance = e.instance
46                assert isinstance(instance, p4_header_instance)
47                index = ""
48
49                if ebpfProgram.EbpfProgram.isArrayElementInstance(instance):
50                    ebpfStack = program.getStackInstance(instance.base_name)
51                    assert isinstance(ebpfStack, ebpfInstance.EbpfHeaderStack)
52
53                    if isinstance(instance.index, int):
54                        index = "[" + str(instance.index) + "]"
55                    elif instance.index is P4_NEXT:
56                        index = "[" + ebpfStack.indexVar + "]"
57                    else:
58                        raise CompilationException(True,
59                            "Unexpected index for array {0}", instance.index)
60                    basetype = ebpfStack.basetype
61                    name = ebpfStack.name
62                else:
63                    ebpfHeader = program.getInstance(instance.name)
64                    assert isinstance(ebpfHeader, ebpfInstance.EbpfHeader)
65                    basetype = ebpfHeader.type
66                    name = ebpfHeader.name
67
68                ebpfField = basetype.getField(e.name)
69                assert isinstance(ebpfField, ebpfStructType.EbpfField)
70
71                totalWidth += ebpfField.widthInBits()
72                fieldReference = (program.headerStructName + "." + name +
73                                  index + "." + ebpfField.name)
74
75                if switchValue == "":
76                    switchValue = fieldReference
77                else:
78                    switchValue = ("(" + switchValue + " << " +
79                                   str(ebpfField.widthInBits()) + ")")
80                    switchValue = switchValue + " | " + fieldReference
81            elif isinstance(e, tuple):
82                switchValue = self.currentReferenceAsString(e, program)
83            else:
84                raise CompilationException(
85                    True, "Unexpected element in match {0}", e)
86
87        if totalWidth > 32:
88            raise NotSupportedException("{0}: Matching on {1}-bit value",
89                                        branch_on, totalWidth)
90        serializer.emitIndent()
91        serializer.appendFormat("{0}32 {1} = {2};",
92                                program.config.uprefix,
93                                selectVarName, switchValue)
94        serializer.newline()
95
96    def generatePacketLoad(self, startBit, width, alignment, program):
97        # Generates an expression that does a load_*, shift and mask
98        # to load 'width' bits starting at startBit from the current
99        # packet offset.
100        # alignment is an integer <= 8 that holds the current alignment
101        # of of the packet offset.
102        assert width > 0
103        assert alignment < 8
104        assert isinstance(startBit, int)
105        assert isinstance(width, int)
106        assert isinstance(alignment, int)
107
108        firstBitIndex = startBit + alignment
109        lastBitIndex = startBit + width + alignment - 1
110        firstWordIndex = firstBitIndex / 8
111        lastWordIndex = lastBitIndex / 8
112
113        wordsToRead = lastWordIndex - firstWordIndex + 1
114        if wordsToRead == 1:
115            load = "load_byte"
116            loadSize = 8
117        elif wordsToRead == 2:
118            load = "load_half"
119            loadSize = 16
120        elif wordsToRead <= 4:
121            load = "load_word"
122            loadSize = 32
123        elif wordsToRead <= 8:
124            load = "load_dword"
125            loadSize = 64
126        else:
127            raise CompilationException(True, "Attempt to load more than 1 word")
128
129        readtype = program.config.uprefix + str(loadSize)
130        loadInstruction = "{0}({1}, ({2} + {3}) / 8)".format(
131            load, program.packetName, program.offsetVariableName, startBit)
132        shift = loadSize - alignment - width
133        load = "(({0}) >> ({1}))".format(loadInstruction, shift)
134        if width != loadSize:
135            mask = " & EBPF_MASK({0}, {1})".format(readtype, width)
136        else:
137            mask = ""
138        return load + mask
139
140    def currentReferenceAsString(self, tpl, program):
141        # a string describing an expression of the form current(position, width)
142        # The assumption is that at this point the packet cursor is ALWAYS
143        # byte aligned.  This should be true because headers are supposed
144        # to have sizes an integral number of bytes.
145        assert isinstance(tpl, tuple)
146        if len(tpl) != 2:
147            raise CompilationException(
148                True, "{0} Expected a tuple with 2 elements", tpl)
149
150        minIndex = tpl[0]
151        totalWidth = tpl[1]
152        result = self.generatePacketLoad(
153            minIndex, totalWidth, 0, program) # alignment is 0
154        return result
155
156    def serializeCases(self, selectVarName, serializer, branch_to, program):
157        assert isinstance(selectVarName, str)
158        assert isinstance(serializer, programSerializer.ProgramSerializer)
159        assert isinstance(program, ebpfProgram.EbpfProgram)
160
161        branches = 0
162        seenDefault = False
163        for e in branch_to.keys():
164            serializer.emitIndent()
165            value = branch_to[e]
166
167            if isinstance(e, int):
168                serializer.appendFormat("if ({0} == {1})", selectVarName, e)
169            elif isinstance(e, tuple):
170                serializer.appendFormat(
171                    "if (({0} & {1}) == {2})", selectVarName, e[0], e[1])
172            elif isinstance(e, p4_parse_value_set):
173                raise NotSupportedException("{0}: Parser value sets", e)
174            elif e is P4_DEFAULT:
175                seenDefault = True
176                if branches > 0:
177                    serializer.append("else")
178            else:
179                raise CompilationException(
180                    True, "Unexpected element in match case {0}", e)
181
182            branches += 1
183            serializer.newline()
184            serializer.increaseIndent()
185            serializer.emitIndent()
186
187            label = program.getLabel(value)
188
189            if isinstance(value, p4_parse_state):
190                serializer.appendFormat("goto {0};", label)
191            elif isinstance(value, p4_table):
192                serializer.appendFormat("goto {0};", label)
193            elif isinstance(value, p4_conditional_node):
194                serializer.appendFormat("goto {0};", label)
195            elif isinstance(value, p4_parser_exception):
196                raise CompilationException(True, "Not yet implemented")
197            else:
198                raise CompilationException(
199                    True, "Unexpected element in match case {0}", value)
200
201            serializer.decreaseIndent()
202            serializer.newline()
203
204        # Must create default if it is missing
205        if not seenDefault:
206            serializer.emitIndent()
207            serializer.appendFormat(
208                "{0} = p4_pe_unhandled_select;", program.errorName)
209            serializer.newline()
210            serializer.emitIndent()
211            serializer.appendFormat("default: goto end;")
212            serializer.newline()
213
214    def serializeBranch(self, serializer, branch_on, branch_to, program):
215        assert isinstance(serializer, programSerializer.ProgramSerializer)
216        assert isinstance(program, ebpfProgram.EbpfProgram)
217
218        if branch_on == []:
219            dest = branch_to.values()[0]
220            serializer.emitIndent()
221            name = program.getLabel(dest)
222            serializer.appendFormat("goto {0};", name)
223            serializer.newline()
224        elif isinstance(branch_on, list):
225            tmpvar = program.generateNewName("tmp")
226            self.serializeSelect(tmpvar, serializer, branch_on, program)
227            self.serializeCases(tmpvar, serializer, branch_to, program)
228        else:
229            raise CompilationException(
230                True, "Unexpected branch_on {0}", branch_on)
231
232    def serializeOperation(self, serializer, op, program):
233        assert isinstance(serializer, programSerializer.ProgramSerializer)
234        assert isinstance(program, ebpfProgram.EbpfProgram)
235
236        operation = op[0]
237        if operation is parse_call.extract:
238            self.serializeExtract(serializer, op[1], program)
239        elif operation is parse_call.set:
240            self.serializeMetadataSet(serializer, op[1], op[2], program)
241        else:
242            raise CompilationException(
243                True, "Unexpected operation in parser {0}", op)
244
245    def serializeFieldExtract(self, serializer, headerInstanceName,
246                              index, field, alignment, program):
247        assert isinstance(index, str)
248        assert isinstance(headerInstanceName, str)
249        assert isinstance(field, ebpfStructType.EbpfField)
250        assert isinstance(serializer, programSerializer.ProgramSerializer)
251        assert isinstance(alignment, int)
252        assert isinstance(program, ebpfProgram.EbpfProgram)
253
254        fieldToExtractTo = headerInstanceName + index + "." + field.name
255
256        serializer.emitIndent()
257        width = field.widthInBits()
258        if field.name == "valid":
259            serializer.appendFormat(
260                "{0}.{1} = 1;", program.headerStructName, fieldToExtractTo)
261            serializer.newline()
262            return
263
264        serializer.appendFormat("if ({0}->len < BYTES({1} + {2})) ",
265                                program.packetName,
266                                program.offsetVariableName, width)
267        serializer.blockStart()
268        serializer.emitIndent()
269        serializer.appendFormat("{0} = p4_pe_header_too_short;",
270                                program.errorName)
271        serializer.newline()
272        serializer.emitIndent()
273        serializer.appendLine("goto end;")
274        # TODO: jump to correct exception handler
275        serializer.blockEnd(True)
276
277        if width <= 32:
278            serializer.emitIndent()
279            load = self.generatePacketLoad(0, width, alignment, program)
280
281            serializer.appendFormat("{0}.{1} = {2};",
282                                    program.headerStructName,
283                                    fieldToExtractTo, load)
284            serializer.newline()
285        else:
286            # Destination is bigger than 4 bytes and
287            # represented as a byte array.
288            if alignment == 0:
289                shift = 0
290            else:
291                shift = 8 - alignment
292
293            assert shift >= 0
294            if shift == 0:
295                method = "load_byte"
296            else:
297                method = "load_half"
298            b = (width + 7) / 8
299            for i in range(0, b):
300                serializer.emitIndent()
301                serializer.appendFormat("{0}.{1}[{2}] = ({3}8)",
302                                        program.headerStructName,
303                                        fieldToExtractTo, i,
304                                        program.config.uprefix)
305                serializer.appendFormat("(({0}({1}, ({2} / 8) + {3}) >> {4})",
306                                        method, program.packetName,
307                                        program.offsetVariableName, i, shift)
308                if (i == b - 1) and (width % 8 != 0):
309                    serializer.appendFormat(" & EBPF_MASK({0}8, {1})",
310                                            program.config.uprefix, width % 8)
311                serializer.append(")")
312                serializer.endOfStatement(True)
313
314        serializer.emitIndent()
315        serializer.appendFormat("{0} += {1};",
316                                program.offsetVariableName, width)
317        serializer.newline()
318
319    def serializeExtract(self, serializer, headerInstance, program):
320        assert isinstance(serializer, programSerializer.ProgramSerializer)
321        assert isinstance(headerInstance, p4_header_instance)
322        assert isinstance(program, ebpfProgram.EbpfProgram)
323
324        if ebpfProgram.EbpfProgram.isArrayElementInstance(headerInstance):
325            ebpfStack = program.getStackInstance(headerInstance.base_name)
326            assert isinstance(ebpfStack, ebpfInstance.EbpfHeaderStack)
327
328            # write bounds check
329            serializer.emitIndent()
330            serializer.appendFormat("if ({0} >= {1}) ",
331                                    ebpfStack.indexVar, ebpfStack.arraySize)
332            serializer.blockStart()
333            serializer.emitIndent()
334            serializer.appendFormat("{0} = p4_pe_index_out_of_bounds;",
335                                    program.errorName)
336            serializer.newline()
337            serializer.emitIndent()
338            serializer.appendLine("goto end;")
339            serializer.blockEnd(True)
340
341            if isinstance(headerInstance.index, int):
342                index = "[" + str(headerInstance.index) + "]"
343            elif headerInstance.index is P4_NEXT:
344                index = "[" + ebpfStack.indexVar + "]"
345            else:
346                raise CompilationException(
347                    True, "Unexpected index for array {0}",
348                    headerInstance.index)
349            basetype = ebpfStack.basetype
350        else:
351            ebpfHeader = program.getHeaderInstance(headerInstance.name)
352            basetype = ebpfHeader.type
353            index = ""
354
355        # extract all fields
356        alignment = 0
357        for field in basetype.fields:
358            assert isinstance(field, ebpfStructType.EbpfField)
359
360            self.serializeFieldExtract(serializer, headerInstance.base_name,
361                                       index, field, alignment, program)
362            alignment += field.widthInBits()
363            alignment = alignment % 8
364
365        if ebpfProgram.EbpfProgram.isArrayElementInstance(headerInstance):
366            # increment stack index
367            ebpfStack = program.getStackInstance(headerInstance.base_name)
368            assert isinstance(ebpfStack, ebpfInstance.EbpfHeaderStack)
369
370            # write bounds check
371            serializer.emitIndent()
372            serializer.appendFormat("{0}++;", ebpfStack.indexVar)
373            serializer.newline()
374
375    def serializeMetadataSet(self, serializer, field, value, program):
376        assert isinstance(serializer, programSerializer.ProgramSerializer)
377        assert isinstance(program, ebpfProgram.EbpfProgram)
378        assert isinstance(field, p4_field)
379
380        dest = program.getInstance(field.instance.name)
381        assert isinstance(dest, ebpfInstance.SimpleInstance)
382        destType = dest.type
383        assert isinstance(destType, ebpfStructType.EbpfStructType)
384        destField = destType.getField(field.name)
385
386        if destField.widthInBits() > 32:
387            useMemcpy = True
388            bytesToCopy = destField.widthInBits() / 8
389            if destField.widthInBits() % 8 != 0:
390                raise CompilationException(
391                    True,
392                    "{0}: Not implemented: wide field w. sz not multiple of 8",
393                    field)
394        else:
395            useMemcpy = False
396            bytesToCopy = None # not needed, but compiler is confused
397
398        serializer.emitIndent()
399        destination = "{0}.{1}.{2}".format(
400            program.metadataStructName, dest.name, destField.name)
401        if isinstance(value, int):
402            source = str(value)
403            if useMemcpy:
404                raise CompilationException(
405                    True,
406                    "{0}: Not implemented: copying from wide constant",
407                    value)
408        elif isinstance(value, tuple):
409            source = self.currentReferenceAsString(value, program)
410        elif isinstance(value, p4_field):
411            source = program.getInstance(value.instance.name)
412            if isinstance(source, ebpfInstance.EbpfMetadata):
413                sourceStruct = program.metadataStructName
414            else:
415                sourceStruct = program.headerStructName
416            source = "{0}.{1}.{2}".format(sourceStruct, source.name, value.name)
417        else:
418            raise CompilationException(
419                True, "Unexpected type for parse_call.set {0}", value)
420
421        if useMemcpy:
422            serializer.appendFormat("memcpy(&{0}, &{1}, {2})",
423                                    destination, source, bytesToCopy)
424        else:
425            serializer.appendFormat("{0} = {1}", destination, source)
426
427        serializer.endOfStatement(True)
428