1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Parent class and utilities for tfprof_graph and tfprof_scope.
17 
18 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
19 #define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
20 
21 #include <algorithm>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/c/checkpoint_reader.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
29 #include "tensorflow/core/profiler/internal/tfprof_node.h"
30 #include "tensorflow/core/profiler/internal/tfprof_node_show.h"
31 #include "tensorflow/core/profiler/internal/tfprof_tensor.h"
32 #include "tensorflow/core/profiler/internal/tfprof_timeline.h"
33 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
34 #include "tensorflow/core/profiler/tfprof_options.h"
35 #include "tensorflow/core/profiler/tfprof_output.pb.h"
36 
37 namespace tensorflow {
38 namespace tfprof {
39 class TFShow {
40  public:
TFShow(checkpoint::CheckpointReader * ckpt_reader)41   explicit TFShow(checkpoint::CheckpointReader* ckpt_reader)
42       : ckpt_reader_(ckpt_reader) {}
~TFShow()43   virtual ~TFShow() {}
44   virtual void AddNode(TFGraphNode* node) = 0;
45   virtual void Build() = 0;
46   virtual const GraphNodeProto& Show(const string& prefix,
47                                      const Options& opts) final;
48 
49  protected:
50   virtual const ShowNode* ShowInternal(const Options& opts,
51                                        Timeline* timeline) = 0;
52 
53   bool LookUpCheckPoint(const string& name,
54                         std::unique_ptr<TFProfTensor>* tensor);
55 
56   // Overridden by subclass if extra requirements need to be met.
ShouldShowIfExtra(const ShowNode * node,const Options & opts,int depth)57   virtual bool ShouldShowIfExtra(const ShowNode* node, const Options& opts,
58                                  int depth) const {
59     return true;
60   }
61 
62   bool ShouldShow(const ShowNode* node, const Options& opts, int depth) const;
63 
64   bool ShouldTrim(const ShowNode* node,
65                   const std::vector<string>& regexes) const;
66 
67   bool ReAccount(ShowNode* node, const Options& opts);
68 
69   string FormatNode(ShowNode* node, const Options& opts) const;
70   string FormatNodeMemory(ShowNode* node, int64 bytes, int64 total_bytes) const;
71 
72   string FormatLegend(const Options& opts) const;
73 
74   template <typename T>
SortNodes(const std::vector<T * > & nodes,const Options & opts)75   std::vector<T*> SortNodes(const std::vector<T*>& nodes, const Options& opts) {
76     if (opts.order_by.empty() || nodes.empty()) {
77       return nodes;
78     }
79     std::vector<T*> sorted_nodes = nodes;
80     std::sort(sorted_nodes.begin(), sorted_nodes.end(),
81               [&opts](const T* n1, const T* n2) {
82                 if (n1->name() == kTFProfRoot) return true;
83                 if (n2->name() == kTFProfRoot) return false;
84                 bool name_cmp = n1->name() < n2->name();
85                 if (opts.order_by == kOrderBy[0]) {
86                   return name_cmp;
87                 } else if (opts.order_by == kOrderBy[1]) {
88                   return n1->proto().total_requested_bytes() >
89                          n2->proto().total_requested_bytes();
90                 } else if (opts.order_by == kOrderBy[2]) {
91                   return n1->proto().total_peak_bytes() >
92                          n2->proto().total_peak_bytes();
93                 } else if (opts.order_by == kOrderBy[3]) {
94                   return n1->proto().total_residual_bytes() >
95                          n2->proto().total_residual_bytes();
96                 } else if (opts.order_by == kOrderBy[4]) {
97                   return n1->proto().total_output_bytes() >
98                          n2->proto().total_output_bytes();
99                 } else if (opts.order_by == kOrderBy[5]) {
100                   return n1->proto().total_exec_micros() >
101                          n2->proto().total_exec_micros();
102                 } else if (opts.order_by == kOrderBy[6]) {
103                   return n1->proto().total_accelerator_exec_micros() >
104                          n2->proto().total_accelerator_exec_micros();
105                 } else if (opts.order_by == kOrderBy[7]) {
106                   return n1->proto().total_cpu_exec_micros() >
107                          n2->proto().total_cpu_exec_micros();
108                 } else if (opts.order_by == kOrderBy[8]) {
109                   return n1->proto().total_parameters() >
110                          n2->proto().total_parameters();
111                 } else if (opts.order_by == kOrderBy[9]) {
112                   return n1->proto().total_float_ops() >
113                          n2->proto().total_float_ops();
114                 }
115                 return name_cmp;
116               });
117     return sorted_nodes;
118   }
119 
120   checkpoint::CheckpointReader* ckpt_reader_;
121 };
122 
123 template <typename T>
FormatTotalExecTime(const T * node,const Options & opts)124 string FormatTotalExecTime(const T* node, const Options& opts) {
125   string time = FormatTime(node->proto().total_exec_micros());
126   if (node->account) {
127     time = FormatTime(node->proto().exec_micros()) + "/" + time;
128   } else {
129     time = "--/" + time;
130   }
131   return time;
132 }
133 template <typename T>
FormatCPUExecTime(const T * node,const Options & opts)134 string FormatCPUExecTime(const T* node, const Options& opts) {
135   string time = FormatTime(node->proto().total_cpu_exec_micros());
136   if (node->account) {
137     time = FormatTime(node->proto().cpu_exec_micros()) + "/" + time;
138   } else {
139     time = "--/" + time;
140   }
141   return time;
142 }
143 template <typename T>
FormatAcceleratorExecTime(const T * node,const Options & opts)144 string FormatAcceleratorExecTime(const T* node, const Options& opts) {
145   string time = FormatTime(node->proto().total_accelerator_exec_micros());
146   if (node->account) {
147     time = FormatTime(node->proto().accelerator_exec_micros()) + "/" + time;
148   } else {
149     time = "--/" + time;
150   }
151   return time;
152 }
153 }  // namespace tfprof
154 }  // namespace tensorflow
155 
156 #endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
157