1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
17 #include <algorithm>
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25
26 namespace xla {
27 namespace gpu {
28
AddDependenciesOnTransitiveOperands(const Thunk & thunk,const HloInstruction & operand,const absl::flat_hash_map<const HloInstruction *,Thunk * > & hlo_to_thunk)29 void ThunkSchedule::AddDependenciesOnTransitiveOperands(
30 const Thunk& thunk, const HloInstruction& operand,
31 const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
32 if (hlo_to_thunk.contains(&operand)) {
33 // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
34 // list if `operand` is assigned to a different stream. As an optimization,
35 // we skip `operand`'s operands because `operand` depends on them already.
36 if (stream_assignment_->StreamNumberForHlo(operand) !=
37 stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) {
38 depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
39 }
40 } else {
41 // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
42 // operands.
43 for (const auto* operand_of_operand : operand.operands()) {
44 AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
45 hlo_to_thunk);
46 }
47 }
48 }
49
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,std::unique_ptr<StreamAssignment> stream_assignment,const std::vector<HloInstruction * > & hlo_total_order)50 ThunkSchedule::ThunkSchedule(
51 std::unique_ptr<ThunkSequence> thunks,
52 std::unique_ptr<StreamAssignment> stream_assignment,
53 const std::vector<HloInstruction*>& hlo_total_order)
54 : thunks_(std::move(thunks)),
55 stream_assignment_(std::move(stream_assignment)) {
56 absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
57 for (const auto& thunk : *thunks_) {
58 InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
59 }
60
61 for (HloInstruction* hlo : hlo_total_order) {
62 if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
63 thunk_total_order_.push_back(*thunk);
64 }
65 }
66
67 for (const Thunk* thunk : thunk_total_order_) {
68 const auto* dst = thunk->hlo_instruction();
69 CHECK(stream_assignment_->HasStreamAssigned(*dst));
70 for (const auto* src : dst->operands()) {
71 AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
72 }
73 }
74
75 RemoveRedundantDependencyEdges();
76
77 // Compute `depended_by_`, the inverse of `depends_on_`.
78 for (const auto& dependency : depends_on_) {
79 for (const auto* depended : dependency.second) {
80 depended_by_.insert(depended);
81 }
82 }
83 }
84
RemoveRedundantDependencyEdges()85 void ThunkSchedule::RemoveRedundantDependencyEdges() {
86 std::unordered_map<const Thunk*, int> thunk_to_total_order;
87 for (int i = 0; i < thunk_total_order_.size(); ++i) {
88 InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
89 }
90
91 int stream_count = stream_assignment_->StreamCount();
92 // S1 S2
93 //
94 // T1<----+
95 // |
96 // T3<--+ |
97 // | | depends on
98 // T4 |
99 // |
100 // T2-+
101 //
102 // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
103 // stream S2. If T2 depends on T1 and T4 depends on T3, and
104 // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
105 // redundant.
106 //
107 // To efficiently detect such redundancy, we leverage array `last_dependency`.
108 // last_dependency[S1][S2] indicates the last thunk (with the maximum order
109 // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
110 // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
111 // redundant dependency edge.
112 Array2D<int> last_dependency(stream_count, stream_count, -1);
113 for (const Thunk* dst : thunk_total_order_) {
114 if (!depends_on_.contains(dst)) {
115 continue;
116 }
117
118 int dst_stream =
119 stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction());
120 std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
121 for (auto iter = sources.begin(); iter != sources.end();) {
122 const Thunk* src = *iter;
123 // `dst` depends on `src`.
124 int src_stream =
125 stream_assignment_->StreamNumberForHlo(*src->hlo_instruction());
126 int src_order = FindOrDie(thunk_to_total_order, src);
127 if (src_order <= last_dependency(dst_stream, src_stream)) {
128 iter = sources.erase(iter);
129 } else {
130 last_dependency(dst_stream, src_stream) = src_order;
131 ++iter;
132 }
133 }
134 if (sources.empty()) {
135 depends_on_.erase(dst);
136 }
137 }
138 }
139
DependsOn(const Thunk * thunk) const140 const std::list<const Thunk*>& ThunkSchedule::DependsOn(
141 const Thunk* thunk) const {
142 if (depends_on_.contains(thunk)) {
143 return FindOrDie(depends_on_, thunk);
144 } else {
145 return empty_thunk_list_;
146 }
147 }
148
ToString() const149 string ThunkSchedule::ToString() const {
150 if (thunk_total_order_.empty()) {
151 return "No thunks.";
152 }
153
154 const Thunk* thunk_with_longest_kind = *absl::c_max_element(
155 thunk_total_order_, [](const Thunk* a, const Thunk* b) {
156 return ThunkKindToString(a->kind()).length() <
157 ThunkKindToString(b->kind()).length();
158 });
159 int64 max_thunk_kind_len =
160 ThunkKindToString(thunk_with_longest_kind->kind()).length();
161
162 string result = "Total order:\n";
163 for (Thunk* thunk : thunk_total_order_) {
164 // Write out the thunk kind, padded out to max_thunk_kind_len.
165 absl::string_view kind_str = ThunkKindToString(thunk->kind());
166 absl::StrAppend(&result, kind_str,
167 string(max_thunk_kind_len - kind_str.length(), ' '), "\t");
168 if (thunk->hlo_instruction() != nullptr) {
169 absl::StrAppend(&result, thunk->hlo_instruction()->ToString());
170 } else {
171 absl::StrAppend(&result, "(no HloInstruction)");
172 }
173 absl::StrAppend(&result, "\n");
174 }
175 absl::StrAppend(&result, "\nDependencies:\n");
176 for (const auto& entry : depends_on_) {
177 const Thunk* dependent = entry.first;
178 for (const Thunk* dependency : entry.second) {
179 absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(),
180 " depends on ", dependency->hlo_instruction()->name(),
181 "\n");
182 }
183 }
184 return result;
185 }
186
187 } // namespace gpu
188 } // namespace xla
189