1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17
18 #include <stdio.h>
19
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/io/zlib_compression_options.h"
28 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
31
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
35
36 const char* const kGradientSuffix = " (gradient)";
37
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 std::string GetTraceString(const CallStack::Trace& trace) {
40 std::string ntrace =
41 absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
42 if (trace.function().length() < 20) {
43 absl::StrAppend(&ntrace, ":", trace.function());
44 } else {
45 absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
46 }
47 return ntrace;
48 }
49
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51 // Given a forward operation with name op, its gradient op has the following
52 // name: ...gradients/op_grad/...
53 // TODO(xpan): This is hacky.
54 auto grad_prefix = name.find("gradients/");
55 auto grad_suffix = name.find("_grad/");
56 if (grad_prefix == name.npos || grad_suffix == name.npos) {
57 return false;
58 }
59 auto start = grad_prefix + string("gradients/").length();
60 auto len = grad_suffix - start;
61 if (len <= 0) {
62 return false;
63 }
64 *forward_name = name.substr(start, len);
65 return true;
66 }
67
68 // StringTable maps each string to an id.
69 class StringTable {
70 public:
StringTable()71 StringTable() {
72 // Pprof requires first entry in string_table to be ''.
73 string_id_[""] = 0;
74 all_strings_.push_back("");
75 }
76
77 // Returns the index of a string. If not found, inserts the string and
78 // return the inserted index.
GetIndex(const string & str)79 uint64 GetIndex(const string& str) {
80 auto idx = string_id_.find(str);
81 if (idx != string_id_.end()) {
82 return idx->second;
83 }
84 all_strings_.push_back(str);
85 return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86 .first->second;
87 }
88
strings() const89 const std::vector<string>& strings() const { return all_strings_; }
90
91 private:
92 std::map<string, uint64> string_id_;
93 std::vector<string> all_strings_;
94 };
95
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98 public:
FunctionTable(StringTable * string_table)99 explicit FunctionTable(StringTable* string_table)
100 : string_table_(string_table) {}
101
102 // Returns the index of a function. If not found, adds a function proto
103 // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104 uint64 GetIndex(const string& file_path, const string& func_name,
105 uint64 func_start_line) {
106 auto key = std::tuple<string, string, uint64>(file_path, func_name,
107 func_start_line);
108 auto idx = function_table_.find(key);
109 if (idx != function_table_.end()) {
110 return idx->second.id();
111 }
112 pprof::Function* func_pb = &function_table_[key];
113 // function index should start from 1.
114 func_pb->set_id(function_table_.size());
115
116 string file_base(io::Basename(file_path));
117 file_base = file_base.substr(0, file_base.find_last_of('.'));
118 func_pb->set_name(
119 string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
120 func_pb->set_filename(string_table_->GetIndex(file_path));
121 func_pb->set_start_line(func_start_line);
122 return func_pb->id();
123 }
124
125 const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126 functions() const {
127 return function_table_;
128 }
129
130 private:
131 StringTable* string_table_;
132 std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
134
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137 public:
LocationTable(FunctionTable * function_table)138 explicit LocationTable(FunctionTable* function_table)
139 : function_table_(function_table) {}
140
141 // Returns the index of a function call location. If not found, adds a
142 // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143 uint64 GetIndex(const string& file_path, uint64 line_number,
144 const string& called_function_name,
145 const string& called_file_path,
146 uint64 called_func_start_line) {
147 auto key = std::tuple<string, string, uint64>(
148 file_path, called_function_name, line_number);
149
150 auto idx = location_table_.find(key);
151 if (idx != location_table_.end()) {
152 return idx->second.id();
153 }
154 pprof::Location* location_pb = &location_table_[key];
155 location_pb->set_id(location_table_.size());
156 pprof::Line* line_pb = location_pb->add_line();
157 line_pb->set_function_id(function_table_->GetIndex(
158 called_file_path, called_function_name, called_func_start_line));
159 line_pb->set_line(line_number);
160 return location_pb->id();
161 }
162
163 const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164 locations() const {
165 return location_table_;
166 }
167
168 private:
169 FunctionTable* function_table_;
170 std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
172
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176 public:
Samples(StringTable * string_table,const Options * opts)177 explicit Samples(StringTable* string_table, const Options* opts)
178 : string_table_(string_table), opts_(opts) {}
179
180 // 'node' is the leaf of the displayed trace. It includes all graph nodes
181 // created by it. 'location_ids' contains
182 // the call stack, from callee to caller.
183 // This method adds the statistics of graph nodes created by the python
184 // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185 void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186 // displayed leaf might not be true leaf. Retrieve the true leaves for
187 // stats.
188 std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189 CHECK(!all_leaf.empty()) << node->name();
190
191 for (const CodeNode* cn : all_leaf) {
192 for (const auto& gn_it : cn->node->graph_nodes()) {
193 const TFGraphNode* gn = gn_it.second;
194 string name = gn->name();
195 // Generate a new trace name, in case the name is taken.
196 while (sample_table_.find(name) != sample_table_.end()) {
197 name += '@';
198 }
199 pprof::Sample* sample_pb = &sample_table_[name];
200 for (uint64 id : location_ids) {
201 sample_pb->mutable_location_id()->Add(id);
202 }
203 pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204 label_pb->set_key(string_table_->GetIndex("graph node:"));
205 label_pb->set_str(string_table_->GetIndex(gn->name()));
206
207 sample_pb->mutable_value()->Add(1);
208 string type = *opts_->select.begin();
209 if (type == kShown[1]) {
210 sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211 } else if (type == kShown[9]) {
212 sample_pb->mutable_value()->Add(
213 gn->accelerator_exec_micros(node->node->step()));
214 } else if (type == kShown[10]) {
215 sample_pb->mutable_value()->Add(
216 gn->cpu_exec_micros(node->node->step()));
217 } else if (type == kShown[0]) {
218 sample_pb->mutable_value()->Add(
219 gn->requested_bytes(node->node->step()));
220 } else if (type == kShown[11]) {
221 sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222 } else if (type == kShown[12]) {
223 sample_pb->mutable_value()->Add(
224 gn->residual_bytes(node->node->step()));
225 } else if (type == kShown[13]) {
226 sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227 } else if (type == kShown[2]) {
228 sample_pb->mutable_value()->Add(gn->parameters());
229 } else if (type == kShown[3]) {
230 sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231 } else {
232 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
233 }
234 }
235 }
236 }
237
samples() const238 const std::map<string, pprof::Sample>& samples() const {
239 return sample_table_;
240 }
241
242 private:
FetchAllLeaf(const CodeNode * root)243 std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244 if (root->children.empty()) {
245 return {root};
246 }
247 std::vector<const CodeNode*> ret;
248 for (auto& n : root->children) {
249 std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250 ret.insert(ret.end(), nodes.begin(), nodes.end());
251 }
252 return ret;
253 }
254
255 StringTable* string_table_;
256 const Options* opts_;
257 std::map<string, pprof::Sample> sample_table_;
258 };
259
260 class PprofProfileImpl : public PprofProfile {
261 public:
PprofProfileImpl(const Options * opts)262 explicit PprofProfileImpl(const Options* opts)
263 : opts_(opts),
264 func_table_(new FunctionTable(&string_table_)),
265 loc_table_(new LocationTable(func_table_.get())),
266 samples_(new Samples(&string_table_, opts)) {}
267
AddLocation(const CodeNode * callee,const CodeNode * caller)268 uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269 const string& file_path = caller->file();
270 uint64 lineno = caller->lineno();
271 const string& callee_file_path = callee->file();
272 const string& callee_function = callee->function();
273 uint64 callee_func_start_line = callee->func_start_line();
274
275 return loc_table_->GetIndex(file_path, lineno, callee_function,
276 callee_file_path, callee_func_start_line);
277 }
278
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279 void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280 std::vector<uint64> reversed_call_ids;
281 std::reverse_copy(call_ids->begin(), call_ids->end(),
282 std::back_inserter(reversed_call_ids));
283 samples_->Add(leaf, reversed_call_ids);
284 }
285
WritePprofProfile(const string & filename)286 Status WritePprofProfile(const string& filename) override {
287 pprof::Profile profile_pb;
288 Build(&profile_pb);
289
290 std::unique_ptr<WritableFile> file;
291 Status s = Env::Default()->NewWritableFile(filename, &file);
292 if (!s.ok()) return s;
293
294 int32 buf_size = 1024 * 1024;
295 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296 file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297 s = zlib_output_buffer->Init();
298 if (!s.ok()) {
299 delete zlib_output_buffer;
300 return s;
301 }
302 s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
303 if (!s.ok()) {
304 delete zlib_output_buffer;
305 return s;
306 }
307 s = zlib_output_buffer->Close();
308 if (!s.ok()) {
309 delete zlib_output_buffer;
310 return s;
311 }
312 absl::FPrintF(stdout,
313 "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
314 filename);
315 delete zlib_output_buffer;
316 return s;
317 }
318
319 private:
Build(pprof::Profile * profile_pb)320 void Build(pprof::Profile* profile_pb) {
321 string sample_type_description = "count";
322 auto sample_type = profile_pb->mutable_sample_type()->Add();
323 sample_type->set_type(string_table_.GetIndex(sample_type_description));
324 sample_type->set_unit(string_table_.GetIndex("count"));
325
326 string type = *opts_->select.begin();
327 sample_type_description = type;
328 sample_type = profile_pb->mutable_sample_type()->Add();
329 sample_type->set_type(string_table_.GetIndex(sample_type_description));
330 if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
331 sample_type->set_unit(string_table_.GetIndex("microseconds"));
332 if (type == kShown[1]) {
333 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
334 "Sum of accelerator execution time and cpu execution time."));
335 } else if (type == kShown[9]) {
336 profile_pb->mutable_comment()->Add(
337 string_table_.GetIndex("Accelerator execution time."));
338 } else if (type == kShown[10]) {
339 profile_pb->mutable_comment()->Add(
340 string_table_.GetIndex("CPU execution time."));
341 }
342 } else if (type == kShown[0]) {
343 sample_type->set_unit(string_table_.GetIndex("bytes"));
344 profile_pb->mutable_comment()->Add(
345 string_table_.GetIndex("Sum of operation total memory requests, "
346 "excluding deallocations."));
347 } else if (type == kShown[11]) {
348 sample_type->set_unit(string_table_.GetIndex("bytes"));
349 profile_pb->mutable_comment()->Add(
350 string_table_.GetIndex("Sum of operation peak memory usage."));
351 } else if (type == kShown[12]) {
352 sample_type->set_unit(string_table_.GetIndex("bytes"));
353 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
354 "Sum of operation allocated memory after finish."));
355 } else if (type == kShown[13]) {
356 sample_type->set_unit(string_table_.GetIndex("bytes"));
357 profile_pb->mutable_comment()->Add(
358 string_table_.GetIndex("Sum of operation output size."));
359 } else if (type == kShown[2]) {
360 sample_type->set_unit(string_table_.GetIndex("count"));
361 profile_pb->mutable_comment()->Add(
362 string_table_.GetIndex("Model parameters."));
363 } else if (type == kShown[3]) {
364 sample_type->set_unit(string_table_.GetIndex("count"));
365 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
366 "Model float operations (Only available if defined)."));
367 } else {
368 absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
369 }
370
371 for (const string& str : string_table_.strings()) {
372 *profile_pb->mutable_string_table()->Add() = str;
373 }
374 for (const auto& sample_it : samples_->samples()) {
375 // TODO(xpan): Consider swap.
376 profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
377 }
378 for (const auto& function_it : func_table_->functions()) {
379 profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
380 }
381 for (const auto& location_it : loc_table_->locations()) {
382 profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
383 }
384 }
385
386 const Options* opts_;
387 StringTable string_table_;
388 std::unique_ptr<FunctionTable> func_table_;
389 std::unique_ptr<LocationTable> loc_table_;
390 std::unique_ptr<Samples> samples_;
391 };
392 } // namespace
393
AddNode(TFGraphNode * node)394 void TFCode::AddNode(TFGraphNode* node) {
395 if (!node->call_stack() || node->call_stack()->traces().empty()) {
396 return;
397 }
398 // We infer the forward operation name from gradient op name. So, we can
399 // map gradient op traces to forward op traces.
400 // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
401 string forward_name;
402 if (IsGradNode(node->name(), &forward_name)) {
403 auto grad_nodes_it = grad_nodes_.find(forward_name);
404 if (grad_nodes_it != grad_nodes_.end()) {
405 grad_nodes_it->second.push_back(node);
406 } else {
407 grad_nodes_.insert(
408 std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
409 }
410 return;
411 } else {
412 forward_nodes_[node->name()] = node;
413 }
414
415 if (!root_) {
416 graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
417 root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
418 }
419
420 CodeNode* pre_code_node = root_.get();
421 // TODO(xpan): Consider to release CodeDef after TFCode is built. It
422 // takes a lot of memory.
423 std::set<string> traces;
424 for (int i = 0, end = node->call_stack()->traces().size(); i < end; ++i) {
425 // Unlike op name, which is globally unique, trace name is only unique
426 // w.r.t. it's parent.
427 const string& trace = GetTraceString(node->call_stack()->traces().at(i));
428 traces.insert(trace);
429 pre_code_node = pre_code_node->AddChildren(
430 trace, &node->call_stack()->traces().at(i), "");
431 const int64 last_index = node->call_stack()->traces().size() - 1;
432 if (i == last_index) {
433 pre_code_node->node->AddGraphNode(node);
434 }
435 }
436 }
437
Build()438 void TFCode::Build() {
439 int64 unaccounted_nodes = 0;
440 for (const auto& it : grad_nodes_) {
441 const string& forward_name = it.first;
442 auto forward_it = forward_nodes_.find(forward_name);
443 if (forward_it == forward_nodes_.end()) {
444 unaccounted_nodes += 1;
445 continue;
446 }
447 TFGraphNode* fn = forward_it->second;
448 CodeNode* leaf = nullptr;
449 CodeNode* pre_code_node = root_.get();
450 for (int i = 0, end = fn->call_stack()->traces().size(); i < end; ++i) {
451 const string& trace =
452 GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
453 pre_code_node = pre_code_node->AddChildren(
454 trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
455 const int64 last_trace = fn->call_stack()->traces().size() - 1;
456 if (i == last_trace) {
457 leaf = pre_code_node;
458 }
459 }
460 for (TFGraphNode* gn : it.second) {
461 leaf->node->AddGraphNode(gn);
462 }
463 }
464 if (unaccounted_nodes > 0) {
465 absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
466 unaccounted_nodes);
467 }
468 }
469
ShowInternal(const Options & opts,Timeline * timeline)470 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
471 Timeline* timeline) {
472 root_->ResetTotalStats();
473 if (opts.output_type == kOutput[3]) {
474 if (opts.select.size() != 1) {
475 absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
476 return root_.get();
477 }
478 string select = *opts.select.begin();
479 if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
480 select != kShown[3] && select != kShown[9] && select != kShown[10] &&
481 select != kShown[11] && select != kShown[12] && select != kShown[13]) {
482 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
483 return root_.get();
484 }
485 }
486 if (opts.account_displayed_op_only) {
487 absl::FPrintF(stderr,
488 "Note: code view ignores account_displayed_op_only\n");
489 }
490
491 std::vector<CodeNode*> roots = Account(root_->children, opts);
492 root_->show_children.clear();
493 for (CodeNode* n : roots) {
494 root_->AggregateTotalStats(n);
495 }
496
497 if (opts.start_name_regexes.size() != 1 ||
498 opts.start_name_regexes[0] != ".*") {
499 roots = SearchRoot(roots, opts.start_name_regexes);
500 }
501
502 root_->show_children.assign(roots.begin(), roots.end());
503
504 CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
505
506 root->formatted_str = FormatLegend(opts) + root->formatted_str;
507
508 if (opts.output_type == kOutput[3]) {
509 std::vector<uint64> call_ids;
510 pprof_profile_.reset(new PprofProfileImpl(&opts));
511 Format(root, root->show_children, opts, &root->formatted_str,
512 root->mutable_proto(), &call_ids);
513 Status s = pprof_profile_->WritePprofProfile(
514 opts.output_options.at(kPprofOpts[0]));
515 if (!s.ok()) {
516 absl::FPrintF(stderr, "%s\n", s.ToString());
517 }
518 } else {
519 Format(root, root->show_children, opts, &root->formatted_str,
520 root->mutable_proto(), nullptr);
521 if (timeline) {
522 timeline->GenerateCodeTimeline(root);
523 }
524 }
525 return root;
526 }
527
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)528 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
529 const Options& opts, string* display_str,
530 MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
531 if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
532 pprof_profile_->AddSample(root, call_ids);
533 }
534
535 for (CodeNode* node : nodes) {
536 if (root->has_trace() && opts.output_type == kOutput[3]) {
537 uint64 loc_id = pprof_profile_->AddLocation(node, root);
538 call_ids->push_back(loc_id);
539 }
540 display_str->append(node->formatted_str);
541 MultiGraphNodeProto* child = proto->add_children();
542 child->MergeFrom(node->proto());
543 Format(node, node->show_children, opts, display_str, child, call_ids);
544 if (root->has_trace() && opts.output_type == kOutput[3]) {
545 call_ids->pop_back();
546 }
547 }
548 }
549
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)550 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
551 const std::vector<string>& regexes) {
552 std::vector<CodeNode*> res;
553 if (roots.empty()) {
554 return res;
555 }
556 for (CodeNode* root : roots) {
557 bool match_start_node = false;
558 for (const string& regex : regexes) {
559 if (RE2::FullMatch(root->name(), regex)) {
560 res.push_back(root);
561 match_start_node = true;
562 break;
563 }
564 }
565 if (match_start_node) {
566 // Found a start node at this branch, no need to continue.
567 continue;
568 }
569 std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
570 res.insert(res.end(), nroots.begin(), nroots.end());
571 }
572 return res;
573 }
574
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)575 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
576 const Options& opts, int depth,
577 int last_ident) {
578 std::vector<CodeNode*> show_nodes;
579
580 for (CodeNode* node : roots) {
581 if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
582 continue;
583 }
584 int ident = last_ident;
585 bool show = ShouldShow(node, opts, depth);
586 if (show) ident += 2;
587
588 std::vector<CodeNode*> show_cnodes =
589 PrintScope(node->show_children, opts, depth + 1, ident);
590 if (show) {
591 node->show_children.clear();
592
593 show_cnodes = SortNodes(show_cnodes, opts);
594 for (CodeNode* sc : show_cnodes) {
595 node->show_children.push_back(sc);
596 }
597
598 node->formatted_str = FormatNode(node, opts, last_ident);
599
600 if (opts.select.find(kShown[4]) != opts.select.end()) {
601 absl::FPrintF(stderr, "code view has no tensor value to show\n");
602 }
603 show_nodes.push_back(node);
604 } else {
605 show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
606 show_cnodes.end());
607 }
608 }
609 return show_nodes;
610 }
611
Account(const std::vector<CodeNode * > & roots,const Options & opts)612 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
613 const Options& opts) {
614 std::vector<CodeNode*> act_nodes;
615
616 for (CodeNode* node : roots) {
617 node->ResetTotalStats();
618 std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
619 node->account = ReAccount(node, opts);
620 if (node->account || !act_cnodes.empty()) {
621 node->show_children.clear();
622 node->ResetTotalStats();
623 node->AddSelfToTotalStats();
624 for (CodeNode* c : act_cnodes) {
625 node->AggregateTotalStats(c);
626 node->show_children.push_back(c);
627 }
628 act_nodes.push_back(node);
629 }
630 }
631 return act_nodes;
632 }
633
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const634 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
635 int64 total_bytes) const {
636 string memory = FormatMemory(total_bytes);
637 if (node->account) {
638 memory = FormatMemory(bytes) + "/" + memory;
639 } else {
640 memory = "--/" + memory;
641 }
642 return memory;
643 }
644
FormatNode(CodeNode * node,const Options & opts,int64 indent) const645 string TFCode::FormatNode(CodeNode* node, const Options& opts,
646 int64 indent) const {
647 std::vector<string> attrs;
648 if (opts.select.find(kShown[0]) != opts.select.end()) {
649 attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
650 node->proto().total_requested_bytes()));
651 }
652 if (opts.select.find(kShown[11]) != opts.select.end()) {
653 attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
654 node->proto().total_peak_bytes()));
655 }
656 if (opts.select.find(kShown[12]) != opts.select.end()) {
657 attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
658 node->proto().total_residual_bytes()));
659 }
660 if (opts.select.find(kShown[13]) != opts.select.end()) {
661 attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
662 node->proto().total_output_bytes()));
663 }
664
665 std::vector<string> time_attrs = FormatTimes(node, opts);
666 attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
667
668 if (opts.select.find(kShown[2]) != opts.select.end()) {
669 string params = FormatNumber(node->proto().total_parameters()) + " params";
670 if (node->account) {
671 params = FormatNumber(node->proto().parameters()) + "/" + params;
672 } else {
673 params = "--/" + params;
674 }
675 attrs.push_back(params);
676 }
677
678 if (opts.select.find(kShown[3]) != opts.select.end()) {
679 string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
680 if (node->account) {
681 fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
682 } else {
683 fops = "--/" + fops;
684 }
685 attrs.push_back(fops);
686 }
687
688 if (opts.select.find(kShown[5]) != opts.select.end() &&
689 !node->node->devices().empty()) {
690 attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
691 }
692 if (opts.select.find(kShown[6]) != opts.select.end()) {
693 std::set<string> op_types = node->node->op_types();
694 attrs.push_back(absl::StrJoin(op_types, "|"));
695 }
696 if (opts.select.find(kShown[7]) != opts.select.end()) {
697 // TODO(xpan): Make op count available in code view?
698 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
699 }
700 if (opts.select.find(kShown[8]) != opts.select.end()) {
701 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
702 }
703
704 return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
705 absl::StrJoin(attrs, ", "));
706 }
707 } // namespace tfprof
708 } // namespace tensorflow
709