1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
18 
19 #include <stddef.h>
20 
21 #include <string>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/shape_tree.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
37 // Abstraction which identifies a specific point in the XLA graph. An
38 // HloPosition specifies a ShapeIndex within the output of a specific
39 // instruction.
40 struct HloPosition {
41   HloInstruction* instruction;
42   ShapeIndex index;
43 
44   // Returns the shape at this position.
45   const Shape& shape() const;
46 
47   string ToString() const;
48 
49   bool operator==(const HloPosition& other) const {
50     return instruction == other.instruction && index == other.index;
51   }
52   bool operator!=(const HloPosition& other) const { return !(*this == other); }
53 
54   // Stable less-than operator using instruction id and index.
55   bool operator<(const HloPosition& other) const {
56     return instruction->unique_id() < other.instruction->unique_id() ||
57            (instruction->unique_id() == other.instruction->unique_id() &&
58             index < other.index);
59   }
60 
61   template <typename H>
AbslHashValueHloPosition62   friend H AbslHashValue(H h, const HloPosition& pos) {
63     return H::combine(std::move(h), pos.instruction->Hash(), pos.index);
64   }
65 };
66 
67 std::ostream& operator<<(std::ostream& out, const HloPosition& position);
68 
69 // Defines a single use of an HLO value.
70 struct HloUse {
71   // Instruction at which the value is used.
72   HloInstruction* instruction;
73 
74   // The operand number in which the value is appears.
75   int64 operand_number;
76 
77   // The shape index within the operand in which the value appears.
78   ShapeIndex operand_index;
79 
80   string ToString() const;
81 
82   bool operator==(const HloUse& other) const {
83     return instruction == other.instruction &&
84            operand_number == other.operand_number &&
85            operand_index == other.operand_index;
86   }
87 
88   bool operator!=(const HloUse& other) const { return !(*this == other); }
89 
90   template <typename H>
AbslHashValueHloUse91   friend H AbslHashValue(H h, const HloUse& use) {
92     return H::combine(std::move(h), use.instruction, use.operand_index,
93                       use.operand_number);
94   }
95 };
96 
97 std::ostream& operator<<(std::ostream& out, const HloUse& use);
98 
99 // HloDataflowAnalysis uses this subclass of BufferValue.
100 class HloValue : public BufferValue {
101  public:
102   // Predicate comparing HloValues by increasing id, useful for std::sort.
IdLessThan(const HloValue * a,const HloValue * b)103   static bool IdLessThan(const HloValue* a, const HloValue* b) {
104     return a->id() < b->id();
105   }
106 
107   // Predicate comparing HloValues by equal id, useful for std::unique.
IdEqual(const HloValue * a,const HloValue * b)108   static bool IdEqual(const HloValue* a, const HloValue* b) {
109     return a->id() == b->id();
110   }
111 
112   // Construct an HloValue defined by 'instruction' at shape index 'index'. If
113   // is_phi is true, then this value is a phi value, for example, at the
114   // parameter of a while body computation. Phi values are only used in the SSA
115   // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
116   HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
117            bool is_phi = false);
~HloValue()118   ~HloValue() override {}
119 
120   // Sets the positions in the module at which the HloValue appears. Updates
121   // uses. Should be called once and only once. The defining position should not
122   // be included in 'positions' as this is set at construction time.
123   void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
124 
125   // Returns whether this value is a phi value.
is_phi()126   bool is_phi() const { return is_phi_; }
127 
128   // Return the position where this value is defined.
defining_position()129   const HloPosition& defining_position() const { return positions_[0]; }
130 
131   // Return the instruction which defines this HloValue.
defining_instruction()132   HloInstruction* defining_instruction() const {
133     return defining_position().instruction;
134   }
135 
instruction()136   HloInstruction* instruction() const override {
137     return defining_instruction();
138   }
139 
140   // Return the shape index at which this HloValue is defined in the output of
141   // its defining instruction.
defining_index()142   const ShapeIndex& defining_index() const { return defining_position().index; }
143 
index()144   const ShapeIndex& index() const override { return defining_index(); }
145 
146   // Return the shape of this HloValue.
shape()147   const Shape& shape() const override { return defining_position().shape(); }
148 
149   // Return all positions of the HloValue in the module.
positions()150   const std::vector<HloPosition>& positions() const { return positions_; }
151 
152   // Return all uses of the HloValue.
uses()153   const std::vector<HloUse>& uses() const { return uses_; }
154 
155   // Get whether this HloValue is live out of the module.
live_out_of_module()156   bool live_out_of_module() const { return live_out_of_module_; }
157 
158   bool operator==(const HloValue& other) const;
159   bool operator!=(const HloValue& other) const;
160 
161   // Return a single-line string representation of the value.
162   string ToShortString() const;
163 
164   string ToString(int indent) const;
165 
ToString()166   string ToString() const override { return ToString(0); }
167 
168  private:
169   // Whether this instruction is a phi value.
170   const bool is_phi_;
171 
172   // The set of positions of this HloValue. The first element is always the
173   // position of the definition.
174   std::vector<HloPosition> positions_;
175 
176   // The set of uses of this HloValue.
177   std::vector<HloUse> uses_;
178 
179   // Whether this value is live out of the HLO module.
180   bool live_out_of_module_ = false;
181 };
182 
183 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
184 
185 // A class representing the possible set of HloValues at a particular point
186 // (shape index in the output of an instruction) in the XLA graph. This set
187 // contains the set of reaching HloValue definitions. For a simple array-shaped
188 // instruction like Add, the HloValueSet of the top-level of the instruction's
189 // output trivially contains only the HloValue defined by the instruction. For
190 // instructions which have non-trivial dataflow such as Tuple or Select, the
191 // HloValueSets of the instruction's output contains one or more HloValues
192 // defined by the instruction's operands or defined further up in the XLA graph.
193 class HloValueSet {
194  public:
195   HloValueSet() = default;
196 
HloValueSet(absl::Span<const HloValue * const> values)197   explicit HloValueSet(absl::Span<const HloValue* const> values)
198       : values_(values.begin(), values.end()) {
199     SortAndUniquifyValues();
200   }
201 
202   // Sets this value set to the union of the given value sets. Returns whether
203   // this value set changed.
204   bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
205 
206   // Return the vector of HloValues in the set. Values in the vector are unique
207   // and stably sorted by value id.
values()208   const std::vector<const HloValue*>& values() const { return values_; }
209 
210   // Adds the value to the set.  Returns true iff the value was added and didn't
211   // already exist in the set.
212   bool AddValue(const HloValue* value);
213 
214   // Clear all values from the set.
Clear()215   void Clear() { values_.clear(); }
216 
217   // Return the unique HLO value in the set. CHECKs if the set does not contain
218   // exactly one value.
GetUniqueValue()219   const HloValue& GetUniqueValue() const {
220     CHECK_EQ(values_.size(), 1);
221     return *values_[0];
222   }
223 
224   bool operator==(const HloValueSet& other) const {
225     if (values_.size() != other.values_.size()) return false;
226     for (size_t i = 0; i < values_.size(); ++i) {
227       if (values_[i]->id() != other.values_[i]->id()) {
228         return false;
229       }
230     }
231     return true;
232   }
233   bool operator!=(const HloValueSet& other) const { return !(*this == other); }
234 
235   string ToString() const;
236 
237  private:
238   // Sorts value_ and removes duplicates. This should be called after adding any
239   // elements to values_.
240   void SortAndUniquifyValues();
241 
242   // HloValues sorted by HloValue::Id.
243   std::vector<const HloValue*> values_;
244 };
245 
246 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
247 
248 // A class collecting the HloValues which might be contained in the output of
249 // an HLO instruction. For array-shaped instructions, an InstructionValueSet
250 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
251 // hold multiple HloValueSets.
252 class InstructionValueSet : public ShapeTree<HloValueSet> {
253  public:
InstructionValueSet(const Shape & shape)254   explicit InstructionValueSet(const Shape& shape)
255       : ShapeTree<HloValueSet>(shape) {}
256 
257   // Sets this value set to the union of the given value sets. Returns whether
258   // this value set changed.
259   bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
260 
261   // Returns true if any value sets for any subshape element is not a
262   // singleton.
263   bool IsAmbiguous() const;
264 
265   string ToString() const;
266 };
267 
268 std::ostream& operator<<(std::ostream& out,
269                          const InstructionValueSet& instruction_value_set);
270 
271 }  // namespace xla
272 
273 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
274