1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Usage: capture_tpu_profile --service_addr="localhost:8466" --logdir=/tmp/log
17 //
18 // Initiates a TPU profiling on the TPUProfiler service at service_addr,
19 // receives and dumps the profile data to a tensorboard log directory.
20 
21 #include "tensorflow/contrib/tpu/profiler/version.h"
22 #include "tensorflow/core/platform/init_main.h"
23 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
24 #include "tensorflow/core/util/command_line_flags.h"
25 
main(int argc,char ** argv)26 int main(int argc, char** argv) {
27   tensorflow::string FLAGS_service_addr;
28   tensorflow::string FLAGS_logdir;
29   tensorflow::string FLAGS_workers_list;
30   int FLAGS_duration_ms = 0;
31   int FLAGS_num_tracing_attempts = 3;
32   bool FLAGS_include_dataset_ops = true;
33   int FLAGS_monitoring_level = 0;
34   int FLAGS_num_queries = 100;
35   std::vector<tensorflow::Flag> flag_list = {
36       tensorflow::Flag("service_addr", &FLAGS_service_addr,
37                        "Address of TPU profiler service e.g. localhost:8466"),
38       tensorflow::Flag("workers_list", &FLAGS_workers_list,
39                        "The list of worker TPUs that we are about to profile "
40                        "in the current session."),
41       tensorflow::Flag("logdir", &FLAGS_logdir,
42                        "Path of TensorBoard log directory e.g. /tmp/tb_log, "
43                        "gs://tb_bucket"),
44       tensorflow::Flag(
45           "duration_ms", &FLAGS_duration_ms,
46           "Duration of tracing or monitoring in ms. Default is 2000ms for "
47           "tracing and 1000ms for monitoring."),
48       tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts,
49                        "Automatically retry N times when no trace event "
50                        "is collected. Default is 3."),
51       tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
52                        "Set to false to profile longer TPU device traces."),
53       tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level,
54                        "Choose a monitoring level between 1 and 2 to monitor "
55                        "your TPU job continuously. Level 2 is more verbose "
56                        "than level 1 and shows more metrics."),
57       tensorflow::Flag("num_queries", &FLAGS_num_queries,
58                        "This script will run monitoring for num_queries before "
59                        "it stops.")};
60 
61   std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
62             << std::endl;
63 
64   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
65   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
66   if (!parse_ok || FLAGS_service_addr.empty() ||
67       (FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) {
68     // Fail if flags are not parsed correctly or service_addr not provided.
69     // Also, fail if neither logdir is provided (required for tracing) nor
70     // monitoring level is provided (required for monitoring).
71     std::cout << usage.c_str() << std::endl;
72     return 2;
73   }
74   if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) {
75     // Invalid monitoring level.
76     std::cout << usage.c_str() << std::endl;
77     return 2;
78   }
79   tensorflow::Status status;
80   status =
81       tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr);
82   if (!status.ok()) {
83     std::cout << status.error_message() << std::endl;
84     std::cout << usage.c_str() << std::endl;
85     return 2;
86   }
87   tensorflow::port::InitMain(argv[0], &argc, &argv);
88 
89   // Sets the minimum duration_ms, tracing attempts and num queries.
90   int duration_ms = std::max(FLAGS_duration_ms, 0);
91   if (duration_ms == 0) {
92     // If profiling duration was not set by user or set to a negative value, we
93     // set it to default values of 2000ms for tracing and 1000ms for monitoring.
94     duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000;
95   }
96   int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1);
97   int num_queries = std::max(FLAGS_num_queries, 1);
98 
99   if (FLAGS_monitoring_level != 0) {
100     std::cout << "Since monitoring level is provided, profile "
101               << FLAGS_service_addr << " for " << duration_ms
102               << "ms and show metrics for " << num_queries << " time(s)."
103               << std::endl;
104     tensorflow::profiler::client::StartMonitoring(
105         FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries);
106   } else {
107     status = tensorflow::profiler::client::StartTracing(
108         FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
109         FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts);
110     if (!status.ok() && status.code() != tensorflow::error::Code::UNAVAILABLE) {
111       std::cout << status.error_message() << std::endl;
112       std::cout << usage.c_str() << std::endl;
113       return 2;
114     }
115   }
116   return 0;
117 }
118