1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16 
17 #include <algorithm>
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/util.h"
31 
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 namespace xla {
35 
36 namespace {
37 
38 // ChooseIdentityValue looks at the instruction and returns a identity value
39 // which, when padded, doesn't change the result of the instruction.
40 //
41 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst)42 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst) {
43   HloComputation* comp = inst->parent();
44   // Padding on elementwise operation doesn't affect the result of the effective
45   // data.
46   if (inst->IsElementwise()) {
47     return nullptr;
48   }
49 
50   switch (inst->opcode()) {
51     case HloOpcode::kReduce:
52     case HloOpcode::kReduceWindow: {
53       // Because of the way we do reduce, we already require the `init` operand
54       // of hlo reduce instruction to be identity value. Here we reuse the
55       // operand.
56       return inst->mutable_operand(1);
57     }
58 
59     case HloOpcode::kConvolution:
60     case HloOpcode::kDot: {
61       // Use 0 as padding value for convolution and dot.
62       PrimitiveType ptype = inst->shape().element_type();
63       return comp->AddInstruction(
64           HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
65     }
66 
67     case HloOpcode::kPad: {
68       return inst->mutable_operand(1);
69     }
70 
71     case HloOpcode::kSelectAndScatter: {
72       return inst->mutable_operand(2);
73     }
74     case HloOpcode::kParameter:
75     case HloOpcode::kGetDimensionSize:
76     case HloOpcode::kReshape:
77     case HloOpcode::kTuple:
78     case HloOpcode::kAllReduce:
79     case HloOpcode::kBroadcast:
80     case HloOpcode::kTranspose:
81     case HloOpcode::kSlice:
82       return nullptr;
83     default:
84       return UnimplementedStrCat("Unimplimented padding for instruction: ",
85                                  inst->ToString());
86   }
87 }
88 
ShouldSkipPadOnOperand(const HloInstruction * inst,int64 operand_num,int64 dimension)89 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num,
90                             int64 dimension) {
91   if ((inst->opcode() == HloOpcode::kReduceWindow ||
92        inst->opcode() == HloOpcode::kSelectAndScatter) &&
93       operand_num == 0 && inst->window().dimensions(dimension).size() == 1) {
94     return true;
95   }
96 
97   if (operand_num == 0 && inst->opcode() == HloOpcode::kConvolution &&
98       inst->convolution_dimension_numbers().input_batch_dimension() ==
99           dimension) {
100     return true;
101   }
102   return false;
103 }
104 
105 }  // namespace
106 
Run(HloModule * module)107 StatusOr<bool> DynamicPadder::Run(HloModule* module) {
108   bool changed = false;
109   VLOG(2) << "Pre DynamicPadder HLO:";
110   XLA_VLOG_LINES(2, module->ToString());
111   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
112                       DynamicDimensionInference::Run(module));
113 
114   for (HloComputation* computation : module->computations()) {
115     for (HloInstruction* inst : computation->instructions()) {
116       for (int64 operand_num = 0; operand_num < inst->operand_count();
117            ++operand_num) {
118         HloInstruction* operand = inst->mutable_operand(operand_num);
119         if (!operand->shape().IsArray()) {
120           continue;
121         }
122         for (int64 dim = 0; dim < operand->shape().rank(); ++dim) {
123           HloInstruction* dynamic_size =
124               dynamic_dimension_inference.GetDynamicSize(operand, {}, dim);
125           if (dynamic_size == nullptr) {
126             continue;
127           }
128           VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @"
129                   << dim;
130 
131           if (ShouldSkipPadOnOperand(inst, operand_num, dim)) {
132             continue;
133           }
134 
135           TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
136                               ChooseIdentityValue(inst));
137           if (identity_value == nullptr) {
138             continue;
139           }
140 
141           // For each dimension, first generates a mask representing the
142           // effective area of data and padded area of data using iota and
143           // dynamic_size. For example, given a dimension of 7 elements and 5
144           // effective elements:
145           //
146           // iota = [0, 1, 2, 3, 4, 5, 6]
147           // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
148           // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
149           //
150           // Once the mask is generated, the input data is then padded using the
151           // mask and pad value.
152           //
153           const Shape mask_shape =
154               ShapeUtil::ChangeElementType(operand->shape(), xla::U32);
155           const Shape pred_shape =
156               ShapeUtil::ChangeElementType(operand->shape(), xla::PRED);
157           HloInstruction* iota = computation->AddInstruction(
158               HloInstruction::CreateIota(mask_shape, dim));
159 
160           HloInstruction* broadcasted_effective_size =
161               computation->AddInstruction(HloInstruction::CreateBroadcast(
162                   mask_shape, dynamic_size, {}));
163           HloInstruction* pred =
164               computation->AddInstruction(HloInstruction::CreateCompare(
165                   pred_shape, iota, broadcasted_effective_size,
166                   ComparisonDirection::kLt));
167 
168           HloInstruction* broadcasted_identity_value =
169               computation->AddInstruction(HloInstruction::CreateBroadcast(
170                   operand->shape(), identity_value, {}));
171           HloInstruction* padded =
172               computation->AddInstruction(HloInstruction::CreateTernary(
173                   operand->shape(), HloOpcode::kSelect, pred, operand,
174                   broadcasted_identity_value));
175           TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
176           operand = inst->mutable_operand(operand_num);
177           changed = true;
178         }
179       }
180     }
181   }
182   HloDCE dce;
183   TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
184   VLOG(2) << "Post DynamicPadder HLO:";
185   XLA_VLOG_LINES(2, module->ToString());
186   return changed;
187 }
188 
189 }  // namespace xla
190