1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
17 
18 #include <queue>
19 
20 namespace xla {
GetOptimizedId(const HloValue & value)21 HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
22   Node* node = value_id_to_node_[value.id()];
23   CHECK(!node->mark_as_dead);
24   return node->value_id;
25 }
26 
27 // Returns true if the inputs to a hlo value are the same as `inputs`.
InputsEqualTo(const HloValue & value,absl::Span<const HloValue * const> inputs)28 bool PhiGraph::InputsEqualTo(const HloValue& value,
29                              absl::Span<const HloValue* const> inputs) {
30   auto iter = value_id_to_node_.find(value.id());
31   CHECK(iter != value_id_to_node_.end());
32   absl::flat_hash_set<HloValue::Id> existing_set;
33   for (Node* operand : iter->second->operands) {
34     existing_set.insert(operand->value_id);
35   }
36   absl::flat_hash_set<HloValue::Id> new_set;
37   for (const HloValue* input : inputs) {
38     new_set.insert(input->id());
39   }
40   return existing_set == new_set;
41 }
42 
FindOptimizedValue(const HloValue::Id id)43 HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
44   auto iter = value_id_to_node_.find(id);
45   CHECK(iter != value_id_to_node_.end());
46   CHECK(!iter->second->mark_as_dead);
47   return iter->second->value_id;
48 }
49 
CreateOrReuseNode(const HloValue & value)50 PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
51   auto iter = value_id_to_node_.find(value.id());
52   if (iter == value_id_to_node_.end()) {
53     node_storage_.emplace_back(absl::make_unique<Node>());
54     Node* node = node_storage_.back().get();
55     node->value_id = value.id();
56     value_id_to_node_[value.id()] = node;
57     node_to_value_id_[node].push_back(value.id());
58     return node;
59   } else {
60     // A node is already registered with this value, check the value_id
61     // is the same as previously registrated.
62     CHECK_NE(iter->second, nullptr);
63     CHECK_EQ(iter->second->value_id, value.id());
64     return iter->second;
65   }
66 }
67 
ReplaceNodeWith(PhiGraph::Node * node,PhiGraph::Node * replace)68 void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
69   // Update users.
70   CHECK(node->is_phi);
71   if (node->mark_as_dead) {
72     // The node has already been replaced with another.
73     return;
74   }
75   if (replace->mark_as_dead) {
76     // The node we are placing with has already been replaced with another node.
77     auto iter = value_id_to_node_.find(replace->value_id);
78     CHECK(iter != value_id_to_node_.end());
79     return ReplaceNodeWith(node, iter->second);
80   }
81   CHECK(!replace->mark_as_dead);
82   for (Node* user : node->users) {
83     absl::c_replace(user->operands, node, replace);
84   }
85 
86   // Update operand's users
87   for (Node* operand : node->operands) {
88     absl::c_replace(operand->users, node, replace);
89   }
90 
91   for (HloValue::Id value_id : node_to_value_id_[node]) {
92     CHECK(value_id_to_node_.contains(value_id));
93     value_id_to_node_[value_id] = replace;
94   }
95   // Update mappings to HloValue::Id.
96   absl::c_copy(node_to_value_id_[node],
97                std::back_inserter(node_to_value_id_[replace]));
98   node_to_value_id_[node].clear();
99   node->mark_as_dead = true;
100 }
101 
RegisterPhi(const HloValue & value,absl::Span<const HloValue * const> inputs)102 void PhiGraph::RegisterPhi(const HloValue& value,
103                            absl::Span<const HloValue* const> inputs) {
104   Node* node = CreateOrReuseNode(value);
105   CHECK(value.is_phi());
106   node->is_phi = true;
107   node->operands.clear();
108   for (auto input : inputs) {
109     CHECK(input != nullptr);
110     Node* input_node = CreateOrReuseNode(*input);
111     node->operands.push_back(input_node);
112   }
113 }
114 
ToString()115 std::string PhiGraph::ToString() {
116   std::string out = "PhiGraph: \n";
117   for (auto& node : node_storage_) {
118     std::string is_phi = node->is_phi ? ", phi" : "";
119     std::string is_optimized = node->mark_as_dead ? ", dead" : "";
120     absl::StrAppend(&out, node->value_id);
121     absl::StrAppend(&out, is_phi);
122     absl::StrAppend(&out, is_optimized, ":\n");
123     for (Node* input : node->operands) {
124       absl::StrAppend(&out, "  ", input->value_id);
125       absl::StrAppend(&out, "\n");
126     }
127   }
128   return out;
129 }
130 
Optimize()131 void PhiGraph::Optimize() {
132   VLOG(2) << "Optimizing phi graph:";
133   XLA_VLOG_LINES(2, ToString());
134   // Set up users for each node.
135   for (auto& node : node_storage_) {
136     for (Node* input : node->operands) {
137       input->users.push_back(node.get());
138     }
139   }
140 
141   // input_node->users.push_back(node);
142   bool changed = true;
143 
144   // Run the optimization to a fixed point.
145   while (changed) {
146     changed = false;
147     absl::flat_hash_set<Node*> checked_for_closure;
148     for (auto& node : node_storage_) {
149       // Only optimize phi node.
150       if (!node->is_phi) {
151         continue;
152       }
153       // Skip dead nodes
154       if (node->mark_as_dead) {
155         continue;
156       }
157 
158       Node* node_ptr = node.get();
159 
160       VLOG(2) << "Optimizing: " << node_ptr->value_id;
161 
162       CHECK_GE(node_ptr->operands.size(), 1);
163 
164       // Remove self-referencing ids from users and operands.
165       auto it = absl::c_find(node_ptr->operands, node_ptr);
166       while (it != node_ptr->operands.end()) {
167         node_ptr->operands.erase(it);
168         it = absl::c_find(node_ptr->operands, node_ptr);
169       }
170 
171       it = absl::c_find(node_ptr->users, node_ptr);
172       while (it != node_ptr->users.end()) {
173         node_ptr->users.erase(it);
174         it = absl::c_find(node_ptr->users, node_ptr);
175       }
176 
177       // If all inputs to phi (after self referencing ids are removed) are the
178       // same value, replace the phi with that value.
179       //
180       // phi(A, A, ... A) => A
181       // phi(A, self) = phi(A) => A
182       CHECK_GE(node_ptr->operands.size(), 1);
183       bool all_inputs_are_same = absl::c_all_of(
184           node_ptr->operands,
185           [&](Node* elem) { return elem == node_ptr->operands[0]; });
186 
187       if (all_inputs_are_same) {
188         VLOG(1) << "All inputs to node " << node_ptr->value_id
189                 << " are the same, replacing it with "
190                 << node_ptr->operands[0]->value_id;
191         ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
192         changed = true;
193         continue;
194       }
195 
196       // Find a closure of inter-connected phis and one non-phi node. Replace
197       // all phis with that non-phi node.
198       //
199       // def A = phi(B, C)
200       // def B = phi(C, D)
201       // def C = phi(A, B)
202       // def D = non-phi
203       // Replace A, B, and C with D:
204       // A = phi(B, C) => D
205       // B = phi(C, D) => D
206       // C = phi(A, B) => D
207       if (checked_for_closure.contains(node_ptr)) {
208         continue;
209       }
210       // Keeps track of nodes in the current closure being tested.
211       absl::flat_hash_set<Node*> workset;
212       std::queue<Node*> worklist;
213       Node* non_phi = nullptr;
214       worklist.push(node_ptr);
215       while (!worklist.empty()) {
216         Node* todo = worklist.front();
217         worklist.pop();
218         if (workset.contains(todo)) {
219           continue;
220         }
221         checked_for_closure.insert(todo);
222         workset.insert(todo);
223         for (Node* operand : todo->operands) {
224           worklist.push(operand);
225         }
226         if (!todo->is_phi) {
227           if (non_phi != nullptr && non_phi != todo) {
228             // We see distinct non-phi nodes in the closure, can't apply the
229             // optimization.
230             non_phi = nullptr;
231             // Break the while loop non_phi setting to nullptr, signaling that
232             // the optimization can't be applied.
233             break;
234           } else {
235             // This is the non_phi node we are seeing so far.
236             non_phi = todo;
237           }
238         }
239       }
240       if (non_phi != nullptr) {
241         // Replace all phi nodes in the closure/workset with the non_phi node.
242         for (Node* node : workset) {
243           if (!node->is_phi) {
244             CHECK_EQ(node, non_phi);
245             continue;
246           }
247           VLOG(1) << "Replace node " << node->value_id
248                   << " in the closure with node " << non_phi->value_id;
249           ReplaceNodeWith(node, non_phi);
250           changed = true;
251         }
252       }
253     }
254   }
255 }
256 }  // namespace xla
257