1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_alg.h"
16
17 #include <stdlib.h>
18 #include <atomic>
19 #include <functional>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/copy_tensor.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/dma_helper.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/types.h"
41
42 // Set true for greater intelligibility of debug mode log messages.
43 #define READABLE_KEYS false
44 // A ring algorithm exchanges chunks of tensor between devices. The chunk size
45 // depends on the number of subdivisions specified in the algorithm. If the
46 // user does not specify the number of subdivisions we may infer the number
47 // dynamically so that the resulting chunk size does not exceed
48 // kMaxChunkSizeBytes, empirically set at 4 MiB.
49 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
50 // kMaxSubdivsPerDev is used to give an upper bound on the number of
51 // subdivisions dynamically generated. A reasonable value would be a small
52 // multiple of the number of NICs adjacent to each device.
53 constexpr int kMaxSubdivsPerDevice = 2;
54
55 namespace tensorflow {
56 namespace {
57 // Each CollectiveOp implementation is free to define its own
58 // BufRendezvous key format. This function produces the key used by
59 // RingAlg instances. Note that the exec_key will differentiate between
60 // different instances consequently we don't need to further differentiate
61 // between subclasses of RingAlg.
RingAlgBufKey(const string & name,const string & exec_key,int pass,int section,int source_rank)62 string RingAlgBufKey(const string& name, const string& exec_key, int pass,
63 int section, int source_rank) {
64 if (READABLE_KEYS) {
65 return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
66 section, "):srcrank(", source_rank, ")");
67 } else {
68 // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
69 // hash.
70 return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
71 }
72 }
73
74 } // namespace
75
Enqueue(RingField * rf)76 void RingAlg::PCQueue::Enqueue(RingField* rf) {
77 mutex_lock l(pcq_mu_);
78 deque_.push_back(rf);
79 if (waiter_count_ > 0) {
80 cv_.notify_one();
81 }
82 }
83
Dequeue()84 RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
85 mutex_lock l(pcq_mu_);
86 if (deque_.empty()) {
87 ++waiter_count_;
88 while (deque_.empty()) {
89 cv_.wait(l);
90 }
91 --waiter_count_;
92 }
93 RingField* rf = deque_.front();
94 deque_.pop_front();
95 return rf;
96 }
97
RingAlg(CollectiveType type,const string & name)98 RingAlg::RingAlg(CollectiveType type, const string& name)
99 : type_(type),
100 name_(name),
101 col_ctx_(nullptr),
102 col_params_(nullptr),
103 done_(nullptr),
104 group_size_(-1),
105 num_subdivs_(-1) {}
106
107 namespace {
GenerateSubdivsInCollectiveParams(CollectiveParams * col_params)108 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
109 if (col_params->instance.shape.num_elements() == 0) {
110 return errors::Internal("shape in CollectiveParams should be non-empty");
111 }
112 const int kAvgDevPerTask =
113 col_params->group.group_size / col_params->group.num_tasks;
114 const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
115 if (kMaxNumSubdivs <= 0) {
116 return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
117 " in ",
118 col_params->instance.impl_details.collective_name);
119 }
120 // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
121 // as many offsets as needed so that the size of tensor chunks <=
122 // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
123 // lead to worse performance.
124 int num_subdivs = 0;
125 const size_t tensor_size = col_params->instance.shape.num_elements() *
126 DataTypeSize(col_params->instance.data_type);
127 size_t chunk_size;
128 do {
129 ++num_subdivs;
130 int num_chunks = col_params->group.group_size * num_subdivs;
131 chunk_size = tensor_size / num_chunks;
132 VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
133 << " chunk_size " << chunk_size;
134 } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
135 if (num_subdivs <= 0) {
136 return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
137 col_params->instance.impl_details.collective_name);
138 }
139
140 int subdiv_stride = kAvgDevPerTask / num_subdivs;
141 if (subdiv_stride == 0) subdiv_stride = 1;
142 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
143 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
144 int subdiv_offset = subdiv_stride * sdi;
145 if (sdi % 2 == 1) subdiv_offset *= -1;
146 col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
147 }
148
149 if (VLOG_IS_ON(2)) {
150 string subdiv_buf;
151 for (const int subdiv_offset :
152 col_params->instance.impl_details.subdiv_offsets) {
153 strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
154 }
155 VLOG(2) << "Dynamically generated " << num_subdivs
156 << " subdiv_offsets:" << subdiv_buf << " tensor_size "
157 << tensor_size << " chunk_size " << chunk_size;
158 }
159
160 return Status::OK();
161 }
162 } // namespace
163
InitializeCollectiveParams(CollectiveParams * col_params)164 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
165 const string& device_name =
166 col_params->instance.device_names[col_params->default_rank];
167 // Each subdiv permutation is a ring formed by rotating each
168 // single-task subsequence of devices by an offset. This makes most
169 // sense when each task has the same number of devices but we can't
170 // depend on that being the case so we'll compute something that
171 // works in any case.
172
173 // Start by counting the devices in each task.
174 // Precondition: device_names must be sorted so that all devices in
175 // the same task are adjacent.
176 VLOG(2) << "Sorted task names: "
177 << str_util::Join(col_params->instance.task_names, ", ");
178 std::vector<int> dev_per_task;
179 const string* prior_task_name = &col_params->instance.task_names[0];
180 int dev_count = 1;
181 for (int di = 1; di < col_params->group.group_size; ++di) {
182 if (col_params->instance.task_names[di] != *prior_task_name) {
183 dev_per_task.push_back(dev_count);
184 dev_count = 1;
185 prior_task_name = &col_params->instance.task_names[di];
186 } else {
187 ++dev_count;
188 }
189 }
190 dev_per_task.push_back(dev_count);
191 DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
192
193 if (col_params->instance.impl_details.subdiv_offsets.empty()) {
194 TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
195 }
196
197 // Generate a ring permutation for requested offset.
198 VLOG(2) << "Setting up perms for col_params " << col_params
199 << " subdiv_permutations "
200 << &col_params->instance.impl_details.subdiv_permutations;
201 col_params->instance.impl_details.subdiv_permutations.resize(
202 col_params->instance.impl_details.subdiv_offsets.size());
203 col_params->subdiv_rank.resize(
204 col_params->instance.impl_details.subdiv_offsets.size(), -1);
205 for (int sdi = 0;
206 sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
207 std::vector<int>& perm =
208 col_params->instance.impl_details.subdiv_permutations[sdi];
209 DCHECK_EQ(perm.size(), 0);
210 int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
211 // A negative subdivision offset is interpreted as follows:
212 // 1. Reverse the local device ordering.
213 // 2. Begin the subdivision at abs(offset) in the reversed ordering.
214 bool reverse = false;
215 if (offset < 0) {
216 offset = abs(offset);
217 reverse = true;
218 }
219 int prior_dev_count = 0; // sum over prior worker device counts
220 for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
221 for (int di = 0; di < dev_per_task[ti]; ++di) {
222 int di_offset = (di + offset) % dev_per_task[ti];
223 int offset_di =
224 reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
225 // Device index in global subdivision permutation.
226 int permuted_di = prior_dev_count + offset_di;
227 int rank = static_cast<int>(perm.size());
228 perm.push_back(permuted_di);
229 if (col_params->instance.device_names[permuted_di] == device_name) {
230 DCHECK_EQ(permuted_di, col_params->default_rank);
231 col_params->subdiv_rank[sdi] = rank;
232 }
233 }
234 prior_dev_count += dev_per_task[ti];
235 }
236 DCHECK_EQ(col_params->group.group_size, perm.size());
237 }
238
239 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
240 return Status::OK();
241 }
242
InitializeCollectiveContext(CollectiveContext * col_ctx)243 Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) {
244 DCHECK(col_ctx->dev_mgr);
245 col_ctx_ = col_ctx;
246 col_params_ = &col_ctx->col_params;
247 return collective_util::InitializeDeviceAndLocality(
248 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
249 &col_ctx->device_locality);
250 }
251
TensorDebugString(const Tensor & tensor)252 string RingAlg::TensorDebugString(const Tensor& tensor) {
253 const DeviceBase::GpuDeviceInfo* gpu_device_info =
254 col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
255 if (gpu_device_info) {
256 Tensor cpu_tensor(tensor.dtype(), tensor.shape());
257 Notification note;
258 gpu_device_info->default_context->CopyDeviceTensorToCPU(
259 &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
260 [¬e](const Status& s) {
261 DCHECK(s.ok());
262 note.Notify();
263 });
264 note.WaitForNotification();
265 return cpu_tensor.SummarizeValue(64);
266 } else {
267 return tensor.SummarizeValue(64);
268 }
269 }
270
StartAbort(const Status & s)271 void RingAlg::StartAbort(const Status& s) {
272 // In abort mode we stop issuing additional ProvideBuf
273 // and ConsumeBuf calls, but we need to wait for all of the
274 // outstanding callbacks to be invoked before quitting.
275 bool abort_started = false;
276 {
277 mutex_lock l(status_mu_);
278 if (status_.ok()) {
279 LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
280 abort_started = true;
281 status_.Update(s);
282 }
283 }
284 // If this is the initial entry to abort mode then invoke StartAbort
285 // on the CollectiveExecutor that invoked us. That should start
286 // cancellation on all of the outstanding CollectiveRemoteAccess
287 // actions.
288 if (abort_started) {
289 col_ctx_->col_exec->StartAbort(s);
290 }
291 }
292
Finish(bool ok)293 void RingAlg::Finish(bool ok) {
294 if (ok) {
295 // Recover the output from the adaptor.
296 ca_->ConsumeFinalValue(col_ctx_->output);
297 }
298 Status s;
299 {
300 mutex_lock l(status_mu_);
301 s = status_;
302 }
303 rfv_.clear(); // Give up Refs on output tensor.
304 done_(s);
305 }
306
307 // At the beginning of the algorithm initialize a RingField struct for
308 // every independent field of the tensor.
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)309 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
310 int field_idx) {
311 // Note on field indexing: There are group_size_ devices in the
312 // instance, implying the same number of chunks per tensor, where a
313 // chunk is the unit of data transferred in a time step. However, if
314 // a device can simultaneously send data by 2 or more independent
315 // channels we can speed up the transfer by subdividing chunks and
316 // processing multiple subdivisions at once. So the actual number
317 // of RingFields is group_size_ * num_subdivs_.
318 DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
319 rf->chunk_idx = chunk_idx;
320 rf->subdiv_idx = subdiv_idx;
321 rf->sc_idx = field_idx;
322 rf->rank = col_params_->subdiv_rank[subdiv_idx];
323 rf->second_pass = false;
324 rf->action = RF_INIT;
325 // Recv from the device with preceding rank within the subdivision.
326 int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
327 int send_to_rank = (rf->rank + 1) % group_size_;
328 rf->recv_dev_idx = col_params_->instance.impl_details
329 .subdiv_permutations[subdiv_idx][recv_from_rank];
330 int send_dev_idx = col_params_->instance.impl_details
331 .subdiv_permutations[subdiv_idx][send_to_rank];
332 rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
333 rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
334 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
335 // In pass 0 we skip Recv when rank = chunk_idx
336 rf->do_recv = (rf->chunk_idx != rf->rank);
337 // In pass 0 we skip Send when rank = chunk_idx-1
338 rf->do_send =
339 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
340 }
341 rf->is_final =
342 (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
343 if (rf->do_send || rf->do_recv) {
344 rf->chunk = ca_->ChunkAlias(rf->sc_idx);
345 }
346 VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
347 << ca_->TBounds(rf->chunk);
348 }
349
350 // When a RingField transitions from first to second recompute the
351 // do_send and do_recv values.
AdvanceToSecondPass(RingField * rf)352 void RingAlg::AdvanceToSecondPass(RingField* rf) {
353 VLOG(3) << "IncrRingField old value " << rf->DebugString();
354 DCHECK(!rf->second_pass);
355 rf->second_pass = true;
356 rf->action = RF_INIT;
357 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
358 // In pass 1 the send/no-send boundary moves down 1 place.
359 rf->do_recv =
360 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
361 rf->do_send =
362 (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
363 }
364 rf->is_final =
365 (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
366 VLOG(3) << "IncrRingField new value " << rf->DebugString();
367 }
368
DebugString() const369 string RingAlg::RingField::DebugString() const {
370 string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
371 " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
372 " action=", action);
373 strings::StrAppend(&rv, " pass=", second_pass);
374 strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
375 " is_final=", is_final, " recv_is_remote=", recv_is_remote,
376 " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
377 return rv;
378 }
379
DispatchSend(RingField * rf,const StatusCallback & done)380 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
381 DCHECK(rf->do_send);
382 string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
383 rf->second_pass, rf->sc_idx, rf->rank);
384 VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
385 << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
386 << rf->sc_idx;
387 int send_to_rank = (rf->rank + 1) % group_size_;
388 int send_to_dev_idx = col_params_->instance.impl_details
389 .subdiv_permutations[rf->subdiv_idx][send_to_rank];
390 col_ctx_->col_exec->PostToPeer(
391 col_params_->instance.device_names[send_to_dev_idx],
392 col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
393 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
394 col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
395 col_ctx_->device_locality, done);
396 }
397
DispatchRecv(RingField * rf,const StatusCallback & done)398 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
399 DCHECK(rf->do_recv);
400 string recv_buf_key =
401 RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
402 (rf->rank + (group_size_ - 1)) % group_size_);
403 VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
404 << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
405 << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
406 Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
407 ? &rf->tmp_chunk
408 : &rf->chunk;
409 col_ctx_->col_exec->RecvFromPeer(
410 col_params_->instance.device_names[rf->recv_dev_idx],
411 col_params_->instance.task_names[rf->recv_dev_idx],
412 col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
413 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
414 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
415 col_ctx_->device_locality, rf->subdiv_idx, done);
416 }
417
FieldState()418 string RingAlg::FieldState() {
419 string s = strings::StrCat(
420 "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
421 " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
422 " state of all ", rfv_.size(), " fields:");
423 for (int i = 0; i < rfv_.size(); ++i) {
424 s.append("\n");
425 s.append(rfv_[i].DebugString());
426 }
427 return s;
428 }
429
430 } // namespace tensorflow
431