1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities to run benchmarks."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numbers
22import os
23import re
24import sys
25import time
26
27import six
28
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.core.util import test_log_pb2
32from tensorflow.python.client import timeline
33from tensorflow.python.framework import ops
34from tensorflow.python.platform import app
35from tensorflow.python.platform import gfile
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util import tf_inspect
38from tensorflow.python.util.tf_export import tf_export
39
40
41# When a subclass of the Benchmark class is created, it is added to
42# the registry automatically
43GLOBAL_BENCHMARK_REGISTRY = set()
44
45# Environment variable that determines whether benchmarks are written.
46# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
47TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
48
49# Environment variable that lets the TensorFlow runtime allocate a new
50# threadpool for each benchmark.
51OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
52
53
54def _global_report_benchmark(
55    name, iters=None, cpu_time=None, wall_time=None,
56    throughput=None, extras=None):
57  """Method for recording a benchmark directly.
58
59  Args:
60    name: The BenchmarkEntry name.
61    iters: (optional) How many iterations were run
62    cpu_time: (optional) Total cpu time in seconds
63    wall_time: (optional) Total wall time in seconds
64    throughput: (optional) Throughput (in MB/s)
65    extras: (optional) Dict mapping string keys to additional benchmark info.
66
67  Raises:
68    TypeError: if extras is not a dict.
69    IOError: if the benchmark output file already exists.
70  """
71  if extras is not None:
72    if not isinstance(extras, dict):
73      raise TypeError("extras must be a dict")
74
75  logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
76               "throughput: %g %s", name, iters if iters is not None else -1,
77               wall_time if wall_time is not None else -1, cpu_time if
78               cpu_time is not None else -1, throughput if
79               throughput is not None else -1, str(extras) if extras else "")
80
81  entries = test_log_pb2.BenchmarkEntries()
82  entry = entries.entry.add()
83  entry.name = name
84  if iters is not None:
85    entry.iters = iters
86  if cpu_time is not None:
87    entry.cpu_time = cpu_time
88  if wall_time is not None:
89    entry.wall_time = wall_time
90  if throughput is not None:
91    entry.throughput = throughput
92  if extras is not None:
93    for (k, v) in extras.items():
94      if isinstance(v, numbers.Number):
95        entry.extras[k].double_value = v
96      else:
97        entry.extras[k].string_value = str(v)
98
99  test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
100  if test_env is None:
101    # Reporting was not requested, just print the proto
102    print(str(entries))
103    return
104
105  serialized_entry = entries.SerializeToString()
106
107  mangled_name = name.replace("/", "__")
108  output_path = "%s%s" % (test_env, mangled_name)
109  if gfile.Exists(output_path):
110    raise IOError("File already exists: %s" % output_path)
111  with gfile.GFile(output_path, "wb") as out:
112    out.write(serialized_entry)
113
114
115class _BenchmarkRegistrar(type):
116  """The Benchmark class registrar.  Used by abstract Benchmark class."""
117
118  def __new__(mcs, clsname, base, attrs):
119    newclass = super(mcs, _BenchmarkRegistrar).__new__(
120        mcs, clsname, base, attrs)
121    if not newclass.is_abstract():
122      GLOBAL_BENCHMARK_REGISTRY.add(newclass)
123    return newclass
124
125
126class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
127  """Abstract class that provides helper functions for running benchmarks.
128
129  Any class subclassing this one is immediately registered in the global
130  benchmark registry.
131
132  Only methods whose names start with the word "benchmark" will be run during
133  benchmarking.
134  """
135
136  @classmethod
137  def is_abstract(cls):
138    # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark
139    return len(cls.mro()) <= 2
140
141  def _get_name(self, overwrite_name=None):
142    """Returns full name of class and method calling report_benchmark."""
143
144    # Find the caller method (outermost Benchmark class)
145    stack = tf_inspect.stack()
146    calling_class = None
147    name = None
148    for frame in stack[::-1]:
149      f_locals = frame[0].f_locals
150      f_self = f_locals.get("self", None)
151      if isinstance(f_self, Benchmark):
152        calling_class = f_self  # Get the outermost stack Benchmark call
153        name = frame[3]  # Get the method name
154        break
155    if calling_class is None:
156      raise ValueError("Unable to determine calling Benchmark class.")
157
158    # Use the method name, or overwrite_name is provided.
159    name = overwrite_name or name
160    # Prefix the name with the class name.
161    class_name = type(calling_class).__name__
162    name = "%s.%s" % (class_name, name)
163    return name
164
165  def report_benchmark(
166      self,
167      iters=None,
168      cpu_time=None,
169      wall_time=None,
170      throughput=None,
171      extras=None,
172      name=None):
173    """Report a benchmark.
174
175    Args:
176      iters: (optional) How many iterations were run
177      cpu_time: (optional) median or mean cpu time in seconds.
178      wall_time: (optional) median or mean wall time in seconds.
179      throughput: (optional) Throughput (in MB/s)
180      extras: (optional) Dict mapping string keys to additional benchmark info.
181        Values may be either floats or values that are convertible to strings.
182      name: (optional) Override the BenchmarkEntry name with `name`.
183        Otherwise it is inferred from the top-level method name.
184    """
185    name = self._get_name(overwrite_name=name)
186    _global_report_benchmark(
187        name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
188        throughput=throughput, extras=extras)
189
190
191@tf_export("test.benchmark_config")
192def benchmark_config():
193  """Returns a tf.ConfigProto for disabling the dependency optimizer.
194
195    Returns:
196      A TensorFlow ConfigProto object.
197  """
198  config = config_pb2.ConfigProto()
199  config.graph_options.rewrite_options.dependency_optimization = (
200      rewriter_config_pb2.RewriterConfig.OFF)
201  return config
202
203
204@tf_export("test.Benchmark")
205class TensorFlowBenchmark(Benchmark):
206  """Abstract class that provides helpers for TensorFlow benchmarks."""
207
208  def __init__(self):
209    # Allow TensorFlow runtime to allocate a new threadpool with different
210    # number of threads for each new benchmark.
211    os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1"
212    super(TensorFlowBenchmark, self).__init__()
213
214  @classmethod
215  def is_abstract(cls):
216    # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means
217    # this is TensorFlowBenchmark.
218    return len(cls.mro()) <= 3
219
220  def run_op_benchmark(self,
221                       sess,
222                       op_or_tensor,
223                       feed_dict=None,
224                       burn_iters=2,
225                       min_iters=10,
226                       store_trace=False,
227                       store_memory_usage=True,
228                       name=None,
229                       extras=None,
230                       mbs=0):
231    """Run an op or tensor in the given session.  Report the results.
232
233    Args:
234      sess: `Session` object to use for timing.
235      op_or_tensor: `Operation` or `Tensor` to benchmark.
236      feed_dict: A `dict` of values to feed for each op iteration (see the
237        `feed_dict` parameter of `Session.run`).
238      burn_iters: Number of burn-in iterations to run.
239      min_iters: Minimum number of iterations to use for timing.
240      store_trace: Boolean, whether to run an extra untimed iteration and
241        store the trace of iteration in returned extras.
242        The trace will be stored as a string in Google Chrome trace format
243        in the extras field "full_trace_chrome_format". Note that trace
244        will not be stored in test_log_pb2.TestResults proto.
245      store_memory_usage: Boolean, whether to run an extra untimed iteration,
246        calculate memory usage, and store that in extras fields.
247      name: (optional) Override the BenchmarkEntry name with `name`.
248        Otherwise it is inferred from the top-level method name.
249      extras: (optional) Dict mapping string keys to additional benchmark info.
250        Values may be either floats or values that are convertible to strings.
251      mbs: (optional) The number of megabytes moved by this op, used to
252        calculate the ops throughput.
253
254    Returns:
255      A `dict` containing the key-value pairs that were passed to
256      `report_benchmark`. If `store_trace` option is used, then
257      `full_chrome_trace_format` will be included in return dictionary even
258      though it is not passed to `report_benchmark` with `extras`.
259    """
260    for _ in range(burn_iters):
261      sess.run(op_or_tensor, feed_dict=feed_dict)
262
263    deltas = [None] * min_iters
264
265    for i in range(min_iters):
266      start_time = time.time()
267      sess.run(op_or_tensor, feed_dict=feed_dict)
268      end_time = time.time()
269      delta = end_time - start_time
270      deltas[i] = delta
271
272    extras = extras if extras is not None else {}
273    unreported_extras = {}
274    if store_trace or store_memory_usage:
275      run_options = config_pb2.RunOptions(
276          trace_level=config_pb2.RunOptions.FULL_TRACE)
277      run_metadata = config_pb2.RunMetadata()
278      sess.run(op_or_tensor, feed_dict=feed_dict,
279               options=run_options, run_metadata=run_metadata)
280      tl = timeline.Timeline(run_metadata.step_stats)
281
282      if store_trace:
283        unreported_extras["full_trace_chrome_format"] = (
284            tl.generate_chrome_trace_format())
285
286      if store_memory_usage:
287        step_stats_analysis = tl.analyze_step_stats(show_memory=True)
288        allocator_maximums = step_stats_analysis.allocator_maximums
289        for k, v in allocator_maximums.items():
290          extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes
291
292    def _median(x):
293      if not x:
294        return -1
295      s = sorted(x)
296      l = len(x)
297      lm1 = l - 1
298      return (s[l//2] + s[lm1//2]) / 2.0
299
300    median_delta = _median(deltas)
301
302    benchmark_values = {
303        "iters": min_iters,
304        "wall_time": median_delta,
305        "extras": extras,
306        "name": name,
307        "throughput": mbs / median_delta
308    }
309    self.report_benchmark(**benchmark_values)
310    benchmark_values["extras"].update(unreported_extras)
311    return benchmark_values
312
313  def evaluate(self, tensors):
314    """Evaluates tensors and returns numpy values.
315
316    Args:
317      tensors: A Tensor or a nested list/tuple of Tensors.
318
319    Returns:
320      tensors numpy values.
321    """
322    sess = ops.get_default_session() or self.cached_session()
323    return sess.run(tensors)
324
325
326def _run_benchmarks(regex):
327  """Run benchmarks that match regex `regex`.
328
329  This function goes through the global benchmark registry, and matches
330  benchmark class and method names of the form
331  `module.name.BenchmarkClass.benchmarkMethod` to the given regex.
332  If a method matches, it is run.
333
334  Args:
335    regex: The string regular expression to match Benchmark classes against.
336  """
337  registry = list(GLOBAL_BENCHMARK_REGISTRY)
338
339  # Match benchmarks in registry against regex
340  for benchmark in registry:
341    benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
342    attrs = dir(benchmark)
343    # Don't instantiate the benchmark class unless necessary
344    benchmark_instance = None
345
346    for attr in attrs:
347      if not attr.startswith("benchmark"):
348        continue
349      candidate_benchmark_fn = getattr(benchmark, attr)
350      if not callable(candidate_benchmark_fn):
351        continue
352      full_benchmark_name = "%s.%s" % (benchmark_name, attr)
353      if regex == "all" or re.search(regex, full_benchmark_name):
354        # Instantiate the class if it hasn't been instantiated
355        benchmark_instance = benchmark_instance or benchmark()
356        # Get the method tied to the class
357        instance_benchmark_fn = getattr(benchmark_instance, attr)
358        # Call the instance method
359        instance_benchmark_fn()
360
361
362def benchmarks_main(true_main, argv=None):
363  """Run benchmarks as declared in argv.
364
365  Args:
366    true_main: True main function to run if benchmarks are not requested.
367    argv: the command line arguments (if None, uses sys.argv).
368  """
369  if argv is None:
370    argv = sys.argv
371  found_arg = [arg for arg in argv
372               if arg.startswith("--benchmarks=")
373               or arg.startswith("-benchmarks=")]
374  if found_arg:
375    # Remove --benchmarks arg from sys.argv
376    argv.remove(found_arg[0])
377
378    regex = found_arg[0].split("=")[1]
379    app.run(lambda _: _run_benchmarks(regex), argv=argv)
380  else:
381    true_main()
382