1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""StatsAggregator for aggregating statistics from `tf.data` pipelines.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import tempfile 21 22from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 23from tensorflow.python.ops import summary_ops_v2 24from tensorflow.python.util import deprecation 25from tensorflow.python.util.tf_export import tf_export 26 27 28_DEFAULT_MAX_QUEUE = 10 29 30 31@tf_export("data.experimental.StatsAggregator", v1=[]) 32@deprecation.deprecated_endpoints("data.experimental.StatsAggregator") 33class StatsAggregatorV2(object): 34 """A stateful resource that aggregates statistics from one or more iterators. 35 36 To record statistics, use one of the custom transformation functions defined 37 in this module when defining your `tf.data.Dataset`. All statistics will be 38 aggregated by the `StatsAggregator` that is associated with a particular 39 iterator (see below). For example, to record the latency of producing each 40 element by iterating over a dataset: 41 42 ```python 43 dataset = ... 44 dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes")) 45 ``` 46 47 To associate a `StatsAggregator` with a `tf.data.Dataset` object, use 48 the following pattern: 49 50 ```python 51 aggregator = tf.data.experimental.StatsAggregator() 52 dataset = ... 53 54 # Apply `StatsOptions` to associate `dataset` with `aggregator`. 55 options = tf.data.Options() 56 options.experimental_stats.aggregator = aggregator 57 dataset = dataset.with_options(options) 58 ``` 59 60 Note: This interface is experimental and expected to change. In particular, 61 we expect to add other implementations of `StatsAggregator` that provide 62 different ways of exporting statistics, and add more types of statistics. 63 """ 64 65 # This deprecation warning on __init__ is necessary to print deprecation 66 # messages. 67 @deprecation.deprecated( 68 None, 69 "Use TF Profiler to analyze performance instead." 70 ) 71 def __init__(self): 72 self._resource = ged_ops.stats_aggregator_handle_v2() 73 # There could be a conflict with multiple file writer in the same logdir, 74 # (b/37351340). Possible workarounds till this bug is resolved are a) having 75 # multiple dataset stats specific file inside log_dir and b) get default 76 # summary writer, getting default summary writer quite doesn't solve the 77 # problem as there might be summary writers in log dir not set as default 78 # e.g. in Keras calback. 79 # Creating a summary_writer here could potentially be replaced with getting 80 # the default summary_writer if any, creating it otherwise or a public 81 # method to associate summary writer. 82 self._logdir = tempfile.mkdtemp() 83 self._summary_writer = summary_ops_v2.create_file_writer_v2( 84 self._logdir, max_queue=_DEFAULT_MAX_QUEUE) 85 ged_ops.stats_aggregator_set_summary_writer(self._resource, 86 self._summary_writer._resource) # pylint: disable=protected-access 87 88 89@tf_export(v1=["data.experimental.StatsAggregator"]) 90@deprecation.deprecated_endpoints("data.experimental.StatsAggregator") 91class StatsAggregatorV1(object): 92 """A stateful resource that aggregates statistics from one or more iterators. 93 94 To record statistics, use one of the custom transformation functions defined 95 in this module when defining your `tf.data.Dataset`. All statistics will be 96 aggregated by the `StatsAggregator` that is associated with a particular 97 iterator (see below). For example, to record the latency of producing each 98 element by iterating over a dataset: 99 100 ```python 101 dataset = ... 102 dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes")) 103 ``` 104 105 To associate a `StatsAggregator` with a `tf.data.Dataset` object, use 106 the following pattern: 107 108 ```python 109 aggregator = tf.data.experimental.StatsAggregator() 110 dataset = ... 111 112 # Apply `StatsOptions` to associate `dataset` with `aggregator`. 113 options = tf.data.Options() 114 options.experimental_stats.aggregator = aggregator 115 dataset = dataset.with_options(options) 116 ``` 117 118 To get a protocol buffer summary of the currently aggregated statistics, 119 use the `StatsAggregator.get_summary()` tensor. The easiest way to do this 120 is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection, 121 so that the summaries will be included with any existing summaries. 122 123 ```python 124 aggregator = tf.data.experimental.StatsAggregator() 125 # ... 126 stats_summary = aggregator.get_summary() 127 tf.compat.v1.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) 128 ``` 129 130 Note: This interface is experimental and expected to change. In particular, 131 we expect to add other implementations of `StatsAggregator` that provide 132 different ways of exporting statistics, and add more types of statistics. 133 """ 134 135 # This deprecation warning on __init__ is necessary to print deprecation 136 # messages. 137 @deprecation.deprecated( 138 None, 139 "Use TF Profiler to analyze performance instead." 140 ) 141 def __init__(self): 142 """Creates a `StatsAggregator`.""" 143 self._resource = ged_ops.stats_aggregator_handle() 144 145 def get_summary(self): 146 """Returns a string `tf.Tensor` that summarizes the aggregated statistics. 147 148 The returned tensor will contain a serialized `tf.compat.v1.summary.Summary` 149 protocol 150 buffer, which can be used with the standard TensorBoard logging facilities. 151 152 Returns: 153 A scalar string `tf.Tensor` that summarizes the aggregated statistics. 154 """ 155 return ged_ops.stats_aggregator_summary(self._resource) 156 157 158# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable 159# SummaryWriterInterface, and do not break any users. 160StatsAggregator = StatsAggregatorV1 161