1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17 
18 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
19 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
20 #include "tensorflow/c/tf_status.h"
21 #include "tensorflow/c/tf_status_internal.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace parallel_device {
28 namespace {
29 
30 class OpDeleter {
31  public:
operator ()(TFE_Op * to_delete) const32   void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
33 };
34 
35 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
36 
37 class StatusDeleter {
38  public:
operator ()(TF_Status * to_delete) const39   void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
40 };
41 
42 using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
43 
44 class ExecutorDeleter {
45  public:
operator ()(TFE_Executor * to_delete) const46   void operator()(TFE_Executor* to_delete) const {
47     TFE_DeleteExecutor(to_delete);
48   }
49 };
50 
51 using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
52 
53 }  // namespace
54 
55 // Allows a single op at a time to be launched without blocking.
56 //
57 // DeviceThread itself is thread-safe, in that StartExecute will block if there
58 // is a pending execution. Since StartExecute is equivalent to grabbing a lock,
59 // multiple DeviceThreads should always be accessed in the same order to avoid
60 // deadlocks.
61 class DeviceThread {
62  public:
63   // Starts a background thread waiting for `StartExecute`.
DeviceThread(const std::string & device,const bool is_async)64   explicit DeviceThread(const std::string& device, const bool is_async)
65       : status_(TF_NewStatus()),
66         device_(device),
67         // If the context's default exector is set to async, re-using that in
68         // each thread would cause collectives to deadlock. For consistency we
69         // create a new sync executor for every thread.
70         //
71         // TODO(allenl): We should have an async API that works with the
72         // parallel device.
73         executor_(TFE_NewExecutor(is_async)),
74         op_(nullptr),
75         thread_(tensorflow::Env::Default()->StartThread(
76             tensorflow::ThreadOptions(), "parallel_device_execute",
77             std::bind(&DeviceThread::Run, this))) {}
78   ~DeviceThread();
79 
80   // Requests that the worker thread execute the specified operation. Blocks
81   // until the previously pending operation (a StartExecute without a Join) has
82   // finished, if any.
83   //
84   // `cancellation_manager` must live until after `Join` finishes and pending
85   // `is_async` operations finish. In addition to allowing the caller to cancel
86   // the operation, its `StartCancel` method will be called if op execution
87   // fails on any device in order to cancel the others.
88   void StartExecute(TFE_Context* context, const char* operation_name,
89                     std::vector<TFE_TensorHandle*> inputs,
90                     const TFE_OpAttrs* attributes, int expected_max_outputs,
91                     CancellationManager& cancellation_manager);
92   // Block until the previous `StartExecute` operation has executed. Forwards
93   // the status from `TFE_Execute` and returns outputs if the status is OK.
94   std::vector<TensorHandlePtr> Join(TF_Status* status);
95 
96  private:
97   void Run();
98 
99   void Execute(TFE_Context* context, const char* operation_name,
100                std::vector<TFE_TensorHandle*> inputs,
101                const TFE_OpAttrs* attributes, int expected_max_outputs,
102                std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
103       TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
104 
105   enum class ExecutionState {
106     kReadyToExecute,
107     kHasResult,
108     kIdle,
109     kShuttingDown,
110   };
111 
112   tensorflow::mutex execution_mutex_;
113   ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
114       ExecutionState::kIdle;
115   // Tells the worker thread that there is new work.
116   tensorflow::condition_variable start_execute_;
117   // The worker thread notifies that work has finished.
118   tensorflow::condition_variable finished_execute_;
119   // Notifies a StartExecute that the previous Join has finished.
120   tensorflow::condition_variable finished_join_;
121 
122   // Temporary state between `StartExecute` and `Join`.
123   //
124   //   Inputs; pointers are to objects not owned by the DeviceThread, but which
125   //   are expected to live at least until `Join` finishes:
126   TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
127   const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
128   std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
129   const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
130   int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
131   CancellationManager* cancellation_manager_ TF_GUARDED_BY(execution_mutex_);
132   //   Outputs:
133   std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
134   // TF_Status is an incomplete type and so can't be stack allocated. To avoid
135   // unnecessary allocations each Execute call, we keep one heap-allocated
136   // version for the thread.
137   StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
138 
139   const std::string device_;
140   ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
141   mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
142   std::unique_ptr<Thread> thread_;
143 };
144 
~DeviceThread()145 DeviceThread::~DeviceThread() {
146   {
147     tensorflow::mutex_lock l(execution_mutex_);
148     execution_state_ = ExecutionState::kShuttingDown;
149   }
150   start_execute_.notify_one();
151 }
152 
Run()153 void DeviceThread::Run() {
154   while (true) {
155     {
156       tensorflow::mutex_lock l(execution_mutex_);
157       while (execution_state_ == ExecutionState::kIdle ||
158              execution_state_ == ExecutionState::kHasResult) {
159         start_execute_.wait(l);
160       }
161       if (execution_state_ == ExecutionState::kShuttingDown) {
162         return;
163       } else if (execution_state_ == ExecutionState::kReadyToExecute) {
164         // op_outputs_ may have been std::moved
165         op_outputs_ = std::vector<TensorHandlePtr>();
166         Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
167                 expected_max_outputs_, &op_outputs_, status_.get());
168         execution_state_ = ExecutionState::kHasResult;
169       }
170     }
171     finished_execute_.notify_one();
172   }
173 }
174 
StartExecute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager)175 void DeviceThread::StartExecute(TFE_Context* context,
176                                 const char* operation_name,
177                                 std::vector<TFE_TensorHandle*> inputs,
178                                 const TFE_OpAttrs* attributes,
179                                 int expected_max_outputs,
180                                 CancellationManager& cancellation_manager) {
181   {
182     tensorflow::mutex_lock l(execution_mutex_);
183     while (execution_state_ != ExecutionState::kIdle) {
184       // If there's already a pending execution, wait until Join finishes before
185       // starting on the next operation.
186       finished_join_.wait(l);
187     }
188     context_ = context;
189     operation_name_ = operation_name;
190     op_inputs_ = inputs;
191     attributes_ = attributes;
192     expected_max_outputs_ = expected_max_outputs;
193     cancellation_manager_ = &cancellation_manager;
194     execution_state_ = ExecutionState::kReadyToExecute;
195   }
196   start_execute_.notify_one();
197 }
198 
Join(TF_Status * status)199 std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
200   std::vector<TensorHandlePtr> result;
201   {
202     tensorflow::mutex_lock l(execution_mutex_);
203     while (execution_state_ != ExecutionState::kHasResult) {
204       finished_execute_.wait(l);
205     }
206     if (TF_GetCode(status_.get()) != TF_OK) {
207       TF_SetStatus(status, TF_GetCode(status_.get()),
208                    TF_Message(status_.get()));
209       // Reset the member `status_` so future op executions (after recovery from
210       // the bad `status`) start with an OK status.
211       TF_SetStatus(status_.get(), TF_OK, "");
212     }
213     cancellation_manager_ = nullptr;
214     execution_state_ = ExecutionState::kIdle;
215     result = std::move(op_outputs_);
216   }
217   finished_join_.notify_one();
218   return result;
219 }
220 
Execute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,std::vector<TensorHandlePtr> * outputs,TF_Status * status) const221 void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
222                            std::vector<TFE_TensorHandle*> inputs,
223                            const TFE_OpAttrs* attributes,
224                            int expected_max_outputs,
225                            std::vector<TensorHandlePtr>* outputs,
226                            TF_Status* status) const {
227   if (op_ == nullptr) {
228     TFE_ContextSetExecutorForThread(context, executor_.get());
229     op_.reset(TFE_NewOp(context, operation_name, status));
230     if (TF_GetCode(status) != TF_OK) return;
231     TFE_OpSetDevice(op_.get(), device_.c_str(), status);
232     if (TF_GetCode(status) != TF_OK) return;
233   } else {
234     TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
235     if (TF_GetCode(status) != TF_OK) return;
236   }
237   TFE_OpAddAttrs(op_.get(), attributes);
238   for (int input_index = 0; input_index < inputs.size(); ++input_index) {
239     TFE_OpAddInput(op_.get(), inputs[input_index], status);
240     if (TF_GetCode(status) != TF_OK) return;
241   }
242   std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
243   int real_num_outputs = expected_max_outputs;
244   TFE_OpSetCancellationManager(op_.get(), wrap(cancellation_manager_), status);
245   if (TF_GetCode(status) != TF_OK) return;
246   TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
247   if (TF_GetCode(status) != TF_OK) {
248     cancellation_manager_->StartCancel();
249     return;
250   }
251   unwrapped_results.resize(real_num_outputs);
252   outputs->reserve(real_num_outputs);
253   for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
254     outputs->emplace_back(unwrapped_result);
255   }
256 }
257 
ParallelDevice(const std::vector<std::string> & devices,const bool is_async)258 ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
259                                const bool is_async)
260     : underlying_devices_(devices),
261       default_cancellation_manager_(absl::make_unique<CancellationManager>()) {
262   device_threads_.reserve(devices.size());
263   for (int device_index = 0; device_index < devices.size(); ++device_index) {
264     device_threads_.emplace_back(
265         new DeviceThread(devices[device_index].c_str(), is_async));
266   }
267 }
268 
269 // Necessary for a unique_ptr to a forward-declared type.
270 ParallelDevice::~ParallelDevice() = default;
271 
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status) const272 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
273     TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
274   std::vector<TensorHandlePtr> components;
275   components.reserve(underlying_devices_.size());
276   for (const std::string& underlying_device_name : underlying_devices_) {
277     TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
278         tensor, context, underlying_device_name.c_str(), status);
279     if (TF_GetCode(status) != TF_OK) return nullptr;
280     components.emplace_back(t);
281   }
282   return ParallelTensor::FromTensorHandles(*this, std::move(components),
283                                            status);
284 }
285 
DeviceIDs(TFE_Context * context,TF_Status * status) const286 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
287     TFE_Context* context, TF_Status* status) const {
288   std::vector<int32_t> ids;
289   ids.reserve(num_underlying_devices());
290   for (int i = 0; i < num_underlying_devices(); ++i) {
291     ids.push_back(i);
292   }
293   return ScalarsFromSequence<int32_t>(ids, context, status);
294 }
295 
296 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Execute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status) const297 ParallelDevice::Execute(TFE_Context* context,
298                         const std::vector<ParallelTensor*>& inputs,
299                         const char* operation_name,
300                         const TFE_OpAttrs* attributes, int expected_max_outputs,
301                         TF_Status* status) const {
302   std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
303   StartExecute(context, inputs, operation_name, attributes,
304                expected_max_outputs, *default_cancellation_manager_);
305   auto result = Join(expected_output_shapes, status);
306   if (TF_GetCode(status) != TF_OK) {
307     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> await_status(
308         TF_NewStatus(), TF_DeleteStatus);
309     // Wait until all pending nodes have completed since they may have a
310     // reference to default_cancellation_manager_. We ignore the status return
311     // since we already have a bad status to propagate.
312     TFE_ContextAsyncWait(context, await_status.get());
313     // Reset the cancellation manager on a bad status. Otherwise we'll cancel
314     // all future operations.
315     default_cancellation_manager_ = absl::make_unique<CancellationManager>();
316   }
317   return result;
318 }
319 
StartExecute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager) const320 void ParallelDevice::StartExecute(
321     TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
322     const char* operation_name, const TFE_OpAttrs* attributes,
323     int expected_max_outputs, CancellationManager& cancellation_manager) const {
324   for (int device_index = 0; device_index < underlying_devices_.size();
325        ++device_index) {
326     DeviceThread* device_thread = device_threads_[device_index].get();
327     std::vector<TFE_TensorHandle*> device_inputs;
328     device_inputs.reserve(device_inputs.size());
329     for (int input_index = 0; input_index < inputs.size(); ++input_index) {
330       // Parallel tensors are divided between operations by device.
331       device_inputs.push_back(inputs[input_index]->tensor(device_index));
332     }
333     device_thread->StartExecute(context, operation_name,
334                                 std::move(device_inputs), attributes,
335                                 expected_max_outputs, cancellation_manager);
336   }
337 }
338 
339 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Join(const std::vector<PartialTensorShape> & expected_output_shapes,TF_Status * status) const340 ParallelDevice::Join(
341     const std::vector<PartialTensorShape>& expected_output_shapes,
342     TF_Status* status) const {
343   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
344   // Compute per-device per-output tensors
345   std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
346   per_device_output_tensors.reserve(underlying_devices_.size());
347   int first_op_output_count = 0;
348   StatusPtr first_bad_status(nullptr);
349   for (int device_index = 0; device_index < underlying_devices_.size();
350        ++device_index) {
351     DeviceThread* device_thread = device_threads_[device_index].get();
352     per_device_output_tensors.push_back(device_thread->Join(status));
353     // We will run every Join even if there are bad statuses in case the user
354     // wants to recover and continue running ops on the parallel device (which
355     // would otherwise deadlock).
356     if (TF_GetCode(status) != TF_OK &&
357         (first_bad_status == nullptr
358          // Prefer propagating non-cancellation related statuses to avoid
359          // shadowing the original failure.
360          || TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
361       first_bad_status.reset(TF_NewStatus());
362       TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
363                    TF_Message(status));
364     }
365 
366     if (device_index == 0) {
367       first_op_output_count = per_device_output_tensors.rbegin()->size();
368     } else {
369       if (first_bad_status == nullptr &&
370           per_device_output_tensors.rbegin()->size() != first_op_output_count) {
371         first_bad_status.reset(TF_NewStatus());
372         TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
373                      "Parallel ops produced different numbers of tensors.");
374       }
375     }
376   }
377   if (first_bad_status != nullptr) {
378     TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
379                  TF_Message(first_bad_status.get()));
380     return result;
381   }
382   // For each output of the original operation, pack the per-device
383   // TensorHandles we've computed into a single parallel TensorHandle.
384   std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
385   per_device_outputs.reserve(first_op_output_count);
386   for (int i = 0; i < first_op_output_count; ++i) {
387     std::vector<TensorHandlePtr> components;
388     components.reserve(underlying_devices_.size());
389     for (int j = 0; j < underlying_devices_.size(); ++j) {
390       components.push_back(std::move(per_device_output_tensors[j][i]));
391     }
392     if (expected_output_shapes[i].IsFullyDefined()) {
393       per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
394           *this, std::move(components),
395           absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
396           status));
397     } else {
398       per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
399           *this, std::move(components), status));
400     }
401     if (TF_GetCode(status) != TF_OK) return result;
402   }
403   result.emplace(std::move(per_device_outputs));
404   return result;
405 }
406 
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,absl::Span<const int64> shape,TF_Status * status)407 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
408     const ParallelDevice& parallel_device,
409     std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
410     TF_Status* status) {
411   TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
412   // Verify that the TensorHandle's shape and dtype match all of the component
413   // shapes and dtypes.
414   for (TensorHandlePtr& component : components) {
415     if (TFE_TensorHandleDataType(component.get()) != dtype) {
416       TF_SetStatus(status, TF_INTERNAL,
417                    "Components of a ParallelTensor must all have "
418                    "the same dtype");
419       return nullptr;
420     }
421   }
422   return std::unique_ptr<ParallelTensor>(
423       new ParallelTensor(parallel_device, std::move(components), shape, dtype));
424 }
425 
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,TF_Status * status)426 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
427     const ParallelDevice& parallel_device,
428     std::vector<TensorHandlePtr> components, TF_Status* status) {
429   TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
430   // Verify that the combined TensorHandle's dtype matches all of the component
431   // dtypes.
432   for (TensorHandlePtr& component : components) {
433     if (TFE_TensorHandleDataType(component.get()) != dtype) {
434       TF_SetStatus(status, TF_INTERNAL,
435                    "Components of a ParallelTensor must all have "
436                    "the same dtype");
437       return nullptr;
438     }
439   }
440   return std::unique_ptr<ParallelTensor>(
441       new ParallelTensor(parallel_device, std::move(components), dtype));
442 }
443 
Shape(const std::vector<int64_t> ** shape) const444 Status ParallelTensor::Shape(const std::vector<int64_t>** shape) const {
445   if (!shape_.has_value()) {
446     TF_Status status;
447     PartialTensorShape first_shape;
448     TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape));
449 
450     // Verify that the TensorHandle's shape matches all of the component shapes.
451     for (const TensorHandlePtr& component : tensors_) {
452       PartialTensorShape component_shape;
453       TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape));
454       if (!first_shape.IsIdenticalTo(component_shape)) {
455         return errors::Unimplemented(absl::StrCat(
456             "Computing the shape of a ParallelTensor when the components do "
457             "not all have the same shapes is not supported. One tensor had "
458             "shape ",
459             first_shape.DebugString(), " and another had shape ",
460             component_shape.DebugString()));
461       }
462     }
463     auto dim_sizes = first_shape.dim_sizes();
464     shape_ = std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
465   }
466   *shape = &*shape_;
467   return Status::OK();
468 }
469 
470 }  // namespace parallel_device
471 }  // namespace tensorflow
472