1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h"
26 #include "tensorflow/core/profiler/utils/group_events.h"
27 #include "tensorflow/core/profiler/utils/html_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
30 #include "tensorflow/core/profiler/utils/timespan.h"
31 #include "tensorflow/core/profiler/utils/xplane_schema.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33 
34 namespace tensorflow {
35 namespace profiler {
36 
37 // 50 us from https://www.tensorflow.org/guide/data_performance_analysis
38 const int64 kSlowCallThresholdPs = 50 * 1000000;
39 
40 namespace {
41 
42 // Returns true if the given iterator event is for a root iterator.
IsRootIteratorEvent(const XEventVisitor & iterator_event)43 bool IsRootIteratorEvent(const XEventVisitor& iterator_event) {
44   std::vector<absl::string_view> split_result =
45       absl::StrSplit(iterator_event.Name(), "::");
46   // The root iterator's name contains only its own name (no parent
47   // information).
48   return split_result.size() == 2;
49 }
50 
51 // Returns true if the given iterator event name is for an async iterator.
IsAsyncIterator(absl::string_view iterator_event_name)52 bool IsAsyncIterator(absl::string_view iterator_event_name) {
53   static auto* kAsyncIterators = new absl::flat_hash_set<absl::string_view>(
54       {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample",
55        "MapAndBatch", "DataService", "LegacyParallelInterleave"});
56   return kAsyncIterators->contains(iterator_event_name);
57 }
58 
SetIteratorMetadata(int64 id,const XEventVisitor & event,IteratorMetadata * metadata)59 void SetIteratorMetadata(int64 id, const XEventVisitor& event,
60                          IteratorMetadata* metadata) {
61   metadata->set_id(id);
62   auto parent_id_stat = event.GetStat(StatType::kParentId);
63   if (parent_id_stat.has_value()) {
64     metadata->set_parent_id(parent_id_stat->IntValue());
65   }
66   metadata->set_name(IteratorName(event.Name()));
67   metadata->set_long_name(event.Name().data(), event.Name().size());
68   metadata->set_is_async(IsAsyncIterator(metadata->name()));
69   // TODO(b/161831651): Set params.
70 }
71 
72 // Returns the parent iterator's id if it is a root of a device input
73 // pipeline.
FindDeviceInputPipeline(const XEventVisitor & event)74 absl::optional<int64> FindDeviceInputPipeline(const XEventVisitor& event) {
75   if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) {
76     auto parent_id_stat = event.GetStat(StatType::kParentId);
77     if (parent_id_stat.has_value()) return parent_id_stat->IntValue();
78   }
79   return absl::nullopt;
80 }
81 
82 // Processes EventForest to do the following:
83 // (1) set iterator metadata
84 // (2) find root iterator events
85 // (3) find device input pipeline ids
ProcessEventForest(const EventForest & event_forest,absl::flat_hash_set<int64> * device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)86 void ProcessEventForest(const EventForest& event_forest,
87                         absl::flat_hash_set<int64>* device_input_pipeline_ids,
88                         absl::flat_hash_map<int64, std::vector<EventNode*>>*
89                             root_iterator_event_map,
90                         TfDataStats* tf_data_stats) {
91   const EventNodeMap& event_node_map = event_forest.GetEventNodeMap();
92   auto iterator_event_list =
93       gtl::FindOrNull(event_node_map, HostEventType::kIterator);
94   if (!iterator_event_list) return;
95   for (const auto& iterator_event : *iterator_event_list) {
96     const XEventVisitor& iterator_event_visitor =
97         iterator_event->GetEventVisitor();
98     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
99     if (!iterator_id_stat.has_value()) continue;
100     int64 iterator_id = iterator_id_stat->IntValue();
101     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
102         {iterator_id, IteratorMetadata()});
103     IteratorMetadata& metadata = result.first->second;
104     if (result.second) {
105       // First time processing this iterator.
106       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
107     }
108     if (IsRootIteratorEvent(iterator_event_visitor)) {
109       // Record root iterator events.
110       (*root_iterator_event_map)[iterator_id].push_back(iterator_event.get());
111     }
112   }
113   auto device_input_pipeline_second_iterator_events = gtl::FindOrNull(
114       event_node_map, HostEventType::kDeviceInputPipelineSecondIterator);
115   if (!device_input_pipeline_second_iterator_events) return;
116   for (const auto& iterator_event :
117        *device_input_pipeline_second_iterator_events) {
118     const XEventVisitor& iterator_event_visitor =
119         iterator_event->GetEventVisitor();
120     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
121     if (!iterator_id_stat.has_value()) continue;
122     int64 iterator_id = iterator_id_stat->IntValue();
123     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
124         {iterator_id, IteratorMetadata()});
125     IteratorMetadata& metadata = result.first->second;
126     if (result.second) {
127       // First time processing this iterator.
128       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
129       // Find and record device input pipeline ids.
130       absl::optional<int64> device_input_pipeline_id =
131           FindDeviceInputPipeline(iterator_event_visitor);
132       if (device_input_pipeline_id.has_value()) {
133         device_input_pipeline_ids->insert(*device_input_pipeline_id);
134       }
135     }
136   }
137 }
138 
SetInputPipelineMetadata(int64 id,int64 name_id,bool is_device_input_pipeline,InputPipelineMetadata * metadata)139 void SetInputPipelineMetadata(int64 id, int64 name_id,
140                               bool is_device_input_pipeline,
141                               InputPipelineMetadata* metadata) {
142   constexpr absl::string_view kHostInputPipelinePrefix = "Host:";
143   constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:";
144   metadata->set_id(id);
145   if (is_device_input_pipeline) {
146     metadata->set_type(InputPipelineMetadata::DEVICE);
147     metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id));
148   } else {
149     metadata->set_type(InputPipelineMetadata::HOST);
150     metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id));
151   }
152 }
153 
ProcessIteratorEvent(const EventNode & iterator_event,InputPipelineStat * input_pipeline_stat,bool is_blocking)154 void ProcessIteratorEvent(const EventNode& iterator_event,
155                           InputPipelineStat* input_pipeline_stat,
156                           bool is_blocking) {
157   const XEventVisitor& visitor = iterator_event.GetEventVisitor();
158   auto iterator_id_stat = visitor.GetStat(StatType::kStepId);
159   if (!iterator_id_stat.has_value()) return;
160   int64 iterator_id = iterator_id_stat->IntValue();
161   auto result = input_pipeline_stat->mutable_iterator_stats()->insert(
162       {iterator_id, IteratorStat()});
163   IteratorStat& iterator_stat = result.first->second;
164   if (result.second) {
165     iterator_stat.set_id(iterator_id);
166     iterator_stat.set_start_time_ps(visitor.TimestampPs());
167   }
168   iterator_stat.set_duration_ps(iterator_stat.duration_ps() +
169                                 visitor.DurationPs());
170   int64 self_time_ps = visitor.DurationPs();
171   Timespan self_time_span = visitor.GetTimespan();
172   for (EventNode* child : iterator_event.GetChildren()) {
173     const XEventVisitor& child_visitor = child->GetEventVisitor();
174     if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) {
175       int64 overlap_duration_ps =
176           self_time_span.OverlappedDurationPs(child_visitor.GetTimespan());
177       ProcessIteratorEvent(*child, input_pipeline_stat,
178                            is_blocking && overlap_duration_ps);
179       // Note: Assume no overlap between child events.
180       self_time_ps -= overlap_duration_ps;
181     }
182   }
183   iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps);
184   iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking);
185   iterator_stat.set_num_calls(iterator_stat.num_calls() + 1);
186 }
187 
SetBottleneckIteratorId(InputPipelineStat * input_pipeline_stat)188 void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) {
189   int64 bottleneck_iterator_id = 0;
190   int64 max_self_time = 0;
191   for (const auto& pair : input_pipeline_stat->iterator_stats()) {
192     const auto& id = pair.first;
193     const auto& iterator_stat = pair.second;
194     if (iterator_stat.is_blocking() &&
195         iterator_stat.self_time_ps() > max_self_time) {
196       bottleneck_iterator_id = id;
197       max_self_time = iterator_stat.self_time_ps();
198     }
199   }
200   input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id);
201   input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time);
202 }
203 
ProcessInputPipelines(const absl::flat_hash_set<int64> & device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)204 void ProcessInputPipelines(
205     const absl::flat_hash_set<int64>& device_input_pipeline_ids,
206     absl::flat_hash_map<int64, std::vector<EventNode*>>*
207         root_iterator_event_map,
208     TfDataStats* tf_data_stats) {
209   auto* input_pipelines = tf_data_stats->mutable_input_pipelines();
210   int64 num_host_input_pipelines = 0;
211   int64 num_device_input_pipelines = 0;
212   for (auto& id_and_events : *root_iterator_event_map) {
213     auto& root_iterator_id = id_and_events.first;
214     auto& root_iterator_events = id_and_events.second;
215     absl::c_sort(root_iterator_events,
216                  [](const EventNode* lhs, const EventNode* rhs) {
217                    return lhs->GetEventVisitor().DurationPs() >
218                           rhs->GetEventVisitor().DurationPs();
219                  });
220     auto result =
221         input_pipelines->insert({root_iterator_id, InputPipelineStats()});
222     InputPipelineStats& input_pipeline_stats = result.first->second;
223     InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata();
224     if (result.second) {
225       bool is_device_input_pipeline =
226           device_input_pipeline_ids.contains(root_iterator_id);
227       int64 name_id = is_device_input_pipeline ? num_device_input_pipelines++
228                                                : num_host_input_pipelines++;
229       SetInputPipelineMetadata(root_iterator_id, name_id,
230                                is_device_input_pipeline, metadata);
231     }
232     int64 sum_latency_ps = 0;
233     int64 min_latency_ps = INT64_MAX;
234     int64 max_latency_ps = 0;
235     int64 num_slow_calls = 0;
236     for (const EventNode* root_iterator_event : root_iterator_events) {
237       InputPipelineStat* stat = input_pipeline_stats.add_stats();
238       ProcessIteratorEvent(*root_iterator_event, stat,
239                            /*is_blocking*/ true);
240       SetBottleneckIteratorId(stat);
241       int64 latency_ps = root_iterator_event->GetEventVisitor().DurationPs();
242       sum_latency_ps += latency_ps;
243       min_latency_ps = std::min(min_latency_ps, latency_ps);
244       max_latency_ps = std::max(max_latency_ps, latency_ps);
245       if (latency_ps > kSlowCallThresholdPs) num_slow_calls++;
246     }
247     input_pipeline_stats.set_avg_latency_ps(sum_latency_ps /
248                                             root_iterator_events.size());
249     input_pipeline_stats.set_min_latency_ps(min_latency_ps);
250     input_pipeline_stats.set_max_latency_ps(max_latency_ps);
251     input_pipeline_stats.set_num_slow_calls(num_slow_calls);
252   }
253 }
254 
SetBottleneckAnalysis(CombinedTfDataStats * combined_tf_data_stats)255 void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) {
256   struct InputPipeline {
257     InputPipeline(absl::string_view host_name,
258                   absl::string_view input_pipeline_name, int64 max_latency_ps,
259                   absl::string_view iterator_name,
260                   absl::string_view iterator_long_name,
261                   int64 iterator_latency_ps)
262         : host_name(host_name),
263           input_pipeline_name(input_pipeline_name),
264           max_latency_ps(max_latency_ps),
265           iterator_name(iterator_name),
266           iterator_long_name(iterator_long_name),
267           iterator_latency_ps(iterator_latency_ps) {}
268     absl::string_view host_name;
269     absl::string_view input_pipeline_name;
270     int64 max_latency_ps;
271     absl::string_view iterator_name;
272     absl::string_view iterator_long_name;
273     int64 iterator_latency_ps;
274 
275     bool operator<(const InputPipeline& rhs) const {
276       return max_latency_ps > rhs.max_latency_ps;
277     }
278   };
279   std::vector<InputPipeline> slow_input_pipelines;
280   for (const auto& host_name_and_tf_data_stats :
281        combined_tf_data_stats->tf_data_stats()) {
282     absl::string_view host_name = host_name_and_tf_data_stats.first;
283     const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second;
284     for (const auto& id_and_stats : tf_data_stats.input_pipelines()) {
285       const InputPipelineStats& input_pipeline_stats = id_and_stats.second;
286       if (input_pipeline_stats.metadata().type() ==
287           InputPipelineMetadata::DEVICE) {
288         // Ignore device input pipelines.
289         continue;
290       }
291       // Choose the slowest execution trace of the input pipeline.
292       // `input_pipeline_stats.stats` is already sorted so choose the first one.
293       const InputPipelineStat& input_pipeline_stat =
294           input_pipeline_stats.stats(0);
295       const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at(
296           input_pipeline_stat.bottleneck_iterator_id());
297       slow_input_pipelines.emplace_back(
298           host_name, input_pipeline_stats.metadata().name(),
299           input_pipeline_stats.max_latency_ps(), metadata.name(),
300           metadata.long_name(),
301           input_pipeline_stat.bottleneck_iterator_latency_ps());
302     }
303   }
304   std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end());
305   for (const auto& input_pipeline : slow_input_pipelines) {
306     TfDataBottleneckAnalysis* bottleneck_analysis =
307         combined_tf_data_stats->add_bottleneck_analysis();
308     bottleneck_analysis->set_host(input_pipeline.host_name.data(),
309                                   input_pipeline.host_name.size());
310     bottleneck_analysis->set_input_pipeline(
311         input_pipeline.input_pipeline_name.data(),
312         input_pipeline.input_pipeline_name.size());
313     bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps);
314     bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(),
315                                            input_pipeline.iterator_name.size());
316     bottleneck_analysis->set_iterator_long_name(
317         input_pipeline.iterator_long_name.data(),
318         input_pipeline.iterator_long_name.size());
319     bottleneck_analysis->set_iterator_latency_ps(
320         input_pipeline.iterator_latency_ps);
321   }
322 }
323 
GetSuggestion(BottleneckType type)324 std::string GetSuggestion(BottleneckType type) {
325   constexpr absl::string_view kPlaybookLink =
326       "https://www.tensorflow.org/guide/data_performance_analysis";
327   constexpr absl::string_view kPlaybookSourceDatasetLink =
328       "https://www.tensorflow.org/guide/"
329       "data_performance_analysis#source_datasets";
330   constexpr absl::string_view kPlaybookCpuUtilizationLink =
331       "https://www.tensorflow.org/guide/"
332       "data_performance_analysis#3_are_you_reaching_high_cpu_utilization";
333   constexpr absl::string_view kPlaybookTransformationLink =
334       "https://www.tensorflow.org/guide/"
335       "data_performance_analysis#transformation_datasets";
336   constexpr absl::string_view kTfGuideParallelDataExtractionLink =
337       "https://www.tensorflow.org/guide/"
338       "data_performance#parallelizing_data_extraction";
339   constexpr absl::string_view kTfGuideParallelTransformationLink =
340       "https://www.tensorflow.org/guide/"
341       "data_performance#parallelizing_data_transformation";
342   constexpr absl::string_view kTfGuideCacheLink =
343       "https://www.tensorflow.org/guide/data_performance#caching";
344   constexpr absl::string_view kTfDataServiceLink =
345       "https://www.tensorflow.org/api_docs/python/tf/data/experimental/"
346       "service?version=nightly";
347   switch (type) {
348     case BottleneckType::kSlowSource:
349       return absl::StrFormat(
350           "1. Check the locality of a host and input data. Ideally, they "
351           "should be in the same cell (or very close, like the same "
352           "region).<br/>"
353           "2. Parallelize reading from this dataset source. See %s and %s for "
354           "more details.<br/>",
355           AnchorElement(kPlaybookSourceDatasetLink, "here"),
356           AnchorElement(kTfGuideParallelDataExtractionLink, "here"));
357     case BottleneckType::kSlowDataService:
358       return absl::StrFormat(
359           "1. Fetching data from tf.data service took a while. Profile the "
360           "tf.data service worker to analyze the issue further.<br/>"
361           "2. See %s for more details on tf.data service.<br/>"
362           "3. See %s for other suggestions.",
363           AnchorElement(kTfDataServiceLink, "this"),
364           AnchorElement(kPlaybookLink, "this"));
365     case BottleneckType::kSlowRemoteSource:
366       return absl::StrFormat(
367           "1. The remote data source is slow. Profile its host to analyze the "
368           "issue further.<br/>"
369           "2. See %s for other suggestions.",
370           AnchorElement(kPlaybookLink, "this"));
371     case BottleneckType::kSlowTransformationWithParallelVersion:
372       return absl::StrFormat(
373           "1. Parallelize this transformation by setting "
374           "<code>num_parallel_calls=tf.data.experimental.AUTOTUNE</code>. See "
375           "%s for more details.<br/>"
376           "2. Consider adding <code>cache</code> after this transformation if "
377           "your data fits into memory and it is appropriate (e.g., there is no "
378           "randomness in upstream transformations like <code>shuffle</code>). "
379           "See %s for more details.<br/>"
380           "3. Find more resources %s.",
381           AnchorElement(kTfGuideParallelTransformationLink, "this"),
382           AnchorElement(kTfGuideCacheLink, "this"),
383           AnchorElement(kPlaybookTransformationLink, "here"));
384     case BottleneckType::kSlowTransformationWithoutParallelVersion:
385       return absl::StrFormat(
386           "1. This transformation is inherently sequential. Add outer "
387           "parallelism by running multiple copies of the input pipeline over "
388           "sharded inputs and combining the results. See %s for more "
389           "details.<br/>"
390           "2. Consider adding <code>cache</code> after this transformation if "
391           "your data fits into memory and it is appropriate (e.g., there is no "
392           "randomness in upstream transformations like <code>shuffle</code>). "
393           "See %s for more details.<br/>"
394           "3. Find more resources %s.",
395           AnchorElement(kPlaybookTransformationLink, "this"),
396           AnchorElement(kTfGuideCacheLink, "this"),
397           AnchorElement(kPlaybookCpuUtilizationLink, "here"));
398     default:
399       return absl::StrFormat("See %s for suggestions.",
400                              AnchorElement(kPlaybookLink, "this"));
401   }
402 }
403 
SetSuggestion(CombinedTfDataStats * combined_tf_data_stats)404 void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) {
405   for (TfDataBottleneckAnalysis& bottleneck_analysis :
406        *combined_tf_data_stats->mutable_bottleneck_analysis()) {
407     bottleneck_analysis.set_suggestion(
408         GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name())));
409   }
410 }
411 
SetSummary(CombinedTfDataStats * combined_tf_data_stats)412 void SetSummary(CombinedTfDataStats* combined_tf_data_stats) {
413   int64 max_latency_ps = 0;
414   if (combined_tf_data_stats->bottleneck_analysis_size()) {
415     max_latency_ps =
416         combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps();
417   }
418   if (max_latency_ps > kSlowCallThresholdPs) {
419     combined_tf_data_stats->set_is_input_bound(true);
420     combined_tf_data_stats->set_summary(
421         "Your profile has a tf.data input pipeline slower than 50 us. For each "
422         "slow input pipeline, below shows a bottleneck in the input pipeline "
423         "and a suggestion on how to fix it.");
424   } else if (max_latency_ps > 0) {
425     combined_tf_data_stats->set_is_input_bound(false);
426     combined_tf_data_stats->set_summary(
427         "Your profile does not have any tf.data input pipeline slower than 50 "
428         "us. Your job could be still input bound if this profile didn't "
429         "capture all workers.");
430   } else {
431     combined_tf_data_stats->set_is_input_bound(false);
432     combined_tf_data_stats->set_summary(
433         "No tf.data activitiy captured in your profile. If your job uses "
434         "tf.data, try to capture a longer profile.");
435   }
436 }
437 
438 }  // namespace
439 
GetBottleneckType(absl::string_view bottleneck_iterator_name)440 BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) {
441   static auto* kBottleneckTypeMap = new absl::flat_hash_map<absl::string_view,
442                                                             BottleneckType>(
443       {// Read from storage.
444        {"TFRecord", BottleneckType::kSlowSource},
445        {"SSTable", BottleneckType::kSlowSource},
446        {"RecordIO", BottleneckType::kSlowSource},
447        {"Spanner", BottleneckType::kSlowSource},
448        {"TFColumn", BottleneckType::kSlowSource},
449        {"SleepwalkRemoteDataset", BottleneckType::kSlowSource},
450        {"TextLine", BottleneckType::kSlowSource},
451        {"StitchedTimelineDataset", BottleneckType::kSlowSource},
452        {"DateKeyDataset", BottleneckType::kSlowSource},
453        {"CapacitorProto", BottleneckType::kSlowSource},
454        {"LMDB", BottleneckType::kSlowSource},
455        {"ExternalDataset", BottleneckType::kSlowSource},
456        {"PearModel", BottleneckType::kSlowSource},
457        {"FixedLengthRecordV2", BottleneckType::kSlowSource},
458        // Read from local memory.
459        {"FromTensor", BottleneckType::kSlowSource},
460        {"TensorSlice", BottleneckType::kSlowSource},
461        {"Generator", BottleneckType::kSlowSource},
462        {"SyntheticDatasetOp", BottleneckType::kSlowSource},
463        // tf.data service.
464        {"DataService", BottleneckType::kSlowDataService},
465        // Read from remote memory.
466        {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource},
467        {"ReverbDataset", BottleneckType::kSlowRemoteSource},
468        {"DatasetSampleGame", BottleneckType::kSlowRemoteSource},
469        {"Courier", BottleneckType::kSlowRemoteSource},
470        {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource},
471        // Transformations with parallel version.
472        {"Map", BottleneckType::kSlowTransformationWithParallelVersion},
473        {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion},
474        // Transformations without parallel version.
475        {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion},
476        {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion},
477        {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}});
478   if (auto type =
479           gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) {
480     return *type;
481   }
482   return BottleneckType::kOther;
483 }
484 
Add(absl::string_view host_name,XPlane * host_plane)485 void CombinedTfDataStatsBuilder::Add(absl::string_view host_name,
486                                      XPlane* host_plane) {
487   TfDataStats& tf_data_stats =
488       (*combined_tf_data_stats_
489             ->mutable_tf_data_stats())[std::string(host_name)];
490   EventForest event_forest;
491   event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane});
492   event_forest.ConnectEvents();
493   event_forest.ConnectTfDataEvents();
494   absl::flat_hash_set<int64> device_input_pipeline_ids;
495   absl::flat_hash_map<int64, std::vector<EventNode*>> root_iterator_event_map;
496   ProcessEventForest(event_forest, &device_input_pipeline_ids,
497                      &root_iterator_event_map, &tf_data_stats);
498   ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map,
499                         &tf_data_stats);
500 }
501 
Finalize()502 void CombinedTfDataStatsBuilder::Finalize() {
503   SetBottleneckAnalysis(combined_tf_data_stats_);
504   if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_);
505   SetSummary(combined_tf_data_stats_);
506 }
507 
508 }  // namespace profiler
509 }  // namespace tensorflow
510