1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
17 #include <algorithm>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 
26 namespace xla {
27 namespace gpu {
28 
AddDependenciesOnTransitiveOperands(const Thunk & thunk,const HloInstruction & operand,const absl::flat_hash_map<const HloInstruction *,Thunk * > & hlo_to_thunk)29 void ThunkSchedule::AddDependenciesOnTransitiveOperands(
30     const Thunk& thunk, const HloInstruction& operand,
31     const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
32   if (hlo_to_thunk.contains(&operand)) {
33     // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
34     // list if `operand` is assigned to a different stream. As an optimization,
35     // we skip `operand`'s operands because `operand` depends on them already.
36     if (stream_assignment_->StreamNumberForHlo(operand) !=
37         stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) {
38       depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
39     }
40   } else {
41     // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
42     // operands.
43     for (const auto* operand_of_operand : operand.operands()) {
44       AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
45                                           hlo_to_thunk);
46     }
47   }
48 }
49 
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,std::unique_ptr<StreamAssignment> stream_assignment,const std::vector<HloInstruction * > & hlo_total_order)50 ThunkSchedule::ThunkSchedule(
51     std::unique_ptr<ThunkSequence> thunks,
52     std::unique_ptr<StreamAssignment> stream_assignment,
53     const std::vector<HloInstruction*>& hlo_total_order)
54     : thunks_(std::move(thunks)),
55       stream_assignment_(std::move(stream_assignment)) {
56   absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
57   for (const auto& thunk : *thunks_) {
58     InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
59   }
60 
61   for (HloInstruction* hlo : hlo_total_order) {
62     if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
63       thunk_total_order_.push_back(*thunk);
64     }
65   }
66 
67   for (const Thunk* thunk : thunk_total_order_) {
68     const auto* dst = thunk->hlo_instruction();
69     CHECK(stream_assignment_->HasStreamAssigned(*dst));
70     for (const auto* src : dst->operands()) {
71       AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
72     }
73   }
74 
75   RemoveRedundantDependencyEdges();
76 
77   // Compute `depended_by_`, the inverse of `depends_on_`.
78   for (const auto& dependency : depends_on_) {
79     for (const auto* depended : dependency.second) {
80       depended_by_.insert(depended);
81     }
82   }
83 }
84 
RemoveRedundantDependencyEdges()85 void ThunkSchedule::RemoveRedundantDependencyEdges() {
86   std::unordered_map<const Thunk*, int> thunk_to_total_order;
87   for (int i = 0; i < thunk_total_order_.size(); ++i) {
88     InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
89   }
90 
91   int stream_count = stream_assignment_->StreamCount();
92   // S1  S2
93   //
94   // T1<----+
95   //        |
96   // T3<--+ |
97   //      | | depends on
98   //     T4 |
99   //        |
100   //     T2-+
101   //
102   // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
103   // stream S2. If T2 depends on T1 and T4 depends on T3, and
104   // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
105   // redundant.
106   //
107   // To efficiently detect such redundancy, we leverage array `last_dependency`.
108   // last_dependency[S1][S2] indicates the last thunk (with the maximum order
109   // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
110   // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
111   // redundant dependency edge.
112   Array2D<int> last_dependency(stream_count, stream_count, -1);
113   for (const Thunk* dst : thunk_total_order_) {
114     if (!depends_on_.contains(dst)) {
115       continue;
116     }
117 
118     int dst_stream =
119         stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction());
120     std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
121     for (auto iter = sources.begin(); iter != sources.end();) {
122       const Thunk* src = *iter;
123       // `dst` depends on `src`.
124       int src_stream =
125           stream_assignment_->StreamNumberForHlo(*src->hlo_instruction());
126       int src_order = FindOrDie(thunk_to_total_order, src);
127       if (src_order <= last_dependency(dst_stream, src_stream)) {
128         iter = sources.erase(iter);
129       } else {
130         last_dependency(dst_stream, src_stream) = src_order;
131         ++iter;
132       }
133     }
134     if (sources.empty()) {
135       depends_on_.erase(dst);
136     }
137   }
138 }
139 
DependsOn(const Thunk * thunk) const140 const std::list<const Thunk*>& ThunkSchedule::DependsOn(
141     const Thunk* thunk) const {
142   if (depends_on_.contains(thunk)) {
143     return FindOrDie(depends_on_, thunk);
144   } else {
145     return empty_thunk_list_;
146   }
147 }
148 
ToString() const149 string ThunkSchedule::ToString() const {
150   if (thunk_total_order_.empty()) {
151     return "No thunks.";
152   }
153 
154   const Thunk* thunk_with_longest_kind = *absl::c_max_element(
155       thunk_total_order_, [](const Thunk* a, const Thunk* b) {
156         return ThunkKindToString(a->kind()).length() <
157                ThunkKindToString(b->kind()).length();
158       });
159   int64 max_thunk_kind_len =
160       ThunkKindToString(thunk_with_longest_kind->kind()).length();
161 
162   string result = "Total order:\n";
163   for (Thunk* thunk : thunk_total_order_) {
164     // Write out the thunk kind, padded out to max_thunk_kind_len.
165     absl::string_view kind_str = ThunkKindToString(thunk->kind());
166     absl::StrAppend(&result, kind_str,
167                     string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
168     if (thunk->hlo_instruction() != nullptr) {
169       absl::StrAppend(&result, thunk->hlo_instruction()->ToString());
170     } else {
171       absl::StrAppend(&result, "(no HloInstruction)");
172     }
173     absl::StrAppend(&result, "\n");
174   }
175   absl::StrAppend(&result, "\nDependencies:\n");
176   for (const auto& entry : depends_on_) {
177     const Thunk* dependent = entry.first;
178     for (const Thunk* dependency : entry.second) {
179       absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(),
180                       " depends on ", dependency->hlo_instruction()->name(),
181                       "\n");
182     }
183   }
184   return result;
185 }
186 
187 }  // namespace gpu
188 }  // namespace xla
189