1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmark for split and grad of split."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session as session_lib
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import benchmark
30from tensorflow.python.platform import test
31from tensorflow.python.platform import tf_logging as logging
32
33
34def build_graph(device, input_shape, output_sizes, axis):
35  """Build a graph containing a sequence of split operations.
36
37  Args:
38    device: string, the device to run on.
39    input_shape: shape of the input tensor.
40    output_sizes: size of each output along axis.
41    axis: axis to be split along.
42
43  Returns:
44    An array of tensors to run()
45  """
46  with ops.device("/%s:0" % device):
47    inp = array_ops.zeros(input_shape)
48
49    outputs = []
50    for _ in range(100):
51      outputs.extend(array_ops.split(inp, output_sizes, axis))
52    return control_flow_ops.group(*outputs)
53
54
55class SplitBenchmark(test.Benchmark):
56  """Benchmark split!"""
57
58  def _run_graph(self, device, output_shape, variable, num_outputs, axis):
59    """Run the graph and print its execution time.
60
61    Args:
62      device: string, the device to run on.
63      output_shape: shape of each output tensors.
64      variable: whether or not the output shape should be fixed
65      num_outputs: the number of outputs to split the input into
66      axis: axis to be split
67
68    Returns:
69      The duration of the run in seconds.
70    """
71    graph = ops.Graph()
72    with graph.as_default():
73      if not variable:
74        if axis == 0:
75          input_shape = [output_shape[0] * num_outputs, output_shape[1]]
76          sizes = [output_shape[0] for _ in range(num_outputs)]
77        else:
78          input_shape = [output_shape[0], output_shape[1] * num_outputs]
79          sizes = [output_shape[1] for _ in range(num_outputs)]
80      else:
81        sizes = np.random.randint(
82            low=max(1, output_shape[axis] - 2),
83            high=output_shape[axis] + 2,
84            size=num_outputs)
85        total_size = np.sum(sizes)
86        if axis == 0:
87          input_shape = [total_size, output_shape[1]]
88        else:
89          input_shape = [output_shape[0], total_size]
90
91      outputs = build_graph(device, input_shape, sizes, axis)
92    config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
93        optimizer_options=config_pb2.OptimizerOptions(
94            opt_level=config_pb2.OptimizerOptions.L0)))
95    with session_lib.Session(graph=graph, config=config) as session:
96      logging.set_verbosity("info")
97      variables.global_variables_initializer().run()
98      bench = benchmark.TensorFlowBenchmark()
99      bench.run_op_benchmark(
100          session,
101          outputs,
102          mbs=input_shape[0] * input_shape[1] * 4 * 2 * 100 / 1e6,
103          extras={
104              "input_shape": input_shape,
105              "variable": variable,
106              "axis": axis
107          })
108
109  def benchmark_split(self):
110    print("Forward vs backward concat")
111    shapes = [[2000, 8], [8, 2000], [100, 18], [1000, 18], [10000, 18],
112              [100, 97], [1000, 97], [10000, 1], [1, 10000]]
113    axis_ = [1]  # 0 is very fast because it doesn't actually do any copying
114    num_outputs = 100
115    variable = [False, True]  # fixed input size or not
116    for shape in shapes:
117      for axis in axis_:
118        for v in variable:
119          self._run_graph("gpu", shape, v, num_outputs, axis)
120
121
122if __name__ == "__main__":
123  test.main()
124