1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
16
17 #include <iostream>
18 #include <limits>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
31 #include "tensorflow/core/profiler/profiler_analysis.pb.h"
32 #include "tensorflow/core/profiler/profiler_options.pb.h"
33 #include "tensorflow/core/profiler/profiler_service.pb.h"
34 #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
35 #include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
36 #include "tensorflow/core/profiler/rpc/client/save_profile.h"
37
38 namespace tensorflow {
39 namespace profiler {
40 namespace {
41
42 using ::tensorflow::profiler::RemoteProfilerSessionManager;
43 using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response;
44
45 constexpr uint64 kMaxEvents = 1000000;
46 const absl::string_view kXPlanePb = "xplane.pb";
47
PopulateMonitorRequest(int duration_ms,int monitoring_level,bool timestamp)48 MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
49 bool timestamp) {
50 MonitorRequest request;
51 request.set_duration_ms(duration_ms);
52 request.set_monitoring_level(monitoring_level);
53 request.set_timestamp(timestamp);
54 return request;
55 }
56
PopulateProfileRequest(absl::string_view repository_root,absl::string_view session_id,absl::string_view host_name,const RemoteProfilerSessionManagerOptions & options)57 ProfileRequest PopulateProfileRequest(
58 absl::string_view repository_root, absl::string_view session_id,
59 absl::string_view host_name,
60 const RemoteProfilerSessionManagerOptions& options) {
61 ProfileRequest request;
62 // TODO(b/169976117) Remove duration from request.
63 request.set_duration_ms(options.profiler_options().duration_ms());
64 request.set_max_events(kMaxEvents);
65 request.set_repository_root(repository_root.data(), repository_root.size());
66 request.set_session_id(session_id.data(), session_id.size());
67 request.set_host_name(host_name.data(), host_name.size());
68 // These tools are only used by TPU profiler.
69 request.add_tools("trace_viewer");
70 request.add_tools("op_profile");
71 request.add_tools("input_pipeline");
72 request.add_tools("kernel_stats");
73 request.add_tools("memory_viewer");
74 request.add_tools("memory_profile");
75 request.add_tools("overview_page");
76 request.add_tools("pod_viewer");
77 request.add_tools("tensorflow_stats");
78 // XPlane tool is only used by OSS profiler and safely ignored by TPU
79 // profiler.
80 request.add_tools(kXPlanePb.data(), kXPlanePb.size());
81 *request.mutable_opts() = options.profiler_options();
82 return request;
83 }
84
PopulateNewProfileSessionRequest(absl::string_view repository_root,absl::string_view session_id,const RemoteProfilerSessionManagerOptions & opts)85 NewProfileSessionRequest PopulateNewProfileSessionRequest(
86 absl::string_view repository_root, absl::string_view session_id,
87 const RemoteProfilerSessionManagerOptions& opts) {
88 NewProfileSessionRequest request;
89 std::vector<absl::string_view> parts =
90 absl::StrSplit(opts.service_addresses(0), ':');
91 DCHECK(!parts.empty());
92
93 *request.mutable_request() =
94 PopulateProfileRequest(repository_root, session_id, parts[0], opts);
95 request.set_repository_root(repository_root.data(), repository_root.size());
96 request.set_session_id(session_id.data(), session_id.size());
97 for (const auto& hostname : opts.service_addresses()) {
98 request.add_hosts(hostname);
99 }
100 return request;
101 }
102
ShouldRetryTracing(Status status)103 inline bool ShouldRetryTracing(Status status) {
104 return status.code() == error::Code::UNAVAILABLE ||
105 status.code() == error::Code::ALREADY_EXISTS ||
106 // When auto-reconnecting to a remote TensorFlow worker after it
107 // restarts, gRPC can return an UNKNOWN error code with a "Stream
108 // removed" error message. This should not be treated as an
109 // unrecoverable error.
110 (status.code() == error::Code::UNKNOWN &&
111 status.error_message() == "Stream removed");
112 }
113
Profile(const std::string & repository_root,const std::string & session_id,const RemoteProfilerSessionManagerOptions & opts)114 Status Profile(const std::string& repository_root,
115 const std::string& session_id,
116 const RemoteProfilerSessionManagerOptions& opts) {
117 Status status;
118 // Host name will be overwritten by RemoteProfilerSessionManager later.
119 ProfileRequest request = PopulateProfileRequest(repository_root, session_id,
120 /*host_name=*/"", opts);
121 auto session = RemoteProfilerSessionManager::Create(opts, request, status);
122 TF_RETURN_IF_ERROR(status);
123 // Expect one or more service addresses.
124 DCHECK_GT(opts.service_addresses_size(), 0);
125 std::vector<Response> responses = session->WaitForCompletion();
126 // Expect responses to have the same size as clients.
127 DCHECK_EQ(responses.size(), opts.service_addresses_size());
128
129 bool has_trace_data = false;
130 for (const auto& client_response : responses) {
131 ProfileResponse& response = *client_response.profile_response;
132 if (response.empty_trace()) {
133 LOG(WARNING) << "No trace event is collected from "
134 << client_response.service_address;
135 } else {
136 has_trace_data = true;
137 // If server side returns tool data in the response, saves that into the
138 // repository. This improves backward compatibility by reducing assumption
139 // of what server side does.
140 TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id,
141 client_response.service_address, response,
142 &std::cout));
143 }
144 if (!client_response.status.ok()) {
145 LOG(WARNING) << client_response.service_address << " returned "
146 << client_response.status;
147 }
148 }
149
150 if (!has_trace_data) {
151 return Status(error::Code::UNAVAILABLE,
152 "No trace event was collected because there were no responses"
153 " from clients or the responses did not have trace data.");
154 }
155 return Status::OK();
156 }
157
158 // Start a new profiling session that include all the hosts included in
159 // hostnames, for the time interval of duration_ms. Possibly save the profiling
160 // result in the directory specified by repository_root and session_id.
NewSession(absl::string_view repository_root,absl::string_view session_id,const RemoteProfilerSessionManagerOptions & opts)161 Status NewSession(absl::string_view repository_root,
162 absl::string_view session_id,
163 const RemoteProfilerSessionManagerOptions& opts) {
164 NewProfileSessionRequest request =
165 PopulateNewProfileSessionRequest(repository_root, session_id, opts);
166 NewProfileSessionResponse response;
167 TF_RETURN_IF_ERROR(
168 NewSessionGrpc(opts.service_addresses(0), request, &response));
169
170 std::cout << "Profile session succeed for host(s):"
171 << absl::StrJoin(opts.service_addresses(), ",") << std::endl;
172 if (response.empty_trace()) {
173 return errors::Unavailable("No trace event is collected");
174 }
175 return Status::OK();
176 }
177
178 } // namespace
179
Trace(const std::string & logdir,int num_tracing_attempts,RemoteProfilerSessionManagerOptions & opts,bool is_cloud_tpu_session)180 Status Trace(const std::string& logdir, int num_tracing_attempts,
181 RemoteProfilerSessionManagerOptions& opts,
182 bool is_cloud_tpu_session) {
183 DCHECK_GT(opts.profiler_options().duration_ms(), 0);
184 DCHECK(!opts.service_addresses().empty());
185
186 // Use the current timestamp as the run name.
187 std::string session_id = GetCurrentTimeStampAsString();
188 std::string repository_root = GetTensorBoardProfilePluginDir(logdir);
189 auto duration_ms = opts.profiler_options().duration_ms();
190 TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
191
192 Status status;
193 int remaining_attempts = num_tracing_attempts;
194 while (true) {
195 auto start_timestamp = absl::Now() + absl::Milliseconds(opts.delay_ms());
196 opts.mutable_profiler_options()->set_start_timestamp_ns(
197 absl::ToUnixNanos(start_timestamp));
198 LOG(INFO) << "Profiler delay_ms was " << opts.delay_ms()
199 << ", start_timestamp_ns set to "
200 << opts.profiler_options().start_timestamp_ns() << " ["
201 << start_timestamp << "]";
202
203 std::cout << "Starting to trace for " << duration_ms << " ms. "
204 << "Remaining attempt(s): " << --remaining_attempts << std::endl;
205
206 if (is_cloud_tpu_session) {
207 status = NewSession(repository_root, session_id, opts);
208 } else {
209 status = Profile(repository_root, session_id, opts);
210 }
211 if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status))
212 break;
213 std::cout << "No trace event is collected. Automatically retrying.\n"
214 << std::endl;
215 }
216
217 if (ShouldRetryTracing(status)) {
218 std::cout << "No trace event is collected after " << num_tracing_attempts
219 << " attempt(s). "
220 << "Perhaps, you want to try again (with more attempts?).\n"
221 << "Tip: increase number of attempts with --num_tracing_attempts."
222 << std::endl;
223 }
224 return status;
225 }
226
Monitor(const std::string & service_addr,int duration_ms,int monitoring_level,bool display_timestamp,std::string * result)227 Status Monitor(const std::string& service_addr, int duration_ms,
228 int monitoring_level, bool display_timestamp,
229 std::string* result) {
230 MonitorRequest request =
231 PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp);
232 MonitorResponse response;
233 TF_RETURN_IF_ERROR(MonitorGrpc(service_addr, request, &response));
234 *result = response.data();
235 return Status::OK();
236 }
237
ExportToTensorBoard(const XSpace & xspace,const std::string & logdir)238 Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) {
239 TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
240
241 ProfileResponse response;
242 ProfileRequest request = PopulateProfileRequest(
243 GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(),
244 port::Hostname(), /*options=*/{});
245 TF_RETURN_IF_ERROR(
246 ConvertXSpaceToProfileResponse(xspace, request, &response));
247 std::stringstream ss; // Record LOG messages.
248 TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(),
249 request.session_id(), request.host_name(),
250 response, &ss));
251 LOG(INFO) << ss.str();
252 return Status::OK();
253 }
254
255 } // namespace profiler
256 } // namespace tensorflow
257