1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities to run benchmarks."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22import numbers
23import os
24import re
25import sys
26import time
27import types
28
29import six
30
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.core.protobuf import rewriter_config_pb2
33from tensorflow.core.util import test_log_pb2
34from tensorflow.python.client import timeline
35from tensorflow.python.framework import ops
36from tensorflow.python.platform import app
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.util import tf_inspect
40from tensorflow.python.util.tf_export import tf_export
41
42
43# When a subclass of the Benchmark class is created, it is added to
44# the registry automatically
45GLOBAL_BENCHMARK_REGISTRY = set()
46
47# Environment variable that determines whether benchmarks are written.
48# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
49TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
50
51# Environment variable that lets the TensorFlow runtime allocate a new
52# threadpool for each benchmark.
53OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
54
55
56def _rename_function(f, arg_num, name):
57  """Rename the given function's name appears in the stack trace."""
58  func_code = six.get_function_code(f)
59  if six.PY2:
60    new_code = types.CodeType(arg_num, func_code.co_nlocals,
61                              func_code.co_stacksize, func_code.co_flags,
62                              func_code.co_code, func_code.co_consts,
63                              func_code.co_names, func_code.co_varnames,
64                              func_code.co_filename, name,
65                              func_code.co_firstlineno, func_code.co_lnotab,
66                              func_code.co_freevars, func_code.co_cellvars)
67  else:
68    if sys.version_info > (3, 8, 0, "alpha", 3):
69      # Python3.8 / PEP570 added co_posonlyargcount argument to CodeType.
70      new_code = types.CodeType(arg_num, func_code.co_posonlyargcount,
71                                0, func_code.co_nlocals,
72                                func_code.co_stacksize, func_code.co_flags,
73                                func_code.co_code, func_code.co_consts,
74                                func_code.co_names, func_code.co_varnames,
75                                func_code.co_filename, name,
76                                func_code.co_firstlineno, func_code.co_lnotab,
77                                func_code.co_freevars, func_code.co_cellvars)
78    else:
79      new_code = types.CodeType(arg_num, 0, func_code.co_nlocals,
80                                func_code.co_stacksize, func_code.co_flags,
81                                func_code.co_code, func_code.co_consts,
82                                func_code.co_names, func_code.co_varnames,
83                                func_code.co_filename, name,
84                                func_code.co_firstlineno, func_code.co_lnotab,
85                                func_code.co_freevars, func_code.co_cellvars)
86
87  return types.FunctionType(new_code, f.__globals__, name, f.__defaults__,
88                            f.__closure__)
89
90
91def _global_report_benchmark(
92    name, iters=None, cpu_time=None, wall_time=None,
93    throughput=None, extras=None, metrics=None):
94  """Method for recording a benchmark directly.
95
96  Args:
97    name: The BenchmarkEntry name.
98    iters: (optional) How many iterations were run
99    cpu_time: (optional) Total cpu time in seconds
100    wall_time: (optional) Total wall time in seconds
101    throughput: (optional) Throughput (in MB/s)
102    extras: (optional) Dict mapping string keys to additional benchmark info.
103    metrics: (optional) A list of dict representing metrics generated by the
104      benchmark. Each dict should contain keys 'name' and'value'. A dict
105      can optionally contain keys 'min_value' and 'max_value'.
106
107  Raises:
108    TypeError: if extras is not a dict.
109    IOError: if the benchmark output file already exists.
110  """
111  logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
112               "throughput: %g, extras: %s, metrics: %s", name,
113               iters if iters is not None else -1,
114               wall_time if wall_time is not None else -1,
115               cpu_time if cpu_time is not None else -1,
116               throughput if throughput is not None else -1,
117               str(extras) if extras else "None",
118               str(metrics) if metrics else "None")
119
120  entries = test_log_pb2.BenchmarkEntries()
121  entry = entries.entry.add()
122  entry.name = name
123  if iters is not None:
124    entry.iters = iters
125  if cpu_time is not None:
126    entry.cpu_time = cpu_time
127  if wall_time is not None:
128    entry.wall_time = wall_time
129  if throughput is not None:
130    entry.throughput = throughput
131  if extras is not None:
132    if not isinstance(extras, dict):
133      raise TypeError("extras must be a dict")
134    for (k, v) in extras.items():
135      if isinstance(v, numbers.Number):
136        entry.extras[k].double_value = v
137      else:
138        entry.extras[k].string_value = str(v)
139  if metrics is not None:
140    if not isinstance(metrics, list):
141      raise TypeError("metrics must be a list")
142    for metric in metrics:
143      if "name" not in metric:
144        raise TypeError("metric must has a 'name' field")
145      if "value" not in metric:
146        raise TypeError("metric must has a 'value' field")
147
148      metric_entry = entry.metrics.add()
149      metric_entry.name = metric["name"]
150      metric_entry.value = metric["value"]
151      if "min_value" in metric:
152        metric_entry.min_value.value = metric["min_value"]
153      if "max_value" in metric:
154        metric_entry.max_value.value = metric["max_value"]
155
156  test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
157  if test_env is None:
158    # Reporting was not requested, just print the proto
159    print(str(entries))
160    return
161
162  serialized_entry = entries.SerializeToString()
163
164  mangled_name = name.replace("/", "__")
165  output_path = "%s%s" % (test_env, mangled_name)
166  if gfile.Exists(output_path):
167    raise IOError("File already exists: %s" % output_path)
168  with gfile.GFile(output_path, "wb") as out:
169    out.write(serialized_entry)
170
171
172class _BenchmarkRegistrar(type):
173  """The Benchmark class registrar.  Used by abstract Benchmark class."""
174
175  def __new__(mcs, clsname, base, attrs):
176    newclass = type.__new__(mcs, clsname, base, attrs)
177    if not newclass.is_abstract():
178      GLOBAL_BENCHMARK_REGISTRY.add(newclass)
179    return newclass
180
181
182@tf_export("__internal__.test.ParameterizedBenchmark", v1=[])
183class ParameterizedBenchmark(_BenchmarkRegistrar):
184  """Metaclass to generate parameterized benchmarks.
185
186  Use this class as a metaclass and override the `_benchmark_parameters` to
187  generate multiple benchmark test cases. For example:
188
189  class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark,
190                     tf.test.Benchmark):
191    # The `_benchmark_parameters` is expected to be a list with test cases.
192    # Each of the test case is a tuple, with the first time to be test case
193    # name, followed by any number of the parameters needed for the test case.
194    _benchmark_parameters = [
195      ('case_1', Foo, 1, 'one'),
196      ('case_2', Bar, 2, 'two'),
197    ]
198
199    def benchmark_test(self, target_class, int_param, string_param):
200      # benchmark test body
201
202  The example above will generate two benchmark test cases:
203  "benchmark_test__case_1" and "benchmark_test__case_2".
204  """
205
206  def __new__(mcs, clsname, base, attrs):
207    param_config_list = attrs["_benchmark_parameters"]
208
209    def create_benchmark_function(original_benchmark, params):
210      return lambda self: original_benchmark(self, *params)
211
212    for name in attrs.copy().keys():
213      if not name.startswith("benchmark"):
214        continue
215
216      original_benchmark = attrs[name]
217      del attrs[name]
218
219      for param_config in param_config_list:
220        test_name_suffix = param_config[0]
221        params = param_config[1:]
222        benchmark_name = name + "__" + test_name_suffix
223        if benchmark_name in attrs:
224          raise Exception(
225              "Benchmark named {} already defined.".format(benchmark_name))
226
227        benchmark = create_benchmark_function(original_benchmark, params)
228        # Renaming is important because `report_benchmark` function looks up the
229        # function name in the stack trace.
230        attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name)
231
232    return super(mcs, ParameterizedBenchmark).__new__(mcs, clsname, base, attrs)
233
234
235class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
236  """Abstract class that provides helper functions for running benchmarks.
237
238  Any class subclassing this one is immediately registered in the global
239  benchmark registry.
240
241  Only methods whose names start with the word "benchmark" will be run during
242  benchmarking.
243  """
244
245  @classmethod
246  def is_abstract(cls):
247    # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark
248    return len(cls.mro()) <= 2
249
250  def _get_name(self, overwrite_name=None):
251    """Returns full name of class and method calling report_benchmark."""
252
253    # Find the caller method (outermost Benchmark class)
254    stack = tf_inspect.stack()
255    calling_class = None
256    name = None
257    for frame in stack[::-1]:
258      f_locals = frame[0].f_locals
259      f_self = f_locals.get("self", None)
260      if isinstance(f_self, Benchmark):
261        calling_class = f_self  # Get the outermost stack Benchmark call
262        name = frame[3]  # Get the method name
263        break
264    if calling_class is None:
265      raise ValueError("Unable to determine calling Benchmark class.")
266
267    # Use the method name, or overwrite_name is provided.
268    name = overwrite_name or name
269    # Prefix the name with the class name.
270    class_name = type(calling_class).__name__
271    name = "%s.%s" % (class_name, name)
272    return name
273
274  def report_benchmark(
275      self,
276      iters=None,
277      cpu_time=None,
278      wall_time=None,
279      throughput=None,
280      extras=None,
281      name=None,
282      metrics=None):
283    """Report a benchmark.
284
285    Args:
286      iters: (optional) How many iterations were run
287      cpu_time: (optional) Median or mean cpu time in seconds.
288      wall_time: (optional) Median or mean wall time in seconds.
289      throughput: (optional) Throughput (in MB/s)
290      extras: (optional) Dict mapping string keys to additional benchmark info.
291        Values may be either floats or values that are convertible to strings.
292      name: (optional) Override the BenchmarkEntry name with `name`.
293        Otherwise it is inferred from the top-level method name.
294      metrics: (optional) A list of dict, where each dict has the keys below
295        name (required), string, metric name
296        value (required), double, metric value
297        min_value (optional), double, minimum acceptable metric value
298        max_value (optional), double, maximum acceptable metric value
299    """
300    name = self._get_name(overwrite_name=name)
301    _global_report_benchmark(
302        name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
303        throughput=throughput, extras=extras, metrics=metrics)
304
305
306@tf_export("test.benchmark_config")
307def benchmark_config():
308  """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer.
309
310    Returns:
311      A TensorFlow ConfigProto object.
312  """
313  config = config_pb2.ConfigProto()
314  config.graph_options.rewrite_options.dependency_optimization = (
315      rewriter_config_pb2.RewriterConfig.OFF)
316  return config
317
318
319@tf_export("test.Benchmark")
320class TensorFlowBenchmark(Benchmark):
321  """Abstract class that provides helpers for TensorFlow benchmarks."""
322
323  def __init__(self):
324    # Allow TensorFlow runtime to allocate a new threadpool with different
325    # number of threads for each new benchmark.
326    os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1"
327    super(TensorFlowBenchmark, self).__init__()
328
329  @classmethod
330  def is_abstract(cls):
331    # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means
332    # this is TensorFlowBenchmark.
333    return len(cls.mro()) <= 3
334
335  def run_op_benchmark(self,
336                       sess,
337                       op_or_tensor,
338                       feed_dict=None,
339                       burn_iters=2,
340                       min_iters=10,
341                       store_trace=False,
342                       store_memory_usage=True,
343                       name=None,
344                       extras=None,
345                       mbs=0):
346    """Run an op or tensor in the given session.  Report the results.
347
348    Args:
349      sess: `Session` object to use for timing.
350      op_or_tensor: `Operation` or `Tensor` to benchmark.
351      feed_dict: A `dict` of values to feed for each op iteration (see the
352        `feed_dict` parameter of `Session.run`).
353      burn_iters: Number of burn-in iterations to run.
354      min_iters: Minimum number of iterations to use for timing.
355      store_trace: Boolean, whether to run an extra untimed iteration and
356        store the trace of iteration in returned extras.
357        The trace will be stored as a string in Google Chrome trace format
358        in the extras field "full_trace_chrome_format". Note that trace
359        will not be stored in test_log_pb2.TestResults proto.
360      store_memory_usage: Boolean, whether to run an extra untimed iteration,
361        calculate memory usage, and store that in extras fields.
362      name: (optional) Override the BenchmarkEntry name with `name`.
363        Otherwise it is inferred from the top-level method name.
364      extras: (optional) Dict mapping string keys to additional benchmark info.
365        Values may be either floats or values that are convertible to strings.
366      mbs: (optional) The number of megabytes moved by this op, used to
367        calculate the ops throughput.
368
369    Returns:
370      A `dict` containing the key-value pairs that were passed to
371      `report_benchmark`. If `store_trace` option is used, then
372      `full_chrome_trace_format` will be included in return dictionary even
373      though it is not passed to `report_benchmark` with `extras`.
374    """
375    for _ in range(burn_iters):
376      sess.run(op_or_tensor, feed_dict=feed_dict)
377
378    deltas = [None] * min_iters
379
380    for i in range(min_iters):
381      start_time = time.time()
382      sess.run(op_or_tensor, feed_dict=feed_dict)
383      end_time = time.time()
384      delta = end_time - start_time
385      deltas[i] = delta
386
387    extras = extras if extras is not None else {}
388    unreported_extras = {}
389    if store_trace or store_memory_usage:
390      run_options = config_pb2.RunOptions(
391          trace_level=config_pb2.RunOptions.FULL_TRACE)
392      run_metadata = config_pb2.RunMetadata()
393      sess.run(op_or_tensor, feed_dict=feed_dict,
394               options=run_options, run_metadata=run_metadata)
395      tl = timeline.Timeline(run_metadata.step_stats)
396
397      if store_trace:
398        unreported_extras["full_trace_chrome_format"] = (
399            tl.generate_chrome_trace_format())
400
401      if store_memory_usage:
402        step_stats_analysis = tl.analyze_step_stats(show_memory=True)
403        allocator_maximums = step_stats_analysis.allocator_maximums
404        for k, v in allocator_maximums.items():
405          extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes
406
407    def _median(x):
408      if not x:
409        return -1
410      s = sorted(x)
411      l = len(x)
412      lm1 = l - 1
413      return (s[l//2] + s[lm1//2]) / 2.0
414
415    def _mean_and_stdev(x):
416      if not x:
417        return -1, -1
418      l = len(x)
419      mean = sum(x) / l
420      if l == 1:
421        return mean, -1
422      variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1)
423      return mean, math.sqrt(variance)
424
425    median_delta = _median(deltas)
426
427    benchmark_values = {
428        "iters": min_iters,
429        "wall_time": median_delta,
430        "extras": extras,
431        "name": name,
432        "throughput": mbs / median_delta
433    }
434    self.report_benchmark(**benchmark_values)
435
436    mean_delta, stdev_delta = _mean_and_stdev(deltas)
437    unreported_extras["wall_time_mean"] = mean_delta
438    unreported_extras["wall_time_stdev"] = stdev_delta
439    benchmark_values["extras"].update(unreported_extras)
440    return benchmark_values
441
442  def evaluate(self, tensors):
443    """Evaluates tensors and returns numpy values.
444
445    Args:
446      tensors: A Tensor or a nested list/tuple of Tensors.
447
448    Returns:
449      tensors numpy values.
450    """
451    sess = ops.get_default_session() or self.cached_session()
452    return sess.run(tensors)
453
454
455def _run_benchmarks(regex):
456  """Run benchmarks that match regex `regex`.
457
458  This function goes through the global benchmark registry, and matches
459  benchmark class and method names of the form
460  `module.name.BenchmarkClass.benchmarkMethod` to the given regex.
461  If a method matches, it is run.
462
463  Args:
464    regex: The string regular expression to match Benchmark classes against.
465
466  Raises:
467    ValueError: If no benchmarks were selected by the input regex.
468  """
469  registry = list(GLOBAL_BENCHMARK_REGISTRY)
470
471  selected_benchmarks = []
472  # Match benchmarks in registry against regex
473  for benchmark in registry:
474    benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
475    attrs = dir(benchmark)
476    # Don't instantiate the benchmark class unless necessary
477    benchmark_instance = None
478
479    for attr in attrs:
480      if not attr.startswith("benchmark"):
481        continue
482      candidate_benchmark_fn = getattr(benchmark, attr)
483      if not callable(candidate_benchmark_fn):
484        continue
485      full_benchmark_name = "%s.%s" % (benchmark_name, attr)
486      if regex == "all" or re.search(regex, full_benchmark_name):
487        selected_benchmarks.append(full_benchmark_name)
488        # Instantiate the class if it hasn't been instantiated
489        benchmark_instance = benchmark_instance or benchmark()
490        # Get the method tied to the class
491        instance_benchmark_fn = getattr(benchmark_instance, attr)
492        # Call the instance method
493        instance_benchmark_fn()
494
495  if not selected_benchmarks:
496    raise ValueError("No benchmarks matched the pattern: '{}'".format(regex))
497
498
499def benchmarks_main(true_main, argv=None):
500  """Run benchmarks as declared in argv.
501
502  Args:
503    true_main: True main function to run if benchmarks are not requested.
504    argv: the command line arguments (if None, uses sys.argv).
505  """
506  if argv is None:
507    argv = sys.argv
508  found_arg = [arg for arg in argv
509               if arg.startswith("--benchmarks=")
510               or arg.startswith("-benchmarks=")]
511  if found_arg:
512    # Remove --benchmarks arg from sys.argv
513    argv.remove(found_arg[0])
514
515    regex = found_arg[0].split("=")[1]
516    app.run(lambda _: _run_benchmarks(regex), argv=argv)
517  else:
518    true_main()
519