1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <ostream>
21 #include <set>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/protobuf_util.h"
28 #include "tensorflow/compiler/xla/ptr_util.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/flatmap.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/lib/strings/strcat.h"
42 #include "tensorflow/core/platform/logging.h"
43
44 namespace xla {
45
46 using tensorflow::str_util::CEscape;
47 using ::tensorflow::str_util::Join;
48 using ::tensorflow::strings::StrAppend;
49 using ::tensorflow::strings::StrCat;
50
51 /* static */
CreateFromProto(HloModule * module,const HloInstructionProto & proto,const tensorflow::gtl::FlatMap<string,HloInstruction * > & instruction_map,const tensorflow::gtl::FlatMap<string,HloComputation * > & computation_map,const std::function<void (std::unique_ptr<HloComputation>)> & add_fused_computation)52 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
53 HloModule* module, const HloInstructionProto& proto,
54 const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
55 const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
56 const std::function<void(std::unique_ptr<HloComputation>)>&
57 add_fused_computation) {
58 TF_RET_CHECK(!proto.opcode().empty());
59 TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
60 TF_RET_CHECK(proto.has_shape());
61
62 auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
63 for (const string& operand_name : proto.operand_names()) {
64 TF_RET_CHECK(ContainsKey(instruction_map, operand_name))
65 << "No instruction named " << operand_name;
66 instruction->AppendOperand(instruction_map.at(operand_name));
67 }
68 for (const string& predecessor_name : proto.control_predecessor_names()) {
69 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name))
70 << "No instruction named " << predecessor_name;
71 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name)
72 ->AddControlDependencyTo(instruction.get()));
73 }
74
75 // In the proto, fused computations are held exclusively within the
76 // HloInstructionProto and do not appear as an HloComputationProto within the
77 // HloModuleProto.
78 if (instruction->opcode() == HloOpcode::kFusion) {
79 TF_RET_CHECK(proto.has_fused_instructions_computation());
80 TF_RET_CHECK(!proto.fusion_kind().empty());
81 TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
82 StringToFusionKind(proto.fusion_kind()));
83 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
84 HloComputation::CreateFromProto(
85 module, proto.fused_instructions_computation(),
86 computation_map, add_fused_computation,
87 /*fusion_instruction=*/instruction.get()));
88 instruction->called_computations_.push_back(fused_computation.get());
89 add_fused_computation(std::move(fused_computation));
90 } else {
91 for (const string& computation_name : proto.called_computation_names()) {
92 TF_RET_CHECK(ContainsKey(computation_map, computation_name))
93 << "No computation named " << computation_name;
94 instruction->called_computations_.push_back(
95 computation_map.at(computation_name));
96 }
97 }
98
99 TF_RET_CHECK(!proto.name().empty());
100 instruction->name_ = proto.name();
101
102 instruction->metadata_ = proto.metadata();
103 if (proto.has_literal()) {
104 TF_ASSIGN_OR_RETURN(instruction->literal_,
105 Literal::CreateFromProto(proto.literal()));
106 }
107 instruction->parameter_number_ = proto.parameter_number();
108
109 instruction->tuple_index_ = proto.tuple_index();
110 for (int64 dimension : proto.dimensions()) {
111 instruction->dimensions_.push_back(dimension);
112 }
113 if (proto.has_window()) {
114 instruction->window_ = MakeUnique<Window>(proto.window());
115 }
116 if (proto.has_convolution_dimension_numbers()) {
117 instruction->convolution_dimension_numbers_ =
118 MakeUnique<ConvolutionDimensionNumbers>(
119 proto.convolution_dimension_numbers());
120 }
121 if (proto.has_dot_dimension_numbers()) {
122 instruction->dot_dimension_numbers_ =
123 MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
124 }
125 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
126 proto.slice_dimensions()) {
127 instruction->slice_starts_.push_back(slice_dimensions.start());
128 instruction->slice_limits_.push_back(slice_dimensions.limit());
129 instruction->slice_strides_.push_back(slice_dimensions.stride());
130 }
131 instruction->exponent_bits_ = proto.exponent_bits();
132 instruction->mantissa_bits_ = proto.mantissa_bits();
133 for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
134 instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size);
135 }
136 if (proto.has_padding_config()) {
137 instruction->padding_config_ =
138 MakeUnique<PaddingConfig>(proto.padding_config());
139 }
140 instruction->outfeed_config_ = proto.outfeed_config();
141 instruction->distribution_ = proto.distribution();
142 instruction->epsilon_ = proto.epsilon();
143 instruction->feature_index_ = proto.feature_index();
144 instruction->channel_id_ = proto.channel_id();
145 instruction->infeed_config_ = proto.infeed_config();
146 instruction->custom_call_target_ = proto.custom_call_target();
147 instruction->outfeed_shape_ = proto.outfeed_shape();
148 instruction->fft_type_ = proto.fft_type();
149 for (int64 fft_len : proto.fft_length()) {
150 instruction->fft_length_.push_back(fft_len);
151 }
152
153 return std::move(instruction);
154 }
155
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)156 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
157 int64 parameter_number, const Shape& shape, const string& name) {
158 auto instruction =
159 WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
160 instruction->parameter_number_ = parameter_number;
161 instruction->name_ = name;
162 return instruction;
163 }
164
CreateTrace(const string & tag,HloInstruction * operand)165 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
166 const string& tag, HloInstruction* operand) {
167 auto instruction =
168 WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
169 instruction->operands_.push_back(operand);
170 instruction->literal_ = Literal::CreateR1U8(tag);
171 return instruction;
172 }
173
CreateConstant(std::unique_ptr<Literal> literal)174 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
175 std::unique_ptr<Literal> literal) {
176 auto instruction =
177 WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
178 instruction->literal_ = std::move(literal);
179 return instruction;
180 }
181
182 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)183 HloInstruction::CreateGetTupleElement(const Shape& shape,
184 HloInstruction* operand, int64 index) {
185 auto instruction =
186 WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
187 instruction->tuple_index_ = index;
188 instruction->AppendOperand(operand);
189 return instruction;
190 }
191
CreateRng(const Shape & shape,RandomDistribution distribution,tensorflow::gtl::ArraySlice<HloInstruction * > parameters)192 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
193 const Shape& shape, RandomDistribution distribution,
194 tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
195 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape));
196 instruction->distribution_ = distribution;
197 instruction->shape_ = shape;
198 for (HloInstruction* param : parameters) {
199 instruction->AppendOperand(param);
200 }
201 return instruction;
202 }
203
CreateNary(const Shape & shape,HloOpcode opcode,tensorflow::gtl::ArraySlice<HloInstruction * > operands)204 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
205 const Shape& shape, HloOpcode opcode,
206 tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
207 if (opcode == HloOpcode::kCopy) {
208 // It is impossible to copy an opaque shape, we don't know how big it is.
209 CHECK(!ShapeUtil::IsOpaque(shape));
210 }
211 auto instruction = WrapUnique(new HloInstruction(opcode, shape));
212 for (auto operand : operands) {
213 instruction->AppendOperand(operand);
214 }
215 return instruction;
216 }
217
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)218 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
219 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
220 // Only certain opcodes are supported with CreateUnary: opcodes of unary
221 // instructions with no auxiliary fields.
222 switch (opcode) {
223 case HloOpcode::kAbs:
224 case HloOpcode::kRoundNearestAfz:
225 case HloOpcode::kBitcast:
226 case HloOpcode::kCeil:
227 case HloOpcode::kCopy:
228 case HloOpcode::kCos:
229 case HloOpcode::kExp:
230 case HloOpcode::kFloor:
231 case HloOpcode::kImag:
232 case HloOpcode::kIsFinite:
233 case HloOpcode::kLog:
234 case HloOpcode::kNot:
235 case HloOpcode::kNegate:
236 case HloOpcode::kReal:
237 case HloOpcode::kSign:
238 case HloOpcode::kSin:
239 case HloOpcode::kSort:
240 case HloOpcode::kTanh:
241 break;
242 default:
243 LOG(FATAL) << "Invalid unary instruction opcode "
244 << HloOpcodeString(opcode);
245 }
246 return CreateNary(shape, opcode, {operand});
247 }
248
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)249 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
250 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
251 HloInstruction* rhs) {
252 // Only certain opcodes are supported with CreateBinary: opcodes of binary
253 // instructions with no auxiliary fields.
254 switch (opcode) {
255 case HloOpcode::kAdd:
256 case HloOpcode::kAtan2:
257 case HloOpcode::kDivide:
258 case HloOpcode::kComplex:
259 case HloOpcode::kDot:
260 case HloOpcode::kEq:
261 case HloOpcode::kGe:
262 case HloOpcode::kGt:
263 case HloOpcode::kLe:
264 case HloOpcode::kLt:
265 case HloOpcode::kMaximum:
266 case HloOpcode::kMinimum:
267 case HloOpcode::kMultiply:
268 case HloOpcode::kNe:
269 case HloOpcode::kPower:
270 case HloOpcode::kRemainder:
271 case HloOpcode::kSubtract:
272 case HloOpcode::kAnd:
273 case HloOpcode::kOr:
274 case HloOpcode::kShiftLeft:
275 case HloOpcode::kShiftRightArithmetic:
276 case HloOpcode::kShiftRightLogical:
277 break;
278 default:
279 LOG(FATAL) << "Invalid binary instruction opcode "
280 << HloOpcodeString(opcode);
281 }
282 return CreateNary(shape, opcode, {lhs, rhs});
283 }
284
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)285 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
286 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
287 HloInstruction* rhs, HloInstruction* ehs) {
288 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
289 // instructions with no auxiliary fields.
290 switch (opcode) {
291 case (HloOpcode::kClamp):
292 case (HloOpcode::kSelect):
293 break;
294 default:
295 LOG(FATAL) << "Invalid ternary instruction opcode "
296 << HloOpcodeString(opcode);
297 }
298 return CreateNary(shape, opcode, {lhs, rhs, ehs});
299 }
300
CreateVariadic(const Shape & shape,HloOpcode opcode,tensorflow::gtl::ArraySlice<HloInstruction * > operands)301 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
302 const Shape& shape, HloOpcode opcode,
303 tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
304 CHECK_EQ(HloOpcode::kTuple, opcode);
305 return CreateNary(shape, opcode, operands);
306 }
307
CreateMap(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,HloComputation * map_computation,tensorflow::gtl::ArraySlice<HloInstruction * > static_operands)308 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
309 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
310 HloComputation* map_computation,
311 tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
312 CHECK(static_operands.empty()) << "static_operands not yet supported";
313 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
314 for (auto operand : operands) {
315 instruction->AppendOperand(operand);
316 }
317 instruction->called_computations_.push_back(map_computation);
318 return instruction;
319 }
320
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers)321 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
322 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
323 const Window& window,
324 const ConvolutionDimensionNumbers& dimension_numbers) {
325 auto instruction =
326 WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape));
327 if (window_util::HasBaseDilation(window)) {
328 instruction->name_ = instruction->name() + "-base-dilated";
329 }
330 if (window_util::HasWindowDilation(window)) {
331 instruction->name_ = instruction->name() + "-window-dilated";
332 }
333 instruction->AppendOperand(lhs);
334 instruction->AppendOperand(rhs);
335 instruction->window_ = MakeUnique<Window>(window);
336 instruction->convolution_dimension_numbers_ =
337 MakeUnique<ConvolutionDimensionNumbers>(dimension_numbers);
338 return instruction;
339 }
340
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,tensorflow::gtl::ArraySlice<int64> fft_length)341 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
342 const Shape& shape, HloInstruction* operand, FftType fft_type,
343 tensorflow::gtl::ArraySlice<int64> fft_length) {
344 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
345 instruction->AppendOperand(operand);
346 instruction->fft_type_ = fft_type;
347 instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
348 return instruction;
349 }
350
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers)351 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
352 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
353 const DotDimensionNumbers& dimension_numbers) {
354 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
355 instruction->AppendOperand(lhs);
356 instruction->AppendOperand(rhs);
357 instruction->dot_dimension_numbers_ =
358 MakeUnique<DotDimensionNumbers>(dimension_numbers);
359 return instruction;
360 }
361
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)362 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
363 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
364 CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
365 CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
366
367 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
368 instruction->AppendOperand(lhs);
369 instruction->AppendOperand(rhs);
370 instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
371 instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
372 instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
373 return instruction;
374 }
375
376 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)377 HloInstruction::CreateReducePrecision(const Shape& shape,
378 HloInstruction* operand,
379 const int exponent_bits,
380 const int mantissa_bits) {
381 auto instruction =
382 WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
383 instruction->AppendOperand(operand);
384 instruction->exponent_bits_ = exponent_bits;
385 instruction->mantissa_bits_ = mantissa_bits;
386 return instruction;
387 }
388
389 /* static */ std::unique_ptr<HloInstruction>
CreateCrossReplicaSum(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands)390 HloInstruction::CreateCrossReplicaSum(
391 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
392 return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands);
393 }
394
CreateInfeed(const Shape & shape,const string & config)395 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
396 const Shape& shape, const string& config) {
397 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
398 instruction->set_infeed_config(config);
399 return instruction;
400 }
401
CreateOutfeed(const Shape & shape,HloInstruction * operand,tensorflow::StringPiece outfeed_config)402 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
403 const Shape& shape, HloInstruction* operand,
404 tensorflow::StringPiece outfeed_config) {
405 std::unique_ptr<HloInstruction> instruction =
406 WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
407 CHECK(ShapeUtil::Compatible(operand->shape(), shape))
408 << "Outfeed shape " << shape << " must be compatible with operand shape "
409 << operand->shape();
410 instruction->AppendOperand(operand);
411 instruction->outfeed_config_ = outfeed_config.ToString();
412 instruction->outfeed_shape_ = shape;
413 return instruction;
414 }
415
CreateSend(HloInstruction * operand,int64 channel_id)416 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
417 HloInstruction* operand, int64 channel_id) {
418 // Send instruction produces a tuple of {aliased operand, U32 context}.
419 Shape output_shape = ShapeUtil::MakeTupleShape(
420 {operand->shape(), ShapeUtil::MakeShape(U32, {})});
421 auto instruction =
422 WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
423 instruction->AppendOperand(operand);
424 instruction->channel_id_ = channel_id;
425 return instruction;
426 }
427
CreateSendDone(HloInstruction * operand)428 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
429 HloInstruction* operand) {
430 CHECK(operand->opcode() == HloOpcode::kSend)
431 << "SendDone must take the context operand from Send";
432 auto instruction = WrapUnique(
433 new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
434 instruction->AppendOperand(operand);
435 instruction->channel_id_ = operand->channel_id();
436 return instruction;
437 }
438
CreateRecv(const Shape & shape,int64 channel_id)439 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
440 const Shape& shape, int64 channel_id) {
441 // Recv instruction produces a tuple of {receive buffer, U32 context}.
442 Shape output_shape =
443 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
444 auto instruction =
445 WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
446 instruction->channel_id_ = channel_id;
447 return instruction;
448 }
449
CreateRecvDone(HloInstruction * operand)450 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
451 HloInstruction* operand) {
452 CHECK(operand->opcode() == HloOpcode::kRecv)
453 << "RecvDone must take the context operand from Recv";
454 Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
455 auto instruction =
456 WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
457 instruction->AppendOperand(operand);
458 instruction->channel_id_ = operand->channel_id();
459 return instruction;
460 }
461
CreateReverse(const Shape & shape,HloInstruction * operand,tensorflow::gtl::ArraySlice<int64> dimensions)462 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
463 const Shape& shape, HloInstruction* operand,
464 tensorflow::gtl::ArraySlice<int64> dimensions) {
465 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
466 instruction->AppendOperand(operand);
467 instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
468 return instruction;
469 }
470
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)471 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
472 const Shape& shape, HloComputation* condition, HloComputation* body,
473 HloInstruction* init) {
474 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
475 instruction->AppendOperand(init);
476 // Body comes before condition computation in the vector.
477 instruction->called_computations_.push_back(body);
478 instruction->called_computations_.push_back(condition);
479 return instruction;
480 }
481
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)482 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
483 const Shape& shape, HloInstruction* pred,
484 HloInstruction* true_computation_arg, HloComputation* true_computation,
485 HloInstruction* false_computation_arg, HloComputation* false_computation) {
486 auto instruction =
487 WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
488 instruction->AppendOperand(pred);
489 instruction->AppendOperand(true_computation_arg);
490 instruction->AppendOperand(false_computation_arg);
491 // In called_computations_, the index of true_computation must be 0 and that
492 // of false computation must be 1, as defined by kTrueComputationIndex and
493 // kFalseComputationIndex.
494 instruction->called_computations_.push_back(true_computation);
495 instruction->called_computations_.push_back(false_computation);
496 return instruction;
497 }
498
CreateSlice(const Shape & shape,HloInstruction * operand,tensorflow::gtl::ArraySlice<int64> start_indices,tensorflow::gtl::ArraySlice<int64> limit_indices,tensorflow::gtl::ArraySlice<int64> strides)499 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
500 const Shape& shape, HloInstruction* operand,
501 tensorflow::gtl::ArraySlice<int64> start_indices,
502 tensorflow::gtl::ArraySlice<int64> limit_indices,
503 tensorflow::gtl::ArraySlice<int64> strides) {
504 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
505 instruction->AppendOperand(operand);
506 instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
507 instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
508 instruction->slice_strides_.assign(strides.begin(), strides.end());
509 // For backward compatibility with old serialized computations: if there are
510 // no strides, assume all strides are 1.
511 // TODO(b/63317920): remove this code.
512 if (instruction->slice_strides_.empty()) {
513 instruction->slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
514 }
515 return instruction;
516 }
517
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,tensorflow::gtl::ArraySlice<int64> slice_sizes)518 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
519 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
520 tensorflow::gtl::ArraySlice<int64> slice_sizes) {
521 auto instruction =
522 WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape));
523 instruction->AppendOperand(operand);
524 instruction->AppendOperand(start_indices);
525 instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(),
526 slice_sizes.end());
527 return instruction;
528 }
529
530 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)531 HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
532 HloInstruction* operand,
533 HloInstruction* update,
534 HloInstruction* start_indices) {
535 auto instruction =
536 WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
537 instruction->AppendOperand(operand);
538 instruction->AppendOperand(update);
539 instruction->AppendOperand(start_indices);
540 return instruction;
541 }
542
CreateConcatenate(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,int64 dimension)543 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
544 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
545 int64 dimension) {
546 auto instruction =
547 WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
548 for (auto operand : operands) {
549 instruction->AppendOperand(operand);
550 }
551 instruction->dimensions_.push_back(dimension);
552 return instruction;
553 }
554
CreateConvert(const Shape & shape,HloInstruction * operand)555 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
556 const Shape& shape, HloInstruction* operand) {
557 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
558 instruction->AppendOperand(operand);
559 return instruction;
560 }
561
562 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)563 HloInstruction::CreateBitcastConvert(const Shape& shape,
564 HloInstruction* operand) {
565 auto instruction =
566 WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
567 instruction->AppendOperand(operand);
568 return instruction;
569 }
570
CreateReduce(const Shape & shape,HloInstruction * arg,HloInstruction * init_value,tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,HloComputation * reduce_computation)571 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
572 const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
573 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
574 HloComputation* reduce_computation) {
575 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
576 instruction->AppendOperand(arg);
577 instruction->AppendOperand(init_value);
578 instruction->dimensions_.assign(dimensions_to_reduce.begin(),
579 dimensions_to_reduce.end());
580 instruction->called_computations_.push_back(reduce_computation);
581 return instruction;
582 }
583
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)584 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
585 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
586 const Window& window, HloComputation* reduce_computation) {
587 auto instruction =
588 WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
589 instruction->AppendOperand(operand);
590 instruction->AppendOperand(init_value);
591 instruction->called_computations_.push_back(reduce_computation);
592 instruction->window_ = MakeUnique<Window>(window);
593 return instruction;
594 }
595
596 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)597 HloInstruction::CreateBatchNormTraining(const Shape& shape,
598 HloInstruction* operand,
599 HloInstruction* scale,
600 HloInstruction* offset, float epsilon,
601 int64 feature_index) {
602 auto instruction =
603 WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape));
604 instruction->AppendOperand(operand);
605 instruction->AppendOperand(scale);
606 instruction->AppendOperand(offset);
607 instruction->epsilon_ = epsilon;
608 instruction->feature_index_ = feature_index;
609 return instruction;
610 }
611
612 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)613 HloInstruction::CreateBatchNormInference(
614 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
615 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
616 float epsilon, int64 feature_index) {
617 auto instruction =
618 WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
619 instruction->AppendOperand(operand);
620 instruction->AppendOperand(scale);
621 instruction->AppendOperand(offset);
622 instruction->AppendOperand(mean);
623 instruction->AppendOperand(variance);
624 instruction->epsilon_ = epsilon;
625 instruction->feature_index_ = feature_index;
626 return instruction;
627 }
628
629 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)630 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
631 HloInstruction* scale, HloInstruction* mean,
632 HloInstruction* variance,
633 HloInstruction* grad_output, float epsilon,
634 int64 feature_index) {
635 auto instruction =
636 WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape));
637 instruction->AppendOperand(operand);
638 instruction->AppendOperand(scale);
639 instruction->AppendOperand(mean);
640 instruction->AppendOperand(variance);
641 instruction->AppendOperand(grad_output);
642 instruction->epsilon_ = epsilon;
643 instruction->feature_index_ = feature_index;
644 return instruction;
645 }
646
647 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)648 HloInstruction::CreateSelectAndScatter(
649 const Shape& shape, HloInstruction* operand, HloComputation* select,
650 const Window& window, HloInstruction* source, HloInstruction* init_value,
651 HloComputation* scatter) {
652 auto instruction =
653 WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape));
654 instruction->AppendOperand(operand);
655 instruction->AppendOperand(source);
656 instruction->AppendOperand(init_value);
657 // Select comes before scatter in the vector.
658 instruction->called_computations_.push_back(select);
659 instruction->called_computations_.push_back(scatter);
660 instruction->window_ = MakeUnique<Window>(window);
661 return instruction;
662 }
663
CreateBroadcast(const Shape & shape,HloInstruction * operand,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)664 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
665 const Shape& shape, HloInstruction* operand,
666 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
667 auto instruction =
668 WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
669 instruction->AppendOperand(operand);
670 instruction->dimensions_.assign(broadcast_dimensions.begin(),
671 broadcast_dimensions.end());
672 return instruction;
673 }
674
675 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)676 HloInstruction::CreateBroadcastSequence(
677 const Shape& output_shape, HloInstruction* operand,
678 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
679 adder) {
680 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
681 ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
682 Shape broadcast_shape = ShapeUtil::ChangeElementType(
683 output_shape, operand->shape().element_type());
684 // Do explicit broadcast for scalar.
685 if (ShapeUtil::IsScalar(operand->shape())) {
686 auto broadcast =
687 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
688 broadcast->set_metadata(operand->metadata());
689 if (operand->has_sharding()) {
690 broadcast->set_sharding(operand->sharding());
691 }
692 return broadcast;
693 }
694 // Do explicit broadcast for degenerate broadcast.
695 std::vector<int64> broadcast_dimensions;
696 std::vector<int64> reshaped_dimensions;
697 for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
698 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
699 broadcast_dimensions.push_back(i);
700 reshaped_dimensions.push_back(operand->shape().dimensions(i));
701 } else {
702 CHECK_EQ(operand->shape().dimensions(i), 1)
703 << "An explicit broadcast sequence requires the broadcasted "
704 "dimensions to be trivial; operand: "
705 << operand->ToString() << "; output_shape: " << output_shape;
706 }
707 }
708 // Eliminate the size one dimensions.
709 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
710 ShapeUtil::MakeShape(operand->shape().element_type(),
711 reshaped_dimensions),
712 operand));
713 reshaped_operand->set_metadata(operand->metadata());
714 if (operand->has_sharding()) {
715 reshaped_operand->set_sharding(operand->sharding());
716 }
717 // Broadcast 'reshape' up to the larger size.
718 auto broadcast = HloInstruction::CreateBroadcast(
719 broadcast_shape, reshaped_operand, broadcast_dimensions);
720 broadcast->set_metadata(operand->metadata());
721 if (operand->has_sharding()) {
722 broadcast->set_sharding(operand->sharding());
723 }
724 return broadcast;
725 }
726
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)727 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
728 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
729 const PaddingConfig& padding_config) {
730 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape));
731 instruction->AppendOperand(operand);
732 instruction->AppendOperand(padding_value);
733 instruction->padding_config_ = MakeUnique<PaddingConfig>(padding_config);
734 return instruction;
735 }
736
CreateReshape(const Shape & shape,HloInstruction * operand)737 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
738 const Shape& shape, HloInstruction* operand) {
739 CHECK_EQ(ShapeUtil::ElementsIn(shape),
740 ShapeUtil::ElementsIn(operand->shape()))
741 << "shape: " << ShapeUtil::HumanString(shape)
742 << " operand: " << ShapeUtil::HumanString(operand->shape());
743 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
744 instruction->AppendOperand(operand);
745 return instruction;
746 }
747
CreateTranspose(const Shape & shape,HloInstruction * operand,tensorflow::gtl::ArraySlice<int64> dimensions)748 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
749 const Shape& shape, HloInstruction* operand,
750 tensorflow::gtl::ArraySlice<int64> dimensions) {
751 CHECK_EQ(shape.dimensions().size(), dimensions.size());
752 CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
753 CHECK(std::equal(operand->shape().dimensions().begin(),
754 operand->shape().dimensions().end(),
755 Permute(dimensions, shape.dimensions()).begin()))
756 << "shape: " << ShapeUtil::HumanString(shape)
757 << ", operand->shape(): " << ShapeUtil::HumanString(shape)
758 << ", dimensions: {" << Join(dimensions, ", ") << "}";
759 auto instruction =
760 WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
761 instruction->AppendOperand(operand);
762 instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
763 return instruction;
764 }
765
766 // We put the fusion kind into the instruction's name for transpose-dot fusions,
767 // since those fusions are really just describing a type of dot rather than
768 // generating a novel computation.
FusionNodeName(HloInstruction::FusionKind fusion_kind)769 static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
770 switch (fusion_kind) {
771 case HloInstruction::FusionKind::kTransposeDot:
772 return "dot_fusion";
773 default:
774 return "fusion";
775 }
776 }
777
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)778 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
779 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
780 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
781 instruction->fusion_kind_ = fusion_kind;
782 instruction->name_ = FusionNodeName(fusion_kind);
783 instruction->set_parent(fused_root->parent());
784 instruction->set_metadata(fused_root->metadata());
785 instruction->CloneAndFuseInternal(fused_root);
786 return instruction;
787 }
788
CreateFusion(const Shape & shape,FusionKind fusion_kind,tensorflow::gtl::ArraySlice<HloInstruction * > operands,HloComputation * fusion_computation)789 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
790 const Shape& shape, FusionKind fusion_kind,
791 tensorflow::gtl::ArraySlice<HloInstruction*> operands,
792 HloComputation* fusion_computation) {
793 auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
794 for (auto operand : operands) {
795 instruction->AppendOperand(operand);
796 }
797 instruction->fusion_kind_ = fusion_kind;
798 instruction->name_ = FusionNodeName(fusion_kind);
799 instruction->called_computations_.push_back(fusion_computation);
800 fusion_computation->SetFusionInstruction(instruction.get());
801 return instruction;
802 }
803
AddFusionOperand(HloInstruction * new_operand)804 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
805 CHECK_EQ(opcode(), HloOpcode::kFusion);
806 CHECK_EQ(operand_count(),
807 fused_instructions_computation()->parameter_instructions().size());
808 const int64 param_no = operand_count();
809 // Name the parameter after the instruction it represents in the outer
810 // (non-fusion) computation.
811 string param_name = StrCat(new_operand->name(), ".param_", param_no);
812 HloInstruction* fused_parameter =
813 fused_instructions_computation()->AddParameter(
814 HloInstruction::CreateParameter(param_no, new_operand->shape(),
815 param_name));
816 AppendOperand(new_operand);
817 return fused_parameter;
818 }
819
MergeFusionInstruction(HloInstruction * instruction_to_merge)820 void HloInstruction::MergeFusionInstruction(
821 HloInstruction* instruction_to_merge) {
822 CHECK_EQ(opcode_, HloOpcode::kFusion);
823 CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
824 CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
825 operands().end());
826 // Clone the instruction from which to merge fused instructions.
827 std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone();
828 // Replace uses of fused parameters with the corresponding operand of the
829 // fusion. Add all non-parameter fused instructions to 'unfused_instructions'
830 // to be merged into 'this'. This is done in reverse post order.
831 std::vector<HloInstruction*> unfused_instructions;
832 auto fused_instructions =
833 clone->fused_instructions_computation()->MakeInstructionPostOrder();
834 for (auto fused_it = fused_instructions.rbegin();
835 fused_it != fused_instructions.rend(); ++fused_it) {
836 auto fused_instruction = *fused_it;
837 if (fused_instruction->opcode() == HloOpcode::kParameter) {
838 TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith(
839 clone->mutable_operand(fused_instruction->parameter_number())));
840 } else {
841 unfused_instructions.push_back(fused_instruction);
842 }
843 }
844 CHECK(unfused_instructions.front() == clone->fused_expression_root());
845 // Replace instruction_to_merge use of 'this' with unfused_root.
846 TF_CHECK_OK(
847 instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
848 // Fuse 'unfused_instructions' into 'this'.
849 for (auto& instruction : unfused_instructions) {
850 FuseInstruction(instruction);
851 instruction->DetachFromOperands();
852 }
853 CHECK_EQ(0, clone->user_count());
854 clone->DetachFromOperands();
855 TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
856 clone->fused_instructions_computation()));
857 }
858
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)859 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
860 HloInstruction* instruction_to_merge) {
861 CHECK_EQ(opcode_, HloOpcode::kFusion);
862 CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
863 // Add all non-parameter fused instructions to 'unfused_instructions' to be
864 // merged into 'this'. `old_to_new' maps the instructions in the fused node
865 // to the disaseembled fusion instructions.
866 // Note that we add the unfused instructions to this->parent_ computation.
867 // This is necessary because the unique_id needs for an instruction and
868 // it's only added when inserting to the computation.
869 tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
870 std::vector<HloInstruction*> unfused_instructions;
871 auto computation_to_merge =
872 instruction_to_merge->fused_instructions_computation();
873 auto post_order = computation_to_merge->MakeInstructionPostOrder();
874 for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
875 auto fused_instruction = *rit;
876 if (fused_instruction->opcode() == HloOpcode::kParameter) {
877 InsertOrDie(&old_to_new, fused_instruction,
878 instruction_to_merge->mutable_operand(
879 fused_instruction->parameter_number()));
880 continue;
881 }
882
883 // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
884 // which clones again. This can be improved.
885 auto cloned_instruction =
886 parent_->AddInstruction(fused_instruction->Clone());
887 unfused_instructions.push_back(cloned_instruction);
888 InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
889 }
890 for (auto unfused_instruction : unfused_instructions) {
891 for (int64 index = 0; index < unfused_instruction->operand_count();
892 index++) {
893 auto new_operand =
894 FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
895 TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
896 }
897 }
898
899 HloInstruction* unfused_root = unfused_instructions.front();
900 TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
901
902 TF_CHECK_OK(
903 instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
904 if (GetModule()) {
905 TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
906 }
907
908 // Fuse the root instruction and generate multiple outputs.
909 FuseInstructionIntoMultiOutput(unfused_root);
910 TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
911 // The rest instructions are of normal fusing.
912 for (int64 i = 1; i < unfused_instructions.size(); i++) {
913 auto instruction = unfused_instructions[i];
914 FuseInstruction(instruction);
915 TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
916 }
917 }
918
FuseInstructionInternal(HloInstruction * instruction_to_fuse,bool add_output)919 HloInstruction* HloInstruction::FuseInstructionInternal(
920 HloInstruction* instruction_to_fuse, bool add_output) {
921 CHECK_EQ(opcode_, HloOpcode::kFusion);
922
923 // When add_output is false, this fusion instruction must be a user of
924 // instruction_to_fuse.
925 if (!add_output) {
926 CHECK(IsUserOf(instruction_to_fuse));
927 }
928 HloInstruction* fused_instruction =
929 CloneAndFuseInternal(instruction_to_fuse, add_output);
930 return fused_instruction;
931 }
932
CloneAndFuseInternal(HloInstruction * instruction_to_fuse,bool add_output)933 HloInstruction* HloInstruction::CloneAndFuseInternal(
934 HloInstruction* instruction_to_fuse, bool add_output) {
935 CHECK_EQ(opcode_, HloOpcode::kFusion);
936 CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
937 VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
938 HloInstruction* clone = nullptr;
939 if (called_computations_.empty()) {
940 // New fusion instruction. It should not be a multioutput instruction.
941 CHECK(!add_output);
942 auto builder = HloComputation::Builder("fused_computation", this);
943 builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
944 called_computations_.push_back(
945 CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
946 clone = fused_expression_root();
947 } else {
948 clone = fused_instructions_computation()->AddInstruction(
949 instruction_to_fuse->Clone(/*suffix=*/""));
950 // When add_output is false, instruction_to_fuse is necessarily an operand
951 // of the fusion instruction. After fusion this will no longer be the case.
952 // Remove the operand from the operand list and remove its corresponding
953 // fused parameter instruction. Renumber parameters as necessary to make
954 // parameter numbers consistent with their index in the
955 // fused_parameter_ vector.
956 bool in_operand_list = std::find(operands_.begin(), operands_.end(),
957 instruction_to_fuse) != operands_.end();
958 CHECK(add_output || in_operand_list);
959 const std::vector<HloInstruction*>& fused_parameters =
960 fused_instructions_computation()->parameter_instructions();
961 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
962 if (instruction_to_fuse == operands_[operand_num]) {
963 // replace the fused parameter instruction's uses with the clone.
964 HloInstruction* fused_parameter = fused_parameters[operand_num];
965 TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
966
967 // Remove the corresponding fused parameter and operand from their
968 // respective vectors.
969 TF_CHECK_OK(
970 fused_instructions_computation()->RemoveParameter(operand_num));
971 operands_.erase(operands_.begin() + operand_num);
972 break;
973 }
974 }
975 // We've cloned instruction_to_fuse into this fusion instruction, so this
976 // fusion instruction is no longer a use of instruction_to_fuse.
977 if (in_operand_list) {
978 instruction_to_fuse->RemoveUser(this);
979 // When the instruction_to_fuse does not have other users, we don't need
980 // to generate a multioutput fusion instruction.
981 if (instruction_to_fuse->user_count() == 0) {
982 add_output = false;
983 }
984 }
985 }
986
987 // Reread the parameters in the computation.
988 const std::vector<HloInstruction*>& fused_parameters =
989 fused_instructions_computation()->parameter_instructions();
990
991 // Add each operand of the clone as an operand of the fusion instruction. A
992 // complication is that some clone operands may already be operands of the
993 // fusion instruction.
994 for (int64 operand_num = 0; operand_num < clone->operand_count();
995 ++operand_num) {
996 HloInstruction* operand = clone->mutable_operand(operand_num);
997
998 // See if this operand is already an operand of the fusion node.
999 CHECK_EQ(operands_.size(), fused_parameters.size());
1000 HloInstruction* fused_param = nullptr;
1001 for (int64 i = 0; i < operands_.size(); ++i) {
1002 if (operands_[i] == operand) {
1003 fused_param = fused_parameters[i];
1004 break;
1005 }
1006 }
1007
1008 if (fused_param == nullptr) {
1009 // Clone's operand was not already an operand of the fusion
1010 // instruction. Add it as an operand and add a corresponding fused
1011 // parameter instruction.
1012 fused_param = AddFusionOperand(operand);
1013 }
1014 TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
1015 }
1016
1017 if (add_output) {
1018 CHECK_GT(instruction_to_fuse->user_count(), 0);
1019 // If this is already a multioutput fusion instruction, expand the root
1020 // tuple by 1.
1021 HloInstruction* fused_root = fused_expression_root();
1022 HloInstruction::InstructionVector tuple_elements;
1023 bool newly_created_tuple_instr = false;
1024 if (fused_root->opcode() == HloOpcode::kTuple) {
1025 tuple_elements = fused_root->operands();
1026 } else {
1027 tuple_elements.push_back(fused_root);
1028 newly_created_tuple_instr = true;
1029 }
1030 if (clone->opcode() == HloOpcode::kTuple) {
1031 for (auto inst : clone->operands()) {
1032 tuple_elements.push_back(inst);
1033 }
1034 } else {
1035 tuple_elements.push_back(clone);
1036 }
1037 HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
1038 HloInstruction::CreateTuple(tuple_elements));
1039 fused_instructions_computation()->set_root_instruction(new_root);
1040 shape_ = new_root->shape();
1041 if (fused_root->opcode() == HloOpcode::kTuple) {
1042 TF_CHECK_OK(
1043 fused_instructions_computation()->RemoveInstruction(fused_root));
1044 }
1045
1046 // If this is a newly created multioutput instruction, we need to update
1047 // the use of the original fusion instruction.
1048 if (newly_created_tuple_instr) {
1049 HloInstruction* new_instr = parent_->AddInstruction(
1050 HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
1051 TF_CHECK_OK(ReplaceAllUsesWith(new_instr));
1052 }
1053 int64 index = tuple_elements.size();
1054 if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
1055 index -= instruction_to_fuse->operand_count();
1056 std::vector<HloInstruction*> to_be_removed;
1057 for (auto old_gte : instruction_to_fuse->users()) {
1058 CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
1059 int64 old_tuple_index = old_gte->tuple_index();
1060 HloInstruction* new_gte =
1061 parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
1062 old_gte->shape(), this, index + old_tuple_index));
1063 TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
1064 to_be_removed.push_back(old_gte);
1065 }
1066 for (auto old_gte : to_be_removed) {
1067 TF_CHECK_OK(parent_->RemoveInstruction(old_gte));
1068 }
1069 TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
1070 } else {
1071 HloInstruction* new_gte =
1072 parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
1073 clone->shape(), this, index - 1));
1074 TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
1075 }
1076 }
1077
1078 VLOG(2) << "New clone:\n" << clone->ToString();
1079 return clone;
1080 }
1081
random_distribution() const1082 RandomDistribution HloInstruction::random_distribution() const {
1083 CHECK_EQ(opcode_, HloOpcode::kRng);
1084 return distribution_;
1085 }
1086
HasSideEffect() const1087 bool HloInstruction::HasSideEffect() const {
1088 switch (opcode_) {
1089 case HloOpcode::kSend:
1090 case HloOpcode::kSendDone:
1091 case HloOpcode::kRecv:
1092 case HloOpcode::kRecvDone:
1093 case HloOpcode::kRng:
1094 case HloOpcode::kInfeed:
1095 case HloOpcode::kOutfeed:
1096 case HloOpcode::kTrace:
1097 case HloOpcode::kHostCompute:
1098 return true;
1099 default: {
1100 // Check if any of the called computations has a side effect.
1101 for (const auto& computation : called_computations()) {
1102 if (computation->HasSideEffect()) {
1103 return true;
1104 }
1105 }
1106 return false;
1107 }
1108 }
1109 }
1110
CreateCall(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,HloComputation * computation)1111 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1112 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1113 HloComputation* computation) {
1114 std::unique_ptr<HloInstruction> instruction =
1115 WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1116 for (auto operand : operands) {
1117 instruction->AppendOperand(operand);
1118 }
1119 instruction->called_computations_.push_back(computation);
1120 return instruction;
1121 }
1122
CreateCustomCall(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,tensorflow::StringPiece custom_call_target)1123 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1124 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1125 tensorflow::StringPiece custom_call_target) {
1126 std::unique_ptr<HloInstruction> instruction =
1127 WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape));
1128 for (auto operand : operands) {
1129 instruction->AppendOperand(operand);
1130 }
1131 instruction->custom_call_target_ = custom_call_target.ToString();
1132 return instruction;
1133 }
1134
CreateHostCompute(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,tensorflow::StringPiece channel_name,const int64 cost_estimate_ns)1135 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
1136 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1137 tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
1138 std::unique_ptr<HloInstruction> instruction =
1139 WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape));
1140 for (auto operand : operands) {
1141 instruction->AppendOperand(operand);
1142 }
1143 instruction->channel_name_ = channel_name.ToString();
1144 instruction->cost_estimate_ns_ = cost_estimate_ns;
1145 return instruction;
1146 }
1147
CreateTuple(tensorflow::gtl::ArraySlice<HloInstruction * > elements)1148 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1149 tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
1150 std::vector<Shape> element_shapes;
1151 for (auto element : elements) {
1152 element_shapes.push_back(element->shape());
1153 }
1154 Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1155 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1156 }
1157
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * gather_indices,const GatherDimensionNumbers & gather_dim_numbers,tensorflow::gtl::ArraySlice<int64> window_bounds)1158 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1159 const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
1160 const GatherDimensionNumbers& gather_dim_numbers,
1161 tensorflow::gtl::ArraySlice<int64> window_bounds) {
1162 std::unique_ptr<HloInstruction> instruction =
1163 WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
1164 instruction->AppendOperand(operand);
1165 instruction->AppendOperand(gather_indices);
1166 instruction->gather_dimension_numbers_ =
1167 MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
1168 c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
1169 return instruction;
1170 }
1171
MakeGatherDimNumbers(tensorflow::gtl::ArraySlice<int64> output_window_dims,tensorflow::gtl::ArraySlice<int64> elided_window_dims,tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims)1172 /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
1173 tensorflow::gtl::ArraySlice<int64> output_window_dims,
1174 tensorflow::gtl::ArraySlice<int64> elided_window_dims,
1175 tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) {
1176 GatherDimensionNumbers gather_dim_numbers;
1177 for (int64 output_window_dim : output_window_dims) {
1178 gather_dim_numbers.add_output_window_dims(output_window_dim);
1179 }
1180 for (int64 elided_window_dim : elided_window_dims) {
1181 gather_dim_numbers.add_elided_window_dims(elided_window_dim);
1182 }
1183 for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
1184 gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
1185 }
1186
1187 return gather_dim_numbers;
1188 }
1189
CloneWithNewOperands(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > new_operands,HloModule * module) const1190 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1191 const Shape& shape,
1192 tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
1193 HloModule* module) const {
1194 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1195 VLOG(3) << " new operands:";
1196 for (const HloInstruction* new_operand : new_operands) {
1197 VLOG(3) << " %" << new_operand->name();
1198 }
1199
1200 std::unique_ptr<HloInstruction> clone;
1201
1202 // Explicitly call the factory for the instruction type. This is more robust
1203 // in the face of code changes than copying fields explicitly. This also
1204 // properly sets the user fields of the operands.
1205 switch (opcode_) {
1206 // Unary ops.
1207 case HloOpcode::kAbs:
1208 case HloOpcode::kRoundNearestAfz:
1209 case HloOpcode::kBitcast:
1210 case HloOpcode::kCeil:
1211 case HloOpcode::kCopy:
1212 case HloOpcode::kCos:
1213 case HloOpcode::kExp:
1214 case HloOpcode::kImag:
1215 case HloOpcode::kIsFinite:
1216 case HloOpcode::kFloor:
1217 case HloOpcode::kLog:
1218 case HloOpcode::kNot:
1219 case HloOpcode::kNegate:
1220 case HloOpcode::kReal:
1221 case HloOpcode::kSign:
1222 case HloOpcode::kSin:
1223 case HloOpcode::kSort:
1224 case HloOpcode::kTanh:
1225 CHECK_EQ(new_operands.size(), 1);
1226 clone = CreateUnary(shape, opcode_, new_operands[0]);
1227 break;
1228 // Binary ops.
1229 case HloOpcode::kAdd:
1230 case HloOpcode::kAtan2:
1231 case HloOpcode::kComplex:
1232 case HloOpcode::kDivide:
1233 case HloOpcode::kMultiply:
1234 case HloOpcode::kSubtract:
1235 case HloOpcode::kEq:
1236 case HloOpcode::kGe:
1237 case HloOpcode::kGt:
1238 case HloOpcode::kLe:
1239 case HloOpcode::kLt:
1240 case HloOpcode::kNe:
1241 case HloOpcode::kMaximum:
1242 case HloOpcode::kMinimum:
1243 case HloOpcode::kPower:
1244 case HloOpcode::kRemainder:
1245 case HloOpcode::kAnd:
1246 case HloOpcode::kOr:
1247 case HloOpcode::kShiftLeft:
1248 case HloOpcode::kShiftRightArithmetic:
1249 case HloOpcode::kShiftRightLogical:
1250 CHECK_EQ(new_operands.size(), 2);
1251 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1252 break;
1253 // Ternary ops.
1254 case HloOpcode::kClamp:
1255 case HloOpcode::kSelect:
1256 CHECK_EQ(new_operands.size(), 3);
1257 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1258 new_operands[2]);
1259 break;
1260 // Other supported ops.
1261 case HloOpcode::kBroadcast:
1262 CHECK_EQ(new_operands.size(), 1);
1263 clone = CreateBroadcast(shape, new_operands[0], dimensions_);
1264 break;
1265 case HloOpcode::kCall:
1266 clone = CreateCall(shape, new_operands, to_apply());
1267 break;
1268 case HloOpcode::kCustomCall:
1269 clone = CreateCustomCall(shape, new_operands, custom_call_target_);
1270 break;
1271 case HloOpcode::kHostCompute:
1272 clone = CreateHostCompute(shape, new_operands, channel_name_,
1273 cost_estimate_ns_);
1274 break;
1275 case HloOpcode::kConcatenate:
1276 clone = CreateConcatenate(shape, new_operands, dimensions(0));
1277 break;
1278 case HloOpcode::kConvert:
1279 CHECK_EQ(new_operands.size(), 1);
1280 clone = CreateConvert(shape, new_operands[0]);
1281 break;
1282 case HloOpcode::kBitcastConvert:
1283 CHECK_EQ(new_operands.size(), 1);
1284 clone = CreateBitcastConvert(shape, new_operands[0]);
1285 break;
1286 case HloOpcode::kReducePrecision:
1287 CHECK_EQ(new_operands.size(), 1);
1288 clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
1289 mantissa_bits_);
1290 break;
1291 case HloOpcode::kConvolution:
1292 CHECK_EQ(new_operands.size(), 2);
1293 clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
1294 *convolution_dimension_numbers_);
1295 break;
1296 case HloOpcode::kDot:
1297 CHECK_EQ(new_operands.size(), 2);
1298 clone = CreateDot(shape, new_operands[0], new_operands[1],
1299 *dot_dimension_numbers_);
1300 break;
1301 case HloOpcode::kFft:
1302 CHECK_EQ(new_operands.size(), 1);
1303 return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
1304 case HloOpcode::kCrossReplicaSum:
1305 clone = CreateCrossReplicaSum(shape, new_operands);
1306 break;
1307 case HloOpcode::kGetTupleElement:
1308 CHECK_EQ(new_operands.size(), 1);
1309 clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
1310 break;
1311 case HloOpcode::kMap:
1312 clone = CreateMap(shape, new_operands, to_apply());
1313 break;
1314 case HloOpcode::kPad:
1315 CHECK_EQ(new_operands.size(), 2);
1316 clone =
1317 CreatePad(shape, new_operands[0], new_operands[1], *padding_config_);
1318 break;
1319 case HloOpcode::kReduce:
1320 CHECK_EQ(new_operands.size(), 2);
1321 clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_,
1322 to_apply());
1323 break;
1324 case HloOpcode::kReduceWindow:
1325 CHECK_EQ(new_operands.size(), 2);
1326 clone = CreateReduceWindow(shape, new_operands[0], new_operands[1],
1327 *window_, to_apply());
1328 break;
1329 case HloOpcode::kSelectAndScatter:
1330 CHECK_EQ(new_operands.size(), 3);
1331 clone =
1332 CreateSelectAndScatter(shape, new_operands[0], select(), *window_,
1333 new_operands[1], new_operands[2], scatter());
1334 break;
1335 case HloOpcode::kReverse:
1336 CHECK_EQ(new_operands.size(), 1);
1337 clone = CreateReverse(shape, new_operands[0], dimensions_);
1338 break;
1339 case HloOpcode::kRng:
1340 clone = CreateRng(shape, distribution_, new_operands);
1341 break;
1342 case HloOpcode::kReshape:
1343 CHECK_EQ(new_operands.size(), 1);
1344 clone = CreateReshape(shape, new_operands[0]);
1345 break;
1346 case HloOpcode::kSlice:
1347 CHECK_EQ(new_operands.size(), 1);
1348 clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
1349 slice_strides_);
1350 break;
1351 case HloOpcode::kDynamicSlice:
1352 clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1],
1353 dynamic_slice_sizes_);
1354 break;
1355 case HloOpcode::kDynamicUpdateSlice:
1356 CHECK_EQ(new_operands.size(), 3);
1357 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1358 new_operands[2]);
1359 break;
1360 case HloOpcode::kTranspose:
1361 CHECK_EQ(new_operands.size(), 1);
1362 clone = CreateTranspose(shape, new_operands[0], dimensions_);
1363 break;
1364 case HloOpcode::kTuple:
1365 clone = CreateTuple(new_operands);
1366 *clone->mutable_shape() = shape;
1367 break;
1368 case HloOpcode::kWhile:
1369 CHECK_EQ(new_operands.size(), 1);
1370 clone =
1371 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1372 break;
1373 case HloOpcode::kConstant:
1374 clone = CreateConstant(literal_->CloneToUnique());
1375 break;
1376 case HloOpcode::kFusion:
1377 clone = CloneFusionWithNewOperands(shape, new_operands, module);
1378 break;
1379 case HloOpcode::kParameter:
1380 clone = CreateParameter(parameter_number_, shape, name_);
1381 break;
1382 case HloOpcode::kBatchNormTraining:
1383 CHECK_EQ(new_operands.size(), 3);
1384 clone =
1385 CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
1386 new_operands[2], epsilon(), feature_index());
1387 break;
1388 case HloOpcode::kBatchNormInference:
1389 CHECK_EQ(new_operands.size(), 5);
1390 clone = CreateBatchNormInference(
1391 shape, new_operands[0], new_operands[1], new_operands[2],
1392 new_operands[3], new_operands[4], epsilon(), feature_index());
1393 break;
1394 case HloOpcode::kInfeed:
1395 CHECK_EQ(new_operands.size(), 0);
1396 clone = CreateInfeed(shape, infeed_config());
1397 break;
1398 case HloOpcode::kOutfeed:
1399 CHECK_EQ(new_operands.size(), 1);
1400 clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
1401 break;
1402 case HloOpcode::kBatchNormGrad:
1403 CHECK_EQ(new_operands.size(), 5);
1404 clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
1405 new_operands[2], new_operands[3],
1406 new_operands[4], epsilon(), feature_index());
1407 break;
1408 case HloOpcode::kConditional:
1409 CHECK_EQ(new_operands.size(), 3);
1410 clone = CreateConditional(shape, new_operands[0], new_operands[1],
1411 true_computation(), new_operands[2],
1412 false_computation());
1413 break;
1414 case HloOpcode::kSend:
1415 CHECK_EQ(new_operands.size(), 1);
1416 clone = CreateSend(new_operands[0], channel_id());
1417 break;
1418 case HloOpcode::kSendDone:
1419 CHECK_EQ(new_operands.size(), 1);
1420 clone = CreateSendDone(new_operands[0]);
1421 break;
1422 case HloOpcode::kRecv:
1423 CHECK_EQ(new_operands.size(), 0);
1424 // The shape is a tuple, but CreateRecv() wants the raw data shape.
1425 clone =
1426 CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
1427 break;
1428 case HloOpcode::kRecvDone:
1429 CHECK_EQ(new_operands.size(), 1);
1430 clone = CreateRecvDone(new_operands[0]);
1431 break;
1432 case HloOpcode::kGather:
1433 CHECK_EQ(new_operands.size(), 2);
1434 clone = CreateGather(shape, new_operands[0], new_operands[1],
1435 *gather_dimension_numbers_, gather_window_bounds_);
1436 break;
1437 case HloOpcode::kTrace:
1438 LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
1439 }
1440 clone->set_metadata(metadata_);
1441 if (has_sharding()) {
1442 clone->set_sharding(sharding());
1443 }
1444 clone->set_parent(parent_);
1445 return clone;
1446 }
1447
~HloInstruction()1448 HloInstruction::~HloInstruction() {}
1449
Clone(const string & suffix,HloModule * module) const1450 std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
1451 HloModule* module) const {
1452 std::unique_ptr<HloInstruction> clone =
1453 CloneWithNewOperands(shape_, operands_, module);
1454 if (suffix.empty()) {
1455 clone->name_ = name();
1456 } else {
1457 // If an instruction is cloned multiple times avoid names like
1458 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1459 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1460 // clone of foo.suffix2 is named foo.suffix3 and so on.
1461 const string dot_suffix = "." + suffix;
1462 size_t index = name().rfind(dot_suffix);
1463 if (index == string::npos) {
1464 // Existing name does not include ".suffix".
1465 clone->name_ = name() + dot_suffix;
1466 } else {
1467 // Existing name includes ".suffix". Determine if substring after
1468 // ".suffix" is numeric and should be replaced with an incremented number.
1469 string after_suffix = name().substr(index + dot_suffix.size());
1470 if (after_suffix.empty()) {
1471 // Existing name ends in ".suffix". New name should end in ".suffix2".
1472 clone->name_ = name() + "2";
1473 } else {
1474 // If names ends with .suffix[0-9]+ then replace with a suffix with the
1475 // numeric value incremented.
1476 int64 numeric_suffix;
1477 if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
1478 clone->name_ =
1479 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1480 } else {
1481 // Substring after ".suffix" is non-numeric.
1482 clone->name_ = name() + dot_suffix;
1483 }
1484 }
1485 }
1486 }
1487 return clone;
1488 }
1489
CloneFusionWithNewOperands(const Shape & shape,tensorflow::gtl::ArraySlice<HloInstruction * > operands,HloModule * module) const1490 std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
1491 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1492 HloModule* module) const {
1493 CHECK_EQ(opcode_, HloOpcode::kFusion);
1494 CHECK(parent() != nullptr);
1495
1496 auto new_instruction =
1497 WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
1498 // Add the operands to our new fusion instruction.
1499 for (HloInstruction* new_operand : operands) {
1500 new_instruction->AppendOperand(new_operand);
1501 }
1502 // Clone all the fused instructions for the new fusion instruction.
1503 HloInstructionMap<HloInstruction*> old_to_new;
1504 std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
1505 // Create the list of fused parameters by mapping through the cloned,
1506 // fused instructions.
1507 for (HloInstruction* old_fused_parameter :
1508 fused_instructions_computation()->parameter_instructions()) {
1509 new_fused_instructions.push_back(
1510 old_fused_parameter->Clone("clone", module));
1511 HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
1512 InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
1513 }
1514 for (auto old_fused_instruction :
1515 fused_instructions_computation()->MakeInstructionPostOrder()) {
1516 if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
1517 FindOrDie(old_to_new, old_fused_instruction);
1518 continue;
1519 }
1520 std::vector<HloInstruction*> new_operands;
1521 for (int64 operand_idx = 0;
1522 operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
1523 HloInstruction* old_operand =
1524 old_fused_instruction->mutable_operand(operand_idx);
1525 new_operands.push_back(FindOrDie(old_to_new, old_operand));
1526 }
1527 new_fused_instructions.push_back(
1528 old_fused_instruction->CloneWithNewOperands(
1529 old_fused_instruction->shape(), new_operands, module));
1530 HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
1531 new_fused_instruction->set_parent(parent_);
1532 InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
1533 }
1534 new_instruction->fusion_kind_ = fusion_kind_;
1535 auto computation_builder = HloComputation::Builder(
1536 fused_instructions_computation()->name() + ".clone",
1537 new_instruction.get());
1538 // We iterated the fusion instructions in reverse post order which means
1539 // that we must reverse our new list of fusion instructions.
1540 for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
1541 new_fused_instruction_iter != new_fused_instructions.rend();
1542 ++new_fused_instruction_iter) {
1543 computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
1544 }
1545 if (module == nullptr) {
1546 module = GetModule();
1547 }
1548 auto fused_root_ = fused_expression_root();
1549 new_instruction->called_computations_.push_back(
1550 CHECK_NOTNULL(module)->AddEmbeddedComputation(
1551 computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
1552 return new_instruction;
1553 }
1554
1555 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1556 HloInstruction::LatestNonGteAncestorAndIndex() const {
1557 const HloInstruction* hlo = this;
1558 ShapeIndex index;
1559 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1560 index.push_back(hlo->tuple_index());
1561 hlo = hlo->operand(0);
1562 }
1563
1564 // We built up index in the reverse order from what we want.
1565 std::reverse(index.begin(), index.end());
1566
1567 return {hlo, index};
1568 }
1569
LatestNonGteAncestor() const1570 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1571 const HloInstruction* hlo = this;
1572 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1573 hlo = hlo->operand(0);
1574 }
1575 return hlo;
1576 }
1577
literal() const1578 const Literal& HloInstruction::literal() const {
1579 CHECK_EQ(HloOpcode::kConstant, opcode_);
1580 return *literal_;
1581 }
1582
CanHaveDimensionsField() const1583 bool HloInstruction::CanHaveDimensionsField() const {
1584 return (opcode() == HloOpcode::kReverse ||
1585 opcode() == HloOpcode::kConcatenate ||
1586 opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
1587 opcode() == HloOpcode::kTranspose);
1588 }
1589
dimensions() const1590 const std::vector<int64>& HloInstruction::dimensions() const {
1591 CHECK(CanHaveDimensionsField());
1592 return dimensions_;
1593 }
1594
dimensions(int64 index) const1595 int64 HloInstruction::dimensions(int64 index) const {
1596 return dimensions()[index];
1597 }
1598
concatenate_dimension() const1599 int64 HloInstruction::concatenate_dimension() const {
1600 CHECK(opcode() == HloOpcode::kConcatenate);
1601 CHECK_EQ(1, dimensions_.size());
1602 return dimensions(0);
1603 }
1604
tuple_index() const1605 int64 HloInstruction::tuple_index() const {
1606 CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
1607 return tuple_index_;
1608 }
1609
operand(int64 i) const1610 const HloInstruction* HloInstruction::operand(int64 i) const {
1611 return operands_[i];
1612 }
1613
mutable_operand(int64 i)1614 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1615 CHECK(operands_[i] != nullptr);
1616 return operands_[i];
1617 }
1618
operand_index(const HloInstruction * target) const1619 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1620 for (int64 i = 0; i < operand_count(); ++i) {
1621 if (target == operand(i)) {
1622 return i;
1623 }
1624 }
1625 LOG(FATAL) << "target was not an operand: " << target->ToString();
1626 }
1627
AddControlDependencyTo(HloInstruction * instruction)1628 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1629 TF_RET_CHECK(instruction->parent() == parent());
1630 if (std::find(control_successors_.begin(), control_successors_.end(),
1631 instruction) == control_successors_.end()) {
1632 control_successors_.push_back(instruction);
1633 TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
1634 instruction->control_predecessors_.end(),
1635 this) == instruction->control_predecessors_.end());
1636 instruction->control_predecessors_.push_back(this);
1637 }
1638 return Status::OK();
1639 }
1640
RemoveControlDependencyTo(HloInstruction * instruction)1641 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1642 auto succ_it = std::find(control_successors_.begin(),
1643 control_successors_.end(), instruction);
1644 TF_RET_CHECK(succ_it != control_successors_.end());
1645 control_successors_.erase(succ_it);
1646 auto pred_it = std::find(instruction->control_predecessors_.begin(),
1647 instruction->control_predecessors_.end(), this);
1648 TF_RET_CHECK(pred_it != instruction->control_predecessors_.end());
1649 instruction->control_predecessors_.erase(pred_it);
1650
1651 return Status::OK();
1652 }
1653
AppendOperand(HloInstruction * operand)1654 void HloInstruction::AppendOperand(HloInstruction* operand) {
1655 operands_.push_back(operand);
1656 operand->AddUser(this);
1657 }
1658
AddUser(HloInstruction * user)1659 void HloInstruction::AddUser(HloInstruction* user) {
1660 if (!ContainsKey(user_set_, user)) {
1661 user_set_.insert(user);
1662 users_.push_back(user);
1663 }
1664 }
1665
IsConstant() const1666 bool HloInstruction::IsConstant() const {
1667 return opcode_ == HloOpcode::kConstant;
1668 }
1669
HasConstantOperand() const1670 bool HloInstruction::HasConstantOperand() const {
1671 for (const HloInstruction* operand : operands_) {
1672 if (operand->IsConstant()) {
1673 return true;
1674 }
1675 }
1676 return false;
1677 }
1678
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,const std::function<bool (const Shape &,const Shape &)> & eq_shapes) const1679 bool HloInstruction::IdenticalSlowPath(
1680 const HloInstruction& other,
1681 const std::function<bool(const HloComputation*, const HloComputation*)>&
1682 eq_computations,
1683 const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
1684 // Perform opcode specific checks.
1685 switch (opcode()) {
1686 // The result of these instructions only depend upon their opcode and
1687 // operands.
1688 case HloOpcode::kAbs:
1689 case HloOpcode::kAtan2:
1690 case HloOpcode::kRoundNearestAfz:
1691 case HloOpcode::kAdd:
1692 case HloOpcode::kCeil:
1693 case HloOpcode::kClamp:
1694 case HloOpcode::kComplex:
1695 case HloOpcode::kCopy:
1696 case HloOpcode::kCos:
1697 case HloOpcode::kCrossReplicaSum:
1698 case HloOpcode::kDivide:
1699 case HloOpcode::kEq:
1700 case HloOpcode::kExp:
1701 case HloOpcode::kFloor:
1702 case HloOpcode::kGe:
1703 case HloOpcode::kGt:
1704 case HloOpcode::kImag:
1705 case HloOpcode::kIsFinite:
1706 case HloOpcode::kLe:
1707 case HloOpcode::kLog:
1708 case HloOpcode::kAnd:
1709 case HloOpcode::kNot:
1710 case HloOpcode::kOr:
1711 case HloOpcode::kLt:
1712 case HloOpcode::kMaximum:
1713 case HloOpcode::kMinimum:
1714 case HloOpcode::kMultiply:
1715 case HloOpcode::kNe:
1716 case HloOpcode::kNegate:
1717 case HloOpcode::kPower:
1718 case HloOpcode::kReal:
1719 case HloOpcode::kRemainder:
1720 case HloOpcode::kSelect:
1721 case HloOpcode::kShiftLeft:
1722 case HloOpcode::kShiftRightArithmetic:
1723 case HloOpcode::kShiftRightLogical:
1724 case HloOpcode::kSign:
1725 case HloOpcode::kSin:
1726 case HloOpcode::kSubtract:
1727 case HloOpcode::kTanh:
1728 case HloOpcode::kTuple:
1729 return true;
1730
1731 case HloOpcode::kFusion:
1732 return fusion_kind() == other.fusion_kind() &&
1733 eq_computations(fused_instructions_computation(),
1734 other.fused_instructions_computation());
1735
1736 // These opcodes have complex or special behavior so just return false.
1737 case HloOpcode::kRng:
1738 case HloOpcode::kTrace:
1739 case HloOpcode::kWhile:
1740 return false;
1741
1742 case HloOpcode::kParameter:
1743 return parameter_number() == other.parameter_number() &&
1744 // Check the shape too because `this` and `other` may be in
1745 // different HloComputations.
1746 eq_shapes(shape(), other.shape());
1747
1748 case HloOpcode::kBatchNormTraining:
1749 case HloOpcode::kBatchNormInference:
1750 case HloOpcode::kBatchNormGrad:
1751 return feature_index() == other.feature_index() &&
1752 epsilon() == other.epsilon();
1753
1754 // A constant is defined by the value in the literal.
1755 case HloOpcode::kConstant:
1756 return literal() == other.literal();
1757
1758 // A convert result is determined by the primitive type that the operand is
1759 // converted into.
1760 case HloOpcode::kConvert:
1761 case HloOpcode::kBitcastConvert:
1762 return shape().element_type() == other.shape().element_type();
1763
1764 // A reduce-precision operation is determined by the bit sizes.
1765 case HloOpcode::kReducePrecision:
1766 return exponent_bits() == other.exponent_bits() &&
1767 mantissa_bits() == other.mantissa_bits();
1768
1769 // Convolution has a window and dimensions.
1770 case HloOpcode::kConvolution:
1771 return protobuf_util::ProtobufEquals(window(), other.window()) &&
1772 protobuf_util::ProtobufEquals(
1773 convolution_dimension_numbers(),
1774 other.convolution_dimension_numbers());
1775 // Check dot dimension numbers.
1776 case HloOpcode::kDot:
1777 return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
1778 other.dot_dimension_numbers());
1779
1780 case HloOpcode::kGather:
1781 return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
1782 other.gather_dimension_numbers()) &&
1783 gather_window_bounds() == other.gather_window_bounds();
1784
1785 // FFT has various types & lengths.
1786 case HloOpcode::kFft:
1787 return fft_type() == other.fft_type() &&
1788 fft_length() == other.fft_length();
1789
1790 // Reduction results are determined by the reduction dimension and the
1791 // reduction computation.
1792 case HloOpcode::kReduce:
1793 return dimensions() == other.dimensions() &&
1794 eq_computations(to_apply(), other.to_apply());
1795 case HloOpcode::kReduceWindow:
1796 return eq_computations(to_apply(), other.to_apply()) &&
1797 protobuf_util::ProtobufEquals(window(), other.window());
1798
1799 // SelectAndScatter is determined by both select and scatter
1800 // computation as well as the window configuration.
1801 case HloOpcode::kSelectAndScatter:
1802 return eq_computations(select(), other.select()) &&
1803 eq_computations(scatter(), other.scatter()) &&
1804 protobuf_util::ProtobufEquals(window(), other.window());
1805
1806 case HloOpcode::kReshape:
1807 return eq_shapes(shape(), other.shape());
1808
1809 // Transpose result is determined by the final shape and the permutation.
1810 case HloOpcode::kTranspose:
1811 return eq_shapes(shape(), other.shape()) &&
1812 dimensions() == other.dimensions();
1813
1814 // Remaining instructions with special values.
1815 case HloOpcode::kBitcast:
1816 return eq_shapes(shape(), other.shape());
1817 case HloOpcode::kBroadcast:
1818 return eq_shapes(shape(), other.shape()) &&
1819 dimensions() == other.dimensions();
1820 case HloOpcode::kConcatenate:
1821 return dimensions() == other.dimensions();
1822 case HloOpcode::kGetTupleElement:
1823 return tuple_index() == other.tuple_index();
1824 case HloOpcode::kPad:
1825 return protobuf_util::ProtobufEquals(padding_config(),
1826 other.padding_config());
1827 case HloOpcode::kSlice:
1828 return slice_starts_ == other.slice_starts_ &&
1829 slice_limits_ == other.slice_limits_ &&
1830 slice_strides_ == other.slice_strides_;
1831 case HloOpcode::kDynamicSlice:
1832 return eq_shapes(shape(), other.shape()) &&
1833 dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
1834 case HloOpcode::kDynamicUpdateSlice:
1835 return eq_shapes(shape(), other.shape());
1836 case HloOpcode::kCall:
1837 case HloOpcode::kMap:
1838 return eq_computations(to_apply(), other.to_apply());
1839 case HloOpcode::kCustomCall:
1840 return custom_call_target_ == other.custom_call_target_;
1841 case HloOpcode::kReverse:
1842 return dimensions() == other.dimensions();
1843 case HloOpcode::kConditional:
1844 return eq_computations(true_computation(), other.true_computation()) &&
1845 eq_computations(false_computation(), other.false_computation());
1846
1847 // These opcodes are not yet supported.
1848 case HloOpcode::kInfeed:
1849 case HloOpcode::kOutfeed:
1850 case HloOpcode::kSort:
1851 case HloOpcode::kRecv:
1852 case HloOpcode::kRecvDone:
1853 case HloOpcode::kSend:
1854 case HloOpcode::kSendDone:
1855 case HloOpcode::kHostCompute:
1856 return false;
1857 }
1858 }
1859
IsRank2Transpose() const1860 bool HloInstruction::IsRank2Transpose() const {
1861 return (opcode_ == HloOpcode::kTranspose) &&
1862 dimensions_ == std::vector<int64>({1, 0}) &&
1863 shape_.dimensions_size() == 2 &&
1864 std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
1865 operands_[0]->shape_.dimensions().rbegin());
1866 }
1867
RemoveUser(HloInstruction * user)1868 void HloInstruction::RemoveUser(HloInstruction* user) {
1869 auto set_it = user_set_.find(user);
1870 CHECK(set_it != user_set_.end());
1871 user_set_.erase(set_it);
1872 // This is linear in the number of the users, but a vector provides a stable
1873 // iteration order and much faster traversal.
1874 auto vec_it = std::find(users_.begin(), users_.end(), user);
1875 CHECK(vec_it != users_.end());
1876 users_.erase(vec_it);
1877 }
1878
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)1879 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
1880 HloInstruction* new_producer) {
1881 TF_RET_CHECK(
1882 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
1883 << "this shape: " << ShapeUtil::HumanString(shape())
1884 << ", replacement shape: "
1885 << ShapeUtil::HumanString(new_producer->shape());
1886
1887 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
1888 << " with " << new_producer->name();
1889
1890 RemoveUser(user);
1891
1892 TF_RET_CHECK(
1893 std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
1894 std::replace(user->operands_.begin(), user->operands_.end(), this,
1895 new_producer);
1896 new_producer->AddUser(user);
1897 return Status::OK();
1898 }
1899
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)1900 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
1901 HloInstruction* new_operand) {
1902 TF_RET_CHECK(operand_num >= 0);
1903 TF_RET_CHECK(operand_num < operand_count());
1904 HloInstruction* old_operand = mutable_operand(operand_num);
1905 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
1906 new_operand->shape()))
1907 << old_operand->shape().ShortDebugString() << " is not compatible with "
1908 << new_operand->shape().ShortDebugString();
1909 operands_[operand_num] = new_operand;
1910
1911 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
1912 << new_operand->name() << ", was " << old_operand->name();
1913
1914 if (std::find(operands_.begin(), operands_.end(), old_operand) ==
1915 operands_.end()) {
1916 old_operand->RemoveUser(this);
1917 }
1918 new_operand->AddUser(this);
1919 return Status::OK();
1920 }
1921
ReplaceAllUsesWith(HloInstruction * new_producer)1922 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
1923 bool new_producer_is_user = false;
1924 for (HloInstruction* user : users()) {
1925 if (user == new_producer) {
1926 // It's possible that new_producer is a user of this instruction as might
1927 // be the case when replacing an instruction with a kCopy of itself. In
1928 // this case, don't do the replacement to avoid creating a cycle in the
1929 // graph. new_producer remains the only user of this instruction.
1930 new_producer_is_user = true;
1931 } else {
1932 std::replace(user->operands_.begin(), user->operands_.end(), this,
1933 new_producer);
1934 new_producer->AddUser(user);
1935 }
1936 }
1937 users_.clear();
1938 user_set_.clear();
1939 if (new_producer_is_user) {
1940 AddUser(new_producer);
1941 }
1942 if (parent_ && parent_->root_instruction() == this) {
1943 parent_->set_root_instruction(new_producer);
1944 }
1945
1946 return Status::OK();
1947 }
1948
DetachFromOperands()1949 void HloInstruction::DetachFromOperands() {
1950 VLOG(3) << "DetachFromOperands:\n " << ToString();
1951 CHECK_EQ(0, user_count());
1952 // An instruction may be repeated as an operand. To avoid calling RemoveUser
1953 // twice on the same operand, keep a set of already detached operands.
1954 std::set<HloInstruction*> detached_operands;
1955 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1956 HloInstruction* operand = operands_[operand_num];
1957 if (!ContainsKey(detached_operands, operand)) {
1958 operand->RemoveUser(this);
1959 detached_operands.insert(operand);
1960 }
1961 operands_[operand_num] = nullptr;
1962 }
1963 }
1964
to_apply() const1965 HloComputation* HloInstruction::to_apply() const {
1966 switch (opcode_) {
1967 case HloOpcode::kCall:
1968 case HloOpcode::kMap:
1969 case HloOpcode::kReduceWindow:
1970 case HloOpcode::kReduce:
1971 CHECK_EQ(called_computations_.size(), 1);
1972 return called_computations_[0];
1973 default:
1974 LOG(FATAL) << "Invalid opcode for to_apply(): "
1975 << HloOpcodeString(opcode());
1976 }
1977 }
1978
set_to_apply(HloComputation * computation)1979 void HloInstruction::set_to_apply(HloComputation* computation) {
1980 // Don't allow changing the computation for fused instructions so we don't
1981 // have to recompute called_instructions for the entire fusion instruction.
1982 CHECK(!IsFused());
1983 switch (opcode_) {
1984 case HloOpcode::kCall:
1985 case HloOpcode::kMap:
1986 case HloOpcode::kReduceWindow:
1987 case HloOpcode::kReduce:
1988 CHECK_EQ(called_computations_.size(), 1);
1989 called_computations_[0] = computation;
1990 break;
1991 default:
1992 LOG(FATAL) << "Invalid opcode for to_apply(): "
1993 << HloOpcodeString(opcode());
1994 }
1995 }
1996
custom_call_target() const1997 const string& HloInstruction::custom_call_target() const {
1998 CHECK_EQ(opcode_, HloOpcode::kCustomCall);
1999 return custom_call_target_;
2000 }
2001
outfeed_config() const2002 const string& HloInstruction::outfeed_config() const {
2003 CHECK_EQ(opcode_, HloOpcode::kOutfeed);
2004 return outfeed_config_;
2005 }
2006
while_condition() const2007 HloComputation* HloInstruction::while_condition() const {
2008 CHECK_EQ(HloOpcode::kWhile, opcode_);
2009 return called_computations_[kConditionComputationIndex];
2010 }
2011
while_body() const2012 HloComputation* HloInstruction::while_body() const {
2013 CHECK_EQ(HloOpcode::kWhile, opcode_);
2014 return called_computations_[kBodyComputationIndex];
2015 }
2016
set_while_condition(HloComputation * computation)2017 void HloInstruction::set_while_condition(HloComputation* computation) {
2018 // Don't allow changing the computation for fused instructions so we don't
2019 // have to recompute called_instructions for the entire fusion instruction.
2020 CHECK(!IsFused());
2021 CHECK_EQ(HloOpcode::kWhile, opcode_);
2022 called_computations_[kConditionComputationIndex] = computation;
2023 }
2024
set_while_body(HloComputation * computation)2025 void HloInstruction::set_while_body(HloComputation* computation) {
2026 // Don't allow changing the computation for fused instructions so we don't
2027 // have to recompute called_instructions for the entire fusion instruction.
2028 CHECK(!IsFused());
2029 CHECK_EQ(HloOpcode::kWhile, opcode_);
2030 called_computations_[kBodyComputationIndex] = computation;
2031 }
2032
select() const2033 HloComputation* HloInstruction::select() const {
2034 CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
2035 return called_computations_[kSelectComputationIndex];
2036 }
2037
scatter() const2038 HloComputation* HloInstruction::scatter() const {
2039 CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
2040 return called_computations_[kScatterComputationIndex];
2041 }
2042
set_select(HloComputation * computation)2043 void HloInstruction::set_select(HloComputation* computation) {
2044 // Don't allow changing the computation for fused instructions so we don't
2045 // have to recompute called_instructions for the entire fusion instruction.
2046 CHECK(!IsFused());
2047 CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
2048 called_computations_[kSelectComputationIndex] = computation;
2049 }
2050
set_scatter(HloComputation * computation)2051 void HloInstruction::set_scatter(HloComputation* computation) {
2052 // Don't allow changing the computation for fused instructions so we don't
2053 // have to recompute called_instructions for the entire fusion instruction.
2054 CHECK(!IsFused());
2055 CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
2056 called_computations_[kScatterComputationIndex] = computation;
2057 }
2058
true_computation() const2059 HloComputation* HloInstruction::true_computation() const {
2060 CHECK_EQ(HloOpcode::kConditional, opcode_);
2061 return called_computations_[kTrueComputationIndex];
2062 }
2063
false_computation() const2064 HloComputation* HloInstruction::false_computation() const {
2065 CHECK_EQ(HloOpcode::kConditional, opcode_);
2066 return called_computations_[kFalseComputationIndex];
2067 }
2068
set_true_computation(HloComputation * true_computation)2069 void HloInstruction::set_true_computation(HloComputation* true_computation) {
2070 // Don't allow changing the computation for fused instructions so we don't
2071 // have to recompute called_instructions for the entire fusion instruction.
2072 CHECK(!IsFused());
2073 CHECK_EQ(HloOpcode::kConditional, opcode_);
2074 called_computations_[kTrueComputationIndex] = true_computation;
2075 }
2076
set_false_computation(HloComputation * false_computation)2077 void HloInstruction::set_false_computation(HloComputation* false_computation) {
2078 // Don't allow changing the computation for fused instructions so we don't
2079 // have to recompute called_instructions for the entire fusion instruction.
2080 CHECK(!IsFused());
2081 CHECK_EQ(HloOpcode::kConditional, opcode_);
2082 called_computations_[kFalseComputationIndex] = false_computation;
2083 }
2084
SignatureString() const2085 string HloInstruction::SignatureString() const {
2086 string operands =
2087 Join(operands_, ", ", [](string* out, HloInstruction* operand) {
2088 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2089 });
2090 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2091 }
2092
2093 namespace {
2094
PrintName(const string & name,const HloPrintOptions & options)2095 string PrintName(const string& name, const HloPrintOptions& options) {
2096 return StrCat(options.print_percent() ? "%" : "", name);
2097 }
2098
2099 } // namespace
2100
ToString(const HloPrintOptions & options) const2101 string HloInstruction::ToString(const HloPrintOptions& options) const {
2102 string result =
2103 StrCat(PrintName(name(), options), " = ",
2104 ShapeUtil::HumanStringWithLayout(shape()), " ",
2105 HloOpcodeString(opcode()), "(", OperandsToString(options), ")");
2106 for (const string& extra : ExtraAttributesToString(options)) {
2107 StrAppend(&result, ", ", extra);
2108 }
2109 if (options.print_metadata() &&
2110 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2111 !metadata_.source_file().empty())) {
2112 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2113 }
2114 return result;
2115 }
2116
OperandsToString(const HloPrintOptions & options) const2117 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2118 string operands;
2119 if (opcode() == HloOpcode::kConstant) {
2120 // For constants, show the actual value in place of an empty operand list.
2121 if ((!ShapeUtil::IsTuple(shape()) &&
2122 ShapeUtil::ElementsIn(shape()) <= 10) ||
2123 options.print_large_constants()) {
2124 // Literal::ToString emits multidimensional arrays over multiple
2125 // lines. Compact this into one line by stripping out white space.
2126 string tmp = literal().ToString();
2127 std::replace(tmp.begin(), tmp.end(), '\n', ' ');
2128 std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
2129 bool first = true;
2130 // Concatenate elements in "v" with spaces separating them, but ignoring
2131 // empty entries.
2132 for (const auto& s : v) {
2133 if (s.empty()) {
2134 continue;
2135 }
2136 StrAppend(&operands, (first ? "" : " "), s);
2137 first = false;
2138 }
2139 } else {
2140 // Do not show large constants or tuples.
2141 operands = "{...}";
2142 }
2143 } else if (opcode() == HloOpcode::kParameter) {
2144 StrAppend(&operands, parameter_number_);
2145 } else {
2146 tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
2147 const int64 kMaxOperandsToShowIfCompact = 4;
2148 if (options.compact_operands() &&
2149 slice.size() > kMaxOperandsToShowIfCompact) {
2150 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2151 }
2152 operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
2153 std::vector<string> str;
2154 if (options.print_operand_shape()) {
2155 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2156 }
2157 if (!options.compact_operands()) {
2158 str.push_back(PrintName(operand->name(), options));
2159 }
2160 StrAppend(out, Join(str, " "));
2161 });
2162 const int64 remaining = operands_.size() - slice.size();
2163 if (slice.size() != operands_.size()) {
2164 StrAppend(&operands, ", ...(+", remaining, ")");
2165 }
2166 }
2167 return operands;
2168 }
2169
ExtraAttributesToString(const HloPrintOptions & options) const2170 std::vector<string> HloInstruction::ExtraAttributesToString(
2171 const HloPrintOptions& options) const {
2172 std::vector<string> extra;
2173 if (opcode() == HloOpcode::kFusion) {
2174 extra.push_back(StrCat("kind=", xla::ToString(fusion_kind())));
2175 }
2176 if (CanHaveDimensionsField()) {
2177 extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
2178 }
2179 if (window_ != nullptr && window_->dimensions_size() != 0) {
2180 extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
2181 }
2182 if (padding_config_ != nullptr) {
2183 extra.push_back(
2184 StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
2185 }
2186 if (opcode() == HloOpcode::kSlice) {
2187 std::vector<string> bounds;
2188 bounds.reserve(slice_starts_.size());
2189 const bool omit_stride =
2190 std::all_of(slice_strides_.begin(), slice_strides_.end(),
2191 [](int64 stride) { return stride == 1; });
2192 for (int i = 0; i < slice_starts_.size(); ++i) {
2193 string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
2194 bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i],
2195 stride_str, "]"));
2196 }
2197 extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
2198 }
2199 if (opcode() == HloOpcode::kDynamicSlice) {
2200 extra.push_back(
2201 StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
2202 }
2203 if (opcode() == HloOpcode::kBatchNormTraining ||
2204 opcode() == HloOpcode::kBatchNormInference ||
2205 opcode() == HloOpcode::kBatchNormGrad) {
2206 extra.push_back(StrCat("epsilon=", epsilon()));
2207 extra.push_back(StrCat("feature_index=", feature_index()));
2208 }
2209
2210 if (convolution_dimension_numbers_ != nullptr) {
2211 extra.push_back(ConvolutionDimensionNumbersToString());
2212 }
2213 if (dot_dimension_numbers_ != nullptr) {
2214 extra.push_back(DotDimensionNumbersToString());
2215 }
2216 if (gather_dimension_numbers_ != nullptr) {
2217 extra.push_back(GatherDimensionNumbersToString());
2218 extra.push_back(
2219 StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
2220 }
2221 if (opcode() == HloOpcode::kFft) {
2222 extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
2223 extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
2224 }
2225
2226 if (options.print_subcomputation_references()) {
2227 if (opcode() == HloOpcode::kWhile) {
2228 extra.push_back(
2229 StrCat("condition=", PrintName(while_condition()->name(), options)));
2230 extra.push_back(
2231 StrCat("body=", PrintName(while_body()->name(), options)));
2232 } else if (opcode() == HloOpcode::kSelectAndScatter) {
2233 extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
2234 extra.push_back(
2235 StrCat("scatter=", PrintName(scatter()->name(), options)));
2236 } else if (opcode() == HloOpcode::kConditional) {
2237 extra.push_back(StrCat("true_computation=",
2238 PrintName(true_computation()->name(), options)));
2239 extra.push_back(StrCat("false_computation=",
2240 PrintName(false_computation()->name(), options)));
2241 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2242 opcode() == HloOpcode::kReduceWindow ||
2243 opcode() == HloOpcode::kReduce) {
2244 extra.push_back(
2245 StrCat("to_apply=", PrintName(to_apply()->name(), options)));
2246 } else if (!called_computations().empty()) {
2247 extra.push_back(StrCat(
2248 "calls=", Join(called_computations(), ", ",
2249 [&](string* out, const HloComputation* computation) {
2250 StrAppend(out,
2251 PrintName(computation->name(), options));
2252 })));
2253 }
2254 }
2255
2256 if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
2257 opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
2258 extra.push_back(StrCat("channel_id=", channel_id_));
2259 }
2260
2261 if (opcode() == HloOpcode::kGetTupleElement) {
2262 extra.push_back(StrCat("index=", tuple_index()));
2263 }
2264 if (has_sharding()) {
2265 extra.push_back(StrCat("sharding=", sharding().ToString()));
2266 }
2267 if (!control_predecessors_.empty()) {
2268 extra.push_back(StrCat("control-predecessors={",
2269 Join(control_predecessors_, ", ",
2270 [&](string* out, HloInstruction* pre) {
2271 StrAppend(out,
2272 PrintName(pre->name(), options));
2273 }),
2274 "}"));
2275 }
2276 if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
2277 extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
2278 }
2279 if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
2280 extra.push_back(
2281 StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
2282 }
2283 if (opcode() == HloOpcode::kRng) {
2284 extra.push_back(
2285 StrCat("distribution=", RandomDistributionToString(distribution_)));
2286 }
2287 if (opcode() == HloOpcode::kReducePrecision) {
2288 extra.push_back(StrCat("exponent_bits=", exponent_bits_));
2289 extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
2290 }
2291
2292 // By contract, we print the custom call target even if
2293 // !options.print_subcomputation_references(), because the call target is not
2294 // an HloComputation.
2295 if (opcode() == HloOpcode::kCustomCall) {
2296 extra.push_back(
2297 StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
2298 }
2299 return extra;
2300 }
2301
ToShortString() const2302 string HloInstruction::ToShortString() const {
2303 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2304 Join(operands_, ", ",
2305 [](string* out, HloInstruction* operand) {
2306 StrAppend(out, "%", operand->name());
2307 }),
2308 ")");
2309 }
2310
ToProto() const2311 HloInstructionProto HloInstruction::ToProto() const {
2312 HloInstructionProto proto;
2313 proto.set_name(name_);
2314 proto.set_opcode(HloOpcodeString(opcode_));
2315 *proto.mutable_shape() = shape_;
2316 for (const HloInstruction* operand : operands_) {
2317 *proto.add_operand_names() = operand->name();
2318 }
2319 for (const HloInstruction* control : control_predecessors_) {
2320 *proto.add_control_predecessor_names() = control->name();
2321 }
2322
2323 *proto.mutable_metadata() = metadata_;
2324 if (literal_ != nullptr) {
2325 *proto.mutable_literal() = literal_->ToProto();
2326 }
2327 proto.set_parameter_number(parameter_number_);
2328 if (opcode() == HloOpcode::kFusion) {
2329 proto.set_fusion_kind(xla::ToString(fusion_kind()));
2330 *proto.mutable_fused_instructions_computation() =
2331 fused_instructions_computation()->ToProto();
2332 } else {
2333 for (const HloComputation* computation : called_computations_) {
2334 *proto.add_called_computation_names() = computation->name();
2335 }
2336 }
2337
2338 proto.set_tuple_index(tuple_index_);
2339 for (int64 dimension : dimensions_) {
2340 proto.add_dimensions(dimension);
2341 }
2342 if (window_ != nullptr) {
2343 *proto.mutable_window() = *window_;
2344 }
2345 if (convolution_dimension_numbers_ != nullptr) {
2346 *proto.mutable_convolution_dimension_numbers() =
2347 *convolution_dimension_numbers_;
2348 }
2349 if (dot_dimension_numbers_ != nullptr) {
2350 *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
2351 }
2352 if (gather_dimension_numbers_ != nullptr) {
2353 *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
2354 }
2355 if (opcode() == HloOpcode::kGather) {
2356 for (int64 bound : gather_window_bounds()) {
2357 proto.add_gather_window_bounds(bound);
2358 }
2359 }
2360 for (int i = 0; i < slice_starts_.size(); ++i) {
2361 auto* slice_dimension = proto.add_slice_dimensions();
2362 slice_dimension->set_start(slice_starts_[i]);
2363 slice_dimension->set_limit(slice_limits_[i]);
2364 slice_dimension->set_stride(slice_strides_[i]);
2365 }
2366 proto.set_exponent_bits(exponent_bits_);
2367 proto.set_mantissa_bits(mantissa_bits_);
2368 for (int64 slice_size : dynamic_slice_sizes_) {
2369 proto.add_dynamic_slice_sizes(slice_size);
2370 }
2371 if (padding_config_ != nullptr) {
2372 *proto.mutable_padding_config() = *padding_config_;
2373 }
2374 proto.set_outfeed_config(outfeed_config_);
2375 if (opcode() == HloOpcode::kRng) {
2376 proto.set_distribution(distribution_);
2377 }
2378 proto.set_epsilon(epsilon_);
2379 proto.set_feature_index(feature_index_);
2380 proto.set_channel_id(channel_id_);
2381 proto.set_infeed_config(infeed_config_);
2382 proto.set_custom_call_target(custom_call_target_);
2383 *proto.mutable_outfeed_shape() = outfeed_shape_;
2384 proto.set_fft_type(fft_type_);
2385 for (int64 fft_len : fft_length_) {
2386 proto.add_fft_length(fft_len);
2387 }
2388
2389 return proto;
2390 }
2391
ToCategory() const2392 string HloInstruction::ToCategory() const {
2393 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2394 opcode() == HloOpcode::kReshape) {
2395 return "data formatting";
2396 }
2397
2398 if (opcode() == HloOpcode::kConvolution) {
2399 string category = "convolution";
2400 if (window_util::HasBaseDilation(window())) {
2401 category += " base-dilated";
2402 }
2403 if (window_util::HasWindowDilation(window())) {
2404 category += " window-dilated";
2405 }
2406 return category;
2407 }
2408
2409 // Give transpose-dot and backwards-conv fusions the categories "dot" and
2410 // "convolution" so they match the categories of proper kDot and kConvolution
2411 // ops. These fusion categories are really just a way of expressing a
2412 // particular kind of dot or conv, so they should have the same category as a
2413 // vanilla dot/conv.
2414 if (opcode() == HloOpcode::kFusion) {
2415 switch (fusion_kind()) {
2416 case FusionKind::kLoop:
2417 return "loop fusion";
2418 case FusionKind::kInput:
2419 return "input fusion";
2420 case FusionKind::kOutput:
2421 return "output fusion";
2422 case FusionKind::kTransposeDot:
2423 return "dot";
2424 case FusionKind::kCustom:
2425 return "custom fusion";
2426 }
2427 }
2428
2429 if (IsElementwise()) {
2430 return "non-fusion elementwise";
2431 }
2432
2433 return HloOpcodeString(opcode());
2434 }
2435
tracing() const2436 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2437
set_tracing(HloInstruction * trace_instruction)2438 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2439 trace_instruction_ = trace_instruction;
2440 }
2441
TracingTag() const2442 string HloInstruction::TracingTag() const {
2443 CHECK_EQ(HloOpcode::kTrace, opcode());
2444 CHECK(literal_ != nullptr);
2445 return literal_->GetR1U8AsString();
2446 }
2447
IsFused() const2448 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2449
IsFusable() const2450 bool HloInstruction::IsFusable() const {
2451 // Instructions which are traced should not be fused.
2452 if (tracing()) {
2453 return false;
2454 }
2455 // Some kinds of instructions don't make sense to fuse.
2456 switch (opcode_) {
2457 case HloOpcode::kParameter:
2458 return false;
2459 // Side effecting instrutions cannot be fused.
2460 default:
2461 return !HasSideEffect();
2462 }
2463 }
2464
fused_instructions_computation() const2465 HloComputation* HloInstruction::fused_instructions_computation() const {
2466 CHECK_EQ(opcode_, HloOpcode::kFusion);
2467 CHECK(!called_computations_.empty());
2468 auto* fused_instructions_computation = called_computations_.front();
2469 CHECK(fused_instructions_computation->IsFusionComputation());
2470 return fused_instructions_computation;
2471 }
2472
fused_expression_root() const2473 HloInstruction* HloInstruction::fused_expression_root() const {
2474 CHECK_EQ(opcode_, HloOpcode::kFusion);
2475 return fused_instructions_computation()->root_instruction();
2476 }
2477
fused_parameter(int64 parameter_number) const2478 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
2479 CHECK_EQ(opcode_, HloOpcode::kFusion);
2480 return fused_instructions_computation()->parameter_instruction(
2481 parameter_number);
2482 }
2483
fused_parameters() const2484 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
2485 CHECK_EQ(opcode_, HloOpcode::kFusion);
2486 return fused_instructions_computation()->parameter_instructions();
2487 }
2488
2489 const tensorflow::gtl::iterator_range<UnwrappingIterator<
2490 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const2491 HloInstruction::fused_instructions() const {
2492 CHECK_EQ(opcode_, HloOpcode::kFusion);
2493 const HloComputation* subcomp = fused_instructions_computation();
2494 return subcomp->instructions();
2495 }
2496
2497 const tensorflow::gtl::iterator_range<
2498 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()2499 HloInstruction::fused_instructions() {
2500 CHECK_EQ(opcode_, HloOpcode::kFusion);
2501 return fused_instructions_computation()->instructions();
2502 }
2503
fused_instruction_count() const2504 int64 HloInstruction::fused_instruction_count() const {
2505 return fused_instructions_computation()->instruction_count();
2506 }
2507
HloInstruction(HloOpcode opcode,const Shape & shape)2508 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
2509 : unique_id_(-1),
2510 opcode_(opcode),
2511 shape_(shape),
2512 name_(HloOpcodeString(opcode)) {
2513 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2514 }
2515
2516 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)2517 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
2518 switch (opcode_) {
2519 case HloOpcode::kAbs:
2520 return visitor->HandleAbs(this);
2521 case HloOpcode::kAtan2:
2522 return visitor->HandleAtan2(this);
2523 case HloOpcode::kRoundNearestAfz:
2524 return visitor->HandleRound(this);
2525 case HloOpcode::kBatchNormTraining:
2526 return visitor->HandleBatchNormTraining(this);
2527 case HloOpcode::kBatchNormInference:
2528 return visitor->HandleBatchNormInference(this);
2529 case HloOpcode::kBatchNormGrad:
2530 return visitor->HandleBatchNormGrad(this);
2531 case HloOpcode::kSign:
2532 return visitor->HandleSign(this);
2533 case HloOpcode::kConstant:
2534 return visitor->HandleConstant(this);
2535 case HloOpcode::kGetTupleElement:
2536 return visitor->HandleGetTupleElement(this);
2537 case HloOpcode::kParameter:
2538 return visitor->HandleParameter(this);
2539 case HloOpcode::kEq:
2540 case HloOpcode::kGe:
2541 case HloOpcode::kGt:
2542 case HloOpcode::kLe:
2543 case HloOpcode::kLt:
2544 case HloOpcode::kNe:
2545 return visitor->HandleCompare(this);
2546 case HloOpcode::kComplex:
2547 return visitor->HandleComplex(this);
2548 case HloOpcode::kAdd:
2549 return visitor->HandleAdd(this);
2550 case HloOpcode::kDivide:
2551 return visitor->HandleDivide(this);
2552 case HloOpcode::kSubtract:
2553 return visitor->HandleSubtract(this);
2554 case HloOpcode::kMaximum:
2555 return visitor->HandleMaximum(this);
2556 case HloOpcode::kMinimum:
2557 return visitor->HandleMinimum(this);
2558 case HloOpcode::kAnd:
2559 return visitor->HandleAnd(this);
2560 case HloOpcode::kOr:
2561 return visitor->HandleOr(this);
2562 case HloOpcode::kShiftLeft:
2563 return visitor->HandleShiftLeft(this);
2564 case HloOpcode::kShiftRightArithmetic:
2565 return visitor->HandleShiftRightArithmetic(this);
2566 case HloOpcode::kShiftRightLogical:
2567 return visitor->HandleShiftRightLogical(this);
2568 case HloOpcode::kConcatenate:
2569 return visitor->HandleConcatenate(this);
2570 case HloOpcode::kConvert:
2571 return visitor->HandleConvert(this);
2572 case HloOpcode::kBitcastConvert:
2573 return visitor->HandleBitcastConvert(this);
2574 case HloOpcode::kCopy:
2575 return visitor->HandleCopy(this);
2576 case HloOpcode::kMultiply:
2577 return visitor->HandleMultiply(this);
2578 case HloOpcode::kDot:
2579 return visitor->HandleDot(this);
2580 case HloOpcode::kPower:
2581 return visitor->HandlePower(this);
2582 case HloOpcode::kRemainder:
2583 return visitor->HandleRemainder(this);
2584 case HloOpcode::kSelect:
2585 return visitor->HandleSelect(this);
2586 case HloOpcode::kConvolution:
2587 return visitor->HandleConvolution(this);
2588 case HloOpcode::kFft:
2589 return visitor->HandleFft(this);
2590 case HloOpcode::kCrossReplicaSum:
2591 return visitor->HandleCrossReplicaSum(this);
2592 case HloOpcode::kTuple:
2593 return visitor->HandleTuple(this);
2594 case HloOpcode::kMap:
2595 return visitor->HandleMap(this);
2596 case HloOpcode::kClamp:
2597 return visitor->HandleClamp(this);
2598 case HloOpcode::kReduce:
2599 return visitor->HandleReduce(this);
2600 case HloOpcode::kReduceWindow:
2601 return visitor->HandleReduceWindow(this);
2602 case HloOpcode::kSelectAndScatter:
2603 return visitor->HandleSelectAndScatter(this);
2604 case HloOpcode::kNegate:
2605 return visitor->HandleNegate(this);
2606 case HloOpcode::kExp:
2607 return visitor->HandleExp(this);
2608 case HloOpcode::kFloor:
2609 return visitor->HandleFloor(this);
2610 case HloOpcode::kCeil:
2611 return visitor->HandleCeil(this);
2612 case HloOpcode::kLog:
2613 return visitor->HandleLog(this);
2614 case HloOpcode::kTanh:
2615 return visitor->HandleTanh(this);
2616 case HloOpcode::kCos:
2617 return visitor->HandleCos(this);
2618 case HloOpcode::kSin:
2619 return visitor->HandleSin(this);
2620 case HloOpcode::kReal:
2621 return visitor->HandleReal(this);
2622 case HloOpcode::kImag:
2623 return visitor->HandleImag(this);
2624 case HloOpcode::kIsFinite:
2625 return visitor->HandleIsFinite(this);
2626 case HloOpcode::kNot:
2627 return visitor->HandleNot(this);
2628 case HloOpcode::kBitcast:
2629 return visitor->HandleBitcast(this);
2630 case HloOpcode::kBroadcast:
2631 return visitor->HandleBroadcast(this);
2632 case HloOpcode::kPad:
2633 return visitor->HandlePad(this);
2634 case HloOpcode::kReshape:
2635 return visitor->HandleReshape(this);
2636 case HloOpcode::kTranspose:
2637 return visitor->HandleTranspose(this);
2638 case HloOpcode::kReverse:
2639 return visitor->HandleReverse(this);
2640 case HloOpcode::kReducePrecision:
2641 return visitor->HandleReducePrecision(this);
2642 case HloOpcode::kSlice:
2643 return visitor->HandleSlice(this);
2644 case HloOpcode::kDynamicSlice:
2645 return visitor->HandleDynamicSlice(this);
2646 case HloOpcode::kDynamicUpdateSlice:
2647 return visitor->HandleDynamicUpdateSlice(this);
2648 case HloOpcode::kSort:
2649 return visitor->HandleSort(this);
2650 case HloOpcode::kInfeed:
2651 return visitor->HandleInfeed(this);
2652 case HloOpcode::kOutfeed:
2653 return visitor->HandleOutfeed(this);
2654 case HloOpcode::kHostCompute:
2655 return visitor->HandleHostCompute(this);
2656 case HloOpcode::kRng:
2657 return visitor->HandleRng(this);
2658 case HloOpcode::kWhile:
2659 return visitor->HandleWhile(this);
2660 case HloOpcode::kFusion:
2661 return visitor->HandleFusion(this);
2662 case HloOpcode::kCall:
2663 return visitor->HandleCall(this);
2664 case HloOpcode::kConditional:
2665 return visitor->HandleConditional(this);
2666 case HloOpcode::kCustomCall:
2667 return visitor->HandleCustomCall(this);
2668 case HloOpcode::kRecv:
2669 return visitor->HandleRecv(this);
2670 case HloOpcode::kRecvDone:
2671 return visitor->HandleRecvDone(this);
2672 case HloOpcode::kSend:
2673 return visitor->HandleSend(this);
2674 case HloOpcode::kSendDone:
2675 return visitor->HandleSendDone(this);
2676 case HloOpcode::kGather:
2677 return visitor->HandleGather(this);
2678
2679 // These opcodes are not handled here.
2680 case HloOpcode::kTrace:
2681 break;
2682 }
2683 return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s",
2684 HloOpcodeString(opcode_).c_str());
2685 }
2686
2687 // Explicit instantiations.
2688 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2689 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2690
2691 using DFSStack =
2692 tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2693
2694 // Push "child" onto the dfs_stack if not already visited. Returns false if a
2695 // cycle was detected, and true otherwise.
2696 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)2697 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
2698 HloInstruction* child) {
2699 CHECK(child != nullptr);
2700 const int id = child->unique_id();
2701 CHECK_GE(id, 0) << "instruction may not have a parent computation";
2702 switch (visitor->GetVisitState(id)) {
2703 case Visitor::kVisiting:
2704 return false;
2705
2706 case Visitor::kVisited:
2707 // Nothing to do
2708 return true;
2709
2710 case Visitor::kNotVisited:
2711 dfs_stack->push_back(std::make_pair(id, child));
2712 return true;
2713 }
2714 }
2715
2716 using InternalCompareFunction =
2717 std::function<bool(std::pair<int, const HloInstruction*>,
2718 std::pair<int, const HloInstruction*>)>;
2719 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)2720 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
2721 const InternalCompareFunction* operand_order,
2722 bool ignore_control_predecessors) {
2723 visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
2724
2725 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
2726 //
2727 // We need to keep track of both the id and the instruction because
2728 // instructions can get deleted while they are on the stack, so we
2729 // can't always use the (potentially dead) instruction object to grab
2730 // its id.
2731 DFSStack dfs_stack;
2732 dfs_stack.emplace_back(root->unique_id(), root);
2733
2734 do {
2735 DCHECK(!dfs_stack.empty());
2736
2737 int current_id = dfs_stack.back().first;
2738 HloInstruction* current_node = dfs_stack.back().second;
2739 CHECK_GE(current_id, 0) << current_id << ": " << current_node
2740 << ": instruction may not have parent computation";
2741 typename Visitor::VisitState visit_state =
2742 visitor->GetVisitState(current_id);
2743 if (visit_state == Visitor::kVisited) {
2744 dfs_stack.pop_back();
2745 VLOG(3) << "Not visiting HLO %" << current_node->name()
2746 << " as it was already visited.";
2747 continue;
2748 }
2749
2750 if (visit_state == Visitor::kVisiting) {
2751 dfs_stack.pop_back();
2752
2753 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
2754 VLOG(2) << "Visiting HLO %" << current_node->name();
2755 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
2756 visitor->SetVisitState(current_id, Visitor::kVisited);
2757 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
2758 continue;
2759 }
2760
2761 visitor->SetVisitState(current_id, Visitor::kVisiting);
2762
2763 const size_t old_dfs_stack_size = dfs_stack.size();
2764 for (HloInstruction* child : current_node->operands()) {
2765 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2766 return FailedPrecondition(
2767 "A cycle is detected while visiting instruction %s",
2768 current_node->ToString().c_str());
2769 }
2770 }
2771
2772 if (!ignore_control_predecessors) {
2773 for (HloInstruction* child : current_node->control_predecessors()) {
2774 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2775 return FailedPrecondition(
2776 "A cycle is detected while visiting instruction %s",
2777 current_node->ToString().c_str());
2778 }
2779 }
2780 }
2781
2782 if (operand_order != nullptr) {
2783 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
2784 *operand_order);
2785 }
2786
2787 // This makes the traversal order the same as what you'd expect
2788 // out of a recursive algorithm.
2789 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
2790 } while (!dfs_stack.empty());
2791
2792 return Status::OK();
2793 }
2794
2795 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)2796 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
2797 bool call_finish_visit,
2798 bool ignore_control_predecessors) {
2799 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
2800 TF_RETURN_IF_ERROR(
2801 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
2802 if (call_finish_visit) {
2803 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2804 }
2805 return Status::OK();
2806 }
2807
2808 // Explicit instantiations.
2809 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2810 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2811
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)2812 Status HloInstruction::AcceptWithOperandOrder(
2813 DfsHloVisitor* visitor, const CompareFunction& operand_order,
2814 bool call_finish_visit) {
2815 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
2816 InternalCompareFunction func = [&operand_order](
2817 std::pair<int, const HloInstruction*> a,
2818 std::pair<int, const HloInstruction*> b) {
2819 // Call the client's comparison function on the actual HloInstruction*
2820 // objects (ignoring the internal ids we also have in our stack entries)
2821 return operand_order(a.second, b.second);
2822 };
2823 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
2824 /*ignore_control_predecessors=*/false));
2825 if (call_finish_visit) {
2826 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
2827 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2828 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
2829 }
2830 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
2831 return Status::OK();
2832 }
2833
2834 namespace {
2835
2836 // Returns true if the given order is a topological sort of the instructions
2837 // it contains.
OrderIsTopologicalSort(const std::vector<const HloInstruction * > & order)2838 bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
2839 // Create a map from instruction to its position in 'order'.
2840 std::unordered_map<const HloInstruction*, int> order_position;
2841 for (int i = 0; i < order.size(); i++) {
2842 if (!order_position.insert({order[i], i}).second) {
2843 // Instruction order[i] is duplicated in the order.
2844 return false;
2845 }
2846 }
2847 // Verify that the operand of each instruction in the order is also in the
2848 // order *and* the operand's position is earlier (defs are before uses for
2849 // all ops).
2850 for (auto* instruction : order) {
2851 for (auto* operand : instruction->operands()) {
2852 if (!ContainsKey(order_position, operand) ||
2853 order_position.at(operand) >= order_position.at(instruction)) {
2854 return false;
2855 }
2856 }
2857 }
2858
2859 return true;
2860 }
2861
2862 } // namespace
2863
Accept(const std::function<Status (HloInstruction *)> & visitor_func)2864 Status HloInstruction::Accept(
2865 const std::function<Status(HloInstruction*)>& visitor_func) {
2866 FunctionVisitor visitor(visitor_func);
2867 return this->Accept(&visitor);
2868 }
2869
Accept(const std::function<Status (const HloInstruction *)> & visitor_func) const2870 Status HloInstruction::Accept(
2871 const std::function<Status(const HloInstruction*)>& visitor_func) const {
2872 ConstFunctionVisitor visitor(visitor_func);
2873 return this->Accept(&visitor);
2874 }
2875
AcceptOrdered(DfsHloVisitor * visitor,const std::vector<const HloInstruction * > & order)2876 Status HloInstruction::AcceptOrdered(
2877 DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
2878 VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")";
2879 TF_RET_CHECK(OrderIsTopologicalSort(order));
2880
2881 // Compute the predecessors of this instruction.
2882 std::unordered_set<const HloInstruction*> predecessors;
2883 TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) {
2884 predecessors.insert(instruction);
2885 return Status::OK();
2886 }));
2887
2888 for (auto* const_instruction : order) {
2889 if (!ContainsKey(predecessors, const_instruction)) {
2890 // Instruction is not a predecessors of 'this'.
2891 continue;
2892 }
2893
2894 // The visitor can mark instructions as visited to skip particular
2895 // instructions.
2896 if (visitor->DidVisit(*const_instruction)) {
2897 VLOG(3) << "Not visiting HLO %" << const_instruction->name()
2898 << " as it was already visited.";
2899 continue;
2900 }
2901
2902 HloInstruction* instruction =
2903 const_cast<HloInstruction*>(const_instruction);
2904
2905 TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
2906 VLOG(2) << "Visiting HLO %" << instruction->name();
2907 TF_RETURN_IF_ERROR(instruction->Visit(visitor));
2908 visitor->SetVisited(*instruction);
2909 TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
2910 }
2911
2912 return visitor->FinishVisit(this);
2913 }
2914
outfeed_shape() const2915 const Shape& HloInstruction::outfeed_shape() const {
2916 DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
2917 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2918 return outfeed_shape_;
2919 }
2920
shape() const2921 const Shape& HloInstruction::shape() const {
2922 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2923 return shape_;
2924 }
2925
OperandIndices(const HloInstruction * operand) const2926 std::vector<int64> HloInstruction::OperandIndices(
2927 const HloInstruction* operand) const {
2928 std::vector<int64> result;
2929 for (int64 i = 0; i < operand_count(); ++i) {
2930 if (this->operand(i) == operand) {
2931 result.push_back(i);
2932 }
2933 }
2934 return result;
2935 }
2936
IsElementwiseBinary() const2937 bool HloInstruction::IsElementwiseBinary() const {
2938 return IsElementwise() && operand_count() == 2;
2939 }
2940
IsElementwise() const2941 bool HloInstruction::IsElementwise() const {
2942 switch (opcode_) {
2943 // Nullary elementwise operations.
2944 case HloOpcode::kConstant:
2945 return true;
2946
2947 // Unary elementwise operations.
2948 case HloOpcode::kAbs:
2949 case HloOpcode::kRoundNearestAfz:
2950 case HloOpcode::kCeil:
2951 case HloOpcode::kConvert:
2952 case HloOpcode::kBitcastConvert:
2953 case HloOpcode::kCopy:
2954 case HloOpcode::kCos:
2955 case HloOpcode::kExp:
2956 case HloOpcode::kFloor:
2957 case HloOpcode::kImag:
2958 case HloOpcode::kIsFinite:
2959 case HloOpcode::kLog:
2960 case HloOpcode::kNot:
2961 case HloOpcode::kNegate:
2962 case HloOpcode::kReal:
2963 case HloOpcode::kReducePrecision:
2964 case HloOpcode::kSign:
2965 case HloOpcode::kSin:
2966 case HloOpcode::kTanh:
2967 CHECK_EQ(1, operand_count());
2968 return true;
2969
2970 // Binary elementwise operations, the same as in IsElementwiseBinary().
2971 case HloOpcode::kAdd:
2972 case HloOpcode::kAtan2:
2973 case HloOpcode::kComplex:
2974 case HloOpcode::kDivide:
2975 case HloOpcode::kEq:
2976 case HloOpcode::kGe:
2977 case HloOpcode::kGt:
2978 case HloOpcode::kLe:
2979 case HloOpcode::kLt:
2980 case HloOpcode::kMaximum:
2981 case HloOpcode::kMinimum:
2982 case HloOpcode::kMultiply:
2983 case HloOpcode::kNe:
2984 case HloOpcode::kPower:
2985 case HloOpcode::kRemainder:
2986 case HloOpcode::kSubtract:
2987 case HloOpcode::kAnd:
2988 case HloOpcode::kOr:
2989 case HloOpcode::kShiftLeft:
2990 case HloOpcode::kShiftRightArithmetic:
2991 case HloOpcode::kShiftRightLogical:
2992 CHECK_EQ(2, operand_count());
2993 return true;
2994
2995 // Ternary elementwise operations.
2996 case HloOpcode::kSelect:
2997 return !ShapeUtil::IsTuple(shape_);
2998 case HloOpcode::kClamp:
2999 return true;
3000
3001 // Other operations.
3002 case HloOpcode::kRng:
3003 case HloOpcode::kMap:
3004 return true;
3005 case HloOpcode::kFusion:
3006 if (fusion_kind() != FusionKind::kLoop) {
3007 return false;
3008 }
3009 for (auto* fused : fused_instructions()) {
3010 if (fused->opcode() != HloOpcode::kParameter &&
3011 !fused->IsElementwise()) {
3012 return false;
3013 }
3014 }
3015 return true;
3016
3017 default:
3018 return false;
3019 }
3020 }
3021
ImplicitlyBroadcastsOperand(int64 operand_idx) const3022 bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
3023 CHECK(IsElementwise());
3024 return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape());
3025 }
3026
3027 namespace {
IsInstructionElementwiseOnOperand(const HloInstruction * instruction,const HloInstruction * operand)3028 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
3029 const HloInstruction* operand) {
3030 std::vector<int64> operand_indices = instruction->OperandIndices(operand);
3031 return std::all_of(
3032 operand_indices.begin(), operand_indices.end(),
3033 [instruction](int64 operand_index) {
3034 return instruction->IsElementwiseOnOperand(operand_index);
3035 });
3036 }
3037 } // namespace
3038
IsElementwiseOnOperand(int64 operand_idx) const3039 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
3040 // For all instructions other than kFusion, being elementwise on one of the
3041 // operands is equivalent to being elementwise on all the operands.
3042 if (opcode() != HloOpcode::kFusion) {
3043 return IsElementwise();
3044 }
3045
3046 CHECK_EQ(HloOpcode::kFusion, opcode());
3047 if (fusion_kind() != FusionKind::kLoop) {
3048 return false;
3049 }
3050
3051 // A loop-fusion is elementwise on an operand if all operations (computed
3052 // using BFS) between the operand and the fused root are elementwise.
3053 std::deque<HloInstruction*> worklist;
3054 std::unordered_set<const HloInstruction*> visited;
3055 worklist.push_back(fused_parameter(operand_idx));
3056 visited.insert(fused_parameter(operand_idx));
3057 while (!worklist.empty()) {
3058 HloInstruction* operand = worklist.front();
3059 worklist.pop_front();
3060 for (HloInstruction* user : operand->users()) {
3061 CHECK_GE(user->unique_id(), 0);
3062 if (ContainsKey(visited, user)) {
3063 continue;
3064 }
3065 if (user->IsElementwise() ||
3066 IsInstructionElementwiseOnOperand(user, operand)) {
3067 worklist.push_back(user);
3068 visited.insert(user);
3069 } else {
3070 return false;
3071 }
3072 }
3073 }
3074 return true;
3075 }
3076
3077 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3078 // in HloInstruction::OperandElementUse below.
3079 class HloInstruction::FusionReusesParamElements {
3080 public:
3081 using UseKind = HloInstruction::UseKind;
3082
3083 // We could rather iterate backwards through fused_instructions_ here, as it
3084 // is in reverse postorder, and compute whether each fused instruction reuses
3085 // the value of this parameter, which would save stack space but not allow us
3086 // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)3087 static UseKind Compute(int64 i, const HloInstruction& hlo) {
3088 tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
3089 return ComputeInternal(i, hlo, &memoization_cache);
3090 }
3091
3092 private:
ComputeInternal(int64 i,const HloInstruction & hlo,tensorflow::gtl::FlatMap<const HloInstruction *,UseKind> * cache)3093 static UseKind ComputeInternal(
3094 int64 i, const HloInstruction& hlo,
3095 tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
3096 if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
3097 return UseKind::kUse;
3098 }
3099
3100 auto p = cache->emplace(&hlo, UseKind{});
3101 auto value_it = p.first;
3102 const bool key_is_new = p.second;
3103
3104 if (key_is_new) {
3105 for (int64 j = 0; j < hlo.operands_.size(); ++j) {
3106 UseKind old_val = value_it->second;
3107
3108 // The next operation invalidates iterators.
3109 UseKind new_val =
3110 Plus(old_val, std::min(hlo.OperandElementUse(j),
3111 ComputeInternal(i, *hlo.operand(j), cache)));
3112
3113 // Re-acquire the iterator. We could work harder to do this only if
3114 // absolutely necessary, but this code is not hot enough to warrant
3115 // that.
3116 value_it = cache->find(&hlo);
3117 value_it->second = new_val;
3118 }
3119 }
3120 return value_it->second;
3121 }
3122
3123 // Fold operation for UseKinds.
Plus(UseKind a,UseKind b)3124 static UseKind Plus(UseKind a, UseKind b) {
3125 if (a == UseKind::kNoUse) {
3126 return b;
3127 } else if (b == UseKind::kNoUse) {
3128 return a;
3129 } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
3130 return UseKind::kReuse;
3131 } else if (a == UseKind::kUsePermutingElements ||
3132 b == UseKind::kUsePermutingElements) {
3133 return UseKind::kReuse;
3134 } else {
3135 CHECK(a == UseKind::kUse && b == UseKind::kUse);
3136 return UseKind::kUse;
3137 }
3138 }
3139 };
3140
OperandElementUse(int64 i) const3141 HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
3142 switch (opcode_) {
3143 case HloOpcode::kBitcast:
3144 case HloOpcode::kConcatenate:
3145 case HloOpcode::kReshape:
3146 case HloOpcode::kReverse:
3147 case HloOpcode::kSlice:
3148 case HloOpcode::kTranspose:
3149 return UseKind::kUsePermutingElements;
3150 case HloOpcode::kPad:
3151 case HloOpcode::kReduce:
3152 // Pad reuses the padding value but not the padded array elements.
3153 // Reduce reuses the init value but not the operand array elements.
3154 return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
3155 case HloOpcode::kFusion:
3156 // Uses the memoizing, recursive computation defined above.
3157 return FusionReusesParamElements::Compute(i, *fused_expression_root());
3158 case HloOpcode::kDot:
3159 // Dot operations with inputs [A,B] * [B,1] do not re-use
3160 // elements on their left operand.
3161 // Dot operations with inputs [1,A] * [A,B] do not re-use
3162 // elements on their right operand.
3163 if (shape().dimensions_size() == 2) {
3164 if ((i == 0 && shape().dimensions(1) == 1) ||
3165 (i == 1 && shape().dimensions(0) == 1)) {
3166 return UseKind::kUse;
3167 }
3168 }
3169 return UseKind::kReuse;
3170 case HloOpcode::kDynamicUpdateSlice:
3171 // Dynamic-update-slice reuses only operand 2 (start_indices).
3172 if (i == 0 || i == 1) {
3173 return UseKind::kUse;
3174 }
3175 return UseKind::kReuse;
3176 default:
3177 return IsElementwise() && !ImplicitlyBroadcastsOperand(i)
3178 ? UseKind::kUse
3179 : UseKind::kReuse;
3180 }
3181 }
3182
3183 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3184 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3185 if (HloOpcode::kReshape != opcode_) {
3186 return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3187 }
3188 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3189 shape_);
3190 }
3191
ToString(HloInstruction::FusionKind kind)3192 string ToString(HloInstruction::FusionKind kind) {
3193 switch (kind) {
3194 case HloInstruction::FusionKind::kLoop:
3195 return "kLoop";
3196 case HloInstruction::FusionKind::kInput:
3197 return "kInput";
3198 case HloInstruction::FusionKind::kOutput:
3199 return "kOutput";
3200 case HloInstruction::FusionKind::kTransposeDot:
3201 return "kTransposeDot";
3202 case HloInstruction::FusionKind::kCustom:
3203 return "kCustom";
3204 }
3205 }
3206
StringToFusionKind(const string & kind_name)3207 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3208 const string& kind_name) {
3209 if (kind_name == "kLoop") {
3210 return HloInstruction::FusionKind::kLoop;
3211 }
3212 if (kind_name == "kInput") {
3213 return HloInstruction::FusionKind::kInput;
3214 }
3215 if (kind_name == "kOutput") {
3216 return HloInstruction::FusionKind::kOutput;
3217 }
3218 if (kind_name == "kTransposeDot") {
3219 return HloInstruction::FusionKind::kTransposeDot;
3220 }
3221 if (kind_name == "kCustom") {
3222 return HloInstruction::FusionKind::kCustom;
3223 }
3224 return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
3225 }
3226
PaddingConfigToString(const PaddingConfig & padding)3227 string PaddingConfigToString(const PaddingConfig& padding) {
3228 bool has_interior_padding =
3229 std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
3230 [](const PaddingConfig::PaddingConfigDimension& dim) {
3231 return dim.interior_padding() != 0;
3232 });
3233 return Join(
3234 padding.dimensions(), "x",
3235 [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3236 StrAppend(
3237 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3238 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3239 });
3240 }
3241
OpMetadataToString(const OpMetadata & metadata)3242 string OpMetadataToString(const OpMetadata& metadata) {
3243 std::vector<string> result;
3244 if (!metadata.op_type().empty()) {
3245 result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
3246 }
3247 if (!metadata.op_name().empty()) {
3248 result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
3249 }
3250 if (!metadata.source_file().empty()) {
3251 result.push_back(
3252 StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
3253 }
3254 if (metadata.source_line() != 0) {
3255 result.push_back(StrCat("source_line=", metadata.source_line()));
3256 }
3257 return Join(result, " ");
3258 }
3259
RandomDistributionToString(const RandomDistribution & distribution)3260 string RandomDistributionToString(const RandomDistribution& distribution) {
3261 return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
3262 }
3263
StringToRandomDistribution(const string & name)3264 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3265 static std::unordered_map<string, RandomDistribution>* map = [] {
3266 static auto* map = new std::unordered_map<string, RandomDistribution>;
3267 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3268 if (RandomDistribution_IsValid(i)) {
3269 auto value = static_cast<RandomDistribution>(i);
3270 (*map)[RandomDistributionToString(value)] = value;
3271 }
3272 }
3273 return map;
3274 }();
3275 auto found = map->find(tensorflow::str_util::Lowercase(name));
3276 if (found == map->end()) {
3277 return InvalidArgument("Unknown distribution");
3278 }
3279 return found->second;
3280 }
3281
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3282 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3283 return os << ToString(kind);
3284 }
3285
ConvolutionDimensionNumbersToString() const3286 string HloInstruction::ConvolutionDimensionNumbersToString() const {
3287 string result;
3288 if (convolution_dimension_numbers_ == nullptr) {
3289 return result;
3290 }
3291 const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
3292 // Show the given dimension labels in order of major to minor based on the
3293 // shape's layout.
3294 const auto append_dims = [&](const std::vector<string>& dims,
3295 const Shape& shape) {
3296 CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
3297 StrAppend(&result, Join(dims, ""));
3298 };
3299
3300 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3301 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3302 std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3303 lhs_dims[dnums.input_batch_dimension()] = 'b';
3304 lhs_dims[dnums.input_feature_dimension()] = 'f';
3305 for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3306 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3307 }
3308
3309 std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3310 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3311 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3312 for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3313 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3314 }
3315
3316 std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3317 output_dims[dnums.output_batch_dimension()] = 'b';
3318 output_dims[dnums.output_feature_dimension()] = 'f';
3319 for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3320 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3321 }
3322
3323 result += "dim_labels=";
3324 append_dims(lhs_dims, operand(0)->shape());
3325 result += "_";
3326 append_dims(rhs_dims, operand(1)->shape());
3327 result += "->";
3328
3329 // A convolution can be represented as a kConvolution HLO or as a CustomCall
3330 // that returns a tuple, the first element of which is the result of the
3331 // convolution.
3332 Shape this_shape =
3333 ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
3334 append_dims(output_dims, this_shape);
3335 return result;
3336 }
3337
DotDimensionNumbersToString() const3338 string HloInstruction::DotDimensionNumbersToString() const {
3339 std::vector<string> result;
3340 if (dot_dimension_numbers_ == nullptr) {
3341 return "";
3342 }
3343 const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
3344 if (!dnums.lhs_batch_dimensions().empty()) {
3345 result.push_back(StrCat("lhs_batch_dims={",
3346 Join(dnums.lhs_batch_dimensions(), ","), "}"));
3347 }
3348 result.push_back(StrCat("lhs_contracting_dims={",
3349 Join(dnums.lhs_contracting_dimensions(), ","), "}"));
3350
3351 if (!dnums.rhs_batch_dimensions().empty()) {
3352 result.push_back(StrCat("rhs_batch_dims={",
3353 Join(dnums.rhs_batch_dimensions(), ","), "}"));
3354 }
3355 result.push_back(StrCat("rhs_contracting_dims={",
3356 Join(dnums.rhs_contracting_dimensions(), ","), "}"));
3357
3358 return Join(result, ", ");
3359 }
3360
GatherDimensionNumbersToString() const3361 string HloInstruction::GatherDimensionNumbersToString() const {
3362 CHECK_NE(gather_dimension_numbers_.get(), nullptr);
3363 string output_window_dims =
3364 StrCat("output_window_dims={",
3365 Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
3366 string elided_window_dims =
3367 StrCat("elided_window_dims={",
3368 Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
3369 string gather_dims_to_operand_dims = StrCat(
3370 "gather_dims_to_operand_dims={",
3371 Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
3372
3373 return Join<std::initializer_list<string>>(
3374 {output_window_dims, elided_window_dims, gather_dims_to_operand_dims},
3375 ", ");
3376 }
3377
CouldBeBitcast() const3378 bool HloInstruction::CouldBeBitcast() const {
3379 switch (opcode_) {
3380 case HloOpcode::kTranspose:
3381 return true;
3382 case HloOpcode::kReshape:
3383 return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
3384 default:
3385 return false;
3386 }
3387 }
3388
GetModule() const3389 HloModule* HloInstruction::GetModule() const {
3390 if (parent_) {
3391 return parent_->parent();
3392 }
3393 return nullptr;
3394 }
3395
UniquifyName(NameUniquer * name_uniquer)3396 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3397 string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3398 name_ = name_uniquer->GetUniqueName(name_);
3399 }
3400
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3401 void HloInstruction::set_outer_dimension_partitions(
3402 const std::vector<int64>& outer_dimension_partitions) {
3403 outer_dimension_partitions_ = outer_dimension_partitions;
3404 }
3405
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3406 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3407 const ShapeIndex& shape_index) {
3408 CHECK_EQ(opcode(), HloOpcode::kConstant);
3409 Shape* mutable_array_subshape =
3410 ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
3411 CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
3412
3413 // Normally array_subshape will always have a layout, but this invariant is
3414 // temporarily broken in LayoutAssignment::AssignLayouts.
3415
3416 if (!mutable_array_subshape->has_layout() ||
3417 !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
3418 literal_ = literal_->Relayout(new_layout, shape_index);
3419 *mutable_array_subshape->mutable_layout() = new_layout;
3420 }
3421 }
3422
3423 } // namespace xla
3424