1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/map_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/core/platform/random.h"
25 
26 namespace xla {
27 namespace gpu {
28 
HasStreamAssigned(const HloInstruction & hlo) const29 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
30   return hlo_to_stream_number_.contains(&hlo);
31 }
32 
StreamNumberForHlo(const HloInstruction & hlo) const33 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
34   return FindOrDie(hlo_to_stream_number_, &hlo);
35 }
36 
AssignStreamToHlo(const HloInstruction * hlo,int stream_num)37 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
38                                          int stream_num) {
39   CHECK_GE(stream_num, 0);
40   if (stream_num >= stream_count_) {
41     stream_count_ = stream_num + 1;
42   }
43   InsertOrDie(&hlo_to_stream_number_, hlo, stream_num);
44   VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString();
45 }
46 
47 namespace {
48 
49 // Returns whether the two HLOs can run concurrently, i.e., neither is a
50 // transitive consumer of the other.
CanRunConcurrently(const HloInstruction & a,const HloInstruction & b,const HloReachabilityMap & reachability)51 bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
52                         const HloReachabilityMap& reachability) {
53   return !reachability.IsConnected(&a, &b);
54 }
55 
56 constexpr int kInvalidStreamNum = -1;
57 //  Returns true iff `stream_num` is an invalid stream number.
IsStreamNumValid(int stream_num)58 inline bool IsStreamNumValid(int stream_num) {
59   return stream_num != kInvalidStreamNum;
60 }
61 
62 // Returns which existing stream to assign to `hlo`, or -1 if a stream is not
63 // needed. `stream_assignment` is the existing stream assignment for all
64 // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
65 // are topologically before `hlo`.
ComputeStreamToAssign(const HloInstruction & hlo,const StreamAssignment & stream_assignment,const HloReachabilityMap & reachability,const std::vector<const HloInstruction * > & seen_gemms)66 int ComputeStreamToAssign(
67     const HloInstruction& hlo, const StreamAssignment& stream_assignment,
68     const HloReachabilityMap& reachability,
69     const std::vector<const HloInstruction*>& seen_gemms) {
70   if (hlo.opcode() == HloOpcode::kParameter ||
71       hlo.opcode() == HloOpcode::kConstant) {
72     // kParameter and kConstant do not need a thunk.
73     return kInvalidStreamNum;
74   }
75 
76   const auto& debug_options = hlo.GetModule()->config().debug_options();
77   if (debug_options.xla_gpu_disable_multi_streaming()) {
78     return 0;
79   }
80 
81   if (debug_options.xla_gpu_use_random_streams()) {
82     // Debug feature: make random stream assignments to try to uncover
83     // concurrency bugs.
84     return tensorflow::random::New64() % 100;
85   }
86 
87   if (!(IsCublasGemm(hlo) || IsMatrixMultiplication(hlo))) {
88     // If `hlo` is not implemented as a GEMM, keep it close to its operands to
89     // avoid excessive synchronization.
90     int stream_num = -1;
91     for (const auto* operand : hlo.operands()) {
92       if (stream_assignment.HasStreamAssigned(*operand)) {
93         stream_num = std::max(stream_num,
94                               stream_assignment.StreamNumberForHlo(*operand));
95       }
96     }
97     if (!IsStreamNumValid(stream_num)) {
98       stream_num = 0;
99     }
100     return stream_num;
101   }
102 
103   // Assign different streams to concurrent GEMMs. The code below uses a
104   // greedy approach. First, we compute as forbidden_stream_numbers the
105   // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
106   // `hlo` a different stream.
107   absl::flat_hash_set<int> forbidden_stream_numbers;
108   for (const auto* seen_gemm : seen_gemms) {
109     int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm);
110     if (!forbidden_stream_numbers.contains(stream_num) &&
111         CanRunConcurrently(*seen_gemm, hlo, reachability)) {
112       forbidden_stream_numbers.insert(stream_num);
113     }
114   }
115 
116   for (int stream_num = 0; stream_num < stream_assignment.StreamCount();
117        ++stream_num) {
118     if (!forbidden_stream_numbers.contains(stream_num)) {
119       return stream_num;
120     }
121   }
122   return stream_assignment.StreamCount();
123 }
124 
125 }  // namespace
126 
AssignStreams(const HloModule & module)127 std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
128   auto stream_assignment = absl::make_unique<StreamAssignment>();
129   const HloComputation& computation = *module.entry_computation();
130   std::unique_ptr<HloReachabilityMap> reachability =
131       HloReachabilityMap::Build(&computation);
132   std::vector<const HloInstruction*> seen_gemms;
133   // The execution of different RNG Hlo instructions in the same module updates
134   // a common global variable. To avoid a race condition, we simply assign all
135   // RNG kernels to the same stream to make them run sequentially.
136   //
137   // TODO(b/111791052): If we remove such a common variable, we will need to
138   // clean up the code here.
139   int stream_num_for_rng = kInvalidStreamNum;
140   for (const auto* hlo : computation.MakeInstructionPostOrder()) {
141     // If we ever enable fusion of RNG instructions, we will need to extend this
142     // code to look inside a fused instruction.
143     int stream_num = (hlo->opcode() == HloOpcode::kRng &&
144                       IsStreamNumValid(stream_num_for_rng))
145                          ? stream_num_for_rng
146                          : ComputeStreamToAssign(*hlo, *stream_assignment,
147                                                  *reachability, seen_gemms);
148     if (IsStreamNumValid(stream_num)) {
149       stream_assignment->AssignStreamToHlo(hlo, stream_num);
150       if (hlo->opcode() == HloOpcode::kRng &&
151           !IsStreamNumValid(stream_num_for_rng)) {
152         stream_num_for_rng = stream_num;
153       }
154     }
155     if (IsCublasGemm(*hlo) || IsMatrixMultiplication(*hlo)) {
156       seen_gemms.push_back(hlo);
157     }
158   }
159   return stream_assignment;
160 }
161 
162 }  // namespace gpu
163 }  // namespace xla
164