1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/client/session_ref.h"
16
17 #include <stdlib.h>
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/io/record_writer.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/protobuf/master.pb.h"
25 #include "tensorflow/core/protobuf/named_tensor.pb.h"
26 #include "tensorflow/core/protobuf/replay_log.pb.h"
27
28 namespace tensorflow {
29
30 namespace {
31
32 // Scope helper to track active calls and manage session lifetime.
33 // SessionRef blocks closing until all active calls complete or are cancelled.
34 struct RunCounter {
35 std::shared_ptr<Session> session;
36 uint64* value;
37 mutex* m;
38 condition_variable* cv;
39
RunCountertensorflow::__anon14b565c80111::RunCounter40 explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41 condition_variable* cv)
42 : session(std::move(s)), value(v), m(m), cv(cv) {
43 mutex_lock l(*m);
44 ++*value;
45 }
46
~RunCountertensorflow::__anon14b565c80111::RunCounter47 ~RunCounter() {
48 mutex_lock l(*m);
49 if (--*value == 0) {
50 cv->notify_all();
51 }
52 }
53 };
54
SessionToHandle(Session * session)55 std::string SessionToHandle(Session* session) {
56 return strings::Printf("%llu", reinterpret_cast<uint64>(session));
57 }
58
59 // The Session interface has many methods of the form:
60 //
61 // X(a, b);
62 // X(RunOptions, a, b);
63 //
64 // Not all sessions support the second case (with an empty RunOptions()).
65 // We use this variable as a sentinel to dispatch to the correct call.
kEmptyRunOptions()66 RunOptions* kEmptyRunOptions() {
67 static RunOptions* options = new RunOptions();
68 return options;
69 }
70
71 } // namespace
72
73 // Run the given session operation, recording start and end timestamps.
74 // If the operation returns a bad status, return after flushing the current
75 // log request. This should be run _after_ all request information has been
76 // added to the current op.
77 #define RUN_WITH_TIMESTAMP(OpName, ...) \
78 op.set_start_time_us(Env::Default()->NowMicros()); \
79 Status status = session->OpName(__VA_ARGS__); \
80 op.set_end_time_us(Env::Default()->NowMicros()); \
81 if (!status.ok()) { \
82 Flush(op).IgnoreError(); \
83 return status; \
84 }
85
86 // Records requests (and optionally responses) performed against a session.
87 // The resulting replay log can be used with the `tf_replay` tool to replicate
88 // the operations against a simulated environment, without requiring the
89 // original code or cluster setup.
90 //
91 // Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
92 class SessionLogger {
93 public:
SessionLogger()94 SessionLogger() {
95 std::string log_name = getenv("TF_REPLAY_LOG_FILE");
96 LOG(INFO) << "Constructing new session logger for " << log_name;
97 TF_CHECK_OK(
98 Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
99 Env::Default()->DeleteFile(log_name).IgnoreError();
100
101 TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
102 log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
103 }
104
~SessionLogger()105 ~SessionLogger() {
106 log_writer_->Close().IgnoreError();
107 log_writer_.release();
108 log_file_->Close().IgnoreError();
109 }
110
RecordNewSession(Session * session)111 Status RecordNewSession(Session* session) {
112 ReplayOp op;
113 NewReplaySession* req = op.mutable_new_replay_session();
114 req->set_session_handle(SessionToHandle(session));
115 return Flush(op);
116 }
117
RecordRun(Session * session,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)118 Status RecordRun(Session* session,
119 const std::vector<std::pair<string, Tensor> >& inputs,
120 const std::vector<string>& output_tensor_names,
121 const std::vector<string>& target_node_names,
122 std::vector<Tensor>* outputs) {
123 return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
124 target_node_names, outputs, nullptr);
125 }
126
RecordRun(Session * session,const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)127 Status RecordRun(Session* session, const RunOptions& run_options,
128 const std::vector<std::pair<string, Tensor> >& inputs,
129 const std::vector<string>& output_tensor_names,
130 const std::vector<string>& target_node_names,
131 std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
132 ReplayOp op;
133 RunStepRequest* req = op.mutable_run_step();
134 RunStepResponse* resp = op.mutable_run_step_response();
135
136 req->set_session_handle(SessionToHandle(session));
137 *req->mutable_options() = run_options;
138
139 for (const auto& it : inputs) {
140 NamedTensorProto* feed = req->add_feed();
141 feed->set_name(it.first);
142 it.second.AsProtoField(feed->mutable_tensor());
143 }
144
145 // Build an index from fetch tensor name to first index in
146 // output_tensor_names.
147 std::unordered_map<string, int> output_name_to_offset;
148 for (int i = 0; i < output_tensor_names.size(); ++i) {
149 const string& name = output_tensor_names[i];
150 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
151 req->add_fetch(name);
152 }
153 }
154 for (const string& target : target_node_names) {
155 req->add_target(target);
156 }
157
158 if (&run_options == kEmptyRunOptions()) {
159 RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
160 outputs);
161 } else {
162 RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
163 target_node_names, outputs, run_metadata);
164 }
165
166 for (size_t i = 0; i < outputs->size(); ++i) {
167 const Tensor& tensor = (*outputs)[i];
168 NamedTensorProto* tproto = resp->add_tensor();
169 tensor.AsProtoField(tproto->mutable_tensor());
170 tproto->set_name(output_tensor_names[i]);
171 }
172
173 if (run_metadata) {
174 *resp->mutable_metadata() = *run_metadata;
175 }
176
177 return Flush(op);
178 }
179
RecordCreate(Session * session,const GraphDef & graph)180 Status RecordCreate(Session* session, const GraphDef& graph) {
181 return RecordCreate(session, *kEmptyRunOptions(), graph);
182 }
183
184 // N.B. RunOptions is not stored (it has no entry in CreateRequest)
RecordCreate(Session * session,const RunOptions & run_options,const GraphDef & graph)185 Status RecordCreate(Session* session, const RunOptions& run_options,
186 const GraphDef& graph) {
187 ReplayOp op;
188 CreateSessionRequest* req = op.mutable_create_session();
189 *req->mutable_graph_def() = graph;
190
191 CreateSessionResponse* resp = op.mutable_create_session_response();
192 if (&run_options == kEmptyRunOptions()) {
193 RUN_WITH_TIMESTAMP(Create, graph);
194 } else {
195 RUN_WITH_TIMESTAMP(Create, run_options, graph);
196 }
197 resp->set_session_handle(SessionToHandle(session));
198 return Flush(op);
199 }
200
RecordExtend(Session * session,const GraphDef & graph)201 Status RecordExtend(Session* session, const GraphDef& graph) {
202 return RecordExtend(session, *kEmptyRunOptions(), graph);
203 }
204
205 // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
RecordExtend(Session * session,const RunOptions & run_options,const GraphDef & graph)206 Status RecordExtend(Session* session, const RunOptions& run_options,
207 const GraphDef& graph) {
208 ReplayOp op;
209 ExtendSessionRequest* req = op.mutable_extend_session();
210 op.mutable_extend_session_response();
211 req->set_session_handle(SessionToHandle(session));
212 *req->mutable_graph_def() = graph;
213 if (&run_options == kEmptyRunOptions()) {
214 RUN_WITH_TIMESTAMP(Extend, graph);
215 } else {
216 RUN_WITH_TIMESTAMP(Extend, run_options, graph);
217 }
218
219 return Flush(op);
220 }
221
RecordClose(Session * session)222 Status RecordClose(Session* session) {
223 return RecordClose(session, *kEmptyRunOptions());
224 }
225
226 // N.B. RunOptions is not stored (it has no entry in CloseRequest)
RecordClose(Session * session,const RunOptions & run_options)227 Status RecordClose(Session* session, const RunOptions& run_options) {
228 ReplayOp op;
229 CloseSessionRequest* req = op.mutable_close_session();
230 req->set_session_handle(SessionToHandle(session));
231 op.mutable_close_session_response();
232 if (&run_options == kEmptyRunOptions()) {
233 RUN_WITH_TIMESTAMP(Close);
234 } else {
235 RUN_WITH_TIMESTAMP(Close, run_options);
236 }
237 return Flush(op);
238 }
239
RecordListDevices(Session * session,std::vector<DeviceAttributes> * response)240 Status RecordListDevices(Session* session,
241 std::vector<DeviceAttributes>* response) {
242 ReplayOp op;
243 ListDevicesRequest* req = op.mutable_list_devices();
244 ListDevicesResponse* resp = op.mutable_list_devices_response();
245 req->set_session_handle(SessionToHandle(session));
246 RUN_WITH_TIMESTAMP(ListDevices, response);
247
248 // TODO(power) -- local vs remote device distinction is lost here!
249 *resp->mutable_local_device() = {response->begin(), response->end()};
250 return Flush(op);
251 }
252
RecordPRunSetup(Session * session,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)253 Status RecordPRunSetup(Session* session,
254 const std::vector<string>& input_names,
255 const std::vector<string>& output_names,
256 const std::vector<string>& target_nodes,
257 string* handle) {
258 ReplayOp op;
259 PartialRunSetupRequest* req = op.mutable_partial_run_setup();
260 req->set_session_handle(SessionToHandle(session));
261 for (auto& input : input_names) {
262 req->add_feed(input);
263 }
264 for (auto& output : output_names) {
265 req->add_fetch(output);
266 }
267 for (auto& target : target_nodes) {
268 req->add_target(target);
269 }
270 RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
271 handle);
272 op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
273 return Flush(op);
274 }
275
RecordPRun(Session * session,const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)276 Status RecordPRun(Session* session, const string& handle,
277 const std::vector<std::pair<string, Tensor> >& inputs,
278 const std::vector<string>& output_names,
279 std::vector<Tensor>* outputs) {
280 ReplayOp op;
281 RunStepRequest* req = op.mutable_run_step();
282 RunStepResponse* resp = op.mutable_run_step_response();
283 req->set_session_handle(SessionToHandle(session));
284
285 // Mark this step as a partial run for replay.
286 req->set_partial_run_handle(handle);
287 for (auto& input : inputs) {
288 auto* feed = req->add_feed();
289 feed->set_name(input.first);
290 input.second.AsProtoField(feed->mutable_tensor());
291 }
292
293 for (auto& output : output_names) {
294 req->add_fetch(output);
295 }
296
297 RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
298
299 for (size_t i = 0; i < outputs->size(); ++i) {
300 const Tensor& tensor = (*outputs)[i];
301 NamedTensorProto* tproto = resp->add_tensor();
302 tensor.AsProtoField(tproto->mutable_tensor());
303 tproto->set_name(output_names[i]);
304 }
305
306 return Flush(op);
307 }
308
RecordMakeCallable(Session * session,const CallableOptions & callable_options,Session::CallableHandle * handle)309 Status RecordMakeCallable(Session* session,
310 const CallableOptions& callable_options,
311 Session::CallableHandle* handle) {
312 ReplayOp op;
313 MakeCallableRequest* req = op.mutable_make_callable();
314 req->set_session_handle(SessionToHandle(session));
315 *req->mutable_options() = callable_options;
316
317 RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
318
319 MakeCallableResponse* resp = op.mutable_make_callable_response();
320 resp->set_handle(*handle);
321
322 return Flush(op);
323 }
324
RecordRunCallable(Session * session,Session::CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)325 Status RecordRunCallable(Session* session, Session::CallableHandle handle,
326 const std::vector<Tensor>& feed_tensors,
327 std::vector<Tensor>* fetch_tensors,
328 RunMetadata* run_metadata) {
329 ReplayOp op;
330 RunCallableRequest* req = op.mutable_run_callable();
331 req->set_session_handle(SessionToHandle(session));
332 req->set_handle(handle);
333 for (auto& tensor : feed_tensors) {
334 tensor.AsProtoField(req->add_feed());
335 }
336 RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
337 run_metadata);
338
339 RunCallableResponse* resp = op.mutable_run_callable_response();
340 if (run_metadata) {
341 *resp->mutable_metadata() = *run_metadata;
342 }
343 for (const Tensor& tensor : *fetch_tensors) {
344 tensor.AsProtoTensorContent(resp->add_fetch());
345 }
346 return Flush(op);
347 }
348
RecordReleaseCallable(Session * session,Session::CallableHandle handle)349 Status RecordReleaseCallable(Session* session,
350 Session::CallableHandle handle) {
351 ReplayOp op;
352 ReleaseCallableRequest* req = op.mutable_release_callable();
353 req->set_session_handle(SessionToHandle(session));
354 req->set_handle(handle);
355 RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
356 return Flush(op);
357 }
358
359 private:
Flush(const ReplayOp & op)360 Status Flush(const ReplayOp& op) {
361 mutex_lock l(log_mutex_);
362
363 string buf;
364 op.SerializeToString(&buf);
365 TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
366
367 // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
368 return log_file_->Sync();
369 }
370
371 std::unique_ptr<WritableFile> log_file_;
372 std::unique_ptr<io::RecordWriter> log_writer_;
373 mutex log_mutex_;
374 };
375
global_session_logger()376 static SessionLogger* global_session_logger() {
377 static SessionLogger* logger = new SessionLogger();
378 return logger;
379 }
380
SessionRef(Session * session)381 SessionRef::SessionRef(Session* session) : session_(session) {
382 if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
383 logger_ = global_session_logger();
384 logger_->RecordNewSession(this->session_.get()).IgnoreError();
385 } else {
386 logger_ = nullptr;
387 }
388 }
389
390 SessionRef::~SessionRef() = default;
391
CheckNotClosed()392 Status SessionRef::CheckNotClosed() {
393 mutex_lock l(run_lock_);
394 if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
395 return ::tensorflow::Status::OK();
396 }
397
398 // If logging is active, log the start and end time of the operation along with
399 // the request and response.
400 #define LOG_AND_RUN_OPERATION(OpName, ...) \
401 TF_RETURN_IF_ERROR(CheckNotClosed()); \
402 RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
403 if (!logger_) { \
404 return rc.session->OpName(__VA_ARGS__); \
405 } \
406 return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
407
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)408 Status SessionRef::Run(const RunOptions& run_options,
409 const std::vector<std::pair<string, Tensor> >& inputs,
410 const std::vector<string>& output_tensor_names,
411 const std::vector<string>& target_node_names,
412 std::vector<Tensor>* outputs,
413 RunMetadata* run_metadata) {
414 LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
415 target_node_names, outputs, run_metadata);
416 }
417
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)418 Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
419 const std::vector<string>& output_tensor_names,
420 const std::vector<string>& target_node_names,
421 std::vector<Tensor>* outputs) {
422 LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
423 outputs);
424 }
425
Create(const GraphDef & graph)426 Status SessionRef::Create(const GraphDef& graph) {
427 LOG_AND_RUN_OPERATION(Create, graph);
428 }
429
Create(const RunOptions & run_options,const GraphDef & graph)430 Status SessionRef::Create(const RunOptions& run_options,
431 const GraphDef& graph) {
432 LOG_AND_RUN_OPERATION(Create, run_options, graph);
433 }
434
Extend(const RunOptions & run_options,const GraphDef & graph)435 Status SessionRef::Extend(const RunOptions& run_options,
436 const GraphDef& graph) {
437 LOG_AND_RUN_OPERATION(Extend, run_options, graph);
438 }
439
Extend(const GraphDef & graph)440 Status SessionRef::Extend(const GraphDef& graph) {
441 LOG_AND_RUN_OPERATION(Extend, graph);
442 }
443
ListDevices(std::vector<DeviceAttributes> * response)444 Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
445 LOG_AND_RUN_OPERATION(ListDevices, response);
446 }
447
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)448 Status SessionRef::PRunSetup(const std::vector<string>& input_names,
449 const std::vector<string>& output_names,
450 const std::vector<string>& target_nodes,
451 string* handle) {
452 LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
453 handle);
454 }
455
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)456 Status SessionRef::PRun(const string& handle,
457 const std::vector<std::pair<string, Tensor> >& inputs,
458 const std::vector<string>& output_names,
459 std::vector<Tensor>* outputs) {
460 LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
461 }
462
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)463 Status SessionRef::MakeCallable(const CallableOptions& callable_options,
464 CallableHandle* out_handle) {
465 LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
466 }
467
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)468 Status SessionRef::RunCallable(CallableHandle handle,
469 const std::vector<Tensor>& feed_tensors,
470 std::vector<Tensor>* fetch_tensors,
471 RunMetadata* run_metadata) {
472 LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
473 run_metadata);
474 }
475
ReleaseCallable(CallableHandle handle)476 Status SessionRef::ReleaseCallable(CallableHandle handle) {
477 LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
478 }
479
Close(const RunOptions & run_options)480 Status SessionRef::Close(const RunOptions& run_options) {
481 TF_RETURN_IF_ERROR(CheckNotClosed());
482 mutex_lock l(run_lock_);
483 Status status;
484 if (logger_) {
485 status = logger_->RecordClose(session_.get(), run_options);
486 } else {
487 status = session_->Close(run_options);
488 }
489 session_.reset();
490 while (run_count_ > 0) {
491 run_finished_.wait(l);
492 }
493 return status;
494 }
495
Close()496 Status SessionRef::Close() {
497 TF_RETURN_IF_ERROR(CheckNotClosed());
498 mutex_lock l(run_lock_);
499 Status status;
500 if (logger_) {
501 status = logger_->RecordClose(session_.get());
502 } else {
503 status = session_->Close();
504 }
505 session_.reset();
506 while (run_count_ > 0) {
507 run_finished_.wait(l);
508 }
509 return status;
510 }
511
512 } // namespace tensorflow
513