1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17
18 #include <stddef.h>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "llvm/IR/DerivedTypes.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
26 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/Type.h"
35 #include "tensorflow/compiler/xla/literal.h"
36 #include "tensorflow/compiler/xla/primitive_util.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49
50 namespace xla {
51 namespace gpu {
52
53 using absl::StrAppend;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57
58 namespace {
59 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)60 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
61 if (operand->opcode() == HloOpcode::kConstant &&
62 operand->literal().IsAllFloat(value)) {
63 return true;
64 }
65 return operand->opcode() == HloOpcode::kBroadcast &&
66 IsFPLiteralWithValue(operand->operand(0), value);
67 }
68 } // namespace
69
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)70 GpuElementalIrEmitter::GpuElementalIrEmitter(
71 const HloModuleConfig& hlo_module_config, llvm::Module* module,
72 llvm::IRBuilder<>* b, NestedComputer compute_nested)
73 : ElementalIrEmitter(hlo_module_config, module, b),
74 hlo_module_config_(hlo_module_config),
75 compute_nested_(std::move(compute_nested)) {}
76
EmitLibdeviceMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)77 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
78 const string& callee_name, absl::Span<llvm::Value* const> operands,
79 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
80 // The libdevice math functions differentiate between "double" and "float" by
81 // appending an 'f' to the function's name. libdevice doesn't have f16 math
82 // functions, so we convert the operands to f32 before calling the function
83 // and then convert the result back to f16.
84 string munged_callee = callee_name;
85 bool cast_result_to_fp16 = false;
86 std::vector<llvm::Value*> converted_operands(operands.begin(),
87 operands.end());
88 std::vector<PrimitiveType> converted_input_types(input_types.begin(),
89 input_types.end());
90 switch (output_type) {
91 case F16:
92 cast_result_to_fp16 = true;
93 for (int64 i = 0; i < operands.size(); ++i) {
94 if (input_types[i] == F16) {
95 converted_operands[i] =
96 FPCast(converted_operands[i], b_->getFloatTy());
97 converted_input_types[i] = F32;
98 }
99 }
100 output_type = F32;
101 TF_FALLTHROUGH_INTENDED;
102 case F32:
103 StrAppend(&munged_callee, "f");
104 break;
105 case F64:
106 break;
107 default:
108 return Unimplemented("Bad type for libdevice math call: %s",
109 PrimitiveType_Name(output_type));
110 }
111 llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
112 converted_input_types, output_type)
113 .ValueOrDie();
114 if (cast_result_to_fp16) {
115 result = FPCast(result, b_->getHalfTy());
116 }
117 return result;
118 }
119
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)120 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
121 const string& callee_name, absl::Span<llvm::Value* const> operands,
122 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
123 // llvm intrinsics differentiate between half/float/double functions via
124 // the suffixes ".f16", ".f32" and ".f64".
125 string munged_callee = callee_name;
126 switch (output_type) {
127 case F16:
128 StrAppend(&munged_callee, ".f16");
129 break;
130 case F32:
131 StrAppend(&munged_callee, ".f32");
132 break;
133 case F64:
134 StrAppend(&munged_callee, ".f64");
135 break;
136 default:
137 return Unimplemented("Bad type for llvm intrinsic math call: %s",
138 PrimitiveType_Name(output_type));
139 }
140 return EmitMathCall(munged_callee, operands, input_types, output_type);
141 }
142
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)143 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
144 const string& callee_name, absl::Span<llvm::Value* const> operands,
145 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
146 // Binary math functions transform are of type [T] -> T.
147 for (PrimitiveType input_type : input_types) {
148 if (output_type != input_type) {
149 return Unimplemented("Input type ≠ output type: %s ≠ %s",
150 PrimitiveType_Name(input_type),
151 PrimitiveType_Name(output_type));
152 }
153 }
154
155 return EmitDeviceFunctionCall(
156 callee_name, operands, input_types, output_type,
157 {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind});
158 }
159
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)160 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
161 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
162 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
163 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
164 PrimitiveType output_type = op->shape().element_type();
165 HloOpcode opcode = op->opcode();
166
167 if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
168 (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
169 return llvm_ir::EmitCallToIntrinsic(
170 opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
171 : llvm::Intrinsic::minnum,
172 {lhs_value, rhs_value}, {lhs_value->getType()}, b_);
173 }
174
175 switch (op->opcode()) {
176 case HloOpcode::kRemainder: {
177 return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
178 {lhs_input_type, rhs_input_type},
179 output_type);
180 }
181 case HloOpcode::kPower: {
182 return EmitPowerOp(op, lhs_value, rhs_value);
183 }
184 default:
185 return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
186 }
187 }
188
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)189 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
190 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
191 CHECK_EQ(op->opcode(), HloOpcode::kPower);
192 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
193 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
194 PrimitiveType output_type = op->shape().element_type();
195 return EmitLibdeviceMathCall("__nv_pow", {lhs_value, rhs_value},
196 {lhs_input_type, rhs_input_type}, output_type);
197 }
198
EmitErfcInv(PrimitiveType prim_type,llvm::Value * value)199 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
200 PrimitiveType prim_type, llvm::Value* value) {
201 return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
202 }
203
EmitLog(PrimitiveType prim_type,llvm::Value * value)204 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
205 llvm::Value* value) {
206 return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
207 }
208
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)209 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
210 llvm::Value* value) {
211 return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
212 }
213
EmitSin(PrimitiveType prim_type,llvm::Value * value)214 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
215 llvm::Value* value) {
216 return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
217 }
218
EmitCos(PrimitiveType prim_type,llvm::Value * value)219 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
220 llvm::Value* value) {
221 return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
222 }
223
EmitExp(PrimitiveType prim_type,llvm::Value * value)224 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
225 llvm::Value* value) {
226 return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
227 }
228
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)229 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
230 llvm::Value* value) {
231 return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
232 }
233
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)234 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
235 llvm::Value* lhs,
236 llvm::Value* rhs) {
237 return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
238 prim_type);
239 }
240
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)241 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
242 llvm::Value* value) {
243 return EmitLibdeviceMathCall("__nv_sqrt", {value}, {prim_type}, prim_type);
244 }
245
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)246 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
247 llvm::Value* value) {
248 return EmitLibdeviceMathCall("__nv_rsqrt", {value}, {prim_type}, prim_type);
249 }
250
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs)251 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
252 llvm::Value* lhs,
253 llvm::Value* rhs) {
254 return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
255 prim_type);
256 }
257
EmitTanh(PrimitiveType prim_type,llvm::Value * value)258 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
259 llvm::Value* value) {
260 // Emit a fast approximation of tanh instead of calling __nv_tanh.
261 // __nv_tanh is particularly bad because it contains branches, thus
262 // preventing LLVM's load-store vectorizer from working its magic across a
263 // function which contains tanh calls.
264 //
265 // This routine isn't numerically precise, but it's good enough for ML.
266
267 // Upcast F16 to F32 if necessary.
268 llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
269 llvm::Value* input = FPCast(value, type);
270 llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
271 return FPCast(fast_tanh, value->getType());
272 }
273
EmitRoundNearestAfz(PrimitiveType prim_type,llvm::Value * value)274 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRoundNearestAfz(
275 PrimitiveType prim_type, llvm::Value* value) {
276 // Use libdevice __nv_round instead of llvm.round. This is to workaround a
277 // bug in the PTX backend, which implements llvm.round with PTX cvt.rni.
278 // When the llvm.round is fixed, we may still want to use __nv_round here as
279 // expanding the non-trivial implementation early while inlining allows better
280 // optimizations.
281 return EmitLibdeviceMathCall("__nv_round", {value}, {prim_type}, prim_type);
282 }
283
EmitDeviceFunctionCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::Span<const llvm::Attribute::AttrKind> attributes)284 llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
285 const string& callee_name, absl::Span<llvm::Value* const> operands,
286 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
287 absl::Span<const llvm::Attribute::AttrKind> attributes) {
288 std::vector<llvm::Type*> ir_input_types;
289 for (PrimitiveType input_type : input_types) {
290 ir_input_types.push_back(
291 llvm_ir::PrimitiveTypeToIrType(input_type, module_));
292 }
293 llvm::FunctionType* callee_type = llvm::FunctionType::get(
294 llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type.
295 ir_input_types, // Parameter types.
296 false); // No variadic arguments.
297
298 // Declares the callee if it is not declared already.
299 llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
300 b_->GetInsertBlock()
301 ->getModule()
302 ->getOrInsertFunction(callee_name, callee_type)
303 .getCallee());
304
305 for (auto attribute : attributes) {
306 callee->addFnAttr(attribute);
307 }
308
309 return Call(callee, llvm_ir::AsArrayRef(operands));
310 }
311
EmitThreadId()312 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
313 llvm::Value* block_id =
314 IntCast(llvm_ir::EmitCallToIntrinsic(
315 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_),
316 b_->getIntNTy(128), /*isSigned=*/true, "block.id");
317 llvm::Value* thread_id_in_block =
318 IntCast(llvm_ir::EmitCallToIntrinsic(
319 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_),
320 b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
321 llvm::Value* threads_per_block =
322 IntCast(llvm_ir::EmitCallToIntrinsic(
323 llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_),
324 b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
325 return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
326 }
327
MakeElementGenerator(const HloInstruction * hlo,const HloToElementGeneratorMap & operand_to_generator)328 llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
329 const HloInstruction* hlo,
330 const HloToElementGeneratorMap& operand_to_generator) {
331 switch (hlo->opcode()) {
332 case HloOpcode::kMap:
333 return [=, &operand_to_generator](
334 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
335 TF_RET_CHECK(!hlo->operands().empty())
336 << "Zero operand map not implemented in GPU backend.";
337 TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
338 std::vector<llvm::Value*> operand_elements;
339 for (HloInstruction* operand : hlo->operands()) {
340 TF_ASSIGN_OR_RETURN(llvm::Value * value,
341 operand_to_generator.at(operand)(index));
342 operand_elements.push_back(value);
343 }
344 return compute_nested_(*hlo->to_apply(), operand_elements);
345 };
346 case HloOpcode::kReduceWindow:
347 // Pseudocode:
348 // for each index I in output
349 // value = init_value
350 // for each index W in window
351 // for each dimension i from 0 to rank - 1
352 // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
353 // if I in bounds of input
354 // value = function(value, input[I])
355 // output[O] = value
356 return [=, &operand_to_generator](
357 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
358 const HloInstruction* operand = hlo->operand(0);
359 const Window& window = hlo->window();
360
361 PrimitiveType operand_element_type = operand->shape().element_type();
362 llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
363 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
364 "reduce_window_accum_ptr", b_);
365 {
366 TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
367 operand_to_generator.at(hlo->operand(1))(
368 IrArray::Index(index.GetType())));
369 Store(init_value, accum_ptr);
370 }
371
372 llvm::Type* index_type = index.GetType();
373 auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
374 return index.GetConstantWithIndexType(c);
375 };
376
377 llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
378 std::vector<int64> window_size;
379 for (const auto& dim : window.dimensions()) {
380 window_size.push_back(dim.size());
381 }
382 const IrArray::Index window_index = loops.AddLoopsForShape(
383 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
384 CHECK_EQ(window_index.size(), index.size());
385
386 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
387
388 std::vector<llvm::Value*> input_multi_index(index.size());
389 llvm::Value* in_bounds = b_->getInt1(true);
390 for (size_t i = 0; i < index.size(); ++i) {
391 llvm::Value* stridden_index = NSWMul(
392 index[i], index_typed_const(window.dimensions(i).stride()));
393 input_multi_index[i] = NSWSub(
394 NSWAdd(stridden_index,
395 NSWMul(window_index[i],
396 index_typed_const(
397 window.dimensions(i).window_dilation()))),
398 index_typed_const(window.dimensions(i).padding_low()));
399
400 // We need to verify that we are not in the dilated base area.
401 llvm::Value* dilation_condition = ICmpEQ(
402 SRem(input_multi_index[i],
403 index_typed_const(window.dimensions(i).base_dilation())),
404 index_typed_const(0));
405 in_bounds = And(in_bounds, dilation_condition);
406
407 // Apply base dilation to the index.
408 input_multi_index[i] =
409 SDiv(input_multi_index[i],
410 index_typed_const(window.dimensions(i).base_dilation()));
411
412 // We must check whether 0 ≤ input_multi_index[i] < bound, as
413 // otherwise we are in the pad and so can skip the computation. This
414 // comparison is equivalent to the unsigned comparison
415 // input_multi_index[i] < bound, as a negative value wraps to a large
416 // positive value.
417 in_bounds =
418 And(in_bounds,
419 ICmpULT(input_multi_index[i],
420 index_typed_const(operand->shape().dimensions(i))));
421 }
422
423 llvm_ir::LlvmIfData if_data =
424 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
425 SetToFirstInsertPoint(if_data.true_block, b_);
426
427 // We are not in pad, so do the computation.
428 IrArray::Index input_index(input_multi_index, operand->shape(),
429 index_type);
430 TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
431 operand_to_generator.at(operand)(input_index));
432 TF_ASSIGN_OR_RETURN(
433 llvm::Value * accum_value,
434 compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
435 Store(accum_value, accum_ptr);
436
437 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
438 return Load(accum_ptr);
439 };
440 case HloOpcode::kReduce:
441 // TODO(b/118332391): This should be supported.
442 CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
443 return [=, &operand_to_generator](
444 const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
445 const HloInstruction* operand = hlo->operand(0);
446 llvm::Value* accum_ptr =
447 b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
448 hlo->shape().element_type(), module_));
449 llvm::Type* index_type = output_index.GetType();
450 TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
451 operand_to_generator.at(hlo->operand(1))(
452 IrArray::Index(index_type)));
453 b()->CreateStore(init_value, accum_ptr);
454
455 llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
456 std::vector<llvm::Value*> input_multi_index =
457 loops.AddLoopsForShapeOnDimensions(
458 operand->shape(), hlo->dimensions(), "reduction_dim");
459 if (!ShapeUtil::IsScalar(hlo->shape())) {
460 // Here only input_multi_index[hlo->dimensions()] are non-null, so we
461 // must set the rest.
462 size_t j = 0;
463 for (auto& i : input_multi_index) {
464 if (i == nullptr) {
465 i = output_index[j++];
466 }
467 }
468 CHECK_EQ(output_index.size(), j);
469 }
470 llvm_ir::IrArray::Index input_index(
471 input_multi_index, hlo->operand(0)->shape(), index_type);
472
473 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
474 TF_ASSIGN_OR_RETURN(
475 llvm::Value * input_value,
476 operand_to_generator.at(hlo->operand(0))(input_index));
477 TF_ASSIGN_OR_RETURN(
478 llvm::Value * accum_value,
479 compute_nested_(*hlo->to_apply(),
480 {b()->CreateLoad(accum_ptr), input_value}));
481 b()->CreateStore(accum_value, accum_ptr);
482 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
483 return b()->CreateLoad(accum_ptr);
484 };
485 default:
486 return ElementalIrEmitter::MakeElementGenerator(hlo,
487 operand_to_generator);
488 }
489 }
490
491 } // namespace gpu
492 } // namespace xla
493