1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
17 
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 
30 namespace xla {
HloProfileIndexMap(const HloModule & module,absl::Span<const string> extra_metrics)31 HloProfileIndexMap::HloProfileIndexMap(const HloModule& module,
32                                        absl::Span<const string> extra_metrics) {
33   size_t current_profile_index = 0;
34   for (xla::HloComputation* computation : module.MakeComputationPostOrder()) {
35     InsertOrDie(&computation_to_profile_idx_, computation,
36                 current_profile_index++);
37     for (const HloInstruction* instruction : computation->instructions()) {
38       // For simplicity we track all instructions here, but we could skip
39       // non-executing instructions like constants and parameters.
40       InsertOrDie(&instruction_to_profile_idx_, instruction,
41                   current_profile_index++);
42     }
43   }
44   for (const string& key : extra_metrics) {
45     InsertOrDie(&extra_metric_to_profile_idx_, key, current_profile_index++);
46   }
47 }
48 
CreateHloProfilePrinterData(const HloProfileIndexMap & hlo_profile_index_map,const HloCostAnalysis & cost_analysis,const string & entry_computation_name)49 std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
50     const HloProfileIndexMap& hlo_profile_index_map,
51     const HloCostAnalysis& cost_analysis,
52     const string& entry_computation_name) {
53   using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
54   using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
55 
56   size_t profile_counters_size = hlo_profile_index_map.total_count();
57 
58   std::unique_ptr<HloProfilePrinterData> profile_printer_data =
59       absl::make_unique<HloProfilePrinterData>();
60   profile_printer_data->set_profile_counters_size(profile_counters_size);
61   profile_printer_data->mutable_computation_infos()->Reserve(
62       hlo_profile_index_map.computation_count());
63 
64   const auto& computation_to_profile_idx_map =
65       hlo_profile_index_map.computation_to_profile_idx();
66 
67   // computation_to_profile_idx_map's order is not deterministic so create a
68   // deterministic computation_and_profile_idx_list so that we end up with a
69   // deterministic HloProfilePrinterData protobuf.
70 
71   std::vector<std::pair<const HloComputation*, int64>>
72       computation_and_profile_idx_list(computation_to_profile_idx_map.begin(),
73                                        computation_to_profile_idx_map.end());
74 
75   // The profile indices were computed deterministically in
76   // HloProfileIndexMap::HloProfileIndexMap.
77   absl::c_sort(computation_and_profile_idx_list,
78                [](const std::pair<const HloComputation*, int64>& left,
79                   const std::pair<const HloComputation*, int64>& right) {
80                  return left.second < right.second;
81                });
82 
83   for (const auto& pair : computation_and_profile_idx_list) {
84     CHECK_LT(pair.second, profile_counters_size);
85     const HloComputation* computation = pair.first;
86     HloComputationInfo* computation_info =
87         profile_printer_data->add_computation_infos();
88 
89     computation_info->set_name(computation->name());
90     computation_info->set_profile_index(pair.second);
91     computation_info->mutable_instruction_infos()->Reserve(
92         computation->instruction_count());
93 
94     for (const HloInstruction* hlo : computation->instructions()) {
95       HloInstructionInfo* instruction_info =
96           computation_info->add_instruction_infos();
97       instruction_info->set_long_name(hlo->ToString());
98       instruction_info->set_short_name(hlo->ToString(
99           HloPrintOptions().set_compact_operands(true).set_print_operand_names(
100               false)));
101       instruction_info->set_category(hlo->ToCategory());
102       instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
103       instruction_info->set_transcendental_count(
104           cost_analysis.transcendental_count(*hlo));
105       instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo));
106       instruction_info->set_optimal_seconds(
107           cost_analysis.optimal_seconds(*hlo));
108       instruction_info->set_profile_index(
109           hlo_profile_index_map.GetProfileIndexFor(*hlo));
110     }
111   }
112 
113   // Add extra metrics if any.
114   for (const auto& pair : hlo_profile_index_map.extra_metric_to_profile_idx()) {
115     profile_printer_data->mutable_extra_metrics()->insert(
116         {pair.first, pair.second});
117   }
118 
119   profile_printer_data->set_entry_computation(entry_computation_name);
120 
121   return profile_printer_data;
122 }
123 
HloExecutionProfile(const HloProfilePrinterData * hlo_profile_printer_data,const HloProfileIndexMap * hlo_profile_index_map)124 HloExecutionProfile::HloExecutionProfile(
125     const HloProfilePrinterData* hlo_profile_printer_data,
126     const HloProfileIndexMap* hlo_profile_index_map)
127     : hlo_profile_printer_data_(*hlo_profile_printer_data),
128       hlo_profile_index_map_(*hlo_profile_index_map),
129       profile_counters_(
130           /*count=*/hlo_profile_index_map_.total_count(),
131           /*value=*/0) {}
132 
SetCyclesTakenBy(const HloInstruction * hlo,uint64 cycles_taken)133 void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
134                                            uint64 cycles_taken) {
135   profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] =
136       cycles_taken;
137 }
138 
GetCyclesTakenBy(const HloInstruction & hlo) const139 uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const {
140   return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)];
141 }
142 
143 }  // namespace xla
144