1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities to test summaries."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import os
24
25import sqlite3
26
27from tensorflow.core.util import event_pb2
28from tensorflow.python.framework import test_util
29from tensorflow.python.lib.io import tf_record
30from tensorflow.python.ops import summary_ops_v2 as summary_ops
31from tensorflow.python.platform import gfile
32
33
34class SummaryDbTest(test_util.TensorFlowTestCase):
35  """Helper for summary database testing."""
36
37  def setUp(self):
38    super(SummaryDbTest, self).setUp()
39    self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
40    if os.path.exists(self.db_path):
41      os.unlink(self.db_path)
42    self.db = sqlite3.connect(self.db_path)
43    self.create_db_writer = functools.partial(
44        summary_ops.create_db_writer,
45        db_uri=self.db_path,
46        experiment_name='experiment',
47        run_name='run',
48        user_name='user')
49
50  def tearDown(self):
51    self.db.close()
52    super(SummaryDbTest, self).tearDown()
53
54
55def events_from_file(filepath):
56  """Returns all events in a single event file.
57
58  Args:
59    filepath: Path to the event file.
60
61  Returns:
62    A list of all tf.Event protos in the event file.
63  """
64  records = list(tf_record.tf_record_iterator(filepath))
65  result = []
66  for r in records:
67    event = event_pb2.Event()
68    event.ParseFromString(r)
69    result.append(event)
70  return result
71
72
73def events_from_logdir(logdir):
74  """Returns all events in the single eventfile in logdir.
75
76  Args:
77    logdir: The directory in which the single event file is sought.
78
79  Returns:
80    A list of all tf.Event protos from the single event file.
81
82  Raises:
83    AssertionError: If logdir does not contain exactly one file.
84  """
85  assert gfile.Exists(logdir)
86  files = gfile.ListDirectory(logdir)
87  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
88  return events_from_file(os.path.join(logdir, files[0]))
89
90
91def get_one(db, q, *p):
92  return db.execute(q, p).fetchone()[0]
93
94
95def get_all(db, q, *p):
96  return unroll(db.execute(q, p).fetchall())
97
98
99def unroll(list_of_tuples):
100  return sum(list_of_tuples, ())
101