1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdio.h>
17 #include <stdlib.h>
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_split.h"
27 #include "linenoise.h"
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/checkpoint_reader.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/init_main.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
37 #include "tensorflow/core/profiler/internal/tfprof_stats.h"
38 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
39 #include "tensorflow/core/profiler/tfprof_log.pb.h"
40 #include "tensorflow/core/profiler/tfprof_options.h"
41 #include "tensorflow/core/protobuf/config.pb.h"
42 #include "tensorflow/core/util/command_line_flags.h"
43 
44 namespace tensorflow {
45 namespace tfprof {
completion(const char * buf,linenoiseCompletions * lc)46 void completion(const char* buf, linenoiseCompletions* lc) {
47   string buf_str = buf;
48   if (buf_str.find(' ') == buf_str.npos) {
49     for (const char* opt : kCmds) {
50       if (string(opt).find(buf_str) == 0) {
51         linenoiseAddCompletion(lc, opt);
52       }
53     }
54     return;
55   }
56 
57   string prefix;
58   int last_dash = buf_str.find_last_of(' ');
59   if (last_dash != string::npos) {
60     prefix = buf_str.substr(0, last_dash + 1);
61     buf_str = buf_str.substr(last_dash + 1, kint32max);
62   }
63   for (const char* opt : kOptions) {
64     if (string(opt).find(buf_str) == 0) {
65       linenoiseAddCompletion(lc, (prefix + opt).c_str());
66     }
67   }
68 }
69 
Run(int argc,char ** argv)70 int Run(int argc, char** argv) {
71   string FLAGS_profile_path = "";
72   string FLAGS_graph_path = "";
73   string FLAGS_run_meta_path = "";
74   string FLAGS_op_log_path = "";
75   string FLAGS_checkpoint_path = "";
76   int32 FLAGS_max_depth = 10;
77   int64 FLAGS_min_bytes = 0;
78   int64 FLAGS_min_peak_bytes = 0;
79   int64 FLAGS_min_residual_bytes = 0;
80   int64 FLAGS_min_output_bytes = 0;
81   int64 FLAGS_min_micros = 0;
82   int64 FLAGS_min_accelerator_micros = 0;
83   int64 FLAGS_min_cpu_micros = 0;
84   int64 FLAGS_min_params = 0;
85   int64 FLAGS_min_float_ops = 0;
86   int64 FLAGS_min_occurrence = 0;
87   int64 FLAGS_step = -1;
88   string FLAGS_order_by = "name";
89   string FLAGS_account_type_regexes = ".*";
90   string FLAGS_start_name_regexes = ".*";
91   string FLAGS_trim_name_regexes = "";
92   string FLAGS_show_name_regexes = ".*";
93   string FLAGS_hide_name_regexes;
94   bool FLAGS_account_displayed_op_only = false;
95   string FLAGS_select = "micros";
96   string FLAGS_output = "";
97   for (int i = 0; i < argc; i++) {
98     absl::FPrintF(stderr, "%s\n", argv[i]);
99   }
100 
101   std::vector<Flag> flag_list = {
102       Flag("profile_path", &FLAGS_profile_path, "Profile binary file name."),
103       Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"),
104       Flag("run_meta_path", &FLAGS_run_meta_path,
105            "Comma-separated list of RunMetadata proto binary "
106            "files. Each file is given step number 0,1,2,etc"),
107       Flag("op_log_path", &FLAGS_op_log_path,
108            "tensorflow::tfprof::OpLogProto proto binary file name"),
109       Flag("checkpoint_path", &FLAGS_checkpoint_path,
110            "TensorFlow Checkpoint file name"),
111       Flag("max_depth", &FLAGS_max_depth, "max depth"),
112       Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
113       Flag("min_peak_bytes", &FLAGS_min_peak_bytes, "min_peak_bytes"),
114       Flag("min_residual_bytes", &FLAGS_min_residual_bytes,
115            "min_residual_bytes"),
116       Flag("min_output_bytes", &FLAGS_min_output_bytes, "min_output_bytes"),
117       Flag("min_micros", &FLAGS_min_micros, "min micros"),
118       Flag("min_accelerator_micros", &FLAGS_min_accelerator_micros,
119            "min accelerator_micros"),
120       Flag("min_cpu_micros", &FLAGS_min_cpu_micros, "min_cpu_micros"),
121       Flag("min_params", &FLAGS_min_params, "min params"),
122       Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
123       Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"),
124       Flag("step", &FLAGS_step,
125            "The stats of which step to use. By default average"),
126       Flag("order_by", &FLAGS_order_by, "order by"),
127       Flag("account_type_regexes", &FLAGS_start_name_regexes,
128            "start name regexes"),
129       Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"),
130       Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"),
131       Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"),
132       Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only,
133            "account displayed op only"),
134       Flag("select", &FLAGS_select, "select"),
135       Flag("output", &FLAGS_output, "output"),
136   };
137   string usage = Flags::Usage(argv[0], flag_list);
138   bool parse_ok = Flags::Parse(&argc, argv, flag_list);
139   if (!parse_ok) {
140     absl::PrintF("%s", usage);
141     return (2);
142   }
143   port::InitMain(argv[0], &argc, &argv);
144 
145   if (!FLAGS_profile_path.empty() &&
146       (!FLAGS_graph_path.empty() || !FLAGS_run_meta_path.empty())) {
147     absl::FPrintF(stderr,
148                   "--profile_path is set, do not set --graph_path or "
149                   "--run_meta_path\n");
150     return 1;
151   }
152 
153   std::vector<string> account_type_regexes =
154       absl::StrSplit(FLAGS_account_type_regexes, ',', absl::SkipEmpty());
155   std::vector<string> start_name_regexes =
156       absl::StrSplit(FLAGS_start_name_regexes, ',', absl::SkipEmpty());
157   std::vector<string> trim_name_regexes =
158       absl::StrSplit(FLAGS_trim_name_regexes, ',', absl::SkipEmpty());
159   std::vector<string> show_name_regexes =
160       absl::StrSplit(FLAGS_show_name_regexes, ',', absl::SkipEmpty());
161   std::vector<string> hide_name_regexes =
162       absl::StrSplit(FLAGS_hide_name_regexes, ',', absl::SkipEmpty());
163   std::vector<string> select =
164       absl::StrSplit(FLAGS_select, ',', absl::SkipEmpty());
165 
166   string output_type;
167   std::map<string, string> output_options;
168   Status s = ParseOutput(FLAGS_output, &output_type, &output_options);
169   CHECK(s.ok()) << s.ToString();
170 
171   string cmd = "";
172   if (argc == 1 && FLAGS_graph_path.empty() && FLAGS_profile_path.empty() &&
173       FLAGS_run_meta_path.empty()) {
174     PrintHelp();
175     return 0;
176   } else if (argc > 1) {
177     if (string(argv[1]) == kCmds[6]) {
178       PrintHelp();
179       return 0;
180     }
181     if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
182         string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
183         string(argv[1]) == kCmds[4]) {
184       cmd = argv[1];
185     }
186   }
187 
188   absl::PrintF("Reading Files...\n");
189   std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
190   TF_Status* status = TF_NewStatus();
191   if (!FLAGS_checkpoint_path.empty()) {
192     ckpt_reader.reset(
193         new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status));
194     if (TF_GetCode(status) != TF_OK) {
195       absl::FPrintF(stderr, "%s\n", TF_Message(status));
196       TF_DeleteStatus(status);
197       return 1;
198     }
199     TF_DeleteStatus(status);
200   }
201 
202   std::unique_ptr<TFStats> tf_stat;
203   if (!FLAGS_profile_path.empty()) {
204     tf_stat.reset(new TFStats(FLAGS_profile_path, std::move(ckpt_reader)));
205   } else {
206     absl::PrintF(
207         "Try to use a single --profile_path instead of "
208         "graph_path,op_log_path,run_meta_path\n");
209     std::unique_ptr<GraphDef> graph(new GraphDef());
210     if (!FLAGS_graph_path.empty()) {
211       s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false);
212       if (!s.ok()) {
213         absl::FPrintF(stderr, "Failed to read graph_path: %s\n", s.ToString());
214         return 1;
215       }
216     }
217 
218     std::unique_ptr<OpLogProto> op_log(new OpLogProto());
219     if (!FLAGS_op_log_path.empty()) {
220       string op_log_str;
221       s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str);
222       if (!s.ok()) {
223         absl::FPrintF(stderr, "Failed to read op_log_path: %s\n", s.ToString());
224         return 1;
225       }
226       if (!ParseProtoUnlimited(op_log.get(), op_log_str)) {
227         absl::FPrintF(stderr, "Failed to parse op_log_path\n");
228         return 1;
229       }
230     }
231     tf_stat.reset(new TFStats(std::move(graph), nullptr, std::move(op_log),
232                               std::move(ckpt_reader)));
233 
234     std::vector<string> run_meta_files =
235         absl::StrSplit(FLAGS_run_meta_path, ',', absl::SkipEmpty());
236     for (int i = 0; i < run_meta_files.size(); ++i) {
237       std::unique_ptr<RunMetadata> run_meta(new RunMetadata());
238       s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(),
239                         true);
240       if (!s.ok()) {
241         absl::FPrintF(stderr, "Failed to read run_meta_path %s. Status: %s\n",
242                       run_meta_files[i], s.ToString());
243         return 1;
244       }
245       tf_stat->AddRunMeta(i, std::move(run_meta));
246       absl::FPrintF(stdout, "run graph coverage: %.2f\n",
247                     tf_stat->run_coverage());
248     }
249   }
250 
251   if (cmd == kCmds[4]) {
252     tf_stat->BuildAllViews();
253     Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
254     return 0;
255   }
256 
257   Options opts(
258       FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_peak_bytes,
259       FLAGS_min_residual_bytes, FLAGS_min_output_bytes, FLAGS_min_micros,
260       FLAGS_min_accelerator_micros, FLAGS_min_cpu_micros, FLAGS_min_params,
261       FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
262       account_type_regexes, start_name_regexes, trim_name_regexes,
263       show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
264       select, output_type, output_options);
265 
266   if (cmd == kCmds[2] || cmd == kCmds[3]) {
267     tf_stat->BuildView(cmd);
268     tf_stat->ShowMultiGraphNode(cmd, opts);
269     return 0;
270   } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
271     tf_stat->BuildView(cmd);
272     tf_stat->ShowGraphNode(cmd, opts);
273     return 0;
274   }
275 
276   linenoiseSetCompletionCallback(completion);
277   linenoiseHistoryLoad(".tfprof_history.txt");
278 
279   bool looped = false;
280   while (true) {
281     char* line = linenoise("tfprof> ");
282     if (line == nullptr) {
283       if (!looped) {
284         absl::FPrintF(stderr,
285                       "Cannot start interactive shell, "
286                       "use 'bazel-bin' instead of 'bazel run'.\n");
287       }
288       break;
289     }
290     looped = true;
291     string line_s = line;
292     free(line);
293 
294     if (line_s.empty()) {
295       absl::PrintF("%s", opts.ToString());
296       continue;
297     }
298     linenoiseHistoryAdd(line_s.c_str());
299     linenoiseHistorySave(".tfprof_history.txt");
300 
301     Options new_opts = opts;
302     Status s = ParseCmdLine(line_s, &cmd, &new_opts);
303     if (!s.ok()) {
304       absl::FPrintF(stderr, "E: %s\n", s.ToString());
305       continue;
306     }
307     if (cmd == kCmds[5]) {
308       opts = new_opts;
309     } else if (cmd == kCmds[6]) {
310       PrintHelp();
311     } else if (cmd == kCmds[2] || cmd == kCmds[3]) {
312       tf_stat->BuildView(cmd);
313       tf_stat->ShowMultiGraphNode(cmd, new_opts);
314     } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
315       tf_stat->BuildView(cmd);
316       tf_stat->ShowGraphNode(cmd, new_opts);
317     } else if (cmd == kCmds[4]) {
318       tf_stat->BuildAllViews();
319       Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
320     }
321   }
322   return 0;
323 }
324 }  // namespace tfprof
325 }  // namespace tensorflow
326 
main(int argc,char ** argv)327 int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); }
328