1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 class HloComputation;
33 class HloInstruction;
34 
35 // DfsHloVisitor with default action based on the HloInstruction being visited.
36 // Users should not use this class directly, but use the type aliases
37 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
38 //
39 // Do *not* add an override to this class if the opcode is covered by
40 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
41 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
42 // here will break passes which rely on the HandleElementwiseUnary/Binary
43 // handling these opcodes.
44 template <typename HloInstructionPtr>
45 class DfsHloVisitorWithDefaultBase
46     : public DfsHloVisitorBase<HloInstructionPtr> {
47  public:
DfsHloVisitorWithDefaultBase()48   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()49   ~DfsHloVisitorWithDefaultBase() override {}
50 
51   // Default action performed on HloInstruction.
52   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
53 
HandleElementwiseUnary(HloInstructionPtr hlo)54   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
55     return DefaultAction(hlo);
56   }
HandleElementwiseBinary(HloInstructionPtr hlo)57   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
58     return DefaultAction(hlo);
59   }
60 
HandleBatchNormTraining(HloInstructionPtr hlo)61   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
62     return DefaultAction(hlo);
63   }
64 
HandleBatchNormInference(HloInstructionPtr hlo)65   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
66     return DefaultAction(hlo);
67   }
68 
HandleBatchNormGrad(HloInstructionPtr hlo)69   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
70     return DefaultAction(hlo);
71   }
72 
HandleClamp(HloInstructionPtr clamp)73   Status HandleClamp(HloInstructionPtr clamp) override {
74     return DefaultAction(clamp);
75   }
HandleConcatenate(HloInstructionPtr concatenate)76   Status HandleConcatenate(HloInstructionPtr concatenate) override {
77     return DefaultAction(concatenate);
78   }
HandleSelect(HloInstructionPtr select)79   Status HandleSelect(HloInstructionPtr select) override {
80     return DefaultAction(select);
81   }
HandleTupleSelect(HloInstructionPtr tuple_select)82   Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
83     return DefaultAction(tuple_select);
84   }
HandleDot(HloInstructionPtr dot)85   Status HandleDot(HloInstructionPtr dot) override {
86     return DefaultAction(dot);
87   }
HandleConvolution(HloInstructionPtr convolution)88   Status HandleConvolution(HloInstructionPtr convolution) override {
89     return DefaultAction(convolution);
90   }
HandleFft(HloInstructionPtr fft)91   Status HandleFft(HloInstructionPtr fft) override {
92     return DefaultAction(fft);
93   }
HandleTriangularSolve(HloInstructionPtr hlo)94   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
95     return DefaultAction(hlo);
96   }
HandleCholesky(HloInstructionPtr hlo)97   Status HandleCholesky(HloInstructionPtr hlo) override {
98     return DefaultAction(hlo);
99   }
HandleAllReduce(HloInstructionPtr crs)100   Status HandleAllReduce(HloInstructionPtr crs) override {
101     return DefaultAction(crs);
102   }
HandleAllToAll(HloInstructionPtr hlo)103   Status HandleAllToAll(HloInstructionPtr hlo) override {
104     return DefaultAction(hlo);
105   }
HandleCollectivePermute(HloInstructionPtr hlo)106   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
107     return DefaultAction(hlo);
108   }
HandleReplicaId(HloInstructionPtr hlo)109   Status HandleReplicaId(HloInstructionPtr hlo) override {
110     return DefaultAction(hlo);
111   }
HandleRng(HloInstructionPtr random)112   Status HandleRng(HloInstructionPtr random) override {
113     return DefaultAction(random);
114   }
HandleInfeed(HloInstructionPtr infeed)115   Status HandleInfeed(HloInstructionPtr infeed) override {
116     return DefaultAction(infeed);
117   }
HandleOutfeed(HloInstructionPtr outfeed)118   Status HandleOutfeed(HloInstructionPtr outfeed) override {
119     return DefaultAction(outfeed);
120   }
HandleReverse(HloInstructionPtr reverse)121   Status HandleReverse(HloInstructionPtr reverse) override {
122     return DefaultAction(reverse);
123   }
HandleSort(HloInstructionPtr sort)124   Status HandleSort(HloInstructionPtr sort) override {
125     return DefaultAction(sort);
126   }
HandleConstant(HloInstructionPtr constant)127   Status HandleConstant(HloInstructionPtr constant) override {
128     return DefaultAction(constant);
129   }
HandleIota(HloInstructionPtr iota)130   Status HandleIota(HloInstructionPtr iota) override {
131     return DefaultAction(iota);
132   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)133   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
134     return DefaultAction(get_tuple_element);
135   }
HandleParameter(HloInstructionPtr parameter)136   Status HandleParameter(HloInstructionPtr parameter) override {
137     return DefaultAction(parameter);
138   }
HandleFusion(HloInstructionPtr fusion)139   Status HandleFusion(HloInstructionPtr fusion) override {
140     return DefaultAction(fusion);
141   }
HandleCall(HloInstructionPtr call)142   Status HandleCall(HloInstructionPtr call) override {
143     return DefaultAction(call);
144   }
HandleCustomCall(HloInstructionPtr custom_call)145   Status HandleCustomCall(HloInstructionPtr custom_call) override {
146     return DefaultAction(custom_call);
147   }
HandleSlice(HloInstructionPtr slice)148   Status HandleSlice(HloInstructionPtr slice) override {
149     return DefaultAction(slice);
150   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)151   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
152     return DefaultAction(dynamic_slice);
153   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)154   Status HandleDynamicUpdateSlice(
155       HloInstructionPtr dynamic_update_slice) override {
156     return DefaultAction(dynamic_update_slice);
157   }
HandleTuple(HloInstructionPtr tuple)158   Status HandleTuple(HloInstructionPtr tuple) override {
159     return DefaultAction(tuple);
160   }
HandleMap(HloInstructionPtr map)161   Status HandleMap(HloInstructionPtr map) override {
162     return DefaultAction(map);
163   }
HandleReduce(HloInstructionPtr reduce)164   Status HandleReduce(HloInstructionPtr reduce) override {
165     return DefaultAction(reduce);
166   }
HandleReduceWindow(HloInstructionPtr reduce_window)167   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
168     return DefaultAction(reduce_window);
169   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)170   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
171     return DefaultAction(select_and_scatter);
172   }
HandleBitcast(HloInstructionPtr bitcast)173   Status HandleBitcast(HloInstructionPtr bitcast) override {
174     return DefaultAction(bitcast);
175   }
HandleBroadcast(HloInstructionPtr broadcast)176   Status HandleBroadcast(HloInstructionPtr broadcast) override {
177     return DefaultAction(broadcast);
178   }
HandlePad(HloInstructionPtr pad)179   Status HandlePad(HloInstructionPtr pad) override {
180     return DefaultAction(pad);
181   }
HandleReshape(HloInstructionPtr reshape)182   Status HandleReshape(HloInstructionPtr reshape) override {
183     return DefaultAction(reshape);
184   }
HandleTranspose(HloInstructionPtr transpose)185   Status HandleTranspose(HloInstructionPtr transpose) override {
186     return DefaultAction(transpose);
187   }
HandleWhile(HloInstructionPtr xla_while)188   Status HandleWhile(HloInstructionPtr xla_while) override {
189     return DefaultAction(xla_while);
190   }
HandleConditional(HloInstructionPtr conditional)191   Status HandleConditional(HloInstructionPtr conditional) override {
192     return DefaultAction(conditional);
193   }
HandleRecv(HloInstructionPtr recv)194   Status HandleRecv(HloInstructionPtr recv) override {
195     return DefaultAction(recv);
196   }
HandleRecvDone(HloInstructionPtr recv_done)197   Status HandleRecvDone(HloInstructionPtr recv_done) override {
198     return DefaultAction(recv_done);
199   }
HandleSend(HloInstructionPtr send)200   Status HandleSend(HloInstructionPtr send) override {
201     return DefaultAction(send);
202   }
HandleSendDone(HloInstructionPtr send_done)203   Status HandleSendDone(HloInstructionPtr send_done) override {
204     return DefaultAction(send_done);
205   }
HandleGather(HloInstructionPtr gather)206   Status HandleGather(HloInstructionPtr gather) override {
207     return DefaultAction(gather);
208   }
HandleScatter(HloInstructionPtr scatter)209   Status HandleScatter(HloInstructionPtr scatter) override {
210     return DefaultAction(scatter);
211   }
HandleAfterAll(HloInstructionPtr token)212   Status HandleAfterAll(HloInstructionPtr token) override {
213     return DefaultAction(token);
214   }
HandleGetDimensionSize(HloInstructionPtr get_size)215   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
216     return DefaultAction(get_size);
217   }
HandleAddDependency(HloInstructionPtr add_dependency)218   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
219     return DefaultAction(add_dependency);
220   }
221 
222   // Invoked to inform the visitor that the traversal has completed, and that
223   // the root was "root".
FinishVisit(HloInstructionPtr)224   Status FinishVisit(HloInstructionPtr /*root*/) override {
225     return Status::OK();
226   }
227 
228  private:
229   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
230 };
231 
232 // Users should use these type aliases which are only two valid instantiations.
233 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
234 using ConstDfsHloVisitorWithDefault =
235     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
236 
237 // (Const)FunctionVisitor lets you transform an
238 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
239 //
240 // This is useful if you have code that needs to handle visitors in the form of
241 // both std::function and DfsHloVisitor.  You can wrap the function in a
242 // FunctionVisitor and then treat it like any other DfsHloVisitor.
243 template <typename HloInstructionPtr>
244 class FunctionVisitorBase
245     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
246  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)247   explicit FunctionVisitorBase(
248       std::function<Status(HloInstructionPtr)> visitor_func)
249       : visitor_func_(std::move(visitor_func)) {}
250 
DefaultAction(HloInstructionPtr hlo_instruction)251   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
252     return visitor_func_(hlo_instruction);
253   }
254 
255  private:
256   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
257 
258   std::function<Status(HloInstructionPtr)> visitor_func_;
259 };
260 
261 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
262 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
263 
264 }  // namespace xla
265 
266 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
267