1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fake summary writer for unit tests."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.core.framework import summary_pb2
21from tensorflow.python.framework import test_util
22from tensorflow.python.summary.writer import writer
23from tensorflow.python.summary.writer import writer_cache
24
25
26# TODO(ptucker): Replace with mock framework.
27class FakeSummaryWriter(object):
28  """Fake summary writer."""
29
30  _replaced_summary_writer = None
31
32  @classmethod
33  def install(cls):
34    if cls._replaced_summary_writer:
35      raise ValueError('FakeSummaryWriter already installed.')
36    cls._replaced_summary_writer = writer.FileWriter
37    writer.FileWriter = FakeSummaryWriter
38    writer_cache.FileWriter = FakeSummaryWriter
39
40  @classmethod
41  def uninstall(cls):
42    if not cls._replaced_summary_writer:
43      raise ValueError('FakeSummaryWriter not installed.')
44    writer.FileWriter = cls._replaced_summary_writer
45    writer_cache.FileWriter = cls._replaced_summary_writer
46    cls._replaced_summary_writer = None
47
48  def __init__(self, logdir, graph=None):
49    self._logdir = logdir
50    self._graph = graph
51    self._summaries = {}
52    self._added_graphs = []
53    self._added_meta_graphs = []
54    self._added_session_logs = []
55    self._added_run_metadata = {}
56
57  @property
58  def summaries(self):
59    return self._summaries
60
61  def assert_summaries(self,
62                       test_case,
63                       expected_logdir=None,
64                       expected_graph=None,
65                       expected_summaries=None,
66                       expected_added_graphs=None,
67                       expected_added_meta_graphs=None,
68                       expected_session_logs=None):
69    """Assert expected items have been added to summary writer."""
70    if expected_logdir is not None:
71      test_case.assertEqual(expected_logdir, self._logdir)
72    if expected_graph is not None:
73      test_case.assertTrue(expected_graph is self._graph)
74    expected_summaries = expected_summaries or {}
75    for step in expected_summaries:
76      test_case.assertTrue(
77          step in self._summaries,
78          msg='Missing step %s from %s.' % (step, self._summaries.keys()))
79      actual_simple_values = {}
80      for step_summary in self._summaries[step]:
81        for v in step_summary.value:
82          # Ignore global_step/sec since it's written by Supervisor in a
83          # separate thread, so it's non-deterministic how many get written.
84          if 'global_step/sec' != v.tag:
85            actual_simple_values[v.tag] = v.simple_value
86      test_case.assertEqual(expected_summaries[step], actual_simple_values)
87    if expected_added_graphs is not None:
88      test_case.assertEqual(expected_added_graphs, self._added_graphs)
89    if expected_added_meta_graphs is not None:
90      test_case.assertEqual(len(expected_added_meta_graphs),
91                            len(self._added_meta_graphs))
92      for expected, actual in zip(expected_added_meta_graphs,
93                                  self._added_meta_graphs):
94        test_util.assert_meta_graph_protos_equal(test_case, expected, actual)
95    if expected_session_logs is not None:
96      test_case.assertEqual(expected_session_logs, self._added_session_logs)
97
98  def add_summary(self, summ, current_global_step):
99    """Add summary."""
100    if isinstance(summ, bytes):
101      summary_proto = summary_pb2.Summary()
102      summary_proto.ParseFromString(summ)
103      summ = summary_proto
104    if current_global_step in self._summaries:
105      step_summaries = self._summaries[current_global_step]
106    else:
107      step_summaries = []
108      self._summaries[current_global_step] = step_summaries
109    step_summaries.append(summ)
110
111  # NOTE: Ignore global_step since its value is non-deterministic.
112  def add_graph(self, graph, global_step=None, graph_def=None):
113    """Add graph."""
114    if (global_step is not None) and (global_step < 0):
115      raise ValueError('Invalid global_step %s.' % global_step)
116    if graph_def is not None:
117      raise ValueError('Unexpected graph_def %s.' % graph_def)
118    self._added_graphs.append(graph)
119
120  def add_meta_graph(self, meta_graph_def, global_step=None):
121    """Add metagraph."""
122    if (global_step is not None) and (global_step < 0):
123      raise ValueError('Invalid global_step %s.' % global_step)
124    self._added_meta_graphs.append(meta_graph_def)
125
126  # NOTE: Ignore global_step since its value is non-deterministic.
127  def add_session_log(self, session_log, global_step=None):
128    # pylint: disable=unused-argument
129    self._added_session_logs.append(session_log)
130
131  def add_run_metadata(self, run_metadata, tag, global_step=None):
132    if (global_step is not None) and (global_step < 0):
133      raise ValueError('Invalid global_step %s.' % global_step)
134    self._added_run_metadata[tag] = run_metadata
135
136  def flush(self):
137    pass
138
139  def reopen(self):
140    pass
141
142  def close(self):
143    pass
144