1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <thread>
25 #include <vector>
26 
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/rendezvous.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/gtl/stl_util.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/public/version.h"
37 
38 namespace tensorflow {
39 
40 // A unit of execution for the EagerExecutor class below. Example subclasses
41 // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
42 // device to another.
43 class EagerNode {
44  public:
45   explicit EagerNode(uint64 id);
46 
~EagerNode()47   virtual ~EagerNode() {}
48 
49   // Runs the computation corresponding to this node and blocks till the
50   // execution is done.
51   virtual Status Run() = 0;
52 
53   // An id unique to the TFE_Context under which this node is created. Allocated
54   // monotonically.
55   const uint64 id;
56 };
57 
58 // A class for handling async execution (see TFE_ContextSetAsync).
59 // Note that this class is thread-safe.
60 // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
61 // device of the input handle. Fix that.
62 // TODO(agarwal): On error, mark all affected handles as corrupted.
63 // TODO(agarwal): Implement support for control dependencies.
64 // TODO(agarwal): Support out-of-order execution and dispatching multiple
65 // EagerNode in parallel.
66 // TODO(agarwal): Implement optimizations over EagerNode traces.
67 class EagerExecutor {
68  public:
69   ~EagerExecutor();
70 
71   // This is called whenever async mode is enabled. Note that it may be called
72   // multiple times as different calling threads may switch async mode on or off
73   // independently.
74   void EnableAsync();
75 
76   // Helper function to create monotonically increasing ids unique to this
77   // object.
78   uint64 NextId();
79 
80   // Schedules `node` for execution.
81   // Note that Add must be called in monotonically increasing order of node->id.
82   void Add(EagerNode* node);
83 
84   // Causes the caller to block till node with id `node_id` has finished
85   // execution.
86   Status WaitFor(uint64 node_id);
87 
88   // Blocks till all currently pending ops are done.
89   Status WaitForAllPendingNodes();
90 
91   // Clears all currently set errors which re-enables async execution.
92   void ClearError();
93 
94   // Returns Status based on any errors that occurred during async execution.
95   Status status();
96 
97  private:
98   // Starts execution of pending EagerNodes. This function loops till
99   // thread_done_ is set to true. If any errors are encontered, these are set
100   // inside `status_`. The loop blocks anytime there are no pending nodes, or if
101   // `status_` is not ok.
102   void Run();
103 
104   Status WaitImpl(bool wait_all, uint64 node_id);
105 
106   mutex node_queue_mutex_;
107 
108   // Used to signal that some EagerNodes are pending execution.
109   condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
110 
111   // Queue of pending EagerNodes.
112   std::queue<EagerNode*> node_queue_ GUARDED_BY(node_queue_mutex_);
113 
114   // `status_` is set based on any errors raised during execution of a
115   // EagerNode.  It remains set until ClearError is called.
116   Status status_ GUARDED_BY(node_queue_mutex_);
117 
118   // Map from id of a EagerNode to condition_variables (not owned by the map).
119   // These condition_variables are notified and removed when that EagerNode is
120   // done executing, or if an error is found in execution of any EagerNode.
121   std::multimap<uint64, condition_variable*> node_done_notifications_
122       GUARDED_BY(node_queue_mutex_);
123 
124   // Thread object that calls the `Run` method. Currently we use only one thread
125   // for executing the EagerNodes one-by-one.
126   std::unique_ptr<Thread> thread_ GUARDED_BY(node_queue_mutex_);
127 
128   // Indicates that `thread_` should stop as soon as it is done executing the
129   // current EagerNode.
130   bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
131 
132   mutex next_id_mutex_;
133   uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1;
134 };
135 
136 }  // namespace tensorflow
137 
138 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
139