1 /* Copyright 2020 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/profiler_client.h"
16 
17 #include <limits>
18 
19 #include "grpcpp/grpcpp.h"
20 #include "absl/memory/memory.h"
21 #include "absl/time/clock.h"
22 #include "absl/time/time.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 
29 namespace tensorflow {
30 namespace profiler {
31 namespace {
32 
FromGrpcStatus(const::grpc::Status & s)33 inline Status FromGrpcStatus(const ::grpc::Status& s) {
34   return s.ok() ? Status::OK()
35                 : Status(static_cast<error::Code>(s.error_code()),
36                          s.error_message());
37 }
38 
39 template <typename T>
CreateStub(const std::string & service_address)40 std::unique_ptr<typename T::Stub> CreateStub(
41     const std::string& service_address) {
42   ::grpc::ChannelArguments channel_args;
43   channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
44   // Default URI prefix is "dns:///" if not provided.
45   auto channel = ::grpc::CreateCustomChannel(
46       service_address, ::grpc::InsecureChannelCredentials(), channel_args);
47   if (!channel) {
48     LOG(ERROR) << "Unable to create channel" << service_address;
49   }
50   return T::NewStub(channel);
51 }
52 
53 }  // namespace
54 
ProfileGrpc(const std::string & service_address,const ProfileRequest & request,ProfileResponse * response)55 Status ProfileGrpc(const std::string& service_address,
56                    const ProfileRequest& request, ProfileResponse* response) {
57   ::grpc::ClientContext context;
58   std::unique_ptr<grpc::ProfilerService::Stub> stub =
59       CreateStub<grpc::ProfilerService>(service_address);
60   TF_RETURN_IF_ERROR(
61       FromGrpcStatus(stub->Profile(&context, request, response)));
62   return Status::OK();
63 }
64 
NewSessionGrpc(const std::string & service_address,const NewProfileSessionRequest & request,NewProfileSessionResponse * response)65 Status NewSessionGrpc(const std::string& service_address,
66                       const NewProfileSessionRequest& request,
67                       NewProfileSessionResponse* response) {
68   ::grpc::ClientContext context;
69   std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
70       CreateStub<grpc::ProfileAnalysis>(service_address);
71   TF_RETURN_IF_ERROR(
72       FromGrpcStatus(stub->NewSession(&context, request, response)));
73   return Status::OK();
74 }
75 
MonitorGrpc(const std::string & service_address,const MonitorRequest & request,MonitorResponse * response)76 Status MonitorGrpc(const std::string& service_address,
77                    const MonitorRequest& request, MonitorResponse* response) {
78   ::grpc::ClientContext context;
79   std::unique_ptr<grpc::ProfilerService::Stub> stub =
80       CreateStub<grpc::ProfilerService>(service_address);
81   TF_RETURN_IF_ERROR(
82       FromGrpcStatus(stub->Monitor(&context, request, response)));
83   return Status::OK();
84 }
85 
Create(const std::string & service_address,absl::Time deadline,const ProfileRequest & profile_request)86 /*static*/ std::unique_ptr<RemoteProfilerSession> RemoteProfilerSession::Create(
87     const std::string& service_address, absl::Time deadline,
88     const ProfileRequest& profile_request) {
89   auto instance = absl::WrapUnique(
90       new RemoteProfilerSession(service_address, deadline, profile_request));
91   instance->ProfileAsync();
92   return instance;
93 }
94 
RemoteProfilerSession(const std::string & service_address,absl::Time deadline,const ProfileRequest & profile_request)95 RemoteProfilerSession::RemoteProfilerSession(
96     const std::string& service_address, absl::Time deadline,
97     const ProfileRequest& profile_request)
98     : response_(absl::make_unique<ProfileResponse>()),
99       service_address_(service_address),
100       stub_(CreateStub<grpc::ProfilerService>(service_address_)),
101       deadline_(deadline),
102       profile_request_(profile_request) {
103   response_->set_empty_trace(true);
104 }
105 
~RemoteProfilerSession()106 RemoteProfilerSession::~RemoteProfilerSession() {
107   Status dummy;
108   WaitForCompletion(dummy);
109   grpc_context_.TryCancel();
110 }
111 
ProfileAsync()112 void RemoteProfilerSession::ProfileAsync() {
113   LOG(INFO) << "Asynchronous gRPC Profile() to " << service_address_;
114   grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
115   VLOG(1) << "Deadline set to " << deadline_;
116   rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
117   // Connection failure will create lame channel whereby grpc_status_ will be an
118   // error.
119   rpc_->Finish(response_.get(), &grpc_status_,
120                static_cast<void*>(&status_on_completion_));
121   VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
122 }
123 
WaitForCompletion(Status & out_status)124 std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
125     Status& out_status) {
126   if (!response_) {
127     out_status = errors::FailedPrecondition(
128         "WaitForCompletion must only be called once.");
129     return nullptr;
130   }
131   LOG(INFO) << "Waiting for completion.";
132 
133   void* got_tag = nullptr;
134   bool ok = false;
135   // Next blocks until there is a response in the completion queue. Expect the
136   // completion queue to have exactly a single response because deadline is set
137   // and completion queue is only drained once at destruction time.
138   bool success = cq_.Next(&got_tag, &ok);
139   if (!success || !ok || got_tag == nullptr) {
140     out_status =
141         errors::Internal("Missing or invalid event from completion queue.");
142     return nullptr;
143   }
144 
145   VLOG(1) << "Writing out status.";
146   // For the event read from the completion queue, expect that got_tag points to
147   // the memory location of status_on_completion.
148   DCHECK_EQ(got_tag, &status_on_completion_);
149   // tagged status points to pre-allocated memory which is okay to overwrite.
150   status_on_completion_.Update(FromGrpcStatus(grpc_status_));
151   if (status_on_completion_.code() == error::DEADLINE_EXCEEDED) {
152     LOG(WARNING) << status_on_completion_;
153   } else if (!status_on_completion_.ok()) {
154     LOG(ERROR) << status_on_completion_;
155   }
156 
157   out_status = status_on_completion_;
158   return std::move(response_);
159 }
160 
161 }  // namespace profiler
162 }  // namespace tensorflow
163