1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_reducer.h"
16
17 #include <stdlib.h>
18 #include <atomic>
19 #include <functional>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/copy_tensor.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/dma_helper.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace tensorflow {
43
~RingReducer()44 RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
45
InitializeCollectiveParams(CollectiveParams * col_params)46 Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
47 // TODO(b/113171733): change CHECKs to return errors.
48 CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
49 CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
50 return RingAlg::InitializeCollectiveParams(col_params);
51 }
52
Run(StatusCallback done)53 void RingReducer::Run(StatusCallback done) {
54 CHECK(col_ctx_);
55 CHECK(col_params_);
56 done_ = std::move(done);
57 group_size_ = col_params_->group.group_size;
58 num_subdivs_ = static_cast<int>(
59 col_params_->instance.impl_details.subdiv_permutations.size());
60 CHECK_GT(num_subdivs_, 0);
61
62 if (VLOG_IS_ON(1)) {
63 string buf;
64 for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
65 strings::StrAppend(&buf, "dev ", r, " : ",
66 col_params_->instance.device_names[r], "\n");
67 }
68 for (int sd = 0;
69 sd < col_params_->instance.impl_details.subdiv_permutations.size();
70 ++sd) {
71 strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
72 for (auto x :
73 col_params_->instance.impl_details.subdiv_permutations[sd]) {
74 strings::StrAppend(&buf, x, ", ");
75 }
76 }
77 VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
78 << " default_rank " << col_params_->default_rank << "\n"
79 << buf;
80 }
81
82 // Start by copying input to output if they're not already the same, i.e. if
83 // we're not computing in-place on the input tensor.
84 if ((col_ctx_->input != col_ctx_->output) &&
85 (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
86 // We are running in a blockable thread and the callback can't block so
87 // just wait here on the copy.
88 Notification note;
89 Status status;
90 CollectiveRemoteAccessLocal::MemCpyAsync(
91 col_ctx_->op_ctx->input_device_context(0),
92 col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
93 col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
94 col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
95 col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
96 [¬e, &status](const Status& s) {
97 status.Update(s);
98 note.Notify();
99 });
100 note.WaitForNotification();
101 if (!status.ok()) {
102 done_(status);
103 return;
104 }
105 }
106 ContinueAfterInputCopy();
107 }
108
109 // Note that this function is blocking and must not run in any thread
110 // which cannot be blocked.
ContinueAfterInputCopy()111 void RingReducer::ContinueAfterInputCopy() {
112 AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
113 ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
114 col_ctx_->device->GetAllocator(attr)));
115
116 if (col_params_->final_op) {
117 // Create an on-device scalar value from group_size_ that may be needed
118 // later.
119 // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
120 // can be provided to the kernel in host memory?
121 Tensor group_size_val = ca_->Scalar(group_size_);
122 if (col_params_->group.device_type != "CPU") {
123 group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
124 col_ctx_->op_ctx->input_alloc_attr(0)));
125 DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
126 op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
127 &group_size_tensor_,
128 [this](const Status& s) {
129 if (!s.ok()) {
130 StartAbort(s);
131 }
132 group_size_tensor_ready_.Notify();
133 });
134 } else {
135 group_size_tensor_ = group_size_val;
136 group_size_tensor_ready_.Notify();
137 }
138 } else {
139 // Value won't be used, so no need to initialize.
140 group_size_tensor_ready_.Notify();
141 }
142 Finish(RunAsyncParts());
143 }
144
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)145 void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
146 int field_idx) {
147 RingAlg::InitRingField(rf, chunk_idx, subdiv_idx, field_idx);
148 if (rf->do_recv) {
149 rf->tmp_chunk = ca_->TempChunk(rf->sc_idx);
150 }
151 }
152
153 // At the beginning of the algorithm initialize a RingField struct for
154 // every independent field of the tensor.
RunAsyncParts()155 bool RingReducer::RunAsyncParts() {
156 // This function orchestrates RingReduce actions on behalf of a
157 // single device. It is entered by a blockable thread that
158 // loops within it until all actions assigned to that device
159 // complete. Hence function local variables are accessible only by that
160 // one thread and do not require an explicit mutex.
161 rfv_.clear();
162 rfv_.resize(group_size_ * num_subdivs_);
163 PCQueue ready_queue;
164 for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
165 for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
166 int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
167 InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index);
168 ready_queue.Enqueue(&rfv_[rf_index]);
169 }
170 }
171 const DeviceBase::GpuDeviceInfo* gpu_info =
172 col_ctx_->device->tensorflow_gpu_device_info();
173 if (gpu_info) {
174 // Wait for all currently queued events on the CPU compute stream to
175 // complete before proceeding. The previous InitRingField calls allocated
176 // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
177 // write) unless we do.
178 Notification note;
179 Status s = gpu_info->default_context->ThenExecute(
180 col_ctx_->device, gpu_info->stream, [¬e]() { note.Notify(); });
181 if (s.ok()) {
182 note.WaitForNotification();
183 } else {
184 mutex_lock l(status_mu_);
185 status_ =
186 errors::Internal("Failed to dispatch ThenExecute in RingReducer");
187 return false;
188 }
189 }
190
191 int field_done_count = 0;
192 int send_pending_count = 0;
193 int recv_pending_count = 0;
194 std::atomic<bool> aborted(false);
195
196 // Loop until all RingFields have advanced to completion.
197 while (field_done_count < rfv_.size()) {
198 VLOG(4) << FieldState();
199 // Wait for a RingField to appear in the ready_queue.
200 RingField* rf = ready_queue.Dequeue();
201 // Advance the RingField to its next action and execute, repeating
202 // until either an async action has been started or the RingField
203 // is done.
204 bool dispatched = false; // true if async action was initiated
205 do {
206 if (aborted) {
207 // Requeue this RingField to be counted off below.
208 ready_queue.Enqueue(rf);
209 break;
210 }
211 switch (rf->action) {
212 case RF_INIT:
213 if (rf->do_recv) {
214 rf->action = RF_RECV;
215 auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
216 if (!s.ok()) {
217 aborted = true;
218 StartAbort(s);
219 }
220 ready_queue.Enqueue(rf);
221 };
222 DispatchRecv(rf, requeue);
223 dispatched = true;
224 ++recv_pending_count;
225 } else {
226 rf->action = RF_SEND_READY;
227 }
228 break;
229 case RF_RECV:
230 CHECK_GT(recv_pending_count, 0);
231 --recv_pending_count;
232 if (!rf->second_pass) {
233 rf->action = RF_REDUCE;
234 Status s = collective_util::ComputeBinOp(
235 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
236 col_params_->merge_op.get(), &rf->chunk, &rf->tmp_chunk);
237 if (!s.ok()) {
238 aborted = true;
239 StartAbort(s);
240 }
241 } else {
242 rf->action = RF_SEND_READY;
243 }
244 break;
245 case RF_REDUCE:
246 if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) {
247 rf->action = RF_FINALIZE;
248 group_size_tensor_ready_.WaitForNotification();
249 Status s = collective_util::ComputeBinOp(
250 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
251 col_params_->final_op.get(), &rf->chunk, &group_size_tensor_);
252 if (!s.ok()) {
253 aborted = true;
254 StartAbort(s);
255 }
256 } else {
257 rf->action = RF_SEND_READY;
258 }
259 break;
260 case RF_FINALIZE:
261 rf->action = RF_DONE;
262 break;
263 case RF_SEND_READY:
264 if (rf->do_send) {
265 rf->action = RF_SEND;
266 auto send_complete = [this, rf, &ready_queue, &aborted](Status s) {
267 if (!s.ok()) {
268 aborted = true;
269 StartAbort(s);
270 }
271 ready_queue.Enqueue(rf);
272 };
273 DispatchSend(rf, send_complete);
274 dispatched = true;
275 ++send_pending_count;
276 } else {
277 rf->action = RF_DONE;
278 }
279 break;
280 case RF_SEND:
281 CHECK_GT(send_pending_count, 0);
282 --send_pending_count;
283 rf->action = RF_DONE;
284 break;
285 case RF_DONE:
286 break;
287 }
288 if (rf->action == RF_DONE) {
289 if (rf->second_pass) {
290 ++field_done_count;
291 break; // from do while(!dispatched)
292 } else {
293 AdvanceToSecondPass(rf);
294 }
295 }
296 } while (!dispatched);
297 if (aborted) break;
298 } // while (field_done_count < number of fields)
299
300 if (aborted) {
301 // All of the pending data actions should be aborted; field the
302 // callbacks and clear the queue before quitting.
303 while ((send_pending_count > 0) || (recv_pending_count > 0)) {
304 RingField* rf = ready_queue.Dequeue();
305 switch (rf->action) {
306 case RF_RECV:
307 --recv_pending_count;
308 break;
309 case RF_SEND:
310 --send_pending_count;
311 break;
312 default: {
313 } // Ignore any other actions
314 }
315 }
316 }
317
318 CHECK_EQ(send_pending_count, 0);
319 CHECK_EQ(recv_pending_count, 0);
320
321 VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
322 << " final value " << TensorDebugString(ca_->Value());
323 return !aborted;
324 }
325
326 REGISTER_COLLECTIVE(RingReduce, RingReducer);
327
328 } // namespace tensorflow
329