1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Benchmark for split and grad of split.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session as session_lib 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import benchmark 30from tensorflow.python.platform import test 31from tensorflow.python.platform import tf_logging as logging 32 33 34def build_graph(device, input_shape, output_sizes, axis): 35 """Build a graph containing a sequence of split operations. 36 37 Args: 38 device: string, the device to run on. 39 input_shape: shape of the input tensor. 40 output_sizes: size of each output along axis. 41 axis: axis to be split along. 42 43 Returns: 44 An array of tensors to run() 45 """ 46 with ops.device("/%s:0" % device): 47 inp = array_ops.zeros(input_shape) 48 49 outputs = [] 50 for _ in range(100): 51 outputs.extend(array_ops.split(inp, output_sizes, axis)) 52 return control_flow_ops.group(*outputs) 53 54 55class SplitBenchmark(test.Benchmark): 56 """Benchmark split!""" 57 58 def _run_graph(self, device, output_shape, variable, num_outputs, axis): 59 """Run the graph and print its execution time. 60 61 Args: 62 device: string, the device to run on. 63 output_shape: shape of each output tensors. 64 variable: whether or not the output shape should be fixed 65 num_outputs: the number of outputs to split the input into 66 axis: axis to be split 67 68 Returns: 69 The duration of the run in seconds. 70 """ 71 graph = ops.Graph() 72 with graph.as_default(): 73 if not variable: 74 if axis == 0: 75 input_shape = [output_shape[0] * num_outputs, output_shape[1]] 76 sizes = [output_shape[0] for _ in range(num_outputs)] 77 else: 78 input_shape = [output_shape[0], output_shape[1] * num_outputs] 79 sizes = [output_shape[1] for _ in range(num_outputs)] 80 else: 81 sizes = np.random.randint( 82 low=max(1, output_shape[axis] - 2), 83 high=output_shape[axis] + 2, 84 size=num_outputs) 85 total_size = np.sum(sizes) 86 if axis == 0: 87 input_shape = [total_size, output_shape[1]] 88 else: 89 input_shape = [output_shape[0], total_size] 90 91 outputs = build_graph(device, input_shape, sizes, axis) 92 config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 93 optimizer_options=config_pb2.OptimizerOptions( 94 opt_level=config_pb2.OptimizerOptions.L0))) 95 with session_lib.Session(graph=graph, config=config) as session: 96 logging.set_verbosity("info") 97 variables.global_variables_initializer().run() 98 bench = benchmark.TensorFlowBenchmark() 99 bench.run_op_benchmark( 100 session, 101 outputs, 102 mbs=input_shape[0] * input_shape[1] * 4 * 2 * 100 / 1e6, 103 extras={ 104 "input_shape": input_shape, 105 "variable": variable, 106 "axis": axis 107 }) 108 109 def benchmark_split(self): 110 print("Forward vs backward concat") 111 shapes = [[2000, 8], [8, 2000], [100, 18], [1000, 18], [10000, 18], 112 [100, 97], [1000, 97], [10000, 1], [1, 10000]] 113 axis_ = [1] # 0 is very fast because it doesn't actually do any copying 114 num_outputs = 100 115 variable = [False, True] # fixed input size or not 116 for shape in shapes: 117 for axis in axis_: 118 for v in variable: 119 self._run_graph("gpu", shape, v, num_outputs, axis) 120 121 122if __name__ == "__main__": 123 test.main() 124