1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.contrib.slim.summaries.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import glob 22import os 23 24 25from tensorflow.contrib.slim.python.slim import summaries 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.platform import gfile 29from tensorflow.python.platform import test 30from tensorflow.python.summary import summary 31from tensorflow.python.summary import summary_iterator 32 33 34class SummariesTest(test.TestCase): 35 36 def safe_create(self, output_dir): 37 if gfile.Exists(output_dir): 38 gfile.DeleteRecursively(output_dir) 39 gfile.MakeDirs(output_dir) 40 41 def assert_scalar_summary(self, output_dir, names_to_values): 42 """Asserts that the given output directory contains written summaries. 43 44 Args: 45 output_dir: The output directory in which to look for even tfiles. 46 names_to_values: A dictionary of summary names to values. 47 """ 48 # The events file may have additional entries, e.g. the event version 49 # stamp, so have to parse things a bit. 50 output_filepath = glob.glob(os.path.join(output_dir, '*')) 51 self.assertEqual(len(output_filepath), 1) 52 53 events = summary_iterator.summary_iterator(output_filepath[0]) 54 summaries_list = [e.summary for e in events if e.summary.value] 55 values = [] 56 for item in summaries_list: 57 for value in item.value: 58 values.append(value) 59 saved_results = {v.tag: v.simple_value for v in values} 60 for name in names_to_values: 61 self.assertAlmostEqual(names_to_values[name], saved_results[name]) 62 63 def testScalarSummaryIsPartOfCollectionWithNoPrint(self): 64 tensor = array_ops.ones([]) * 3 65 name = 'my_score' 66 prefix = 'eval' 67 op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=False) 68 self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES)) 69 70 def testScalarSummaryIsPartOfCollectionWithPrint(self): 71 tensor = array_ops.ones([]) * 3 72 name = 'my_score' 73 prefix = 'eval' 74 op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=True) 75 self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES)) 76 77 def verify_scalar_summary_is_written(self, print_summary): 78 value = 3 79 tensor = array_ops.ones([]) * value 80 name = 'my_score' 81 prefix = 'eval' 82 summaries.add_scalar_summary(tensor, name, prefix, print_summary) 83 84 output_dir = os.path.join(self.get_temp_dir(), 85 'scalar_summary_no_print_test') 86 self.safe_create(output_dir) 87 88 summary_op = summary.merge_all() 89 90 summary_writer = summary.FileWriter(output_dir) 91 with self.cached_session() as sess: 92 new_summary = sess.run(summary_op) 93 summary_writer.add_summary(new_summary, 1) 94 summary_writer.flush() 95 96 self.assert_scalar_summary(output_dir, { 97 '%s/%s' % (prefix, name): value 98 }) 99 100 def testScalarSummaryIsWrittenWithNoPrint(self): 101 self.verify_scalar_summary_is_written(print_summary=False) 102 103 def testScalarSummaryIsWrittenWithPrint(self): 104 self.verify_scalar_summary_is_written(print_summary=True) 105 106 107if __name__ == '__main__': 108 test.main() 109