1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 
43 #define VALUE_IN_DEBUG_STRING false
44 
45 namespace tensorflow {
46 
47 namespace {
IsCancelled(CancellationManager * cancel_mgr)48 bool IsCancelled(CancellationManager* cancel_mgr) {
49   return cancel_mgr != nullptr &&
50          (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
51 }
52 }  // namespace
53 
54 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)55 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
56                                           int64 num_chunks) {
57   DCHECK_GT(num_chunks, 0);
58   int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
59   if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
60   if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
61     // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
62     DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
63     return base_chunk_elts;
64   }
65   // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
66   // must be a common multiple of the various atomic data types.
67   DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
68       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
69       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
70       << " elt_bytes=" << elt_bytes;
71   // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
72   int64 chunk_bytes = base_chunk_elts * elt_bytes;
73   int64 diff =
74       (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
75           ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
76           : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
77   DCHECK_EQ(0, diff % elt_bytes);
78   base_chunk_elts += (diff / elt_bytes);
79   DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
80       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
81       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
82       << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
83   return base_chunk_elts;
84 }
85 
86 namespace {
87 template <typename T>
88 class CollectiveAdapterImpl : public CollectiveAdapter {
89  public:
90   // Takes ownership of output and prepares to properly alias its chunks.
91   // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)92   CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
93                         bool align_chunks)
94       : output_(std::move(*output)),
95         dt_(output_.dtype()),
96         old_shape_(output_.shape()),
97         num_chunks_(num_chunks),
98         allocator_(allocator),
99         total_elts_(output_.NumElements()),
100         chunk_elts_(align_chunks
101                         ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
102                         : total_elts_ / num_chunks_),
103         data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
104         data_end_(data_start_ + total_elts_) {
105     if (!align_chunks) {
106       DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
107     }
108     DCHECK_GT(chunk_elts_, 0);
109     Flatten();
110   }
111 
~CollectiveAdapterImpl()112   ~CollectiveAdapterImpl() override {}
113 
Value() const114   const Tensor& Value() const override { return output_; }
115 
116   // If necessary, flatten output.
Flatten()117   void Flatten() {
118     if (old_shape_.dims() != 1) {
119       TensorShape new_shape = TensorShape({old_shape_.num_elements()});
120       DMAHelper::UnsafeSetShape(&output_, new_shape);
121     }
122   }
123 
ConsumeFinalValue(Tensor * output)124   void ConsumeFinalValue(Tensor* output) override {
125     if (old_shape_ != output_.shape()) {
126       DMAHelper::UnsafeSetShape(&output_, old_shape_);
127     }
128     *output = std::move(output_);
129   }
130 
131   // Number of T elements in a particular chunk.
ChunkElts(int i) const132   inline int64 ChunkElts(int i) const {
133     DCHECK_LT(i, num_chunks_);
134     const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
135     const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
136     return chunk_end - chunk_start;
137   }
138 
ChunkBytes(int i) const139   int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
140 
141   // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)142   Tensor ChunkAlias(int i) override {
143     int64 start = chunk_elts_ * i;
144     int64 num_elts = ChunkElts(i);
145     // If this chunk is empty the prior chunk might also be short
146     // so always take an empty slice from the front of the tensor
147     // to avoid an illegal offset check failure somewhere.
148     return (num_elts > 0) ? output_.Slice(start, start + num_elts)
149                           : output_.Slice(0, 0);
150   }
151 
TempChunk(int i) const152   Tensor TempChunk(int i) const override {
153     AllocationAttributes empty;
154     ScopedMemoryDebugAnnotation op_annotation(
155         "CollectiveAdapterImpl::TempChunk");
156     return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
157   }
158 
DebugString() const159   string DebugString() const override {
160     return strings::StrCat(
161         "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
162         " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
163         chunk_elts_, " value ",
164         VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
165   }
166 
TBounds(const Tensor & t) const167   string TBounds(const Tensor& t) const override {
168     int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
169     return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
170                            ")");
171   }
172 
Scalar(int v) const173   Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
174 
Scalar(Allocator * a,const AllocationAttributes & attr) const175   Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
176     Tensor t(a, dt_, TensorShape({}), attr);
177     return t;
178   }
179 
180   Tensor output_;
181   const DataType dt_;
182   const TensorShape old_shape_;
183   const int64 num_chunks_;
184   Allocator* allocator_;
185   const int64 total_elts_;
186   const int64 chunk_elts_;
187   const T* data_start_;
188   const T* data_end_;
189 };
190 
191 }  // namespace
192 
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)193 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
194                                          Allocator* allocator,
195                                          bool align_chunks) {
196   switch (output->dtype()) {
197     case DT_HALF:
198       return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
199                                                     allocator, align_chunks);
200       break;
201     case DT_FLOAT:
202       return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
203                                               align_chunks);
204       break;
205     case DT_DOUBLE:
206       return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
207                                                align_chunks);
208       break;
209     case DT_INT32:
210       return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
211                                               align_chunks);
212       break;
213     case DT_INT64:
214       return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
215                                               align_chunks);
216       break;
217     default:
218       LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
219                  << " to MakeCollectiveAdapter";
220       return nullptr;
221   }
222 }
223 
~BaseCollectiveExecutor()224 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
225 
StartAbort(const Status & s)226 void BaseCollectiveExecutor::StartAbort(const Status& s) {
227   Status status;
228   {
229     mutex_lock l(status_mu_);
230     if (!status_.ok()) {
231       VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
232               << s;
233       return;
234     }
235     status_ = StatusGroup::MakeDerived(Status(
236         s.code(),
237         absl::StrCat(
238             "Collective ops is aborted by: ", s.error_message(),
239             "\nThe error could be from a previous operation. Restart your "
240             "program to reset.")));
241     status = status_;
242   }
243   LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
244   cem_->GetParamResolver()->StartAbort(status);
245   remote_access_->StartAbort(status);
246   if (cem_->GetNcclCommunicator() != nullptr) {
247     cem_->GetNcclCommunicator()->StartAbort(status);
248   }
249 }
250 
GetStatus(const Status & s)251 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
252   if (s.ok()) return s;
253   mutex_lock l(status_mu_);
254   // If the collective executor is already aborted, use the aborted status
255   // which is more likely the actual error instead of an artifact of an
256   // abortion.
257   if (!status_.ok()) {
258     VLOG(2) << "Overriding status with collective ops executor status. "
259                "Original status: "
260             << s;
261     return status_;
262   }
263   return s;
264 }
265 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)266 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
267                                           const CollectiveParams* col_params,
268                                           const string& exec_key,
269                                           StatusCallback done) {
270   // See CompleteParamsAsync() how done() and the timeout callback interacts.
271   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
272   auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
273     bool called = is_callback_called->exchange(true);
274     if (!called) {
275       if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
276         // This is a collective error. Abort CollectiveExecutor so that this
277         // error can propagate to other workers.
278         StartAbort(s);
279       }
280       done(GetStatus(s));
281     }
282   };
283   auto timeout_microseconds = static_cast<int64>(
284       col_params->instance.impl_details.timeout_seconds * 1'000'000);
285   if (timeout_microseconds > 0) {
286     // TODO(xldrx): Share the timeout watchdog thread among collectives.
287     SchedNonBlockingClosureAfter(
288         timeout_microseconds, [this, is_callback_called, done] {
289           bool called = is_callback_called->exchange(true);
290           if (!called) {
291             Status status(error::DEADLINE_EXCEEDED,
292                           "Collective has timed out during execution.");
293             StartAbort(status);
294             done(status);
295           }
296         });
297   }
298 
299   Tensor* output = ctx->mutable_output(0);
300   const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
301                          col_params->instance.type == GATHER_COLLECTIVE ||
302                          col_params->instance.type == PERMUTE_COLLECTIVE ||
303                          (col_params->instance.type == BROADCAST_COLLECTIVE &&
304                           col_params->is_source))
305                             ? &ctx->input(0)
306                             : nullptr;
307   CollectiveImplementationInterface* col_impl = nullptr;
308   Status status = CreateCollective(*col_params, &col_impl);
309   if (!status.ok()) {
310     done_safe(status);
311     DCHECK_EQ(nullptr, col_impl);
312     return;
313   }
314   core::ScopedUnref unref(col_impl);
315   auto col_ctx = std::make_shared<CollectiveContext>(
316       this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
317       col_params, exec_key, step_id_, input, output);
318   status = col_impl->InitializeCollectiveContext(col_ctx);
319   if (!status.ok()) {
320     done_safe(status);
321     return;
322   }
323   // Run on an unbounded work queue that can handle blocking work so as to not
324   // starve executor threads.
325   col_impl->Ref();
326   profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
327   RunClosure([col_impl, col_ctx, done_safe, ctx,
328               context_id = producer.GetContextId()]() {
329     core::ScopedUnref unref(col_impl);
330     profiler::TraceMeConsumer consumer(
331         [ctx] {
332           string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
333                                           ctx->op_kernel().type_string_view());
334           return profiler::TraceMeEncode(std::move(op),
335                                          {{"id", ctx->step_id()}});
336         },
337         context_id);
338     col_impl->Ref();
339     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
340       core::ScopedUnref unref(col_impl);
341       done_safe(s);
342     });
343   });
344 }
345 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)346 void BaseCollectiveExecutor::CompleteParamsAsync(
347     const DeviceAttributes& device, CollectiveParams* cp,
348     CancellationManager* cancel_mgr, StatusCallback done) {
349   cp->group.gpu_ring_order = *gpu_ring_order_;
350   // We need to make sure that when the timeout callback executes,
351   // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
352   // is called, CollectiveExecutorMgr may be destructed and we don't have a way
353   // to keep it without making the ownerships more complicated. Therefore if the
354   // timeout callback executes, done_safe will become a no-op and the timeout
355   // callback is responsible for invoking done() at the end.
356   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
357   auto trace_id =
358       profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
359   auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
360                     done](const Status& s) {
361     profiler::TraceMe::ActivityEnd(trace_id);
362     bool called = is_callback_called->exchange(true);
363     if (!called) {
364       if (!s.ok() && !IsCancelled(cancel_mgr)) {
365         // This is a collective error. Abort CollectiveExecutor so that this
366         // error can propagate to other workers.
367         StartAbort(s);
368       }
369       done(GetStatus(s));
370     }
371   };
372   auto timeout_microseconds =
373       static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
374   if (timeout_microseconds > 0) {
375     // TODO(xldrx): Share the timeout watchdog thread among collectives.
376     SchedNonBlockingClosureAfter(
377         timeout_microseconds, [this, is_callback_called, done]() {
378           bool called = is_callback_called->exchange(true);
379           if (!called) {
380             Status status(
381                 error::DEADLINE_EXCEEDED,
382                 "Collective has timed out waiting for other workers.");
383             StartAbort(status);
384             done(status);
385           }
386         });
387   }
388   cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
389                                                 done_safe);
390 }
391 
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)392 Status BaseCollectiveExecutor::CreateCollective(
393     const CollectiveParams& col_params,
394     CollectiveImplementationInterface** col_impl) {
395   VLOG(2) << "CreateCollective type "
396           << DataTypeString(col_params.instance.data_type) << " name "
397           << col_params.instance.impl_details.collective_name;
398   *col_impl = nullptr;
399   switch (col_params.instance.data_type) {
400     case DT_BOOL:
401       if (col_params.instance.type == BROADCAST_COLLECTIVE) {
402         return CollectiveRegistry::Lookup(
403             col_params.instance.impl_details.collective_name, col_impl);
404       } else {
405         return errors::Internal(
406             "No collective other than broadcast supports DT_BOOL");
407       }
408     case DT_INT32:
409       if (col_params.group.device_type == DEVICE_GPU &&
410           col_params.instance.type == REDUCTION_COLLECTIVE) {
411         // TODO(b/139421603): enable int32 all-reduce on GPU.
412         return errors::Internal(
413             "Collective all-reduce does not support datatype DT_INT32 on "
414             "DEVICE_GPU");
415       } else {
416         return CollectiveRegistry::Lookup(
417             col_params.instance.impl_details.collective_name, col_impl);
418       }
419     case DT_HALF:
420     case DT_FLOAT:
421     case DT_DOUBLE:
422     case DT_INT64: {
423       return CollectiveRegistry::Lookup(
424           col_params.instance.impl_details.collective_name, col_impl);
425     }
426     default:
427       return errors::Internal(
428           "CollectiveImplementation does not support datatype ",
429           DataTypeString(col_params.instance.data_type));
430   }
431 }
432 
CheckDependencies(const CollectiveParams & col_params)433 bool BaseCollectiveExecutor::CheckDependencies(
434     const CollectiveParams& col_params) {
435   for (int32 instance : col_params.instance.impl_details.dependencies) {
436     auto find_iter = launched_.find(instance);
437     if (find_iter == launched_.end() || find_iter->second != 0) {
438       VLOG(1) << "Collective " << col_params.ToString()
439               << " blocked by instance " << instance;
440       return false;
441     }
442   }
443   return true;
444 }
445 
WaitForDependencies(const CollectiveParams & col_params)446 void BaseCollectiveExecutor::WaitForDependencies(
447     const CollectiveParams& col_params) {
448   mutex_lock l(launch_mu_);
449   while (!CheckDependencies(col_params)) {
450     launch_cv_.wait(l);
451   }
452   VLOG(1) << "Unblocking collective " << col_params.ToString();
453 }
454 
UnblockDependencies(const CollectiveParams & col_params)455 void BaseCollectiveExecutor::UnblockDependencies(
456     const CollectiveParams& col_params) {
457   mutex_lock l(launch_mu_);
458   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
459     const string& task_name =
460         col_params.group.task_names[col_params.default_rank];
461     const int32 num_devices =
462         col_params.group.num_devices_per_task.at(task_name);
463     launched_[col_params.instance.instance_key] = num_devices;
464   }
465   if (--launched_[col_params.instance.instance_key] == 0) {
466     VLOG(1) << "Unblocking dependencies for collective instance "
467             << col_params.instance.instance_key;
468     launch_cv_.notify_all();
469   }
470 }
471 
472 }  // namespace tensorflow
473