1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/platform/mutex.h"
32 
33 // This thunk's implementation is somewhat similar to our implementation of
34 // AllReduce using NCCL.  One reason it's separate is that, because this doesn't
35 // use NCCL, it can work even without a CUDA compiler.
36 
37 namespace xla {
38 namespace gpu {
39 
40 namespace {
41 
42 using tensorflow::BlockingCounter;
43 
44 // This same function appears in nccl_all_reduce_thunk.  I've copy/pasted it
45 // here primarily because I want the VLOGs to work.
46 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)47 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
48                        const DescFn& desc_fn) {
49   VLOG(3) << "Begin: " << desc_fn();
50   const std::chrono::milliseconds timeout(5000);
51   bool ok = counter->WaitFor(timeout);
52   if (ok) {
53     VLOG(3) << "Finished: " << desc_fn();
54     return;
55   }
56   LOG(ERROR) << "This thread has been waiting for " << timeout.count()
57              << "ms for and may be stuck: " << desc_fn();
58   counter->Wait();
59   LOG(ERROR) << "Thread is unstuck!  Warning above was a false-positive.  "
60                 "Perhaps the timeout is too short: "
61              << desc_fn();
62 }
63 
64 // Key for looking up a Rendezvous object in our global map.
65 //
66 // Morally, the key is just a RunId.  num_participants is in this struct only
67 // because we use that information when constructing the Rendezvous.
68 struct RendezvousKey {
69   RunId run_id;
70   int num_participants;  // int, not int64, to match BlockingCounter's counter.
71 
ToStringxla::gpu::__anonae9547960111::RendezvousKey72   string ToString() const {
73     return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
74                            run_id.ToString(), num_participants);
75   }
76 
77   template <typename H>
AbslHashValue(H h,const RendezvousKey & k)78   friend H AbslHashValue(H h, const RendezvousKey& k) {
79     return H::combine(std::move(h), k.run_id);
80   }
operator ==(const RendezvousKey & a,const RendezvousKey & b)81   friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
82     return a.run_id == b.run_id;
83   }
operator !=(const RendezvousKey & a,const RendezvousKey & b)84   friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
85     return !(a == b);
86   }
87 };
88 
89 // Information about a thread that's participating in a collective-permute
90 // operation.
91 struct ParticipantData {
92   int64 replica_id;
93   se::Stream* stream;
94 
95   se::DeviceMemoryBase src;
96   se::DeviceMemoryBase dest;
97 
98   // ReplicaIds to which we will copy the data in src.
99   std::vector<int64> dest_replicas;
100 };
101 
102 // The set of threads that want to do a collective permute together all pick the
103 // same Rendezvous object out of the global cache and call SubmitParticipant.
104 //
105 // The Rendezvous instance handles waiting for all threads to join and then
106 // doing the actual work of the collective permute.
107 //
108 // Rendezvous objects can only be used once.
109 class Rendezvous {
110  public:
Rendezvous(const RendezvousKey & key)111   explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
112 
113   // Runs the collective permute on the given thread.
114   //
115   // If successful, returns a BlockingCounter initialized to the number of
116   // participants, so that the caller can coordinate with the participants one
117   // last time if it chooses.  This is useful for coordinating destruction of
118   // the Rendezvous.
119   StatusOr<std::shared_ptr<BlockingCounter>> SubmitParticipant(
120       ParticipantData participant);
121 
122  private:
123   const RendezvousKey key_;
124   BlockingCounter all_arrived_{key_.num_participants};
125 
126   // BlockingCounter returned by SubmitParticipant.
127   std::shared_ptr<BlockingCounter> returned_blocking_counter_{
128       std::make_shared<BlockingCounter>(key_.num_participants)};
129 
130   tensorflow::mutex mu_;
131   bool initialized_ TF_GUARDED_BY(mu_) = false;
132 
133   // We use an std::map so that we can iterate over it below in a guaranteed
134   // order.  The order shouldn't actually matter, but why be nondeterministic if
135   // we don't have to be?
136   std::map<int64, ParticipantData> participants_ TF_GUARDED_BY(mu_);
137 };
138 
EnqueueCopy(se::DeviceMemoryBase src,se::Stream * src_stream,se::DeviceMemoryBase dest,se::Stream * dest_stream)139 void EnqueueCopy(se::DeviceMemoryBase src, se::Stream* src_stream,
140                  se::DeviceMemoryBase dest, se::Stream* dest_stream) {
141   CHECK_EQ(src.size(), dest.size());
142 
143   // If src_stream == dest_stream, we're doing a plain memcpy from one GPU back
144   // to the same GPU.  x->ThenWaitFor(x) is illegal, so this has to be a special
145   // case.
146   if (src_stream == dest_stream) {
147     dest_stream->ThenMemcpyD2D(&dest, src, src.size());
148     return;
149   }
150 
151   // We (arbitrarily) choose to use the dest stream do perform the copy.  This
152   // means we need to make the dest stream wait until the src stream is ready
153   // before it performs the copy, and then we need to make the src stream wait
154   // until the dest stream has completed the copy.
155   dest_stream->ThenWaitFor(src_stream).ThenMemcpyD2D(&dest, src, src.size());
156   src_stream->ThenWaitFor(dest_stream);
157 }
158 
SubmitParticipant(ParticipantData participant)159 StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
160     ParticipantData participant) {
161   bool primary;
162   {
163     tensorflow::mutex_lock lock(mu_);
164     CHECK(participants_.emplace(participant.replica_id, participant).second);
165 
166     // The first thread to acquire the lock is designated as the primary.
167     primary = !initialized_;
168 
169     if (primary) {
170       initialized_ = true;
171       returned_blocking_counter_ =
172           std::make_shared<BlockingCounter>(key_.num_participants);
173     }
174   }
175 
176   // Wait for all participants to arrive.  Even though our copies are async and
177   // enqueued by just one thread, this is not optional!  If we didn't wait for
178   // everyone, then we wouldn't be able to enqueue the copies at the correct
179   // point in their streams.
180   all_arrived_.DecrementCount();
181   WaitAndLogIfStuck(&all_arrived_, [&] {
182     return absl::StrFormat(
183         "participant for replica %d (stream %p, device %d) waiting for all "
184         "other participants to arrive: %s",
185         participant.replica_id, participant.stream,
186         participant.stream->parent()->device_ordinal(), key_.ToString());
187   });
188 
189   // Schedule the copies between the devices.  This is much easier to reason
190   // about if we schedule all of the copies from just one thread.  The copies
191   // are async anyway, so the number of host threads we use isn't important.
192   if (primary) {
193     tensorflow::mutex_lock lock(mu_);
194     for (const auto& kv : participants_) {
195       const ParticipantData& src_participant = kv.second;
196       for (int64 dest_replica : src_participant.dest_replicas) {
197         const ParticipantData& dest_participant =
198             participants_.at(dest_replica);
199         EnqueueCopy(src_participant.src, src_participant.stream,
200                     dest_participant.dest, dest_participant.stream);
201       }
202     }
203   }
204 
205   return returned_blocking_counter_;
206 }
207 
208 // Global map of Rendezvous objects.  A thread participating in a collective op
209 // looks up its Rendezvous in this map to find the other threads that it's
210 // participating with.
211 //
212 // Rendezvous objects are one-time use, so they're removed from this map once
213 // we're through with them.
GlobalRendezvousMap()214 RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
215   static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
216   return m;
217 }
218 
219 }  // anonymous namespace
220 
CollectivePermuteThunk(ThunkInfo thunk_info,std::vector<std::pair<int64,int64>> source_target_pairs,const BufferAllocation::Slice & src,const BufferAllocation::Slice & dest)221 CollectivePermuteThunk::CollectivePermuteThunk(
222     ThunkInfo thunk_info,
223     std::vector<std::pair<int64, int64>> source_target_pairs,
224     const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest)
225     : Thunk(kCollectivePermute, thunk_info),
226       source_target_pairs_(std::move(source_target_pairs)),
227       src_(src),
228       dest_(dest) {}
229 
ExecuteOnStream(const ExecuteParams & params)230 Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
231   auto op_profiler =
232       params.profiler->MakeScopedInstructionProfiler(profile_index());
233 
234   // Rendezvous with the threads for all other devices that are participating in
235   // this CollectivePermute.
236   RendezvousKey key{params.run_id, params.device_assn->replica_count()};
237   auto rendezvous_factory = [](const RendezvousKey& key) {
238     return absl::make_unique<Rendezvous>(key);
239   };
240   std::shared_ptr<Rendezvous> rendezvous =
241       GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory);
242 
243   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
244                       params.GetGlobalDeviceId());
245   TF_ASSIGN_OR_RETURN(int64 replica_id,
246                       params.device_assn->ReplicaIdForDevice(global_device_id));
247 
248   // Figure out which replicas our data is copied to.
249   std::vector<int64> dest_replicas;
250   for (const auto& src_dest : source_target_pairs_) {
251     if (src_dest.first == replica_id) {
252       dest_replicas.push_back(src_dest.second);
253     }
254   }
255 
256   auto src_addr = params.buffer_allocations->GetDeviceAddress(src_);
257   auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
258   ParticipantData participant{replica_id, params.stream, src_addr, dest_addr,
259                               dest_replicas};
260   TF_ASSIGN_OR_RETURN(std::shared_ptr<BlockingCounter> final_sync,
261                       rendezvous->SubmitParticipant(participant));
262 
263   // If no replica writes into us (i.e. we aren't the target of any copies), our
264   // contract is that we zero our output.
265   if (absl::c_none_of(source_target_pairs_,
266                       [&](std::pair<int64, int64> src_dest) {
267                         return src_dest.second == replica_id;
268                       })) {
269     params.stream->ThenMemZero(&dest_addr, dest_addr.size());
270   }
271 
272   // Drop our reference to the Rendezvous and wait for all other threads to do
273   // the same.  If we didn't do this, one of the threads could run past this
274   // point, reenter ExecuteOnStream for another collective-permute, and attempt
275   // to reuse the Rendezvous!
276   //
277   // An alternative way of accomplishing this goal would be to implement
278   // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
279   // erase() is deceptively complex to implement correctly.
280   rendezvous.reset();
281   final_sync->DecrementCount();
282   WaitAndLogIfStuck(final_sync.get(), [&] {
283     return absl::StrFormat(
284         "participant for replica %d (stream %p, device ordinal %d) waiting for "
285         "all threads to drop their reference to the rendezvous: %s",
286         participant.replica_id, participant.stream,
287         participant.stream->parent()->device_ordinal(), key.ToString());
288   });
289   return Status::OK();
290 }
291 
292 }  // namespace gpu
293 }  // namespace xla
294