1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17 
18 #include <stddef.h>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "llvm/IR/DerivedTypes.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
26 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/Type.h"
35 #include "tensorflow/compiler/xla/literal.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 
50 namespace xla {
51 namespace gpu {
52 
53 using absl::StrAppend;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57 
58 namespace {
59 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)60 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
61   if (operand->opcode() == HloOpcode::kConstant &&
62       operand->literal().IsAllFloat(value)) {
63     return true;
64   }
65   return operand->opcode() == HloOpcode::kBroadcast &&
66          IsFPLiteralWithValue(operand->operand(0), value);
67 }
68 }  // namespace
69 
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)70 GpuElementalIrEmitter::GpuElementalIrEmitter(
71     const HloModuleConfig& hlo_module_config, llvm::Module* module,
72     llvm::IRBuilder<>* b, NestedComputer compute_nested)
73     : ElementalIrEmitter(hlo_module_config, module, b),
74       hlo_module_config_(hlo_module_config),
75       compute_nested_(std::move(compute_nested)) {}
76 
EmitLibdeviceMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)77 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
78     const string& callee_name, absl::Span<llvm::Value* const> operands,
79     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
80   // The libdevice math functions differentiate between "double" and "float" by
81   // appending an 'f' to the function's name. libdevice doesn't have f16 math
82   // functions, so we convert the operands to f32 before calling the function
83   // and then convert the result back to f16.
84   string munged_callee = callee_name;
85   bool cast_result_to_fp16 = false;
86   std::vector<llvm::Value*> converted_operands(operands.begin(),
87                                                operands.end());
88   std::vector<PrimitiveType> converted_input_types(input_types.begin(),
89                                                    input_types.end());
90   switch (output_type) {
91     case F16:
92       cast_result_to_fp16 = true;
93       for (int64 i = 0; i < operands.size(); ++i) {
94         if (input_types[i] == F16) {
95           converted_operands[i] =
96               FPCast(converted_operands[i], b_->getFloatTy());
97           converted_input_types[i] = F32;
98         }
99       }
100       output_type = F32;
101       TF_FALLTHROUGH_INTENDED;
102     case F32:
103       StrAppend(&munged_callee, "f");
104       break;
105     case F64:
106       break;
107     default:
108       return Unimplemented("Bad type for libdevice math call: %s",
109                            PrimitiveType_Name(output_type));
110   }
111   llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
112                                      converted_input_types, output_type)
113                             .ValueOrDie();
114   if (cast_result_to_fp16) {
115     result = FPCast(result, b_->getHalfTy());
116   }
117   return result;
118 }
119 
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)120 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
121     const string& callee_name, absl::Span<llvm::Value* const> operands,
122     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
123   // llvm intrinsics differentiate between half/float/double functions via
124   // the suffixes ".f16", ".f32" and ".f64".
125   string munged_callee = callee_name;
126   switch (output_type) {
127     case F16:
128       StrAppend(&munged_callee, ".f16");
129       break;
130     case F32:
131       StrAppend(&munged_callee, ".f32");
132       break;
133     case F64:
134       StrAppend(&munged_callee, ".f64");
135       break;
136     default:
137       return Unimplemented("Bad type for llvm intrinsic math call: %s",
138                            PrimitiveType_Name(output_type));
139   }
140   return EmitMathCall(munged_callee, operands, input_types, output_type);
141 }
142 
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)143 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
144     const string& callee_name, absl::Span<llvm::Value* const> operands,
145     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
146   // Binary math functions transform are of type [T] -> T.
147   for (PrimitiveType input_type : input_types) {
148     if (output_type != input_type) {
149       return Unimplemented("Input type ≠ output type: %s ≠ %s",
150                            PrimitiveType_Name(input_type),
151                            PrimitiveType_Name(output_type));
152     }
153   }
154 
155   return EmitDeviceFunctionCall(
156       callee_name, operands, input_types, output_type,
157       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind});
158 }
159 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)160 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
161     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
162   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
163   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
164   PrimitiveType output_type = op->shape().element_type();
165   HloOpcode opcode = op->opcode();
166 
167   if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
168       (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
169     return llvm_ir::EmitCallToIntrinsic(
170         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
171                                       : llvm::Intrinsic::minnum,
172         {lhs_value, rhs_value}, {lhs_value->getType()}, b_);
173   }
174 
175   switch (op->opcode()) {
176     case HloOpcode::kRemainder: {
177       return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
178                                    {lhs_input_type, rhs_input_type},
179                                    output_type);
180     }
181     case HloOpcode::kPower: {
182       return EmitPowerOp(op, lhs_value, rhs_value);
183     }
184     default:
185       return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
186   }
187 }
188 
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)189 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
190     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
191   CHECK_EQ(op->opcode(), HloOpcode::kPower);
192   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
193   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
194   PrimitiveType output_type = op->shape().element_type();
195   return EmitLibdeviceMathCall("__nv_pow", {lhs_value, rhs_value},
196                                {lhs_input_type, rhs_input_type}, output_type);
197 }
198 
EmitErfcInv(PrimitiveType prim_type,llvm::Value * value)199 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
200     PrimitiveType prim_type, llvm::Value* value) {
201   return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
202 }
203 
EmitLog(PrimitiveType prim_type,llvm::Value * value)204 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
205                                                       llvm::Value* value) {
206   return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
207 }
208 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)209 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
210                                                         llvm::Value* value) {
211   return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
212 }
213 
EmitSin(PrimitiveType prim_type,llvm::Value * value)214 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
215                                                       llvm::Value* value) {
216   return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
217 }
218 
EmitCos(PrimitiveType prim_type,llvm::Value * value)219 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
220                                                       llvm::Value* value) {
221   return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
222 }
223 
EmitExp(PrimitiveType prim_type,llvm::Value * value)224 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
225                                                       llvm::Value* value) {
226   return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
227 }
228 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)229 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
230                                                         llvm::Value* value) {
231   return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
232 }
233 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)234 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
235                                                       llvm::Value* lhs,
236                                                       llvm::Value* rhs) {
237   return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
238                                prim_type);
239 }
240 
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)241 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
242                                                        llvm::Value* value) {
243   return EmitLibdeviceMathCall("__nv_sqrt", {value}, {prim_type}, prim_type);
244 }
245 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)246 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
247                                                         llvm::Value* value) {
248   return EmitLibdeviceMathCall("__nv_rsqrt", {value}, {prim_type}, prim_type);
249 }
250 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)251 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
252                                                         llvm::Value* lhs,
253                                                         llvm::Value* rhs) {
254   return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
255                                prim_type);
256 }
257 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)258 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
259                                                        llvm::Value* value) {
260   // Emit a fast approximation of tanh instead of calling __nv_tanh.
261   // __nv_tanh is particularly bad because it contains branches, thus
262   // preventing LLVM's load-store vectorizer from working its magic across a
263   // function which contains tanh calls.
264   //
265   // This routine isn't numerically precise, but it's good enough for ML.
266 
267   // Upcast F16 to F32 if necessary.
268   llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
269   llvm::Value* input = FPCast(value, type);
270   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
271   return FPCast(fast_tanh, value->getType());
272 }
273 
EmitRoundNearestAfz(PrimitiveType prim_type,llvm::Value * value)274 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRoundNearestAfz(
275     PrimitiveType prim_type, llvm::Value* value) {
276   // Use libdevice __nv_round instead of llvm.round. This is to workaround a
277   // bug in the PTX backend, which implements llvm.round with PTX cvt.rni.
278   // When the llvm.round is fixed, we may still want to use __nv_round here as
279   // expanding the non-trivial implementation early while inlining allows better
280   // optimizations.
281   return EmitLibdeviceMathCall("__nv_round", {value}, {prim_type}, prim_type);
282 }
283 
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes)284 llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
285     const string& callee_name, absl::Span<llvm::Value* const> operands,
286     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
287     absl::Span<const llvm::Attribute::AttrKind> attributes) {
288   std::vector<llvm::Type*> ir_input_types;
289   for (PrimitiveType input_type : input_types) {
290     ir_input_types.push_back(
291         llvm_ir::PrimitiveTypeToIrType(input_type, module_));
292   }
293   llvm::FunctionType* callee_type = llvm::FunctionType::get(
294       llvm_ir::PrimitiveTypeToIrType(output_type, module_),  // Return type.
295       ir_input_types,                                        // Parameter types.
296       false);  // No variadic arguments.
297 
298   // Declares the callee if it is not declared already.
299   llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
300       b_->GetInsertBlock()
301           ->getModule()
302           ->getOrInsertFunction(callee_name, callee_type)
303           .getCallee());
304 
305   for (auto attribute : attributes) {
306     callee->addFnAttr(attribute);
307   }
308 
309   return Call(callee, llvm_ir::AsArrayRef(operands));
310 }
311 
EmitThreadId()312 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
313   llvm::Value* block_id =
314       IntCast(llvm_ir::EmitCallToIntrinsic(
315                   llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_),
316               b_->getIntNTy(128), /*isSigned=*/true, "block.id");
317   llvm::Value* thread_id_in_block =
318       IntCast(llvm_ir::EmitCallToIntrinsic(
319                   llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_),
320               b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
321   llvm::Value* threads_per_block =
322       IntCast(llvm_ir::EmitCallToIntrinsic(
323                   llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_),
324               b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
325   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
326 }
327 
MakeElementGenerator(const HloInstruction * hlo,const HloToElementGeneratorMap & operand_to_generator)328 llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
329     const HloInstruction* hlo,
330     const HloToElementGeneratorMap& operand_to_generator) {
331   switch (hlo->opcode()) {
332     case HloOpcode::kMap:
333       return [=, &operand_to_generator](
334                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
335         TF_RET_CHECK(!hlo->operands().empty())
336             << "Zero operand map not implemented in GPU backend.";
337         TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
338         std::vector<llvm::Value*> operand_elements;
339         for (HloInstruction* operand : hlo->operands()) {
340           TF_ASSIGN_OR_RETURN(llvm::Value * value,
341                               operand_to_generator.at(operand)(index));
342           operand_elements.push_back(value);
343         }
344         return compute_nested_(*hlo->to_apply(), operand_elements);
345       };
346     case HloOpcode::kReduceWindow:
347       // Pseudocode:
348       // for each index I in output
349       //   value = init_value
350       //   for each index W in window
351       //     for each dimension i from 0 to rank - 1
352       //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
353       //     if I in bounds of input
354       //       value = function(value, input[I])
355       //     output[O] = value
356       return [=, &operand_to_generator](
357                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
358         const HloInstruction* operand = hlo->operand(0);
359         const Window& window = hlo->window();
360 
361         PrimitiveType operand_element_type = operand->shape().element_type();
362         llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
363             llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
364             "reduce_window_accum_ptr", b_);
365         {
366           TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
367                               operand_to_generator.at(hlo->operand(1))(
368                                   IrArray::Index(index.GetType())));
369           Store(init_value, accum_ptr);
370         }
371 
372         llvm::Type* index_type = index.GetType();
373         auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
374           return index.GetConstantWithIndexType(c);
375         };
376 
377         llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
378         std::vector<int64> window_size;
379         for (const auto& dim : window.dimensions()) {
380           window_size.push_back(dim.size());
381         }
382         const IrArray::Index window_index = loops.AddLoopsForShape(
383             ShapeUtil::MakeShape(operand_element_type, window_size), "window");
384         CHECK_EQ(window_index.size(), index.size());
385 
386         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
387 
388         std::vector<llvm::Value*> input_multi_index(index.size());
389         llvm::Value* in_bounds = b_->getInt1(true);
390         for (size_t i = 0; i < index.size(); ++i) {
391           llvm::Value* stridden_index = NSWMul(
392               index[i], index_typed_const(window.dimensions(i).stride()));
393           input_multi_index[i] = NSWSub(
394               NSWAdd(stridden_index,
395                      NSWMul(window_index[i],
396                             index_typed_const(
397                                 window.dimensions(i).window_dilation()))),
398               index_typed_const(window.dimensions(i).padding_low()));
399 
400           // We need to verify that we are not in the dilated base area.
401           llvm::Value* dilation_condition = ICmpEQ(
402               SRem(input_multi_index[i],
403                    index_typed_const(window.dimensions(i).base_dilation())),
404               index_typed_const(0));
405           in_bounds = And(in_bounds, dilation_condition);
406 
407           // Apply base dilation to the index.
408           input_multi_index[i] =
409               SDiv(input_multi_index[i],
410                    index_typed_const(window.dimensions(i).base_dilation()));
411 
412           // We must check whether 0 ≤ input_multi_index[i] < bound, as
413           // otherwise we are in the pad and so can skip the computation. This
414           // comparison is equivalent to the unsigned comparison
415           // input_multi_index[i] < bound, as a negative value wraps to a large
416           // positive value.
417           in_bounds =
418               And(in_bounds,
419                   ICmpULT(input_multi_index[i],
420                           index_typed_const(operand->shape().dimensions(i))));
421         }
422 
423         llvm_ir::LlvmIfData if_data =
424             llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
425         SetToFirstInsertPoint(if_data.true_block, b_);
426 
427         // We are not in pad, so do the computation.
428         IrArray::Index input_index(input_multi_index, operand->shape(),
429                                    index_type);
430         TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
431                             operand_to_generator.at(operand)(input_index));
432         TF_ASSIGN_OR_RETURN(
433             llvm::Value * accum_value,
434             compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
435         Store(accum_value, accum_ptr);
436 
437         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
438         return Load(accum_ptr);
439       };
440     case HloOpcode::kReduce:
441       // TODO(b/118332391): This should be supported.
442       CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
443       return [=, &operand_to_generator](
444                  const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
445         const HloInstruction* operand = hlo->operand(0);
446         llvm::Value* accum_ptr =
447             b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
448                 hlo->shape().element_type(), module_));
449         llvm::Type* index_type = output_index.GetType();
450         TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
451                             operand_to_generator.at(hlo->operand(1))(
452                                 IrArray::Index(index_type)));
453         b()->CreateStore(init_value, accum_ptr);
454 
455         llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
456         std::vector<llvm::Value*> input_multi_index =
457             loops.AddLoopsForShapeOnDimensions(
458                 operand->shape(), hlo->dimensions(), "reduction_dim");
459         if (!ShapeUtil::IsScalar(hlo->shape())) {
460           // Here only input_multi_index[hlo->dimensions()] are non-null, so we
461           // must set the rest.
462           size_t j = 0;
463           for (auto& i : input_multi_index) {
464             if (i == nullptr) {
465               i = output_index[j++];
466             }
467           }
468           CHECK_EQ(output_index.size(), j);
469         }
470         llvm_ir::IrArray::Index input_index(
471             input_multi_index, hlo->operand(0)->shape(), index_type);
472 
473         SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
474         TF_ASSIGN_OR_RETURN(
475             llvm::Value * input_value,
476             operand_to_generator.at(hlo->operand(0))(input_index));
477         TF_ASSIGN_OR_RETURN(
478             llvm::Value * accum_value,
479             compute_nested_(*hlo->to_apply(),
480                             {b()->CreateLoad(accum_ptr), input_value}));
481         b()->CreateStore(accum_value, accum_ptr);
482         SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
483         return b()->CreateLoad(accum_ptr);
484       };
485     default:
486       return ElementalIrEmitter::MakeElementGenerator(hlo,
487                                                       operand_to_generator);
488   }
489 }
490 
491 }  // namespace gpu
492 }  // namespace xla
493