1"""Code generator for Code Completion Model Inference.
2
3Tool runs on the Decision Forest model defined in {model} directory.
4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp
5The generated files defines the Example class named {cpp_class} having all the features as class members.
6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate.
7"""
8
9import argparse
10import json
11import struct
12
13
14class CppClass:
15    """Holds class name and names of the enclosing namespaces."""
16
17    def __init__(self, cpp_class):
18        ns_and_class = cpp_class.split("::")
19        self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0]
20        self.name = ns_and_class[-1]
21        if len(self.name) == 0:
22            raise ValueError("Empty class name.")
23
24    def ns_begin(self):
25        """Returns snippet for opening namespace declarations."""
26        open_ns = ["namespace %s {" % ns for ns in self.ns]
27        return "\n".join(open_ns)
28
29    def ns_end(self):
30        """Returns snippet for closing namespace declarations."""
31        close_ns = [
32            "} // namespace %s" % ns for ns in reversed(self.ns)]
33        return "\n".join(close_ns)
34
35
36def header_guard(filename):
37    '''Returns the header guard for the generated header.'''
38    return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper()
39
40
41def boost_node(n, label, next_label):
42    """Returns code snippet for a leaf/boost node."""
43    return "%s: return %sf;" % (label, n['score'])
44
45
46def if_greater_node(n, label, next_label):
47    """Returns code snippet for a if_greater node.
48    Jumps to true_label if the Example feature (NUMBER) is greater than the threshold.
49    Comparing integers is much faster than comparing floats. Assuming floating points
50    are represented as IEEE 754, it order-encodes the floats to integers before comparing them.
51    Control falls through if condition is evaluated to false."""
52    threshold = n["threshold"]
53    return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % (
54        label, n['feature'], order_encode(threshold), threshold, next_label)
55
56
57def if_member_node(n, label, next_label):
58    """Returns code snippet for a if_member node.
59    Jumps to true_label if the Example feature (ENUM) is present in the set of enum values
60    described in the node.
61    Control falls through if condition is evaluated to false."""
62    members = '|'.join([
63        "BIT(%s_type::%s)" % (n['feature'], member)
64        for member in n["set"]
65    ])
66    return "%s: if (E.get%s() & (%s)) goto %s;" % (
67        label, n['feature'], members, next_label)
68
69
70def node(n, label, next_label):
71    """Returns code snippet for the node."""
72    return {
73        'boost': boost_node,
74        'if_greater': if_greater_node,
75        'if_member': if_member_node,
76    }[n['operation']](n, label, next_label)
77
78
79def tree(t, tree_num, node_num):
80    """Returns code for inferencing a Decision Tree.
81    Also returns the size of the decision tree.
82
83    A tree starts with its label `t{tree#}`.
84    A node of the tree starts with label `t{tree#}_n{node#}`.
85
86    The tree contains two types of node: Conditional node and Leaf node.
87    -   Conditional node evaluates a condition. If true, it jumps to the true node/child.
88        Code is generated using pre-order traversal of the tree considering
89        false node as the first child. Therefore the false node is always the
90        immediately next label.
91    -   Leaf node adds the value to the score and jumps to the next tree.
92    """
93    label = "t%d_n%d" % (tree_num, node_num)
94    code = []
95
96    if t["operation"] == "boost":
97        code.append(node(t, label=label, next_label="t%d" % (tree_num + 1)))
98        return code, 1
99
100    false_code, false_size = tree(
101        t['else'], tree_num=tree_num, node_num=node_num+1)
102
103    true_node_num = node_num+false_size+1
104    true_label = "t%d_n%d" % (tree_num, true_node_num)
105
106    true_code, true_size = tree(
107        t['then'], tree_num=tree_num, node_num=true_node_num)
108
109    code.append(node(t, label=label, next_label=true_label))
110
111    return code+false_code+true_code, 1+false_size+true_size
112
113
114def gen_header_code(features_json, cpp_class, filename):
115    """Returns code for header declaring the inference runtime.
116
117    Declares the Example class named {cpp_class} inside relevant namespaces.
118    The Example class contains all the features as class members. This
119    class can be used to represent a code completion candidate.
120    Provides `float Evaluate()` function which can be used to score the Example.
121    """
122    setters = []
123    getters = []
124    for f in features_json:
125        feature = f["name"]
126
127        if f["kind"] == "NUMBER":
128            # Floats are order-encoded to integers for faster comparison.
129            setters.append(
130                "void set%s(float V) { %s = OrderEncode(V); }" % (
131                    feature, feature))
132        elif f["kind"] == "ENUM":
133            setters.append(
134                "void set%s(unsigned V) { %s = 1 << V; }" % (feature, feature))
135        else:
136            raise ValueError("Unhandled feature type.", f["kind"])
137
138    # Class members represent all the features of the Example.
139    class_members = [
140        "uint32_t %s = 0;" % f['name']
141        for f in features_json
142    ]
143    getters = [
144        "LLVM_ATTRIBUTE_ALWAYS_INLINE uint32_t get%s() const { return %s; }"
145        % (f['name'], f['name'])
146        for f in features_json
147    ]
148    nline = "\n  "
149    guard = header_guard(filename)
150    return """#ifndef %s
151#define %s
152#include <cstdint>
153#include "llvm/Support/Compiler.h"
154
155%s
156class %s {
157public:
158  // Setters.
159  %s
160
161  // Getters.
162  %s
163
164private:
165  %s
166
167  // Produces an integer that sorts in the same order as F.
168  // That is: a < b <==> orderEncode(a) < orderEncode(b).
169  static uint32_t OrderEncode(float F);
170};
171
172float Evaluate(const %s&);
173%s
174#endif // %s
175""" % (guard, guard, cpp_class.ns_begin(), cpp_class.name,
176        nline.join(setters),
177        nline.join(getters),
178        nline.join(class_members),
179        cpp_class.name, cpp_class.ns_end(), guard)
180
181
182def order_encode(v):
183    i = struct.unpack('<I', struct.pack('<f', v))[0]
184    TopBit = 1 << 31
185    # IEEE 754 floats compare like sign-magnitude integers.
186    if (i & TopBit):  # Negative float
187        return (1 << 32) - i  # low half of integers, order reversed.
188    return TopBit + i  # top half of integers
189
190
191def evaluate_func(forest_json, cpp_class):
192    """Generates evaluation functions for each tree and combines them in
193    `float Evaluate(const {Example}&)` function. This function can be
194    used to score an Example."""
195
196    code = ""
197
198    # Generate evaluation function of each tree.
199    code += "namespace {\n"
200    tree_num = 0
201    for tree_json in forest_json:
202        code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (tree_num, cpp_class.name)
203        code += "  " + \
204            "\n  ".join(
205                tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n"
206        code += "}\n\n"
207        tree_num += 1
208    code += "} // namespace\n\n"
209
210    # Combine the scores of all trees in the final function.
211    # MSAN will timeout if these functions are inlined.
212    code += "float Evaluate(const %s& E) {\n" % cpp_class.name
213    code += "  float Score = 0;\n"
214    for tree_num in range(len(forest_json)):
215        code += "  Score += EvaluateTree%d(E);\n" % tree_num
216    code += "  return Score;\n"
217    code += "}\n"
218
219    return code
220
221
222def gen_cpp_code(forest_json, features_json, filename, cpp_class):
223    """Generates code for the .cpp file."""
224    # Headers
225    # Required by OrderEncode(float F).
226    angled_include = [
227        '#include <%s>' % h
228        for h in ["cstring", "limits"]
229    ]
230
231    # Include generated header.
232    qouted_headers = {filename + '.h', 'llvm/ADT/bit.h'}
233    # Headers required by ENUM features used by the model.
234    qouted_headers |= {f["header"]
235                       for f in features_json if f["kind"] == "ENUM"}
236    quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)]
237
238    # using-decl for ENUM features.
239    using_decls = "\n".join("using %s_type = %s;" % (
240        feature['name'], feature['type'])
241        for feature in features_json
242        if feature["kind"] == "ENUM")
243    nl = "\n"
244    return """%s
245
246%s
247
248#define BIT(X) (1 << X)
249
250%s
251
252%s
253
254uint32_t %s::OrderEncode(float F) {
255  static_assert(std::numeric_limits<float>::is_iec559, "");
256  constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
257
258  // Get the bits of the float. Endianness is the same as for integers.
259  uint32_t U = llvm::bit_cast<uint32_t>(F);
260  std::memcpy(&U, &F, sizeof(U));
261  // IEEE 754 floats compare like sign-magnitude integers.
262  if (U & TopBit)    // Negative float.
263    return 0 - U;    // Map onto the low half of integers, order reversed.
264  return U + TopBit; // Positive floats map onto the high half of integers.
265}
266
267%s
268%s
269""" % (nl.join(angled_include), nl.join(quoted_include), cpp_class.ns_begin(),
270       using_decls, cpp_class.name, evaluate_func(forest_json, cpp_class),
271       cpp_class.ns_end())
272
273
274def main():
275    parser = argparse.ArgumentParser('DecisionForestCodegen')
276    parser.add_argument('--filename', help='output file name.')
277    parser.add_argument('--output_dir', help='output directory.')
278    parser.add_argument('--model', help='path to model directory.')
279    parser.add_argument(
280        '--cpp_class',
281        help='The name of the class (which may be a namespace-qualified) created in generated header.'
282    )
283    ns = parser.parse_args()
284
285    output_dir = ns.output_dir
286    filename = ns.filename
287    header_file = "%s/%s.h" % (output_dir, filename)
288    cpp_file = "%s/%s.cpp" % (output_dir, filename)
289    cpp_class = CppClass(cpp_class=ns.cpp_class)
290
291    model_file = "%s/forest.json" % ns.model
292    features_file = "%s/features.json" % ns.model
293
294    with open(features_file) as f:
295        features_json = json.load(f)
296
297    with open(model_file) as m:
298        forest_json = json.load(m)
299
300    with open(cpp_file, 'w+t') as output_cc:
301        output_cc.write(
302            gen_cpp_code(forest_json=forest_json,
303                         features_json=features_json,
304                         filename=filename,
305                         cpp_class=cpp_class))
306
307    with open(header_file, 'w+t') as output_h:
308        output_h.write(gen_header_code(
309            features_json=features_json,
310            cpp_class=cpp_class,
311            filename=filename))
312
313
314if __name__ == '__main__':
315    main()
316