1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
18 
19 #include <type_traits>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 class HloComputation;
37 class HloInstruction;
38 
39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
40 // instruction, all its operands were already visited. User code can subclass
41 // this to iterate over an HloInstruction DAG. The Handle* routines have
42 // operands / data unpacked for ease of use in the visitor subclass.
43 //
44 // No instruction will ever be visited twice; however, the root instruction will
45 // be reported again when the traversal is done via a call to FinishVisit.
46 //
47 // A subclass must override at least
48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
50 // The default Handle methods for (unary, binary) ops call
51 // (HandleElementwiseUnary, HandleElementwiseBinary).
52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
53 // "unimplemented" error status.
54 //
55 // Note: this may change to an iterator in the future for flexibility purposes.
56 //
57 // Users should not use this class directly, but use the type-aliases
58 // DfsHloVisitor/ConstDfsHloVisitor instead.
59 template <typename HloInstructionPtr>
60 class DfsHloVisitorBase {
61   static_assert(
62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
64       "Template argument expected to be HloInstruction* or const "
65       "HloInstruction*");
66 
67  public:
DfsHloVisitorBase()68   DfsHloVisitorBase() {}
~DfsHloVisitorBase()69   virtual ~DfsHloVisitorBase() {}
70 
71   // These routines are self-descriptive, see class comment for usage
72   // information.
73 
74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
76 
77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
79   virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
HandleMaximum(HloInstructionPtr hlo)80   virtual Status HandleMaximum(HloInstructionPtr hlo) {
81     return HandleElementwiseBinary(hlo);
82   }
HandleMinimum(HloInstructionPtr hlo)83   virtual Status HandleMinimum(HloInstructionPtr hlo) {
84     return HandleElementwiseBinary(hlo);
85   }
86   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
HandleConvert(HloInstructionPtr hlo)87   virtual Status HandleConvert(HloInstructionPtr hlo) {
88     return HandleElementwiseUnary(hlo);
89   }
HandleBitcastConvert(HloInstructionPtr hlo)90   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
91     return HandleElementwiseUnary(hlo);
92   }
HandleCopy(HloInstructionPtr hlo)93   virtual Status HandleCopy(HloInstructionPtr hlo) {
94     return HandleElementwiseUnary(hlo);
95   }
HandleComplex(HloInstructionPtr hlo)96   virtual Status HandleComplex(HloInstructionPtr hlo) {
97     return HandleElementwiseBinary(hlo);
98   }
HandleMultiply(HloInstructionPtr hlo)99   virtual Status HandleMultiply(HloInstructionPtr hlo) {
100     return HandleElementwiseBinary(hlo);
101   }
102   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
HandlePower(HloInstructionPtr hlo)103   virtual Status HandlePower(HloInstructionPtr hlo) {
104     return HandleElementwiseBinary(hlo);
105   }
HandleSqrt(HloInstructionPtr hlo)106   virtual Status HandleSqrt(HloInstructionPtr hlo) {
107     return HandleElementwiseUnary(hlo);
108   }
HandleRsqrt(HloInstructionPtr hlo)109   virtual Status HandleRsqrt(HloInstructionPtr hlo) {
110     return HandleElementwiseUnary(hlo);
111   }
112   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
113   virtual Status HandleFft(HloInstructionPtr fft) = 0;
114   virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
115   virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
116   virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
117   virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
118   virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
119   virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
120   virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
HandleCompare(HloInstructionPtr hlo)121   virtual Status HandleCompare(HloInstructionPtr hlo) {
122     return HandleElementwiseBinary(hlo);
123   }
HandleAdd(HloInstructionPtr hlo)124   virtual Status HandleAdd(HloInstructionPtr hlo) {
125     return HandleElementwiseBinary(hlo);
126   }
HandleDivide(HloInstructionPtr hlo)127   virtual Status HandleDivide(HloInstructionPtr hlo) {
128     return HandleElementwiseBinary(hlo);
129   }
HandleRemainder(HloInstructionPtr hlo)130   virtual Status HandleRemainder(HloInstructionPtr hlo) {
131     return HandleElementwiseBinary(hlo);
132   }
HandleSubtract(HloInstructionPtr hlo)133   virtual Status HandleSubtract(HloInstructionPtr hlo) {
134     return HandleElementwiseBinary(hlo);
135   }
HandleAbs(HloInstructionPtr hlo)136   virtual Status HandleAbs(HloInstructionPtr hlo) {
137     return HandleElementwiseUnary(hlo);
138   }
HandleAtan2(HloInstructionPtr hlo)139   virtual Status HandleAtan2(HloInstructionPtr hlo) {
140     return HandleElementwiseBinary(hlo);
141   }
HandleRound(HloInstructionPtr hlo)142   virtual Status HandleRound(HloInstructionPtr hlo) {
143     return HandleElementwiseUnary(hlo);
144   }
HandleSign(HloInstructionPtr hlo)145   virtual Status HandleSign(HloInstructionPtr hlo) {
146     return HandleElementwiseUnary(hlo);
147   }
HandleNegate(HloInstructionPtr hlo)148   virtual Status HandleNegate(HloInstructionPtr hlo) {
149     return HandleElementwiseUnary(hlo);
150   }
HandleExp(HloInstructionPtr hlo)151   virtual Status HandleExp(HloInstructionPtr hlo) {
152     return HandleElementwiseUnary(hlo);
153   }
HandleExpm1(HloInstructionPtr hlo)154   virtual Status HandleExpm1(HloInstructionPtr hlo) {
155     return HandleElementwiseUnary(hlo);
156   }
HandleFloor(HloInstructionPtr hlo)157   virtual Status HandleFloor(HloInstructionPtr hlo) {
158     return HandleElementwiseUnary(hlo);
159   }
HandleCeil(HloInstructionPtr hlo)160   virtual Status HandleCeil(HloInstructionPtr hlo) {
161     return HandleElementwiseUnary(hlo);
162   }
HandleLog(HloInstructionPtr hlo)163   virtual Status HandleLog(HloInstructionPtr hlo) {
164     return HandleElementwiseUnary(hlo);
165   }
HandleClz(HloInstructionPtr hlo)166   virtual Status HandleClz(HloInstructionPtr hlo) {
167     return HandleElementwiseUnary(hlo);
168   }
HandleLog1p(HloInstructionPtr hlo)169   virtual Status HandleLog1p(HloInstructionPtr hlo) {
170     return HandleElementwiseUnary(hlo);
171   }
HandleCos(HloInstructionPtr hlo)172   virtual Status HandleCos(HloInstructionPtr hlo) {
173     return HandleElementwiseUnary(hlo);
174   }
HandleSin(HloInstructionPtr hlo)175   virtual Status HandleSin(HloInstructionPtr hlo) {
176     return HandleElementwiseUnary(hlo);
177   }
HandleTanh(HloInstructionPtr hlo)178   virtual Status HandleTanh(HloInstructionPtr hlo) {
179     return HandleElementwiseUnary(hlo);
180   }
HandleReal(HloInstructionPtr hlo)181   virtual Status HandleReal(HloInstructionPtr hlo) {
182     return HandleElementwiseUnary(hlo);
183   }
HandleImag(HloInstructionPtr hlo)184   virtual Status HandleImag(HloInstructionPtr hlo) {
185     return HandleElementwiseUnary(hlo);
186   }
HandleIsFinite(HloInstructionPtr hlo)187   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
188     return HandleElementwiseUnary(hlo);
189   }
HandleAnd(HloInstructionPtr hlo)190   virtual Status HandleAnd(HloInstructionPtr hlo) {
191     return HandleElementwiseBinary(hlo);
192   }
HandleNot(HloInstructionPtr hlo)193   virtual Status HandleNot(HloInstructionPtr hlo) {
194     return HandleElementwiseUnary(hlo);
195   }
HandleOr(HloInstructionPtr hlo)196   virtual Status HandleOr(HloInstructionPtr hlo) {
197     return HandleElementwiseBinary(hlo);
198   }
HandleXor(HloInstructionPtr hlo)199   virtual Status HandleXor(HloInstructionPtr hlo) {
200     return HandleElementwiseBinary(hlo);
201   }
HandleShiftLeft(HloInstructionPtr hlo)202   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
203     return HandleElementwiseBinary(hlo);
204   }
HandleShiftRightArithmetic(HloInstructionPtr hlo)205   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
206     return HandleElementwiseBinary(hlo);
207   }
HandleShiftRightLogical(HloInstructionPtr hlo)208   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
209     return HandleElementwiseBinary(hlo);
210   }
211 
HandleReducePrecision(HloInstructionPtr hlo)212   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
213     return HandleElementwiseUnary(hlo);
214   }
215 
HandleDomain(HloInstructionPtr hlo)216   virtual Status HandleDomain(HloInstructionPtr hlo) {
217     return HandleElementwiseUnary(hlo);
218   }
219 
220   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
221   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
222   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
223   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
224   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
225   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
226   virtual Status HandleIota(HloInstructionPtr hlo) = 0;
227   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
228   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
229   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
230   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
231   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
232   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
233   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
234   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
235   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
236   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
237   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
238   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
239   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
240   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
241   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
242   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
243   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
244   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
245   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
246   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
247   virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
248 
249   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
250 
251   virtual Status HandleSend(HloInstructionPtr send) = 0;
252   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
253 
254   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
255   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
256 
257   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
258 
259   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
260 
261   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
262 
263   virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
264   virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
265 
266   // Invoked to inform the visitor that the traversal has completed, and that
267   // the root was "root".
268   virtual Status FinishVisit(HloInstructionPtr root) = 0;
269 
270   // 3 possible visitation states of HLO instructions. Each instruction's
271   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
272   enum VisitState {
273     kNotVisited = 0,
274     kVisiting = 1,
275     kVisited = 2,
276   };
277 
GetVisitState(int id)278   VisitState GetVisitState(int id) {
279     auto iter = visit_state_.find(id);
280     if (iter == visit_state_.end()) {
281       return VisitState::kNotVisited;
282     }
283     return iter->second;
284   }
285   VisitState GetVisitState(const HloInstruction& instruction);
286 
287   // Resize internal state if necessary to hold state for ids <= num.
288   // This call is purely a performance hint and can be omitted without
289   // affecting correctness.
ReserveVisitStates(int num)290   void ReserveVisitStates(int num) { visit_state_.reserve(num); }
291 
292   // Useful when we want to visit the same computation more than once with the
293   // same visitor.
ResetVisitStates()294   void ResetVisitStates() { visit_state_.clear(); }
295 
SetVisitState(int id,VisitState state)296   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
297 
298   // Sets the visitation state of the given instruction as kVisiting.
299   //
300   // Precondition: current state must be kNotVisited.
301   void SetVisiting(const HloInstruction& instruction);
302 
303   // Sets the visitation state of the given instruction as kVisited.
304   //
305   // Precondition: current state must be either kNotVisited or kVisiting.
306   void SetVisited(const HloInstruction& instruction);
307 
308   // Returns whether the state of the given instruction is kVisiting.
IsVisiting(const HloInstruction & instruction)309   bool IsVisiting(const HloInstruction& instruction) {
310     return GetVisitState(instruction) == kVisiting;
311   }
312 
313   // Returns whether the state of the given instruction is kVisited.
DidVisit(const HloInstruction & instruction)314   bool DidVisit(const HloInstruction& instruction) {
315     return GetVisitState(instruction) == kVisited;
316   }
317 
318   // Returns whether the state of the given instruction is kNotVisited.
NotVisited(const HloInstruction & instruction)319   bool NotVisited(const HloInstruction& instruction) {
320     return GetVisitState(instruction) == kNotVisited;
321   }
322 
323   // This method should be overridden by subclasses that wish to run some
324   // operation on an op before its Handle* visitor method is called.
325   //
326   // For any HLO op, the order of calls is:
327   //
328   //   Preprocess(op);
329   //   Handle/OpType/(op);
330   //   Postprocess(op);
331   //
332   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
333   // own preprocessing.
334   virtual Status Preprocess(HloInstructionPtr hlo);
335 
336   // This method should be overridden by subclasses that wish to run some
337   // operation on an op after its Handle* visitor method is called. See
338   // Preprocess for more details.
339   //
340   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
341   // own postprocessing.
342   virtual Status Postprocess(HloInstructionPtr hlo);
343 
344  private:
345   absl::flat_hash_map<int, VisitState> visit_state_;
346 
347   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
348 };
349 
350 // Users should use one of these two type aliases, which are the only two valid
351 // instantiations of DfsHloVisitorBase.
352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
354 
355 }  // namespace xla
356 
357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
358