1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17
18 #include <stdio.h>
19 #include <utility>
20
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/lib/io/zlib_compression_options.h"
25 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
31
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
35
36 const char* const kGradientSuffix = " (gradient)";
37
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 string GetTraceString(const CallStack::Trace& trace) {
40 string ntrace(io::Basename(trace.file()));
41 ntrace += strings::StrCat(":", trace.lineno());
42 if (trace.function().length() < 20) {
43 ntrace += ":" + trace.function();
44 } else {
45 ntrace += ":" + trace.function().substr(0, 17) + "...";
46 }
47 return ntrace;
48 }
49
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51 // Given a forward operation with name op, its gradient op has the following
52 // name: ...gradients/op_grad/...
53 // TODO(xpan): This is hacky.
54 auto grad_prefix = name.find("gradients/");
55 auto grad_suffix = name.find("_grad/");
56 if (grad_prefix == name.npos || grad_suffix == name.npos) {
57 return false;
58 }
59 auto start = grad_prefix + string("gradients/").length();
60 auto len = grad_suffix - start;
61 if (len <= 0) {
62 return false;
63 }
64 *forward_name = name.substr(start, len);
65 return true;
66 }
67
68 // StringTable maps each string to an id.
69 class StringTable {
70 public:
StringTable()71 StringTable() {
72 // Pprof requires first entry in string_table to be ''.
73 string_id_[""] = 0;
74 all_strings_.push_back("");
75 }
76
77 // Returns the index of a string. If not found, inserts the string and
78 // return the inserted index.
GetIndex(const string & str)79 uint64 GetIndex(const string& str) {
80 auto idx = string_id_.find(str);
81 if (idx != string_id_.end()) {
82 return idx->second;
83 }
84 all_strings_.push_back(str);
85 return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86 .first->second;
87 }
88
strings() const89 const std::vector<string>& strings() const { return all_strings_; }
90
91 private:
92 std::map<string, uint64> string_id_;
93 std::vector<string> all_strings_;
94 };
95
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98 public:
FunctionTable(StringTable * string_table)99 explicit FunctionTable(StringTable* string_table)
100 : string_table_(string_table) {}
101
102 // Returns the index of a function. If not found, adds a function proto
103 // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104 uint64 GetIndex(const string& file_path, const string& func_name,
105 uint64 func_start_line) {
106 auto key = std::tuple<string, string, uint64>(file_path, func_name,
107 func_start_line);
108 auto idx = function_table_.find(key);
109 if (idx != function_table_.end()) {
110 return idx->second.id();
111 }
112 pprof::Function* func_pb = &function_table_[key];
113 // function index should start from 1.
114 func_pb->set_id(function_table_.size());
115
116 string file_base(io::Basename(file_path));
117 file_base = file_base.substr(0, file_base.find_last_of("."));
118 func_pb->set_name(
119 string_table_->GetIndex(strings::StrCat(file_base, ":", func_name)));
120 func_pb->set_filename(string_table_->GetIndex(file_path));
121 func_pb->set_start_line(func_start_line);
122 return func_pb->id();
123 }
124
125 const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126 functions() const {
127 return function_table_;
128 }
129
130 private:
131 StringTable* string_table_;
132 std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
134
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137 public:
LocationTable(FunctionTable * function_table)138 explicit LocationTable(FunctionTable* function_table)
139 : function_table_(function_table) {}
140
141 // Returns the index of a function call localtion. If not found, adds a
142 // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143 uint64 GetIndex(const string& file_path, uint64 line_number,
144 const string& called_function_name,
145 const string& called_file_path,
146 uint64 called_func_start_line) {
147 auto key = std::tuple<string, string, uint64>(
148 file_path, called_function_name, line_number);
149
150 auto idx = location_table_.find(key);
151 if (idx != location_table_.end()) {
152 return idx->second.id();
153 }
154 pprof::Location* location_pb = &location_table_[key];
155 location_pb->set_id(location_table_.size());
156 pprof::Line* line_pb = location_pb->add_line();
157 line_pb->set_function_id(function_table_->GetIndex(
158 called_file_path, called_function_name, called_func_start_line));
159 line_pb->set_line(line_number);
160 return location_pb->id();
161 }
162
163 const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164 locations() const {
165 return location_table_;
166 }
167
168 private:
169 FunctionTable* function_table_;
170 std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
172
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176 public:
Samples(StringTable * string_table,const Options * opts)177 explicit Samples(StringTable* string_table, const Options* opts)
178 : string_table_(string_table), opts_(opts) {}
179
180 // 'node' is the leaf of the displayed trace. It includes all graph nodes
181 // created by it. 'location_ids' contains
182 // the call stack, from callee to caller.
183 // This method adds the statistics of graph nodes created by the python
184 // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185 void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186 // displayed leaf might not be true leaf. Retrieve the true leaves for
187 // stats.
188 std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189 CHECK(!all_leaf.empty()) << node->name();
190
191 for (const CodeNode* cn : all_leaf) {
192 for (auto gn_it : cn->node->graph_nodes()) {
193 const TFGraphNode* gn = gn_it.second;
194 string name = gn->name();
195 // Generate a new trace name, in case the name is taken.
196 while (sample_table_.find(name) != sample_table_.end()) {
197 name += '@';
198 }
199 pprof::Sample* sample_pb = &sample_table_[name];
200 for (uint64 id : location_ids) {
201 sample_pb->mutable_location_id()->Add(id);
202 }
203 pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204 label_pb->set_key(string_table_->GetIndex("graph node:"));
205 label_pb->set_str(string_table_->GetIndex(gn->name()));
206
207 sample_pb->mutable_value()->Add(1);
208 string type = *opts_->select.begin();
209 if (type == kShown[1]) {
210 sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211 } else if (type == kShown[9]) {
212 sample_pb->mutable_value()->Add(
213 gn->accelerator_exec_micros(node->node->step()));
214 } else if (type == kShown[10]) {
215 sample_pb->mutable_value()->Add(
216 gn->cpu_exec_micros(node->node->step()));
217 } else if (type == kShown[0]) {
218 sample_pb->mutable_value()->Add(
219 gn->requested_bytes(node->node->step()));
220 } else if (type == kShown[11]) {
221 sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222 } else if (type == kShown[12]) {
223 sample_pb->mutable_value()->Add(
224 gn->residual_bytes(node->node->step()));
225 } else if (type == kShown[13]) {
226 sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227 } else if (type == kShown[2]) {
228 sample_pb->mutable_value()->Add(gn->parameters());
229 } else if (type == kShown[3]) {
230 sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231 } else {
232 fprintf(stderr, "pprof doesn't support -select=%s\n", type.c_str());
233 }
234 }
235 }
236 }
237
samples() const238 const std::map<string, pprof::Sample>& samples() const {
239 return sample_table_;
240 }
241
242 private:
FetchAllLeaf(const CodeNode * root)243 std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244 if (root->children.empty()) {
245 return {root};
246 }
247 std::vector<const CodeNode*> ret;
248 for (auto& n : root->children) {
249 std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250 ret.insert(ret.end(), nodes.begin(), nodes.end());
251 }
252 return ret;
253 }
254
255 StringTable* string_table_;
256 const Options* opts_;
257 std::map<string, pprof::Sample> sample_table_;
258 };
259
260 class PprofProfileImpl : public PprofProfile {
261 public:
PprofProfileImpl(const Options * opts)262 explicit PprofProfileImpl(const Options* opts)
263 : opts_(opts),
264 func_table_(new FunctionTable(&string_table_)),
265 loc_table_(new LocationTable(func_table_.get())),
266 samples_(new Samples(&string_table_, opts)) {}
267
AddLocation(const CodeNode * callee,const CodeNode * caller)268 uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269 const string& file_path = caller->file();
270 uint64 lineno = caller->lineno();
271 const string& callee_file_path = callee->file();
272 const string& callee_function = callee->function();
273 uint64 callee_func_start_line = callee->func_start_line();
274
275 return loc_table_->GetIndex(file_path, lineno, callee_function,
276 callee_file_path, callee_func_start_line);
277 }
278
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279 void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280 std::vector<uint64> reversed_call_ids;
281 std::reverse_copy(call_ids->begin(), call_ids->end(),
282 std::back_inserter(reversed_call_ids));
283 samples_->Add(leaf, reversed_call_ids);
284 }
285
WritePprofProfile(const string & filename)286 Status WritePprofProfile(const string& filename) override {
287 pprof::Profile profile_pb;
288 Build(&profile_pb);
289
290 std::unique_ptr<WritableFile> file;
291 Status s = Env::Default()->NewWritableFile(filename, &file);
292 if (!s.ok()) return s;
293
294 int32 buf_size = 1024 * 1024;
295 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296 file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297 s = zlib_output_buffer->Init();
298 if (!s.ok()) return s;
299 s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
300 if (!s.ok()) return s;
301 s = zlib_output_buffer->Close();
302 if (!s.ok()) return s;
303 fprintf(stdout, "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
304 filename.c_str());
305 return s;
306 }
307
308 private:
Build(pprof::Profile * profile_pb)309 void Build(pprof::Profile* profile_pb) {
310 string sample_type_description = "count";
311 auto sample_type = profile_pb->mutable_sample_type()->Add();
312 sample_type->set_type(string_table_.GetIndex(sample_type_description));
313 sample_type->set_unit(string_table_.GetIndex("count"));
314
315 string type = *opts_->select.begin();
316 sample_type_description = type;
317 sample_type = profile_pb->mutable_sample_type()->Add();
318 sample_type->set_type(string_table_.GetIndex(sample_type_description));
319 if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
320 sample_type->set_unit(string_table_.GetIndex("microseconds"));
321 if (type == kShown[1]) {
322 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
323 "Sum of accelerator execution time and cpu execution time."));
324 } else if (type == kShown[9]) {
325 profile_pb->mutable_comment()->Add(
326 string_table_.GetIndex("Accelerator execution time."));
327 } else if (type == kShown[10]) {
328 profile_pb->mutable_comment()->Add(
329 string_table_.GetIndex("CPU execution time."));
330 }
331 } else if (type == kShown[0]) {
332 sample_type->set_unit(string_table_.GetIndex("bytes"));
333 profile_pb->mutable_comment()->Add(
334 string_table_.GetIndex("Sum of operation total memory requests, "
335 "excluding deallocations."));
336 } else if (type == kShown[11]) {
337 sample_type->set_unit(string_table_.GetIndex("bytes"));
338 profile_pb->mutable_comment()->Add(
339 string_table_.GetIndex("Sum of operation peak memory usage."));
340 } else if (type == kShown[12]) {
341 sample_type->set_unit(string_table_.GetIndex("bytes"));
342 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
343 "Sum of operation allocated memory after finish."));
344 } else if (type == kShown[13]) {
345 sample_type->set_unit(string_table_.GetIndex("bytes"));
346 profile_pb->mutable_comment()->Add(
347 string_table_.GetIndex("Sum of operation output size."));
348 } else if (type == kShown[2]) {
349 sample_type->set_unit(string_table_.GetIndex("count"));
350 profile_pb->mutable_comment()->Add(
351 string_table_.GetIndex("Model parameters."));
352 } else if (type == kShown[3]) {
353 sample_type->set_unit(string_table_.GetIndex("count"));
354 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
355 "Model float operations (Only available if defined)."));
356 } else {
357 fprintf(stderr, "pprof doesn't support selecting: %s\n", type.c_str());
358 }
359
360 for (const string& str : string_table_.strings()) {
361 *profile_pb->mutable_string_table()->Add() = str;
362 }
363 for (const auto& sample_it : samples_->samples()) {
364 // TODO(xpan): Consider swap.
365 profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
366 }
367 for (const auto& function_it : func_table_->functions()) {
368 profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
369 }
370 for (const auto& location_it : loc_table_->locations()) {
371 profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
372 }
373 }
374
375 const Options* opts_;
376 StringTable string_table_;
377 std::unique_ptr<FunctionTable> func_table_;
378 std::unique_ptr<LocationTable> loc_table_;
379 std::unique_ptr<Samples> samples_;
380 };
381 } // namespace
382
AddNode(TFGraphNode * node)383 void TFCode::AddNode(TFGraphNode* node) {
384 if (!node->call_stack() || node->call_stack()->traces().empty()) {
385 return;
386 }
387 // We infer the forward operation name from gradient op name. So, we can
388 // map gradient op traces to forward op traces.
389 // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
390 string forward_name;
391 if (IsGradNode(node->name(), &forward_name)) {
392 auto grad_nodes_it = grad_nodes_.find(forward_name);
393 if (grad_nodes_it != grad_nodes_.end()) {
394 grad_nodes_it->second.push_back(node);
395 } else {
396 grad_nodes_.insert(
397 std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
398 }
399 return;
400 } else {
401 forward_nodes_[node->name()] = node;
402 }
403
404 if (!root_) {
405 graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
406 root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
407 }
408
409 CodeNode* pre_code_node = root_.get();
410 // TODO(xpan): Consider to release CodeDef after TFCode is built. It
411 // takes a lot of memory.
412 std::set<string> traces;
413 for (int i = 0; i < node->call_stack()->traces().size(); ++i) {
414 // Unlike op name, which is globally unique, trace name is only unique
415 // w.r.t. it's parent.
416 const string& trace = GetTraceString(node->call_stack()->traces().at(i));
417 traces.insert(trace);
418 pre_code_node = pre_code_node->AddChildren(
419 trace, &node->call_stack()->traces().at(i), "");
420 if (i == node->call_stack()->traces().size() - 1) {
421 pre_code_node->node->AddGraphNode(node);
422 }
423 }
424 }
425
Build()426 void TFCode::Build() {
427 int64 unaccounted_nodes = 0;
428 for (auto it : grad_nodes_) {
429 const string& forward_name = it.first;
430 auto forward_it = forward_nodes_.find(forward_name);
431 if (forward_it == forward_nodes_.end()) {
432 unaccounted_nodes += 1;
433 continue;
434 }
435 TFGraphNode* fn = forward_it->second;
436 CodeNode* leaf = nullptr;
437 CodeNode* pre_code_node = root_.get();
438 for (int i = 0; i < fn->call_stack()->traces().size(); ++i) {
439 const string& trace =
440 GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
441 pre_code_node = pre_code_node->AddChildren(
442 trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
443 if (i == fn->call_stack()->traces().size() - 1) {
444 leaf = pre_code_node;
445 }
446 }
447 for (TFGraphNode* gn : it.second) {
448 leaf->node->AddGraphNode(gn);
449 }
450 }
451 if (unaccounted_nodes > 0) {
452 fprintf(stderr, "%lld gradient nodes not accounted\n", unaccounted_nodes);
453 }
454 }
455
ShowInternal(const Options & opts,Timeline * timeline)456 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
457 Timeline* timeline) {
458 root_->ResetTotalStats();
459 if (opts.output_type == kOutput[3]) {
460 if (opts.select.size() != 1) {
461 fprintf(stderr, "Can only select 1 attribute for pprof output.\n");
462 return root_.get();
463 }
464 string select = *opts.select.begin();
465 if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
466 select != kShown[3] && select != kShown[9] && select != kShown[10] &&
467 select != kShown[11] && select != kShown[12] && select != kShown[13]) {
468 fprintf(stderr, "pprof doesn't support -select=%s\n", select.c_str());
469 return root_.get();
470 }
471 }
472 if (opts.account_displayed_op_only) {
473 fprintf(stderr, "Note: code view ignores account_displayed_op_only\n");
474 }
475
476 std::vector<CodeNode*> roots = Account(root_->children, opts);
477 root_->show_children.clear();
478 for (CodeNode* n : roots) {
479 root_->AggregateTotalStats(n);
480 }
481
482 if (opts.start_name_regexes.size() != 1 ||
483 opts.start_name_regexes[0] != ".*") {
484 roots = SearchRoot(roots, opts.start_name_regexes);
485 }
486
487 root_->show_children.assign(roots.begin(), roots.end());
488
489 CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
490
491 root->formatted_str = FormatLegend(opts) + root->formatted_str;
492
493 if (opts.output_type == kOutput[3]) {
494 std::vector<uint64> call_ids;
495 pprof_profile_.reset(new PprofProfileImpl(&opts));
496 Format(root, root->show_children, opts, &root->formatted_str,
497 root->mutable_proto(), &call_ids);
498 Status s = pprof_profile_->WritePprofProfile(
499 opts.output_options.at(kPprofOpts[0]));
500 if (!s.ok()) {
501 fprintf(stderr, "%s\n", s.ToString().c_str());
502 }
503 } else {
504 Format(root, root->show_children, opts, &root->formatted_str,
505 root->mutable_proto(), nullptr);
506 if (timeline) {
507 timeline->GenerateCodeTimeline(root);
508 }
509 }
510 return root;
511 }
512
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)513 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
514 const Options& opts, string* display_str,
515 MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
516 if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
517 pprof_profile_->AddSample(root, call_ids);
518 }
519
520 for (CodeNode* node : nodes) {
521 if (root->has_trace() && opts.output_type == kOutput[3]) {
522 uint64 loc_id = pprof_profile_->AddLocation(node, root);
523 call_ids->push_back(loc_id);
524 }
525 display_str->append(node->formatted_str);
526 MultiGraphNodeProto* child = proto->add_children();
527 child->MergeFrom(node->proto());
528 Format(node, node->show_children, opts, display_str, child, call_ids);
529 if (root->has_trace() && opts.output_type == kOutput[3]) {
530 call_ids->pop_back();
531 }
532 }
533 }
534
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)535 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
536 const std::vector<string>& regexes) {
537 std::vector<CodeNode*> res;
538 if (roots.empty()) {
539 return res;
540 }
541 for (CodeNode* root : roots) {
542 bool match_start_node = false;
543 for (const string& regex : regexes) {
544 if (RE2::FullMatch(root->name(), regex)) {
545 res.push_back(root);
546 match_start_node = true;
547 break;
548 }
549 }
550 if (match_start_node) {
551 // Found a start node at this branch, no need to continue.
552 continue;
553 }
554 std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
555 res.insert(res.end(), nroots.begin(), nroots.end());
556 }
557 return res;
558 }
559
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)560 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
561 const Options& opts, int depth,
562 int last_ident) {
563 std::vector<CodeNode*> show_nodes;
564
565 for (CodeNode* node : roots) {
566 if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
567 continue;
568 }
569 int ident = last_ident;
570 bool show = ShouldShow(node, opts, depth);
571 if (show) ident += 2;
572
573 std::vector<CodeNode*> show_cnodes =
574 PrintScope(node->show_children, opts, depth + 1, ident);
575 if (show) {
576 node->show_children.clear();
577
578 show_cnodes = SortNodes(show_cnodes, opts);
579 for (CodeNode* sc : show_cnodes) {
580 node->show_children.push_back(sc);
581 }
582
583 node->formatted_str = FormatNode(node, opts, last_ident);
584
585 if (opts.select.find(kShown[4]) != opts.select.end()) {
586 fprintf(stderr, "code view has no tensor value to show\n");
587 }
588 show_nodes.push_back(node);
589 } else {
590 show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
591 show_cnodes.end());
592 }
593 }
594 return show_nodes;
595 }
596
Account(const std::vector<CodeNode * > & roots,const Options & opts)597 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
598 const Options& opts) {
599 std::vector<CodeNode*> act_nodes;
600
601 for (CodeNode* node : roots) {
602 node->ResetTotalStats();
603 std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
604 node->account = ReAccount(node, opts);
605 if (node->account || !act_cnodes.empty()) {
606 node->show_children.clear();
607 node->ResetTotalStats();
608 node->AddSelfToTotalStats();
609 for (CodeNode* c : act_cnodes) {
610 node->AggregateTotalStats(c);
611 node->show_children.push_back(c);
612 }
613 act_nodes.push_back(node);
614 }
615 }
616 return act_nodes;
617 }
618
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const619 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
620 int64 total_bytes) const {
621 string memory = FormatMemory(total_bytes);
622 if (node->account) {
623 memory = FormatMemory(bytes) + "/" + memory;
624 } else {
625 memory = "--/" + memory;
626 }
627 return memory;
628 }
629
FormatNode(CodeNode * node,const Options & opts,int64 indent) const630 string TFCode::FormatNode(CodeNode* node, const Options& opts,
631 int64 indent) const {
632 std::vector<string> attrs;
633 if (opts.select.find(kShown[0]) != opts.select.end()) {
634 attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
635 node->proto().total_requested_bytes()));
636 }
637 if (opts.select.find(kShown[11]) != opts.select.end()) {
638 attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
639 node->proto().total_peak_bytes()));
640 }
641 if (opts.select.find(kShown[12]) != opts.select.end()) {
642 attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
643 node->proto().total_residual_bytes()));
644 }
645 if (opts.select.find(kShown[13]) != opts.select.end()) {
646 attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
647 node->proto().total_output_bytes()));
648 }
649
650 std::vector<string> time_attrs = FormatTimes(node, opts);
651 attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
652
653 if (opts.select.find(kShown[2]) != opts.select.end()) {
654 string params = FormatNumber(node->proto().total_parameters()) + " params";
655 if (node->account) {
656 params = FormatNumber(node->proto().parameters()) + "/" + params;
657 } else {
658 params = "--/" + params;
659 }
660 attrs.push_back(params);
661 }
662
663 if (opts.select.find(kShown[3]) != opts.select.end()) {
664 string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
665 if (node->account) {
666 fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
667 } else {
668 fops = "--/" + fops;
669 }
670 attrs.push_back(fops);
671 }
672
673 if (opts.select.find(kShown[5]) != opts.select.end() &&
674 !node->node->devices().empty()) {
675 attrs.push_back(str_util::Join(node->node->devices(), "|"));
676 }
677 if (opts.select.find(kShown[6]) != opts.select.end()) {
678 std::set<string> op_types = node->node->op_types();
679 attrs.push_back(str_util::Join(op_types, "|"));
680 }
681 if (opts.select.find(kShown[7]) != opts.select.end()) {
682 // TODO(xpan): Make op count available in code view?
683 attrs.push_back(strings::Printf("%s N/A in code view", kShown[7]));
684 }
685 if (opts.select.find(kShown[8]) != opts.select.end()) {
686 attrs.push_back(strings::Printf("%s N/A in code view", kShown[8]));
687 }
688
689 return strings::Printf("%s%s (%s)\n", string(indent, ' ').c_str(),
690 node->name().c_str(),
691 str_util::Join(attrs, ", ").c_str());
692 }
693 } // namespace tfprof
694 } // namespace tensorflow
695