1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""StatsAggregator for aggregating statistics from `tf.data` pipelines."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tempfile
21
22from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
23from tensorflow.python.ops import summary_ops_v2
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
26
27
28_DEFAULT_MAX_QUEUE = 10
29
30
31@tf_export("data.experimental.StatsAggregator", v1=[])
32@deprecation.deprecated_endpoints("data.experimental.StatsAggregator")
33class StatsAggregatorV2(object):
34  """A stateful resource that aggregates statistics from one or more iterators.
35
36  To record statistics, use one of the custom transformation functions defined
37  in this module when defining your `tf.data.Dataset`. All statistics will be
38  aggregated by the `StatsAggregator` that is associated with a particular
39  iterator (see below). For example, to record the latency of producing each
40  element by iterating over a dataset:
41
42  ```python
43  dataset = ...
44  dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
45  ```
46
47  To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
48  the following pattern:
49
50  ```python
51  aggregator = tf.data.experimental.StatsAggregator()
52  dataset = ...
53
54  # Apply `StatsOptions` to associate `dataset` with `aggregator`.
55  options = tf.data.Options()
56  options.experimental_stats.aggregator = aggregator
57  dataset = dataset.with_options(options)
58  ```
59
60  Note: This interface is experimental and expected to change. In particular,
61  we expect to add other implementations of `StatsAggregator` that provide
62  different ways of exporting statistics, and add more types of statistics.
63  """
64
65  # This deprecation warning on __init__ is necessary to print deprecation
66  # messages.
67  @deprecation.deprecated(
68      None,
69      "Use TF Profiler to analyze performance instead."
70  )
71  def __init__(self):
72    self._resource = ged_ops.stats_aggregator_handle_v2()
73    # There could be a conflict with multiple file writer in the same logdir,
74    # (b/37351340). Possible workarounds till this bug is resolved are a) having
75    # multiple dataset stats specific file inside log_dir and b) get default
76    # summary writer, getting default summary writer quite doesn't solve the
77    # problem as there might be summary writers in log dir not set as default
78    # e.g. in Keras calback.
79    # Creating a summary_writer here could potentially be replaced with getting
80    # the default summary_writer if any, creating it otherwise or a public
81    # method to associate summary writer.
82    self._logdir = tempfile.mkdtemp()
83    self._summary_writer = summary_ops_v2.create_file_writer_v2(
84        self._logdir, max_queue=_DEFAULT_MAX_QUEUE)
85    ged_ops.stats_aggregator_set_summary_writer(self._resource,
86                                                self._summary_writer._resource)  # pylint: disable=protected-access
87
88
89@tf_export(v1=["data.experimental.StatsAggregator"])
90@deprecation.deprecated_endpoints("data.experimental.StatsAggregator")
91class StatsAggregatorV1(object):
92  """A stateful resource that aggregates statistics from one or more iterators.
93
94  To record statistics, use one of the custom transformation functions defined
95  in this module when defining your `tf.data.Dataset`. All statistics will be
96  aggregated by the `StatsAggregator` that is associated with a particular
97  iterator (see below). For example, to record the latency of producing each
98  element by iterating over a dataset:
99
100  ```python
101  dataset = ...
102  dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
103  ```
104
105  To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
106  the following pattern:
107
108  ```python
109  aggregator = tf.data.experimental.StatsAggregator()
110  dataset = ...
111
112  # Apply `StatsOptions` to associate `dataset` with `aggregator`.
113  options = tf.data.Options()
114  options.experimental_stats.aggregator = aggregator
115  dataset = dataset.with_options(options)
116  ```
117
118  To get a protocol buffer summary of the currently aggregated statistics,
119  use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
120  is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
121  so that the summaries will be included with any existing summaries.
122
123  ```python
124  aggregator = tf.data.experimental.StatsAggregator()
125  # ...
126  stats_summary = aggregator.get_summary()
127  tf.compat.v1.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
128  ```
129
130  Note: This interface is experimental and expected to change. In particular,
131  we expect to add other implementations of `StatsAggregator` that provide
132  different ways of exporting statistics, and add more types of statistics.
133  """
134
135  # This deprecation warning on __init__ is necessary to print deprecation
136  # messages.
137  @deprecation.deprecated(
138      None,
139      "Use TF Profiler to analyze performance instead."
140  )
141  def __init__(self):
142    """Creates a `StatsAggregator`."""
143    self._resource = ged_ops.stats_aggregator_handle()
144
145  def get_summary(self):
146    """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
147
148    The returned tensor will contain a serialized `tf.compat.v1.summary.Summary`
149    protocol
150    buffer, which can be used with the standard TensorBoard logging facilities.
151
152    Returns:
153      A scalar string `tf.Tensor` that summarizes the aggregated statistics.
154    """
155    return ged_ops.stats_aggregator_summary(self._resource)
156
157
158# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
159# SummaryWriterInterface, and do not break any users.
160StatsAggregator = StatsAggregatorV1
161