1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data benchmarking functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21
22import numpy as np
23
24from tensorflow.python.eager import context
25from tensorflow.python.client import session
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.util import nest
28from tensorflow.python.platform import test
29
30
31class DatasetBenchmarkBase(test.Benchmark):
32  """Base class for dataset benchmarks."""
33
34  def _run_eager_benchmark(self, iterable, iters, warmup):
35    """Benchmark the iterable in eager mode.
36
37    Runs the iterable `iters` times. In each iteration, the benchmark measures
38    the time it takes to go execute the iterable.
39
40    Args:
41      iterable: The tf op or tf.data Dataset to benchmark.
42      iters: Number of times to repeat the timing.
43      warmup: If true, warms up the session caches by running an untimed run.
44
45    Returns:
46      A float, representing the median time (with respect to `iters`)
47      it takes for the iterable to be executed `iters` num of times.
48    """
49
50    deltas = []
51    if not context.executing_eagerly():
52      raise RuntimeError(
53          "Eager mode benchmarking is not supported in graph mode.")
54
55    for _ in range(iters):
56      if warmup:
57        iterator = iter(iterable)
58        next(iterator)
59
60      iterator = iter(iterable)
61      start = time.time()
62      next(iterator)
63      end = time.time()
64      deltas.append(end - start)
65    return np.median(deltas)
66
67  def _run_graph_benchmark(self,
68                           iterable,
69                           iters,
70                           warmup,
71                           session_config,
72                           initializer=None):
73    """Benchmarks the iterable in graph mode.
74
75    Runs the iterable `iters` times. In each iteration, the benchmark measures
76    the time it takes to go execute the iterable.
77
78    Args:
79      iterable: The tf op or tf.data Dataset to benchmark.
80      iters: Number of times to repeat the timing.
81      warmup: If true, warms up the session caches by running an untimed run.
82      session_config: A ConfigProto protocol buffer with configuration options
83        for the session. Applicable only for benchmarking in graph mode.
84      initializer: The initializer op required to initialize the iterable.
85
86    Returns:
87      A float, representing the median time (with respect to `iters`)
88      it takes for the iterable to be executed `iters` num of times.
89    """
90
91    deltas = []
92    if context.executing_eagerly():
93      raise RuntimeError(
94          "Graph mode benchmarking is not supported in eager mode.")
95
96    for _ in range(iters):
97      with session.Session(config=session_config) as sess:
98        if warmup:
99          # Run once to warm up the session caches.
100          if initializer:
101            sess.run(initializer)
102          sess.run(iterable)
103
104        if initializer:
105          sess.run(initializer)
106        start = time.time()
107        sess.run(iterable)
108        end = time.time()
109      deltas.append(end - start)
110    return np.median(deltas)
111
112  def run_op_benchmark(self, op, iters=1, warmup=True, session_config=None):
113    """Benchmarks the op.
114
115    Runs the op `iters` times. In each iteration, the benchmark measures
116    the time it takes to go execute the op.
117
118    Args:
119      op: The tf op to benchmark.
120      iters: Number of times to repeat the timing.
121      warmup: If true, warms up the session caches by running an untimed run.
122      session_config: A ConfigProto protocol buffer with configuration options
123        for the session. Applicable only for benchmarking in graph mode.
124
125    Returns:
126      A float, representing the per-execution wall time of the op in seconds.
127      This is the median time (with respect to `iters`) it takes for the op
128      to be executed `iters` num of times.
129    """
130
131    if context.executing_eagerly():
132      return self._run_eager_benchmark(iterable=op, iters=iters, warmup=warmup)
133
134    return self._run_graph_benchmark(
135        iterable=op, iters=iters, warmup=warmup, session_config=session_config)
136
137  def run_benchmark(self,
138                    dataset,
139                    num_elements,
140                    iters=1,
141                    warmup=True,
142                    apply_default_optimizations=False,
143                    session_config=None):
144    """Benchmarks the dataset.
145
146    Runs the dataset `iters` times. In each iteration, the benchmark measures
147    the time it takes to go through `num_elements` elements of the dataset.
148
149    Args:
150      dataset: Dataset to benchmark.
151      num_elements: Number of dataset elements to iterate through each benchmark
152        iteration.
153      iters: Number of times to repeat the timing.
154      warmup: If true, warms up the session caches by running an untimed run.
155      apply_default_optimizations: Determines whether default optimizations
156        should be applied.
157      session_config: A ConfigProto protocol buffer with configuration options
158        for the session. Applicable only for benchmarking in graph mode.
159
160    Returns:
161      A float, representing the per-element wall time of the dataset in seconds.
162      This is the median time (with respect to `iters`) it takes for the dataset
163      to go through `num_elements` elements, divided by `num_elements.`
164    """
165
166    # The options that have been applied to the dataset are preserved so that
167    # they are not overwritten while benchmarking.
168    options = dataset.options()
169    options.experimental_optimization.apply_default_optimizations = (
170        apply_default_optimizations)
171    dataset = dataset.with_options(options)
172
173    # NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
174    # the overhead of having to execute a TensorFlow op for each step of the input
175    # pipeline. Note that this relies on the underlying implementation of `skip`
176    # to execute upstream computation. If it is optimized in the future,
177    # we will have to change this code.
178    dataset = dataset.skip(num_elements - 1)
179
180    if context.executing_eagerly():
181      median_duration = self._run_eager_benchmark(
182          iterable=dataset, iters=iters, warmup=warmup)
183      return median_duration / float(num_elements)
184
185    iterator = dataset_ops.make_initializable_iterator(dataset)
186    next_element = iterator.get_next()
187    op = nest.flatten(next_element)[0].op
188    median_duration = self._run_graph_benchmark(
189        iterable=op,
190        iters=iters,
191        warmup=warmup,
192        session_config=session_config,
193        initializer=iterator.initializer)
194    return median_duration / float(num_elements)
195
196  def run_and_report_benchmark(self,
197                               dataset,
198                               num_elements,
199                               name,
200                               iters=5,
201                               extras=None,
202                               warmup=True,
203                               apply_default_optimizations=False,
204                               session_config=None):
205    """Benchmarks the dataset and reports the stats.
206
207    Runs the dataset `iters` times. In each iteration, the benchmark measures
208    the time it takes to go through `num_elements` elements of the dataset.
209    This is followed by logging/printing the benchmark stats.
210
211    Args:
212      dataset: Dataset to benchmark.
213      num_elements: Number of dataset elements to iterate through each benchmark
214        iteration.
215      name: Name of the benchmark.
216      iters: Number of times to repeat the timing.
217      extras: A dict which maps string keys to additional benchmark info.
218      warmup: If true, warms up the session caches by running an untimed run.
219      apply_default_optimizations: Determines whether default optimizations
220        should be applied.
221      session_config: A ConfigProto protocol buffer with configuration options
222        for the session. Applicable only for benchmarking in graph mode.
223
224    Returns:
225      A float, representing the per-element wall time of the dataset in seconds.
226      This is the median time (with respect to `iters`) it takes for the dataset
227      to go through `num_elements` elements, divided by `num_elements.`
228    """
229    wall_time = self.run_benchmark(
230        dataset=dataset,
231        num_elements=num_elements,
232        iters=iters,
233        warmup=warmup,
234        apply_default_optimizations=apply_default_optimizations,
235        session_config=session_config)
236    if context.executing_eagerly():
237      name = "{}.eager".format(name)
238    else:
239      name = "{}.graph".format(name)
240    if extras is None:
241      extras = {}
242    extras["num_elements"] = num_elements
243    self.report_benchmark(
244        wall_time=wall_time, iters=iters, name=name, extras=extras)
245    return wall_time
246