1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A test lib that defines some models."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import init_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
27from tensorflow.python.ops import nn_ops
28from tensorflow.python.ops import rnn
29from tensorflow.python.ops import rnn_cell
30from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.profiler import model_analyzer
33from tensorflow.python.training import gradient_descent
34from tensorflow.python.util import _pywrap_tfprof as print_mdl
35from tensorflow.python.util import compat
36
37
38def BuildSmallModel():
39  """Build a small forward conv model."""
40  image = array_ops.zeros([2, 6, 6, 3])
41  _ = variable_scope.get_variable(
42      'ScalarW', [],
43      dtypes.float32,
44      initializer=init_ops.random_normal_initializer(stddev=0.001))
45  kernel = variable_scope.get_variable(
46      'DW', [3, 3, 3, 6],
47      dtypes.float32,
48      initializer=init_ops.random_normal_initializer(stddev=0.001))
49  x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
50  kernel = variable_scope.get_variable(
51      'DW2', [2, 2, 6, 12],
52      dtypes.float32,
53      initializer=init_ops.random_normal_initializer(stddev=0.001))
54  x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
55  return x
56
57
58def BuildFullModel():
59  """Build the full model with conv,rnn,opt."""
60  seq = []
61  for i in range(4):
62    with variable_scope.variable_scope('inp_%d' % i):
63      seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
64
65  cell = rnn_cell.BasicRNNCell(16)
66  out = rnn.dynamic_rnn(
67      cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
68
69  target = array_ops.ones_like(out)
70  loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
71  sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
72  return sgd_op.minimize(loss)
73
74
75def BuildSplittableModel():
76  """Build a small model that can be run partially in each step."""
77  image = array_ops.zeros([2, 6, 6, 3])
78
79  kernel1 = variable_scope.get_variable(
80      'DW', [3, 3, 3, 6],
81      dtypes.float32,
82      initializer=init_ops.random_normal_initializer(stddev=0.001))
83  r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
84
85  kernel2 = variable_scope.get_variable(
86      'DW2', [2, 3, 3, 6],
87      dtypes.float32,
88      initializer=init_ops.random_normal_initializer(stddev=0.001))
89  r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
90
91  r3 = r1 + r2
92  return r1, r2, r3
93
94
95def SearchTFProfNode(node, name):
96  """Search a node in the tree."""
97  if node.name == name:
98    return node
99  for c in node.children:
100    r = SearchTFProfNode(c, name)
101    if r: return r
102  return None
103
104
105@contextlib.contextmanager
106def ProfilerFromFile(profile_file):
107  """Initialize a profiler from profile file."""
108  print_mdl.ProfilerFromFile(compat.as_bytes(profile_file))
109  profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler)
110  yield profiler
111  print_mdl.DeleteProfiler()
112
113
114def CheckAndRemoveDoc(profile):
115  assert 'Doc:' in profile
116  start_pos = profile.find('Profile:')
117  return profile[start_pos + 9:]
118