1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
18 
19 #include <stdint.h>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/IR/BasicBlock.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace llvm {
41 class FastMathFlags;
42 class TargetOptions;
43 };
44 
45 namespace xla {
46 namespace llvm_ir {
47 
48 // Convert a absl::string_view to a llvm::StringRef. Note: both
49 // absl::string_view and llvm::StringRef are non-owning pointers into a
50 // string in memory. This method is used to feed strings to LLVM
51 // & Clang APIs that expect llvm::StringRef.
AsStringRef(absl::string_view str)52 inline llvm::StringRef AsStringRef(absl::string_view str) {
53   return llvm::StringRef(str.data(), str.size());
54 }
55 
56 template <typename T>
AsArrayRef(const std::vector<T> & vec)57 llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
58   return llvm::ArrayRef<T>(vec.data(), vec.size());
59 }
60 
61 template <typename T>
AsArrayRef(const absl::Span<const T> & slice)62 llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
63   return llvm::ArrayRef<T>(slice.data(), slice.size());
64 }
65 
66 // Dump the given LLVM entity to a string. This works for Types and Values.
67 template <typename T>
DumpToString(const T & entity)68 string DumpToString(const T& entity) {
69   std::string buffer_string;
70   llvm::raw_string_ostream ostream(buffer_string);
71   entity.print(ostream);
72   ostream.flush();
73   return buffer_string;
74 }
75 
76 // Dump the given LLVM module to a string. This requires a function distinct
77 // from DumpToString because the signatures of the print() methods for Values
78 // and Modules are slightly different.
79 string DumpModuleToString(const llvm::Module& module);
80 
81 // Constructs a human-friendly name from the given inputs.  The result is
82 // suitable for use as an llvm::Value's name.
83 //
84 // This is equivalent to
85 //
86 //   - changing the HloInstruction* to its name() (if we called that overload),
87 //   - joining all of the nonempty inputs by '.', and then
88 //   - removing all '%'s.
89 //
90 string IrName(absl::string_view a);
91 string IrName(absl::string_view a, absl::string_view b);
92 string IrName(const HloInstruction* a, absl::string_view b = "");
93 
94 // Removes special characters from a function name.
95 //
96 // Note that this can cause different inputs to map to the same output, so after
97 // sanitizing a function name, you must run it through a uniquer.
98 string SanitizeFunctionName(string function_name);
99 
100 // Emits a call to the specified intrinsic with the given operands. Overloaded
101 // intrinsics (for example, "minnum") must include a type in overloaded_types
102 // for each overloaded type. Typically, overloaded intrinsics have only a single
103 // overloaded type.
104 llvm::CallInst* EmitCallToIntrinsic(
105     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
106     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
107     absl::string_view name = "");
108 
109 // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
110 // fcmp+select otherwise
111 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
112                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
113                           absl::string_view name = "");
114 
115 // Emit float min. Emit minnum intrinsic is fast math is disabled, or
116 // fcmp+select otherwise
117 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
118                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
119                           absl::string_view name = "");
120 
121 // Convenience methods for emitting a GEP instruction that indexes into a buffer
122 // (1-dimensional array), equivalent to array[index]. The type is automatically
123 // determined from the element type of the array.  The int64 index overload
124 // wraps the index in a i64 llvm::Value.
125 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
126                                    llvm::IRBuilder<>* b);
127 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
128                                    llvm::IRBuilder<>* b);
129 
130 // Returns the LLVM type which represents the given XLA primitive type.
131 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
132                                   llvm::Module* module);
133 
134 // Returns the type size in bits. If "type" is a struct, it must be packed.
135 int GetSizeInBits(llvm::Type* type);
136 
137 // Returns the LLVM type which represents the given XLA shape. For example,
138 // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
139 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
140 
141 // Returns a value that represents a pointer to a global string constant that
142 // encodes the shape as a serialized protobuf.
143 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
144                                                          int32* shape_size,
145                                                          llvm::IRBuilder<>* b);
146 
147 // Converts a given literal to an IR Constant. Literals have known constant
148 // values at IR emission time.
149 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
150                                            llvm::Module* module);
151 
152 // Allocates a tile of shared memory.
153 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
154                                                llvm::Type* tile_type,
155                                                absl::string_view name);
156 
157 // Inserts an allocate of the requested type at the entry point of the
158 // function that the builder is currently building. The insert point
159 // of the builder is set to the same place after calling this function
160 // as before.
161 //
162 // This can be useful to avoid e.g. executing an alloca every time
163 // through a loop.
164 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
165                                             absl::string_view name,
166                                             llvm::IRBuilder<>* b,
167                                             int alignment = 0);
168 
169 // As EmitAllocaAtFunctionEntry, but allocates element_count entries
170 // instead of a single element.
171 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
172                                                      llvm::Value* element_count,
173                                                      absl::string_view name,
174                                                      llvm::IRBuilder<>* b,
175                                                      int alignment = 0);
176 
177 // Creates a basic block with the same context and function as for the
178 // builder. Inserts at the end of the function if insert_before is
179 // null.
180 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
181                                    absl::string_view name,
182                                    llvm::IRBuilder<>* b);
183 
184 // Struct with data on a conditional branch in a diamond shape created
185 // via EmitIfThenElse.
186 struct LlvmIfData {
187   // The block that has the conditional branch.
188   llvm::BasicBlock* if_block;
189 
190   // The block that is executed if the condition is true.
191   llvm::BasicBlock* true_block;
192 
193   // The block that is executed if the condition is false.
194   llvm::BasicBlock* false_block;
195 
196   // The block that follows after both the true_block and the
197   // false_block.
198   llvm::BasicBlock* after_block;
199 };
200 
201 // Inserts a diamond-shaped if-then-else construct at the current
202 // insertion point of the builder. This involves splitting the current
203 // block into two blocks, at the insertion point, and introducing a
204 // true-block and a false-block that connect the two split pieces. The
205 // true-block is executed if the condition parameter evaluates to true
206 // and otherwise the false-block is executed. If `emit_else` is false,
207 // it jumps to the after-block rather than the false-block if the
208 // condition is false, and the returned `false_block` is null.
209 //
210 // Currently the insertion point of the builder must be a well-formed
211 // block with a terminator. If you need to use this for a
212 // non-terminated block, just make the function able to do that too.
213 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
214                           llvm::IRBuilder<>* b, bool emit_else = true);
215 
216 // Emits a compare operation between "lhs" and "rhs" with the given predicate,
217 // and then converts the result to i8 so that it is addressable.
218 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
219                             llvm::Value* lhs, llvm::Value* rhs,
220                             llvm::IRBuilder<>* b, absl::string_view name = "");
221 
222 // Emits a call that logs the given value with the given tag as a prefix.
223 // The provided tag and value are passed to a runtime logging call that is
224 // embedded in this translation unit when the emitted code is executed.
225 //
226 // This can be very useful for debugging generated programs in short order when
227 // developing new generated routines.
228 //
229 // Precondition: value must be an int64.
230 // Precondition: tag must be a stable pointer for the lifetime of the generated
231 // program (the constant pointer is burned in to the program).
232 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b);
233 
234 // Adds alignment metadata to a load instruction using the given alignment.
235 // The alignment refers to the result of the load, not the load itself.
236 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
237 
238 // Adds dereferenceable metadata to a load instruction using the given
239 // the number of dereferenceable bytes.
240 // Dereferenceable refers to the result of the load, not the load itself.
241 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
242                                        uint64_t dereferenceable_bytes);
243 
244 // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
245 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
246                                     llvm::Instruction* inst);
247 
248 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
249 
250 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
251 
252 // Create a bitwise rotation of `rotand` by `rotor`.
253 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
254                        llvm::IRBuilder<>* builder);
255 
256 // Returns the number of bytes within the shape.
257 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
258 
259 // Gets an llvm::FastMathFlags that reflects the settings in the given
260 // module config.
261 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
262 
263 // Computes a conservative union of the metadata in "a" and "b".  For
264 // aliasing-related metadata, this means the result can be applied to
265 // instructions whose aliasing relationship can be described either by "a" *or*
266 // by "b".
267 std::map<int, llvm::MDNode*> MergeMetadata(
268     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
269     const std::map<int, llvm::MDNode*>& b);
270 
271 // Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is
272 // enabled for the given HLO module.
273 //
274 // A sanitized version of `hlo_module_name` is incorporated into the file name.
275 // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
276 // of "-no-opt.ll" is used.
277 void DumpIrIfEnabled(const HloModule& hlo_module,
278                      const llvm::Module& llvm_module, bool optimized,
279                      absl::string_view filename_suffix = "");
280 
281 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
282                                   llvm::GlobalValue::LinkageTypes linkage,
283                                   const HloModuleConfig& module_config,
284                                   absl::string_view name, llvm::Module* module);
285 
286 // Extracts the xla_backend_extra_options from `config` and passes those that
287 // don't start with xla_ to LLVM.
288 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config);
289 
290 // Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the
291 // result as a pair of (low 32 bits, high 32 bits).
292 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
293                                                     llvm::Value* src0,
294                                                     llvm::Value* src1);
295 // Splits the 64-bit integer value into its high and low 32 bits.
296 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
297     llvm::IRBuilder<>* b, llvm::Value* value_64bits);
298 
299 // Checks whether a global variable is already created to represent the state
300 // of a random number generator. If not, creates such a variable. Returns the
301 // global variable.
302 llvm::GlobalVariable* GetOrCreateVariableRngState(llvm::Module* module,
303                                                   llvm::IRBuilder<>* b);
304 
305 // Adds a delta value to the global state variable and return the old value of
306 // the variable.
307 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
308                                   llvm::IRBuilder<>* b);
309 
310 // Gets the LLVM address space that should be used for global variables (e.g.
311 // XLA's rng state).
312 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module);
313 }  // namespace llvm_ir
314 }  // namespace xla
315 
316 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
317