1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17 
18 #include <stdio.h>
19 
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/io/zlib_compression_options.h"
28 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
31 
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
35 
36 const char* const kGradientSuffix = " (gradient)";
37 
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 std::string GetTraceString(const CallStack::Trace& trace) {
40   std::string ntrace =
41       absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
42   if (trace.function().length() < 20) {
43     absl::StrAppend(&ntrace, ":", trace.function());
44   } else {
45     absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
46   }
47   return ntrace;
48 }
49 
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51   // Given a forward operation with name op, its gradient op has the following
52   // name: ...gradients/op_grad/...
53   // TODO(xpan): This is hacky.
54   auto grad_prefix = name.find("gradients/");
55   auto grad_suffix = name.find("_grad/");
56   if (grad_prefix == name.npos || grad_suffix == name.npos) {
57     return false;
58   }
59   auto start = grad_prefix + string("gradients/").length();
60   auto len = grad_suffix - start;
61   if (len <= 0) {
62     return false;
63   }
64   *forward_name = name.substr(start, len);
65   return true;
66 }
67 
68 // StringTable maps each string to an id.
69 class StringTable {
70  public:
StringTable()71   StringTable() {
72     // Pprof requires first entry in string_table to be ''.
73     string_id_[""] = 0;
74     all_strings_.push_back("");
75   }
76 
77   // Returns the index of a string. If not found, inserts the string and
78   // return the inserted index.
GetIndex(const string & str)79   uint64 GetIndex(const string& str) {
80     auto idx = string_id_.find(str);
81     if (idx != string_id_.end()) {
82       return idx->second;
83     }
84     all_strings_.push_back(str);
85     return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86         .first->second;
87   }
88 
strings() const89   const std::vector<string>& strings() const { return all_strings_; }
90 
91  private:
92   std::map<string, uint64> string_id_;
93   std::vector<string> all_strings_;
94 };
95 
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98  public:
FunctionTable(StringTable * string_table)99   explicit FunctionTable(StringTable* string_table)
100       : string_table_(string_table) {}
101 
102   // Returns the index of a function. If not found, adds a function proto
103   // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104   uint64 GetIndex(const string& file_path, const string& func_name,
105                   uint64 func_start_line) {
106     auto key = std::tuple<string, string, uint64>(file_path, func_name,
107                                                   func_start_line);
108     auto idx = function_table_.find(key);
109     if (idx != function_table_.end()) {
110       return idx->second.id();
111     }
112     pprof::Function* func_pb = &function_table_[key];
113     // function index should start from 1.
114     func_pb->set_id(function_table_.size());
115 
116     string file_base(io::Basename(file_path));
117     file_base = file_base.substr(0, file_base.find_last_of('.'));
118     func_pb->set_name(
119         string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
120     func_pb->set_filename(string_table_->GetIndex(file_path));
121     func_pb->set_start_line(func_start_line);
122     return func_pb->id();
123   }
124 
125   const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126   functions() const {
127     return function_table_;
128   }
129 
130  private:
131   StringTable* string_table_;
132   std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
134 
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137  public:
LocationTable(FunctionTable * function_table)138   explicit LocationTable(FunctionTable* function_table)
139       : function_table_(function_table) {}
140 
141   // Returns the index of a function call location. If not found, adds a
142   // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143   uint64 GetIndex(const string& file_path, uint64 line_number,
144                   const string& called_function_name,
145                   const string& called_file_path,
146                   uint64 called_func_start_line) {
147     auto key = std::tuple<string, string, uint64>(
148         file_path, called_function_name, line_number);
149 
150     auto idx = location_table_.find(key);
151     if (idx != location_table_.end()) {
152       return idx->second.id();
153     }
154     pprof::Location* location_pb = &location_table_[key];
155     location_pb->set_id(location_table_.size());
156     pprof::Line* line_pb = location_pb->add_line();
157     line_pb->set_function_id(function_table_->GetIndex(
158         called_file_path, called_function_name, called_func_start_line));
159     line_pb->set_line(line_number);
160     return location_pb->id();
161   }
162 
163   const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164   locations() const {
165     return location_table_;
166   }
167 
168  private:
169   FunctionTable* function_table_;
170   std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
172 
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176  public:
Samples(StringTable * string_table,const Options * opts)177   explicit Samples(StringTable* string_table, const Options* opts)
178       : string_table_(string_table), opts_(opts) {}
179 
180   // 'node' is the leaf of the displayed trace. It includes all graph nodes
181   // created by it. 'location_ids' contains
182   // the call stack, from callee to caller.
183   // This method adds the statistics of graph nodes created by the python
184   // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185   void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186     // displayed leaf might not be true leaf. Retrieve the true leaves for
187     // stats.
188     std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189     CHECK(!all_leaf.empty()) << node->name();
190 
191     for (const CodeNode* cn : all_leaf) {
192       for (const auto& gn_it : cn->node->graph_nodes()) {
193         const TFGraphNode* gn = gn_it.second;
194         string name = gn->name();
195         // Generate a new trace name, in case the name is taken.
196         while (sample_table_.find(name) != sample_table_.end()) {
197           name += '@';
198         }
199         pprof::Sample* sample_pb = &sample_table_[name];
200         for (uint64 id : location_ids) {
201           sample_pb->mutable_location_id()->Add(id);
202         }
203         pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204         label_pb->set_key(string_table_->GetIndex("graph node:"));
205         label_pb->set_str(string_table_->GetIndex(gn->name()));
206 
207         sample_pb->mutable_value()->Add(1);
208         string type = *opts_->select.begin();
209         if (type == kShown[1]) {
210           sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211         } else if (type == kShown[9]) {
212           sample_pb->mutable_value()->Add(
213               gn->accelerator_exec_micros(node->node->step()));
214         } else if (type == kShown[10]) {
215           sample_pb->mutable_value()->Add(
216               gn->cpu_exec_micros(node->node->step()));
217         } else if (type == kShown[0]) {
218           sample_pb->mutable_value()->Add(
219               gn->requested_bytes(node->node->step()));
220         } else if (type == kShown[11]) {
221           sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222         } else if (type == kShown[12]) {
223           sample_pb->mutable_value()->Add(
224               gn->residual_bytes(node->node->step()));
225         } else if (type == kShown[13]) {
226           sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227         } else if (type == kShown[2]) {
228           sample_pb->mutable_value()->Add(gn->parameters());
229         } else if (type == kShown[3]) {
230           sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231         } else {
232           absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
233         }
234       }
235     }
236   }
237 
samples() const238   const std::map<string, pprof::Sample>& samples() const {
239     return sample_table_;
240   }
241 
242  private:
FetchAllLeaf(const CodeNode * root)243   std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244     if (root->children.empty()) {
245       return {root};
246     }
247     std::vector<const CodeNode*> ret;
248     for (auto& n : root->children) {
249       std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250       ret.insert(ret.end(), nodes.begin(), nodes.end());
251     }
252     return ret;
253   }
254 
255   StringTable* string_table_;
256   const Options* opts_;
257   std::map<string, pprof::Sample> sample_table_;
258 };
259 
260 class PprofProfileImpl : public PprofProfile {
261  public:
PprofProfileImpl(const Options * opts)262   explicit PprofProfileImpl(const Options* opts)
263       : opts_(opts),
264         func_table_(new FunctionTable(&string_table_)),
265         loc_table_(new LocationTable(func_table_.get())),
266         samples_(new Samples(&string_table_, opts)) {}
267 
AddLocation(const CodeNode * callee,const CodeNode * caller)268   uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269     const string& file_path = caller->file();
270     uint64 lineno = caller->lineno();
271     const string& callee_file_path = callee->file();
272     const string& callee_function = callee->function();
273     uint64 callee_func_start_line = callee->func_start_line();
274 
275     return loc_table_->GetIndex(file_path, lineno, callee_function,
276                                 callee_file_path, callee_func_start_line);
277   }
278 
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279   void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280     std::vector<uint64> reversed_call_ids;
281     std::reverse_copy(call_ids->begin(), call_ids->end(),
282                       std::back_inserter(reversed_call_ids));
283     samples_->Add(leaf, reversed_call_ids);
284   }
285 
WritePprofProfile(const string & filename)286   Status WritePprofProfile(const string& filename) override {
287     pprof::Profile profile_pb;
288     Build(&profile_pb);
289 
290     std::unique_ptr<WritableFile> file;
291     Status s = Env::Default()->NewWritableFile(filename, &file);
292     if (!s.ok()) return s;
293 
294     int32 buf_size = 1024 * 1024;
295     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297     s = zlib_output_buffer->Init();
298     if (!s.ok()) {
299       delete zlib_output_buffer;
300       return s;
301     }
302     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
303     if (!s.ok()) {
304       delete zlib_output_buffer;
305       return s;
306     }
307     s = zlib_output_buffer->Close();
308     if (!s.ok()) {
309       delete zlib_output_buffer;
310       return s;
311     }
312     absl::FPrintF(stdout,
313                   "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
314                   filename);
315     delete zlib_output_buffer;
316     return s;
317   }
318 
319  private:
Build(pprof::Profile * profile_pb)320   void Build(pprof::Profile* profile_pb) {
321     string sample_type_description = "count";
322     auto sample_type = profile_pb->mutable_sample_type()->Add();
323     sample_type->set_type(string_table_.GetIndex(sample_type_description));
324     sample_type->set_unit(string_table_.GetIndex("count"));
325 
326     string type = *opts_->select.begin();
327     sample_type_description = type;
328     sample_type = profile_pb->mutable_sample_type()->Add();
329     sample_type->set_type(string_table_.GetIndex(sample_type_description));
330     if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
331       sample_type->set_unit(string_table_.GetIndex("microseconds"));
332       if (type == kShown[1]) {
333         profile_pb->mutable_comment()->Add(string_table_.GetIndex(
334             "Sum of accelerator execution time and cpu execution time."));
335       } else if (type == kShown[9]) {
336         profile_pb->mutable_comment()->Add(
337             string_table_.GetIndex("Accelerator execution time."));
338       } else if (type == kShown[10]) {
339         profile_pb->mutable_comment()->Add(
340             string_table_.GetIndex("CPU execution time."));
341       }
342     } else if (type == kShown[0]) {
343       sample_type->set_unit(string_table_.GetIndex("bytes"));
344       profile_pb->mutable_comment()->Add(
345           string_table_.GetIndex("Sum of operation total memory requests, "
346                                  "excluding deallocations."));
347     } else if (type == kShown[11]) {
348       sample_type->set_unit(string_table_.GetIndex("bytes"));
349       profile_pb->mutable_comment()->Add(
350           string_table_.GetIndex("Sum of operation peak memory usage."));
351     } else if (type == kShown[12]) {
352       sample_type->set_unit(string_table_.GetIndex("bytes"));
353       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
354           "Sum of operation allocated memory after finish."));
355     } else if (type == kShown[13]) {
356       sample_type->set_unit(string_table_.GetIndex("bytes"));
357       profile_pb->mutable_comment()->Add(
358           string_table_.GetIndex("Sum of operation output size."));
359     } else if (type == kShown[2]) {
360       sample_type->set_unit(string_table_.GetIndex("count"));
361       profile_pb->mutable_comment()->Add(
362           string_table_.GetIndex("Model parameters."));
363     } else if (type == kShown[3]) {
364       sample_type->set_unit(string_table_.GetIndex("count"));
365       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
366           "Model float operations (Only available if defined)."));
367     } else {
368       absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
369     }
370 
371     for (const string& str : string_table_.strings()) {
372       *profile_pb->mutable_string_table()->Add() = str;
373     }
374     for (const auto& sample_it : samples_->samples()) {
375       // TODO(xpan): Consider swap.
376       profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
377     }
378     for (const auto& function_it : func_table_->functions()) {
379       profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
380     }
381     for (const auto& location_it : loc_table_->locations()) {
382       profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
383     }
384   }
385 
386   const Options* opts_;
387   StringTable string_table_;
388   std::unique_ptr<FunctionTable> func_table_;
389   std::unique_ptr<LocationTable> loc_table_;
390   std::unique_ptr<Samples> samples_;
391 };
392 }  // namespace
393 
AddNode(TFGraphNode * node)394 void TFCode::AddNode(TFGraphNode* node) {
395   if (!node->call_stack() || node->call_stack()->traces().empty()) {
396     return;
397   }
398   // We infer the forward operation name from gradient op name. So, we can
399   // map gradient op traces to forward op traces.
400   // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
401   string forward_name;
402   if (IsGradNode(node->name(), &forward_name)) {
403     auto grad_nodes_it = grad_nodes_.find(forward_name);
404     if (grad_nodes_it != grad_nodes_.end()) {
405       grad_nodes_it->second.push_back(node);
406     } else {
407       grad_nodes_.insert(
408           std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
409     }
410     return;
411   } else {
412     forward_nodes_[node->name()] = node;
413   }
414 
415   if (!root_) {
416     graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
417     root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
418   }
419 
420   CodeNode* pre_code_node = root_.get();
421   // TODO(xpan): Consider to release CodeDef after TFCode is built. It
422   // takes a lot of memory.
423   std::set<string> traces;
424   for (int i = 0, end = node->call_stack()->traces().size(); i < end; ++i) {
425     // Unlike op name, which is globally unique, trace name is only unique
426     // w.r.t. it's parent.
427     const string& trace = GetTraceString(node->call_stack()->traces().at(i));
428     traces.insert(trace);
429     pre_code_node = pre_code_node->AddChildren(
430         trace, &node->call_stack()->traces().at(i), "");
431     const int64 last_index = node->call_stack()->traces().size() - 1;
432     if (i == last_index) {
433       pre_code_node->node->AddGraphNode(node);
434     }
435   }
436 }
437 
Build()438 void TFCode::Build() {
439   int64 unaccounted_nodes = 0;
440   for (const auto& it : grad_nodes_) {
441     const string& forward_name = it.first;
442     auto forward_it = forward_nodes_.find(forward_name);
443     if (forward_it == forward_nodes_.end()) {
444       unaccounted_nodes += 1;
445       continue;
446     }
447     TFGraphNode* fn = forward_it->second;
448     CodeNode* leaf = nullptr;
449     CodeNode* pre_code_node = root_.get();
450     for (int i = 0, end = fn->call_stack()->traces().size(); i < end; ++i) {
451       const string& trace =
452           GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
453       pre_code_node = pre_code_node->AddChildren(
454           trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
455       const int64 last_trace = fn->call_stack()->traces().size() - 1;
456       if (i == last_trace) {
457         leaf = pre_code_node;
458       }
459     }
460     for (TFGraphNode* gn : it.second) {
461       leaf->node->AddGraphNode(gn);
462     }
463   }
464   if (unaccounted_nodes > 0) {
465     absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
466                   unaccounted_nodes);
467   }
468 }
469 
ShowInternal(const Options & opts,Timeline * timeline)470 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
471                                           Timeline* timeline) {
472   root_->ResetTotalStats();
473   if (opts.output_type == kOutput[3]) {
474     if (opts.select.size() != 1) {
475       absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
476       return root_.get();
477     }
478     string select = *opts.select.begin();
479     if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
480         select != kShown[3] && select != kShown[9] && select != kShown[10] &&
481         select != kShown[11] && select != kShown[12] && select != kShown[13]) {
482       absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
483       return root_.get();
484     }
485   }
486   if (opts.account_displayed_op_only) {
487     absl::FPrintF(stderr,
488                   "Note: code view ignores account_displayed_op_only\n");
489   }
490 
491   std::vector<CodeNode*> roots = Account(root_->children, opts);
492   root_->show_children.clear();
493   for (CodeNode* n : roots) {
494     root_->AggregateTotalStats(n);
495   }
496 
497   if (opts.start_name_regexes.size() != 1 ||
498       opts.start_name_regexes[0] != ".*") {
499     roots = SearchRoot(roots, opts.start_name_regexes);
500   }
501 
502   root_->show_children.assign(roots.begin(), roots.end());
503 
504   CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
505 
506   root->formatted_str = FormatLegend(opts) + root->formatted_str;
507 
508   if (opts.output_type == kOutput[3]) {
509     std::vector<uint64> call_ids;
510     pprof_profile_.reset(new PprofProfileImpl(&opts));
511     Format(root, root->show_children, opts, &root->formatted_str,
512            root->mutable_proto(), &call_ids);
513     Status s = pprof_profile_->WritePprofProfile(
514         opts.output_options.at(kPprofOpts[0]));
515     if (!s.ok()) {
516       absl::FPrintF(stderr, "%s\n", s.ToString());
517     }
518   } else {
519     Format(root, root->show_children, opts, &root->formatted_str,
520            root->mutable_proto(), nullptr);
521     if (timeline) {
522       timeline->GenerateCodeTimeline(root);
523     }
524   }
525   return root;
526 }
527 
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)528 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
529                     const Options& opts, string* display_str,
530                     MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
531   if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
532     pprof_profile_->AddSample(root, call_ids);
533   }
534 
535   for (CodeNode* node : nodes) {
536     if (root->has_trace() && opts.output_type == kOutput[3]) {
537       uint64 loc_id = pprof_profile_->AddLocation(node, root);
538       call_ids->push_back(loc_id);
539     }
540     display_str->append(node->formatted_str);
541     MultiGraphNodeProto* child = proto->add_children();
542     child->MergeFrom(node->proto());
543     Format(node, node->show_children, opts, display_str, child, call_ids);
544     if (root->has_trace() && opts.output_type == kOutput[3]) {
545       call_ids->pop_back();
546     }
547   }
548 }
549 
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)550 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
551                                           const std::vector<string>& regexes) {
552   std::vector<CodeNode*> res;
553   if (roots.empty()) {
554     return res;
555   }
556   for (CodeNode* root : roots) {
557     bool match_start_node = false;
558     for (const string& regex : regexes) {
559       if (RE2::FullMatch(root->name(), regex)) {
560         res.push_back(root);
561         match_start_node = true;
562         break;
563       }
564     }
565     if (match_start_node) {
566       // Found a start node at this branch, no need to continue.
567       continue;
568     }
569     std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
570     res.insert(res.end(), nroots.begin(), nroots.end());
571   }
572   return res;
573 }
574 
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)575 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
576                                           const Options& opts, int depth,
577                                           int last_ident) {
578   std::vector<CodeNode*> show_nodes;
579 
580   for (CodeNode* node : roots) {
581     if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
582       continue;
583     }
584     int ident = last_ident;
585     bool show = ShouldShow(node, opts, depth);
586     if (show) ident += 2;
587 
588     std::vector<CodeNode*> show_cnodes =
589         PrintScope(node->show_children, opts, depth + 1, ident);
590     if (show) {
591       node->show_children.clear();
592 
593       show_cnodes = SortNodes(show_cnodes, opts);
594       for (CodeNode* sc : show_cnodes) {
595         node->show_children.push_back(sc);
596       }
597 
598       node->formatted_str = FormatNode(node, opts, last_ident);
599 
600       if (opts.select.find(kShown[4]) != opts.select.end()) {
601         absl::FPrintF(stderr, "code view has no tensor value to show\n");
602       }
603       show_nodes.push_back(node);
604     } else {
605       show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
606                         show_cnodes.end());
607     }
608   }
609   return show_nodes;
610 }
611 
Account(const std::vector<CodeNode * > & roots,const Options & opts)612 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
613                                        const Options& opts) {
614   std::vector<CodeNode*> act_nodes;
615 
616   for (CodeNode* node : roots) {
617     node->ResetTotalStats();
618     std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
619     node->account = ReAccount(node, opts);
620     if (node->account || !act_cnodes.empty()) {
621       node->show_children.clear();
622       node->ResetTotalStats();
623       node->AddSelfToTotalStats();
624       for (CodeNode* c : act_cnodes) {
625         node->AggregateTotalStats(c);
626         node->show_children.push_back(c);
627       }
628       act_nodes.push_back(node);
629     }
630   }
631   return act_nodes;
632 }
633 
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const634 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
635                                 int64 total_bytes) const {
636   string memory = FormatMemory(total_bytes);
637   if (node->account) {
638     memory = FormatMemory(bytes) + "/" + memory;
639   } else {
640     memory = "--/" + memory;
641   }
642   return memory;
643 }
644 
FormatNode(CodeNode * node,const Options & opts,int64 indent) const645 string TFCode::FormatNode(CodeNode* node, const Options& opts,
646                           int64 indent) const {
647   std::vector<string> attrs;
648   if (opts.select.find(kShown[0]) != opts.select.end()) {
649     attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
650                                      node->proto().total_requested_bytes()));
651   }
652   if (opts.select.find(kShown[11]) != opts.select.end()) {
653     attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
654                                      node->proto().total_peak_bytes()));
655   }
656   if (opts.select.find(kShown[12]) != opts.select.end()) {
657     attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
658                                      node->proto().total_residual_bytes()));
659   }
660   if (opts.select.find(kShown[13]) != opts.select.end()) {
661     attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
662                                      node->proto().total_output_bytes()));
663   }
664 
665   std::vector<string> time_attrs = FormatTimes(node, opts);
666   attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
667 
668   if (opts.select.find(kShown[2]) != opts.select.end()) {
669     string params = FormatNumber(node->proto().total_parameters()) + " params";
670     if (node->account) {
671       params = FormatNumber(node->proto().parameters()) + "/" + params;
672     } else {
673       params = "--/" + params;
674     }
675     attrs.push_back(params);
676   }
677 
678   if (opts.select.find(kShown[3]) != opts.select.end()) {
679     string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
680     if (node->account) {
681       fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
682     } else {
683       fops = "--/" + fops;
684     }
685     attrs.push_back(fops);
686   }
687 
688   if (opts.select.find(kShown[5]) != opts.select.end() &&
689       !node->node->devices().empty()) {
690     attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
691   }
692   if (opts.select.find(kShown[6]) != opts.select.end()) {
693     std::set<string> op_types = node->node->op_types();
694     attrs.push_back(absl::StrJoin(op_types, "|"));
695   }
696   if (opts.select.find(kShown[7]) != opts.select.end()) {
697     // TODO(xpan): Make op count available in code view?
698     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
699   }
700   if (opts.select.find(kShown[8]) != opts.select.end()) {
701     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
702   }
703 
704   return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
705                          absl::StrJoin(attrs, ", "));
706 }
707 }  // namespace tfprof
708 }  // namespace tensorflow
709