1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/local_executor_params.h"
21 #include "tensorflow/core/framework/rendezvous.h"
22 #include "tensorflow/core/framework/session_state.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/threadpool_interface.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class StepStatsCollector;
35 
36 // If this is called, we will sample execution cost for "inexpensive" kernels
37 // and switch them to "expensive" when the estimated cost exceeds expensive-ness
38 // threshold.
39 // This is a temporary flag for validating the performance impact of
40 // this feature. For simplicity, a global flag is used and once the flag
41 // is turned on, it cannot be turned off. We will remove this flag once this
42 // feature is validated.
43 void EnableAlwaysTrackKernelExecutionCost();
44 
45 // Executor runs a graph computation.
46 // Example:
47 //   Graph* graph = ...;
48 //      ... construct graph ...
49 //   Executor* executor;
50 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
51 //   Rendezvous* rendezvous = NewNaiveRendezvous();
52 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
53 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
54 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
55 //   ... ...
56 //
57 // Multiple threads can call Executor::Run concurrently.
58 class Executor {
59  public:
~Executor()60   virtual ~Executor() {}
61 
62   // RunAsync() executes the graph computation. "done" is run when the
63   // graph computation completes. If any error happens during the
64   // computation, "done" is run and the error is passed to "done".
65   //
66   // RunAsync() is given a few arguments in Args. The caller must
67   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
68   // are alive at least until done is invoked. All pointers to the
69   // argument objects can be nullptr.
70   //
71   // "step_id" is a process-wide unique identifier for the step being
72   // run. Executors on different devices may receive the same step_id
73   // in the case that a step runs Ops on more than one device. The
74   // step_id is used for tracking resource usage of a given step.
75   //
76   // RunAsync() uses the given "rendezvous", if not null, as the
77   // mechanism to communicate inputs and outputs of the underlying
78   // graph computation.
79   //
80   // RunAsync() calls "stats_collector", if not null, to keep track of
81   // stats. This allows us to collect statistics and traces on demand.
82   //
83   // RunAsync() is provided a "call_frame", if the executor is used
84   // for executing a function, is used to pass arguments and return
85   // values between the caller and the callee.
86   //
87   // RunAsync() uses "cancellation_manager", if not nullptr, to
88   // register callbacks that should be called if the graph computation
89   // is canceled. Note that the callbacks merely unblock any
90   // long-running computation, and a canceled step will terminate by
91   // returning/calling the DoneCallback as usual.
92   //
93   // RunAsync() dispatches closures to "runner". Typically, "runner"
94   // is backed up by a bounded threadpool.
95   struct Args {
96     int64 step_id = 0;
97     RendezvousInterface* rendezvous = nullptr;
98     StepStatsCollectorInterface* stats_collector = nullptr;
99     CallFrameInterface* call_frame = nullptr;
100     CancellationManager* cancellation_manager = nullptr;
101     SessionState* session_state = nullptr;
102     // Unique session identifier. Can be empty.
103     string session_handle;
104     TensorStore* tensor_store = nullptr;
105     ScopedStepContainer* step_container = nullptr;
106     CollectiveExecutor* collective_executor = nullptr;
107     thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
108 
109     // If true, calls Sync() on the device.
110     bool sync_on_finish = false;
111 
112     typedef std::function<void()> Closure;
113     typedef std::function<void(Closure)> Runner;
114     Runner runner = nullptr;
115 
116     // If true, all kernels will be treated as "inexpensive", and hence executed
117     // on the scheduling thread.
118     bool run_all_kernels_inline = false;
119   };
120   typedef std::function<void(const Status&)> DoneCallback;
121   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
122 
123   // Synchronous wrapper for RunAsync().
Run(const Args & args)124   virtual Status Run(const Args& args) {
125     Status ret;
126     Notification n;
127     RunAsync(args, [&ret, &n](const Status& s) {
128       ret = s;
129       n.Notify();
130     });
131     n.WaitForNotification();
132     return ret;
133   }
134 };
135 
136 // Creates an Executor that computes the given "graph".
137 //
138 // If successful, returns the constructed executor in "*executor". Otherwise,
139 // returns an error status.
140 //
141 // "params" provides a set of context for the executor. We expect that
142 // different context would provide different implementations.
143 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
144                                       const Graph& graph, Executor** executor);
145 
146 // A class to help run multiple executors in parallel and wait until
147 // all of them are complete.
148 //
149 // ExecutorBarrier deletes itself after the function returned by Get()
150 // is called.
151 class ExecutorBarrier {
152  public:
153   typedef std::function<void(const Status&)> StatusCallback;
154 
155   // Create an ExecutorBarrier for 'num' different executors.
156   //
157   // 'r' is the shared Rendezvous object that is used to communicate
158   // state.  If any of the executors experiences an error, the
159   // rendezvous object will be aborted exactly once.
160   //
161   // 'done' is called after the last executor completes, and
162   // ExecutorBarrier is deleted.
ExecutorBarrier(size_t num,Rendezvous * r,StatusCallback done)163   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
164       : rendez_(r), done_cb_(done), pending_(num) {}
165 
~ExecutorBarrier()166   ~ExecutorBarrier() {}
167 
168   // Returns a closure that Executors must call when they are done
169   // computing, passing the status of their execution as an argument.
Get()170   StatusCallback Get() {
171     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
172   }
173 
174  private:
175   Rendezvous* rendez_ = nullptr;
176   StatusCallback done_cb_ = nullptr;
177 
178   mutable mutex mu_;
179   int pending_ TF_GUARDED_BY(mu_) = 0;
180   StatusGroup status_group_ TF_GUARDED_BY(mu_);
181 
WhenDone(const Status & s)182   void WhenDone(const Status& s) {
183     Rendezvous* error_rendez = nullptr;
184     StatusCallback done = nullptr;
185     Status status;
186 
187     {
188       mutex_lock l(mu_);
189 
190       // If we are the first error encountered, trigger an abort of the
191       // Rendezvous object by this thread only.
192       if (status_group_.ok() && !s.ok()) {
193         error_rendez = rendez_;
194         error_rendez->Ref();
195       }
196 
197       if (!s.ok() && !StatusGroup::IsDerived(s) &&
198           !status_group_.HasLogMessages()) {
199         status_group_.AttachLogMessages();
200       }
201 
202       status_group_.Update(s);
203 
204       // If this is the last call to WhenDone, call the final callback
205       // below.
206       if (--pending_ == 0) {
207         CHECK(done_cb_ != nullptr);
208         std::swap(done, done_cb_);
209         status = status_group_.as_summary_status();
210       }
211     }
212 
213     if (error_rendez != nullptr) {
214       error_rendez->StartAbort(
215           errors::Aborted("Stopping remaining executors."));
216       error_rendez->Unref();
217     }
218 
219     if (done != nullptr) {
220       delete this;
221       if (!status.ok()) {
222         VLOG(1) << "ExecutorBarrier finished with bad status: " << status;
223       }
224       done(status);
225     }
226   }
227 
228   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
229 };
230 
231 // A few helpers to facilitate create/delete kernels.
232 
233 // Creates a kernel based on "props" on device "device". The kernel can
234 // access the functions in the "flib". The caller takes ownership of
235 // returned "*kernel".
236 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
237                              const std::shared_ptr<const NodeProperties>& props,
238                              int graph_def_version, OpKernel** kernel);
239 
240 // Deletes "kernel" returned by CreateKernel.
241 void DeleteNonCachedKernel(OpKernel* kernel);
242 
243 }  // end namespace tensorflow
244 
245 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
246