1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 
40 using absl::StrAppend;
41 using absl::StrCat;
42 
shape() const43 const Shape& HloPosition::shape() const {
44   return ShapeUtil::GetSubshape(instruction->shape(), index);
45 }
46 
ToString() const47 string HloPosition::ToString() const {
48   string index_str =
49       instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
50   return StrCat(instruction->name(), index_str);
51 }
52 
operator <<(std::ostream & out,const HloPosition & position)53 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
54   out << position.ToString();
55   return out;
56 }
57 
ToString() const58 string HloUse::ToString() const {
59   string index_str = instruction->operand(operand_number)->shape().IsTuple()
60                          ? (" " + operand_index.ToString())
61                          : "";
62   return StrCat(instruction->name(), ", operand ", operand_number, index_str);
63 }
64 
operator <<(std::ostream & out,const HloUse & use)65 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
66   out << use.ToString();
67   return out;
68 }
69 
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)70 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
71                    const ShapeIndex& index, bool is_phi)
72     : BufferValue(instruction, index, id), is_phi_(is_phi) {
73   // The defining position is always the first element in the positions_ vector.
74   positions_.push_back(HloPosition{instruction, index});
75 }
76 
operator ==(const HloValue & other) const77 bool HloValue::operator==(const HloValue& other) const {
78   bool equal = defining_instruction() == other.defining_instruction() &&
79                defining_index() == other.defining_index();
80   // If the values are equal they most both be phi (or non phi).
81   CHECK(!(equal && is_phi() != other.is_phi()));
82   return equal;
83 }
84 
operator !=(const HloValue & other) const85 bool HloValue::operator!=(const HloValue& other) const {
86   return !(*this == other);
87 }
88 
ToShortString() const89 string HloValue::ToShortString() const {
90   string index_str = defining_instruction()->shape().IsTuple()
91                          ? defining_index().ToString()
92                          : "";
93   return StrCat(id(), " ", is_phi_ ? "PHI " : "",
94                 defining_instruction()->name(), index_str);
95 }
96 
ToString(int indent) const97 string HloValue::ToString(int indent) const {
98   string indentation(indent, ' ');
99   string out = StrCat(indentation, ToShortString(), ", positions:\n");
100   for (const HloPosition& position : positions()) {
101     StrAppend(&out, indentation, "  ", position.ToString(), "\n");
102   }
103   StrAppend(&out, indentation, " uses:\n");
104   for (const HloUse& use : uses()) {
105     StrAppend(&out, indentation, "  ", use.ToString(), "\n");
106   }
107   return out;
108 }
109 
110 namespace {
111 
112 // Returns true if the instruction 'user' may use the value at the given
113 // ShapeIndex in the given operand. Generally, instruction which pass through
114 // values transparently without reading the value are not considered to use the
115 // value.
MayUseOperandValue(int64 operand_number,const ShapeIndex & index,const HloInstruction * user)116 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
117                         const HloInstruction* user) {
118   switch (user->opcode()) {
119     case HloOpcode::kGetTupleElement:
120     case HloOpcode::kCopy:
121       // These instructions only access the top-level values of their
122       // operand. Non-top-level (nested) values are passed through
123       // transparently.
124       CHECK_EQ(operand_number, 0);
125       return index.empty();
126     case HloOpcode::kTupleSelect:
127       // Select does not use any nested elements of its selected-from operands
128       // (operand 1 and 2)
129       CHECK_GE(operand_number, 0);
130       CHECK_LE(operand_number, 2);
131       return operand_number == 0 || index.empty();
132 
133     case HloOpcode::kDomain:
134     case HloOpcode::kTuple:
135       // These instructions always pass through their operands transparently.
136       return false;
137 
138     case HloOpcode::kCall:
139     case HloOpcode::kWhile:
140       // Although call and while instructions pass through their operands, they
141       // are considered uses.
142       return true;
143 
144     default:
145       return true;
146   }
147 }
148 
149 }  // namespace
150 
SetPositionsAndComputeUses(absl::Span<const HloPosition> positions)151 void HloValue::SetPositionsAndComputeUses(
152     absl::Span<const HloPosition> positions) {
153   CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
154 
155   // The positions must be unique and should not contain the defining position
156   // as this is added at construction time.
157   for (const HloPosition& position_a : positions) {
158     DCHECK_NE(position_a, defining_position());
159     for (const HloPosition& position_b : positions) {
160       if (&position_a != &position_b) {
161         DCHECK_NE(position_a, position_b);
162       }
163     }
164   }
165 
166   positions_.insert(positions_.end(), positions.begin(), positions.end());
167 
168   // Gather the computation roots at which this value appears.
169   absl::flat_hash_set<HloInstruction*> root_positions;
170   for (const HloPosition& position : positions_) {
171     if (position.instruction ==
172         position.instruction->parent()->root_instruction()) {
173       root_positions.insert(position.instruction);
174     }
175   }
176 
177   // Build vector of HloUses for the value.
178   for (const HloPosition& position : positions_) {
179     for (HloInstruction* user : position.instruction->users()) {
180       for (int64 operand_number : user->OperandIndices(position.instruction)) {
181         // Root instructions of computations are considered to be uses whether
182         // or not the root instruction itself actually uses the value.
183         if (MayUseOperandValue(operand_number, position.index, user) ||
184             ContainsKey(root_positions, user)) {
185           HloUse new_use{user, operand_number, position.index};
186 
187           // The new use must not already exist in uses_.
188           for (const HloUse& use : uses_) {
189             DCHECK_NE(use, new_use);
190           }
191 
192           uses_.push_back(std::move(new_use));
193         }
194       }
195     }
196 
197     // Update liveout status of this HloValue.
198     const HloModule& module = *position.instruction->parent()->parent();
199     if (position.instruction ==
200         module.entry_computation()->root_instruction()) {
201       live_out_of_module_ = true;
202     }
203   }
204 }
205 
operator <<(std::ostream & out,const HloValue & value)206 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
207   out << value.ToShortString();
208   return out;
209 }
210 
SortAndUniquifyValues()211 void HloValueSet::SortAndUniquifyValues() {
212   absl::c_sort(values_, HloValue::IdLessThan);
213   values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
214                 values_.end());
215 }
216 
ToString() const217 string HloValueSet::ToString() const {
218   return StrCat(
219       "HloValueSet: ",
220       absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
221         result->append(value->ToShortString());
222       }));
223 }
224 
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)225 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
226   HloValueSet union_set;
227   for (const HloValueSet* input : inputs) {
228     for (const HloValue* value : input->values()) {
229       union_set.values_.push_back(value);
230     }
231   }
232   union_set.SortAndUniquifyValues();
233   if (*this != union_set) {
234     *this = union_set;
235     return true;
236   }
237   return false;
238 }
239 
AddValue(const HloValue * value)240 bool HloValueSet::AddValue(const HloValue* value) {
241   auto it = std::lower_bound(values_.begin(), values_.end(), value,
242                              HloValue::IdLessThan);
243   if (it == values_.end() || (*it)->id() != value->id()) {
244     values_.insert(it, value);
245     return true;
246   }
247   return false;  // already exists
248 }
249 
operator <<(std::ostream & out,const HloValueSet & value_set)250 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
251   out << value_set.ToString();
252   return out;
253 }
254 
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)255 bool InstructionValueSet::AssignUnionOf(
256     absl::Span<const InstructionValueSet* const> inputs) {
257   CHECK_GT(inputs.size(), 0);
258   for (int i = 1; i < inputs.size(); ++i) {
259     DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
260   }
261   bool changed = false;
262   for (auto& pair : *this) {
263     const ShapeIndex& index = pair.first;
264     HloValueSet& value_set = pair.second;
265 
266     std::vector<const HloValueSet*> input_value_sets;
267     for (const InstructionValueSet* input : inputs) {
268       input_value_sets.push_back(&input->element(index));
269     }
270     changed |= value_set.AssignUnionOf(input_value_sets);
271   }
272 
273   return changed;
274 }
275 
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)276 std::ostream& operator<<(std::ostream& out,
277                          const InstructionValueSet& instruction_value_set) {
278   out << instruction_value_set.ToString();
279   return out;
280 }
281 
ToString() const282 string InstructionValueSet::ToString() const {
283   string out =
284       StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
285   ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
286     StrAppend(&out, "  ", index.ToString(), " : ", value_set.ToString(), "\n");
287   });
288   return out;
289 }
290 
291 }  // namespace xla
292