1# Copyright (c) Barefoot Networks, Inc.
2# Licensed under the Apache License, Version 2.0 (the "License")
3
4from p4_hlir.hlir import p4_match_type, p4_field, p4_table, p4_header_instance
5from programSerializer import ProgramSerializer
6from compilationException import *
7import ebpfProgram
8import ebpfInstance
9import ebpfCounter
10import ebpfStructType
11import ebpfAction
12
13
14class EbpfTableKeyField(object):
15    def __init__(self, fieldname, instance, field, mask):
16        assert isinstance(instance, ebpfInstance.EbpfInstanceBase)
17        assert isinstance(field, ebpfStructType.EbpfField)
18
19        self.keyFieldName = fieldname
20        self.instance = instance
21        self.field = field
22        self.mask = mask
23
24    def serializeType(self, serializer):
25        assert isinstance(serializer, ProgramSerializer)
26        ftype = self.field.type
27        serializer.emitIndent()
28        ftype.declare(serializer, self.keyFieldName, False)
29        serializer.endOfStatement(True)
30
31    def serializeConstruction(self, keyName, serializer, program):
32        assert isinstance(serializer, ProgramSerializer)
33        assert isinstance(keyName, str)
34        assert isinstance(program, ebpfProgram.EbpfProgram)
35
36        if self.mask is not None:
37            maskExpression = " & {0}".format(self.mask)
38        else:
39            maskExpression = ""
40
41        if isinstance(self.instance, ebpfInstance.EbpfMetadata):
42            base = program.metadataStructName
43        else:
44            base = program.headerStructName
45
46        if isinstance(self.instance, ebpfInstance.SimpleInstance):
47            source = "{0}.{1}.{2}".format(
48                base, self.instance.name, self.field.name)
49        else:
50            assert isinstance(self.instance, ebpfInstance.EbpfHeaderStack)
51            source = "{0}.{1}[{2}].{3}".format(
52                base, self.instance.name,
53                self.instance.hlirInstance.index, self.field.name)
54        destination = "{0}.{1}".format(keyName, self.keyFieldName)
55        size = self.field.widthInBits()
56
57        serializer.emitIndent()
58        if size <= 32:
59            serializer.appendFormat("{0} = ({1}){2};",
60                                    destination, source, maskExpression)
61        else:
62            if maskExpression != "":
63                raise NotSupportedException(
64                    "{0} Mask wider than 32 bits", self.field.hlirType)
65            serializer.appendFormat(
66                "memcpy(&{0}, &{1}, {2});", destination, source, size / 8)
67
68        serializer.newline()
69
70
71class EbpfTableKey(object):
72    def __init__(self, match_fields, program):
73        assert isinstance(program, ebpfProgram.EbpfProgram)
74
75        self.expressions = []
76        self.fields = []
77        self.masks = []
78        self.fieldNamePrefix = "key_field_"
79        self.program = program
80
81        fieldNumber = 0
82        for f in match_fields:
83            field = f[0]
84            matchType = f[1]
85            mask = f[2]
86
87            if ((matchType is p4_match_type.P4_MATCH_TERNARY) or
88                (matchType is p4_match_type.P4_MATCH_LPM) or
89                (matchType is p4_match_type.P4_MATCH_RANGE)):
90                raise NotSupportedException(
91                    False, "Match type {0}", matchType)
92
93            if matchType is p4_match_type.P4_MATCH_VALID:
94                # we should be really checking the valid field;
95                # p4_field is a header instance
96                assert isinstance(field, p4_header_instance)
97                instance = field
98                fieldname = "valid"
99            else:
100                assert isinstance(field, p4_field)
101                instance = field.instance
102                fieldname = field.name
103
104            if ebpfProgram.EbpfProgram.isArrayElementInstance(instance):
105                ebpfStack = program.getStackInstance(instance.base_name)
106                assert isinstance(ebpfStack, ebpfInstance.EbpfHeaderStack)
107                basetype = ebpfStack.basetype
108                eInstance = program.getStackInstance(instance.base_name)
109            else:
110                ebpfHeader = program.getInstance(instance.name)
111                assert isinstance(ebpfHeader, ebpfInstance.SimpleInstance)
112                basetype = ebpfHeader.type
113                eInstance = program.getInstance(instance.name)
114
115            ebpfField = basetype.getField(fieldname)
116            assert isinstance(ebpfField, ebpfStructType.EbpfField)
117
118            fieldName = self.fieldNamePrefix + str(fieldNumber)
119            fieldNumber += 1
120            keyField = EbpfTableKeyField(fieldName, eInstance, ebpfField, mask)
121
122            self.fields.append(keyField)
123            self.masks.append(mask)
124
125    @staticmethod
126    def fieldRank(field):
127        assert isinstance(field, EbpfTableKeyField)
128        return field.field.type.alignment()
129
130    def serializeType(self, serializer, keyTypeName):
131        assert isinstance(serializer, ProgramSerializer)
132        serializer.emitIndent()
133        serializer.appendFormat("struct {0} ", keyTypeName)
134        serializer.blockStart()
135
136        # Sort fields in decreasing size; this will ensure that
137        # there is no padding.
138        # Padding may cause the ebpf verification to fail,
139        # since padding fields are not initalized
140        fieldOrder = sorted(
141            self.fields, key=EbpfTableKey.fieldRank, reverse=True)
142        for f in fieldOrder:
143            assert isinstance(f, EbpfTableKeyField)
144            f.serializeType(serializer)
145
146        serializer.blockEnd(False)
147        serializer.endOfStatement(True)
148
149    def serializeConstruction(self, serializer, keyName, program):
150        serializer.emitIndent()
151        serializer.appendLine("/* construct key */")
152
153        for f in self.fields:
154            f.serializeConstruction(keyName, serializer, program)
155
156
157class EbpfTable(object):
158    # noinspection PyUnresolvedReferences
159    def __init__(self, hlirtable, program, config):
160        assert isinstance(hlirtable, p4_table)
161        assert isinstance(program, ebpfProgram.EbpfProgram)
162
163        self.name = hlirtable.name
164        self.hlirtable = hlirtable
165        self.config = config
166
167        self.defaultActionMapName = (program.reservedPrefix +
168                                     self.name + "_miss")
169        self.key = EbpfTableKey(hlirtable.match_fields, program)
170        self.size = hlirtable.max_size
171        if self.size is None:
172            program.emitWarning(
173                "{0} does not specify a max_size; using 1024", hlirtable)
174            self.size = 1024
175        self.isHash = True  # TODO: try to guess arrays when possible
176        self.dataMapName = self.name
177        self.actionEnumName = program.generateNewName(self.name + "_actions")
178        self.keyTypeName = program.generateNewName(self.name + "_key")
179        self.valueTypeName = program.generateNewName(self.name + "_value")
180        self.actions = []
181
182        if hlirtable.action_profile is not None:
183            raise NotSupportedException("{0}: action_profile tables",
184                                        hlirtable)
185        if hlirtable.support_timeout:
186            program.emitWarning("{0}: table timeout {1}; ignoring",
187                                hlirtable, NotSupportedException.archError)
188
189        self.counters = []
190        if (hlirtable.attached_counters is not None):
191            for c in hlirtable.attached_counters:
192                ctr = program.getCounter(c.name)
193                assert isinstance(ctr, ebpfCounter.EbpfCounter)
194                self.counters.append(ctr)
195
196        if (len(hlirtable.attached_meters) > 0 or
197            len(hlirtable.attached_registers) > 0):
198            program.emitWarning("{0}: meters/registers {1}; ignored",
199                                hlirtable, NotSupportedException.archError)
200
201        for a in hlirtable.actions:
202            action = program.getAction(a)
203            self.actions.append(action)
204
205    def serializeKeyType(self, serializer):
206        assert isinstance(serializer, ProgramSerializer)
207        self.key.serializeType(serializer, self.keyTypeName)
208
209    def serializeActionArguments(self, serializer, action):
210        assert isinstance(serializer, ProgramSerializer)
211        assert isinstance(action, ebpfAction.EbpfActionBase)
212        action.serializeArgumentsAsStruct(serializer)
213
214    def serializeValueType(self, serializer):
215        assert isinstance(serializer, ProgramSerializer)
216        #  create an enum with tags for all actions
217        serializer.emitIndent()
218        serializer.appendFormat("enum {0} ", self.actionEnumName)
219        serializer.blockStart()
220
221        for a in self.actions:
222            name = a.name
223            serializer.emitIndent()
224            serializer.appendFormat("{0}_{1},", self.name, name)
225            serializer.newline()
226
227        serializer.blockEnd(False)
228        serializer.endOfStatement(True)
229
230        # a type-safe union: a struct with a tag and an union
231        serializer.emitIndent()
232        serializer.appendFormat("struct {0} ", self.valueTypeName)
233        serializer.blockStart()
234
235        serializer.emitIndent()
236        #serializer.appendFormat("enum {0} action;", self.actionEnumName)
237        # teporary workaround bcc bug
238        serializer.appendFormat("{0}32 action;",
239                                self.config.uprefix)
240        serializer.newline()
241
242        serializer.emitIndent()
243        serializer.append("union ")
244        serializer.blockStart()
245
246        for a in self.actions:
247            self.serializeActionArguments(serializer, a)
248
249        serializer.blockEnd(False)
250        serializer.space()
251        serializer.appendLine("u;")
252        serializer.blockEnd(False)
253        serializer.endOfStatement(True)
254
255    def serialize(self, serializer, program):
256        assert isinstance(serializer, ProgramSerializer)
257        assert isinstance(program, ebpfProgram.EbpfProgram)
258
259        self.serializeKeyType(serializer)
260        self.serializeValueType(serializer)
261
262        self.config.serializeTableDeclaration(
263            serializer, self.dataMapName, self.isHash,
264            "struct " + self.keyTypeName,
265            "struct " + self.valueTypeName, self.size)
266        self.config.serializeTableDeclaration(
267            serializer, self.defaultActionMapName, False,
268            program.arrayIndexType, "struct " + self.valueTypeName, 1)
269
270    def serializeCode(self, serializer, program, nextNode):
271        assert isinstance(serializer, ProgramSerializer)
272        assert isinstance(program, ebpfProgram.EbpfProgram)
273
274        hitVarName = program.reservedPrefix + "hit"
275        keyname = "key"
276        valueName = "value"
277
278        serializer.newline()
279        serializer.emitIndent()
280        serializer.appendFormat("{0}:", program.getLabel(self))
281        serializer.newline()
282
283        serializer.emitIndent()
284        serializer.blockStart()
285
286        serializer.emitIndent()
287        serializer.appendFormat("{0}8 {1};", program.config.uprefix, hitVarName)
288        serializer.newline()
289
290        serializer.emitIndent()
291        serializer.appendFormat("struct {0} {1} = {{}};", self.keyTypeName, keyname)
292        serializer.newline()
293
294        serializer.emitIndent()
295        serializer.appendFormat(
296            "struct {0} *{1};", self.valueTypeName, valueName)
297        serializer.newline()
298
299        self.key.serializeConstruction(serializer, keyname, program)
300
301        serializer.emitIndent()
302        serializer.appendFormat("{0} = 1;", hitVarName)
303        serializer.newline()
304
305        serializer.emitIndent()
306        serializer.appendLine("/* perform lookup */")
307        serializer.emitIndent()
308        program.config.serializeLookup(
309            serializer, self.dataMapName, keyname, valueName)
310        serializer.newline()
311
312        serializer.emitIndent()
313        serializer.appendFormat("if ({0} == NULL) ", valueName)
314        serializer.blockStart()
315
316        serializer.emitIndent()
317        serializer.appendFormat("{0} = 0;", hitVarName)
318        serializer.newline()
319
320        serializer.emitIndent()
321        serializer.appendLine("/* miss; find default action */")
322        serializer.emitIndent()
323        program.config.serializeLookup(
324            serializer, self.defaultActionMapName,
325            program.zeroKeyName, valueName)
326        serializer.newline()
327        serializer.blockEnd(True)
328
329        if len(self.counters) > 0:
330            serializer.emitIndent()
331            serializer.append("else ")
332            serializer.blockStart()
333            for c in self.counters:
334                assert isinstance(c, ebpfCounter.EbpfCounter)
335                if c.autoIncrement:
336                    serializer.emitIndent()
337                    serializer.blockStart()
338                    c.serializeCode(keyname, serializer, program)
339                    serializer.blockEnd(True)
340            serializer.blockEnd(True)
341
342        serializer.emitIndent()
343        serializer.appendFormat("if ({0} != NULL) ", valueName)
344        serializer.blockStart()
345        serializer.emitIndent()
346        serializer.appendLine("/* run action */")
347        self.runAction(serializer, self.name, valueName, program, nextNode)
348
349        nextNode = self.hlirtable.next_
350        if "hit" in nextNode:
351            node = nextNode["hit"]
352            if node is None:
353                node = nextNode
354            label = program.getLabel(node)
355            serializer.emitIndent()
356            serializer.appendFormat("if (hit) goto {0};", label)
357            serializer.newline()
358
359            node = nextNode["miss"]
360            if node is None:
361                node = nextNode
362            label = program.getLabel(node)
363            serializer.emitIndent()
364            serializer.appendFormat("else goto {0};", label)
365            serializer.newline()
366
367        serializer.blockEnd(True)
368        if not "hit" in nextNode:
369            # Catch-all
370            serializer.emitIndent()
371            serializer.appendFormat("goto end;")
372            serializer.newline()
373
374        serializer.blockEnd(True)
375
376    def runAction(self, serializer, tableName, valueName, program, nextNode):
377        serializer.emitIndent()
378        serializer.appendFormat("switch ({0}->action) ", valueName)
379        serializer.blockStart()
380
381        for a in self.actions:
382            assert isinstance(a, ebpfAction.EbpfActionBase)
383
384            serializer.emitIndent()
385            serializer.appendFormat("case {0}_{1}: ", tableName, a.name)
386            serializer.newline()
387            serializer.emitIndent()
388            serializer.blockStart()
389            a.serializeBody(serializer, valueName, program)
390            serializer.blockEnd(True)
391            serializer.emitIndent()
392
393            nextNodes = self.hlirtable.next_
394            if a.hliraction in nextNodes:
395                node = nextNodes[a.hliraction]
396                if node is None:
397                    node = nextNode
398                label = program.getLabel(node)
399                serializer.appendFormat("goto {0};", label)
400            else:
401                serializer.appendFormat("break;")
402            serializer.newline()
403
404        serializer.blockEnd(True)
405