1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16 
17 #include <stddef.h>
18 
19 #include <algorithm>
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/device_attributes.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,const string & task_name)41 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
42     const ConfigProto& config, const DeviceMgr* dev_mgr,
43     DeviceResolverInterface* dev_resolver, const string& task_name)
44     : nccl_(config.experimental().collective_nccl()),
45       dev_mgr_(dev_mgr),
46       dev_resolver_(dev_resolver),
47       task_name_(task_name) {}
48 
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)49 void CollectiveParamResolverLocal::CompleteGroupAsync(
50     const CompleteGroupRequest* request, CompleteGroupResponse* response,
51     CancellationManager* cancel_mgr, const StatusCallback& done) {
52   done(
53       errors::Internal("CompleteGroup is not implemented by "
54                        "CollectiveParamResolverLocal which is "
55                        "intended only for non-distributed deployment."));
56 }
57 
58 namespace {
GetCollectiveName(const CollectiveParams * cp,bool nccl)59 const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
60   switch (cp->instance.type) {
61     case BROADCAST_COLLECTIVE:
62       return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
63 
64     case REDUCTION_COLLECTIVE:
65       return nccl ? "NcclReduce" : "RingReduce";
66 
67     case GATHER_COLLECTIVE:
68       return nccl ? "NcclGather" : "RingGather";
69 
70     case PERMUTE_COLLECTIVE:
71       return "Permute";
72 
73     default:
74       return "undef";
75   }
76 }
77 
TaskNameFromDeviceName(const string & device_name)78 string TaskNameFromDeviceName(const string& device_name) {
79   DeviceNameUtils::ParsedName parsed_device;
80   CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
81   string task_name;
82   CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
83   return task_name;
84 }
85 }  // namespace
86 
CompleteGroupLocal(const DeviceAttributes & device,CollectiveParams * cp,const GroupRecCallback & done,CancellationManager * cancel_mgr)87 void CollectiveParamResolverLocal::CompleteGroupLocal(
88     const DeviceAttributes& device, CollectiveParams* cp,
89     const GroupRecCallback& done, CancellationManager* cancel_mgr) {
90   VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
91           << ": " << cp->ToString();
92   std::vector<StatusCallback> to_be_called;
93   // Keep a reference to `cp` to avoid racing with deletion due to cancellation.
94   cp->Ref();
95   core::ScopedUnref cp_unref(cp);
96 
97   std::function<void(const Status& s, GroupRec* gr)> done_with_cleanup;
98   if (cancel_mgr != nullptr) {
99     auto cancelled_mu = std::make_shared<mutex>();
100     // Some callers delete `cancel_mgr` as soon as `done` is called once,
101     // meaning we can't rely on it to avoid calling `done` twice if the local op
102     // is cancelled but the group succeeds.
103     auto cancelled = std::make_shared<bool>(false);
104     const CancellationToken token = cancel_mgr->get_cancellation_token();
105     const bool already_cancelled =
106         !cancel_mgr->RegisterCallback(token, [done, cancelled_mu, cancelled]() {
107           {
108             mutex_lock l(*cancelled_mu);
109             *cancelled = true;
110           }
111           done(errors::Cancelled("op cancelled"), nullptr);
112         });
113     if (already_cancelled) {
114       done(errors::Cancelled("op cancelled"), nullptr);
115       return;
116     }
117     done_with_cleanup = [cancel_mgr, done, cancelled_mu, cancelled, token](
118                             const Status& s, GroupRec* gr) {
119       {
120         mutex_lock l(*cancelled_mu);
121         if (*cancelled || !cancel_mgr->TryDeregisterCallback(token)) {
122           return;
123         }
124       }
125       // The operation was never cancelled, so we'll return a normal status.
126       done(s, gr);
127     };
128   } else {
129     done_with_cleanup = done;
130   }
131 
132   GroupRec* gr = nullptr;
133   Status status;
134   {
135     mutex_lock l(group_mu_);
136     auto it = group_table_.find(cp->group.group_key);
137     if (it == group_table_.end()) {
138       gr = new GroupRec;
139       mutex_lock grl(gr->mu);
140       gr->group.group_key = cp->group.group_key;
141       gr->group.group_size = cp->group.group_size;
142       gr->group.device_type = cp->group.device_type;
143       gr->group.gpu_ring_order = cp->group.gpu_ring_order;
144 
145       // Initialize group runtime details.
146       CollectiveImplementationInterface* col_impl;
147       // Try to lookup a NCCL collective kernel.  This will return error status
148       // if `NcclReduce` kernel is not present in the registry, e.g. on an
149       // environment that does not support NCCL.
150       status = CollectiveRegistry::LookupParamResolverInstance("NcclReduce",
151                                                                &col_impl);
152       if (!status.ok()) {
153         // Fallback to non-NCCL collective.
154         status = CollectiveRegistry::LookupParamResolverInstance(
155             GetCollectiveName(cp, /*nccl=*/false), &col_impl);
156       }
157       if (status.ok()) {
158         status = col_impl->InitializeCollectiveGroupRuntimeDetails(
159             &gr->group.runtime_details);
160       }
161 
162       if (!status.ok()) {
163         done_with_cleanup(status, gr);
164         return;
165       }
166 
167       // Store GroupRec in group_table_ which is shared between all devices on
168       // this worker.
169       group_table_[gr->group.group_key].reset(gr);
170       VLOG(2) << "New group_key=" << gr->group.group_key
171               << " group_size=" << gr->group.group_size
172               << " runtime_details=" << gr->group.runtime_details.ToString();
173     } else {
174       gr = it->second.get();
175     }
176   }
177   {
178     mutex_lock l(status_mu_);
179     status = status_;
180   }
181   if (!status.ok()) {
182     done_with_cleanup(status, nullptr);
183     return;
184   }
185   {
186     mutex_lock gr_lock(gr->mu);
187     // If there is ever an error associated with a group key, we store the error
188     // status and invoke all waiting and future callbacks with this error
189     // status.
190     VLOG(2) << "gr device_type=" << gr->group.device_type
191             << " cp device_type=" << cp->group.device_type
192             << " current device=" << device.name();
193     if (gr->status.ok()) {
194       // Check for consistency with existing GroupRec.
195       if (cp->group.device_type != gr->group.device_type) {
196         gr->status = errors::Internal(
197             "Collective Op ", cp->name, " is assigned to device ",
198             device.name(), " with type ", cp->group.device_type.type_string(),
199             " and group_key ", cp->group.group_key, " but that group has type ",
200             gr->group.device_type.type_string());
201       } else if (cp->group.group_size != gr->group.group_size) {
202         gr->status = errors::Internal(
203             "Collective Op ", cp->name, " has group_size ",
204             cp->group.group_size, " and group_key ", cp->group.group_key,
205             " but that group has size ", gr->group.group_size);
206       }
207     }
208     bool new_device = false;
209     if (gr->status.ok()) {
210       // Insert device if not already present.
211       auto it = gr->devices.find(device.name());
212       if (it == gr->devices.end()) {
213         if (gr->devices.size() == gr->group.group_size) {
214           // The group is already full.
215           gr->status = errors::Internal(
216               "Collective Op ", cp->name, " is assigned to device ",
217               device.name(), " and group_key ", cp->group.group_key,
218               " but that group doesn't contain that device.");
219         } else {
220           // This is a new device that has not yet joined the group.
221           gr->devices[device.name()] = device;
222           new_device = true;
223           if (VLOG_IS_ON(1)) {
224             string dev_buf;
225             for (const auto& d : gr->devices) {
226               strings::StrAppend(&dev_buf, ",", d.first);
227             }
228             VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
229                     << " group_size=" << gr->group.group_size << " (current"
230                     << " devices)=(" << dev_buf << ") (number of"
231                     << " devices pending)="
232                     << (gr->group.group_size - gr->devices.size());
233           }
234         }
235       } else {
236         // If the device already exists, check if the incarnation matches.
237         if (it->second.incarnation() != device.incarnation()) {
238           gr->status = errors::FailedPrecondition(
239               "Device ", device.name(),
240               " current incarnation doesn't match with one in the group. This "
241               "usually means this worker has restarted but the collective "
242               "leader hasn't, or this worker connects to a wrong cluster.");
243         }
244       }
245     }
246 
247     if (gr->status.ok()) {
248       // If the group is not yet complete, queue to wait for it.
249       VLOG(2) << "group_size " << gr->group.group_size << " set size "
250               << gr->devices.size() << " gr " << gr;
251 
252       if (gr->devices.size() < gr->group.group_size) {
253         gr->waiting.push_back(
254             std::bind(done_with_cleanup, std::placeholders::_1, gr));
255         return;
256       }
257       CHECK_EQ(gr->devices.size(), gr->group.group_size);
258       // We get a full group. Fill in remaining fields in gr->group.
259       if (new_device) {
260         FinishGroup(gr);
261       }
262     }
263     // At this point, we either have a full group, or an error status.  Ensure
264     // that all callbacks are invoked with the appropriate status.
265     if (!gr->waiting.empty()) {
266       std::swap(to_be_called, gr->waiting);
267     }
268     status = gr->status;
269   }
270   done_with_cleanup(status, gr);
271   for (int i = 0; i < to_be_called.size(); ++i) {
272     to_be_called[i](status);
273   }
274 }
275 
276 namespace {
277 struct DevRec {
278   string task;
279   string device;
280   int original_rank;
281   int local_rank;
282   int global_rank;
283   const DeviceLocality* locality;
284 };
285 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
286 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
287 
288 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollGroupParams & gp,const std::vector<DeviceAttributes> & attributes)289 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp,
290                              const std::vector<DeviceAttributes>& attributes) {
291   GlobalDeviceMap gdm;
292   CHECK_EQ(gp.device_names.size(), gp.task_names.size());
293   CHECK_EQ(gp.device_names.size(), attributes.size());
294   for (int i = 0; i < gp.device_names.size(); ++i) {
295     TaskDeviceMap& tdm = gdm[gp.task_names[i]];
296     DevRec* dr = &tdm[gp.device_names[i]];
297     dr->task = gp.task_names[i];
298     dr->device = gp.device_names[i];
299     dr->original_rank = i;
300     dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
301     dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
302     dr->locality = &attributes[i].locality();
303   }
304   return gdm;
305 }
306 
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)307 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
308   std::vector<string> split_gpu_ring_order_str =
309       str_util::Split(gpu_ring_order_str, ',');
310   if (split_gpu_ring_order_str.size() != tdm->size()) return false;
311 
312   // gpu id -> local rank
313   gtl::FlatMap<int32, int32> gpu_ranks;
314   for (int32 rank = 0;
315        rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
316     int32 tmp;
317     if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
318       gpu_ranks[tmp] = rank;
319     } else {
320       return false;
321     }
322   }
323 
324   for (auto& tdm_it : *tdm) {
325     DeviceNameUtils::ParsedName parsed_name;
326     DevRec* dr = &tdm_it.second;
327     if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
328       return false;
329     }
330     auto rank_it = gpu_ranks.find(parsed_name.id);
331     if (rank_it == gpu_ranks.end()) return false;
332     dr->local_rank = rank_it->second;
333   }
334   VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
335   return true;
336 }
337 
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)338 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
339   CHECK_GT(tdm->size(), 0);  // Should never be called with 0 devices
340 
341   // If a valid ring order has been passed in via ConfigProto, use that.
342   if (ParseRingOrder(gpu_ring_order, tdm)) return;
343 
344   // Either no ring order was passed in, or the format was unexpected.
345   // We now assign a ring order based on link strengths.  Note that this
346   // algorithm is not optimal and may not always find the best ring order.
347   int least_rank = -1;
348   string next_device;
349   std::set<string> selected;
350   // Starting device is one with the least initial rank.
351   for (const auto& it : *tdm) {
352     if (least_rank < 0 || it.second.original_rank < least_rank) {
353       least_rank = it.second.original_rank;
354       next_device = it.second.device;
355     }
356   }
357   CHECK_GE(least_rank, 0);
358   DeviceNameUtils::ParsedName parsed_name;
359   CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
360   // NOTE: InterconnectLink has only a device_id, nothing more, so for
361   // the time being if there's more than one device at a task we
362   // assume they're all GPUs.
363 
364   int next_rank = 0;
365   while (true) {
366     selected.insert(next_device);
367     auto next_dev_it = tdm->find(next_device);
368     CHECK(next_dev_it != tdm->end());
369     DevRec* dr = &next_dev_it->second;
370     dr->local_rank = next_rank;
371     ++next_rank;
372     if (selected.size() == tdm->size()) {
373       break;
374     }
375     // For the present time we assume Locality links only cover GPUs.
376     // For multiple CPUs, just take them in order.
377     const InterconnectLink* best_link = nullptr;
378     if (parsed_name.type == "GPU") {
379       for (const InterconnectLink& il : dr->locality->links().link()) {
380         parsed_name.id = il.device_id();
381         string endpoint_device =
382             DeviceNameUtils::ParsedNameToString(parsed_name);
383         // Skip the device if we've already seen it.
384         if (selected.find(endpoint_device) != selected.end()) {
385           continue;
386         }
387         // Skip the device if it is not participating in this collective
388         // instance.
389         if (tdm->find(endpoint_device) == tdm->end()) {
390           continue;
391         }
392         if (best_link == nullptr || il.strength() > best_link->strength()) {
393           best_link = &il;
394         }
395       }
396     }
397     if (best_link != nullptr) {
398       // Follow the best edge
399       parsed_name.id = best_link->device_id();
400       next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
401     } else {
402       // No good edges, alas. Pick the lowest initial rank among remaining
403       // devices.
404       least_rank = -1;
405       for (const auto& it : *tdm) {
406         if (selected.find(it.second.device) != selected.end()) {
407           continue;
408         }
409         if (least_rank < 0 || it.second.original_rank < least_rank) {
410           least_rank = it.second.original_rank;
411           next_device = it.second.device;
412         }
413       }
414       CHECK_GE(least_rank, 0);
415     }
416   }
417 }
418 
419 // The first time a CollGroupParams is established for a group we compute a good
420 // rank order for all the devices in the group, that is appropriate for a ring
421 // algorithm.
EstablishGlobalRank(const CollGroupParams & gp,const std::vector<DeviceAttributes> & attributes)422 GlobalDeviceMap EstablishGlobalRank(
423     const CollGroupParams& gp,
424     const std::vector<DeviceAttributes>& attributes) {
425   VLOG(1) << "EstablishGlobalRank";
426   GlobalDeviceMap gdm = BuildDevRecs(gp, attributes);
427   for (auto& iter : gdm) {
428     TaskDeviceMap& tdm = iter.second;
429     OrderTaskDeviceMap(gp.gpu_ring_order, &tdm);
430   }
431   // Connect the global rank order by the order in which tasks first appear.
432   std::set<string> ordered_tasks;
433   int next_rank = 0;
434   for (int i = 0; i < gp.task_names.size(); ++i) {
435     const string& task_name = gp.task_names[i];
436     if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
437       continue;
438     }
439     ordered_tasks.insert(task_name);
440     TaskDeviceMap* tdm = &gdm[task_name];
441     for (auto& it : *tdm) {
442       it.second.global_rank = it.second.local_rank + next_rank;
443     }
444     next_rank += tdm->size();
445   }
446   return gdm;
447 }
448 
449 // Count the devices associated with each task and set
450 // gp->same_num_devices_per_task.  Requires gp->task_names
451 // be sorted.
SetDevPerTask(CollGroupParams * gp)452 void SetDevPerTask(CollGroupParams* gp) {
453   gp->num_devices_per_task.clear();
454   const string* last_task_name = &gp->task_names[0];
455   int count = 0;
456   for (const string& task_name : gp->task_names) {
457     if (task_name == *last_task_name) {
458       ++count;
459     } else {
460       gp->num_devices_per_task[*last_task_name] = count;
461       count = 1;
462       last_task_name = &task_name;
463     }
464   }
465   gp->num_devices_per_task[*last_task_name] = count;
466 
467   gp->same_num_devices_per_task = false;
468   int dev_per_task = -1;
469   for (const auto& task_dev : gp->num_devices_per_task) {
470     if (dev_per_task == -1) {
471       dev_per_task = task_dev.second;
472     } else if (dev_per_task != task_dev.second) {
473       return;
474     }
475   }
476   gp->same_num_devices_per_task = true;
477   CHECK_EQ((gp->group_size % gp->num_tasks), 0);
478 }
479 
480 // Sort gp->device_names lexicographically, but do by first
481 // computing a reordering permutation so we can keep gp->task_names
482 // in corresponding order.
SortDevicesAndTasks(CollGroupParams * gp)483 void SortDevicesAndTasks(CollGroupParams* gp) {
484   VLOG(1) << "SortDevicesAndTasks " << gp << " " << gp;
485   CHECK(gp);
486   CHECK_EQ(gp->group_size, gp->device_names.size());
487   CHECK_EQ(gp->group_size, gp->task_names.size());
488   std::vector<int> perm(gp->group_size);
489   // TODO(tucker): substitute std::iota when the windows build supports it.
490   // std::iota(perm.begin(), perm.end(), 0);
491   for (int i = 0; i < perm.size(); ++i) {
492     perm[i] = i;
493   }
494   std::sort(perm.begin(), perm.end(), [gp](int a, int b) {
495     return gp->device_names[a] < gp->device_names[b];
496   });
497   std::vector<string> new_devs;
498   std::vector<string> new_tasks;
499   new_devs.reserve(gp->group_size);
500   new_tasks.reserve(gp->group_size);
501   for (int pi : perm) {
502     new_devs.push_back(gp->device_names[pi]);
503     new_tasks.push_back(gp->task_names[pi]);
504   }
505   gp->device_names = std::move(new_devs);
506   gp->task_names = std::move(new_tasks);
507   VLOG(1) << "Modified device_names on " << gp;
508   SetDevPerTask(gp);
509 }
510 }  // namespace
511 
FinishGroup(GroupRec * gr)512 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
513   gr->group.device_names.reserve(gr->devices.size());
514   gr->group.task_names.reserve(gr->devices.size());
515   std::vector<DeviceAttributes> attributes;
516   // Unique tasks. It's used to calculate num_tasks.
517   std::unordered_set<string> tasks;
518   attributes.reserve(gr->devices.size());
519   for (const auto& item : gr->devices) {
520     gr->group.device_names.push_back(item.first);
521     string task_name = TaskNameFromDeviceName(item.first);
522     gr->group.task_names.push_back(task_name);
523     tasks.insert(task_name);
524     attributes.push_back(item.second);
525   }
526   gr->group.num_tasks = static_cast<int32>(tasks.size());
527   // Sort device_names lexicographically, keeping task_names in corresponding
528   // order. Also set number of devices per task.
529   SortDevicesAndTasks(&gr->group);
530   // Establish the final order of gp->device_names and gp->task_names by
531   // considering localities of all devices.
532   CompleteDefaultRanking(attributes, &gr->group);
533 }
534 
CompleteTaskIsLocal(const string & task_name,CollectiveParams * cp)535 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
536                                                        CollectiveParams* cp) {
537   cp->task.is_local.resize(cp->group.group_size, false);
538   for (int i = 0; i < cp->group.group_size; ++i) {
539     cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
540   }
541 }
542 
SetDefaultRank(const string & device,CollectiveParams * cp)543 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
544                                                   CollectiveParams* cp) {
545   CHECK_EQ(cp->group.group_size, cp->group.device_names.size()) << cp;
546   for (int i = 0; i < cp->group.group_size; ++i) {
547     if (cp->group.device_names[i] == device) {
548       cp->default_rank = i;
549       break;
550     }
551   }
552 }
553 
InitInstanceSharedParams(const GroupRec * gr,const CollectiveParams * cp,InstanceRec * ir)554 void CollectiveParamResolverLocal::InitInstanceSharedParams(
555     const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
556   ir->shared->instance = cp->instance;
557   ir->shared->default_rank = -1;
558 
559   // Set is_local and task_names in *shared prior to invoking
560   // GetDeviceAttributesAsync.  In a distributed context this function can be
561   // called by a derived class, some of the devices may be non-local and
562   // GetDeviceAttributesAsync will use those fields to launch RPCs.
563   CompleteTaskIsLocal(task_name_, ir->shared);
564 }
565 
566 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
567 // to all devices that they are physically connected to and visible to the
568 // TensorFlow runtime.  This set of devices may be a superset of the devices
569 // participating in this instance of collectives.
CompleteDefaultRanking(const std::vector<DeviceAttributes> & attributes,CollGroupParams * gp)570 void CollectiveParamResolverLocal::CompleteDefaultRanking(
571     const std::vector<DeviceAttributes>& attributes, CollGroupParams* gp) {
572   // Establish an instance-specific default rank order for devices
573   // based on localities.  This rank order should be a good ring
574   // order, if possible.
575   GlobalDeviceMap gdm = EstablishGlobalRank(*gp, attributes);
576   // Reflect the new global ranking on shared
577   size_t num_devices = gp->group_size;
578   std::vector<string> new_device_names(num_devices, "");
579   std::vector<string> new_task_names(num_devices, "");
580   for (const auto& git : gdm) {
581     const TaskDeviceMap& tdm = git.second;
582     for (const auto& tit : tdm) {
583       const DevRec& dr = tit.second;
584       new_device_names[dr.global_rank] = gp->device_names[dr.original_rank];
585       new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
586     }
587   }
588 
589   gp->device_names = new_device_names;
590   gp->task_names = new_task_names;
591   if (VLOG_IS_ON(2)) {
592     string buf;
593     for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
594     VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
595             << buf;
596   }
597 }
598 
599 CollectiveParamResolverLocal::InstanceRec*
GetOrCreateInstanceRec(const GroupRec * gr,CollectiveParams * cp,bool * created)600 CollectiveParamResolverLocal::GetOrCreateInstanceRec(const GroupRec* gr,
601                                                      CollectiveParams* cp,
602                                                      bool* created) {
603   *created = false;
604   InstanceRec* irec = nullptr;
605   {
606     mutex_lock l(instance_mu_);
607     auto group_it = instance_table_.find(gr->group.group_key);
608     if (group_it != instance_table_.end()) {
609       auto instance_it = group_it->second.find(cp->instance.instance_key);
610       if (instance_it != group_it->second.end()) {
611         irec = instance_it->second.get();
612       }
613     }
614     if (irec == nullptr) {
615       // Create new InstanceRec.
616       irec = new InstanceRec;
617       *created = true;
618       {
619         mutex_lock il(irec->mu);
620         irec->known.resize(cp->group.group_size, false);
621       }
622       InitInstanceSharedParams(gr, cp, irec);
623       instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
624           irec);
625     }
626   }
627   Status status;
628   {
629     mutex_lock l(status_mu_);
630     status = status_;
631   }
632   if (!status.ok()) {
633     mutex_lock l(irec->mu);
634     irec->status = status;
635   }
636   return irec;
637 }
638 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)639 void CollectiveParamResolverLocal::CompleteParamsAsync(
640     const DeviceAttributes& device, CollectiveParams* cp,
641     CancellationManager* cancel_mgr, const StatusCallback& done) {
642   VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
643           << cp->ToString();
644   CompleteGroupLocal(
645       device, cp,
646       [this, device, cp, done](const Status& s, const GroupRec* gr) {
647         if (s.ok()) {
648           CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
649         } else {
650           done(s);
651         }
652       },
653       cancel_mgr);
654 }
655 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)656 void CollectiveParamResolverLocal::CompleteInstanceAsync(
657     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
658     CancellationManager* cancel_mgr, const StatusCallback& done) {
659   done(
660       errors::Internal("CompleteInstance is not implemented by "
661                        "CollectiveParamResolverLocal which is "
662                        "intended only for non-distributed deployment."));
663 }
664 
665 // TODO(b/111897089): we need a better way to pick the collective
666 // implementation.  The ideal way would depend upon the topology and link
667 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)668 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
669   // We use the NCCL implementation if this is an environment which supports
670   // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
671   // also if indicated either in `ConfigProto` or `communication_hint`.
672   //
673   // After enough testing, we may simplify this logic to use NCCL whenever
674   // available.
675   CollectiveImplementationInterface* col_impl;
676   bool use_nccl =
677       (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
678       CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
679           .ok();
680   cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
681   VLOG(1) << "AssignCollectiveType "
682           << cp->instance.impl_details.collective_name;
683 }
684 
CompleteInstanceLocal(const string & device,const GroupRec * gr,CollectiveParams * cp,bool is_source,const StatusCallback & done)685 void CollectiveParamResolverLocal::CompleteInstanceLocal(
686     const string& device, const GroupRec* gr, CollectiveParams* cp,
687     bool is_source, const StatusCallback& done) {
688   VLOG(1) << "CompleteInstanceLocal " << device
689           << " instance_key: " << cp->instance.instance_key << " gr " << gr;
690 
691   // Populate the group portion of *cp from *gr.  Most of it should already
692   // match.
693   {
694     mutex_lock l(gr->mu);
695     DCHECK_EQ(cp->group.group_key, gr->group.group_key);
696     DCHECK_EQ(cp->group.group_size, gr->group.group_size);
697     DCHECK_EQ(cp->group.device_type, gr->group.device_type);
698     cp->group = gr->group;
699   }
700 
701   bool created_irec;
702   InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
703   if (!created_irec) {
704     // Check that the preexisting IRec is consistent with the params passed into
705     // this invocation.
706     if (ir->shared->instance.type != cp->instance.type ||
707         ir->shared->instance.data_type != cp->instance.data_type) {
708       done(errors::Internal("Collective instance ", cp->instance.instance_key,
709                             " expected type ", ir->shared->instance.type,
710                             " and data_type ", ir->shared->instance.data_type,
711                             " but got type ", cp->instance.type,
712                             " and data_type ", cp->instance.data_type));
713       return;
714     }
715   }
716   CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done);
717 }
718 
CompleteInstanceFromInitializedIRec(const string & device,const GroupRec * gr,CollectiveParams * cp,InstanceRec * ir,bool is_source,const StatusCallback & done)719 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
720     const string& device, const GroupRec* gr, CollectiveParams* cp,
721     InstanceRec* ir, bool is_source, const StatusCallback& done) {
722   auto expected_shape = cp->instance.shape;
723   Status status;
724   // Populate the fields common across instance.
725   {
726     mutex_lock l(ir->mu);
727     status = ir->status;
728     if (status.ok()) {
729       // custom operator= does a deep copy.
730       cp->instance = ir->shared->instance;
731     }
732   }
733   if (!status.ok()) {
734     done(status);
735     return;
736   }
737   if (expected_shape != cp->instance.shape) {
738     done(errors::InvalidArgument(
739         "Shape mismatch in the collective instance ", cp->instance.instance_key,
740         ". Op at device ", device, " expected shape ",
741         expected_shape.DebugString(), " but another member in the group ",
742         "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
743         " due to different input shapes at different members of the collective",
744         " op."));
745     return;
746   }
747   // Populate the fields common across task.
748   AssignCollectiveType(cp);
749   SetDefaultRank(device, cp);
750   CompleteTaskIsLocal(task_name_, cp);
751 
752   CollectiveImplementationInterface* col_impl;
753   status = CollectiveRegistry::LookupParamResolverInstance(
754       cp->instance.impl_details.collective_name, &col_impl);
755   if (!status.ok()) {
756     done(status);
757     return;
758   }
759 
760   //  We may need to wait for the group, if this is a broadcast, for source
761   //  discovery.
762   if (cp->instance.type == BROADCAST_COLLECTIVE) {
763     WaitForGroup(ir, cp, is_source,
764                  [col_impl, ir, device, cp, done](InstanceRec* irec) {
765                    Status s;
766                    if (ir != irec) {
767                      s = errors::Internal("Expected ir ", ir, " and irec ",
768                                           irec, " to be equal");
769                    } else {
770                      mutex_lock l(irec->mu);
771                      s = irec->status;
772                      cp->source_rank = irec->source_rank;
773                    }
774                    if (s.ok()) {
775                      s = col_impl->InitializeCollectiveParams(cp);
776                    }
777                    done(s);
778                  });
779   } else {
780     done(col_impl->InitializeCollectiveParams(cp));
781   }
782 }
783 
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,bool is_source,const IRConsumer & f)784 void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
785                                                 CollectiveParams* cp,
786                                                 bool is_source,
787                                                 const IRConsumer& f) {
788   std::vector<IRConsumer> ready_waiters;
789   do {
790     mutex_lock l(ir->mu);
791     if (!ir->status.ok()) {
792       break;
793     }
794     CHECK_EQ(cp->group.group_size, ir->known.size());
795     CHECK_GE(cp->default_rank, 0);
796     if (!ir->known[cp->default_rank]) {
797       ir->known[cp->default_rank] = true;
798       ++ir->known_count;
799       if (is_source) {
800         // Initialize source rank.
801         if (ir->source_rank >= 0) {
802           ir->status = errors::Internal("Instance ", cp->instance.instance_key,
803                                         " already has source ", ir->source_rank,
804                                         ", received second claim from ",
805                                         cp->default_rank);
806         } else {
807           ir->source_rank = cp->default_rank;
808         }
809       }
810     }
811     if (ir->known_count < cp->group.group_size) {
812       ir->known_waiters.push_back(f);
813       return;
814     }
815     CHECK_EQ(ir->known_count, cp->group.group_size);
816     if (ir->source_rank < 0) {
817       // NOTE(ayushd): changing the error message below would also require
818       // updating CompleteParamsBroadcastForgotSend test in
819       // CollectiveParamResolverLocalTest.
820       ir->status =
821           errors::Internal("Instance ", cp->instance.instance_key,
822                            " found no source for broadcast.  This "
823                            "could mean that there were group_size=",
824                            ir->known_count, " BcastRecvs but no BcastSend.");
825     }
826     if (!ir->known_waiters.empty()) {
827       ready_waiters = std::move(ir->known_waiters);
828     }
829   } while (false);
830   f(ir);
831   for (auto& f : ready_waiters) {
832     f(ir);
833   }
834 }
835 
StartAbort(const Status & s)836 void CollectiveParamResolverLocal::StartAbort(const Status& s) {
837   {
838     mutex_lock l(status_mu_);
839     if (!status_.ok()) {
840       VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
841                  "subsequent abortion with status: "
842               << s;
843       return;
844     }
845     status_ = s;
846   }
847   StartAbortLocal(s);
848 }
849 
StartAbortLocal(const Status & s)850 void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
851   {
852     mutex_lock l(group_mu_);
853     for (const auto& item : group_table_) {
854       GroupRec* gr = item.second.get();
855       std::vector<StatusCallback> waiting;
856       {
857         mutex_lock gl(gr->mu);
858         gr->status = s;
859         waiting.swap(gr->waiting);
860       }
861       for (const StatusCallback& done : waiting) {
862         done(s);
863       }
864     }
865   }
866   std::vector<InstanceRec*> instances;
867   {
868     mutex_lock l(instance_mu_);
869     for (const auto& group_entry : instance_table_) {
870       for (const auto& item : group_entry.second) {
871         instances.push_back(item.second.get());
872       }
873     }
874   }
875   for (InstanceRec* ir : instances) {
876     std::vector<IRConsumer> known_waiters;
877     {
878       mutex_lock il(ir->mu);
879       ir->status = s;
880       known_waiters.swap(ir->known_waiters);
881     }
882     for (const IRConsumer& done : known_waiters) {
883       done(ir);
884     }
885   }
886 }
887 
888 }  // namespace tensorflow
889