1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42
43 #define VALUE_IN_DEBUG_STRING false
44
45 namespace tensorflow {
46
47 namespace {
IsCancelled(CancellationManager * cancel_mgr)48 bool IsCancelled(CancellationManager* cancel_mgr) {
49 return cancel_mgr != nullptr &&
50 (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
51 }
52 } // namespace
53
54 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)55 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
56 int64 num_chunks) {
57 DCHECK_GT(num_chunks, 0);
58 int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
59 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
60 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
61 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
62 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
63 return base_chunk_elts;
64 }
65 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
66 // must be a common multiple of the various atomic data types.
67 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
68 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
69 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
70 << " elt_bytes=" << elt_bytes;
71 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
72 int64 chunk_bytes = base_chunk_elts * elt_bytes;
73 int64 diff =
74 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
75 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
76 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
77 DCHECK_EQ(0, diff % elt_bytes);
78 base_chunk_elts += (diff / elt_bytes);
79 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
80 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
81 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
82 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
83 return base_chunk_elts;
84 }
85
86 namespace {
87 template <typename T>
88 class CollectiveAdapterImpl : public CollectiveAdapter {
89 public:
90 // Takes ownership of output and prepares to properly alias its chunks.
91 // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)92 CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
93 bool align_chunks)
94 : output_(std::move(*output)),
95 dt_(output_.dtype()),
96 old_shape_(output_.shape()),
97 num_chunks_(num_chunks),
98 allocator_(allocator),
99 total_elts_(output_.NumElements()),
100 chunk_elts_(align_chunks
101 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
102 : total_elts_ / num_chunks_),
103 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
104 data_end_(data_start_ + total_elts_) {
105 if (!align_chunks) {
106 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
107 }
108 DCHECK_GT(chunk_elts_, 0);
109 Flatten();
110 }
111
~CollectiveAdapterImpl()112 ~CollectiveAdapterImpl() override {}
113
Value() const114 const Tensor& Value() const override { return output_; }
115
116 // If necessary, flatten output.
Flatten()117 void Flatten() {
118 if (old_shape_.dims() != 1) {
119 TensorShape new_shape = TensorShape({old_shape_.num_elements()});
120 DMAHelper::UnsafeSetShape(&output_, new_shape);
121 }
122 }
123
ConsumeFinalValue(Tensor * output)124 void ConsumeFinalValue(Tensor* output) override {
125 if (old_shape_ != output_.shape()) {
126 DMAHelper::UnsafeSetShape(&output_, old_shape_);
127 }
128 *output = std::move(output_);
129 }
130
131 // Number of T elements in a particular chunk.
ChunkElts(int i) const132 inline int64 ChunkElts(int i) const {
133 DCHECK_LT(i, num_chunks_);
134 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
135 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
136 return chunk_end - chunk_start;
137 }
138
ChunkBytes(int i) const139 int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
140
141 // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)142 Tensor ChunkAlias(int i) override {
143 int64 start = chunk_elts_ * i;
144 int64 num_elts = ChunkElts(i);
145 // If this chunk is empty the prior chunk might also be short
146 // so always take an empty slice from the front of the tensor
147 // to avoid an illegal offset check failure somewhere.
148 return (num_elts > 0) ? output_.Slice(start, start + num_elts)
149 : output_.Slice(0, 0);
150 }
151
TempChunk(int i) const152 Tensor TempChunk(int i) const override {
153 AllocationAttributes empty;
154 ScopedMemoryDebugAnnotation op_annotation(
155 "CollectiveAdapterImpl::TempChunk");
156 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
157 }
158
DebugString() const159 string DebugString() const override {
160 return strings::StrCat(
161 "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
162 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
163 chunk_elts_, " value ",
164 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
165 }
166
TBounds(const Tensor & t) const167 string TBounds(const Tensor& t) const override {
168 int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
169 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
170 ")");
171 }
172
Scalar(int v) const173 Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
174
Scalar(Allocator * a,const AllocationAttributes & attr) const175 Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
176 Tensor t(a, dt_, TensorShape({}), attr);
177 return t;
178 }
179
180 Tensor output_;
181 const DataType dt_;
182 const TensorShape old_shape_;
183 const int64 num_chunks_;
184 Allocator* allocator_;
185 const int64 total_elts_;
186 const int64 chunk_elts_;
187 const T* data_start_;
188 const T* data_end_;
189 };
190
191 } // namespace
192
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)193 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
194 Allocator* allocator,
195 bool align_chunks) {
196 switch (output->dtype()) {
197 case DT_HALF:
198 return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
199 allocator, align_chunks);
200 break;
201 case DT_FLOAT:
202 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
203 align_chunks);
204 break;
205 case DT_DOUBLE:
206 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
207 align_chunks);
208 break;
209 case DT_INT32:
210 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
211 align_chunks);
212 break;
213 case DT_INT64:
214 return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
215 align_chunks);
216 break;
217 default:
218 LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
219 << " to MakeCollectiveAdapter";
220 return nullptr;
221 }
222 }
223
~BaseCollectiveExecutor()224 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
225
StartAbort(const Status & s)226 void BaseCollectiveExecutor::StartAbort(const Status& s) {
227 Status status;
228 {
229 mutex_lock l(status_mu_);
230 if (!status_.ok()) {
231 VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
232 << s;
233 return;
234 }
235 status_ = StatusGroup::MakeDerived(Status(
236 s.code(),
237 absl::StrCat(
238 "Collective ops is aborted by: ", s.error_message(),
239 "\nThe error could be from a previous operation. Restart your "
240 "program to reset.")));
241 status = status_;
242 }
243 LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
244 cem_->GetParamResolver()->StartAbort(status);
245 remote_access_->StartAbort(status);
246 if (cem_->GetNcclCommunicator() != nullptr) {
247 cem_->GetNcclCommunicator()->StartAbort(status);
248 }
249 }
250
GetStatus(const Status & s)251 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
252 if (s.ok()) return s;
253 mutex_lock l(status_mu_);
254 // If the collective executor is already aborted, use the aborted status
255 // which is more likely the actual error instead of an artifact of an
256 // abortion.
257 if (!status_.ok()) {
258 VLOG(2) << "Overriding status with collective ops executor status. "
259 "Original status: "
260 << s;
261 return status_;
262 }
263 return s;
264 }
265
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)266 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
267 const CollectiveParams* col_params,
268 const string& exec_key,
269 StatusCallback done) {
270 // See CompleteParamsAsync() how done() and the timeout callback interacts.
271 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
272 auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
273 bool called = is_callback_called->exchange(true);
274 if (!called) {
275 if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
276 // This is a collective error. Abort CollectiveExecutor so that this
277 // error can propagate to other workers.
278 StartAbort(s);
279 }
280 done(GetStatus(s));
281 }
282 };
283 auto timeout_microseconds = static_cast<int64>(
284 col_params->instance.impl_details.timeout_seconds * 1'000'000);
285 if (timeout_microseconds > 0) {
286 // TODO(xldrx): Share the timeout watchdog thread among collectives.
287 SchedNonBlockingClosureAfter(
288 timeout_microseconds, [this, is_callback_called, done] {
289 bool called = is_callback_called->exchange(true);
290 if (!called) {
291 Status status(error::DEADLINE_EXCEEDED,
292 "Collective has timed out during execution.");
293 StartAbort(status);
294 done(status);
295 }
296 });
297 }
298
299 Tensor* output = ctx->mutable_output(0);
300 const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
301 col_params->instance.type == GATHER_COLLECTIVE ||
302 col_params->instance.type == PERMUTE_COLLECTIVE ||
303 (col_params->instance.type == BROADCAST_COLLECTIVE &&
304 col_params->is_source))
305 ? &ctx->input(0)
306 : nullptr;
307 CollectiveImplementationInterface* col_impl = nullptr;
308 Status status = CreateCollective(*col_params, &col_impl);
309 if (!status.ok()) {
310 done_safe(status);
311 DCHECK_EQ(nullptr, col_impl);
312 return;
313 }
314 core::ScopedUnref unref(col_impl);
315 auto col_ctx = std::make_shared<CollectiveContext>(
316 this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
317 col_params, exec_key, step_id_, input, output);
318 status = col_impl->InitializeCollectiveContext(col_ctx);
319 if (!status.ok()) {
320 done_safe(status);
321 return;
322 }
323 // Run on an unbounded work queue that can handle blocking work so as to not
324 // starve executor threads.
325 col_impl->Ref();
326 profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
327 RunClosure([col_impl, col_ctx, done_safe, ctx,
328 context_id = producer.GetContextId()]() {
329 core::ScopedUnref unref(col_impl);
330 profiler::TraceMeConsumer consumer(
331 [ctx] {
332 string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
333 ctx->op_kernel().type_string_view());
334 return profiler::TraceMeEncode(std::move(op),
335 {{"id", ctx->step_id()}});
336 },
337 context_id);
338 col_impl->Ref();
339 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
340 core::ScopedUnref unref(col_impl);
341 done_safe(s);
342 });
343 });
344 }
345
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)346 void BaseCollectiveExecutor::CompleteParamsAsync(
347 const DeviceAttributes& device, CollectiveParams* cp,
348 CancellationManager* cancel_mgr, StatusCallback done) {
349 cp->group.gpu_ring_order = *gpu_ring_order_;
350 // We need to make sure that when the timeout callback executes,
351 // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
352 // is called, CollectiveExecutorMgr may be destructed and we don't have a way
353 // to keep it without making the ownerships more complicated. Therefore if the
354 // timeout callback executes, done_safe will become a no-op and the timeout
355 // callback is responsible for invoking done() at the end.
356 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
357 auto trace_id =
358 profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
359 auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
360 done](const Status& s) {
361 profiler::TraceMe::ActivityEnd(trace_id);
362 bool called = is_callback_called->exchange(true);
363 if (!called) {
364 if (!s.ok() && !IsCancelled(cancel_mgr)) {
365 // This is a collective error. Abort CollectiveExecutor so that this
366 // error can propagate to other workers.
367 StartAbort(s);
368 }
369 done(GetStatus(s));
370 }
371 };
372 auto timeout_microseconds =
373 static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
374 if (timeout_microseconds > 0) {
375 // TODO(xldrx): Share the timeout watchdog thread among collectives.
376 SchedNonBlockingClosureAfter(
377 timeout_microseconds, [this, is_callback_called, done]() {
378 bool called = is_callback_called->exchange(true);
379 if (!called) {
380 Status status(
381 error::DEADLINE_EXCEEDED,
382 "Collective has timed out waiting for other workers.");
383 StartAbort(status);
384 done(status);
385 }
386 });
387 }
388 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
389 done_safe);
390 }
391
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)392 Status BaseCollectiveExecutor::CreateCollective(
393 const CollectiveParams& col_params,
394 CollectiveImplementationInterface** col_impl) {
395 VLOG(2) << "CreateCollective type "
396 << DataTypeString(col_params.instance.data_type) << " name "
397 << col_params.instance.impl_details.collective_name;
398 *col_impl = nullptr;
399 switch (col_params.instance.data_type) {
400 case DT_BOOL:
401 if (col_params.instance.type == BROADCAST_COLLECTIVE) {
402 return CollectiveRegistry::Lookup(
403 col_params.instance.impl_details.collective_name, col_impl);
404 } else {
405 return errors::Internal(
406 "No collective other than broadcast supports DT_BOOL");
407 }
408 case DT_INT32:
409 if (col_params.group.device_type == DEVICE_GPU &&
410 col_params.instance.type == REDUCTION_COLLECTIVE) {
411 // TODO(b/139421603): enable int32 all-reduce on GPU.
412 return errors::Internal(
413 "Collective all-reduce does not support datatype DT_INT32 on "
414 "DEVICE_GPU");
415 } else {
416 return CollectiveRegistry::Lookup(
417 col_params.instance.impl_details.collective_name, col_impl);
418 }
419 case DT_HALF:
420 case DT_FLOAT:
421 case DT_DOUBLE:
422 case DT_INT64: {
423 return CollectiveRegistry::Lookup(
424 col_params.instance.impl_details.collective_name, col_impl);
425 }
426 default:
427 return errors::Internal(
428 "CollectiveImplementation does not support datatype ",
429 DataTypeString(col_params.instance.data_type));
430 }
431 }
432
CheckDependencies(const CollectiveParams & col_params)433 bool BaseCollectiveExecutor::CheckDependencies(
434 const CollectiveParams& col_params) {
435 for (int32 instance : col_params.instance.impl_details.dependencies) {
436 auto find_iter = launched_.find(instance);
437 if (find_iter == launched_.end() || find_iter->second != 0) {
438 VLOG(1) << "Collective " << col_params.ToString()
439 << " blocked by instance " << instance;
440 return false;
441 }
442 }
443 return true;
444 }
445
WaitForDependencies(const CollectiveParams & col_params)446 void BaseCollectiveExecutor::WaitForDependencies(
447 const CollectiveParams& col_params) {
448 mutex_lock l(launch_mu_);
449 while (!CheckDependencies(col_params)) {
450 launch_cv_.wait(l);
451 }
452 VLOG(1) << "Unblocking collective " << col_params.ToString();
453 }
454
UnblockDependencies(const CollectiveParams & col_params)455 void BaseCollectiveExecutor::UnblockDependencies(
456 const CollectiveParams& col_params) {
457 mutex_lock l(launch_mu_);
458 if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
459 const string& task_name =
460 col_params.group.task_names[col_params.default_rank];
461 const int32 num_devices =
462 col_params.group.num_devices_per_task.at(task_name);
463 launched_[col_params.instance.instance_key] = num_devices;
464 }
465 if (--launched_[col_params.instance.instance_key] == 0) {
466 VLOG(1) << "Unblocking dependencies for collective instance "
467 << col_params.instance.instance_key;
468 launch_cv_.notify_all();
469 }
470 }
471
472 } // namespace tensorflow
473