1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
18 
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 #include "tensorflow/compiler/xla/service/shape_inference.h"
24 
25 namespace xla {
26 
27 // Visitor which verifies that the output shape is correctly set. Verifies
28 // against the inferred shape for the instruction.
29 // TODO(b/26024837): Check output shape for all instruction types.
30 class ShapeVerifier : public DfsHloVisitor {
31  public:
ShapeVerifier(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)32   ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision,
33                 std::function<int64(const Shape&)> shape_size_function)
34       : layout_sensitive_(layout_sensitive),
35         allow_mixed_precision_(allow_mixed_precision),
36         shape_size_function_(shape_size_function) {}
37 
38   // Verifies that entry computation layout matches parameters and root shape of
39   // the module's entry computation.
40   virtual Status VerifyEntryComputationLayout(const HloModule& module);
41 
42   Status Preprocess(HloInstruction* hlo) override;
43 
44   Status HandleElementwiseUnary(HloInstruction* hlo) override;
45   Status HandleElementwiseBinary(HloInstruction* hlo) override;
46   Status HandleClamp(HloInstruction* clamp) override;
47   Status HandleSelect(HloInstruction* select) override;
48   Status HandleTupleSelect(HloInstruction* tuple_select) override;
49   Status HandleConcatenate(HloInstruction* concatenate) override;
50   Status HandleIota(HloInstruction* hlo) override;
51   Status HandleConvert(HloInstruction* convert) override;
52   Status HandleBitcastConvert(HloInstruction* convert) override;
53   Status HandleCopy(HloInstruction* copy) override;
54   Status HandleDot(HloInstruction* dot) override;
55   Status HandleConvolution(HloInstruction* convolution) override;
56   Status HandleFft(HloInstruction* fft) override;
57   Status HandleCholesky(HloInstruction* hlo) override;
58   Status HandleTriangularSolve(HloInstruction* hlo) override;
59   Status HandleAllGather(HloInstruction* hlo) override;
60   Status HandleAllReduce(HloInstruction* hlo) override;
61   Status HandleAllToAll(HloInstruction* hlo) override;
62   Status HandleCollectivePermute(HloInstruction* hlo) override;
63   Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
64   Status HandleCollectivePermuteDone(HloInstruction* hlo) override;
65   Status HandlePartitionId(HloInstruction* hlo) override;
66   Status HandleReplicaId(HloInstruction* hlo) override;
67   Status HandleReducePrecision(HloInstruction* reduce_precision) override;
68   Status HandleInfeed(HloInstruction*) override;
69   Status HandleOutfeed(HloInstruction*) override;
70   Status HandleRng(HloInstruction*) override;
71   Status HandleRngBitGenerator(HloInstruction*) override;
72   Status HandleRngGetAndUpdateState(HloInstruction*) override;
73   Status HandleReverse(HloInstruction* reverse) override;
74   Status HandleSort(HloInstruction* sort) override;
75   Status HandleConstant(HloInstruction* constant) override;
76   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
77   Status HandleReduce(HloInstruction* reduce) override;
78   Status HandleBitcast(HloInstruction* bitcast) override;
79   Status HandleBroadcast(HloInstruction* broadcast) override;
80   Status HandleReshape(HloInstruction* reshape) override;
81   Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override;
82   Status HandleTranspose(HloInstruction* transpose) override;
83   Status HandleParameter(HloInstruction*) override;
84   Status HandleFusion(HloInstruction*) override;
85   Status HandleCall(HloInstruction* call) override;
86   Status HandleCustomCall(HloInstruction*) override;
87   Status HandleSlice(HloInstruction* slice) override;
88   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
89   Status HandleDynamicUpdateSlice(
90       HloInstruction* dynamic_update_slice) override;
91   Status HandleTuple(HloInstruction* tuple) override;
92   Status HandleMap(HloInstruction* map) override;
93   Status HandleReduceWindow(HloInstruction* reduce_window) override;
94   Status HandleSelectAndScatter(HloInstruction* instruction) override;
95   Status HandleWhile(HloInstruction* xla_while) override;
96   Status HandleConditional(HloInstruction* conditional) override;
97   Status HandlePad(HloInstruction* pad) override;
98   Status HandleCopyStart(HloInstruction* copy_start) override;
99   Status HandleCopyDone(HloInstruction* copy_done) override;
100   Status HandleSend(HloInstruction* send) override;
101   Status HandleSendDone(HloInstruction* send_done) override;
102   Status HandleRecv(HloInstruction* recv) override;
103   Status HandleRecvDone(HloInstruction* recv_done) override;
104   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
105   Status HandleBatchNormInference(
106       HloInstruction* batch_norm_inference) override;
107   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
108   Status HandleGather(HloInstruction* gather) override;
109   Status HandleScatter(HloInstruction* scatter) override;
110   Status HandleAfterAll(HloInstruction* token) override;
111   Status HandleGetDimensionSize(HloInstruction* get_size) override;
112   Status HandleSetDimensionSize(HloInstruction* set_size) override;
113   Status HandleAddDependency(HloInstruction* add_dependency) override;
114 
FinishVisit(HloInstruction *)115   Status FinishVisit(HloInstruction*) override { return Status::OK(); }
116 
117  protected:
118   // Check the instruction's shape against the shape given by ShapeInference
119   // and return an appropriate error if there is a mismatch.
120   Status CheckShape(const HloInstruction* instruction,
121                     const Shape& inferred_shape,
122                     bool only_compare_minor_to_major_in_layout = false);
123 
124   // Overload which takes a StatusOr to reduce boilerplate in the caller.
125   Status CheckShape(const HloInstruction* instruction,
126                     const StatusOr<Shape>& inferred_shape_status);
127 
128   // Check a unary (binary, etc) instruction's shape against the inferred shape.
129   Status CheckUnaryShape(const HloInstruction* instruction);
130   Status CheckBinaryShape(const HloInstruction* instruction);
131   Status CheckTernaryShape(const HloInstruction* instruction);
132   Status CheckVariadicShape(const HloInstruction* instruction);
133 
134  private:
135   // Helpers that switch on layout_sensitive_.
136   bool ShapesSame(const Shape& a, const Shape& b,
137                   bool minor_to_major_only = false,
138                   bool ignore_memory_space = false) {
139     if (!layout_sensitive_) {
140       return ShapeUtil::Compatible(a, b);
141     }
142     Shape::Equal equal;
143     if (ignore_memory_space) {
144       equal.IgnoreMemorySpaceInLayout();
145     }
146     if (minor_to_major_only) {
147       equal.MinorToMajorOnlyInLayout();
148     }
149     return equal(a, b);
150   }
151 
152   bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
153                                      bool minor_to_major_only = false) {
154     if (!layout_sensitive_) {
155       return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
156     }
157     Shape::Equal equal;
158     if (minor_to_major_only) {
159       equal.MinorToMajorOnlyInLayout();
160     }
161     equal.IgnoreFpPrecision();
162     return equal(a, b);
163   }
164 
StringifyShape(const Shape & s)165   string StringifyShape(const Shape& s) {
166     return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
167                              : ShapeUtil::HumanString(s);
168   }
169 
170   // Helpers that switch on allow_mixed_precision_.
SameElementType(const Shape & a,const Shape & b)171   bool SameElementType(const Shape& a, const Shape& b) {
172     return allow_mixed_precision_
173                ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b)
174                : ShapeUtil::SameElementType(a, b);
175   }
176 
177   // Checks that the given operand of the given instruction is of type TOKEN.
178   Status CheckIsTokenOperand(const HloInstruction* instruction,
179                              int64 operand_no);
180 
181   // Checks that the shape of the given operand of the given instruction matches
182   // the given parameter of the given computation.
183   Status CheckOperandAndParameter(const HloInstruction* instruction,
184                                   int64 operand_number,
185                                   const HloComputation* computation,
186                                   int64 parameter_number);
187 
188   // Returns true if the shapes of the two operands have the same element type,
189   // and the result shape either has the same element type as the operand shapes
190   // or mixed precision is allowed and the result shape and the operand shapes
191   // have floating point element types.
192   bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
193                                  const Shape& result_shape);
194 
195   // If the verifier is layout-sensitive, shapes must be equal to what's
196   // expected.  Otherwise, the shapes must simply be compatible.
197   bool layout_sensitive_;
198 
199   // Whether the inputs and output of an instruction can contain both F32s and
200   // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
201   // this flag.
202   bool allow_mixed_precision_;
203 
204   // Returns a target-specific shape size.
205   std::function<int64(const Shape&)> shape_size_function_;
206 };
207 
208 // An interface used to encapsulate target-specific verification quirks.
209 class TargetVerifierMetadata {
210  public:
TargetVerifierMetadata(std::function<int64 (const Shape &)> shape_size_function)211   TargetVerifierMetadata(std::function<int64(const Shape&)> shape_size_function)
212       : shape_size_function_(shape_size_function) {}
213 
214   // Returns a target-specific shape size.
ShapeSize(const Shape & shape)215   int64 ShapeSize(const Shape& shape) const {
216     return shape_size_function_(shape);
217   }
218 
219   virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
220 
221   virtual bool IsLayoutSensitive() const = 0;
222 
TargetVerifierMetadata()223   TargetVerifierMetadata() {}
~TargetVerifierMetadata()224   virtual ~TargetVerifierMetadata() {}
225 
226   TargetVerifierMetadata(const TargetVerifierMetadata&) = delete;
227   TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete;
228 
229  protected:
230   // Returns a target-specific shape size.
231   std::function<int64(const Shape&)> shape_size_function_;
232 };
233 
234 // The default implementation of TargetVerifierMetadata, used unless the target
235 // needs to override it.
236 class DefaultVerifierMetadata : public TargetVerifierMetadata {
237  public:
DefaultVerifierMetadata(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)238   DefaultVerifierMetadata(
239       bool layout_sensitive, bool allow_mixed_precision,
240       std::function<int64(const Shape&)> shape_size_function)
241       : TargetVerifierMetadata(shape_size_function),
242         layout_sensitive_(layout_sensitive),
243         allow_mixed_precision_(allow_mixed_precision) {}
244 
245   // Creates a ShapeVerifier that checks that shapes match inferred
246   // expectations. This creates a new verifier every time because ShapeVerifier,
247   // being a DfsHloVisitor, is stateful. We want a clean object for each run of
248   // the verifier.
GetVerifier()249   std::unique_ptr<ShapeVerifier> GetVerifier() const override {
250     return absl::make_unique<ShapeVerifier>(
251         layout_sensitive_, allow_mixed_precision_, shape_size_function_);
252   }
253 
IsLayoutSensitive()254   bool IsLayoutSensitive() const override { return layout_sensitive_; }
255 
256  private:
257   bool layout_sensitive_;
258   bool allow_mixed_precision_;
259 };
260 
261 // HLO pass that verifies invariants of HLO instructions for each computation in
262 // the module.
263 class HloVerifier : public HloModulePass {
264  public:
265   explicit HloVerifier(
266       bool layout_sensitive, bool allow_mixed_precision,
267       std::function<bool(const HloInstruction*)>
268           instruction_can_change_layout_func = {},
269       std::function<int64(const Shape&)> shape_size_func =
270           [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); })
target_metadata_(absl::make_unique<DefaultVerifierMetadata> (layout_sensitive,allow_mixed_precision,shape_size_func))271       : target_metadata_(absl::make_unique<DefaultVerifierMetadata>(
272             layout_sensitive, allow_mixed_precision, shape_size_func)),
273         instruction_can_change_layout_func_(
274             std::move(instruction_can_change_layout_func)) {
275     CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
276   }
277 
278   // Uses custom target metadata
HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)279   explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)
280       : target_metadata_(std::move(target_metadata)) {}
281 
282   ~HloVerifier() override = default;
name()283   absl::string_view name() const override { return "verifier"; }
284 
285   // Never returns true; no instructions are ever modified by this pass.
286   StatusOr<bool> Run(HloModule* module) override;
287 
288  private:
289   std::unique_ptr<TargetVerifierMetadata> target_metadata_;
290 
291   // Determines whether an instruction can change layouts.
292   std::function<bool(const HloInstruction*)>
293       instruction_can_change_layout_func_;
294 };
295 
296 }  // namespace xla
297 
298 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
299