1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A test lib that defines some models.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import init_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 27from tensorflow.python.ops import nn_ops 28from tensorflow.python.ops import rnn 29from tensorflow.python.ops import rnn_cell 30from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.profiler import model_analyzer 33from tensorflow.python.training import gradient_descent 34from tensorflow.python.util import _pywrap_tfprof as print_mdl 35from tensorflow.python.util import compat 36 37 38def BuildSmallModel(): 39 """Build a small forward conv model.""" 40 image = array_ops.zeros([2, 6, 6, 3]) 41 _ = variable_scope.get_variable( 42 'ScalarW', [], 43 dtypes.float32, 44 initializer=init_ops.random_normal_initializer(stddev=0.001)) 45 kernel = variable_scope.get_variable( 46 'DW', [3, 3, 3, 6], 47 dtypes.float32, 48 initializer=init_ops.random_normal_initializer(stddev=0.001)) 49 x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME') 50 kernel = variable_scope.get_variable( 51 'DW2', [2, 2, 6, 12], 52 dtypes.float32, 53 initializer=init_ops.random_normal_initializer(stddev=0.001)) 54 x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME') 55 return x 56 57 58def BuildFullModel(): 59 """Build the full model with conv,rnn,opt.""" 60 seq = [] 61 for i in range(4): 62 with variable_scope.variable_scope('inp_%d' % i): 63 seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1])) 64 65 cell = rnn_cell.BasicRNNCell(16) 66 out = rnn.dynamic_rnn( 67 cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0] 68 69 target = array_ops.ones_like(out) 70 loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out)) 71 sgd_op = gradient_descent.GradientDescentOptimizer(1e-2) 72 return sgd_op.minimize(loss) 73 74 75def BuildSplittableModel(): 76 """Build a small model that can be run partially in each step.""" 77 image = array_ops.zeros([2, 6, 6, 3]) 78 79 kernel1 = variable_scope.get_variable( 80 'DW', [3, 3, 3, 6], 81 dtypes.float32, 82 initializer=init_ops.random_normal_initializer(stddev=0.001)) 83 r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME') 84 85 kernel2 = variable_scope.get_variable( 86 'DW2', [2, 3, 3, 6], 87 dtypes.float32, 88 initializer=init_ops.random_normal_initializer(stddev=0.001)) 89 r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME') 90 91 r3 = r1 + r2 92 return r1, r2, r3 93 94 95def SearchTFProfNode(node, name): 96 """Search a node in the tree.""" 97 if node.name == name: 98 return node 99 for c in node.children: 100 r = SearchTFProfNode(c, name) 101 if r: return r 102 return None 103 104 105@contextlib.contextmanager 106def ProfilerFromFile(profile_file): 107 """Initialize a profiler from profile file.""" 108 print_mdl.ProfilerFromFile(compat.as_bytes(profile_file)) 109 profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler) 110 yield profiler 111 print_mdl.DeleteProfiler() 112 113 114def CheckAndRemoveDoc(profile): 115 assert 'Doc:' in profile 116 start_pos = profile.find('Profile:') 117 return profile[start_pos + 9:] 118