1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
16 
17 #include <iostream>
18 #include <limits>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
31 #include "tensorflow/core/profiler/profiler_analysis.pb.h"
32 #include "tensorflow/core/profiler/profiler_options.pb.h"
33 #include "tensorflow/core/profiler/profiler_service.pb.h"
34 #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
35 #include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
36 #include "tensorflow/core/profiler/rpc/client/save_profile.h"
37 
38 namespace tensorflow {
39 namespace profiler {
40 namespace {
41 
42 using ::tensorflow::profiler::RemoteProfilerSessionManager;
43 using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response;
44 
45 constexpr uint64 kMaxEvents = 1000000;
46 const absl::string_view kXPlanePb = "xplane.pb";
47 
PopulateMonitorRequest(int duration_ms,int monitoring_level,bool timestamp)48 MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
49                                       bool timestamp) {
50   MonitorRequest request;
51   request.set_duration_ms(duration_ms);
52   request.set_monitoring_level(monitoring_level);
53   request.set_timestamp(timestamp);
54   return request;
55 }
56 
PopulateProfileRequest(absl::string_view repository_root,absl::string_view session_id,absl::string_view host_name,const RemoteProfilerSessionManagerOptions & options)57 ProfileRequest PopulateProfileRequest(
58     absl::string_view repository_root, absl::string_view session_id,
59     absl::string_view host_name,
60     const RemoteProfilerSessionManagerOptions& options) {
61   ProfileRequest request;
62   // TODO(b/169976117) Remove duration from request.
63   request.set_duration_ms(options.profiler_options().duration_ms());
64   request.set_max_events(kMaxEvents);
65   request.set_repository_root(repository_root.data(), repository_root.size());
66   request.set_session_id(session_id.data(), session_id.size());
67   request.set_host_name(host_name.data(), host_name.size());
68   // These tools are only used by TPU profiler.
69   request.add_tools("trace_viewer");
70   request.add_tools("op_profile");
71   request.add_tools("input_pipeline");
72   request.add_tools("kernel_stats");
73   request.add_tools("memory_viewer");
74   request.add_tools("memory_profile");
75   request.add_tools("overview_page");
76   request.add_tools("pod_viewer");
77   request.add_tools("tensorflow_stats");
78   // XPlane tool is only used by OSS profiler and safely ignored by TPU
79   // profiler.
80   request.add_tools(kXPlanePb.data(), kXPlanePb.size());
81   *request.mutable_opts() = options.profiler_options();
82   return request;
83 }
84 
PopulateNewProfileSessionRequest(absl::string_view repository_root,absl::string_view session_id,const RemoteProfilerSessionManagerOptions & opts)85 NewProfileSessionRequest PopulateNewProfileSessionRequest(
86     absl::string_view repository_root, absl::string_view session_id,
87     const RemoteProfilerSessionManagerOptions& opts) {
88   NewProfileSessionRequest request;
89   std::vector<absl::string_view> parts =
90       absl::StrSplit(opts.service_addresses(0), ':');
91   DCHECK(!parts.empty());
92 
93   *request.mutable_request() =
94       PopulateProfileRequest(repository_root, session_id, parts[0], opts);
95   request.set_repository_root(repository_root.data(), repository_root.size());
96   request.set_session_id(session_id.data(), session_id.size());
97   for (const auto& hostname : opts.service_addresses()) {
98     request.add_hosts(hostname);
99   }
100   return request;
101 }
102 
ShouldRetryTracing(Status status)103 inline bool ShouldRetryTracing(Status status) {
104   return status.code() == error::Code::UNAVAILABLE ||
105          status.code() == error::Code::ALREADY_EXISTS ||
106          // When auto-reconnecting to a remote TensorFlow worker after it
107          // restarts, gRPC can return an UNKNOWN error code with a "Stream
108          // removed" error message. This should not be treated as an
109          // unrecoverable error.
110          (status.code() == error::Code::UNKNOWN &&
111           status.error_message() == "Stream removed");
112 }
113 
Profile(const std::string & repository_root,const std::string & session_id,const RemoteProfilerSessionManagerOptions & opts)114 Status Profile(const std::string& repository_root,
115                const std::string& session_id,
116                const RemoteProfilerSessionManagerOptions& opts) {
117   Status status;
118   // Host name will be overwritten by RemoteProfilerSessionManager later.
119   ProfileRequest request = PopulateProfileRequest(repository_root, session_id,
120                                                   /*host_name=*/"", opts);
121   auto session = RemoteProfilerSessionManager::Create(opts, request, status);
122   TF_RETURN_IF_ERROR(status);
123   // Expect one or more service addresses.
124   DCHECK_GT(opts.service_addresses_size(), 0);
125   std::vector<Response> responses = session->WaitForCompletion();
126   // Expect responses to have the same size as clients.
127   DCHECK_EQ(responses.size(), opts.service_addresses_size());
128 
129   bool has_trace_data = false;
130   for (const auto& client_response : responses) {
131     ProfileResponse& response = *client_response.profile_response;
132     if (response.empty_trace()) {
133       LOG(WARNING) << "No trace event is collected from "
134                    << client_response.service_address;
135     } else {
136       has_trace_data = true;
137       // If server side returns tool data in the response, saves that into the
138       // repository. This improves backward compatibility by reducing assumption
139       // of what server side does.
140       TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id,
141                                      client_response.service_address, response,
142                                      &std::cout));
143     }
144     if (!client_response.status.ok()) {
145       LOG(WARNING) << client_response.service_address << " returned "
146                    << client_response.status;
147     }
148   }
149 
150   if (!has_trace_data) {
151     return Status(error::Code::UNAVAILABLE,
152                   "No trace event was collected because there were no responses"
153                   " from clients or the responses did not have trace data.");
154   }
155   return Status::OK();
156 }
157 
158 // Start a new profiling session that include all the hosts included in
159 // hostnames, for the time interval of duration_ms. Possibly save the profiling
160 // result in the directory specified by repository_root and session_id.
NewSession(absl::string_view repository_root,absl::string_view session_id,const RemoteProfilerSessionManagerOptions & opts)161 Status NewSession(absl::string_view repository_root,
162                   absl::string_view session_id,
163                   const RemoteProfilerSessionManagerOptions& opts) {
164   NewProfileSessionRequest request =
165       PopulateNewProfileSessionRequest(repository_root, session_id, opts);
166   NewProfileSessionResponse response;
167   TF_RETURN_IF_ERROR(
168       NewSessionGrpc(opts.service_addresses(0), request, &response));
169 
170   std::cout << "Profile session succeed for host(s):"
171             << absl::StrJoin(opts.service_addresses(), ",") << std::endl;
172   if (response.empty_trace()) {
173     return errors::Unavailable("No trace event is collected");
174   }
175   return Status::OK();
176 }
177 
178 }  // namespace
179 
Trace(const std::string & logdir,int num_tracing_attempts,RemoteProfilerSessionManagerOptions & opts,bool is_cloud_tpu_session)180 Status Trace(const std::string& logdir, int num_tracing_attempts,
181              RemoteProfilerSessionManagerOptions& opts,
182              bool is_cloud_tpu_session) {
183   DCHECK_GT(opts.profiler_options().duration_ms(), 0);
184   DCHECK(!opts.service_addresses().empty());
185 
186   // Use the current timestamp as the run name.
187   std::string session_id = GetCurrentTimeStampAsString();
188   std::string repository_root = GetTensorBoardProfilePluginDir(logdir);
189   auto duration_ms = opts.profiler_options().duration_ms();
190   TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
191 
192   Status status;
193   int remaining_attempts = num_tracing_attempts;
194   while (true) {
195     auto start_timestamp = absl::Now() + absl::Milliseconds(opts.delay_ms());
196     opts.mutable_profiler_options()->set_start_timestamp_ns(
197         absl::ToUnixNanos(start_timestamp));
198     LOG(INFO) << "Profiler delay_ms was " << opts.delay_ms()
199               << ", start_timestamp_ns set to "
200               << opts.profiler_options().start_timestamp_ns() << " ["
201               << start_timestamp << "]";
202 
203     std::cout << "Starting to trace for " << duration_ms << " ms. "
204               << "Remaining attempt(s): " << --remaining_attempts << std::endl;
205 
206     if (is_cloud_tpu_session) {
207       status = NewSession(repository_root, session_id, opts);
208     } else {
209       status = Profile(repository_root, session_id, opts);
210     }
211     if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
212       break;
213     std::cout << "No trace event is collected. Automatically retrying.\n"
214               << std::endl;
215   }
216 
217   if (ShouldRetryTracing(status)) {
218     std::cout << "No trace event is collected after " << num_tracing_attempts
219               << " attempt(s). "
220               << "Perhaps, you want to try again (with more attempts?).\n"
221               << "Tip: increase number of attempts with --num_tracing_attempts."
222               << std::endl;
223   }
224   return status;
225 }
226 
Monitor(const std::string & service_addr,int duration_ms,int monitoring_level,bool display_timestamp,std::string * result)227 Status Monitor(const std::string& service_addr, int duration_ms,
228                int monitoring_level, bool display_timestamp,
229                std::string* result) {
230   MonitorRequest request =
231       PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp);
232   MonitorResponse response;
233   TF_RETURN_IF_ERROR(MonitorGrpc(service_addr, request, &response));
234   *result = response.data();
235   return Status::OK();
236 }
237 
ExportToTensorBoard(const XSpace & xspace,const std::string & logdir)238 Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) {
239   TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
240 
241   ProfileResponse response;
242   ProfileRequest request = PopulateProfileRequest(
243       GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(),
244       port::Hostname(), /*options=*/{});
245   TF_RETURN_IF_ERROR(
246       ConvertXSpaceToProfileResponse(xspace, request, &response));
247   std::stringstream ss;  // Record LOG messages.
248   TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
249                                  request.session_id(), request.host_name(),
250                                  response, &ss));
251   LOG(INFO) << ss.str();
252   return Status::OK();
253 }
254 
255 }  // namespace profiler
256 }  // namespace tensorflow
257