1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h"
26 #include "tensorflow/core/profiler/utils/group_events.h"
27 #include "tensorflow/core/profiler/utils/html_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
30 #include "tensorflow/core/profiler/utils/timespan.h"
31 #include "tensorflow/core/profiler/utils/xplane_schema.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33
34 namespace tensorflow {
35 namespace profiler {
36
37 // 50 us from https://www.tensorflow.org/guide/data_performance_analysis
38 const int64 kSlowCallThresholdPs = 50 * 1000000;
39
40 namespace {
41
42 // Returns true if the given iterator event is for a root iterator.
IsRootIteratorEvent(const XEventVisitor & iterator_event)43 bool IsRootIteratorEvent(const XEventVisitor& iterator_event) {
44 std::vector<absl::string_view> split_result =
45 absl::StrSplit(iterator_event.Name(), "::");
46 // The root iterator's name contains only its own name (no parent
47 // information).
48 return split_result.size() == 2;
49 }
50
51 // Returns true if the given iterator event name is for an async iterator.
IsAsyncIterator(absl::string_view iterator_event_name)52 bool IsAsyncIterator(absl::string_view iterator_event_name) {
53 static auto* kAsyncIterators = new absl::flat_hash_set<absl::string_view>(
54 {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample",
55 "MapAndBatch", "DataService", "LegacyParallelInterleave"});
56 return kAsyncIterators->contains(iterator_event_name);
57 }
58
SetIteratorMetadata(int64 id,const XEventVisitor & event,IteratorMetadata * metadata)59 void SetIteratorMetadata(int64 id, const XEventVisitor& event,
60 IteratorMetadata* metadata) {
61 metadata->set_id(id);
62 auto parent_id_stat = event.GetStat(StatType::kParentId);
63 if (parent_id_stat.has_value()) {
64 metadata->set_parent_id(parent_id_stat->IntValue());
65 }
66 metadata->set_name(IteratorName(event.Name()));
67 metadata->set_long_name(event.Name().data(), event.Name().size());
68 metadata->set_is_async(IsAsyncIterator(metadata->name()));
69 // TODO(b/161831651): Set params.
70 }
71
72 // Returns the parent iterator's id if it is a root of a device input
73 // pipeline.
FindDeviceInputPipeline(const XEventVisitor & event)74 absl::optional<int64> FindDeviceInputPipeline(const XEventVisitor& event) {
75 if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) {
76 auto parent_id_stat = event.GetStat(StatType::kParentId);
77 if (parent_id_stat.has_value()) return parent_id_stat->IntValue();
78 }
79 return absl::nullopt;
80 }
81
82 // Processes EventForest to do the following:
83 // (1) set iterator metadata
84 // (2) find root iterator events
85 // (3) find device input pipeline ids
ProcessEventForest(const EventForest & event_forest,absl::flat_hash_set<int64> * device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)86 void ProcessEventForest(const EventForest& event_forest,
87 absl::flat_hash_set<int64>* device_input_pipeline_ids,
88 absl::flat_hash_map<int64, std::vector<EventNode*>>*
89 root_iterator_event_map,
90 TfDataStats* tf_data_stats) {
91 const EventNodeMap& event_node_map = event_forest.GetEventNodeMap();
92 auto iterator_event_list =
93 gtl::FindOrNull(event_node_map, HostEventType::kIterator);
94 if (!iterator_event_list) return;
95 for (const auto& iterator_event : *iterator_event_list) {
96 const XEventVisitor& iterator_event_visitor =
97 iterator_event->GetEventVisitor();
98 auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
99 if (!iterator_id_stat.has_value()) continue;
100 int64 iterator_id = iterator_id_stat->IntValue();
101 auto result = tf_data_stats->mutable_iterator_metadata()->insert(
102 {iterator_id, IteratorMetadata()});
103 IteratorMetadata& metadata = result.first->second;
104 if (result.second) {
105 // First time processing this iterator.
106 SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
107 }
108 if (IsRootIteratorEvent(iterator_event_visitor)) {
109 // Record root iterator events.
110 (*root_iterator_event_map)[iterator_id].push_back(iterator_event.get());
111 }
112 }
113 auto device_input_pipeline_second_iterator_events = gtl::FindOrNull(
114 event_node_map, HostEventType::kDeviceInputPipelineSecondIterator);
115 if (!device_input_pipeline_second_iterator_events) return;
116 for (const auto& iterator_event :
117 *device_input_pipeline_second_iterator_events) {
118 const XEventVisitor& iterator_event_visitor =
119 iterator_event->GetEventVisitor();
120 auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
121 if (!iterator_id_stat.has_value()) continue;
122 int64 iterator_id = iterator_id_stat->IntValue();
123 auto result = tf_data_stats->mutable_iterator_metadata()->insert(
124 {iterator_id, IteratorMetadata()});
125 IteratorMetadata& metadata = result.first->second;
126 if (result.second) {
127 // First time processing this iterator.
128 SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
129 // Find and record device input pipeline ids.
130 absl::optional<int64> device_input_pipeline_id =
131 FindDeviceInputPipeline(iterator_event_visitor);
132 if (device_input_pipeline_id.has_value()) {
133 device_input_pipeline_ids->insert(*device_input_pipeline_id);
134 }
135 }
136 }
137 }
138
SetInputPipelineMetadata(int64 id,int64 name_id,bool is_device_input_pipeline,InputPipelineMetadata * metadata)139 void SetInputPipelineMetadata(int64 id, int64 name_id,
140 bool is_device_input_pipeline,
141 InputPipelineMetadata* metadata) {
142 constexpr absl::string_view kHostInputPipelinePrefix = "Host:";
143 constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:";
144 metadata->set_id(id);
145 if (is_device_input_pipeline) {
146 metadata->set_type(InputPipelineMetadata::DEVICE);
147 metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id));
148 } else {
149 metadata->set_type(InputPipelineMetadata::HOST);
150 metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id));
151 }
152 }
153
ProcessIteratorEvent(const EventNode & iterator_event,InputPipelineStat * input_pipeline_stat,bool is_blocking)154 void ProcessIteratorEvent(const EventNode& iterator_event,
155 InputPipelineStat* input_pipeline_stat,
156 bool is_blocking) {
157 const XEventVisitor& visitor = iterator_event.GetEventVisitor();
158 auto iterator_id_stat = visitor.GetStat(StatType::kStepId);
159 if (!iterator_id_stat.has_value()) return;
160 int64 iterator_id = iterator_id_stat->IntValue();
161 auto result = input_pipeline_stat->mutable_iterator_stats()->insert(
162 {iterator_id, IteratorStat()});
163 IteratorStat& iterator_stat = result.first->second;
164 if (result.second) {
165 iterator_stat.set_id(iterator_id);
166 iterator_stat.set_start_time_ps(visitor.TimestampPs());
167 }
168 iterator_stat.set_duration_ps(iterator_stat.duration_ps() +
169 visitor.DurationPs());
170 int64 self_time_ps = visitor.DurationPs();
171 Timespan self_time_span = visitor.GetTimespan();
172 for (EventNode* child : iterator_event.GetChildren()) {
173 const XEventVisitor& child_visitor = child->GetEventVisitor();
174 if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) {
175 int64 overlap_duration_ps =
176 self_time_span.OverlappedDurationPs(child_visitor.GetTimespan());
177 ProcessIteratorEvent(*child, input_pipeline_stat,
178 is_blocking && overlap_duration_ps);
179 // Note: Assume no overlap between child events.
180 self_time_ps -= overlap_duration_ps;
181 }
182 }
183 iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps);
184 iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking);
185 iterator_stat.set_num_calls(iterator_stat.num_calls() + 1);
186 }
187
SetBottleneckIteratorId(InputPipelineStat * input_pipeline_stat)188 void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) {
189 int64 bottleneck_iterator_id = 0;
190 int64 max_self_time = 0;
191 for (const auto& pair : input_pipeline_stat->iterator_stats()) {
192 const auto& id = pair.first;
193 const auto& iterator_stat = pair.second;
194 if (iterator_stat.is_blocking() &&
195 iterator_stat.self_time_ps() > max_self_time) {
196 bottleneck_iterator_id = id;
197 max_self_time = iterator_stat.self_time_ps();
198 }
199 }
200 input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id);
201 input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time);
202 }
203
ProcessInputPipelines(const absl::flat_hash_set<int64> & device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)204 void ProcessInputPipelines(
205 const absl::flat_hash_set<int64>& device_input_pipeline_ids,
206 absl::flat_hash_map<int64, std::vector<EventNode*>>*
207 root_iterator_event_map,
208 TfDataStats* tf_data_stats) {
209 auto* input_pipelines = tf_data_stats->mutable_input_pipelines();
210 int64 num_host_input_pipelines = 0;
211 int64 num_device_input_pipelines = 0;
212 for (auto& id_and_events : *root_iterator_event_map) {
213 auto& root_iterator_id = id_and_events.first;
214 auto& root_iterator_events = id_and_events.second;
215 absl::c_sort(root_iterator_events,
216 [](const EventNode* lhs, const EventNode* rhs) {
217 return lhs->GetEventVisitor().DurationPs() >
218 rhs->GetEventVisitor().DurationPs();
219 });
220 auto result =
221 input_pipelines->insert({root_iterator_id, InputPipelineStats()});
222 InputPipelineStats& input_pipeline_stats = result.first->second;
223 InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata();
224 if (result.second) {
225 bool is_device_input_pipeline =
226 device_input_pipeline_ids.contains(root_iterator_id);
227 int64 name_id = is_device_input_pipeline ? num_device_input_pipelines++
228 : num_host_input_pipelines++;
229 SetInputPipelineMetadata(root_iterator_id, name_id,
230 is_device_input_pipeline, metadata);
231 }
232 int64 sum_latency_ps = 0;
233 int64 min_latency_ps = INT64_MAX;
234 int64 max_latency_ps = 0;
235 int64 num_slow_calls = 0;
236 for (const EventNode* root_iterator_event : root_iterator_events) {
237 InputPipelineStat* stat = input_pipeline_stats.add_stats();
238 ProcessIteratorEvent(*root_iterator_event, stat,
239 /*is_blocking*/ true);
240 SetBottleneckIteratorId(stat);
241 int64 latency_ps = root_iterator_event->GetEventVisitor().DurationPs();
242 sum_latency_ps += latency_ps;
243 min_latency_ps = std::min(min_latency_ps, latency_ps);
244 max_latency_ps = std::max(max_latency_ps, latency_ps);
245 if (latency_ps > kSlowCallThresholdPs) num_slow_calls++;
246 }
247 input_pipeline_stats.set_avg_latency_ps(sum_latency_ps /
248 root_iterator_events.size());
249 input_pipeline_stats.set_min_latency_ps(min_latency_ps);
250 input_pipeline_stats.set_max_latency_ps(max_latency_ps);
251 input_pipeline_stats.set_num_slow_calls(num_slow_calls);
252 }
253 }
254
SetBottleneckAnalysis(CombinedTfDataStats * combined_tf_data_stats)255 void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) {
256 struct InputPipeline {
257 InputPipeline(absl::string_view host_name,
258 absl::string_view input_pipeline_name, int64 max_latency_ps,
259 absl::string_view iterator_name,
260 absl::string_view iterator_long_name,
261 int64 iterator_latency_ps)
262 : host_name(host_name),
263 input_pipeline_name(input_pipeline_name),
264 max_latency_ps(max_latency_ps),
265 iterator_name(iterator_name),
266 iterator_long_name(iterator_long_name),
267 iterator_latency_ps(iterator_latency_ps) {}
268 absl::string_view host_name;
269 absl::string_view input_pipeline_name;
270 int64 max_latency_ps;
271 absl::string_view iterator_name;
272 absl::string_view iterator_long_name;
273 int64 iterator_latency_ps;
274
275 bool operator<(const InputPipeline& rhs) const {
276 return max_latency_ps > rhs.max_latency_ps;
277 }
278 };
279 std::vector<InputPipeline> slow_input_pipelines;
280 for (const auto& host_name_and_tf_data_stats :
281 combined_tf_data_stats->tf_data_stats()) {
282 absl::string_view host_name = host_name_and_tf_data_stats.first;
283 const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second;
284 for (const auto& id_and_stats : tf_data_stats.input_pipelines()) {
285 const InputPipelineStats& input_pipeline_stats = id_and_stats.second;
286 if (input_pipeline_stats.metadata().type() ==
287 InputPipelineMetadata::DEVICE) {
288 // Ignore device input pipelines.
289 continue;
290 }
291 // Choose the slowest execution trace of the input pipeline.
292 // `input_pipeline_stats.stats` is already sorted so choose the first one.
293 const InputPipelineStat& input_pipeline_stat =
294 input_pipeline_stats.stats(0);
295 const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at(
296 input_pipeline_stat.bottleneck_iterator_id());
297 slow_input_pipelines.emplace_back(
298 host_name, input_pipeline_stats.metadata().name(),
299 input_pipeline_stats.max_latency_ps(), metadata.name(),
300 metadata.long_name(),
301 input_pipeline_stat.bottleneck_iterator_latency_ps());
302 }
303 }
304 std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end());
305 for (const auto& input_pipeline : slow_input_pipelines) {
306 TfDataBottleneckAnalysis* bottleneck_analysis =
307 combined_tf_data_stats->add_bottleneck_analysis();
308 bottleneck_analysis->set_host(input_pipeline.host_name.data(),
309 input_pipeline.host_name.size());
310 bottleneck_analysis->set_input_pipeline(
311 input_pipeline.input_pipeline_name.data(),
312 input_pipeline.input_pipeline_name.size());
313 bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps);
314 bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(),
315 input_pipeline.iterator_name.size());
316 bottleneck_analysis->set_iterator_long_name(
317 input_pipeline.iterator_long_name.data(),
318 input_pipeline.iterator_long_name.size());
319 bottleneck_analysis->set_iterator_latency_ps(
320 input_pipeline.iterator_latency_ps);
321 }
322 }
323
GetSuggestion(BottleneckType type)324 std::string GetSuggestion(BottleneckType type) {
325 constexpr absl::string_view kPlaybookLink =
326 "https://www.tensorflow.org/guide/data_performance_analysis";
327 constexpr absl::string_view kPlaybookSourceDatasetLink =
328 "https://www.tensorflow.org/guide/"
329 "data_performance_analysis#source_datasets";
330 constexpr absl::string_view kPlaybookCpuUtilizationLink =
331 "https://www.tensorflow.org/guide/"
332 "data_performance_analysis#3_are_you_reaching_high_cpu_utilization";
333 constexpr absl::string_view kPlaybookTransformationLink =
334 "https://www.tensorflow.org/guide/"
335 "data_performance_analysis#transformation_datasets";
336 constexpr absl::string_view kTfGuideParallelDataExtractionLink =
337 "https://www.tensorflow.org/guide/"
338 "data_performance#parallelizing_data_extraction";
339 constexpr absl::string_view kTfGuideParallelTransformationLink =
340 "https://www.tensorflow.org/guide/"
341 "data_performance#parallelizing_data_transformation";
342 constexpr absl::string_view kTfGuideCacheLink =
343 "https://www.tensorflow.org/guide/data_performance#caching";
344 constexpr absl::string_view kTfDataServiceLink =
345 "https://www.tensorflow.org/api_docs/python/tf/data/experimental/"
346 "service?version=nightly";
347 switch (type) {
348 case BottleneckType::kSlowSource:
349 return absl::StrFormat(
350 "1. Check the locality of a host and input data. Ideally, they "
351 "should be in the same cell (or very close, like the same "
352 "region).<br/>"
353 "2. Parallelize reading from this dataset source. See %s and %s for "
354 "more details.<br/>",
355 AnchorElement(kPlaybookSourceDatasetLink, "here"),
356 AnchorElement(kTfGuideParallelDataExtractionLink, "here"));
357 case BottleneckType::kSlowDataService:
358 return absl::StrFormat(
359 "1. Fetching data from tf.data service took a while. Profile the "
360 "tf.data service worker to analyze the issue further.<br/>"
361 "2. See %s for more details on tf.data service.<br/>"
362 "3. See %s for other suggestions.",
363 AnchorElement(kTfDataServiceLink, "this"),
364 AnchorElement(kPlaybookLink, "this"));
365 case BottleneckType::kSlowRemoteSource:
366 return absl::StrFormat(
367 "1. The remote data source is slow. Profile its host to analyze the "
368 "issue further.<br/>"
369 "2. See %s for other suggestions.",
370 AnchorElement(kPlaybookLink, "this"));
371 case BottleneckType::kSlowTransformationWithParallelVersion:
372 return absl::StrFormat(
373 "1. Parallelize this transformation by setting "
374 "<code>num_parallel_calls=tf.data.experimental.AUTOTUNE</code>. See "
375 "%s for more details.<br/>"
376 "2. Consider adding <code>cache</code> after this transformation if "
377 "your data fits into memory and it is appropriate (e.g., there is no "
378 "randomness in upstream transformations like <code>shuffle</code>). "
379 "See %s for more details.<br/>"
380 "3. Find more resources %s.",
381 AnchorElement(kTfGuideParallelTransformationLink, "this"),
382 AnchorElement(kTfGuideCacheLink, "this"),
383 AnchorElement(kPlaybookTransformationLink, "here"));
384 case BottleneckType::kSlowTransformationWithoutParallelVersion:
385 return absl::StrFormat(
386 "1. This transformation is inherently sequential. Add outer "
387 "parallelism by running multiple copies of the input pipeline over "
388 "sharded inputs and combining the results. See %s for more "
389 "details.<br/>"
390 "2. Consider adding <code>cache</code> after this transformation if "
391 "your data fits into memory and it is appropriate (e.g., there is no "
392 "randomness in upstream transformations like <code>shuffle</code>). "
393 "See %s for more details.<br/>"
394 "3. Find more resources %s.",
395 AnchorElement(kPlaybookTransformationLink, "this"),
396 AnchorElement(kTfGuideCacheLink, "this"),
397 AnchorElement(kPlaybookCpuUtilizationLink, "here"));
398 default:
399 return absl::StrFormat("See %s for suggestions.",
400 AnchorElement(kPlaybookLink, "this"));
401 }
402 }
403
SetSuggestion(CombinedTfDataStats * combined_tf_data_stats)404 void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) {
405 for (TfDataBottleneckAnalysis& bottleneck_analysis :
406 *combined_tf_data_stats->mutable_bottleneck_analysis()) {
407 bottleneck_analysis.set_suggestion(
408 GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name())));
409 }
410 }
411
SetSummary(CombinedTfDataStats * combined_tf_data_stats)412 void SetSummary(CombinedTfDataStats* combined_tf_data_stats) {
413 int64 max_latency_ps = 0;
414 if (combined_tf_data_stats->bottleneck_analysis_size()) {
415 max_latency_ps =
416 combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps();
417 }
418 if (max_latency_ps > kSlowCallThresholdPs) {
419 combined_tf_data_stats->set_is_input_bound(true);
420 combined_tf_data_stats->set_summary(
421 "Your profile has a tf.data input pipeline slower than 50 us. For each "
422 "slow input pipeline, below shows a bottleneck in the input pipeline "
423 "and a suggestion on how to fix it.");
424 } else if (max_latency_ps > 0) {
425 combined_tf_data_stats->set_is_input_bound(false);
426 combined_tf_data_stats->set_summary(
427 "Your profile does not have any tf.data input pipeline slower than 50 "
428 "us. Your job could be still input bound if this profile didn't "
429 "capture all workers.");
430 } else {
431 combined_tf_data_stats->set_is_input_bound(false);
432 combined_tf_data_stats->set_summary(
433 "No tf.data activitiy captured in your profile. If your job uses "
434 "tf.data, try to capture a longer profile.");
435 }
436 }
437
438 } // namespace
439
GetBottleneckType(absl::string_view bottleneck_iterator_name)440 BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) {
441 static auto* kBottleneckTypeMap = new absl::flat_hash_map<absl::string_view,
442 BottleneckType>(
443 {// Read from storage.
444 {"TFRecord", BottleneckType::kSlowSource},
445 {"SSTable", BottleneckType::kSlowSource},
446 {"RecordIO", BottleneckType::kSlowSource},
447 {"Spanner", BottleneckType::kSlowSource},
448 {"TFColumn", BottleneckType::kSlowSource},
449 {"SleepwalkRemoteDataset", BottleneckType::kSlowSource},
450 {"TextLine", BottleneckType::kSlowSource},
451 {"StitchedTimelineDataset", BottleneckType::kSlowSource},
452 {"DateKeyDataset", BottleneckType::kSlowSource},
453 {"CapacitorProto", BottleneckType::kSlowSource},
454 {"LMDB", BottleneckType::kSlowSource},
455 {"ExternalDataset", BottleneckType::kSlowSource},
456 {"PearModel", BottleneckType::kSlowSource},
457 {"FixedLengthRecordV2", BottleneckType::kSlowSource},
458 // Read from local memory.
459 {"FromTensor", BottleneckType::kSlowSource},
460 {"TensorSlice", BottleneckType::kSlowSource},
461 {"Generator", BottleneckType::kSlowSource},
462 {"SyntheticDatasetOp", BottleneckType::kSlowSource},
463 // tf.data service.
464 {"DataService", BottleneckType::kSlowDataService},
465 // Read from remote memory.
466 {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource},
467 {"ReverbDataset", BottleneckType::kSlowRemoteSource},
468 {"DatasetSampleGame", BottleneckType::kSlowRemoteSource},
469 {"Courier", BottleneckType::kSlowRemoteSource},
470 {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource},
471 // Transformations with parallel version.
472 {"Map", BottleneckType::kSlowTransformationWithParallelVersion},
473 {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion},
474 // Transformations without parallel version.
475 {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion},
476 {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion},
477 {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}});
478 if (auto type =
479 gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) {
480 return *type;
481 }
482 return BottleneckType::kOther;
483 }
484
Add(absl::string_view host_name,XPlane * host_plane)485 void CombinedTfDataStatsBuilder::Add(absl::string_view host_name,
486 XPlane* host_plane) {
487 TfDataStats& tf_data_stats =
488 (*combined_tf_data_stats_
489 ->mutable_tf_data_stats())[std::string(host_name)];
490 EventForest event_forest;
491 event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane});
492 event_forest.ConnectEvents();
493 event_forest.ConnectTfDataEvents();
494 absl::flat_hash_set<int64> device_input_pipeline_ids;
495 absl::flat_hash_map<int64, std::vector<EventNode*>> root_iterator_event_map;
496 ProcessEventForest(event_forest, &device_input_pipeline_ids,
497 &root_iterator_event_map, &tf_data_stats);
498 ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map,
499 &tf_data_stats);
500 }
501
Finalize()502 void CombinedTfDataStatsBuilder::Finalize() {
503 SetBottleneckAnalysis(combined_tf_data_stats_);
504 if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_);
505 SetSummary(combined_tf_data_stats_);
506 }
507
508 } // namespace profiler
509 } // namespace tensorflow
510