1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
18 
19 #define _USE_MATH_DEFINES
20 
21 #include <functional>
22 #include <memory>
23 
24 #include "absl/container/node_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/shape_inference.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/macros.h"
39 
40 namespace xla {
41 
42 // Responsible for evaluating HLO and obtain literal as the evaluation results.
43 //
44 // This class is not thread-safe.
45 class HloEvaluator : public DfsHloVisitorWithDefault {
46  public:
47   // Only evaluate up to max_loop_iterations per while-loop execution if
48   // specified.
49   explicit HloEvaluator(int64 max_loop_iterations = -1);
50 
51   // Evaluates an HLO module and an array of pointers to literals.  Returns the
52   // evaluated result as a literal if successful.
53   //
54   // Precondition: The indices of arg_literals correspond to the parameter
55   // numbers of the HLO parameters in the computation. See comment below for an
56   // example.
57   //
58   // (Dummy template arg is to reduce the overloading priority of one overload
59   // so that Evaluate(module, {}) resolves unambiguously.)
Evaluate(const HloModule & module,absl::Span<const Literal * const> arg_literals)60   StatusOr<Literal> Evaluate(const HloModule& module,
61                              absl::Span<const Literal* const> arg_literals) {
62     return Evaluate(*module.entry_computation(), arg_literals);
63   }
64   template <typename Dummy = void>
Evaluate(const HloModule & module,absl::Span<const Literal> arg_literals)65   StatusOr<Literal> Evaluate(const HloModule& module,
66                              absl::Span<const Literal> arg_literals) {
67     return Evaluate(*module.entry_computation(), arg_literals);
68   }
69 
70   // Evaluates an HLO computation and an array of pointers to literals.
71   // Returns the evaluated result as a literal if successful.
72   // Precondition: The indices of arg_literals correspond to the parameter
73   // numbers of the HLO parameters in the computation. For e.g., consider the
74   // following graph:
75   //
76   //                *
77   //            /       \
78   //            +     Parameter1
79   //        /      \
80   //       /        \
81   //    Parameter0  Constant
82   //
83   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
84   // 1 in this computation. The input literals array will then have its first
85   // literal map to Parameter0 and the second map to Parameter1.
86   //
87   // (Dummy template arg is to reduce the overloading priority of one overload
88   // so that Evaluate(module, {}) resolves unambiguously.)
89   StatusOr<Literal> Evaluate(const HloComputation& computation,
90                              absl::Span<const Literal* const> arg_literals);
91   template <typename Dummy = void>
Evaluate(const HloComputation & computation,absl::Span<const Literal> arg_literals)92   StatusOr<Literal> Evaluate(const HloComputation& computation,
93                              absl::Span<const Literal> arg_literals) {
94     std::vector<const Literal*> arg_literal_ptrs;
95     for (const auto& l : arg_literals) {
96       arg_literal_ptrs.push_back(&l);
97     }
98     return Evaluate(computation, arg_literal_ptrs);
99   }
100 
101   // Gets the value of running a single HLO instruction.
102   //
103   // All of the operands to this instruction must be constants.
104   StatusOr<Literal> Evaluate(HloInstruction* instruction);
105 
106   // Same as Evaluate, except returning false on error and accepts an output
107   // pointer.
108   bool TryEvaluate(HloInstruction* instruction, Literal* result);
109 
110   // Evaluates a single HLO instruction, substituting the given literals for
111   // some of the instruction's operands.
112   //
113   // For example, given instruction = op(A, B, C) and the map
114   // {A = x, C = y}, this evaluates op(x, B, y).
115   StatusOr<Literal> EvaluateWithSubstitutions(
116       const HloInstruction* instruction,
117       const std::unordered_map<const HloInstruction*, const Literal*>&
118           substitutions);
119 
120   StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
121                                                 const Literal& lhs,
122                                                 const Literal& rhs);
123 
124   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
125                                                const Literal& operand);
126 
127   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
128                                   const PrecisionConfig& precision_config,
129                                   const Literal& lhs, const Literal& rhs);
130 
set_dynamic_dimension_inference(DynamicDimensionInference * dynamic_dimension_inference)131   void set_dynamic_dimension_inference(
132       DynamicDimensionInference* dynamic_dimension_inference) {
133     dynamic_dimension_inference_ = dynamic_dimension_inference;
134   }
135 
dynamic_dimension_inference()136   DynamicDimensionInference* dynamic_dimension_inference() {
137     return dynamic_dimension_inference_;
138   }
139 
140   // Enable the fast path for certain operations like dot or convolution.
set_use_fast_path(bool value)141   void set_use_fast_path(bool value) { use_fast_path_ = value; }
142 
143   // Handles evaluation of a custom-call op.
144   // Operand literals are provided in |operands| and implementations must
145   // populate |output| before returning.
146   using CustomCallHandler = std::function<StatusOr<Literal>(
147       HloInstruction* custom_call, absl::Span<const Literal*> operands)>;
148 
149   // Sets a handler that is called during evaluation for custom-call ops.
150   // If no handler is defined the default error behavior will occur. The handler
151   // will be provided evaluated literals for all operands and is expected to
152   // return an output literal of the appropriate shape.
set_custom_call_handler(std::function<StatusOr<Literal> (HloInstruction * custom_call,absl::Span<const Literal * > operands)> handler)153   void set_custom_call_handler(
154       std::function<StatusOr<Literal>(HloInstruction* custom_call,
155                                       absl::Span<const Literal*> operands)>
156           handler) {
157     custom_call_handler_ = std::move(handler);
158   }
159 
160   // Returns the result of a matrix multiply `lhs x rhs`.
161   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
162       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
163   static std::unique_ptr<Array2D<float>> MatmulArray2D(
164       const Array2D<float>& lhs, const Array2D<float>& rhs);
165   static std::unique_ptr<Array2D<double>> MatmulArray2D(
166       const Array2D<double>& lhs, const Array2D<double>& rhs);
167   static std::unique_ptr<Array2D<std::complex<float>>> MatmulArray2D(
168       const Array2D<std::complex<float>>& lhs,
169       const Array2D<std::complex<float>>& rhs);
170   static std::unique_ptr<Array2D<std::complex<double>>> MatmulArray2D(
171       const Array2D<std::complex<double>>& lhs,
172       const Array2D<std::complex<double>>& rhs);
173   static std::unique_ptr<Array2D<int32>> MatmulArray2D(
174       const Array2D<int32>& lhs, const Array2D<int32>& rhs);
175 
176  protected:
177   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
178   // class.
179   //
180   // A straightforward implementation would be to make it a nested class
181   // declared and defined in hlo_evaluator.cc.  Instead HloEvaluatorTypedVisitor
182   // lives as a separate class with its own header because its template gets
183   // instantiated many times and we want to use extern templates to shard out
184   // the compilation of those instantiations across multiple cc files.
185   template <typename ReturnT, typename ElementwiseT>
186   friend class HloEvaluatorTypedVisitor;
187 
188   // Wraps around instruction handling to infer types before dispatching to
189   // the corresponding typed Visitor.
DefaultAction(HloInstruction * hlo)190   Status DefaultAction(HloInstruction* hlo) override {
191     return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
192   }
193 
194   Status Preprocess(HloInstruction* hlo) override;
195 
196   Status Postprocess(HloInstruction* hlo) override;
197 
198   // Operations that are type-agnostic or always return a specific type, such as
199   // HandleIsFinite where boolean is always returned.
200   //
201   Status HandleBitcast(HloInstruction* bitcast) override;
202 
203   Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override;
204 
205   Status HandleSetDimensionSize(HloInstruction* set_dimension_size) override;
206 
207   Status HandleParameter(HloInstruction* parameter) override;
208 
209   Status HandleConstant(HloInstruction* constant) override;
210 
211   Status HandleConcatenate(HloInstruction* concatenate) override;
212 
213   Status HandleReshape(HloInstruction* reshape) override;
214 
215   Status HandleTranspose(HloInstruction* transpose) override;
216 
217   Status HandleIsFinite(HloInstruction* is_finite) override;
218 
219   Status HandleCompare(HloInstruction* compare) override;
220 
221   Status HandleTuple(HloInstruction* tuple) override;
222 
223   Status HandleFft(HloInstruction* fft) override;
224 
225   Status HandleGather(HloInstruction* gather) override;
226 
227   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
228 
229   Status HandleCopy(HloInstruction* copy) override;
230 
231   Status HandleCopyStart(HloInstruction* copy_start) override;
232 
233   Status HandleCopyDone(HloInstruction* copy_done) override;
234 
235   Status HandleConditional(HloInstruction* conditional) override;
236 
237   Status HandleCall(HloInstruction* call) override;
238 
239   Status HandleFusion(HloInstruction* fusion) override;
240 
241   Status HandleWhile(HloInstruction* while_hlo) override;
242 
243   Status HandleSelect(HloInstruction* select) override;
244 
245   Status HandleTupleSelect(HloInstruction* tuple_select) override;
246 
247   Status HandleBroadcast(HloInstruction* broadcast) override;
248 
249   Status HandleAfterAll(HloInstruction* after_all) override;
250 
251   Status HandleAddDependency(HloInstruction* add_dependency) override;
252 
253   Status HandleSort(HloInstruction* sort) override;
254 
255   Status HandleReal(HloInstruction* real) override;
256 
257   Status HandleImag(HloInstruction* imag) override;
258 
259   Status HandleComplex(HloInstruction* complex) override;
260 
261   Status HandleReduce(HloInstruction* reduce) override;
262 
263   Status HandleReduceWindow(HloInstruction* hlo) override;
264 
265   Status HandleCustomCall(HloInstruction* custom_call) override;
266 
267   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
268   // expanded in a semantic-preserving way into other HLOs by adding expansion
269   // HLO pass to the HLO optimization pass during compilation, which can then be
270   // handled by the evaluator.
HandleBatchNormGrad(HloInstruction * batch_norm_grad)271   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
272     return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
273   };
HandleBatchNormInference(HloInstruction * batch_norm_inference)274   Status HandleBatchNormInference(
275       HloInstruction* batch_norm_inference) override {
276     return Unimplemented(
277         "BatchNormInference HLO is unsupported by the evaluator.");
278   };
HandleBatchNormTraining(HloInstruction * batch_norm_training)279   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
280     return Unimplemented(
281         "BatchNormTraining HLO is unsupported by the evaluator.");
282   };
HandleInfeed(HloInstruction * infeed)283   Status HandleInfeed(HloInstruction* infeed) override {
284     return Unimplemented("Infeed HLO is unsupported by the evaluator.");
285   };
HandleOutfeed(HloInstruction * outfeed)286   Status HandleOutfeed(HloInstruction* outfeed) override {
287     return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
288   };
289 
290   // Returns the already-evaluated literal result for the instruction.
291   //
292   // A Constant instruction is considered evaluated and its literal will be
293   // returned directly without looking up the cache.
294   //
295   // Similarly, a Parameter instruction is considered evaluated and its literal
296   // is looked up in arg_literals.
297   //
298   // Crash with log if the given instruction has not been evaluated previously.
GetEvaluatedLiteralFor(const HloInstruction * hlo)299   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
300     if (hlo->IsConstant()) {
301       return hlo->literal();
302     }
303     if (hlo->opcode() == HloOpcode::kParameter) {
304       return *arg_literals_.at(hlo->parameter_number());
305     }
306     auto it = evaluated_.find(hlo);
307     CHECK(it != evaluated_.end())
308         << "could not find evaluated value for: " << hlo->ToString();
309     return it->second;
310   }
311 
312   // Tracks the HLO instruction and its evaluated literal result.
313   //
314   // Parameters and constants aren't stored here, see implementation of
315   // GetEvaluatedLiteralFor.
316   //
317   // TODO(b/35950897): have better memory management here to free instructions
318   // that are no longer a parent for any other subsequent instruction in
319   // post-ordering.
320   //
321   // Must be cleared for each evaluation.
322   //
323   // Storing Literal in place requires the container to have pointer stability
324   // so we cannot use flat_hash_map any more.
325   absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
326 
327   // Use fast path that uses eigen in the evaluator.
328   bool use_fast_path_ = false;
329 
330  private:
331   template <typename ReturnT, typename NativeT>
ElementWiseUnaryOpImpl(HloInstruction * instruction,const std::function<ReturnT (NativeT)> & unary_op,const Literal & operand_literal)332   static StatusOr<Literal> ElementWiseUnaryOpImpl(
333       HloInstruction* instruction,
334       const std::function<ReturnT(NativeT)>& unary_op,
335       const Literal& operand_literal) {
336     const auto shape = instruction->shape();
337     const auto* operand = instruction->operand(0);
338     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
339 
340     Literal result(shape);
341     TF_RETURN_IF_ERROR(
342         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
343           return unary_op(operand_literal.Get<NativeT>(multi_index));
344         }));
345     return std::move(result);
346   }
347 
348   // Map from a primitive type to its associated (templated) DfsHloVisitor.
349   std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
350 
351   // Caches pointers to input literals, assuming they are in post-order.
352   // Literals are not owned by this class, and they must outlive the lifetime of
353   // each invocation to the Evaluate* method.
354   // Must be cleared for each evaluation.
355   std::vector<const Literal*> arg_literals_;
356 
357   // Max loop iterations to execute with no maximum if negative.
358   int64 max_loop_iterations_ = 0;
359 
360   // Module-level seed handle.
361   uint64 seed_ = 0;
362   // RNG engine.
363   std::minstd_rand0 engine_;
364 
365   // DynamicDimensionInference is used to evaluate GetDimensionSize, which
366   // returns the dynamic dimension size of its operand.
367   DynamicDimensionInference* dynamic_dimension_inference_ = nullptr;
368 
369   // Optional handler for custom_call ops.
370   std::function<StatusOr<Literal>(HloInstruction* custom_call,
371                                   absl::Span<const Literal*> operands)>
372       custom_call_handler_;
373 
374   TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
375 };
376 
377 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs,
378                                               const Array2D<float>& rhs);
379 }  // namespace xla
380 
381 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
382