1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace xla {
40 
41 using absl::StrAppend;
42 using absl::StrCat;
43 
shape() const44 const Shape& HloPosition::shape() const {
45   return ShapeUtil::GetSubshape(instruction->shape(), index);
46 }
47 
ToString() const48 string HloPosition::ToString() const {
49   string index_str =
50       instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
51   return StrCat(instruction->name(), index_str);
52 }
53 
operator <<(std::ostream & out,const HloPosition & position)54 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
55   out << position.ToString();
56   return out;
57 }
58 
ToString() const59 string HloUse::ToString() const {
60   string index_str = instruction->operand(operand_number)->shape().IsTuple()
61                          ? (" " + operand_index.ToString())
62                          : "";
63   return StrCat(instruction->name(), ", operand ", operand_number, index_str);
64 }
65 
operator <<(std::ostream & out,const HloUse & use)66 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
67   out << use.ToString();
68   return out;
69 }
70 
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)71 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
72                    const ShapeIndex& index, bool is_phi)
73     : BufferValue(instruction, index, id), is_phi_(is_phi) {
74   // The defining position is always the first element in the positions_ vector.
75   positions_.push_back(HloPosition{instruction, index});
76 }
77 
operator ==(const HloValue & other) const78 bool HloValue::operator==(const HloValue& other) const {
79   bool equal = defining_instruction() == other.defining_instruction() &&
80                defining_index() == other.defining_index();
81   // If the values are equal they must both be phi (or non phi).
82   CHECK(!(equal && is_phi() != other.is_phi()));
83   return equal;
84 }
85 
operator !=(const HloValue & other) const86 bool HloValue::operator!=(const HloValue& other) const {
87   return !(*this == other);
88 }
89 
ToShortString() const90 string HloValue::ToShortString() const {
91   return absl::StrFormat(
92       "<%d %s%s%s%s>", id(), instruction()->name(),
93       instruction()->shape().IsTuple() ? index().ToString() : "",
94       is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
95 }
96 
ToString(int indent) const97 string HloValue::ToString(int indent) const {
98   string indentation(indent, ' ');
99   string out =
100       StrCat(indentation, ToShortString(), "\n", indentation, " positions:\n");
101   for (const HloPosition& position : positions()) {
102     StrAppend(&out, indentation, "  ", position.ToString(), "\n");
103   }
104   StrAppend(&out, indentation, " uses:\n");
105   for (const HloUse& use : uses()) {
106     StrAppend(&out, indentation, "  ", use.ToString(), "\n");
107   }
108   return out;
109 }
110 
111 namespace {
112 
113 // Returns true if the instruction 'user' may use the value at the given
114 // ShapeIndex in the given operand. Generally, instruction which pass through
115 // values transparently without reading the value are not considered to use the
116 // value.
MayUseOperandValue(int64 operand_number,const ShapeIndex & index,const HloInstruction * user)117 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
118                         const HloInstruction* user) {
119   switch (user->opcode()) {
120     case HloOpcode::kGetTupleElement:
121     case HloOpcode::kCopy:
122       // These instructions only access the top-level values of their
123       // operand. Non-top-level (nested) values are passed through
124       // transparently.
125       CHECK_EQ(operand_number, 0);
126       return index.empty();
127     case HloOpcode::kTupleSelect:
128       // Select does not use any nested elements of its selected-from operands
129       // (operand 1 and 2)
130       CHECK_GE(operand_number, 0);
131       CHECK_LE(operand_number, 2);
132       return operand_number == 0 || index.empty();
133 
134     case HloOpcode::kDomain:
135     case HloOpcode::kTuple:
136       // These instructions always pass through their operands transparently.
137       return false;
138 
139     case HloOpcode::kCall:
140     case HloOpcode::kWhile:
141       // Although call and while instructions pass through their operands, they
142       // are considered uses.
143       return true;
144 
145     default:
146       return true;
147   }
148 }
149 
150 }  // namespace
151 
SetPositionsAndComputeUses(absl::Span<const HloPosition> positions)152 void HloValue::SetPositionsAndComputeUses(
153     absl::Span<const HloPosition> positions) {
154   CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
155 
156   // The positions must be unique and should not contain the defining position
157   // as this is added at construction time.
158   for (const HloPosition& position_a : positions) {
159     DCHECK_NE(position_a, defining_position());
160     for (const HloPosition& position_b : positions) {
161       if (&position_a != &position_b) {
162         DCHECK_NE(position_a, position_b);
163       }
164     }
165   }
166 
167   positions_.insert(positions_.end(), positions.begin(), positions.end());
168 
169   // Gather the computation roots at which this value appears.
170   absl::flat_hash_set<HloInstruction*> root_positions;
171   for (const HloPosition& position : positions_) {
172     if (position.instruction ==
173         position.instruction->parent()->root_instruction()) {
174       root_positions.insert(position.instruction);
175     }
176   }
177 
178   // Build vector of HloUses for the value.
179   for (const HloPosition& position : positions_) {
180     for (HloInstruction* user : position.instruction->users()) {
181       for (int64 i = 0; i < user->operand_count(); ++i) {
182         if (user->operand(i) != position.instruction) {
183           continue;
184         }
185 
186         // Root instructions of computations are considered to be uses whether
187         // or not the root instruction itself actually uses the value.
188         if (MayUseOperandValue(i, position.index, user) ||
189             ContainsKey(root_positions, user)) {
190           HloUse new_use{user, i, position.index};
191 
192           // The new use must not already exist in uses_.
193           for (const HloUse& use : uses_) {
194             DCHECK_NE(use, new_use);
195           }
196 
197           uses_.push_back(std::move(new_use));
198         }
199       }
200     }
201 
202     // Update liveout status of this HloValue.
203     const HloModule& module = *position.instruction->parent()->parent();
204     if (position.instruction ==
205         module.entry_computation()->root_instruction()) {
206       live_out_of_module_ = true;
207     }
208   }
209 }
210 
operator <<(std::ostream & out,const HloValue & value)211 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
212   out << value.ToShortString();
213   return out;
214 }
215 
SortAndUniquifyValues()216 void HloValueSet::SortAndUniquifyValues() {
217   absl::c_sort(values_, HloValue::IdLessThan);
218   values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
219                 values_.end());
220 }
221 
ToString() const222 string HloValueSet::ToString() const {
223   return StrCat(
224       "HloValueSet: ",
225       absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
226         result->append(value->ToShortString());
227       }));
228 }
229 
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)230 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
231   HloValueSet union_set;
232   for (const HloValueSet* input : inputs) {
233     for (const HloValue* value : input->values()) {
234       union_set.values_.push_back(value);
235     }
236   }
237   union_set.SortAndUniquifyValues();
238   if (*this != union_set) {
239     *this = union_set;
240     return true;
241   }
242   return false;
243 }
244 
AddValue(const HloValue * value)245 bool HloValueSet::AddValue(const HloValue* value) {
246   auto it = std::lower_bound(values_.begin(), values_.end(), value,
247                              HloValue::IdLessThan);
248   if (it == values_.end() || (*it)->id() != value->id()) {
249     values_.insert(it, value);
250     return true;
251   }
252   return false;  // already exists
253 }
254 
operator <<(std::ostream & out,const HloValueSet & value_set)255 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
256   out << value_set.ToString();
257   return out;
258 }
259 
IsAmbiguous() const260 bool InstructionValueSet::IsAmbiguous() const {
261   bool ambiguous = false;
262   for (auto& iter : *this) {
263     ambiguous |= iter.second.values().size() > 1;
264   }
265   return ambiguous;
266 }
267 
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)268 bool InstructionValueSet::AssignUnionOf(
269     absl::Span<const InstructionValueSet* const> inputs) {
270   CHECK_GT(inputs.size(), 0);
271   for (int i = 1; i < inputs.size(); ++i) {
272     DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
273   }
274   bool changed = false;
275   for (auto& pair : *this) {
276     const ShapeIndex& index = pair.first;
277     HloValueSet& value_set = pair.second;
278 
279     std::vector<const HloValueSet*> input_value_sets;
280     for (const InstructionValueSet* input : inputs) {
281       input_value_sets.push_back(&input->element(index));
282     }
283     changed |= value_set.AssignUnionOf(input_value_sets);
284   }
285 
286   return changed;
287 }
288 
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)289 std::ostream& operator<<(std::ostream& out,
290                          const InstructionValueSet& instruction_value_set) {
291   out << instruction_value_set.ToString();
292   return out;
293 }
294 
ToString() const295 string InstructionValueSet::ToString() const {
296   string out =
297       StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
298   ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
299     StrAppend(&out, "  ", index.ToString(), " : ", value_set.ToString(), "\n");
300   });
301   return out;
302 }
303 
304 }  // namespace xla
305