1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
18 
19 #include <map>
20 #include <unordered_map>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 
36 // FusedIrEmitter is used to generate code for fusion nodes.
37 //
38 // Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM
39 // Module, FusedIrEmitter is better understood as "IR generator generator".
40 // FusedIrEmitter recursively creates a generator (a host function) which the
41 // compiler can invoke at a later time.  Invoking the generator emits LLVM IR
42 // that, when run, produces the value at a particular index of the output.
43 //
44 // After building this generator, the compiler creates a loop (or its moral
45 // equivalent, e.g. a GPU kernel) and calls the generator from within the loop.
46 // This generates code that produces each element of the output.
47 //
48 // This class handles both vanilla fusion and multi-output fusion.  In the MOF
49 // case, the fusion node ends with a kTuple instruction, and the generator
50 // created produces an LLVM struct with N elements, one for each element of the
51 // arrays in the tuple.  It follows that the arrays in the tuple must have the
52 // same length.
53 class FusedIrEmitter {
54  public:
55   using IndexedGenerator = llvm_ir::ElementGenerator;
56 
FusedIrEmitter(ElementalIrEmitter * elemental_emitter)57   explicit FusedIrEmitter(ElementalIrEmitter* elemental_emitter)
58       : elemental_emitter_(elemental_emitter),
59         b_(elemental_emitter->b()),
60         module_(elemental_emitter->module()) {}
61 
BindGenerator(const HloInstruction * hlo,llvm_ir::ElementGenerator generator)62   void BindGenerator(const HloInstruction* hlo,
63                      llvm_ir::ElementGenerator generator) {
64     indexed_generators_[hlo] = std::move(generator);
65   }
66 
67   // Returns the generator function for the given instruction.
68   StatusOr<IndexedGenerator> GetGenerator(const HloInstruction* instruction);
69 
70   // Evaluates whether fusing 'producer' into 'consumer' might cause exponential
71   // behavior in FusedIrEmitter. We currently can have exponential time/memory
72   // requirements for emitting certain fusion kernels, in which case we don't
73   // want to fuse.
74   // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
75   static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
76                                           const HloInstruction* producer);
77 
78  private:
79   Status DefaultAction(const HloInstruction* hlo);
80 
81   Status HandleConstant(const HloInstruction* constant);
82 
83   Status HandleGetTupleElement(const HloInstruction* get_tuple_element);
84 
85   Status HandleParameter(const HloInstruction* parameter);
86 
87   // Emits the ir value for each element in the tuple.
88   Status HandleTuple(const HloInstruction* tuple);
89 
90   ElementalIrEmitter* elemental_emitter_;
91 
92   // Borrowed
93   llvm::IRBuilder<>* b_;
94   llvm::Module* module_;
95 
96   // Map from instructions to functions that generate code for the output
97   // elements. If an instruction is a GetTupleElement instruction, the
98   // instruction produces non-tuple result.
99   std::unordered_map<const HloInstruction*, IndexedGenerator>
100       indexed_generators_;
101 
102   // Cache of generated values, lest we regenerate an element of a node with
103   // multiple outgoing edges
104   absl::flat_hash_map<
105       const HloInstruction*,
106       absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
107       generated_value_cache_;
108 };
109 
110 }  // namespace xla
111 
112 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
113