1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17
18 #include <unordered_map>
19
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33
34 namespace tensorflow {
35
36 const char* const kSchemePrefix = "grpc://";
37 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
38
GrpcSession(const SessionOptions & options)39 GrpcSession::GrpcSession(const SessionOptions& options)
40 : options_(options), current_graph_version_(-1) {}
41
~GrpcSession()42 GrpcSession::~GrpcSession() {}
43
44 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)45 Status GrpcSession::Create(const SessionOptions& options,
46 std::unique_ptr<GrpcSession>* out_session) {
47 std::unique_ptr<GrpcSession> session(new GrpcSession(options));
48 std::unique_ptr<MasterInterface> master;
49 // For testing, we enable the client to disable the use of the local
50 // master registry, so that the RPC stack is exercised.
51 if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
52 master = LocalMaster::Lookup(options.target);
53 }
54 if (!master) {
55 SharedGrpcChannelPtr master_channel;
56 TF_RETURN_IF_ERROR(
57 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
58 &options.config.rpc_options(), &master_channel));
59 master.reset(NewGrpcMaster(master_channel));
60 }
61 session->SetRemoteMaster(std::move(master));
62 *out_session = std::move(session);
63 return Status::OK();
64 }
65
66 namespace {
67 // Re-encodes constant represented in tensor proto into
68 // tensor_content, which is slightly better (less copies and lower peak
69 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)70 void ReEncodeConsts(GraphDef* gdef) {
71 for (NodeDef& ndef : *(gdef->mutable_node())) {
72 if (ndef.op() == "Const") {
73 TensorProto* proto = nullptr;
74 for (auto& attr : *ndef.mutable_attr()) {
75 if (attr.first == "value") {
76 proto = attr.second.mutable_tensor();
77 }
78 }
79 if (proto != nullptr && proto->tensor_content().empty() &&
80 proto->ByteSizeLong() > 64) {
81 // If the constant is encoded with repeated proto fields and
82 // it is moderate large, we re-encode it in tensor_content as
83 // a Cord. This is mildly helpful for reducing the peak memory
84 // usage on the server side where GraphDef/NodeDef are copied
85 // quite often.
86 Tensor parsed(proto->dtype());
87 if (parsed.FromProto(*proto)) {
88 parsed.AsProtoTensorContent(proto);
89 }
90 }
91 }
92 }
93 }
94 } // namespace
95
SetHandleAndGraphVersion(string handle,int64 graph_version)96 void GrpcSession::SetHandleAndGraphVersion(string handle, int64 graph_version) {
97 mutex_lock l(mu_);
98 handle_ = std::move(handle);
99 current_graph_version_ = graph_version;
100 }
101
Handle(string * out_handle)102 Status GrpcSession::Handle(string* out_handle) {
103 mutex_lock l(mu_);
104 if (handle_.empty()) {
105 return errors::InvalidArgument("A session is not created yet....");
106 }
107 *out_handle = handle_;
108 return Status::OK();
109 }
110
CreateImpl(CallOptions * call_options,const GraphDef & graph)111 Status GrpcSession::CreateImpl(CallOptions* call_options,
112 const GraphDef& graph) {
113 {
114 mutex_lock l(mu_);
115 if (!handle_.empty()) {
116 return errors::InvalidArgument("A session is alive.");
117 }
118 }
119 CreateSessionRequest req;
120 *req.mutable_config() = options_.config;
121 *req.mutable_graph_def() = graph;
122 req.set_target(options_.target);
123 ReEncodeConsts(req.mutable_graph_def());
124 CreateSessionResponse resp;
125 Status s = master_->CreateSession(call_options, &req, &resp);
126 if (s.ok()) {
127 SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
128 }
129 return s;
130 }
131
Create(const GraphDef & graph)132 Status GrpcSession::Create(const GraphDef& graph) {
133 CallOptions call_options;
134 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
135 return CreateImpl(&call_options, graph);
136 }
137
Create(const RunOptions & run_options,const GraphDef & graph)138 Status GrpcSession::Create(const RunOptions& run_options,
139 const GraphDef& graph) {
140 CallOptions call_options;
141 call_options.SetTimeout(run_options.timeout_in_ms());
142 return CreateImpl(&call_options, graph);
143 }
144
ExtendImpl(CallOptions * call_options,const GraphDef & graph)145 Status GrpcSession::ExtendImpl(CallOptions* call_options,
146 const GraphDef& graph) {
147 bool handle_is_empty;
148 {
149 mutex_lock l(mu_);
150 handle_is_empty = handle_.empty();
151 }
152 if (handle_is_empty) {
153 // Session was unitialized, so simply initialize the session with 'graph'.
154 return Create(graph);
155 }
156 mutex_lock l(mu_);
157 ExtendSessionRequest req;
158 req.set_session_handle(handle_);
159 *req.mutable_graph_def() = graph;
160 req.set_current_graph_version(current_graph_version_);
161 ExtendSessionResponse resp;
162 Status s = master_->ExtendSession(call_options, &req, &resp);
163 if (s.ok()) {
164 current_graph_version_ = resp.new_graph_version();
165 }
166 return s;
167 }
168
Extend(const GraphDef & graph)169 Status GrpcSession::Extend(const GraphDef& graph) {
170 CallOptions call_options;
171 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
172 return ExtendImpl(&call_options, graph);
173 }
174
Extend(const RunOptions & run_options,const GraphDef & graph)175 Status GrpcSession::Extend(const RunOptions& run_options,
176 const GraphDef& graph) {
177 CallOptions call_options;
178 call_options.SetTimeout(run_options.timeout_in_ms());
179 return ExtendImpl(&call_options, graph);
180 }
181
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)182 Status GrpcSession::RunHelper(
183 const RunOptions& run_options,
184 const std::vector<std::pair<string, Tensor>>& inputs,
185 const std::vector<string>& output_tensor_names,
186 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
187 RunMetadata* run_metadata, const string& prun_handle) {
188 // Convert to proto
189 std::unique_ptr<MutableRunStepRequestWrapper> req(
190 master_->CreateRunStepRequest());
191 std::unique_ptr<MutableRunStepResponseWrapper> resp(
192 master_->CreateRunStepResponse());
193
194 *req->mutable_options() = run_options;
195
196 if (run_options.timeout_in_ms() == 0) {
197 req->mutable_options()->set_timeout_in_ms(
198 options_.config.operation_timeout_in_ms());
199 }
200
201 if (!prun_handle.empty()) {
202 req->set_partial_run_handle(prun_handle);
203 }
204
205 for (const auto& it : inputs) {
206 req->add_feed(it.first, it.second);
207 }
208
209 // Support long error messages by storing the error code in the response body.
210 req->set_store_errors_in_response_body(true);
211
212 // Build an index from fetch tensor name to first index in
213 // output_tensor_names.
214 std::unordered_map<string, int> output_name_to_offset;
215 for (int i = 0; i < output_tensor_names.size(); ++i) {
216 const string& name = output_tensor_names[i];
217 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
218 req->add_fetch(name);
219 }
220 }
221 for (const string& target : target_node_names) {
222 req->add_target(target);
223 }
224
225 CallOptions call_options;
226 call_options.SetTimeout(req->options().timeout_in_ms());
227 TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
228
229 // Look for an extended error returned in the response body.
230 if (resp->status_code() != error::Code::OK) {
231 return Status(resp->status_code(), resp->status_error_message());
232 }
233
234 if (!output_tensor_names.empty()) {
235 outputs->resize(output_tensor_names.size());
236 }
237
238 // Convert response back to Tensors in the correct order.
239 for (size_t i = 0; i < resp->num_tensors(); ++i) {
240 auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
241 if (fetch_it == output_name_to_offset.end()) {
242 return errors::Internal("Received response for unrequested fetch: ",
243 resp->tensor_name(i));
244 }
245
246 Tensor output;
247 TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
248 (*outputs)[fetch_it->second] = output;
249 }
250 // In the unlikely event that output_tensor_names contains duplicates, fill in
251 // the duplicate values.
252 if (output_name_to_offset.size() != output_tensor_names.size()) {
253 for (int i = 0; i < output_tensor_names.size(); ++i) {
254 const string& name = output_tensor_names[i];
255 int offset = output_name_to_offset[name];
256 if (offset != i) {
257 (*outputs)[i] = (*outputs)[offset];
258 }
259 }
260 }
261
262 if (run_metadata) {
263 run_metadata->Swap(resp->mutable_metadata());
264 }
265
266 return Status::OK();
267 }
268
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)269 Status GrpcSession::Run(const RunOptions& run_options,
270 const std::vector<std::pair<string, Tensor>>& inputs,
271 const std::vector<string>& output_tensor_names,
272 const std::vector<string>& target_node_names,
273 std::vector<Tensor>* outputs,
274 RunMetadata* run_metadata) {
275 return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
276 outputs, run_metadata, /* prun_handle */ "");
277 }
278
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)279 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
280 const std::vector<string>& output_tensor_names,
281 const std::vector<string>& target_node_names,
282 std::vector<Tensor>* outputs) {
283 RunOptions run_options;
284 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
285 return Run(run_options, inputs, output_tensor_names, target_node_names,
286 outputs, nullptr);
287 }
288
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)289 Status GrpcSession::RunProto(CallOptions* call_options,
290 MutableRunStepRequestWrapper* req,
291 MutableRunStepResponseWrapper* resp) {
292 string handle;
293 TF_RETURN_IF_ERROR(Handle(&handle));
294 req->set_session_handle(handle);
295 return master_->RunStep(call_options, req, resp);
296 }
297
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)298 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
299 const std::vector<string>& output_names,
300 const std::vector<string>& target_nodes,
301 string* handle) {
302 // Convert to proto
303 PartialRunSetupRequest req;
304 PartialRunSetupResponse resp;
305 CallOptions call_options;
306 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
307 for (const string& feed : input_names) {
308 req.add_feed(feed);
309 }
310 for (const string& fetch : output_names) {
311 req.add_fetch(fetch);
312 }
313 for (const string& target : target_nodes) {
314 req.add_target(target);
315 }
316 req.set_request_id(GetUniqueRequestId());
317 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
318 TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
319 *handle = resp.partial_run_handle();
320 return Status::OK();
321 }
322
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)323 Status GrpcSession::PRun(const string& handle,
324 const std::vector<std::pair<string, Tensor>>& inputs,
325 const std::vector<string>& output_names,
326 std::vector<Tensor>* outputs) {
327 RunOptions run_options;
328 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
329 return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
330 /* run_metadata */ nullptr, handle);
331 }
332
Close()333 Status GrpcSession::Close() {
334 CloseSessionRequest req;
335 {
336 mutex_lock l(mu_);
337 if (handle_.empty()) {
338 return Status::OK();
339 }
340 req.set_session_handle(handle_);
341 handle_.clear();
342 }
343 CloseSessionResponse resp;
344 CallOptions call_options;
345 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
346 return master_->CloseSession(&call_options, &req, &resp);
347 }
348
ListDevices(std::vector<DeviceAttributes> * response)349 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
350 ListDevicesRequest req;
351 {
352 mutex_lock l(mu_);
353 req.set_session_handle(handle_);
354 }
355 if (req.session_handle().empty()) {
356 LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
357 "an empty graph and other defaults because the session has "
358 "not yet been created.";
359 GraphDef graph_def;
360 TF_RETURN_IF_ERROR(Create(graph_def));
361 {
362 mutex_lock l(mu_);
363 req.set_session_handle(handle_);
364 }
365 }
366 ListDevicesResponse resp;
367 CallOptions call_options;
368 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
369 Status s = master_->ListDevices(&call_options, &req, &resp);
370 if (!s.ok()) {
371 LOG(ERROR) << "Could not list devices: " << s;
372 return s;
373 }
374
375 response->clear();
376 response->reserve(resp.local_device_size() + resp.remote_device_size());
377 for (const auto& device_attr : resp.local_device()) {
378 response->emplace_back(device_attr);
379 }
380 for (const auto& device_attr : resp.remote_device()) {
381 response->emplace_back(device_attr);
382 }
383 return Status::OK();
384 }
385
SetRemoteMaster(std::unique_ptr<MasterInterface> master)386 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
387 master_ = std::move(master);
388 }
389
390 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)391 Status GrpcSession::Reset(const SessionOptions& options,
392 const std::vector<string>& containers) {
393 SharedGrpcChannelPtr master_channel;
394 TF_RETURN_IF_ERROR(
395 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
396 /*rpc_options=*/nullptr, &master_channel));
397 auto master = NewGrpcMaster(master_channel);
398 ResetRequest req;
399 for (const auto& c : containers) req.add_container(c);
400 ResetResponse resp;
401 CallOptions call_options;
402 call_options.SetTimeout(options.config.operation_timeout_in_ms());
403 Status ret = master->Reset(&call_options, &req, &resp);
404 delete master;
405 return ret;
406 }
407
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)408 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
409 CallableHandle* out_handle) {
410 MakeCallableRequest req;
411 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
412 *req.mutable_options() = callable_options;
413 req.set_request_id(GetUniqueRequestId());
414 MakeCallableResponse resp;
415 CallOptions call_options;
416 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
417 TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
418 *out_handle = resp.handle();
419 return Status::OK();
420 }
421
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)422 Status GrpcSession::RunCallable(CallableHandle handle,
423 const std::vector<Tensor>& feed_tensors,
424 std::vector<Tensor>* fetch_tensors,
425 RunMetadata* run_metadata) {
426 RunCallableRequest req;
427 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
428 req.set_handle(handle);
429 req.set_request_id(GetUniqueRequestId());
430 for (const Tensor& feed : feed_tensors) {
431 feed.AsProtoTensorContent(req.mutable_feed()->Add());
432 }
433
434 RunCallableResponse resp;
435 CallOptions call_options;
436 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
437 TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
438 for (const TensorProto& fetch : resp.fetch()) {
439 Tensor fetch_tensor;
440 if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
441 return errors::Internal(
442 "Could not parse fetched tensor data in response from master.");
443 }
444 fetch_tensors->push_back(std::move(fetch_tensor));
445 }
446 return Status::OK();
447 }
448
ReleaseCallable(CallableHandle handle)449 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
450 ReleaseCallableRequest req;
451 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
452 req.set_handle(handle);
453 ReleaseCallableResponse resp;
454 CallOptions call_options;
455 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
456 return master_->ReleaseCallable(&call_options, &req, &resp);
457 }
458
459 class GrpcSessionFactory : public SessionFactory {
460 public:
AcceptsOptions(const SessionOptions & options)461 bool AcceptsOptions(const SessionOptions& options) override {
462 return str_util::StartsWith(options.target, kSchemePrefix);
463 }
464
NewSession(const SessionOptions & options,Session ** out_session)465 Status NewSession(const SessionOptions& options,
466 Session** out_session) override {
467 std::unique_ptr<GrpcSession> session;
468 TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
469 *out_session = session.release();
470 return Status::OK();
471 }
472
473 // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)474 Status Reset(const SessionOptions& options,
475 const std::vector<string>& containers) override {
476 return GrpcSession::Reset(options, containers);
477 }
478 };
479
480 class GrpcSessionRegistrar {
481 public:
GrpcSessionRegistrar()482 GrpcSessionRegistrar() {
483 SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
484 }
485 };
486 static GrpcSessionRegistrar registrar;
487
488 } // namespace tensorflow
489