1# Copyright (c) Barefoot Networks, Inc.
2# Licensed under the Apache License, Version 2.0 (the "License")
3
4from p4_hlir.hlir import p4_counter, P4_DIRECT, P4_COUNTER_BYTES
5from programSerializer import ProgramSerializer
6from compilationException import *
7import ebpfTable
8import ebpfProgram
9
10
11class EbpfCounter(object):
12    # noinspection PyUnresolvedReferences
13    def __init__(self, hlircounter, program):
14        assert isinstance(hlircounter, p4_counter)
15        assert isinstance(program, ebpfProgram.EbpfProgram)
16
17        self.name = hlircounter.name
18        self.hlircounter = hlircounter
19
20        width = hlircounter.min_width
21        # ebpf counters only work on 64-bits
22        if width <= 64:
23            self.valueTypeName = program.config.uprefix + "64"
24        else:
25            raise NotSupportedException(
26                "{0}: Counters with {1} bits", hlircounter, width)
27
28        self.dataMapName = self.name
29
30        if ((hlircounter.binding is None) or
31            (hlircounter.binding[0] != P4_DIRECT)):
32            raise NotSupportedException(
33                "{0}: counter which is not direct", hlircounter)
34
35        self.autoIncrement = (hlircounter.binding != None and
36                              hlircounter.binding[0] == P4_DIRECT)
37
38        if hlircounter.type is P4_COUNTER_BYTES:
39            self.increment = "{0}->len".format(program.packetName)
40        else:
41            self.increment = "1"
42
43    def getSize(self, program):
44        if self.hlircounter.instance_count is not None:
45            return self.hlircounter.instance_count
46        if self.autoIncrement:
47            return self.getTable(program).size
48        program.emitWarning(
49            "{0} does not specify a max_size; using 1024", self.hlircounter)
50        return 1024
51
52    def getTable(self, program):
53        table = program.getTable(self.hlircounter.binding[1].name)
54        assert isinstance(table, ebpfTable.EbpfTable)
55        return table
56
57    def serialize(self, serializer, program):
58        assert isinstance(serializer, ProgramSerializer)
59
60        # Direct counters have the same key as the associated table
61        # Static counters have integer keys
62        if self.autoIncrement:
63            keyTypeName = "struct " + self.getTable(program).keyTypeName
64        else:
65            keyTypeName = program.config.uprefix + "32"
66        program.config.serializeTableDeclaration(
67            serializer, self.dataMapName, True, keyTypeName,
68            self.valueTypeName, self.getSize(program))
69
70    def serializeCode(self, keyname, serializer, program):
71        assert isinstance(serializer, ProgramSerializer)
72        assert isinstance(program, ebpfProgram.EbpfProgram)
73
74        serializer.emitIndent()
75        serializer.appendFormat("/* Update counter {0} */", self.name)
76        serializer.newline()
77
78        valueName = "ctrvalue"
79        initValuename = "init_val"
80
81        serializer.emitIndent()
82        serializer.appendFormat("{0} *{1};", self.valueTypeName, valueName)
83        serializer.newline()
84        serializer.emitIndent()
85        serializer.appendFormat("{0} {1};", self.valueTypeName, initValuename)
86        serializer.newline()
87
88        serializer.emitIndent()
89        serializer.appendLine("/* perform lookup */")
90        serializer.emitIndent()
91        program.config.serializeLookup(
92            serializer, self.dataMapName, keyname, valueName)
93        serializer.newline()
94
95        serializer.emitIndent()
96        serializer.appendFormat("if ({0} != NULL) ", valueName)
97        serializer.newline()
98        serializer.increaseIndent()
99        serializer.emitIndent()
100        serializer.appendFormat("__sync_fetch_and_add({0}, {1});",
101                                valueName, self.increment)
102        serializer.newline()
103        serializer.decreaseIndent()
104        serializer.emitIndent()
105
106        serializer.append("else ")
107        serializer.blockStart()
108        serializer.emitIndent()
109        serializer.appendFormat("{0} = {1};", initValuename, self.increment)
110        serializer.newline()
111
112        serializer.emitIndent()
113        program.config.serializeUpdate(
114            serializer, self.dataMapName, keyname, initValuename)
115        serializer.newline()
116        serializer.blockEnd(True)
117