1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33 
34 namespace tensorflow {
35 
36 const char* const kSchemePrefix = "grpc://";
37 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
38 
GrpcSession(const SessionOptions & options)39 GrpcSession::GrpcSession(const SessionOptions& options)
40     : options_(options), current_graph_version_(-1) {}
41 
~GrpcSession()42 GrpcSession::~GrpcSession() {}
43 
44 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)45 Status GrpcSession::Create(const SessionOptions& options,
46                            std::unique_ptr<GrpcSession>* out_session) {
47   std::unique_ptr<GrpcSession> session(new GrpcSession(options));
48   std::unique_ptr<MasterInterface> master;
49   // For testing, we enable the client to disable the use of the local
50   // master registry, so that the RPC stack is exercised.
51   if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
52     master = LocalMaster::Lookup(options.target);
53   }
54   if (!master) {
55     SharedGrpcChannelPtr master_channel;
56     TF_RETURN_IF_ERROR(
57         NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
58                                &options.config.rpc_options(), &master_channel));
59     master.reset(NewGrpcMaster(master_channel));
60   }
61   session->SetRemoteMaster(std::move(master));
62   *out_session = std::move(session);
63   return Status::OK();
64 }
65 
66 namespace {
67 // Re-encodes constant represented in tensor proto into
68 // tensor_content, which is slightly better (less copies and lower peak
69 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)70 void ReEncodeConsts(GraphDef* gdef) {
71   for (NodeDef& ndef : *(gdef->mutable_node())) {
72     if (ndef.op() == "Const") {
73       TensorProto* proto = nullptr;
74       for (auto& attr : *ndef.mutable_attr()) {
75         if (attr.first == "value") {
76           proto = attr.second.mutable_tensor();
77         }
78       }
79       if (proto != nullptr && proto->tensor_content().empty() &&
80           proto->ByteSizeLong() > 64) {
81         // If the constant is encoded with repeated proto fields and
82         // it is moderate large, we re-encode it in tensor_content as
83         // a Cord. This is mildly helpful for reducing the peak memory
84         // usage on the server side where GraphDef/NodeDef are copied
85         // quite often.
86         Tensor parsed(proto->dtype());
87         if (parsed.FromProto(*proto)) {
88           parsed.AsProtoTensorContent(proto);
89         }
90       }
91     }
92   }
93 }
94 }  // namespace
95 
SetHandleAndGraphVersion(string handle,int64 graph_version)96 void GrpcSession::SetHandleAndGraphVersion(string handle, int64 graph_version) {
97   mutex_lock l(mu_);
98   handle_ = std::move(handle);
99   current_graph_version_ = graph_version;
100 }
101 
Handle(string * out_handle)102 Status GrpcSession::Handle(string* out_handle) {
103   mutex_lock l(mu_);
104   if (handle_.empty()) {
105     return errors::InvalidArgument("A session is not created yet....");
106   }
107   *out_handle = handle_;
108   return Status::OK();
109 }
110 
CreateImpl(CallOptions * call_options,const GraphDef & graph)111 Status GrpcSession::CreateImpl(CallOptions* call_options,
112                                const GraphDef& graph) {
113   {
114     mutex_lock l(mu_);
115     if (!handle_.empty()) {
116       return errors::InvalidArgument("A session is alive.");
117     }
118   }
119   CreateSessionRequest req;
120   *req.mutable_config() = options_.config;
121   *req.mutable_graph_def() = graph;
122   req.set_target(options_.target);
123   ReEncodeConsts(req.mutable_graph_def());
124   CreateSessionResponse resp;
125   Status s = master_->CreateSession(call_options, &req, &resp);
126   if (s.ok()) {
127     SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
128   }
129   return s;
130 }
131 
Create(const GraphDef & graph)132 Status GrpcSession::Create(const GraphDef& graph) {
133   CallOptions call_options;
134   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
135   return CreateImpl(&call_options, graph);
136 }
137 
Create(const RunOptions & run_options,const GraphDef & graph)138 Status GrpcSession::Create(const RunOptions& run_options,
139                            const GraphDef& graph) {
140   CallOptions call_options;
141   call_options.SetTimeout(run_options.timeout_in_ms());
142   return CreateImpl(&call_options, graph);
143 }
144 
ExtendImpl(CallOptions * call_options,const GraphDef & graph)145 Status GrpcSession::ExtendImpl(CallOptions* call_options,
146                                const GraphDef& graph) {
147   bool handle_is_empty;
148   {
149     mutex_lock l(mu_);
150     handle_is_empty = handle_.empty();
151   }
152   if (handle_is_empty) {
153     // Session was unitialized, so simply initialize the session with 'graph'.
154     return Create(graph);
155   }
156   mutex_lock l(mu_);
157   ExtendSessionRequest req;
158   req.set_session_handle(handle_);
159   *req.mutable_graph_def() = graph;
160   req.set_current_graph_version(current_graph_version_);
161   ExtendSessionResponse resp;
162   Status s = master_->ExtendSession(call_options, &req, &resp);
163   if (s.ok()) {
164     current_graph_version_ = resp.new_graph_version();
165   }
166   return s;
167 }
168 
Extend(const GraphDef & graph)169 Status GrpcSession::Extend(const GraphDef& graph) {
170   CallOptions call_options;
171   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
172   return ExtendImpl(&call_options, graph);
173 }
174 
Extend(const RunOptions & run_options,const GraphDef & graph)175 Status GrpcSession::Extend(const RunOptions& run_options,
176                            const GraphDef& graph) {
177   CallOptions call_options;
178   call_options.SetTimeout(run_options.timeout_in_ms());
179   return ExtendImpl(&call_options, graph);
180 }
181 
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)182 Status GrpcSession::RunHelper(
183     const RunOptions& run_options,
184     const std::vector<std::pair<string, Tensor>>& inputs,
185     const std::vector<string>& output_tensor_names,
186     const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
187     RunMetadata* run_metadata, const string& prun_handle) {
188   // Convert to proto
189   std::unique_ptr<MutableRunStepRequestWrapper> req(
190       master_->CreateRunStepRequest());
191   std::unique_ptr<MutableRunStepResponseWrapper> resp(
192       master_->CreateRunStepResponse());
193 
194   *req->mutable_options() = run_options;
195 
196   if (run_options.timeout_in_ms() == 0) {
197     req->mutable_options()->set_timeout_in_ms(
198         options_.config.operation_timeout_in_ms());
199   }
200 
201   if (!prun_handle.empty()) {
202     req->set_partial_run_handle(prun_handle);
203   }
204 
205   for (const auto& it : inputs) {
206     req->add_feed(it.first, it.second);
207   }
208 
209   // Support long error messages by storing the error code in the response body.
210   req->set_store_errors_in_response_body(true);
211 
212   // Build an index from fetch tensor name to first index in
213   // output_tensor_names.
214   std::unordered_map<string, int> output_name_to_offset;
215   for (int i = 0; i < output_tensor_names.size(); ++i) {
216     const string& name = output_tensor_names[i];
217     if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
218       req->add_fetch(name);
219     }
220   }
221   for (const string& target : target_node_names) {
222     req->add_target(target);
223   }
224 
225   CallOptions call_options;
226   call_options.SetTimeout(req->options().timeout_in_ms());
227   TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
228 
229   // Look for an extended error returned in the response body.
230   if (resp->status_code() != error::Code::OK) {
231     return Status(resp->status_code(), resp->status_error_message());
232   }
233 
234   if (!output_tensor_names.empty()) {
235     outputs->resize(output_tensor_names.size());
236   }
237 
238   // Convert response back to Tensors in the correct order.
239   for (size_t i = 0; i < resp->num_tensors(); ++i) {
240     auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
241     if (fetch_it == output_name_to_offset.end()) {
242       return errors::Internal("Received response for unrequested fetch: ",
243                               resp->tensor_name(i));
244     }
245 
246     Tensor output;
247     TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
248     (*outputs)[fetch_it->second] = output;
249   }
250   // In the unlikely event that output_tensor_names contains duplicates, fill in
251   // the duplicate values.
252   if (output_name_to_offset.size() != output_tensor_names.size()) {
253     for (int i = 0; i < output_tensor_names.size(); ++i) {
254       const string& name = output_tensor_names[i];
255       int offset = output_name_to_offset[name];
256       if (offset != i) {
257         (*outputs)[i] = (*outputs)[offset];
258       }
259     }
260   }
261 
262   if (run_metadata) {
263     run_metadata->Swap(resp->mutable_metadata());
264   }
265 
266   return Status::OK();
267 }
268 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)269 Status GrpcSession::Run(const RunOptions& run_options,
270                         const std::vector<std::pair<string, Tensor>>& inputs,
271                         const std::vector<string>& output_tensor_names,
272                         const std::vector<string>& target_node_names,
273                         std::vector<Tensor>* outputs,
274                         RunMetadata* run_metadata) {
275   return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
276                    outputs, run_metadata, /* prun_handle */ "");
277 }
278 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)279 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
280                         const std::vector<string>& output_tensor_names,
281                         const std::vector<string>& target_node_names,
282                         std::vector<Tensor>* outputs) {
283   RunOptions run_options;
284   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
285   return Run(run_options, inputs, output_tensor_names, target_node_names,
286              outputs, nullptr);
287 }
288 
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)289 Status GrpcSession::RunProto(CallOptions* call_options,
290                              MutableRunStepRequestWrapper* req,
291                              MutableRunStepResponseWrapper* resp) {
292   string handle;
293   TF_RETURN_IF_ERROR(Handle(&handle));
294   req->set_session_handle(handle);
295   return master_->RunStep(call_options, req, resp);
296 }
297 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)298 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
299                               const std::vector<string>& output_names,
300                               const std::vector<string>& target_nodes,
301                               string* handle) {
302   // Convert to proto
303   PartialRunSetupRequest req;
304   PartialRunSetupResponse resp;
305   CallOptions call_options;
306   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
307   for (const string& feed : input_names) {
308     req.add_feed(feed);
309   }
310   for (const string& fetch : output_names) {
311     req.add_fetch(fetch);
312   }
313   for (const string& target : target_nodes) {
314     req.add_target(target);
315   }
316   req.set_request_id(GetUniqueRequestId());
317   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
318   TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
319   *handle = resp.partial_run_handle();
320   return Status::OK();
321 }
322 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)323 Status GrpcSession::PRun(const string& handle,
324                          const std::vector<std::pair<string, Tensor>>& inputs,
325                          const std::vector<string>& output_names,
326                          std::vector<Tensor>* outputs) {
327   RunOptions run_options;
328   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
329   return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
330                    /* run_metadata */ nullptr, handle);
331 }
332 
Close()333 Status GrpcSession::Close() {
334   CloseSessionRequest req;
335   {
336     mutex_lock l(mu_);
337     if (handle_.empty()) {
338       return Status::OK();
339     }
340     req.set_session_handle(handle_);
341     handle_.clear();
342   }
343   CloseSessionResponse resp;
344   CallOptions call_options;
345   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
346   return master_->CloseSession(&call_options, &req, &resp);
347 }
348 
ListDevices(std::vector<DeviceAttributes> * response)349 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
350   ListDevicesRequest req;
351   {
352     mutex_lock l(mu_);
353     req.set_session_handle(handle_);
354   }
355   if (req.session_handle().empty()) {
356     LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
357                     "an empty graph and other defaults because the session has "
358                     "not yet been created.";
359     GraphDef graph_def;
360     TF_RETURN_IF_ERROR(Create(graph_def));
361     {
362       mutex_lock l(mu_);
363       req.set_session_handle(handle_);
364     }
365   }
366   ListDevicesResponse resp;
367   CallOptions call_options;
368   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
369   Status s = master_->ListDevices(&call_options, &req, &resp);
370   if (!s.ok()) {
371     LOG(ERROR) << "Could not list devices: " << s;
372     return s;
373   }
374 
375   response->clear();
376   response->reserve(resp.local_device_size() + resp.remote_device_size());
377   for (const auto& device_attr : resp.local_device()) {
378     response->emplace_back(device_attr);
379   }
380   for (const auto& device_attr : resp.remote_device()) {
381     response->emplace_back(device_attr);
382   }
383   return Status::OK();
384 }
385 
SetRemoteMaster(std::unique_ptr<MasterInterface> master)386 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
387   master_ = std::move(master);
388 }
389 
390 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)391 Status GrpcSession::Reset(const SessionOptions& options,
392                           const std::vector<string>& containers) {
393   SharedGrpcChannelPtr master_channel;
394   TF_RETURN_IF_ERROR(
395       NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
396                              /*rpc_options=*/nullptr, &master_channel));
397   auto master = NewGrpcMaster(master_channel);
398   ResetRequest req;
399   for (const auto& c : containers) req.add_container(c);
400   ResetResponse resp;
401   CallOptions call_options;
402   call_options.SetTimeout(options.config.operation_timeout_in_ms());
403   Status ret = master->Reset(&call_options, &req, &resp);
404   delete master;
405   return ret;
406 }
407 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)408 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
409                                  CallableHandle* out_handle) {
410   MakeCallableRequest req;
411   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
412   *req.mutable_options() = callable_options;
413   req.set_request_id(GetUniqueRequestId());
414   MakeCallableResponse resp;
415   CallOptions call_options;
416   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
417   TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
418   *out_handle = resp.handle();
419   return Status::OK();
420 }
421 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)422 Status GrpcSession::RunCallable(CallableHandle handle,
423                                 const std::vector<Tensor>& feed_tensors,
424                                 std::vector<Tensor>* fetch_tensors,
425                                 RunMetadata* run_metadata) {
426   RunCallableRequest req;
427   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
428   req.set_handle(handle);
429   req.set_request_id(GetUniqueRequestId());
430   for (const Tensor& feed : feed_tensors) {
431     feed.AsProtoTensorContent(req.mutable_feed()->Add());
432   }
433 
434   RunCallableResponse resp;
435   CallOptions call_options;
436   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
437   TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
438   for (const TensorProto& fetch : resp.fetch()) {
439     Tensor fetch_tensor;
440     if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
441       return errors::Internal(
442           "Could not parse fetched tensor data in response from master.");
443     }
444     fetch_tensors->push_back(std::move(fetch_tensor));
445   }
446   return Status::OK();
447 }
448 
ReleaseCallable(CallableHandle handle)449 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
450   ReleaseCallableRequest req;
451   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
452   req.set_handle(handle);
453   ReleaseCallableResponse resp;
454   CallOptions call_options;
455   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
456   return master_->ReleaseCallable(&call_options, &req, &resp);
457 }
458 
459 class GrpcSessionFactory : public SessionFactory {
460  public:
AcceptsOptions(const SessionOptions & options)461   bool AcceptsOptions(const SessionOptions& options) override {
462     return str_util::StartsWith(options.target, kSchemePrefix);
463   }
464 
NewSession(const SessionOptions & options,Session ** out_session)465   Status NewSession(const SessionOptions& options,
466                     Session** out_session) override {
467     std::unique_ptr<GrpcSession> session;
468     TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
469     *out_session = session.release();
470     return Status::OK();
471   }
472 
473   // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)474   Status Reset(const SessionOptions& options,
475                const std::vector<string>& containers) override {
476     return GrpcSession::Reset(options, containers);
477   }
478 };
479 
480 class GrpcSessionRegistrar {
481  public:
GrpcSessionRegistrar()482   GrpcSessionRegistrar() {
483     SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
484   }
485 };
486 static GrpcSessionRegistrar registrar;
487 
488 }  // namespace tensorflow
489