1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 18 19 #include <functional> 20 21 #include "tensorflow/core/distributed_runtime/call_options.h" 22 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 23 #include "tensorflow/core/lib/core/notification.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/types.h" 26 #include "tensorflow/core/protobuf/worker.pb.h" 27 28 namespace tensorflow { 29 30 // Status callback. 31 typedef std::function<void(const Status&)> StatusCallback; 32 33 // Custom decoder for a response to RecvTensorAsync. 34 class TensorResponse; 35 36 // Interface for talking with the TensorFlow Worker service. 37 class WorkerInterface { 38 public: 39 virtual void GetStatusAsync(CallOptions* opts, 40 const GetStatusRequest* request, 41 GetStatusResponse* response, bool fail_fast, 42 StatusCallback done) = 0; 43 44 virtual void CreateWorkerSessionAsync( 45 const CreateWorkerSessionRequest* request, 46 CreateWorkerSessionResponse* response, StatusCallback done) = 0; 47 48 virtual void DeleteWorkerSessionAsync( 49 CallOptions* opts, const DeleteWorkerSessionRequest* request, 50 DeleteWorkerSessionResponse* response, StatusCallback done) = 0; 51 52 virtual void RegisterGraphAsync(const RegisterGraphRequest* request, 53 RegisterGraphResponse* response, 54 StatusCallback done) = 0; 55 56 virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request, 57 DeregisterGraphResponse* response, 58 StatusCallback done) = 0; 59 60 virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 61 MutableRunGraphResponseWrapper* response, 62 StatusCallback done) = 0; 63 RunGraphAsync(CallOptions * opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)64 virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, 65 RunGraphResponse* response, StatusCallback done) { 66 RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request); 67 MutableRunGraphResponseWrapper* wrapped_response = 68 new NonOwnedProtoRunGraphResponse(response); 69 RunGraphAsync(opts, wrapped_request, wrapped_response, 70 [wrapped_request, wrapped_response, 71 done = std::move(done)](const Status& s) { 72 done(s); 73 delete wrapped_request; 74 delete wrapped_response; 75 }); 76 } 77 78 // Returns a request object for use in calls to 79 // `RunGraphAsync()`. Ownership is transferred to the caller. 80 // 81 // The message returned from this method must only be used in a 82 // `RunGraph()` call on the same `WorkerInterface` instance. CreateRunGraphRequest()83 virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() { 84 return new MutableProtoRunGraphRequest; 85 } 86 87 // Returns a response object for use in calls to 88 // `RunGraphAsync()`. Ownership is transferred to the caller. 89 // 90 // The message returned from this method must only be used in a 91 // `RunGraph()` call on the same `WorkerInterface` instance. CreateRunGraphResponse()92 virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() { 93 return new OwnedProtoRunGraphResponse; 94 } 95 96 virtual void CleanupGraphAsync(const CleanupGraphRequest* request, 97 CleanupGraphResponse* response, 98 StatusCallback done) = 0; 99 100 virtual void CleanupAllAsync(const CleanupAllRequest* request, 101 CleanupAllResponse* response, 102 StatusCallback done) = 0; 103 104 virtual void RecvTensorAsync(CallOptions* opts, 105 const RecvTensorRequest* request, 106 TensorResponse* response, 107 StatusCallback done) = 0; 108 109 virtual void LoggingAsync(const LoggingRequest* request, 110 LoggingResponse* response, StatusCallback done) = 0; 111 112 virtual void TracingAsync(const TracingRequest* request, 113 TracingResponse* response, StatusCallback done) = 0; 114 115 virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 116 RecvBufResponse* response, StatusCallback done) = 0; 117 118 virtual void CompleteGroupAsync(CallOptions* opts, 119 const CompleteGroupRequest* request, 120 CompleteGroupResponse* response, 121 StatusCallback done) = 0; 122 123 virtual void CompleteInstanceAsync(CallOptions* ops, 124 const CompleteInstanceRequest* request, 125 CompleteInstanceResponse* response, 126 StatusCallback done) = 0; 127 128 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 129 GetStepSequenceResponse* response, 130 StatusCallback done) = 0; 131 GetStatus(const GetStatusRequest * request,GetStatusResponse * response)132 Status GetStatus(const GetStatusRequest* request, 133 GetStatusResponse* response) { 134 Status ret; 135 Notification n; 136 GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true, 137 [&ret, &n](const Status& s) { 138 ret = s; 139 n.Notify(); 140 }); 141 n.WaitForNotification(); 142 return ret; 143 } 144 CreateWorkerSession(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response)145 Status CreateWorkerSession(const CreateWorkerSessionRequest* request, 146 CreateWorkerSessionResponse* response) { 147 return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); 148 } 149 DeleteWorkerSession(const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response)150 Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, 151 DeleteWorkerSessionResponse* response) { 152 return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request, 153 response); 154 } 155 RegisterGraph(const RegisterGraphRequest * request,RegisterGraphResponse * response)156 Status RegisterGraph(const RegisterGraphRequest* request, 157 RegisterGraphResponse* response) { 158 return CallAndWait(&ME::RegisterGraphAsync, request, response); 159 } 160 DeregisterGraph(const DeregisterGraphRequest * request,DeregisterGraphResponse * response)161 Status DeregisterGraph(const DeregisterGraphRequest* request, 162 DeregisterGraphResponse* response) { 163 return CallAndWait(&ME::DeregisterGraphAsync, request, response); 164 } 165 CleanupGraph(const CleanupGraphRequest * request,CleanupGraphResponse * response)166 Status CleanupGraph(const CleanupGraphRequest* request, 167 CleanupGraphResponse* response) { 168 return CallAndWait(&ME::CleanupGraphAsync, request, response); 169 } 170 CleanupAll(const CleanupAllRequest * request,CleanupAllResponse * response)171 Status CleanupAll(const CleanupAllRequest* request, 172 CleanupAllResponse* response) { 173 return CallAndWait(&ME::CleanupAllAsync, request, response); 174 } 175 Logging(const LoggingRequest * request,LoggingResponse * response)176 Status Logging(const LoggingRequest* request, LoggingResponse* response) { 177 return CallAndWait(&ME::LoggingAsync, request, response); 178 } 179 Tracing(const TracingRequest * request,TracingResponse * response)180 Status Tracing(const TracingRequest* request, TracingResponse* response) { 181 return CallAndWait(&ME::TracingAsync, request, response); 182 } 183 GetStepSequence(const GetStepSequenceRequest * request,GetStepSequenceResponse * response)184 Status GetStepSequence(const GetStepSequenceRequest* request, 185 GetStepSequenceResponse* response) { 186 return CallAndWait(&ME::GetStepSequenceAsync, request, response); 187 } 188 189 protected: 190 // Instances of WorkerInterface must be deleted by a call to 191 // WorkerCacheInterface::ReleaseWorker(). ~WorkerInterface()192 virtual ~WorkerInterface() {} 193 friend class WorkerCacheInterface; 194 195 // NOTE: This should only be called by implementations of this 196 // interface whose CreateRunGraphResponse() method returns a 197 // proto-based wrappers for the RunGraphResponse message. get_proto_from_wrapper(MutableRunGraphResponseWrapper * wrapper)198 RunGraphResponse* get_proto_from_wrapper( 199 MutableRunGraphResponseWrapper* wrapper) { 200 return wrapper->get_proto(); 201 } 202 203 private: 204 typedef WorkerInterface ME; 205 206 template <typename Method, typename Req, typename Resp> CallAndWait(Method func,const Req * req,Resp * resp)207 Status CallAndWait(Method func, const Req* req, Resp* resp) { 208 Status ret; 209 Notification n; 210 (this->*func)(req, resp, [&ret, &n](const Status& s) { 211 ret = s; 212 n.Notify(); 213 }); 214 n.WaitForNotification(); 215 return ret; 216 } 217 218 template <typename Method, typename Req, typename Resp> CallAndWaitWithOptions(Method func,const Req * req,Resp * resp)219 Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) { 220 CallOptions call_opts; 221 Status ret; 222 Notification n; 223 (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) { 224 ret = s; 225 n.Notify(); 226 }); 227 n.WaitForNotification(); 228 return ret; 229 } 230 }; 231 232 } // namespace tensorflow 233 234 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ 235