1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
18 
19 #include <type_traits>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 class HloComputation;
37 class HloInstruction;
38 
39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
40 // instruction, all its operands were already visited. User code can subclass
41 // this to iterate over an HloInstruction DAG. The Handle* routines have
42 // operands / data unpacked for ease of use in the visitor subclass.
43 //
44 // No instruction will ever be visited twice; however, the root instruction will
45 // be reported again when the traversal is done via a call to FinishVisit.
46 //
47 // A subclass must override at least
48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
50 // The default Handle methods for (unary, binary) ops call
51 // (HandleElementwiseUnary, HandleElementwiseBinary).
52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
53 // "unimplemented" error status.
54 //
55 // Note: this may change to an iterator in the future for flexibility purposes.
56 //
57 // Users should not use this class directly, but use the type-aliases
58 // DfsHloVisitor/ConstDfsHloVisitor instead.
59 template <typename HloInstructionPtr>
60 class DfsHloVisitorBase {
61   static_assert(
62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
64       "Template argument expected to be HloInstruction* or const "
65       "HloInstruction*");
66 
67  public:
DfsHloVisitorBase()68   DfsHloVisitorBase() {}
~DfsHloVisitorBase()69   virtual ~DfsHloVisitorBase() {}
70 
71   // These routines are self-descriptive, see class comment for usage
72   // information.
73 
74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
76 
77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
79   virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
HandleMaximum(HloInstructionPtr hlo)80   virtual Status HandleMaximum(HloInstructionPtr hlo) {
81     return HandleElementwiseBinary(hlo);
82   }
HandleMinimum(HloInstructionPtr hlo)83   virtual Status HandleMinimum(HloInstructionPtr hlo) {
84     return HandleElementwiseBinary(hlo);
85   }
86   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
HandleConvert(HloInstructionPtr hlo)87   virtual Status HandleConvert(HloInstructionPtr hlo) {
88     return HandleElementwiseUnary(hlo);
89   }
HandleBitcastConvert(HloInstructionPtr hlo)90   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
91     return HandleElementwiseUnary(hlo);
92   }
HandleCopy(HloInstructionPtr hlo)93   virtual Status HandleCopy(HloInstructionPtr hlo) {
94     return HandleElementwiseUnary(hlo);
95   }
HandleComplex(HloInstructionPtr hlo)96   virtual Status HandleComplex(HloInstructionPtr hlo) {
97     return HandleElementwiseBinary(hlo);
98   }
HandleMultiply(HloInstructionPtr hlo)99   virtual Status HandleMultiply(HloInstructionPtr hlo) {
100     return HandleElementwiseBinary(hlo);
101   }
102   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
HandlePower(HloInstructionPtr hlo)103   virtual Status HandlePower(HloInstructionPtr hlo) {
104     return HandleElementwiseBinary(hlo);
105   }
HandleSqrt(HloInstructionPtr hlo)106   virtual Status HandleSqrt(HloInstructionPtr hlo) {
107     return HandleElementwiseUnary(hlo);
108   }
HandleRsqrt(HloInstructionPtr hlo)109   virtual Status HandleRsqrt(HloInstructionPtr hlo) {
110     return HandleElementwiseUnary(hlo);
111   }
HandleCbrt(HloInstructionPtr hlo)112   virtual Status HandleCbrt(HloInstructionPtr hlo) {
113     return HandleElementwiseUnary(hlo);
114   }
115   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
116   virtual Status HandleFft(HloInstructionPtr fft) = 0;
117   virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
118   virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
119   virtual Status HandleAllGather(HloInstructionPtr hlo) = 0;
120   virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
121   virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
122   virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
123   virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
124   virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
125   virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
126   virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
127   virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
128   virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0;
HandleCompare(HloInstructionPtr hlo)129   virtual Status HandleCompare(HloInstructionPtr hlo) {
130     return HandleElementwiseBinary(hlo);
131   }
HandleAdd(HloInstructionPtr hlo)132   virtual Status HandleAdd(HloInstructionPtr hlo) {
133     return HandleElementwiseBinary(hlo);
134   }
HandleDivide(HloInstructionPtr hlo)135   virtual Status HandleDivide(HloInstructionPtr hlo) {
136     return HandleElementwiseBinary(hlo);
137   }
HandleRemainder(HloInstructionPtr hlo)138   virtual Status HandleRemainder(HloInstructionPtr hlo) {
139     return HandleElementwiseBinary(hlo);
140   }
HandleSubtract(HloInstructionPtr hlo)141   virtual Status HandleSubtract(HloInstructionPtr hlo) {
142     return HandleElementwiseBinary(hlo);
143   }
HandleAbs(HloInstructionPtr hlo)144   virtual Status HandleAbs(HloInstructionPtr hlo) {
145     return HandleElementwiseUnary(hlo);
146   }
HandleAtan2(HloInstructionPtr hlo)147   virtual Status HandleAtan2(HloInstructionPtr hlo) {
148     return HandleElementwiseBinary(hlo);
149   }
HandleRound(HloInstructionPtr hlo)150   virtual Status HandleRound(HloInstructionPtr hlo) {
151     return HandleElementwiseUnary(hlo);
152   }
HandleLogistic(HloInstructionPtr hlo)153   virtual Status HandleLogistic(HloInstructionPtr hlo) {
154     return HandleElementwiseUnary(hlo);
155   }
HandleSign(HloInstructionPtr hlo)156   virtual Status HandleSign(HloInstructionPtr hlo) {
157     return HandleElementwiseUnary(hlo);
158   }
HandleNegate(HloInstructionPtr hlo)159   virtual Status HandleNegate(HloInstructionPtr hlo) {
160     return HandleElementwiseUnary(hlo);
161   }
HandleExp(HloInstructionPtr hlo)162   virtual Status HandleExp(HloInstructionPtr hlo) {
163     return HandleElementwiseUnary(hlo);
164   }
HandleExpm1(HloInstructionPtr hlo)165   virtual Status HandleExpm1(HloInstructionPtr hlo) {
166     return HandleElementwiseUnary(hlo);
167   }
HandleFloor(HloInstructionPtr hlo)168   virtual Status HandleFloor(HloInstructionPtr hlo) {
169     return HandleElementwiseUnary(hlo);
170   }
HandleCeil(HloInstructionPtr hlo)171   virtual Status HandleCeil(HloInstructionPtr hlo) {
172     return HandleElementwiseUnary(hlo);
173   }
HandleLog(HloInstructionPtr hlo)174   virtual Status HandleLog(HloInstructionPtr hlo) {
175     return HandleElementwiseUnary(hlo);
176   }
HandleClz(HloInstructionPtr hlo)177   virtual Status HandleClz(HloInstructionPtr hlo) {
178     return HandleElementwiseUnary(hlo);
179   }
HandleLog1p(HloInstructionPtr hlo)180   virtual Status HandleLog1p(HloInstructionPtr hlo) {
181     return HandleElementwiseUnary(hlo);
182   }
HandleCos(HloInstructionPtr hlo)183   virtual Status HandleCos(HloInstructionPtr hlo) {
184     return HandleElementwiseUnary(hlo);
185   }
HandleSin(HloInstructionPtr hlo)186   virtual Status HandleSin(HloInstructionPtr hlo) {
187     return HandleElementwiseUnary(hlo);
188   }
HandleTanh(HloInstructionPtr hlo)189   virtual Status HandleTanh(HloInstructionPtr hlo) {
190     return HandleElementwiseUnary(hlo);
191   }
HandleReal(HloInstructionPtr hlo)192   virtual Status HandleReal(HloInstructionPtr hlo) {
193     return HandleElementwiseUnary(hlo);
194   }
HandleImag(HloInstructionPtr hlo)195   virtual Status HandleImag(HloInstructionPtr hlo) {
196     return HandleElementwiseUnary(hlo);
197   }
HandleIsFinite(HloInstructionPtr hlo)198   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
199     return HandleElementwiseUnary(hlo);
200   }
HandleAnd(HloInstructionPtr hlo)201   virtual Status HandleAnd(HloInstructionPtr hlo) {
202     return HandleElementwiseBinary(hlo);
203   }
HandleNot(HloInstructionPtr hlo)204   virtual Status HandleNot(HloInstructionPtr hlo) {
205     return HandleElementwiseUnary(hlo);
206   }
HandleOr(HloInstructionPtr hlo)207   virtual Status HandleOr(HloInstructionPtr hlo) {
208     return HandleElementwiseBinary(hlo);
209   }
HandleXor(HloInstructionPtr hlo)210   virtual Status HandleXor(HloInstructionPtr hlo) {
211     return HandleElementwiseBinary(hlo);
212   }
HandlePopulationCount(HloInstructionPtr hlo)213   virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
214     return HandleElementwiseUnary(hlo);
215   }
HandleShiftLeft(HloInstructionPtr hlo)216   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
217     return HandleElementwiseBinary(hlo);
218   }
HandleShiftRightArithmetic(HloInstructionPtr hlo)219   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
220     return HandleElementwiseBinary(hlo);
221   }
HandleShiftRightLogical(HloInstructionPtr hlo)222   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
223     return HandleElementwiseBinary(hlo);
224   }
225 
HandleReducePrecision(HloInstructionPtr hlo)226   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
227     return HandleElementwiseUnary(hlo);
228   }
229 
HandleDomain(HloInstructionPtr hlo)230   virtual Status HandleDomain(HloInstructionPtr hlo) {
231     return HandleElementwiseUnary(hlo);
232   }
233 
234   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
235   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
236   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
237   virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0;
238   virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0;
239   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
240   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
241   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
242   virtual Status HandleIota(HloInstructionPtr hlo) = 0;
243   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
244   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
245   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
246   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
247   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
248   virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0;
249   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
250   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
251   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
252   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
253   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
254   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
255   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
256   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
257   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
258   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
259   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
260   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
261   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
262   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
263   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
264   virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
265 
266   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
267 
268   virtual Status HandleCopyStart(HloInstructionPtr copy_start) = 0;
269   virtual Status HandleCopyDone(HloInstructionPtr copy_done) = 0;
270 
271   virtual Status HandleSend(HloInstructionPtr send) = 0;
272   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
273 
274   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
275   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
276 
277   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
278 
279   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
280 
281   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
282 
283   virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
284   virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
285 
286   // Invoked to inform the visitor that the traversal has completed, and that
287   // the root was "root".
288   virtual Status FinishVisit(HloInstructionPtr root) = 0;
289 
290   // 3 possible visitation states of HLO instructions. Each instruction's
291   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
292   enum VisitState {
293     kNotVisited = 0,
294     kVisiting = 1,
295     kVisited = 2,
296   };
297 
GetVisitState(int id)298   VisitState GetVisitState(int id) {
299     auto iter = visit_state_.find(id);
300     if (iter == visit_state_.end()) {
301       return VisitState::kNotVisited;
302     }
303     return iter->second;
304   }
305   VisitState GetVisitState(const HloInstruction& instruction);
306 
307   // Resize internal state if necessary to hold state for ids <= num.
308   // This call is purely a performance hint and can be omitted without
309   // affecting correctness.
ReserveVisitStates(int num)310   void ReserveVisitStates(int num) { visit_state_.reserve(num); }
VisitStateCapacity()311   size_t VisitStateCapacity() const { return visit_state_.capacity(); }
312 
313   // Useful when we want to visit the same computation more than once with the
314   // same visitor.
ResetVisitStates()315   void ResetVisitStates() {
316     // Clear the map, but don't resize the capacity across uses -- Calculating
317     // and reserving space could be expensive, and we always use the same
318     // module->instruction_count() as the capacity.
319     visit_state_.erase(visit_state_.begin(), visit_state_.end());
320   }
321 
322   // Useful when we want to free up the memory used by the visit state without
323   // destroying the actual visitor subclass.
DestroyVisitState()324   void DestroyVisitState() {
325     visit_state_ = absl::flat_hash_map<int, VisitState>{};
326   }
327 
SetVisitState(int id,VisitState state)328   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
329 
330   // Sets the visitation state of the given instruction as kVisiting.
331   //
332   // Precondition: current state must be kNotVisited.
333   void SetVisiting(const HloInstruction& instruction);
334 
335   // Sets the visitation state of the given instruction as kVisited.
336   //
337   // Precondition: current state must be either kNotVisited or kVisiting.
338   void SetVisited(const HloInstruction& instruction);
339 
340   // Returns whether the state of the given instruction is kVisiting.
IsVisiting(const HloInstruction & instruction)341   bool IsVisiting(const HloInstruction& instruction) {
342     return GetVisitState(instruction) == kVisiting;
343   }
344 
345   // Returns whether the state of the given instruction is kVisited.
DidVisit(const HloInstruction & instruction)346   bool DidVisit(const HloInstruction& instruction) {
347     return GetVisitState(instruction) == kVisited;
348   }
349 
350   // Returns whether the state of the given instruction is kNotVisited.
NotVisited(const HloInstruction & instruction)351   bool NotVisited(const HloInstruction& instruction) {
352     return GetVisitState(instruction) == kNotVisited;
353   }
354 
355   // This method should be overridden by subclasses that wish to run some
356   // operation on an op before its Handle* visitor method is called.
357   //
358   // For any HLO op, the order of calls is:
359   //
360   //   Preprocess(op);
361   //   Handle/OpType/(op);
362   //   Postprocess(op);
363   //
364   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
365   // own preprocessing.
366   virtual Status Preprocess(HloInstructionPtr hlo);
367 
368   // This method should be overridden by subclasses that wish to run some
369   // operation on an op after its Handle* visitor method is called. See
370   // Preprocess for more details.
371   //
372   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
373   // own postprocessing.
374   virtual Status Postprocess(HloInstructionPtr hlo);
375 
376  private:
377   absl::flat_hash_map<int, VisitState> visit_state_;
378 
379   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
380 };
381 
382 // Explicit instantiations in dfs_hlo_visitor.cc.
383 extern template class DfsHloVisitorBase<HloInstruction*>;
384 extern template class DfsHloVisitorBase<const HloInstruction*>;
385 
386 // Users should use one of these two type aliases, which are the only two valid
387 // instantiations of DfsHloVisitorBase.
388 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
389 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
390 
391 }  // namespace xla
392 
393 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
394