• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
18 #include <stdio.h>
19 #include <utility>
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/lib/io/zlib_compression_options.h"
25 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
36 const char* const kGradientSuffix = " (gradient)";
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 string GetTraceString(const CallStack::Trace& trace) {
40   string ntrace(io::Basename(trace.file()));
41   ntrace += strings::StrCat(":", trace.lineno());
42   if (trace.function().length() < 20) {
43     ntrace += ":" + trace.function();
44   } else {
45     ntrace += ":" + trace.function().substr(0, 17) + "...";
46   }
47   return ntrace;
48 }
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51   // Given a forward operation with name op, its gradient op has the following
52   // name: ...gradients/op_grad/...
53   // TODO(xpan): This is hacky.
54   auto grad_prefix = name.find("gradients/");
55   auto grad_suffix = name.find("_grad/");
56   if (grad_prefix == name.npos || grad_suffix == name.npos) {
57     return false;
58   }
59   auto start = grad_prefix + string("gradients/").length();
60   auto len = grad_suffix - start;
61   if (len <= 0) {
62     return false;
63   }
64   *forward_name = name.substr(start, len);
65   return true;
66 }
68 // StringTable maps each string to an id.
69 class StringTable {
70  public:
StringTable()71   StringTable() {
72     // Pprof requires first entry in string_table to be ''.
73     string_id_[""] = 0;
74     all_strings_.push_back("");
75   }
77   // Returns the index of a string. If not found, inserts the string and
78   // return the inserted index.
GetIndex(const string & str)79   uint64 GetIndex(const string& str) {
80     auto idx = string_id_.find(str);
81     if (idx != string_id_.end()) {
82       return idx->second;
83     }
84     all_strings_.push_back(str);
85     return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86         .first->second;
87   }
strings() const89   const std::vector<string>& strings() const { return all_strings_; }
91  private:
92   std::map<string, uint64> string_id_;
93   std::vector<string> all_strings_;
94 };
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98  public:
FunctionTable(StringTable * string_table)99   explicit FunctionTable(StringTable* string_table)
100       : string_table_(string_table) {}
102   // Returns the index of a function. If not found, adds a function proto
103   // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104   uint64 GetIndex(const string& file_path, const string& func_name,
105                   uint64 func_start_line) {
106     auto key = std::tuple<string, string, uint64>(file_path, func_name,
107                                                   func_start_line);
108     auto idx = function_table_.find(key);
109     if (idx != function_table_.end()) {
110       return idx->second.id();
111     }
112     pprof::Function* func_pb = &function_table_[key];
113     // function index should start from 1.
114     func_pb->set_id(function_table_.size());
116     string file_base(io::Basename(file_path));
117     file_base = file_base.substr(0, file_base.find_last_of("."));
118     func_pb->set_name(
119         string_table_->GetIndex(strings::StrCat(file_base, ":", func_name)));
120     func_pb->set_filename(string_table_->GetIndex(file_path));
121     func_pb->set_start_line(func_start_line);
122     return func_pb->id();
123   }
125   const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126   functions() const {
127     return function_table_;
128   }
130  private:
131   StringTable* string_table_;
132   std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137  public:
LocationTable(FunctionTable * function_table)138   explicit LocationTable(FunctionTable* function_table)
139       : function_table_(function_table) {}
141   // Returns the index of a function call localtion. If not found, adds a
142   // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143   uint64 GetIndex(const string& file_path, uint64 line_number,
144                   const string& called_function_name,
145                   const string& called_file_path,
146                   uint64 called_func_start_line) {
147     auto key = std::tuple<string, string, uint64>(
148         file_path, called_function_name, line_number);
150     auto idx = location_table_.find(key);
151     if (idx != location_table_.end()) {
152       return idx->second.id();
153     }
154     pprof::Location* location_pb = &location_table_[key];
155     location_pb->set_id(location_table_.size());
156     pprof::Line* line_pb = location_pb->add_line();
157     line_pb->set_function_id(function_table_->GetIndex(
158         called_file_path, called_function_name, called_func_start_line));
159     line_pb->set_line(line_number);
160     return location_pb->id();
161   }
163   const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164   locations() const {
165     return location_table_;
166   }
168  private:
169   FunctionTable* function_table_;
170   std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176  public:
Samples(StringTable * string_table,const Options * opts)177   explicit Samples(StringTable* string_table, const Options* opts)
178       : string_table_(string_table), opts_(opts) {}
180   // 'node' is the leaf of the displayed trace. It includes all graph nodes
181   // created by it. 'location_ids' contains
182   // the call stack, from callee to caller.
183   // This method adds the statistics of graph nodes created by the python
184   // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185   void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186     // displayed leaf might not be true leaf. Retrieve the true leaves for
187     // stats.
188     std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189     CHECK(!all_leaf.empty()) << node->name();
191     for (const CodeNode* cn : all_leaf) {
192       for (auto gn_it : cn->node->graph_nodes()) {
193         const TFGraphNode* gn = gn_it.second;
194         string name = gn->name();
195         // Generate a new trace name, in case the name is taken.
196         while (sample_table_.find(name) != sample_table_.end()) {
197           name += '@';
198         }
199         pprof::Sample* sample_pb = &sample_table_[name];
200         for (uint64 id : location_ids) {
201           sample_pb->mutable_location_id()->Add(id);
202         }
203         pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204         label_pb->set_key(string_table_->GetIndex("graph node:"));
205         label_pb->set_str(string_table_->GetIndex(gn->name()));
207         sample_pb->mutable_value()->Add(1);
208         string type = *opts_->select.begin();
209         if (type == kShown[1]) {
210           sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211         } else if (type == kShown[9]) {
212           sample_pb->mutable_value()->Add(
213               gn->accelerator_exec_micros(node->node->step()));
214         } else if (type == kShown[10]) {
215           sample_pb->mutable_value()->Add(
216               gn->cpu_exec_micros(node->node->step()));
217         } else if (type == kShown[0]) {
218           sample_pb->mutable_value()->Add(
219               gn->requested_bytes(node->node->step()));
220         } else if (type == kShown[11]) {
221           sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222         } else if (type == kShown[12]) {
223           sample_pb->mutable_value()->Add(
224               gn->residual_bytes(node->node->step()));
225         } else if (type == kShown[13]) {
226           sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227         } else if (type == kShown[2]) {
228           sample_pb->mutable_value()->Add(gn->parameters());
229         } else if (type == kShown[3]) {
230           sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231         } else {
232           fprintf(stderr, "pprof doesn't support -select=%s\n", type.c_str());
233         }
234       }
235     }
236   }
samples() const238   const std::map<string, pprof::Sample>& samples() const {
239     return sample_table_;
240   }
242  private:
FetchAllLeaf(const CodeNode * root)243   std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244     if (root->children.empty()) {
245       return {root};
246     }
247     std::vector<const CodeNode*> ret;
248     for (auto& n : root->children) {
249       std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250       ret.insert(ret.end(), nodes.begin(), nodes.end());
251     }
252     return ret;
253   }
255   StringTable* string_table_;
256   const Options* opts_;
257   std::map<string, pprof::Sample> sample_table_;
258 };
260 class PprofProfileImpl : public PprofProfile {
261  public:
PprofProfileImpl(const Options * opts)262   explicit PprofProfileImpl(const Options* opts)
263       : opts_(opts),
264         func_table_(new FunctionTable(&string_table_)),
265         loc_table_(new LocationTable(func_table_.get())),
266         samples_(new Samples(&string_table_, opts)) {}
AddLocation(const CodeNode * callee,const CodeNode * caller)268   uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269     const string& file_path = caller->file();
270     uint64 lineno = caller->lineno();
271     const string& callee_file_path = callee->file();
272     const string& callee_function = callee->function();
273     uint64 callee_func_start_line = callee->func_start_line();
275     return loc_table_->GetIndex(file_path, lineno, callee_function,
276                                 callee_file_path, callee_func_start_line);
277   }
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279   void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280     std::vector<uint64> reversed_call_ids;
281     std::reverse_copy(call_ids->begin(), call_ids->end(),
282                       std::back_inserter(reversed_call_ids));
283     samples_->Add(leaf, reversed_call_ids);
284   }
WritePprofProfile(const string & filename)286   Status WritePprofProfile(const string& filename) override {
287     pprof::Profile profile_pb;
288     Build(&profile_pb);
290     std::unique_ptr<WritableFile> file;
291     Status s = Env::Default()->NewWritableFile(filename, &file);
292     if (!s.ok()) return s;
294     int32 buf_size = 1024 * 1024;
295     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297     s = zlib_output_buffer->Init();
298     if (!s.ok()) return s;
299     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
300     if (!s.ok()) return s;
301     s = zlib_output_buffer->Close();
302     if (!s.ok()) return s;
303     fprintf(stdout, "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
304             filename.c_str());
305     return s;
306   }
308  private:
Build(pprof::Profile * profile_pb)309   void Build(pprof::Profile* profile_pb) {
310     string sample_type_description = "count";
311     auto sample_type = profile_pb->mutable_sample_type()->Add();
312     sample_type->set_type(string_table_.GetIndex(sample_type_description));
313     sample_type->set_unit(string_table_.GetIndex("count"));
315     string type = *opts_->select.begin();
316     sample_type_description = type;
317     sample_type = profile_pb->mutable_sample_type()->Add();
318     sample_type->set_type(string_table_.GetIndex(sample_type_description));
319     if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
320       sample_type->set_unit(string_table_.GetIndex("microseconds"));
321       if (type == kShown[1]) {
322         profile_pb->mutable_comment()->Add(string_table_.GetIndex(
323             "Sum of accelerator execution time and cpu execution time."));
324       } else if (type == kShown[9]) {
325         profile_pb->mutable_comment()->Add(
326             string_table_.GetIndex("Accelerator execution time."));
327       } else if (type == kShown[10]) {
328         profile_pb->mutable_comment()->Add(
329             string_table_.GetIndex("CPU execution time."));
330       }
331     } else if (type == kShown[0]) {
332       sample_type->set_unit(string_table_.GetIndex("bytes"));
333       profile_pb->mutable_comment()->Add(
334           string_table_.GetIndex("Sum of operation total memory requests, "
335                                  "excluding deallocations."));
336     } else if (type == kShown[11]) {
337       sample_type->set_unit(string_table_.GetIndex("bytes"));
338       profile_pb->mutable_comment()->Add(
339           string_table_.GetIndex("Sum of operation peak memory usage."));
340     } else if (type == kShown[12]) {
341       sample_type->set_unit(string_table_.GetIndex("bytes"));
342       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
343           "Sum of operation allocated memory after finish."));
344     } else if (type == kShown[13]) {
345       sample_type->set_unit(string_table_.GetIndex("bytes"));
346       profile_pb->mutable_comment()->Add(
347           string_table_.GetIndex("Sum of operation output size."));
348     } else if (type == kShown[2]) {
349       sample_type->set_unit(string_table_.GetIndex("count"));
350       profile_pb->mutable_comment()->Add(
351           string_table_.GetIndex("Model parameters."));
352     } else if (type == kShown[3]) {
353       sample_type->set_unit(string_table_.GetIndex("count"));
354       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
355           "Model float operations (Only available if defined)."));
356     } else {
357       fprintf(stderr, "pprof doesn't support selecting: %s\n", type.c_str());
358     }
360     for (const string& str : string_table_.strings()) {
361       *profile_pb->mutable_string_table()->Add() = str;
362     }
363     for (const auto& sample_it : samples_->samples()) {
364       // TODO(xpan): Consider swap.
365       profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
366     }
367     for (const auto& function_it : func_table_->functions()) {
368       profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
369     }
370     for (const auto& location_it : loc_table_->locations()) {
371       profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
372     }
373   }
375   const Options* opts_;
376   StringTable string_table_;
377   std::unique_ptr<FunctionTable> func_table_;
378   std::unique_ptr<LocationTable> loc_table_;
379   std::unique_ptr<Samples> samples_;
380 };
381 }  // namespace
AddNode(TFGraphNode * node)383 void TFCode::AddNode(TFGraphNode* node) {
384   if (!node->call_stack() || node->call_stack()->traces().empty()) {
385     return;
386   }
387   // We infer the forward operation name from gradient op name. So, we can
388   // map gradient op traces to forward op traces.
389   // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
390   string forward_name;
391   if (IsGradNode(node->name(), &forward_name)) {
392     auto grad_nodes_it = grad_nodes_.find(forward_name);
393     if (grad_nodes_it != grad_nodes_.end()) {
394       grad_nodes_it->second.push_back(node);
395     } else {
396       grad_nodes_.insert(
397           std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
398     }
399     return;
400   } else {
401     forward_nodes_[node->name()] = node;
402   }
404   if (!root_) {
405     graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
406     root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
407   }
409   CodeNode* pre_code_node = root_.get();
410   // TODO(xpan): Consider to release CodeDef after TFCode is built. It
411   // takes a lot of memory.
412   std::set<string> traces;
413   for (int i = 0; i < node->call_stack()->traces().size(); ++i) {
414     // Unlike op name, which is globally unique, trace name is only unique
415     // w.r.t. it's parent.
416     const string& trace = GetTraceString(node->call_stack()->traces().at(i));
417     traces.insert(trace);
418     pre_code_node = pre_code_node->AddChildren(
419         trace, &node->call_stack()->traces().at(i), "");
420     if (i == node->call_stack()->traces().size() - 1) {
421       pre_code_node->node->AddGraphNode(node);
422     }
423   }
424 }
Build()426 void TFCode::Build() {
427   int64 unaccounted_nodes = 0;
428   for (auto it : grad_nodes_) {
429     const string& forward_name = it.first;
430     auto forward_it = forward_nodes_.find(forward_name);
431     if (forward_it == forward_nodes_.end()) {
432       unaccounted_nodes += 1;
433       continue;
434     }
435     TFGraphNode* fn = forward_it->second;
436     CodeNode* leaf = nullptr;
437     CodeNode* pre_code_node = root_.get();
438     for (int i = 0; i < fn->call_stack()->traces().size(); ++i) {
439       const string& trace =
440           GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
441       pre_code_node = pre_code_node->AddChildren(
442           trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
443       if (i == fn->call_stack()->traces().size() - 1) {
444         leaf = pre_code_node;
445       }
446     }
447     for (TFGraphNode* gn : it.second) {
448       leaf->node->AddGraphNode(gn);
449     }
450   }
451   if (unaccounted_nodes > 0) {
452     fprintf(stderr, "%lld gradient nodes not accounted\n", unaccounted_nodes);
453   }
454 }
ShowInternal(const Options & opts,Timeline * timeline)456 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
457                                           Timeline* timeline) {
458   root_->ResetTotalStats();
459   if (opts.output_type == kOutput[3]) {
460     if (opts.select.size() != 1) {
461       fprintf(stderr, "Can only select 1 attribute for pprof output.\n");
462       return root_.get();
463     }
464     string select = *opts.select.begin();
465     if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
466         select != kShown[3] && select != kShown[9] && select != kShown[10] &&
467         select != kShown[11] && select != kShown[12] && select != kShown[13]) {
468       fprintf(stderr, "pprof doesn't support -select=%s\n", select.c_str());
469       return root_.get();
470     }
471   }
472   if (opts.account_displayed_op_only) {
473     fprintf(stderr, "Note: code view ignores account_displayed_op_only\n");
474   }
476   std::vector<CodeNode*> roots = Account(root_->children, opts);
477   root_->show_children.clear();
478   for (CodeNode* n : roots) {
479     root_->AggregateTotalStats(n);
480   }
482   if (opts.start_name_regexes.size() != 1 ||
483       opts.start_name_regexes[0] != ".*") {
484     roots = SearchRoot(roots, opts.start_name_regexes);
485   }
487   root_->show_children.assign(roots.begin(), roots.end());
489   CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
491   root->formatted_str = FormatLegend(opts) + root->formatted_str;
493   if (opts.output_type == kOutput[3]) {
494     std::vector<uint64> call_ids;
495     pprof_profile_.reset(new PprofProfileImpl(&opts));
496     Format(root, root->show_children, opts, &root->formatted_str,
497            root->mutable_proto(), &call_ids);
498     Status s = pprof_profile_->WritePprofProfile(
499         opts.output_options.at(kPprofOpts[0]));
500     if (!s.ok()) {
501       fprintf(stderr, "%s\n", s.ToString().c_str());
502     }
503   } else {
504     Format(root, root->show_children, opts, &root->formatted_str,
505            root->mutable_proto(), nullptr);
506     if (timeline) {
507       timeline->GenerateCodeTimeline(root);
508     }
509   }
510   return root;
511 }
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)513 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
514                     const Options& opts, string* display_str,
515                     MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
516   if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
517     pprof_profile_->AddSample(root, call_ids);
518   }
520   for (CodeNode* node : nodes) {
521     if (root->has_trace() && opts.output_type == kOutput[3]) {
522       uint64 loc_id = pprof_profile_->AddLocation(node, root);
523       call_ids->push_back(loc_id);
524     }
525     display_str->append(node->formatted_str);
526     MultiGraphNodeProto* child = proto->add_children();
527     child->MergeFrom(node->proto());
528     Format(node, node->show_children, opts, display_str, child, call_ids);
529     if (root->has_trace() && opts.output_type == kOutput[3]) {
530       call_ids->pop_back();
531     }
532   }
533 }
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)535 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
536                                           const std::vector<string>& regexes) {
537   std::vector<CodeNode*> res;
538   if (roots.empty()) {
539     return res;
540   }
541   for (CodeNode* root : roots) {
542     bool match_start_node = false;
543     for (const string& regex : regexes) {
544       if (RE2::FullMatch(root->name(), regex)) {
545         res.push_back(root);
546         match_start_node = true;
547         break;
548       }
549     }
550     if (match_start_node) {
551       // Found a start node at this branch, no need to continue.
552       continue;
553     }
554     std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
555     res.insert(res.end(), nroots.begin(), nroots.end());
556   }
557   return res;
558 }
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)560 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
561                                           const Options& opts, int depth,
562                                           int last_ident) {
563   std::vector<CodeNode*> show_nodes;
565   for (CodeNode* node : roots) {
566     if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
567       continue;
568     }
569     int ident = last_ident;
570     bool show = ShouldShow(node, opts, depth);
571     if (show) ident += 2;
573     std::vector<CodeNode*> show_cnodes =
574         PrintScope(node->show_children, opts, depth + 1, ident);
575     if (show) {
576       node->show_children.clear();
578       show_cnodes = SortNodes(show_cnodes, opts);
579       for (CodeNode* sc : show_cnodes) {
580         node->show_children.push_back(sc);
581       }
583       node->formatted_str = FormatNode(node, opts, last_ident);
585       if (opts.select.find(kShown[4]) != opts.select.end()) {
586         fprintf(stderr, "code view has no tensor value to show\n");
587       }
588       show_nodes.push_back(node);
589     } else {
590       show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
591                         show_cnodes.end());
592     }
593   }
594   return show_nodes;
595 }
Account(const std::vector<CodeNode * > & roots,const Options & opts)597 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
598                                        const Options& opts) {
599   std::vector<CodeNode*> act_nodes;
601   for (CodeNode* node : roots) {
602     node->ResetTotalStats();
603     std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
604     node->account = ReAccount(node, opts);
605     if (node->account || !act_cnodes.empty()) {
606       node->show_children.clear();
607       node->ResetTotalStats();
608       node->AddSelfToTotalStats();
609       for (CodeNode* c : act_cnodes) {
610         node->AggregateTotalStats(c);
611         node->show_children.push_back(c);
612       }
613       act_nodes.push_back(node);
614     }
615   }
616   return act_nodes;
617 }
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const619 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
620                                 int64 total_bytes) const {
621   string memory = FormatMemory(total_bytes);
622   if (node->account) {
623     memory = FormatMemory(bytes) + "/" + memory;
624   } else {
625     memory = "--/" + memory;
626   }
627   return memory;
628 }
FormatNode(CodeNode * node,const Options & opts,int64 indent) const630 string TFCode::FormatNode(CodeNode* node, const Options& opts,
631                           int64 indent) const {
632   std::vector<string> attrs;
633   if (opts.select.find(kShown[0]) != opts.select.end()) {
634     attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
635                                      node->proto().total_requested_bytes()));
636   }
637   if (opts.select.find(kShown[11]) != opts.select.end()) {
638     attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
639                                      node->proto().total_peak_bytes()));
640   }
641   if (opts.select.find(kShown[12]) != opts.select.end()) {
642     attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
643                                      node->proto().total_residual_bytes()));
644   }
645   if (opts.select.find(kShown[13]) != opts.select.end()) {
646     attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
647                                      node->proto().total_output_bytes()));
648   }
650   std::vector<string> time_attrs = FormatTimes(node, opts);
651   attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
653   if (opts.select.find(kShown[2]) != opts.select.end()) {
654     string params = FormatNumber(node->proto().total_parameters()) + " params";
655     if (node->account) {
656       params = FormatNumber(node->proto().parameters()) + "/" + params;
657     } else {
658       params = "--/" + params;
659     }
660     attrs.push_back(params);
661   }
663   if (opts.select.find(kShown[3]) != opts.select.end()) {
664     string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
665     if (node->account) {
666       fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
667     } else {
668       fops = "--/" + fops;
669     }
670     attrs.push_back(fops);
671   }
673   if (opts.select.find(kShown[5]) != opts.select.end() &&
674       !node->node->devices().empty()) {
675     attrs.push_back(str_util::Join(node->node->devices(), "|"));
676   }
677   if (opts.select.find(kShown[6]) != opts.select.end()) {
678     std::set<string> op_types = node->node->op_types();
679     attrs.push_back(str_util::Join(op_types, "|"));
680   }
681   if (opts.select.find(kShown[7]) != opts.select.end()) {
682     // TODO(xpan): Make op count available in code view?
683     attrs.push_back(strings::Printf("%s N/A in code view", kShown[7]));
684   }
685   if (opts.select.find(kShown[8]) != opts.select.end()) {
686     attrs.push_back(strings::Printf("%s N/A in code view", kShown[8]));
687   }
689   return strings::Printf("%s%s (%s)\n", string(indent, ' ').c_str(),
690                          node->name().c_str(),
691                          str_util::Join(attrs, ", ").c_str());
692 }
693 }  // namespace tfprof
694 }  // namespace tensorflow