1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
21
22 namespace tensorflow {
23 namespace profiler {
24 namespace {
25
26 using OperationType = OpMetrics::MemoryAccessed::OperationType;
27
CombinePrecisionStats(const PrecisionStats & src,PrecisionStats * dst)28 void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) {
29 dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps());
30 dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps());
31 }
32
33 } // namespace
34
CopyOpMetricsMetadata(const OpMetrics & src,OpMetrics * dst)35 void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) {
36 DCHECK(dst != nullptr);
37 DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id());
38 DCHECK_EQ(src.name(), dst->name());
39 if (dst->long_name().empty()) {
40 dst->set_long_name(src.long_name());
41 }
42 if (dst->category().empty()) {
43 dst->set_category(src.category());
44 }
45 if (dst->provenance().empty()) {
46 dst->set_provenance(src.provenance());
47 }
48 if (dst->deduplicated_name().empty()) {
49 dst->set_deduplicated_name(src.deduplicated_name());
50 }
51 if (!dst->has_layout() && src.has_layout()) {
52 *dst->mutable_layout() = src.layout();
53 }
54 if (!dst->has_children() && src.has_children()) {
55 *dst->mutable_children() = src.children();
56 }
57 }
58
CombineOpMetrics(const OpMetrics & src,OpMetrics * dst)59 void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst) {
60 DCHECK(dst != nullptr);
61 if (dst->occurrences() == 0) {
62 dst->set_min_time_ps(src.min_time_ps());
63 } else {
64 dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps()));
65 }
66 dst->set_is_eager(dst->is_eager() || src.is_eager());
67 dst->set_occurrences(src.occurrences() + dst->occurrences());
68 dst->set_time_ps(src.time_ps() + dst->time_ps());
69 dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps());
70 dst->set_flops(src.flops() + dst->flops());
71 dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed());
72 CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(),
73 dst->mutable_memory_accessed_breakdown());
74 dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps());
75 }
76
CombineMemoryAccessedBreakdown(const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> & src,protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> * dst)77 void CombineMemoryAccessedBreakdown(
78 const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
79 protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst) {
80 if (src.empty()) return;
81 absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
82 OpMetrics_MemoryAccessed*>
83 dst_memory_accessed_map;
84 for (auto& dst_memory_accessed : *dst) {
85 dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
86 dst_memory_accessed.operation_type()}] =
87 &dst_memory_accessed;
88 }
89 for (const auto& src_memory_accessed : src) {
90 uint64 memory_space = src_memory_accessed.memory_space();
91 OperationType operation_type = src_memory_accessed.operation_type();
92 auto*& dst_memory_accessed =
93 dst_memory_accessed_map[{memory_space, operation_type}];
94 if (dst_memory_accessed == nullptr) {
95 dst_memory_accessed = dst->Add();
96 dst_memory_accessed->set_memory_space(memory_space);
97 dst_memory_accessed->set_operation_type(operation_type);
98 }
99 dst_memory_accessed->set_bytes_accessed(
100 src_memory_accessed.bytes_accessed() +
101 dst_memory_accessed->bytes_accessed());
102 }
103 }
104
Combine(const OpMetricsDb & src)105 void OpMetricsDbCombiner::Combine(const OpMetricsDb& src) {
106 OpMetricsDb* dst = db();
107 dst->set_total_host_infeed_enq_duration_ps(
108 src.total_host_infeed_enq_duration_ps() +
109 dst->total_host_infeed_enq_duration_ps());
110 dst->set_total_host_infeed_enq_start_timestamp_ps_diff(
111 src.total_host_infeed_enq_start_timestamp_ps_diff() +
112 dst->total_host_infeed_enq_start_timestamp_ps_diff());
113 dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps());
114 dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps());
115 CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats());
116
117 for (const auto& src_metrics : src.metrics_db()) {
118 auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(),
119 src_metrics.name());
120 CopyOpMetricsMetadata(src_metrics, dst_metrics);
121 CombineOpMetrics(src_metrics, dst_metrics);
122 }
123 }
124
125 } // namespace profiler
126 } // namespace tensorflow
127