1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.platform import test
25
26
27class TFProfLoggerTest(test.TestCase):
28
29  def _BuildSmallPlaceholderlModel(self):
30    a = array_ops.placeholder(dtypes.int32, [2, 2])
31    b = array_ops.placeholder(dtypes.int32, [2, 2])
32    y = math_ops.matmul(a, b)
33    return a, b, y
34
35  def _BuildSmallModel(self):
36    a = constant_op.constant([[1, 2], [3, 4]])
37    b = constant_op.constant([[1, 2], [3, 4]])
38    return math_ops.matmul(a, b)
39
40  # pylint: disable=pointless-string-statement
41  """# TODO(xpan): This out of core so it doesn't depend on contrib.
42  def testFillMissingShape(self):
43    a, b, y = self._BuildSmallPlaceholderlModel()
44    run_options = config_pb2.RunOptions(
45        trace_level=config_pb2.RunOptions.FULL_TRACE)
46    run_metadata = config_pb2.RunMetadata()
47    sess = session.Session()
48    sess.run(y,
49             options=run_options,
50             run_metadata=run_metadata,
51             feed_dict={a: [[1, 2], [2, 3]],
52                        b: [[1, 2], [2, 3]]})
53
54    graph2 = ops.Graph()
55    # Use copy_op_to_graph to remove shape information.
56    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
57    self.assertEqual('<unknown>', str(y2.get_shape()))
58
59    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
60    self.assertEqual('(2, 2)', str(y2.get_shape()))
61
62  def testFailedFillMissingShape(self):
63    y = self._BuildSmallModel()
64    run_options = config_pb2.RunOptions(
65        trace_level=config_pb2.RunOptions.FULL_TRACE)
66    run_metadata = config_pb2.RunMetadata()
67    sess = session.Session()
68    sess.run(y, options=run_options, run_metadata=run_metadata)
69
70    graph2 = ops.Graph()
71    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
72    self.assertEqual('<unknown>', str(y2.get_shape()))
73    # run_metadata has special name for MatMul, hence failed to fill shape.
74    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
75    self.assertEqual('<unknown>', str(y2.get_shape()))
76  """
77
78
79if __name__ == '__main__':
80  test.main()
81