1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/grappler/cost_analyzer.h"
17 
18 #include <iomanip>
19 #include "tensorflow/core/grappler/costs/utils.h"
20 #include "tensorflow/core/grappler/grappler_item.h"
21 #include "tensorflow/core/lib/core/status.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 
CostAnalyzer(const GrapplerItem & item,Cluster * cluster,const string & suffix)26 CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
27                            const string& suffix)
28     : item_(&item),
29       measure_estimator_(cluster, 10, 0),
30       analytical_estimator_(cluster, /*use_static_shapes=*/false,
31                             /*use_aggressive_shape_inference=*/true),
32       suffix_(suffix) {}
33 
GenerateReport(std::ostream & os,bool per_node_report,bool verbose)34 Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
35                                     bool verbose) {
36   GatherCosts();
37   PreprocessCosts();
38   AnalyzeCosts();
39   PrintAnalysis(os, per_node_report, verbose);
40   return Status::OK();
41 }
42 
PredictCosts(CostEstimator * cost_estimator,CostGraphDef * cost_graph,int64 * total_time)43 void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
44                                 CostGraphDef* cost_graph, int64* total_time) {
45   TF_CHECK_OK(cost_estimator->Initialize(*item_));
46   RunMetadata run_metadata;
47   Costs costs;
48   const Status status =
49       cost_estimator->PredictCosts(item_->graph, &run_metadata, &costs);
50   if (cost_graph) {
51     cost_graph->Swap(run_metadata.mutable_cost_graph());
52   }
53   *total_time = costs.execution_time.count();
54   if (!status.ok()) {
55     LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "
56                << status.error_message();
57     return;
58   }
59 }
60 
GatherCosts()61 void CostAnalyzer::GatherCosts() {
62   CostGraphDef cost_graph_measured;
63   PredictCosts(&measure_estimator_, &cost_graph_measured,
64                &total_time_measured_);
65   VLOG(1) << "Graph size: " << item_->graph.node_size();
66   VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
67 
68   CostGraphDef cost_graph_analytical;
69   PredictCosts(&analytical_estimator_, &cost_graph_analytical,
70                &total_time_analytical_);
71   VLOG(1) << "cost_graph_analytical size: "
72           << cost_graph_analytical.node_size();
73 
74   CostGraphDef cost_graph_analytical_filtered;
75   CostGraphDef cost_graph_measured_filtered;
76   std::map<string, const CostGraphDef_Node*> measured_nodes;
77   for (const auto& node : cost_graph_measured.node()) {
78     measured_nodes[node.name()] = &node;
79   }
80   for (const auto& node : cost_graph_analytical.node()) {
81     auto it = measured_nodes.find(node.name());
82     // Filter the nodes that are not the cost nodes returned by
83     // MeasuringCostEstimator.
84     if (it == measured_nodes.end()) {
85       continue;
86     }
87     auto added_node_analytical = cost_graph_analytical_filtered.add_node();
88     auto added_node_measured = cost_graph_measured_filtered.add_node();
89     *added_node_analytical = node;
90     *added_node_measured = *(it->second);
91   }
92   VLOG(1) << "cost_graph_analytical_filtered size: "
93           << cost_graph_analytical_filtered.node_size();
94 
95   // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and
96   // op_perf_ cover the same set of nodes.
97   op_perf_analytical_ = CostGraphToOpPerformanceData(
98       cost_graph_analytical_filtered, item_->graph);
99   op_perf_ =
100       CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph);
101 }
102 
PreprocessCosts()103 void CostAnalyzer::PreprocessCosts() {
104   for (int i = 0; i < op_perf_.op_performance_size(); i++) {
105     OpPerformance* perf = op_perf_.mutable_op_performance(i);
106     const OpPerformance& analytical = op_perf_analytical_.op_performance(i);
107     perf->set_compute_time(analytical.compute_time());
108     perf->set_memory_time(analytical.memory_time());
109     double measured_cost = perf->compute_cost();
110 
111     double analytical_compute_cost = analytical.compute_time();
112     if (analytical_compute_cost == 0) {
113       // Negative infinity indidates unavailable data.
114       perf->set_compute_efficiency(-INFINITY);
115     } else {
116       perf->set_compute_efficiency(analytical_compute_cost / measured_cost);
117     }
118 
119     double analytical_memory_cost = analytical.memory_time();
120     if (analytical_memory_cost == 0) {
121       // Negative infinity indidates unavailable data.
122       perf->set_memory_efficiency(-INFINITY);
123     } else {
124       perf->set_memory_efficiency(analytical_memory_cost / measured_cost);
125     }
126   }
127 }
128 
SortOpsByTime(std::map<string,OpPerfSummary> ops)129 void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
130   for (const auto& op : ops) {
131     ops_.push_back(op.second);
132   }
133   struct CompareByTime {
134     bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const {
135       return a.time > b.time;
136     }
137   };
138   std::stable_sort(ops_.begin(), ops_.end(), CompareByTime());
139 }
140 
AnalyzeCosts()141 void CostAnalyzer::AnalyzeCosts() {
142   std::map<string, OpPerfSummary> ops;
143   for (const auto& op_perf : op_perf_.op_performance()) {
144     string op_name = op_perf.op().op();
145     ops[op_name].count++;
146     ops[op_name].time += op_perf.compute_cost();
147     ops[op_name].compute_time += op_perf.compute_time();
148     ops[op_name].memory_time += op_perf.memory_time();
149     ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time();
150     ops[op_name].time_lower +=
151         std::max(op_perf.compute_time(), op_perf.memory_time());
152     ops[op_name].name = op_name;
153   }
154   SortOpsByTime(ops);
155 
156   total_time_measured_serialized_ = 0;
157   total_time_analytical_upper_ = 0;
158   total_time_analytical_lower_ = 0;
159   for (const auto& op : ops_) {
160     total_time_measured_serialized_ += op.time;
161     total_time_analytical_upper_ += op.time_upper;
162     total_time_analytical_lower_ += op.time_lower;
163   }
164 }
165 
PrintAnalysis(std::ostream & os,bool per_node_report,bool verbose) const166 void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report,
167                                  bool verbose) const {
168   os << std::endl;
169   os << std::left << std::setw(50)
170      << "Total time measured in ns (serialized): " << std::right
171      << std::setw(20) << total_time_measured_serialized_ << std::endl;
172   os << std::left << std::setw(50)
173      << "Total time measured in ns (actual): " << std::right << std::setw(20)
174      << total_time_measured_ << std::endl;
175   os << std::left << std::setw(50)
176      << "Total time analytical in ns (upper bound): " << std::right
177      << std::setw(20) << total_time_analytical_upper_ << std::endl;
178   os << std::left << std::setw(50)
179      << "Total time analytical in ns (lower bound): " << std::right
180      << std::setw(20) << total_time_analytical_lower_ << std::endl;
181   double efficiency_upper = static_cast<double>(total_time_analytical_upper_) /
182                             static_cast<double>(total_time_measured_);
183   os << std::left << std::setw(50)
184      << "Overall efficiency (analytical upper/actual): " << std::right
185      << std::setw(20) << efficiency_upper << std::endl;
186   double efficiency_lower = static_cast<double>(total_time_analytical_lower_) /
187                             static_cast<double>(total_time_measured_);
188   os << std::left << std::setw(50)
189      << "Overall efficiency (analytical lower/actual): " << std::right
190      << std::setw(20) << efficiency_lower << std::endl;
191   os << std::endl;
192 
193   int width = 35;
194   int width_narrow = 15;
195   int width_wide = 20;
196   os << std::setw(width + 1) << "Op,";
197   os << std::setw(width_narrow + 1) << "Count,";
198   os << std::setw(width_wide + 1) << "Measured time (ns),";
199   os << std::setw(width_narrow + 2) << "Time percent,";
200   os << std::setw(width_narrow + 2) << "Acc percent,";
201   os << std::setw(width_wide + 1) << "Analytical upper,";
202   os << std::setw(width_wide + 1) << "Analytical lower,";
203   os << std::setw(width_narrow + 2) << "Overall eff";
204   os << std::setw(width_narrow + 2) << "Compute eff";
205   os << std::setw(width_narrow + 2) << "Memory eff" << std::endl;
206   float acc_percent = 0;
207   for (const auto& op : ops_) {
208     double percent = static_cast<double>(op.time) /
209                      static_cast<double>(total_time_measured_serialized_);
210     double eff =
211         static_cast<double>(op.time_upper) / static_cast<double>(op.time);
212     double compute_eff =
213         static_cast<double>(op.compute_time) / static_cast<double>(op.time);
214     double memory_eff =
215         static_cast<double>(op.memory_time) / static_cast<double>(op.time);
216     os << std::setw(width) << op.name << ",";
217     os << std::setw(width_narrow) << op.count << ",";
218     os << std::setw(width_wide) << op.time << ",";
219     os << std::setw(width_narrow) << std::setprecision(2) << percent * 100
220        << "%,";
221     acc_percent += percent;
222     os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100
223        << "%,";
224     os << std::setw(width_wide) << op.time_upper << ",";
225     os << std::setw(width_wide) << op.time_lower << ",";
226     os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,";
227     os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100
228        << "%,";
229     os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100
230        << "%,";
231     os << std::endl;
232   }
233   os << std::endl;
234 
235   if (per_node_report) {
236     if (verbose) {
237       os << "Below is the full per-node report:" << std::endl;
238       os << op_perf_.DebugString();
239     } else {
240       os << "Below is the per-node report summary:" << std::endl;
241       int width = 35;
242       int width_narrow = 15;
243       int width_wide = 20;
244       os << std::setw(width + 1) << "Op,";
245       os << std::setw(width_wide + 1) << "Measured time (ns),";
246       os << std::setw(width_wide + 1) << "Compute time (ns),";
247       os << std::setw(width_wide + 1) << "Memory time (ns),";
248       os << std::setw(width_narrow + 2) << "Compute eff,";
249       os << std::setw(width_narrow + 2) << "Memory eff,";
250       os << "    Inputs" << std::endl;
251       for (int i = 0; i < op_perf_.op_performance_size(); i++) {
252         const auto& perf = op_perf_.op_performance(i);
253         string op_name = perf.op().op();
254         os << std::setw(width) << op_name << ",";
255         os << std::setw(width_wide) << perf.compute_cost() << ",";
256         os << std::setw(width_wide) << perf.compute_time() << ",";
257         os << std::setw(width_wide) << perf.memory_time() << ",";
258         os << std::setw(width_narrow) << std::setprecision(2)
259            << perf.compute_efficiency() * 100 << "%,";
260         os << std::setw(width_narrow) << std::setprecision(2)
261            << perf.memory_efficiency() * 100 << "%,";
262         os << "    [";
263         for (int j = 0; j < perf.op().inputs_size(); j++) {
264           const auto& shape = perf.op().inputs(j).shape();
265           if (shape.dim_size() > 0) {
266             os << "(";
267             std::vector<int> dims;
268             for (int k = 0; k < shape.dim_size(); k++) {
269               os << shape.dim(k).size();
270               if (k < shape.dim_size() - 1) {
271                 os << ", ";
272               }
273             }
274             os << ")";
275             if (j < perf.op().inputs_size() - 1) {
276               os << ", ";
277             }
278           }
279         }
280         os << "]" << std::endl;
281       }
282       os << std::endl;
283     }
284   }
285 }
286 }  // end namespace grappler
287 }  // end namespace tensorflow
288