1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16
17 #include <algorithm>
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/util.h"
31
32 #include "tensorflow/core/lib/core/errors.h"
33
34 namespace xla {
35
36 namespace {
37
38 // ChooseIdentityValue looks at the instruction and returns a identity value
39 // which, when padded, doesn't change the result of the instruction.
40 //
41 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst)42 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst) {
43 HloComputation* comp = inst->parent();
44 // Padding on elementwise operation doesn't affect the result of the effective
45 // data.
46 if (inst->IsElementwise()) {
47 return nullptr;
48 }
49
50 switch (inst->opcode()) {
51 case HloOpcode::kReduce:
52 case HloOpcode::kReduceWindow: {
53 // Because of the way we do reduce, we already require the `init` operand
54 // of hlo reduce instruction to be identity value. Here we reuse the
55 // operand.
56 return inst->mutable_operand(1);
57 }
58
59 case HloOpcode::kConvolution:
60 case HloOpcode::kDot: {
61 // Use 0 as padding value for convolution and dot.
62 PrimitiveType ptype = inst->shape().element_type();
63 return comp->AddInstruction(
64 HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
65 }
66
67 case HloOpcode::kPad: {
68 return inst->mutable_operand(1);
69 }
70
71 case HloOpcode::kSelectAndScatter: {
72 return inst->mutable_operand(2);
73 }
74 case HloOpcode::kParameter:
75 case HloOpcode::kGetDimensionSize:
76 case HloOpcode::kReshape:
77 case HloOpcode::kTuple:
78 case HloOpcode::kAllReduce:
79 case HloOpcode::kBroadcast:
80 case HloOpcode::kTranspose:
81 case HloOpcode::kSlice:
82 return nullptr;
83 default:
84 return UnimplementedStrCat("Unimplimented padding for instruction: ",
85 inst->ToString());
86 }
87 }
88
ShouldSkipPadOnOperand(const HloInstruction * inst,int64 operand_num,int64 dimension)89 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
90 int64 dimension) {
91 if ((inst->opcode() == HloOpcode::kReduceWindow ||
92 inst->opcode() == HloOpcode::kSelectAndScatter) &&
93 operand_num == 0 && inst->window().dimensions(dimension).size() == 1) {
94 return true;
95 }
96
97 if (operand_num == 0 && inst->opcode() == HloOpcode::kConvolution &&
98 inst->convolution_dimension_numbers().input_batch_dimension() ==
99 dimension) {
100 return true;
101 }
102 return false;
103 }
104
105 } // namespace
106
Run(HloModule * module)107 StatusOr<bool> DynamicPadder::Run(HloModule* module) {
108 bool changed = false;
109 VLOG(2) << "Pre DynamicPadder HLO:";
110 XLA_VLOG_LINES(2, module->ToString());
111 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
112 DynamicDimensionInference::Run(module));
113
114 for (HloComputation* computation : module->computations()) {
115 for (HloInstruction* inst : computation->instructions()) {
116 for (int64 operand_num = 0; operand_num < inst->operand_count();
117 ++operand_num) {
118 HloInstruction* operand = inst->mutable_operand(operand_num);
119 if (!operand->shape().IsArray()) {
120 continue;
121 }
122 for (int64 dim = 0; dim < operand->shape().rank(); ++dim) {
123 HloInstruction* dynamic_size =
124 dynamic_dimension_inference.GetDynamicSize(operand, {}, dim);
125 if (dynamic_size == nullptr) {
126 continue;
127 }
128 VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @"
129 << dim;
130
131 if (ShouldSkipPadOnOperand(inst, operand_num, dim)) {
132 continue;
133 }
134
135 TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
136 ChooseIdentityValue(inst));
137 if (identity_value == nullptr) {
138 continue;
139 }
140
141 // For each dimension, first generates a mask representing the
142 // effective area of data and padded area of data using iota and
143 // dynamic_size. For example, given a dimension of 7 elements and 5
144 // effective elements:
145 //
146 // iota = [0, 1, 2, 3, 4, 5, 6]
147 // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
148 // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
149 //
150 // Once the mask is generated, the input data is then padded using the
151 // mask and pad value.
152 //
153 const Shape mask_shape =
154 ShapeUtil::ChangeElementType(operand->shape(), xla::U32);
155 const Shape pred_shape =
156 ShapeUtil::ChangeElementType(operand->shape(), xla::PRED);
157 HloInstruction* iota = computation->AddInstruction(
158 HloInstruction::CreateIota(mask_shape, dim));
159
160 HloInstruction* broadcasted_effective_size =
161 computation->AddInstruction(HloInstruction::CreateBroadcast(
162 mask_shape, dynamic_size, {}));
163 HloInstruction* pred =
164 computation->AddInstruction(HloInstruction::CreateCompare(
165 pred_shape, iota, broadcasted_effective_size,
166 ComparisonDirection::kLt));
167
168 HloInstruction* broadcasted_identity_value =
169 computation->AddInstruction(HloInstruction::CreateBroadcast(
170 operand->shape(), identity_value, {}));
171 HloInstruction* padded =
172 computation->AddInstruction(HloInstruction::CreateTernary(
173 operand->shape(), HloOpcode::kSelect, pred, operand,
174 broadcasted_identity_value));
175 TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
176 operand = inst->mutable_operand(operand_num);
177 changed = true;
178 }
179 }
180 }
181 }
182 HloDCE dce;
183 TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
184 VLOG(2) << "Post DynamicPadder HLO:";
185 XLA_VLOG_LINES(2, module->ToString());
186 return changed;
187 }
188
189 } // namespace xla
190