1 /* Copyright 2020 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
16
17 #include <limits>
18
19 #include "grpcpp/grpcpp.h"
20 #include "absl/memory/memory.h"
21 #include "absl/time/clock.h"
22 #include "absl/time/time.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28
29 namespace tensorflow {
30 namespace profiler {
31 namespace {
32
FromGrpcStatus(const::grpc::Status & s)33 inline Status FromGrpcStatus(const ::grpc::Status& s) {
34 return s.ok() ? Status::OK()
35 : Status(static_cast<error::Code>(s.error_code()),
36 s.error_message());
37 }
38
39 template <typename T>
CreateStub(const std::string & service_address)40 std::unique_ptr<typename T::Stub> CreateStub(
41 const std::string& service_address) {
42 ::grpc::ChannelArguments channel_args;
43 channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
44 // Default URI prefix is "dns:///" if not provided.
45 auto channel = ::grpc::CreateCustomChannel(
46 service_address, ::grpc::InsecureChannelCredentials(), channel_args);
47 if (!channel) {
48 LOG(ERROR) << "Unable to create channel" << service_address;
49 }
50 return T::NewStub(channel);
51 }
52
53 } // namespace
54
ProfileGrpc(const std::string & service_address,const ProfileRequest & request,ProfileResponse * response)55 Status ProfileGrpc(const std::string& service_address,
56 const ProfileRequest& request, ProfileResponse* response) {
57 ::grpc::ClientContext context;
58 std::unique_ptr<grpc::ProfilerService::Stub> stub =
59 CreateStub<grpc::ProfilerService>(service_address);
60 TF_RETURN_IF_ERROR(
61 FromGrpcStatus(stub->Profile(&context, request, response)));
62 return Status::OK();
63 }
64
NewSessionGrpc(const std::string & service_address,const NewProfileSessionRequest & request,NewProfileSessionResponse * response)65 Status NewSessionGrpc(const std::string& service_address,
66 const NewProfileSessionRequest& request,
67 NewProfileSessionResponse* response) {
68 ::grpc::ClientContext context;
69 std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
70 CreateStub<grpc::ProfileAnalysis>(service_address);
71 TF_RETURN_IF_ERROR(
72 FromGrpcStatus(stub->NewSession(&context, request, response)));
73 return Status::OK();
74 }
75
MonitorGrpc(const std::string & service_address,const MonitorRequest & request,MonitorResponse * response)76 Status MonitorGrpc(const std::string& service_address,
77 const MonitorRequest& request, MonitorResponse* response) {
78 ::grpc::ClientContext context;
79 std::unique_ptr<grpc::ProfilerService::Stub> stub =
80 CreateStub<grpc::ProfilerService>(service_address);
81 TF_RETURN_IF_ERROR(
82 FromGrpcStatus(stub->Monitor(&context, request, response)));
83 return Status::OK();
84 }
85
Create(const std::string & service_address,absl::Time deadline,const ProfileRequest & profile_request)86 /*static*/ std::unique_ptr<RemoteProfilerSession> RemoteProfilerSession::Create(
87 const std::string& service_address, absl::Time deadline,
88 const ProfileRequest& profile_request) {
89 auto instance = absl::WrapUnique(
90 new RemoteProfilerSession(service_address, deadline, profile_request));
91 instance->ProfileAsync();
92 return instance;
93 }
94
RemoteProfilerSession(const std::string & service_address,absl::Time deadline,const ProfileRequest & profile_request)95 RemoteProfilerSession::RemoteProfilerSession(
96 const std::string& service_address, absl::Time deadline,
97 const ProfileRequest& profile_request)
98 : response_(absl::make_unique<ProfileResponse>()),
99 service_address_(service_address),
100 stub_(CreateStub<grpc::ProfilerService>(service_address_)),
101 deadline_(deadline),
102 profile_request_(profile_request) {
103 response_->set_empty_trace(true);
104 }
105
~RemoteProfilerSession()106 RemoteProfilerSession::~RemoteProfilerSession() {
107 Status dummy;
108 WaitForCompletion(dummy);
109 grpc_context_.TryCancel();
110 }
111
ProfileAsync()112 void RemoteProfilerSession::ProfileAsync() {
113 LOG(INFO) << "Asynchronous gRPC Profile() to " << service_address_;
114 grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
115 VLOG(1) << "Deadline set to " << deadline_;
116 rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
117 // Connection failure will create lame channel whereby grpc_status_ will be an
118 // error.
119 rpc_->Finish(response_.get(), &grpc_status_,
120 static_cast<void*>(&status_on_completion_));
121 VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
122 }
123
WaitForCompletion(Status & out_status)124 std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
125 Status& out_status) {
126 if (!response_) {
127 out_status = errors::FailedPrecondition(
128 "WaitForCompletion must only be called once.");
129 return nullptr;
130 }
131 LOG(INFO) << "Waiting for completion.";
132
133 void* got_tag = nullptr;
134 bool ok = false;
135 // Next blocks until there is a response in the completion queue. Expect the
136 // completion queue to have exactly a single response because deadline is set
137 // and completion queue is only drained once at destruction time.
138 bool success = cq_.Next(&got_tag, &ok);
139 if (!success || !ok || got_tag == nullptr) {
140 out_status =
141 errors::Internal("Missing or invalid event from completion queue.");
142 return nullptr;
143 }
144
145 VLOG(1) << "Writing out status.";
146 // For the event read from the completion queue, expect that got_tag points to
147 // the memory location of status_on_completion.
148 DCHECK_EQ(got_tag, &status_on_completion_);
149 // tagged status points to pre-allocated memory which is okay to overwrite.
150 status_on_completion_.Update(FromGrpcStatus(grpc_status_));
151 if (status_on_completion_.code() == error::DEADLINE_EXCEEDED) {
152 LOG(WARNING) << status_on_completion_;
153 } else if (!status_on_completion_.ok()) {
154 LOG(ERROR) << status_on_completion_;
155 }
156
157 out_status = status_on_completion_;
158 return std::move(response_);
159 }
160
161 } // namespace profiler
162 } // namespace tensorflow
163