1# Copyright (c) Barefoot Networks, Inc.
2# Licensed under the Apache License, Version 2.0 (the "License")
3
4from collections import defaultdict, OrderedDict
5from p4_hlir.hlir import parse_call, p4_field, p4_parse_value_set, \
6    P4_DEFAULT, p4_parse_state, p4_table, \
7    p4_conditional_node, p4_parser_exception, \
8    p4_header_instance, P4_NEXT
9
10import ebpfProgram
11import ebpfInstance
12import ebpfType
13import ebpfStructType
14from topoSorting import Graph
15from programSerializer import ProgramSerializer
16
17def produce_parser_topo_sorting(hlir):
18    # This function is copied from the P4 behavioral model implementation
19    header_graph = Graph()
20
21    def walk_rec(hlir, parse_state, prev_hdr_node, tag_stacks_index):
22        assert(isinstance(parse_state, p4_parse_state))
23        for call in parse_state.call_sequence:
24            call_type = call[0]
25            if call_type == parse_call.extract:
26                hdr = call[1]
27
28                if hdr.virtual:
29                    base_name = hdr.base_name
30                    current_index = tag_stacks_index[base_name]
31                    if current_index > hdr.max_index:
32                        return
33                    tag_stacks_index[base_name] += 1
34                    name = base_name + "[%d]" % current_index
35                    hdr = hlir.p4_header_instances[name]
36
37                if hdr not in header_graph:
38                    header_graph.add_node(hdr)
39                hdr_node = header_graph.get_node(hdr)
40
41                if prev_hdr_node:
42                    prev_hdr_node.add_edge_to(hdr_node)
43                else:
44                    header_graph.root = hdr
45                prev_hdr_node = hdr_node
46
47        for branch_case, next_state in parse_state.branch_to.items():
48            if not next_state:
49                continue
50            if not isinstance(next_state, p4_parse_state):
51                continue
52            walk_rec(hlir, next_state, prev_hdr_node, tag_stacks_index.copy())
53
54    start_state = hlir.p4_parse_states["start"]
55    walk_rec(hlir, start_state, None, defaultdict(int))
56
57    header_topo_sorting = header_graph.produce_topo_sorting()
58
59    return header_topo_sorting
60
61class EbpfDeparser(object):
62    def __init__(self, hlir):
63        header_topo_sorting = produce_parser_topo_sorting(hlir)
64        self.headerOrder = [hdr.name for hdr in header_topo_sorting]
65
66    def serialize(self, serializer, program):
67        assert isinstance(serializer, ProgramSerializer)
68        assert isinstance(program, ebpfProgram.EbpfProgram)
69
70        serializer.emitIndent()
71        serializer.blockStart()
72        serializer.emitIndent()
73        serializer.appendLine("/* Deparser */")
74        serializer.emitIndent()
75        serializer.appendFormat("{0} = 0;", program.offsetVariableName)
76        serializer.newline()
77        for h in self.headerOrder:
78            header = program.getHeaderInstance(h)
79            self.serializeHeaderEmit(header, serializer, program)
80        serializer.blockEnd(True)
81
82    def serializeHeaderEmit(self, header, serializer, program):
83        assert isinstance(header, ebpfInstance.EbpfHeader)
84        assert isinstance(serializer, ProgramSerializer)
85        assert isinstance(program, ebpfProgram.EbpfProgram)
86        p4header = header.hlirInstance
87        assert isinstance(p4header, p4_header_instance)
88
89        serializer.emitIndent()
90        serializer.appendFormat("if ({0}.{1}.valid) ",
91                                program.headerStructName, header.name)
92        serializer.blockStart()
93
94        if ebpfProgram.EbpfProgram.isArrayElementInstance(p4header):
95            ebpfStack = program.getStackInstance(p4header.base_name)
96            assert isinstance(ebpfStack, ebpfInstance.EbpfHeaderStack)
97
98            if isinstance(p4header.index, int):
99                index = "[" + str(headerInstance.index) + "]"
100            elif p4header.index is P4_NEXT:
101                index = "[" + ebpfStack.indexVar + "]"
102            else:
103                raise CompilationException(
104                    True, "Unexpected index for array {0}",
105                    p4header.index)
106            basetype = ebpfStack.basetype
107        else:
108            ebpfHeader = program.getHeaderInstance(p4header.name)
109            basetype = ebpfHeader.type
110            index = ""
111
112        alignment = 0
113        for field in basetype.fields:
114            assert isinstance(field, ebpfStructType.EbpfField)
115
116            self.serializeFieldEmit(serializer, p4header.base_name,
117                                    index, field, alignment, program)
118            alignment += field.widthInBits()
119            alignment = alignment % 8
120        serializer.blockEnd(True)
121
122    def serializeFieldEmit(self, serializer, name, index,
123                           field, alignment, program):
124        assert isinstance(index, str)
125        assert isinstance(name, str)
126        assert isinstance(field, ebpfStructType.EbpfField)
127        assert isinstance(serializer, ProgramSerializer)
128        assert isinstance(alignment, int)
129        assert isinstance(program, ebpfProgram.EbpfProgram)
130
131        if field.name == "valid":
132            return
133
134        fieldToEmit = (program.headerStructName + "." + name +
135                       index + "." + field.name)
136        width = field.widthInBits()
137        if width <= 32:
138            store = self.generatePacketStore(fieldToEmit, 0, alignment,
139                                             width, program)
140            serializer.emitIndent()
141            serializer.appendLine(store)
142        else:
143            # Destination is bigger than 4 bytes and
144            # represented as a byte array.
145            b = (width + 7) / 8
146            for i in range(0, b):
147                serializer.emitIndent()
148                store = self.generatePacketStore(fieldToEmit + "["+str(i)+"]",
149                                                 i,
150                                                 alignment,
151                                                 8, program)
152                serializer.appendLine(store)
153
154        serializer.emitIndent()
155        serializer.appendFormat("{0} += {1};",
156                                program.offsetVariableName, width)
157        serializer.newline()
158
159    def generatePacketStore(self, value, offset, alignment, width, program):
160        assert width > 0
161        assert alignment < 8
162        assert isinstance(width, int)
163        assert isinstance(alignment, int)
164
165        return "bpf_dins_pkt({0}, {1} / 8 + {2}, {3}, {4}, {5});".format(
166            program.packetName,
167            program.offsetVariableName,
168            offset,
169            alignment,
170            width,
171            value
172        )
173