1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <utility>
22 
23 #include "tensorflow/compiler/xla/executable_run_options.h"
24 #include "tensorflow/compiler/xla/service/backend.h"
25 #include "tensorflow/compiler/xla/service/stream_pool.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 
38 // Represents an asynchronously launched execution. Owns the stream (from the
39 // passed run_options->stream()) on which the execution is launched and releases
40 // the stream when destructed.
41 class AsyncExecution {
42  public:
43   AsyncExecution(Backend* backend, std::vector<StreamPool::Ptr> streams,
44                  const ExecutionProfile& profile, GlobalDataHandle result);
45 
46   Status BlockUntilDone() const;
47 
result()48   const GlobalDataHandle& result() const { return result_; }
49 
profile()50   const ExecutionProfile& profile() const { return profile_; }
51 
52  private:
53   // Backend to execute the computation on.
54   Backend* backend_;
55 
56   // Stream on which the execution is launched.
57   std::vector<StreamPool::Ptr> streams_;
58 
59   // Profile object of the execution to be returned to the user.
60   ExecutionProfile profile_;
61 
62   // Data handle to the result of the execution. Data represented by this handle
63   // is valid only after BlockUntilDone() is called.
64   GlobalDataHandle result_;
65 };
66 
67 // Tracks asynchronously launched executions for the XLA service.
68 class ExecutionTracker {
69  public:
70   ExecutionTracker();
71 
72   // Registers an execution with its backend, streams, and data handle to the
73   // execution result. Returns a handle for the registered execution.
74   ExecutionHandle Register(Backend* backend,
75                            std::vector<StreamPool::Ptr> stream,
76                            const ExecutionProfile& profile,
77                            GlobalDataHandle data);
78 
79   // Unregisters the execution for the given handle.
80   Status Unregister(const ExecutionHandle& handle);
81 
82   // Resolves the given ExecutionHandle to an AsyncExecution. Returns an
83   // error status if the given handle is not found, which means that the
84   // execution is not yet registered or already unregistered.
85   StatusOr<const AsyncExecution*> Resolve(const ExecutionHandle& handle);
86 
87  private:
88   // The next handle to assign to an execution.
89   int64 next_handle_ GUARDED_BY(execution_mutex_);
90 
91   // Mapping from ExecutionHandle handle to the corresponding registered
92   // AsyncExecution object.
93   std::map<int64, std::unique_ptr<AsyncExecution>> handle_to_execution_
94       GUARDED_BY(execution_mutex_);
95 
96   tensorflow::mutex execution_mutex_;  // Guards the execution mapping.
97 
98   TF_DISALLOW_COPY_AND_ASSIGN(ExecutionTracker);
99 };
100 
101 }  // namespace xla
102 
103 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
104