1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains helper functions for creating summaries.
16
17This module contains various helper functions for quickly and easily adding
18tensorflow summaries. These allow users to print summary values
19automatically as they are computed and add prefixes to collections of summaries.
20
21Example usage:
22
23  import tensorflow as tf
24  slim = tf.contrib.slim
25
26  slim.summaries.add_histogram_summaries(slim.variables.get_model_variables())
27  slim.summaries.add_scalar_summary(total_loss, 'Total Loss')
28  slim.summaries.add_scalar_summary(learning_rate, 'Learning Rate')
29  slim.summaries.add_histogram_summaries(my_tensors)
30  slim.summaries.add_zero_fraction_summaries(my_tensors)
31"""
32from __future__ import absolute_import
33from __future__ import division
34from __future__ import print_function
35
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import logging_ops
38from tensorflow.python.ops import nn_impl as nn
39from tensorflow.python.summary import summary
40
41
42def _get_summary_name(tensor, name=None, prefix=None, postfix=None):
43  """Produces the summary name given.
44
45  Args:
46    tensor: A variable or op `Tensor`.
47    name: The optional name for the summary.
48    prefix: An optional prefix for the summary name.
49    postfix: An optional postfix for the summary name.
50
51  Returns:
52    a summary name.
53  """
54  if not name:
55    name = tensor.op.name
56  if prefix:
57    name = prefix + '/' + name
58  if postfix:
59    name = name + '/' + postfix
60  return name
61
62
63def add_histogram_summary(tensor, name=None, prefix=None):
64  """Adds a histogram summary for the given tensor.
65
66  Args:
67    tensor: A variable or op tensor.
68    name: The optional name for the summary.
69    prefix: An optional prefix for the summary names.
70
71  Returns:
72    A scalar `Tensor` of type `string` whose contents are the serialized
73    `Summary` protocol buffer.
74  """
75  return summary.histogram(
76      _get_summary_name(tensor, name, prefix), tensor)
77
78
79def add_image_summary(tensor, name=None, prefix=None, print_summary=False):
80  """Adds an image summary for the given tensor.
81
82  Args:
83    tensor: a variable or op tensor with shape [batch,height,width,channels]
84    name: the optional name for the summary.
85    prefix: An optional prefix for the summary names.
86    print_summary: If `True`, the summary is printed to stdout when the summary
87      is computed.
88
89  Returns:
90    An image `Tensor` of type `string` whose contents are the serialized
91    `Summary` protocol buffer.
92  """
93  summary_name = _get_summary_name(tensor, name, prefix)
94  # If print_summary, then we need to make sure that this call doesn't add the
95  # non-printing op to the collection. We'll add it to the collection later.
96  collections = [] if print_summary else None
97  op = summary.image(
98      name=summary_name, tensor=tensor, collections=collections)
99  if print_summary:
100    op = logging_ops.Print(op, [tensor], summary_name)
101    ops.add_to_collection(ops.GraphKeys.SUMMARIES, op)
102  return op
103
104
105def add_scalar_summary(tensor, name=None, prefix=None, print_summary=False):
106  """Adds a scalar summary for the given tensor.
107
108  Args:
109    tensor: a variable or op tensor.
110    name: the optional name for the summary.
111    prefix: An optional prefix for the summary names.
112    print_summary: If `True`, the summary is printed to stdout when the summary
113      is computed.
114
115  Returns:
116    A scalar `Tensor` of type `string` whose contents are the serialized
117    `Summary` protocol buffer.
118  """
119  collections = [] if print_summary else None
120  summary_name = _get_summary_name(tensor, name, prefix)
121
122  # If print_summary, then we need to make sure that this call doesn't add the
123  # non-printing op to the collection. We'll add it to the collection later.
124  op = summary.scalar(
125      name=summary_name, tensor=tensor, collections=collections)
126  if print_summary:
127    op = logging_ops.Print(op, [tensor], summary_name)
128    ops.add_to_collection(ops.GraphKeys.SUMMARIES, op)
129  return op
130
131
132def add_zero_fraction_summary(tensor, name=None, prefix=None,
133                              print_summary=False):
134  """Adds a summary for the percentage of zero values in the given tensor.
135
136  Args:
137    tensor: a variable or op tensor.
138    name: the optional name for the summary.
139    prefix: An optional prefix for the summary names.
140    print_summary: If `True`, the summary is printed to stdout when the summary
141      is computed.
142
143  Returns:
144    A scalar `Tensor` of type `string` whose contents are the serialized
145    `Summary` protocol buffer.
146  """
147  name = _get_summary_name(tensor, name, prefix, 'Fraction_of_Zero_Values')
148  tensor = nn.zero_fraction(tensor)
149  return add_scalar_summary(tensor, name, print_summary=print_summary)
150
151
152def add_histogram_summaries(tensors, prefix=None):
153  """Adds a histogram summary for each of the given tensors.
154
155  Args:
156    tensors: A list of variable or op tensors.
157    prefix: An optional prefix for the summary names.
158
159  Returns:
160    A list of scalar `Tensors` of type `string` whose contents are the
161    serialized `Summary` protocol buffer.
162  """
163  summary_ops = []
164  for tensor in tensors:
165    summary_ops.append(add_histogram_summary(tensor, prefix=prefix))
166  return summary_ops
167
168
169def add_image_summaries(tensors, prefix=None):
170  """Adds an image summary for each of the given tensors.
171
172  Args:
173    tensors: A list of variable or op tensors.
174    prefix: An optional prefix for the summary names.
175
176  Returns:
177    A list of scalar `Tensors` of type `string` whose contents are the
178    serialized `Summary` protocol buffer.
179  """
180  summary_ops = []
181  for tensor in tensors:
182    summary_ops.append(add_image_summary(tensor, prefix=prefix))
183  return summary_ops
184
185
186def add_scalar_summaries(tensors, prefix=None, print_summary=False):
187  """Adds a scalar summary for each of the given tensors.
188
189  Args:
190    tensors: a list of variable or op tensors.
191    prefix: An optional prefix for the summary names.
192    print_summary: If `True`, the summary is printed to stdout when the summary
193      is computed.
194
195  Returns:
196    A list of scalar `Tensors` of type `string` whose contents are the
197    serialized `Summary` protocol buffer.
198  """
199  summary_ops = []
200  for tensor in tensors:
201    summary_ops.append(add_scalar_summary(tensor, prefix=prefix,
202                                          print_summary=print_summary))
203  return summary_ops
204
205
206def add_zero_fraction_summaries(tensors, prefix=None):
207  """Adds a scalar zero-fraction summary for each of the given tensors.
208
209  Args:
210    tensors: a list of variable or op tensors.
211    prefix: An optional prefix for the summary names.
212
213  Returns:
214    A list of scalar `Tensors` of type `string` whose contents are the
215    serialized `Summary` protocol buffer.
216  """
217  summary_ops = []
218  for tensor in tensors:
219    summary_ops.append(add_zero_fraction_summary(tensor, prefix=prefix))
220  return summary_ops
221