1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Analysis for determining the possible set of values for all positions
17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
18 // tracking values across computation boundaries.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
22 
23 #include <iterator>
24 #include <memory>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/service/call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
36 #include "tensorflow/compiler/xla/service/hlo_value.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 
44 namespace xla {
45 
46 // Analysis which identifies all HLO values and their uses in an HLO module.
47 class HloDataflowAnalysis {
48  public:
49   // Infrastructure for passing may-alias hints: HLO passes can populate the
50   // may-alias table. If an empty optional is returned, default rules are used.
51   //
52   // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
53   // overriden using backend-specific overrides.
54   //
55   // The first parameter of the function should be the instruction, the
56   // second parameter should be an operand of the instruction. The third
57   // parameter should be the output index of the instruction.
58   using CanShareBuffer = std::function<absl::optional<bool>(
59       const HloInstruction* instr, const HloInstruction* operand,
60       const ShapeIndex& user_index)>;
61 
62   // Runs dataflow analysis on the given module. Parameters:
63   //
64   //   ssa_form : If true then new values are defined at the merge points of
65   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
66   //     values".  The merge is formed by the init value and loop backedge. The
67   //     SSA form is minimal in that a new phi value is defined only if the
68   //     merge point is reachable by multiple different values. The SSA form is
69   //     also in loop-closed form in that no values defined inside of a loop
70   //     (while body) is used outside of the loop. Example use of this ssa_form
71   //     mode is to reason about live range interference of buffers.
72   //
73   //     If ssa_form is false, then merge points do not define new
74   //     values. Rather, the HloValueSet for the merge point contains the union
75   //     of the merged HloValues.
76   //
77   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
78   //     a new HLO value in the analysis. If false then Bitcast forwards the
79   //     value of its operand.
80   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
81       const HloModule& module, bool ssa_form = false,
82       bool bitcast_defines_value = false,
83       const CanShareBuffer& can_share_buffer = nullptr);
84 
85   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
86 
87   // Returns true if 'instruction' defines an HLO value at the given shape index
88   // of its output.
89   bool ValueIsDefinedAt(const HloInstruction* instruction,
90                         const ShapeIndex& index = {}) const;
91 
92   // Returns the HloValue defined by 'instruction' at the given shape index of
93   // its output.
94   //
95   // Precondition: ValueIsDefinedAt is true for this instruction and index.
96   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
97                                     const ShapeIndex& index = {}) const;
98   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
99                               const ShapeIndex& index = {});
100 
101   // Returns the InstructionValueSet for the given instruction.
102   const InstructionValueSet& GetInstructionValueSet(
103       const HloInstruction* instruction) const;
104   InstructionValueSet& GetInstructionValueSet(
105       const HloInstruction* instruction);
106 
107   // Returns all values that are contained in the output of this instruction in
108   // a flattened set.
109   HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const;
110 
111   // Returns the HloValueSet for the given instruction at the given index or the
112   // given position.
113   const HloValueSet& GetValueSet(const HloInstruction* instruction,
114                                  const ShapeIndex& index = {}) const;
115   const HloValueSet& GetValueSet(const HloPosition& position) const;
116   HloValueSet& GetValueSet(const HloPosition& position);
117   HloValueSet& GetValueSet(const HloInstruction* instruction,
118                            const ShapeIndex& index = {});
119 
120   // Returns the unique value in the HloValueSet at the given instruction and
121   // shape index. CHECKs if the value set does not contain a exactly one value.
122   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
123                                    const ShapeIndex& index = {}) const {
124     return GetValueSet(instruction, index).GetUniqueValue();
125   }
126   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
127                              const ShapeIndex& index = {}) {
128     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
129   }
130 
131   // Returns the HloValue with the given Id.
132   const HloValue& GetValue(HloValue::Id value_id) const;
133   HloValue& GetValue(HloValue::Id value_id);
134 
135   // Returns the total number of HloValues.
value_count()136   int64 value_count() const { return values_.size(); }
137 
138   // Returns a vector of all HloValues stabily sorted by HloValue::Id.
values()139   const std::vector<HloValue*>& values() const { return values_vector_; }
140 
141   // Returns the call graph used for computing the dataflow.
call_graph()142   const CallGraph& call_graph() const { return *call_graph_; }
143 
144   string ToString() const;
145 
146   // Returns true if 'user' cannot possibly use the buffer at 'index' in
147   // 'operand'. Returns false otherwise.
148   //
149   // 'operand' does not have to be an operand of 'user'. This can be the
150   // case with indirect uses.
151   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
152                                const ShapeIndex& index,
153                                const HloInstruction* user) const;
154 
155   // Returns true if 'user' (at 'user_index') can share a buffer with its
156   // operand 'operand' (at 'operand_index'). Returns false otherwise.
157   //
158   // REQUIRES: 'operand' is an operand of 'user'.
159   bool CanShareOperandBufferWithUser(HloInstruction* operand,
160                                      const ShapeIndex& operand_index,
161                                      HloInstruction* user,
162                                      const ShapeIndex& user_index) const;
163 
module()164   const HloModule& module() const { return module_; }
165 
166   // Returns true if the operation is an in-place operation and its operand 0
167   // must alias with the output.
168   static bool IsInPlaceOperation(HloOpcode opcode);
169 
170   // Returns a vector consisting of the HloUse (operand number and shape index)
171   // and output shape index of the in-place operations within this HLO.
172   static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs(
173       HloInstruction* instruction);
174 
175  protected:
176   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
177                       bool bitcast_defines_value = false,
178                       const CanShareBuffer& can_share_buffer = nullptr);
179 
180   // 1. During value propagation (Propagate function), always create phi
181   // values once it see multiple inputs merging at the same point. It then
182   // records those phi values as well as their inputs in a phi graph.
183   //
184   // 2. Post value propagation, Dataflow analysis can then do certain
185   // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi
186   // nodes.
187   //
188   // Note that this applies in SSA form, and Both of the functions are
189   // guaranteed to exit.
190   //
191   void OptimizePhiValues();
192 
193   // Returns a new HloValue defined at the given instruction and shape index.
194   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
195                         bool is_phi);
196 
197   // Marks the HloValue with the given ID for deletion.
198   void MarkValueForDeletion(HloValue::Id value_id);
199 
200   // Deletes all HloValues marked for deletion. Should be called after
201   // propagation is complete.
202   void DeleteMarkedValues();
203 
204   // Constructs and initializes the InstructionValueSets of all instructions to
205   // contain exactly the HloValues defined by each instruction. These values can
206   // then propagated throughout the HLO graph by calling Propagate.
207   Status InitializeInstructionValueSets();
208 
209   // Updates the value set of the given instruction based on the values flowing
210   // into the instruction (operands and cross-computation dataflow).
211   bool UpdateInstructionValueSet(HloInstruction* instruction);
212 
213   // Updates the value set for a particular instruction type. Returns whether
214   // the instruction value set changed.
215   bool UpdateBitcastValueSet(HloInstruction* bitcast);
216   bool UpdateCallValueSet(HloInstruction* call);
217   bool UpdateConditionalValueSet(HloInstruction* conditional);
218   bool UpdateCopyValueSet(HloInstruction* copy);
219   bool UpdateCustomCallValueSet(HloInstruction* custom_call);
220   bool UpdateDomainValueSet(HloInstruction* domain);
221   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
222   bool UpdateParameterValueSet(HloInstruction* parameter);
223   bool UpdateCopyStartValueSet(HloInstruction* copy_start);
224   bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
225   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
226   bool UpdateTupleSelectValueSet(HloInstruction* select);
227   bool UpdateSendValueSet(HloInstruction* send);
228   bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size);
229   bool UpdateTupleValueSet(HloInstruction* tuple);
230   bool UpdateWhileValueSet(HloInstruction* xla_while);
231   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
232   bool UpdateCollectivePermuteStartValueSet(
233       HloInstruction* collective_permute_start);
234   bool UpdateCollectivePermuteDoneValueSet(
235       HloInstruction* collective_permute_done);
236 
237   // Propagates the dataflow through the module. In particular, it propagates
238   // the HloValueSet from its defining instruction to the users of the
239   // instructions.
240   void Propagate();
241 
242   // Returns the result of the SSA Phi function applied to the given inputs at
243   // the given instruction.
244   bool Phi(HloInstruction* instruction,
245            absl::Span<const InstructionValueSet* const> inputs);
246 
247   // Updates the positions of the HloValues in the output of the given
248   // instruction. This should be called after the instruction value set of
249   // 'instruction' has been changed. 'prev_value_set' must point to the previous
250   // state of the value set prior to the change. 'prev_value_set' may be null if
251   // this is the first time positions are being computed. The previous state is
252   // necessary to efficiently remove positions which have been eliminated due to
253   // changes in the instructions' InstructionValueSet.
254   void UpdatePositionsOfValuesAt(
255       HloInstruction* instruction, const InstructionValueSet& new_value_set,
256       const InstructionValueSet* prev_value_set = nullptr);
257 
258   // Verifies various invariants of the dataflow analysis.
259   Status Verify() const;
260 
261   const HloModule& module_;
262   const bool ssa_form_;
263   const bool bitcast_defines_value_;
264 
265   std::unique_ptr<CallGraph> call_graph_;
266 
267   // The map of all HloValues in the module. We pass around pointers to the
268   // mapped HloValues, so the underlying container must keep them valid despite
269   // mutations touching other map entries.
270   std::unordered_map<HloValue::Id, HloValue> values_;
271 
272   // A map from instruction to InstructionValueSet.
273   std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
274 
275   // Values marked for deletion during construction. We don't delete them
276   // immediately because references to them may remain in ValueSets temporarily
277   // during propagation. After construction, these values are deleted.
278   std::vector<HloValue::Id> value_ids_to_delete_;
279 
280   // A vector containing all HloValues sorted by HloValue::Id.
281   std::vector<HloValue*> values_vector_;
282 
283   // The Id to use for the next HloValue.
284   HloValue::Id next_value_id_ = 0;
285 
286   // An explicit graph holding phi values and edges.
287   PhiGraph phi_graph_;
288 
289   // Backend specific function that decides whether an instruction can share
290   // a buffer with its operand.
291   CanShareBuffer can_share_buffer_ = nullptr;
292 };
293 
294 }  // namespace xla
295 
296 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
297