1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/client/session_ref.h"
16 
17 #include <stdlib.h>
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/io/record_writer.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/protobuf/master.pb.h"
25 #include "tensorflow/core/protobuf/named_tensor.pb.h"
26 #include "tensorflow/core/protobuf/replay_log.pb.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 // Scope helper to track active calls and manage session lifetime.
33 // SessionRef blocks closing until all active calls complete or are cancelled.
34 struct RunCounter {
35   std::shared_ptr<Session> session;
36   uint64* value;
37   mutex* m;
38   condition_variable* cv;
39 
RunCountertensorflow::__anon14b565c80111::RunCounter40   explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41                       condition_variable* cv)
42       : session(std::move(s)), value(v), m(m), cv(cv) {
43     mutex_lock l(*m);
44     ++*value;
45   }
46 
~RunCountertensorflow::__anon14b565c80111::RunCounter47   ~RunCounter() {
48     mutex_lock l(*m);
49     if (--*value == 0) {
50       cv->notify_all();
51     }
52   }
53 };
54 
SessionToHandle(Session * session)55 std::string SessionToHandle(Session* session) {
56   return strings::Printf("%llu", reinterpret_cast<uint64>(session));
57 }
58 
59 // The Session interface has many methods of the form:
60 //
61 // X(a, b);
62 // X(RunOptions, a, b);
63 //
64 // Not all sessions support the second case (with an empty RunOptions()).
65 // We use this variable as a sentinel to dispatch to the correct call.
kEmptyRunOptions()66 RunOptions* kEmptyRunOptions() {
67   static RunOptions* options = new RunOptions();
68   return options;
69 }
70 
71 }  // namespace
72 
73 // Run the given session operation, recording start and end timestamps.
74 // If the operation returns a bad status, return after flushing the current
75 // log request.  This should be run _after_ all request information has been
76 // added to the current op.
77 #define RUN_WITH_TIMESTAMP(OpName, ...)              \
78   op.set_start_time_us(Env::Default()->NowMicros()); \
79   Status status = session->OpName(__VA_ARGS__);      \
80   op.set_end_time_us(Env::Default()->NowMicros());   \
81   if (!status.ok()) {                                \
82     Flush(op).IgnoreError();                         \
83     return status;                                   \
84   }
85 
86 // Records requests (and optionally responses) performed against a session.
87 // The resulting replay log can be used with the `tf_replay` tool to replicate
88 // the operations against a simulated environment, without requiring the
89 // original code or cluster setup.
90 //
91 // Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
92 class SessionLogger {
93  public:
SessionLogger()94   SessionLogger() {
95     std::string log_name = getenv("TF_REPLAY_LOG_FILE");
96     LOG(INFO) << "Constructing new session logger for " << log_name;
97     TF_CHECK_OK(
98         Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
99     Env::Default()->DeleteFile(log_name).IgnoreError();
100 
101     TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
102     log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
103   }
104 
~SessionLogger()105   ~SessionLogger() {
106     log_writer_->Close().IgnoreError();
107     log_writer_.release();
108     log_file_->Close().IgnoreError();
109   }
110 
RecordNewSession(Session * session)111   Status RecordNewSession(Session* session) {
112     ReplayOp op;
113     NewReplaySession* req = op.mutable_new_replay_session();
114     req->set_session_handle(SessionToHandle(session));
115     return Flush(op);
116   }
117 
RecordRun(Session * session,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)118   Status RecordRun(Session* session,
119                    const std::vector<std::pair<string, Tensor> >& inputs,
120                    const std::vector<string>& output_tensor_names,
121                    const std::vector<string>& target_node_names,
122                    std::vector<Tensor>* outputs) {
123     return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
124                      target_node_names, outputs, nullptr);
125   }
126 
RecordRun(Session * session,const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)127   Status RecordRun(Session* session, const RunOptions& run_options,
128                    const std::vector<std::pair<string, Tensor> >& inputs,
129                    const std::vector<string>& output_tensor_names,
130                    const std::vector<string>& target_node_names,
131                    std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
132     ReplayOp op;
133     RunStepRequest* req = op.mutable_run_step();
134     RunStepResponse* resp = op.mutable_run_step_response();
135 
136     req->set_session_handle(SessionToHandle(session));
137     *req->mutable_options() = run_options;
138 
139     for (const auto& it : inputs) {
140       NamedTensorProto* feed = req->add_feed();
141       feed->set_name(it.first);
142       it.second.AsProtoField(feed->mutable_tensor());
143     }
144 
145     // Build an index from fetch tensor name to first index in
146     // output_tensor_names.
147     std::unordered_map<string, int> output_name_to_offset;
148     for (int i = 0; i < output_tensor_names.size(); ++i) {
149       const string& name = output_tensor_names[i];
150       if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
151         req->add_fetch(name);
152       }
153     }
154     for (const string& target : target_node_names) {
155       req->add_target(target);
156     }
157 
158     if (&run_options == kEmptyRunOptions()) {
159       RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
160                          outputs);
161     } else {
162       RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
163                          target_node_names, outputs, run_metadata);
164     }
165 
166     for (size_t i = 0; i < outputs->size(); ++i) {
167       const Tensor& tensor = (*outputs)[i];
168       NamedTensorProto* tproto = resp->add_tensor();
169       tensor.AsProtoField(tproto->mutable_tensor());
170       tproto->set_name(output_tensor_names[i]);
171     }
172 
173     if (run_metadata) {
174       *resp->mutable_metadata() = *run_metadata;
175     }
176 
177     return Flush(op);
178   }
179 
RecordCreate(Session * session,const GraphDef & graph)180   Status RecordCreate(Session* session, const GraphDef& graph) {
181     return RecordCreate(session, *kEmptyRunOptions(), graph);
182   }
183 
184   // N.B. RunOptions is not stored (it has no entry in CreateRequest)
RecordCreate(Session * session,const RunOptions & run_options,const GraphDef & graph)185   Status RecordCreate(Session* session, const RunOptions& run_options,
186                       const GraphDef& graph) {
187     ReplayOp op;
188     CreateSessionRequest* req = op.mutable_create_session();
189     *req->mutable_graph_def() = graph;
190 
191     CreateSessionResponse* resp = op.mutable_create_session_response();
192     if (&run_options == kEmptyRunOptions()) {
193       RUN_WITH_TIMESTAMP(Create, graph);
194     } else {
195       RUN_WITH_TIMESTAMP(Create, run_options, graph);
196     }
197     resp->set_session_handle(SessionToHandle(session));
198     return Flush(op);
199   }
200 
RecordExtend(Session * session,const GraphDef & graph)201   Status RecordExtend(Session* session, const GraphDef& graph) {
202     return RecordExtend(session, *kEmptyRunOptions(), graph);
203   }
204 
205   // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
RecordExtend(Session * session,const RunOptions & run_options,const GraphDef & graph)206   Status RecordExtend(Session* session, const RunOptions& run_options,
207                       const GraphDef& graph) {
208     ReplayOp op;
209     ExtendSessionRequest* req = op.mutable_extend_session();
210     op.mutable_extend_session_response();
211     req->set_session_handle(SessionToHandle(session));
212     *req->mutable_graph_def() = graph;
213     if (&run_options == kEmptyRunOptions()) {
214       RUN_WITH_TIMESTAMP(Extend, graph);
215     } else {
216       RUN_WITH_TIMESTAMP(Extend, run_options, graph);
217     }
218 
219     return Flush(op);
220   }
221 
RecordClose(Session * session)222   Status RecordClose(Session* session) {
223     return RecordClose(session, *kEmptyRunOptions());
224   }
225 
226   // N.B. RunOptions is not stored (it has no entry in CloseRequest)
RecordClose(Session * session,const RunOptions & run_options)227   Status RecordClose(Session* session, const RunOptions& run_options) {
228     ReplayOp op;
229     CloseSessionRequest* req = op.mutable_close_session();
230     req->set_session_handle(SessionToHandle(session));
231     op.mutable_close_session_response();
232     if (&run_options == kEmptyRunOptions()) {
233       RUN_WITH_TIMESTAMP(Close);
234     } else {
235       RUN_WITH_TIMESTAMP(Close, run_options);
236     }
237     return Flush(op);
238   }
239 
RecordListDevices(Session * session,std::vector<DeviceAttributes> * response)240   Status RecordListDevices(Session* session,
241                            std::vector<DeviceAttributes>* response) {
242     ReplayOp op;
243     ListDevicesRequest* req = op.mutable_list_devices();
244     ListDevicesResponse* resp = op.mutable_list_devices_response();
245     req->set_session_handle(SessionToHandle(session));
246     RUN_WITH_TIMESTAMP(ListDevices, response);
247 
248     // TODO(power) -- local vs remote device distinction is lost here!
249     *resp->mutable_local_device() = {response->begin(), response->end()};
250     return Flush(op);
251   }
252 
RecordPRunSetup(Session * session,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)253   Status RecordPRunSetup(Session* session,
254                          const std::vector<string>& input_names,
255                          const std::vector<string>& output_names,
256                          const std::vector<string>& target_nodes,
257                          string* handle) {
258     ReplayOp op;
259     PartialRunSetupRequest* req = op.mutable_partial_run_setup();
260     req->set_session_handle(SessionToHandle(session));
261     for (auto& input : input_names) {
262       req->add_feed(input);
263     }
264     for (auto& output : output_names) {
265       req->add_fetch(output);
266     }
267     for (auto& target : target_nodes) {
268       req->add_target(target);
269     }
270     RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
271                        handle);
272     op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
273     return Flush(op);
274   }
275 
RecordPRun(Session * session,const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)276   Status RecordPRun(Session* session, const string& handle,
277                     const std::vector<std::pair<string, Tensor> >& inputs,
278                     const std::vector<string>& output_names,
279                     std::vector<Tensor>* outputs) {
280     ReplayOp op;
281     RunStepRequest* req = op.mutable_run_step();
282     RunStepResponse* resp = op.mutable_run_step_response();
283     req->set_session_handle(SessionToHandle(session));
284 
285     // Mark this step as a partial run for replay.
286     req->set_partial_run_handle(handle);
287     for (auto& input : inputs) {
288       auto* feed = req->add_feed();
289       feed->set_name(input.first);
290       input.second.AsProtoField(feed->mutable_tensor());
291     }
292 
293     for (auto& output : output_names) {
294       req->add_fetch(output);
295     }
296 
297     RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
298 
299     for (size_t i = 0; i < outputs->size(); ++i) {
300       const Tensor& tensor = (*outputs)[i];
301       NamedTensorProto* tproto = resp->add_tensor();
302       tensor.AsProtoField(tproto->mutable_tensor());
303       tproto->set_name(output_names[i]);
304     }
305 
306     return Flush(op);
307   }
308 
RecordMakeCallable(Session * session,const CallableOptions & callable_options,Session::CallableHandle * handle)309   Status RecordMakeCallable(Session* session,
310                             const CallableOptions& callable_options,
311                             Session::CallableHandle* handle) {
312     ReplayOp op;
313     MakeCallableRequest* req = op.mutable_make_callable();
314     req->set_session_handle(SessionToHandle(session));
315     *req->mutable_options() = callable_options;
316 
317     RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
318 
319     MakeCallableResponse* resp = op.mutable_make_callable_response();
320     resp->set_handle(*handle);
321 
322     return Flush(op);
323   }
324 
RecordRunCallable(Session * session,Session::CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)325   Status RecordRunCallable(Session* session, Session::CallableHandle handle,
326                            const std::vector<Tensor>& feed_tensors,
327                            std::vector<Tensor>* fetch_tensors,
328                            RunMetadata* run_metadata) {
329     ReplayOp op;
330     RunCallableRequest* req = op.mutable_run_callable();
331     req->set_session_handle(SessionToHandle(session));
332     req->set_handle(handle);
333     for (auto& tensor : feed_tensors) {
334       tensor.AsProtoField(req->add_feed());
335     }
336     RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
337                        run_metadata);
338 
339     RunCallableResponse* resp = op.mutable_run_callable_response();
340     if (run_metadata) {
341       *resp->mutable_metadata() = *run_metadata;
342     }
343     for (const Tensor& tensor : *fetch_tensors) {
344       tensor.AsProtoTensorContent(resp->add_fetch());
345     }
346     return Flush(op);
347   }
348 
RecordReleaseCallable(Session * session,Session::CallableHandle handle)349   Status RecordReleaseCallable(Session* session,
350                                Session::CallableHandle handle) {
351     ReplayOp op;
352     ReleaseCallableRequest* req = op.mutable_release_callable();
353     req->set_session_handle(SessionToHandle(session));
354     req->set_handle(handle);
355     RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
356     return Flush(op);
357   }
358 
359  private:
Flush(const ReplayOp & op)360   Status Flush(const ReplayOp& op) {
361     mutex_lock l(log_mutex_);
362 
363     string buf;
364     op.SerializeToString(&buf);
365     TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
366 
367     // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
368     return log_file_->Sync();
369   }
370 
371   std::unique_ptr<WritableFile> log_file_;
372   std::unique_ptr<io::RecordWriter> log_writer_;
373   mutex log_mutex_;
374 };
375 
global_session_logger()376 static SessionLogger* global_session_logger() {
377   static SessionLogger* logger = new SessionLogger();
378   return logger;
379 }
380 
SessionRef(Session * session)381 SessionRef::SessionRef(Session* session) : session_(session) {
382   if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
383     logger_ = global_session_logger();
384     logger_->RecordNewSession(this->session_.get()).IgnoreError();
385   } else {
386     logger_ = nullptr;
387   }
388 }
389 
390 SessionRef::~SessionRef() = default;
391 
CheckNotClosed()392 Status SessionRef::CheckNotClosed() {
393   mutex_lock l(run_lock_);
394   if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
395   return ::tensorflow::Status::OK();
396 }
397 
398 // If logging is active, log the start and end time of the operation along with
399 // the request and response.
400 #define LOG_AND_RUN_OPERATION(OpName, ...)                          \
401   TF_RETURN_IF_ERROR(CheckNotClosed());                             \
402   RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
403   if (!logger_) {                                                   \
404     return rc.session->OpName(__VA_ARGS__);                         \
405   }                                                                 \
406   return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
407 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)408 Status SessionRef::Run(const RunOptions& run_options,
409                        const std::vector<std::pair<string, Tensor> >& inputs,
410                        const std::vector<string>& output_tensor_names,
411                        const std::vector<string>& target_node_names,
412                        std::vector<Tensor>* outputs,
413                        RunMetadata* run_metadata) {
414   LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
415                         target_node_names, outputs, run_metadata);
416 }
417 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)418 Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
419                        const std::vector<string>& output_tensor_names,
420                        const std::vector<string>& target_node_names,
421                        std::vector<Tensor>* outputs) {
422   LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
423                         outputs);
424 }
425 
Create(const GraphDef & graph)426 Status SessionRef::Create(const GraphDef& graph) {
427   LOG_AND_RUN_OPERATION(Create, graph);
428 }
429 
Create(const RunOptions & run_options,const GraphDef & graph)430 Status SessionRef::Create(const RunOptions& run_options,
431                           const GraphDef& graph) {
432   LOG_AND_RUN_OPERATION(Create, run_options, graph);
433 }
434 
Extend(const RunOptions & run_options,const GraphDef & graph)435 Status SessionRef::Extend(const RunOptions& run_options,
436                           const GraphDef& graph) {
437   LOG_AND_RUN_OPERATION(Extend, run_options, graph);
438 }
439 
Extend(const GraphDef & graph)440 Status SessionRef::Extend(const GraphDef& graph) {
441   LOG_AND_RUN_OPERATION(Extend, graph);
442 }
443 
ListDevices(std::vector<DeviceAttributes> * response)444 Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
445   LOG_AND_RUN_OPERATION(ListDevices, response);
446 }
447 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)448 Status SessionRef::PRunSetup(const std::vector<string>& input_names,
449                              const std::vector<string>& output_names,
450                              const std::vector<string>& target_nodes,
451                              string* handle) {
452   LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
453                         handle);
454 }
455 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)456 Status SessionRef::PRun(const string& handle,
457                         const std::vector<std::pair<string, Tensor> >& inputs,
458                         const std::vector<string>& output_names,
459                         std::vector<Tensor>* outputs) {
460   LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
461 }
462 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)463 Status SessionRef::MakeCallable(const CallableOptions& callable_options,
464                                 CallableHandle* out_handle) {
465   LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
466 }
467 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)468 Status SessionRef::RunCallable(CallableHandle handle,
469                                const std::vector<Tensor>& feed_tensors,
470                                std::vector<Tensor>* fetch_tensors,
471                                RunMetadata* run_metadata) {
472   LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
473                         run_metadata);
474 }
475 
ReleaseCallable(CallableHandle handle)476 Status SessionRef::ReleaseCallable(CallableHandle handle) {
477   LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
478 }
479 
Close(const RunOptions & run_options)480 Status SessionRef::Close(const RunOptions& run_options) {
481   TF_RETURN_IF_ERROR(CheckNotClosed());
482   mutex_lock l(run_lock_);
483   Status status;
484   if (logger_) {
485     status = logger_->RecordClose(session_.get(), run_options);
486   } else {
487     status = session_->Close(run_options);
488   }
489   session_.reset();
490   while (run_count_ > 0) {
491     run_finished_.wait(l);
492   }
493   return status;
494 }
495 
Close()496 Status SessionRef::Close() {
497   TF_RETURN_IF_ERROR(CheckNotClosed());
498   mutex_lock l(run_lock_);
499   Status status;
500   if (logger_) {
501     status = logger_->RecordClose(session_.get());
502   } else {
503     status = session_->Close();
504   }
505   session_.reset();
506   while (run_count_ > 0) {
507     run_finished_.wait(l);
508   }
509   return status;
510 }
511 
512 }  // namespace tensorflow
513