1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.contrib.slim.summaries."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import glob
22import os
23
24
25from tensorflow.contrib.slim.python.slim import summaries
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.platform import gfile
29from tensorflow.python.platform import test
30from tensorflow.python.summary import summary
31from tensorflow.python.summary import summary_iterator
32
33
34class SummariesTest(test.TestCase):
35
36  def safe_create(self, output_dir):
37    if gfile.Exists(output_dir):
38      gfile.DeleteRecursively(output_dir)
39    gfile.MakeDirs(output_dir)
40
41  def assert_scalar_summary(self, output_dir, names_to_values):
42    """Asserts that the given output directory contains written summaries.
43
44    Args:
45      output_dir: The output directory in which to look for even tfiles.
46      names_to_values: A dictionary of summary names to values.
47    """
48    # The events file may have additional entries, e.g. the event version
49    # stamp, so have to parse things a bit.
50    output_filepath = glob.glob(os.path.join(output_dir, '*'))
51    self.assertEqual(len(output_filepath), 1)
52
53    events = summary_iterator.summary_iterator(output_filepath[0])
54    summaries_list = [e.summary for e in events if e.summary.value]
55    values = []
56    for item in summaries_list:
57      for value in item.value:
58        values.append(value)
59    saved_results = {v.tag: v.simple_value for v in values}
60    for name in names_to_values:
61      self.assertAlmostEqual(names_to_values[name], saved_results[name])
62
63  def testScalarSummaryIsPartOfCollectionWithNoPrint(self):
64    tensor = array_ops.ones([]) * 3
65    name = 'my_score'
66    prefix = 'eval'
67    op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=False)
68    self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES))
69
70  def testScalarSummaryIsPartOfCollectionWithPrint(self):
71    tensor = array_ops.ones([]) * 3
72    name = 'my_score'
73    prefix = 'eval'
74    op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=True)
75    self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES))
76
77  def verify_scalar_summary_is_written(self, print_summary):
78    value = 3
79    tensor = array_ops.ones([]) * value
80    name = 'my_score'
81    prefix = 'eval'
82    summaries.add_scalar_summary(tensor, name, prefix, print_summary)
83
84    output_dir = os.path.join(self.get_temp_dir(),
85                              'scalar_summary_no_print_test')
86    self.safe_create(output_dir)
87
88    summary_op = summary.merge_all()
89
90    summary_writer = summary.FileWriter(output_dir)
91    with self.cached_session() as sess:
92      new_summary = sess.run(summary_op)
93      summary_writer.add_summary(new_summary, 1)
94      summary_writer.flush()
95
96    self.assert_scalar_summary(output_dir, {
97        '%s/%s' % (prefix, name): value
98    })
99
100  def testScalarSummaryIsWrittenWithNoPrint(self):
101    self.verify_scalar_summary_is_written(print_summary=False)
102
103  def testScalarSummaryIsWrittenWithPrint(self):
104    self.verify_scalar_summary_is_written(print_summary=True)
105
106
107if __name__ == '__main__':
108  test.main()
109