1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_reducer.h"
16 
17 #include <stdlib.h>
18 #include <atomic>
19 #include <functional>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/copy_tensor.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/dma_helper.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace tensorflow {
43 
~RingReducer()44 RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
45 
InitializeCollectiveParams(CollectiveParams * col_params)46 Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
47   // TODO(b/113171733): change CHECKs to return errors.
48   CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
49   CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
50   return RingAlg::InitializeCollectiveParams(col_params);
51 }
52 
Run(StatusCallback done)53 void RingReducer::Run(StatusCallback done) {
54   CHECK(col_ctx_);
55   CHECK(col_params_);
56   done_ = std::move(done);
57   group_size_ = col_params_->group.group_size;
58   num_subdivs_ = static_cast<int>(
59       col_params_->instance.impl_details.subdiv_permutations.size());
60   CHECK_GT(num_subdivs_, 0);
61 
62   if (VLOG_IS_ON(1)) {
63     string buf;
64     for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
65       strings::StrAppend(&buf, "dev ", r, " : ",
66                          col_params_->instance.device_names[r], "\n");
67     }
68     for (int sd = 0;
69          sd < col_params_->instance.impl_details.subdiv_permutations.size();
70          ++sd) {
71       strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
72       for (auto x :
73            col_params_->instance.impl_details.subdiv_permutations[sd]) {
74         strings::StrAppend(&buf, x, ", ");
75       }
76     }
77     VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
78             << " default_rank " << col_params_->default_rank << "\n"
79             << buf;
80   }
81 
82   // Start by copying input to output if they're not already the same, i.e. if
83   // we're not computing in-place on the input tensor.
84   if ((col_ctx_->input != col_ctx_->output) &&
85       (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
86     // We are running in a blockable thread and the callback can't block so
87     // just wait here on the copy.
88     Notification note;
89     Status status;
90     CollectiveRemoteAccessLocal::MemCpyAsync(
91         col_ctx_->op_ctx->input_device_context(0),
92         col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
93         col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
94         col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
95         col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
96         [&note, &status](const Status& s) {
97           status.Update(s);
98           note.Notify();
99         });
100     note.WaitForNotification();
101     if (!status.ok()) {
102       done_(status);
103       return;
104     }
105   }
106   ContinueAfterInputCopy();
107 }
108 
109 // Note that this function is blocking and must not run in any thread
110 // which cannot be blocked.
ContinueAfterInputCopy()111 void RingReducer::ContinueAfterInputCopy() {
112   AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
113   ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
114                                   col_ctx_->device->GetAllocator(attr)));
115 
116   if (col_params_->final_op) {
117     // Create an on-device scalar value from group_size_ that may be needed
118     // later.
119     // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
120     // can be provided to the kernel in host memory?
121     Tensor group_size_val = ca_->Scalar(group_size_);
122     if (col_params_->group.device_type != "CPU") {
123       group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
124           col_ctx_->op_ctx->input_alloc_attr(0)));
125       DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
126       op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
127                                         &group_size_tensor_,
128                                         [this](const Status& s) {
129                                           if (!s.ok()) {
130                                             StartAbort(s);
131                                           }
132                                           group_size_tensor_ready_.Notify();
133                                         });
134     } else {
135       group_size_tensor_ = group_size_val;
136       group_size_tensor_ready_.Notify();
137     }
138   } else {
139     // Value won't be used, so no need to initialize.
140     group_size_tensor_ready_.Notify();
141   }
142   Finish(RunAsyncParts());
143 }
144 
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)145 void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
146                                 int field_idx) {
147   RingAlg::InitRingField(rf, chunk_idx, subdiv_idx, field_idx);
148   if (rf->do_recv) {
149     rf->tmp_chunk = ca_->TempChunk(rf->sc_idx);
150   }
151 }
152 
153 // At the beginning of the algorithm initialize a RingField struct for
154 // every independent field of the tensor.
RunAsyncParts()155 bool RingReducer::RunAsyncParts() {
156   // This function orchestrates RingReduce actions on behalf of a
157   // single device. It is entered by a blockable thread that
158   // loops within it until all actions assigned to that device
159   // complete. Hence function local variables are accessible only by that
160   // one thread and do not require an explicit mutex.
161   rfv_.clear();
162   rfv_.resize(group_size_ * num_subdivs_);
163   PCQueue ready_queue;
164   for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
165     for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
166       int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
167       InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index);
168       ready_queue.Enqueue(&rfv_[rf_index]);
169     }
170   }
171   const DeviceBase::GpuDeviceInfo* gpu_info =
172       col_ctx_->device->tensorflow_gpu_device_info();
173   if (gpu_info) {
174     // Wait for all currently queued events on the CPU compute stream to
175     // complete before proceeding.  The previous InitRingField calls allocated
176     // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
177     // write) unless we do.
178     Notification note;
179     Status s = gpu_info->default_context->ThenExecute(
180         col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
181     if (s.ok()) {
182       note.WaitForNotification();
183     } else {
184       mutex_lock l(status_mu_);
185       status_ =
186           errors::Internal("Failed to dispatch ThenExecute in RingReducer");
187       return false;
188     }
189   }
190 
191   int field_done_count = 0;
192   int send_pending_count = 0;
193   int recv_pending_count = 0;
194   std::atomic<bool> aborted(false);
195 
196   // Loop until all RingFields have advanced to completion.
197   while (field_done_count < rfv_.size()) {
198     VLOG(4) << FieldState();
199     // Wait for a RingField to appear in the ready_queue.
200     RingField* rf = ready_queue.Dequeue();
201     // Advance the RingField to its next action and execute, repeating
202     // until either an async action has been started or the RingField
203     // is done.
204     bool dispatched = false;  // true if async action was initiated
205     do {
206       if (aborted) {
207         // Requeue this RingField to be counted off below.
208         ready_queue.Enqueue(rf);
209         break;
210       }
211       switch (rf->action) {
212         case RF_INIT:
213           if (rf->do_recv) {
214             rf->action = RF_RECV;
215             auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
216               if (!s.ok()) {
217                 aborted = true;
218                 StartAbort(s);
219               }
220               ready_queue.Enqueue(rf);
221             };
222             DispatchRecv(rf, requeue);
223             dispatched = true;
224             ++recv_pending_count;
225           } else {
226             rf->action = RF_SEND_READY;
227           }
228           break;
229         case RF_RECV:
230           CHECK_GT(recv_pending_count, 0);
231           --recv_pending_count;
232           if (!rf->second_pass) {
233             rf->action = RF_REDUCE;
234             Status s = collective_util::ComputeBinOp(
235                 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
236                 col_params_->merge_op.get(), &rf->chunk, &rf->tmp_chunk);
237             if (!s.ok()) {
238               aborted = true;
239               StartAbort(s);
240             }
241           } else {
242             rf->action = RF_SEND_READY;
243           }
244           break;
245         case RF_REDUCE:
246           if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) {
247             rf->action = RF_FINALIZE;
248             group_size_tensor_ready_.WaitForNotification();
249             Status s = collective_util::ComputeBinOp(
250                 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
251                 col_params_->final_op.get(), &rf->chunk, &group_size_tensor_);
252             if (!s.ok()) {
253               aborted = true;
254               StartAbort(s);
255             }
256           } else {
257             rf->action = RF_SEND_READY;
258           }
259           break;
260         case RF_FINALIZE:
261           rf->action = RF_DONE;
262           break;
263         case RF_SEND_READY:
264           if (rf->do_send) {
265             rf->action = RF_SEND;
266             auto send_complete = [this, rf, &ready_queue, &aborted](Status s) {
267               if (!s.ok()) {
268                 aborted = true;
269                 StartAbort(s);
270               }
271               ready_queue.Enqueue(rf);
272             };
273             DispatchSend(rf, send_complete);
274             dispatched = true;
275             ++send_pending_count;
276           } else {
277             rf->action = RF_DONE;
278           }
279           break;
280         case RF_SEND:
281           CHECK_GT(send_pending_count, 0);
282           --send_pending_count;
283           rf->action = RF_DONE;
284           break;
285         case RF_DONE:
286           break;
287       }
288       if (rf->action == RF_DONE) {
289         if (rf->second_pass) {
290           ++field_done_count;
291           break;  // from do while(!dispatched)
292         } else {
293           AdvanceToSecondPass(rf);
294         }
295       }
296     } while (!dispatched);
297     if (aborted) break;
298   }  // while (field_done_count < number of fields)
299 
300   if (aborted) {
301     // All of the pending data actions should be aborted; field the
302     // callbacks and clear the queue before quitting.
303     while ((send_pending_count > 0) || (recv_pending_count > 0)) {
304       RingField* rf = ready_queue.Dequeue();
305       switch (rf->action) {
306         case RF_RECV:
307           --recv_pending_count;
308           break;
309         case RF_SEND:
310           --send_pending_count;
311           break;
312         default: {
313         }  // Ignore any other actions
314       }
315     }
316   }
317 
318   CHECK_EQ(send_pending_count, 0);
319   CHECK_EQ(recv_pending_count, 0);
320 
321   VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
322           << " final value " << TensorDebugString(ca_->Value());
323   return !aborted;
324 }
325 
326 REGISTER_COLLECTIVE(RingReduce, RingReducer);
327 
328 }  // namespace tensorflow
329