1#!/usr/bin/python3
2#
3# Copyright (C) 2022 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18Generate java benchmarks for 2238-varhandle-perf
19"""
20# TODO: fix constants when converting the test to a Golem benchmark
21
22
23from enum import Enum
24from pathlib import Path
25
26import io
27import sys
28
29
30class MemLoc(Enum):
31    FIELD = 0
32    ARRAY = 1
33    BYTE_ARRAY_VIEW = 2
34
35
36def to_camel_case(word):
37    return ''.join(c for c in word.title() if not c == '_')
38
39
40class Benchmark:
41    def __init__(self, code, static, vartype, flavour, klass, method, memloc,
42        byteorder="LITTLE_ENDIAN"):
43        self.code = code
44        self.static = static
45        self.vartype = vartype
46        self.flavour = flavour
47        self.klass = klass
48        self.method = method
49        self.byteorder = byteorder
50        self.memloc = memloc
51
52    def fullname(self):
53        return "{klass}{method}{flavour}{static_name}{memloc}{byteorder}{vartype}Benchmark".format(
54            klass = self.klass,
55            method = to_camel_case(self.method),
56            flavour = self.flavour,
57            static_name = "Static" if self.static else "",
58            memloc = to_camel_case(self.memloc.name),
59            byteorder = to_camel_case(self.byteorder),
60            vartype = to_camel_case(self.vartype))
61
62    def gencode(self):
63        if self.klass == "Reflect":
64            method_suffix = "" if self.vartype == "String" else self.vartype.title()
65            static_first_arg = "null"
66        elif self.klass == "Unsafe":
67            method_suffix = "Object" if self.vartype == "String" else self.vartype.title()
68            static_first_arg = "this.getClass()"
69        else:
70            method_suffix = ""
71            static_first_arg = ""
72
73        first_arg = static_first_arg if self.static else "this"
74
75        return self.code.format(
76            name = self.fullname(),
77            method = self.method + method_suffix,
78            flavour = self.flavour,
79            static_name = "Static" if self.static else "",
80            static_kwd = "static " if self.static else "",
81            this = first_arg,
82            this_comma = "" if not first_arg else first_arg + ", ",
83            vartype = self.vartype,
84            byteorder = self.byteorder,
85            value1 = VALUES[self.vartype][0],
86            value2 = VALUES[self.vartype][1],
87            value1_byte_array = VALUES["byte[]"][self.byteorder][0],
88            value2_byte_array = VALUES["byte[]"][self.byteorder][1],
89            loop = "for (int pass = 0; pass < 100; ++pass)",
90            iters = ITERATIONS)
91
92
93def BenchVHField(code, static, vartype, flavour, method):
94    return Benchmark(code, static, vartype, flavour, "VarHandle", method, MemLoc.FIELD)
95
96
97def BenchVHArray(code, vartype, flavour, method):
98    return Benchmark(code, False, vartype, flavour, "VarHandle", method, MemLoc.ARRAY)
99
100
101def BenchVHByteArrayView(code, byteorder, vartype, flavour, method):
102    return Benchmark(code, False, vartype, flavour, "VarHandle", method, MemLoc.BYTE_ARRAY_VIEW, byteorder)
103
104
105def BenchReflect(code, static, vartype, method):
106    return Benchmark(code, static, vartype, "", "Reflect", method, MemLoc.FIELD)
107
108
109def BenchUnsafe(code, static, vartype, method):
110    return Benchmark(code, static, vartype, "", "Unsafe", method, MemLoc.FIELD)
111
112
113VALUES = {
114    "int": ["42", "~42"],
115    "float": ["3.14f", "2.17f"],
116    "String": ["\"qwerty\"", "null"],
117    "byte[]": {
118        "LITTLE_ENDIAN": [
119            "{ (byte) VALUE, (byte) (VALUE >> 8), (byte) (VALUE >> 16), (byte) (VALUE >> 24) }",
120            "{ (byte) VALUE, (byte) (-1 >> 8), (byte) (-1 >> 16), (byte) (-1 >> 24) }",
121        ],
122        "BIG_ENDIAN": [
123            "{ (byte) (VALUE >> 24), (byte) (VALUE >> 16), (byte) (VALUE >> 8), (byte) VALUE }",
124            "{ (byte) (-1 >> 24), (byte) (-1 >> 16), (byte) (-1 >> 8), (byte) VALUE }",
125        ],
126    },
127}
128
129
130# TODO: fix these numbers when converting the test to a Golem benchmark
131ITERATIONS = 1 # 3000 for a real benchmark
132REPEAT = 2 # 30 for a real benchmark
133REPEAT_HALF = (int) (REPEAT / 2)
134
135
136BANNER = '// This file is generated by util-src/generate_java.py do not directly modify!'
137
138
139VH_IMPORTS = """
140import java.lang.invoke.MethodHandles;
141import java.lang.invoke.VarHandle;
142"""
143
144
145VH_START = BANNER + VH_IMPORTS + """
146class {name} extends MicroBenchmark {{
147  static final {vartype} FIELD_VALUE = {value1};
148  {static_kwd}{vartype} field = FIELD_VALUE;
149  VarHandle vh;
150
151  {name}() throws Throwable {{
152    vh = MethodHandles.lookup().find{static_name}VarHandle(this.getClass(), "field", {vartype}.class);
153  }}
154"""
155
156
157END = """
158    }}
159  }}
160
161  @Override
162  public int innerIterations() {{
163      return {iters};
164  }}
165}}"""
166
167
168VH_GET = VH_START + """
169  @Override
170  public void setup() {{
171    {vartype} v = ({vartype}) vh.{method}{flavour}({this});
172    if (v != FIELD_VALUE) {{
173      throw new RuntimeException("field has unexpected value " + v);
174    }}
175  }}
176
177  @Override
178  public void run() {{
179    {vartype} x;
180    {loop} {{""" + """
181      x = ({vartype}) vh.{method}{flavour}({this});""" * REPEAT + END
182
183
184VH_SET = VH_START + """
185  @Override
186  public void teardown() {{
187    if (field != FIELD_VALUE) {{
188      throw new RuntimeException("field has unexpected value " + field);
189    }}
190  }}
191
192  @Override
193  public void run() {{
194    {vartype} x;
195    {loop} {{""" + """
196      vh.{method}{flavour}({this_comma}FIELD_VALUE);""" * REPEAT + END
197
198
199VH_CAS = VH_START + """
200  @Override
201  public void run() {{
202    boolean success;
203    {loop} {{""" + """
204      success = vh.{method}{flavour}({this_comma}field, {value2});
205      success = vh.{method}{flavour}({this_comma}field, {value1});""" * REPEAT_HALF + END
206
207
208VH_CAE = VH_START + """
209  @Override
210  public void run() {{
211    {vartype} x;
212    {loop} {{""" + """
213      x = ({vartype}) vh.{method}{flavour}({this_comma}field, {value2});
214      x = ({vartype}) vh.{method}{flavour}({this_comma}field, {value1});""" * REPEAT_HALF + END
215
216
217VH_GAS = VH_START + """
218  @Override
219  public void run() {{
220    {vartype} x;
221    {loop} {{""" + """
222      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END
223
224
225VH_GAA = VH_START + """
226  @Override
227  public void run() {{
228    {vartype} x;
229    {loop} {{""" + """
230      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END
231
232
233VH_GAB = VH_START + """
234  @Override
235  public void run() {{
236    int x;
237    {loop} {{""" + """
238      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END
239
240
241VH_START_ARRAY = BANNER + VH_IMPORTS + """
242class {name} extends MicroBenchmark {{
243  static final {vartype} ELEMENT_VALUE = {value1};
244  {vartype}[] array = {{ ELEMENT_VALUE }};
245  VarHandle vh;
246
247  {name}() throws Throwable {{
248    vh = MethodHandles.arrayElementVarHandle({vartype}[].class);
249  }}
250"""
251
252
253VH_GET_A = VH_START_ARRAY + """
254  @Override
255  public void setup() {{
256    {vartype} v = ({vartype}) vh.{method}{flavour}(array, 0);
257    if (v != ELEMENT_VALUE) {{
258      throw new RuntimeException("array element has unexpected value: " + v);
259    }}
260  }}
261
262  @Override
263  public void run() {{
264    {vartype}[] a = array;
265    {vartype} x;
266    {loop} {{""" + """
267      x = ({vartype}) vh.{method}{flavour}(a, 0);""" * REPEAT + END
268
269
270VH_SET_A = VH_START_ARRAY + """
271  @Override
272  public void teardown() {{
273    if (array[0] != {value2}) {{
274      throw new RuntimeException("array element has unexpected value: " + array[0]);
275    }}
276  }}
277
278  @Override
279  public void run() {{
280    {vartype}[] a = array;
281    {vartype} x;
282    {loop} {{""" + """
283      vh.{method}{flavour}(a, 0, {value2});""" * REPEAT + END
284
285
286VH_START_BYTE_ARRAY_VIEW = BANNER + VH_IMPORTS + """
287import java.util.Arrays;
288import java.nio.ByteOrder;
289
290class {name} extends MicroBenchmark {{
291  static final {vartype} VALUE = {value1};
292  byte[] array1 = {value1_byte_array};
293  byte[] array2 = {value2_byte_array};
294  VarHandle vh;
295
296  {name}() throws Throwable {{
297    vh = MethodHandles.byteArrayViewVarHandle({vartype}[].class, ByteOrder.{byteorder});
298  }}
299"""
300
301
302VH_GET_BAV = VH_START_BYTE_ARRAY_VIEW + """
303  @Override
304  public void setup() {{
305    {vartype} v = ({vartype}) vh.{method}{flavour}(array1, 0);
306    if (v != VALUE) {{
307      throw new RuntimeException("array has unexpected value: " + v);
308    }}
309  }}
310
311  @Override
312  public void run() {{
313    byte[] a = array1;
314    {vartype} x;
315    {loop} {{""" + """
316      x = ({vartype}) vh.{method}{flavour}(a, 0);""" * REPEAT + END
317
318
319VH_SET_BAV = VH_START_BYTE_ARRAY_VIEW + """
320  @Override
321  public void teardown() {{
322    if (!Arrays.equals(array2, array1)) {{
323      throw new RuntimeException("array has unexpected values: " +
324          array2[0] + " " + array2[1] + " " + array2[2] + " " + array2[3]);
325    }}
326  }}
327
328  @Override
329  public void run() {{
330    byte[] a = array2;
331    {loop} {{""" + """
332      vh.{method}{flavour}(a, 0, VALUE);""" * REPEAT + END
333
334
335REFLECT_START = BANNER + """
336import java.lang.reflect.Field;
337
338class {name} extends MicroBenchmark {{
339  Field field;
340  {static_kwd}{vartype} value;
341
342  {name}() throws Throwable {{
343    field = this.getClass().getDeclaredField("value");
344  }}
345"""
346
347
348REFLECT_GET = REFLECT_START + """
349  @Override
350  public void run() throws Throwable {{
351    {vartype} x;
352    {loop} {{""" + """
353      x = ({vartype}) field.{method}({this});""" * REPEAT + END
354
355
356REFLECT_SET = REFLECT_START + """
357  @Override
358  public void run() throws Throwable {{
359    {loop} {{""" + """
360      field.{method}({this_comma}{value1});""" * REPEAT + END
361
362
363UNSAFE_START = BANNER + """
364import java.lang.reflect.Field;
365import jdk.internal.misc.Unsafe;
366
367class {name} extends UnsafeMicroBenchmark {{
368  long offset;
369  {static_kwd}{vartype} value = {value1};
370
371  {name}() throws Throwable {{
372    Field field = this.getClass().getDeclaredField("value");
373    offset = get{static_name}FieldOffset(field);
374  }}
375"""
376
377
378UNSAFE_GET = UNSAFE_START + """
379  @Override
380  public void run() throws Throwable {{
381    {vartype} x;
382    {loop} {{""" + """
383      x = ({vartype}) theUnsafe.{method}({this_comma}offset);""" * REPEAT + END
384
385
386UNSAFE_PUT = UNSAFE_START + """
387  @Override
388  public void run() throws Throwable {{
389    {loop} {{""" + """
390      theUnsafe.{method}({this_comma}offset, {value1});""" * REPEAT + END
391
392
393UNSAFE_CAS = UNSAFE_START + """
394  @Override
395  public void run() throws Throwable {{
396    {loop} {{""" + """
397      theUnsafe.{method}({this_comma}offset, {value1}, {value2});
398      theUnsafe.{method}({this_comma}offset, {value2}, {value1});""" * REPEAT_HALF + END
399
400
401ALL_BENCHMARKS = (
402    [BenchVHField(VH_GET, static, vartype, flavour, "get")
403        for flavour in ["", "Acquire", "Opaque", "Volatile"]
404        for static in [True, False]
405        for vartype in ["int", "String"]] +
406    [BenchVHField(VH_SET, static, vartype, flavour, "set")
407        for flavour in ["", "Volatile", "Opaque", "Release"]
408        for static in [True, False]
409        for vartype in ["int", "String"]] +
410    [BenchVHField(VH_CAS, static, vartype, flavour, "compareAndSet")
411        for flavour in [""]
412        for static in [True, False]
413        for vartype in ["int", "String"]] +
414    [BenchVHField(VH_CAS, static, vartype, flavour, "weakCompareAndSet")
415        for flavour in ["", "Plain", "Acquire", "Release"]
416        for static in [True, False]
417        for vartype in ["int", "String"]] +
418    [BenchVHField(VH_CAE, static, vartype, flavour, "compareAndExchange")
419        for flavour in ["", "Acquire", "Release"]
420        for static in [True, False]
421        for vartype in ["int", "String"]] +
422    [BenchVHField(VH_GAS, static, vartype, flavour, "getAndSet")
423        for flavour in ["", "Acquire", "Release"]
424        for static in [True, False]
425        for vartype in ["int", "String"]] +
426    [BenchVHField(VH_GAA, static, vartype, flavour, "getAndAdd")
427        for flavour in ["", "Acquire", "Release"]
428        for static in [True, False]
429        for vartype in ["int", "float"]] +
430    [BenchVHField(VH_GAB, static, vartype, flavour, "getAndBitwise")
431        for flavour in [oper + mode
432            for oper in ["Or", "Xor", "And"]
433            for mode in ["", "Acquire", "Release"]]
434        for static in [True, False]
435        for vartype in ["int"]] +
436    [BenchVHArray(VH_GET_A, vartype, flavour, "get")
437        for flavour in [""]
438        for vartype in ["int", "String"]] +
439    [BenchVHArray(VH_SET_A, vartype, flavour, "set")
440        for flavour in [""]
441        for vartype in ["int", "String"]] +
442    [BenchVHByteArrayView(VH_GET_BAV, byteorder, vartype, flavour, "get")
443        for flavour in [""]
444        for byteorder in ["BIG_ENDIAN", "LITTLE_ENDIAN"]
445        for vartype in ["int"]] +
446    [BenchVHByteArrayView(VH_SET_BAV, byteorder, vartype, flavour, "set")
447        for flavour in [""]
448        for byteorder in ["BIG_ENDIAN", "LITTLE_ENDIAN"]
449        for vartype in ["int"]] +
450    [BenchReflect(REFLECT_GET, static, vartype, "get")
451        for static in [True, False]
452        for vartype in ["int", "String"]] +
453    [BenchReflect(REFLECT_SET, static, vartype, "set")
454        for static in [True, False]
455        for vartype in ["int", "String"]] +
456    [BenchUnsafe(UNSAFE_GET, static, vartype, "get")
457        for static in [True, False]
458        for vartype in ["int", "String"]] +
459    [BenchUnsafe(UNSAFE_PUT, static, vartype, "put")
460        for static in [True, False]
461        for vartype in ["int", "String"]] +
462    [BenchUnsafe(UNSAFE_CAS, static, vartype, method)
463        for method in ["compareAndSwap", "compareAndSet"]
464        for static in [True, False]
465        for vartype in ["int", "String"]])
466
467
468MAIN = BANNER + """
469public class Main {
470  static MicroBenchmark[] benchmarks;
471
472  private static void initialize() throws Throwable {
473    benchmarks = new MicroBenchmark[] {""" + "".join(["""
474      new {}(),""".format(b.fullname()) for b in ALL_BENCHMARKS]) + """
475    };
476  }
477
478  public static void main(String[] args) throws Throwable {
479    initialize();
480    for (MicroBenchmark benchmark : benchmarks) {
481      benchmark.report();
482    }
483  }
484}"""
485
486
487def main(argv):
488    final_java_dir = Path(argv[1])
489    if not final_java_dir.exists() or not final_java_dir.is_dir():
490        print("{} is not a valid java dir".format(final_java_dir), file=sys.stderr)
491        sys.exit(1)
492
493    for bench in ALL_BENCHMARKS:
494        file_path = final_java_dir / "{}.java".format(bench.fullname())
495        with file_path.open("w") as f:
496            print(bench.gencode(), file=f)
497
498    file_path = final_java_dir / "Main.java"
499    with file_path.open("w") as f:
500        print(MAIN, file=f)
501
502
503if __name__ == '__main__':
504    main(sys.argv)
505