1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Operations to emit summaries."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import getpass
23import os
24import re
25import time
26
27import six
28
29from tensorflow.contrib.summary import gen_summary_ops
30from tensorflow.core.framework import graph_pb2
31from tensorflow.python.eager import context
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.layers import utils
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import summary_op_util
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.training import training_util
43from tensorflow.python.util import tf_contextlib
44
45
46# Name for a collection which is expected to have at most a single boolean
47# Tensor. If this tensor is True the summary ops will record summaries.
48_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
49
50_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
51
52_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
53_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$")
54_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
55
56
57def should_record_summaries():
58  """Returns boolean Tensor which is true if summaries should be recorded."""
59  should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
60  if not should_record_collection:
61    return False
62  if len(should_record_collection) != 1:
63    raise ValueError(
64        "More than one tensor specified for whether summaries "
65        "should be recorded: %s" % should_record_collection)
66  return should_record_collection[0]
67
68
69# TODO(apassos) consider how to handle local step here.
70@tf_contextlib.contextmanager
71def record_summaries_every_n_global_steps(n, global_step=None):
72  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
73  if global_step is None:
74    global_step = training_util.get_or_create_global_step()
75  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
76  old = collection_ref[:]
77  with ops.device("cpu:0"):
78    collection_ref[:] = [math_ops.equal(global_step % n, 0)]
79  yield
80  collection_ref[:] = old
81
82
83@tf_contextlib.contextmanager
84def always_record_summaries():
85  """Sets the should_record_summaries Tensor to always true."""
86  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
87  old = collection_ref[:]
88  collection_ref[:] = [True]
89  yield
90  collection_ref[:] = old
91
92
93@tf_contextlib.contextmanager
94def never_record_summaries():
95  """Sets the should_record_summaries Tensor to always false."""
96  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
97  old = collection_ref[:]
98  collection_ref[:] = [False]
99  yield
100  collection_ref[:] = old
101
102
103class SummaryWriter(object):
104  """Encapsulates a stateful summary writer resource.
105
106  See also:
107  - @{tf.contrib.summary.create_file_writer}
108  - @{tf.contrib.summary.create_db_writer}
109  """
110
111  def  __init__(self, resource):
112    self._resource = resource
113    if context.in_eager_mode() and self._resource is not None:
114      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
115          handle=self._resource, handle_device="cpu:0")
116
117  def set_as_default(self):
118    """Enables this summary writer for the current thread."""
119    context.context().summary_writer_resource = self._resource
120
121  @tf_contextlib.contextmanager
122  def as_default(self):
123    """Enables summary writing within a `with` block."""
124    if self._resource is None:
125      yield self
126    else:
127      old = context.context().summary_writer_resource
128      context.context().summary_writer_resource = self._resource
129      yield self
130      # Flushes the summary writer in eager mode or in graph functions, but not
131      # in legacy graph mode (you're on your own there).
132      with ops.device("cpu:0"):
133        gen_summary_ops.flush_summary_writer(self._resource)
134      context.context().summary_writer_resource = old
135
136
137def initialize(
138    graph=None,  # pylint: disable=redefined-outer-name
139    session=None):
140  """Initializes summary writing for graph execution mode.
141
142  This helper method provides a higher-level alternative to using
143  @{tf.contrib.summary.summary_writer_initializer_op} and
144  @{tf.contrib.summary.graph}.
145
146  Most users will also want to call @{tf.train.create_global_step}
147  which can happen before or after this function is called.
148
149  Args:
150    graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer.
151      This function will not write the default graph by default. When
152      writing to an event log file, the associated step will be zero.
153    session: So this method can call @{tf.Session.run}. This defaults
154      to @{tf.get_default_session}.
155
156  Raises:
157    RuntimeError: If  the current thread has no default
158      @{tf.contrib.summary.SummaryWriter}.
159    ValueError: If session wasn't passed and no default session.
160  """
161  if context.in_eager_mode():
162    return
163  if context.context().summary_writer_resource is None:
164    raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
165  if session is None:
166    session = ops.get_default_session()
167    if session is None:
168      raise ValueError("session must be passed if no default session exists")
169  session.run(summary_writer_initializer_op())
170  if graph is not None:
171    data = _serialize_graph(graph)
172    x = array_ops.placeholder(dtypes.string)
173    session.run(_graph(x, 0), feed_dict={x: data})
174
175
176def create_file_writer(logdir,
177                       max_queue=None,
178                       flush_millis=None,
179                       filename_suffix=None,
180                       name=None):
181  """Creates a summary file writer in the current context.
182
183  Args:
184    logdir: a string, or None. If a string, creates a summary file writer
185     which writes to the directory named by the string. If None, returns
186     a mock object which acts like a summary writer but does nothing,
187     useful to use as a context manager.
188    max_queue: the largest number of summaries to keep in a queue; will
189     flush once the queue gets bigger than this.
190    flush_millis: the largest interval between flushes.
191    filename_suffix: optional suffix for the event file name.
192    name: Shared name for this SummaryWriter resource stored to default
193      Graph.
194
195  Returns:
196    Either a summary writer or an empty object which can be used as a
197    summary writer.
198  """
199  if logdir is None:
200    return SummaryWriter(None)
201  with ops.device("cpu:0"):
202    if max_queue is None:
203      max_queue = constant_op.constant(10)
204    if flush_millis is None:
205      flush_millis = constant_op.constant(2 * 60 * 1000)
206    if filename_suffix is None:
207      filename_suffix = constant_op.constant(".v2")
208    return _make_summary_writer(
209        name,
210        gen_summary_ops.create_summary_file_writer,
211        logdir=logdir,
212        max_queue=max_queue,
213        flush_millis=flush_millis,
214        filename_suffix=filename_suffix)
215
216
217def create_db_writer(db_uri,
218                     experiment_name=None,
219                     run_name=None,
220                     user_name=None,
221                     name=None):
222  """Creates a summary database writer in the current context.
223
224  This can be used to write tensors from the execution graph directly
225  to a database. Only SQLite is supported right now. This function
226  will create the schema if it doesn't exist. Entries in the Users,
227  Experiments, and Runs tables will be created automatically if they
228  don't already exist.
229
230  Args:
231    db_uri: For example "file:/tmp/foo.sqlite".
232    experiment_name: Defaults to YYYY-MM-DD in local time if None.
233      Empty string means the Run will not be associated with an
234      Experiment. Can't contain ASCII control characters or <>. Case
235      sensitive.
236    run_name: Defaults to HH:MM:SS in local time if None. Empty string
237      means a Tag will not be associated with any Run. Can't contain
238      ASCII control characters or <>. Case sensitive.
239    user_name: Defaults to system username if None. Empty means the
240      Experiment will not be associated with a User. Must be valid as
241      both a DNS label and Linux username.
242    name: Shared name for this SummaryWriter resource stored to default
243      @{tf.Graph}.
244
245  Returns:
246    A @{tf.contrib.summary.SummaryWriter} instance.
247  """
248  with ops.device("cpu:0"):
249    if experiment_name is None:
250      experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time()))
251    if run_name is None:
252      run_name = time.strftime("%H:%M:%S", time.localtime(time.time()))
253    if user_name is None:
254      user_name = getpass.getuser()
255    experiment_name = _cleanse_string(
256        "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name)
257    run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name)
258    user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name)
259    return _make_summary_writer(
260        name,
261        gen_summary_ops.create_summary_db_writer,
262        db_uri=db_uri,
263        experiment_name=experiment_name,
264        run_name=run_name,
265        user_name=user_name)
266
267
268def _make_summary_writer(name, factory, **kwargs):
269  resource = gen_summary_ops.summary_writer(shared_name=name)
270  # TODO(apassos): Consider doing this instead.
271  # node = factory(resource, **kwargs)
272  # if not context.in_eager_mode():
273  #   ops.get_default_session().run(node)
274  ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
275                        factory(resource, **kwargs))
276  return SummaryWriter(resource)
277
278
279def _cleanse_string(name, pattern, value):
280  if isinstance(value, six.string_types) and pattern.search(value) is None:
281    raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern))
282  return ops.convert_to_tensor(value, dtypes.string)
283
284
285def _nothing():
286  """Convenient else branch for when summaries do not record."""
287  return constant_op.constant(False)
288
289
290def all_summary_ops():
291  """Graph-mode only. Returns all summary ops.
292
293  Please note this excludes @{tf.contrib.summary.graph} ops.
294
295  Returns:
296    The summary ops.
297  """
298  if context.in_eager_mode():
299    return None
300  return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
301
302
303def summary_writer_initializer_op():
304  """Graph-mode only. Returns the list of ops to create all summary writers.
305
306  Returns:
307    The initializer ops.
308
309  Raises:
310    RuntimeError: If in Eager mode.
311  """
312  if context.in_eager_mode():
313    raise RuntimeError(
314        "tf.contrib.summary.summary_writer_initializer_op is only "
315        "supported in graph mode.")
316  return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME)
317
318
319def summary_writer_function(name, tensor, function, family=None):
320  """Helper function to write summaries.
321
322  Args:
323    name: name of the summary
324    tensor: main tensor to form the summary
325    function: function taking a tag and a scope which writes the summary
326    family: optional, the summary's family
327
328  Returns:
329    The result of writing the summary.
330  """
331  def record():
332    with summary_op_util.summary_scope(
333        name, family, values=[tensor]) as (tag, scope):
334      with ops.control_dependencies([function(tag, scope)]):
335        return constant_op.constant(True)
336
337  if context.context().summary_writer_resource is None:
338    return control_flow_ops.no_op()
339  with ops.device("cpu:0"):
340    op = utils.smart_cond(
341        should_record_summaries(), record, _nothing, name="")
342    ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
343  return op
344
345
346def generic(name, tensor, metadata=None, family=None, step=None):
347  """Writes a tensor summary if possible."""
348
349  def function(tag, scope):
350    if metadata is None:
351      serialized_metadata = constant_op.constant("")
352    elif hasattr(metadata, "SerializeToString"):
353      serialized_metadata = constant_op.constant(metadata.SerializeToString())
354    else:
355      serialized_metadata = metadata
356    # Note the identity to move the tensor to the CPU.
357    return gen_summary_ops.write_summary(
358        context.context().summary_writer_resource,
359        _choose_step(step),
360        array_ops.identity(tensor),
361        tag,
362        serialized_metadata,
363        name=scope)
364  return summary_writer_function(name, tensor, function, family=family)
365
366
367def scalar(name, tensor, family=None, step=None):
368  """Writes a scalar summary if possible.
369
370  Unlike @{tf.contrib.summary.generic} this op may change the dtype
371  depending on the writer, for both practical and efficiency concerns.
372
373  Args:
374    name: An arbitrary name for this summary.
375    tensor: A @{tf.Tensor} Must be one of the following types:
376      `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`,
377      `int8`, `uint16`, `half`, `uint32`, `uint64`.
378    family: Optional, the summary's family.
379    step: The `int64` monotonic step variable, which defaults
380      to @{tf.train.get_global_step}.
381
382  Returns:
383    The created @{tf.Operation} or a @{tf.no_op} if summary writing has
384    not been enabled for this context.
385  """
386
387  def function(tag, scope):
388    # Note the identity to move the tensor to the CPU.
389    return gen_summary_ops.write_scalar_summary(
390        context.context().summary_writer_resource,
391        _choose_step(step),
392        tag,
393        array_ops.identity(tensor),
394        name=scope)
395
396  return summary_writer_function(name, tensor, function, family=family)
397
398
399def histogram(name, tensor, family=None, step=None):
400  """Writes a histogram summary if possible."""
401
402  def function(tag, scope):
403    # Note the identity to move the tensor to the CPU.
404    return gen_summary_ops.write_histogram_summary(
405        context.context().summary_writer_resource,
406        _choose_step(step),
407        tag,
408        array_ops.identity(tensor),
409        name=scope)
410
411  return summary_writer_function(name, tensor, function, family=family)
412
413
414def image(name, tensor, bad_color=None, max_images=3, family=None, step=None):
415  """Writes an image summary if possible."""
416
417  def function(tag, scope):
418    bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
419                  if bad_color is None else bad_color)
420    # Note the identity to move the tensor to the CPU.
421    return gen_summary_ops.write_image_summary(
422        context.context().summary_writer_resource,
423        _choose_step(step),
424        tag,
425        array_ops.identity(tensor),
426        bad_color_,
427        max_images,
428        name=scope)
429
430  return summary_writer_function(name, tensor, function, family=family)
431
432
433def audio(name, tensor, sample_rate, max_outputs, family=None, step=None):
434  """Writes an audio summary if possible."""
435
436  def function(tag, scope):
437    # Note the identity to move the tensor to the CPU.
438    return gen_summary_ops.write_audio_summary(
439        context.context().summary_writer_resource,
440        _choose_step(step),
441        tag,
442        array_ops.identity(tensor),
443        sample_rate=sample_rate,
444        max_outputs=max_outputs,
445        name=scope)
446
447  return summary_writer_function(name, tensor, function, family=family)
448
449
450def graph(param, step=None, name=None):
451  """Writes a TensorFlow graph to the summary interface.
452
453  The graph summary is, strictly speaking, not a summary. Conditions
454  like @{tf.contrib.summary.never_record_summaries} do not apply. Only
455  a single graph can be associated with a particular run. If multiple
456  graphs are written, then only the last one will be considered by
457  TensorBoard.
458
459  When not using eager execution mode, the user should consider passing
460  the `graph` parameter to @{tf.contrib.summary.initialize} instead of
461  calling this function. Otherwise special care needs to be taken when
462  using the graph to record the graph.
463
464  Args:
465    param: A @{tf.Tensor} containing a serialized graph proto. When
466      eager execution is enabled, this function will automatically
467      coerce @{tf.Graph}, @{tf.GraphDef}, and string types.
468    step: The global step variable. This doesn't have useful semantics
469      for graph summaries, but is used anyway, due to the structure of
470      event log files. This defaults to the global step.
471    name: A name for the operation (optional).
472
473  Returns:
474    The created @{tf.Operation} or a @{tf.no_op} if summary writing has
475    not been enabled for this context.
476
477  Raises:
478    TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
479  """
480  if not context.in_eager_mode() and not isinstance(param, ops.Tensor):
481    raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
482                    "mode, but was: %s" % type(param))
483  writer = context.context().summary_writer_resource
484  if writer is None:
485    return control_flow_ops.no_op()
486  with ops.device("cpu:0"):
487    if isinstance(param, (ops.Graph, graph_pb2.GraphDef)):
488      tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string)
489    else:
490      tensor = array_ops.identity(param)
491    return gen_summary_ops.write_graph_summary(
492        writer, _choose_step(step), tensor, name=name)
493
494
495_graph = graph  # for functions with a graph parameter
496
497
498def import_event(tensor, name=None):
499  """Writes a @{tf.Event} binary proto.
500
501  When using create_db_writer(), this can be used alongside
502  @{tf.TFRecordReader} to load event logs into the database. Please
503  note that this is lower level than the other summary functions and
504  will ignore any conditions set by methods like
505  @{tf.contrib.summary.should_record_summaries}.
506
507  Args:
508    tensor: A @{tf.Tensor} of type `string` containing a serialized
509      @{tf.Event} proto.
510    name: A name for the operation (optional).
511
512  Returns:
513    The created @{tf.Operation}.
514  """
515  return gen_summary_ops.import_event(
516      context.context().summary_writer_resource, tensor, name=name)
517
518
519def flush(writer=None, name=None):
520  """Forces summary writer to send any buffered data to storage.
521
522  This operation blocks until that finishes.
523
524  Args:
525    writer: The @{tf.contrib.summary.SummaryWriter} resource to flush.
526      The thread default will be used if this parameter is None.
527      Otherwise a @{tf.no_op} is returned.
528    name: A name for the operation (optional).
529
530  Returns:
531    The created @{tf.Operation}.
532  """
533  if writer is None:
534    writer = context.context().summary_writer_resource
535    if writer is None:
536      return control_flow_ops.no_op()
537  return gen_summary_ops.flush_summary_writer(writer, name=name)
538
539
540def eval_dir(model_dir, name=None):
541  """Construct a logdir for an eval summary writer."""
542  return os.path.join(model_dir, "eval" if not name else "eval_" + name)
543
544
545def create_summary_file_writer(*args, **kwargs):
546  """Please use @{tf.contrib.summary.create_file_writer}."""
547  logging.warning("Deprecation Warning: create_summary_file_writer was renamed "
548                  "to create_file_writer")
549  return create_file_writer(*args, **kwargs)
550
551
552def _serialize_graph(arbitrary_graph):
553  if isinstance(arbitrary_graph, ops.Graph):
554    return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString()
555  else:
556    return arbitrary_graph.SerializeToString()
557
558
559def _choose_step(step):
560  if step is None:
561    return training_util.get_or_create_global_step()
562  if not isinstance(step, ops.Tensor):
563    return ops.convert_to_tensor(step, dtypes.int64)
564  return step
565