1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities to test summaries.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import os 24 25import sqlite3 26 27from tensorflow.core.util import event_pb2 28from tensorflow.python.framework import test_util 29from tensorflow.python.lib.io import tf_record 30from tensorflow.python.ops import summary_ops_v2 as summary_ops 31from tensorflow.python.platform import gfile 32 33 34class SummaryDbTest(test_util.TensorFlowTestCase): 35 """Helper for summary database testing.""" 36 37 def setUp(self): 38 super(SummaryDbTest, self).setUp() 39 self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') 40 if os.path.exists(self.db_path): 41 os.unlink(self.db_path) 42 self.db = sqlite3.connect(self.db_path) 43 self.create_db_writer = functools.partial( 44 summary_ops.create_db_writer, 45 db_uri=self.db_path, 46 experiment_name='experiment', 47 run_name='run', 48 user_name='user') 49 50 def tearDown(self): 51 self.db.close() 52 super(SummaryDbTest, self).tearDown() 53 54 55def events_from_file(filepath): 56 """Returns all events in a single event file. 57 58 Args: 59 filepath: Path to the event file. 60 61 Returns: 62 A list of all tf.Event protos in the event file. 63 """ 64 records = list(tf_record.tf_record_iterator(filepath)) 65 result = [] 66 for r in records: 67 event = event_pb2.Event() 68 event.ParseFromString(r) 69 result.append(event) 70 return result 71 72 73def events_from_logdir(logdir): 74 """Returns all events in the single eventfile in logdir. 75 76 Args: 77 logdir: The directory in which the single event file is sought. 78 79 Returns: 80 A list of all tf.Event protos from the single event file. 81 82 Raises: 83 AssertionError: If logdir does not contain exactly one file. 84 """ 85 assert gfile.Exists(logdir) 86 files = gfile.ListDirectory(logdir) 87 assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files 88 return events_from_file(os.path.join(logdir, files[0])) 89 90 91def get_one(db, q, *p): 92 return db.execute(q, p).fetchone()[0] 93 94 95def get_all(db, q, *p): 96 return unroll(db.execute(q, p).fetchall()) 97 98 99def unroll(list_of_tuples): 100 return sum(list_of_tuples, ()) 101