1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17
18 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
19 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
20 #include "tensorflow/c/tf_status.h"
21 #include "tensorflow/c/tf_status_internal.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27 namespace parallel_device {
28 namespace {
29
30 class OpDeleter {
31 public:
operator ()(TFE_Op * to_delete) const32 void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
33 };
34
35 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
36
37 class StatusDeleter {
38 public:
operator ()(TF_Status * to_delete) const39 void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
40 };
41
42 using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
43
44 class ExecutorDeleter {
45 public:
operator ()(TFE_Executor * to_delete) const46 void operator()(TFE_Executor* to_delete) const {
47 TFE_DeleteExecutor(to_delete);
48 }
49 };
50
51 using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
52
53 } // namespace
54
55 // Allows a single op at a time to be launched without blocking.
56 //
57 // DeviceThread itself is thread-safe, in that StartExecute will block if there
58 // is a pending execution. Since StartExecute is equivalent to grabbing a lock,
59 // multiple DeviceThreads should always be accessed in the same order to avoid
60 // deadlocks.
61 class DeviceThread {
62 public:
63 // Starts a background thread waiting for `StartExecute`.
DeviceThread(const std::string & device,const bool is_async)64 explicit DeviceThread(const std::string& device, const bool is_async)
65 : status_(TF_NewStatus()),
66 device_(device),
67 // If the context's default exector is set to async, re-using that in
68 // each thread would cause collectives to deadlock. For consistency we
69 // create a new sync executor for every thread.
70 //
71 // TODO(allenl): We should have an async API that works with the
72 // parallel device.
73 executor_(TFE_NewExecutor(is_async)),
74 op_(nullptr),
75 thread_(tensorflow::Env::Default()->StartThread(
76 tensorflow::ThreadOptions(), "parallel_device_execute",
77 std::bind(&DeviceThread::Run, this))) {}
78 ~DeviceThread();
79
80 // Requests that the worker thread execute the specified operation. Blocks
81 // until the previously pending operation (a StartExecute without a Join) has
82 // finished, if any.
83 //
84 // `cancellation_manager` must live until after `Join` finishes and pending
85 // `is_async` operations finish. In addition to allowing the caller to cancel
86 // the operation, its `StartCancel` method will be called if op execution
87 // fails on any device in order to cancel the others.
88 void StartExecute(TFE_Context* context, const char* operation_name,
89 std::vector<TFE_TensorHandle*> inputs,
90 const TFE_OpAttrs* attributes, int expected_max_outputs,
91 CancellationManager& cancellation_manager);
92 // Block until the previous `StartExecute` operation has executed. Forwards
93 // the status from `TFE_Execute` and returns outputs if the status is OK.
94 std::vector<TensorHandlePtr> Join(TF_Status* status);
95
96 private:
97 void Run();
98
99 void Execute(TFE_Context* context, const char* operation_name,
100 std::vector<TFE_TensorHandle*> inputs,
101 const TFE_OpAttrs* attributes, int expected_max_outputs,
102 std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
103 TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
104
105 enum class ExecutionState {
106 kReadyToExecute,
107 kHasResult,
108 kIdle,
109 kShuttingDown,
110 };
111
112 tensorflow::mutex execution_mutex_;
113 ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
114 ExecutionState::kIdle;
115 // Tells the worker thread that there is new work.
116 tensorflow::condition_variable start_execute_;
117 // The worker thread notifies that work has finished.
118 tensorflow::condition_variable finished_execute_;
119 // Notifies a StartExecute that the previous Join has finished.
120 tensorflow::condition_variable finished_join_;
121
122 // Temporary state between `StartExecute` and `Join`.
123 //
124 // Inputs; pointers are to objects not owned by the DeviceThread, but which
125 // are expected to live at least until `Join` finishes:
126 TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
127 const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
128 std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
129 const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
130 int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
131 CancellationManager* cancellation_manager_ TF_GUARDED_BY(execution_mutex_);
132 // Outputs:
133 std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
134 // TF_Status is an incomplete type and so can't be stack allocated. To avoid
135 // unnecessary allocations each Execute call, we keep one heap-allocated
136 // version for the thread.
137 StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
138
139 const std::string device_;
140 ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
141 mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
142 std::unique_ptr<Thread> thread_;
143 };
144
~DeviceThread()145 DeviceThread::~DeviceThread() {
146 {
147 tensorflow::mutex_lock l(execution_mutex_);
148 execution_state_ = ExecutionState::kShuttingDown;
149 }
150 start_execute_.notify_one();
151 }
152
Run()153 void DeviceThread::Run() {
154 while (true) {
155 {
156 tensorflow::mutex_lock l(execution_mutex_);
157 while (execution_state_ == ExecutionState::kIdle ||
158 execution_state_ == ExecutionState::kHasResult) {
159 start_execute_.wait(l);
160 }
161 if (execution_state_ == ExecutionState::kShuttingDown) {
162 return;
163 } else if (execution_state_ == ExecutionState::kReadyToExecute) {
164 // op_outputs_ may have been std::moved
165 op_outputs_ = std::vector<TensorHandlePtr>();
166 Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
167 expected_max_outputs_, &op_outputs_, status_.get());
168 execution_state_ = ExecutionState::kHasResult;
169 }
170 }
171 finished_execute_.notify_one();
172 }
173 }
174
StartExecute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager)175 void DeviceThread::StartExecute(TFE_Context* context,
176 const char* operation_name,
177 std::vector<TFE_TensorHandle*> inputs,
178 const TFE_OpAttrs* attributes,
179 int expected_max_outputs,
180 CancellationManager& cancellation_manager) {
181 {
182 tensorflow::mutex_lock l(execution_mutex_);
183 while (execution_state_ != ExecutionState::kIdle) {
184 // If there's already a pending execution, wait until Join finishes before
185 // starting on the next operation.
186 finished_join_.wait(l);
187 }
188 context_ = context;
189 operation_name_ = operation_name;
190 op_inputs_ = inputs;
191 attributes_ = attributes;
192 expected_max_outputs_ = expected_max_outputs;
193 cancellation_manager_ = &cancellation_manager;
194 execution_state_ = ExecutionState::kReadyToExecute;
195 }
196 start_execute_.notify_one();
197 }
198
Join(TF_Status * status)199 std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
200 std::vector<TensorHandlePtr> result;
201 {
202 tensorflow::mutex_lock l(execution_mutex_);
203 while (execution_state_ != ExecutionState::kHasResult) {
204 finished_execute_.wait(l);
205 }
206 if (TF_GetCode(status_.get()) != TF_OK) {
207 TF_SetStatus(status, TF_GetCode(status_.get()),
208 TF_Message(status_.get()));
209 // Reset the member `status_` so future op executions (after recovery from
210 // the bad `status`) start with an OK status.
211 TF_SetStatus(status_.get(), TF_OK, "");
212 }
213 cancellation_manager_ = nullptr;
214 execution_state_ = ExecutionState::kIdle;
215 result = std::move(op_outputs_);
216 }
217 finished_join_.notify_one();
218 return result;
219 }
220
Execute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,std::vector<TensorHandlePtr> * outputs,TF_Status * status) const221 void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
222 std::vector<TFE_TensorHandle*> inputs,
223 const TFE_OpAttrs* attributes,
224 int expected_max_outputs,
225 std::vector<TensorHandlePtr>* outputs,
226 TF_Status* status) const {
227 if (op_ == nullptr) {
228 TFE_ContextSetExecutorForThread(context, executor_.get());
229 op_.reset(TFE_NewOp(context, operation_name, status));
230 if (TF_GetCode(status) != TF_OK) return;
231 TFE_OpSetDevice(op_.get(), device_.c_str(), status);
232 if (TF_GetCode(status) != TF_OK) return;
233 } else {
234 TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
235 if (TF_GetCode(status) != TF_OK) return;
236 }
237 TFE_OpAddAttrs(op_.get(), attributes);
238 for (int input_index = 0; input_index < inputs.size(); ++input_index) {
239 TFE_OpAddInput(op_.get(), inputs[input_index], status);
240 if (TF_GetCode(status) != TF_OK) return;
241 }
242 std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
243 int real_num_outputs = expected_max_outputs;
244 TFE_OpSetCancellationManager(op_.get(), wrap(cancellation_manager_), status);
245 if (TF_GetCode(status) != TF_OK) return;
246 TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
247 if (TF_GetCode(status) != TF_OK) {
248 cancellation_manager_->StartCancel();
249 return;
250 }
251 unwrapped_results.resize(real_num_outputs);
252 outputs->reserve(real_num_outputs);
253 for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
254 outputs->emplace_back(unwrapped_result);
255 }
256 }
257
ParallelDevice(const std::vector<std::string> & devices,const bool is_async)258 ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
259 const bool is_async)
260 : underlying_devices_(devices),
261 default_cancellation_manager_(absl::make_unique<CancellationManager>()) {
262 device_threads_.reserve(devices.size());
263 for (int device_index = 0; device_index < devices.size(); ++device_index) {
264 device_threads_.emplace_back(
265 new DeviceThread(devices[device_index].c_str(), is_async));
266 }
267 }
268
269 // Necessary for a unique_ptr to a forward-declared type.
270 ParallelDevice::~ParallelDevice() = default;
271
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status) const272 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
273 TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
274 std::vector<TensorHandlePtr> components;
275 components.reserve(underlying_devices_.size());
276 for (const std::string& underlying_device_name : underlying_devices_) {
277 TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
278 tensor, context, underlying_device_name.c_str(), status);
279 if (TF_GetCode(status) != TF_OK) return nullptr;
280 components.emplace_back(t);
281 }
282 return ParallelTensor::FromTensorHandles(*this, std::move(components),
283 status);
284 }
285
DeviceIDs(TFE_Context * context,TF_Status * status) const286 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
287 TFE_Context* context, TF_Status* status) const {
288 std::vector<int32_t> ids;
289 ids.reserve(num_underlying_devices());
290 for (int i = 0; i < num_underlying_devices(); ++i) {
291 ids.push_back(i);
292 }
293 return ScalarsFromSequence<int32_t>(ids, context, status);
294 }
295
296 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Execute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status) const297 ParallelDevice::Execute(TFE_Context* context,
298 const std::vector<ParallelTensor*>& inputs,
299 const char* operation_name,
300 const TFE_OpAttrs* attributes, int expected_max_outputs,
301 TF_Status* status) const {
302 std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
303 StartExecute(context, inputs, operation_name, attributes,
304 expected_max_outputs, *default_cancellation_manager_);
305 auto result = Join(expected_output_shapes, status);
306 if (TF_GetCode(status) != TF_OK) {
307 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> await_status(
308 TF_NewStatus(), TF_DeleteStatus);
309 // Wait until all pending nodes have completed since they may have a
310 // reference to default_cancellation_manager_. We ignore the status return
311 // since we already have a bad status to propagate.
312 TFE_ContextAsyncWait(context, await_status.get());
313 // Reset the cancellation manager on a bad status. Otherwise we'll cancel
314 // all future operations.
315 default_cancellation_manager_ = absl::make_unique<CancellationManager>();
316 }
317 return result;
318 }
319
StartExecute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager) const320 void ParallelDevice::StartExecute(
321 TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
322 const char* operation_name, const TFE_OpAttrs* attributes,
323 int expected_max_outputs, CancellationManager& cancellation_manager) const {
324 for (int device_index = 0; device_index < underlying_devices_.size();
325 ++device_index) {
326 DeviceThread* device_thread = device_threads_[device_index].get();
327 std::vector<TFE_TensorHandle*> device_inputs;
328 device_inputs.reserve(device_inputs.size());
329 for (int input_index = 0; input_index < inputs.size(); ++input_index) {
330 // Parallel tensors are divided between operations by device.
331 device_inputs.push_back(inputs[input_index]->tensor(device_index));
332 }
333 device_thread->StartExecute(context, operation_name,
334 std::move(device_inputs), attributes,
335 expected_max_outputs, cancellation_manager);
336 }
337 }
338
339 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Join(const std::vector<PartialTensorShape> & expected_output_shapes,TF_Status * status) const340 ParallelDevice::Join(
341 const std::vector<PartialTensorShape>& expected_output_shapes,
342 TF_Status* status) const {
343 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
344 // Compute per-device per-output tensors
345 std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
346 per_device_output_tensors.reserve(underlying_devices_.size());
347 int first_op_output_count = 0;
348 StatusPtr first_bad_status(nullptr);
349 for (int device_index = 0; device_index < underlying_devices_.size();
350 ++device_index) {
351 DeviceThread* device_thread = device_threads_[device_index].get();
352 per_device_output_tensors.push_back(device_thread->Join(status));
353 // We will run every Join even if there are bad statuses in case the user
354 // wants to recover and continue running ops on the parallel device (which
355 // would otherwise deadlock).
356 if (TF_GetCode(status) != TF_OK &&
357 (first_bad_status == nullptr
358 // Prefer propagating non-cancellation related statuses to avoid
359 // shadowing the original failure.
360 || TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
361 first_bad_status.reset(TF_NewStatus());
362 TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
363 TF_Message(status));
364 }
365
366 if (device_index == 0) {
367 first_op_output_count = per_device_output_tensors.rbegin()->size();
368 } else {
369 if (first_bad_status == nullptr &&
370 per_device_output_tensors.rbegin()->size() != first_op_output_count) {
371 first_bad_status.reset(TF_NewStatus());
372 TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
373 "Parallel ops produced different numbers of tensors.");
374 }
375 }
376 }
377 if (first_bad_status != nullptr) {
378 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
379 TF_Message(first_bad_status.get()));
380 return result;
381 }
382 // For each output of the original operation, pack the per-device
383 // TensorHandles we've computed into a single parallel TensorHandle.
384 std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
385 per_device_outputs.reserve(first_op_output_count);
386 for (int i = 0; i < first_op_output_count; ++i) {
387 std::vector<TensorHandlePtr> components;
388 components.reserve(underlying_devices_.size());
389 for (int j = 0; j < underlying_devices_.size(); ++j) {
390 components.push_back(std::move(per_device_output_tensors[j][i]));
391 }
392 if (expected_output_shapes[i].IsFullyDefined()) {
393 per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
394 *this, std::move(components),
395 absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
396 status));
397 } else {
398 per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
399 *this, std::move(components), status));
400 }
401 if (TF_GetCode(status) != TF_OK) return result;
402 }
403 result.emplace(std::move(per_device_outputs));
404 return result;
405 }
406
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,absl::Span<const int64> shape,TF_Status * status)407 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
408 const ParallelDevice& parallel_device,
409 std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
410 TF_Status* status) {
411 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
412 // Verify that the TensorHandle's shape and dtype match all of the component
413 // shapes and dtypes.
414 for (TensorHandlePtr& component : components) {
415 if (TFE_TensorHandleDataType(component.get()) != dtype) {
416 TF_SetStatus(status, TF_INTERNAL,
417 "Components of a ParallelTensor must all have "
418 "the same dtype");
419 return nullptr;
420 }
421 }
422 return std::unique_ptr<ParallelTensor>(
423 new ParallelTensor(parallel_device, std::move(components), shape, dtype));
424 }
425
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,TF_Status * status)426 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
427 const ParallelDevice& parallel_device,
428 std::vector<TensorHandlePtr> components, TF_Status* status) {
429 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
430 // Verify that the combined TensorHandle's dtype matches all of the component
431 // dtypes.
432 for (TensorHandlePtr& component : components) {
433 if (TFE_TensorHandleDataType(component.get()) != dtype) {
434 TF_SetStatus(status, TF_INTERNAL,
435 "Components of a ParallelTensor must all have "
436 "the same dtype");
437 return nullptr;
438 }
439 }
440 return std::unique_ptr<ParallelTensor>(
441 new ParallelTensor(parallel_device, std::move(components), dtype));
442 }
443
Shape(const std::vector<int64_t> ** shape) const444 Status ParallelTensor::Shape(const std::vector<int64_t>** shape) const {
445 if (!shape_.has_value()) {
446 TF_Status status;
447 PartialTensorShape first_shape;
448 TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape));
449
450 // Verify that the TensorHandle's shape matches all of the component shapes.
451 for (const TensorHandlePtr& component : tensors_) {
452 PartialTensorShape component_shape;
453 TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape));
454 if (!first_shape.IsIdenticalTo(component_shape)) {
455 return errors::Unimplemented(absl::StrCat(
456 "Computing the shape of a ParallelTensor when the components do "
457 "not all have the same shapes is not supported. One tensor had "
458 "shape ",
459 first_shape.DebugString(), " and another had shape ",
460 component_shape.DebugString()));
461 }
462 }
463 auto dim_sizes = first_shape.dim_sizes();
464 shape_ = std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
465 }
466 *shape = &*shape_;
467 return Status::OK();
468 }
469
470 } // namespace parallel_device
471 } // namespace tensorflow
472