1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
24
25 namespace xla {
26 namespace gpu {
27 namespace {
28
29 // Describes matched patterns:
30 // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
31 // for floating point types or
32 // max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
33 // * side_input + broadcast(bias));
34 // for int8.
35 // Where side_input has the shape of output buffer, and bias is a 1D array with
36 // the dimension of number of output features.
37 struct ConvWithRelu {
38 HloInstruction* maximum;
39 HloCustomCallInstruction* conv;
40 HloInstruction* bias;
41 HloInstruction* side_input;
42 HloConstantInstruction* alpha_conv;
43 HloConstantInstruction* alpha_side_input;
44 };
45
46 // The pattern we want to match:
47 // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
48 // or
49 // max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
50 // * side_input + broadcast(bias));
51 // With its variants involving commute/reassociation of adds, multiplies, and
52 // max, and omission of alpha1, side_input, alpha2, or bias.
FindConvWithRelu(HloInstruction * instr)53 absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
54 using match::Add;
55 using match::AddAnyOrder;
56 using match::AnyOf;
57 using match::Broadcast;
58 using match::ConstantScalar;
59 using match::GetTupleElement;
60 using match::Maximum;
61 using match::MultiplyAnyOrder;
62 using match::Op;
63
64 HloInstruction* relu_input;
65
66 // Match max(0, relu_input).
67 auto zero_pattern = Broadcast(ConstantScalar(0));
68 if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
69 !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
70 return absl::nullopt;
71 }
72 HloInstruction* conv_instr = nullptr;
73 HloInstruction* alpha_conv_instr = nullptr;
74 HloInstruction* alpha_side_input_instr = nullptr;
75 HloInstruction* bias_broadcast_instr = nullptr;
76 HloInstruction* bias = nullptr;
77 HloInstruction* side_input = nullptr;
78
79 // These nodes will not be in the returned value, but we need to check them
80 // for single use.
81 HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
82 *mul1 = nullptr, *mul2 = nullptr;
83
84 const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
85 const auto conv_pattern = [&] {
86 auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
87 auto conv_pattern = GetTupleElement(
88 >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
89 return AnyOf<HloInstruction>(
90 MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
91 }();
92 const auto side_input_pattern = [&] {
93 auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
94 // If bias is already matched, match arbitrary additional input as side
95 // input. Note this may force a cheap operation (e.g. broadcast) to be
96 // materialized into a large buffer, as large as the output buffer.
97 //
98 // TODO(timshen): If in practice there are significant false positives, we
99 // should fix it.
100 auto side_input_pattern = Op(&side_input);
101 return AnyOf<HloInstruction>(
102 MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
103 side_input_pattern);
104 }();
105
106 {
107 // Try to match any of the following form of add, in any association:
108 // addends[0]
109 // addends[0] + addends[1]
110 // addends[0] + addends[1] + addends[2]
111 //
112 // Then try to match each addend with one of the three patterns: bias, conv,
113 // or side_input. Notice that side_input matching must go last, as it
114 // also matches a conv or a bias.
115 HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
116 auto add3_pattern = [&] {
117 auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
118 return AnyOf<HloInstruction>(
119 AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
120 Op(&addends[0]));
121 }();
122 CHECK(Match(relu_input, add3_pattern));
123 for (auto addend : addends) {
124 if (addend) {
125 if (bias == nullptr && Match(addend, bias_pattern)) {
126 CHECK(bias);
127 } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
128 CHECK(conv_instr);
129 } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
130 CHECK(side_input);
131 } else {
132 return absl::nullopt;
133 }
134 }
135 }
136 }
137
138 if (conv_instr == nullptr) {
139 return absl::nullopt;
140 }
141
142 for (HloInstruction* instr :
143 {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
144 if (instr && instr->user_count() > 1) {
145 return absl::nullopt;
146 }
147 }
148
149 auto conv = Cast<HloCustomCallInstruction>(conv_instr);
150 auto bias_broadcast =
151 CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
152
153 if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
154 return absl::nullopt;
155 }
156
157 // In order to map to cudnnConvolutionBiasActivationForward for int8, the
158 // convolution output is float, i.e. conv<float>(int8_x, int8_w)
159 if (conv->operand(0)->shape().element_type() == xla::S8) {
160 if (conv->shape().tuple_shapes(0).element_type() != xla::F32) {
161 return absl::nullopt;
162 }
163 }
164
165 if (bias_broadcast) {
166 // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
167 if (bias_broadcast_instr->dimensions().size() != 1) {
168 return absl::nullopt;
169 }
170 if (bias_broadcast_instr->dimensions(0) !=
171 conv->convolution_dimension_numbers().output_feature_dimension()) {
172 return absl::nullopt;
173 }
174 }
175
176 return ConvWithRelu{
177 instr,
178 conv,
179 bias,
180 side_input,
181 CastOrNull<HloConstantInstruction>(alpha_conv_instr),
182 CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
183 }
184
TryRewriteToCudnnForwardRelu(ConvWithRelu match)185 StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
186 ConvWithRelu match) {
187 auto conv = match.conv;
188
189 HloComputation* computation = conv->parent();
190
191 const auto get_alpha_value =
192 [](HloConstantInstruction* instr) -> StatusOr<double> {
193 TF_ASSIGN_OR_RETURN(
194 auto alpha,
195 Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
196 return alpha.GetFirstElement<double>();
197 };
198
199 double alpha_conv = 1;
200 if (match.alpha_conv) {
201 TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
202 }
203
204 double alpha_side_input;
205 if (match.side_input) {
206 if (match.alpha_side_input) {
207 TF_ASSIGN_OR_RETURN(alpha_side_input,
208 get_alpha_value(match.alpha_side_input));
209 } else {
210 alpha_side_input = 1;
211 }
212 } else {
213 CHECK(match.alpha_side_input == nullptr);
214 alpha_side_input = 0;
215 }
216
217 auto bias = match.bias;
218 if (!bias) {
219 PrimitiveType conv_output_type =
220 conv->shape().tuple_shapes(0).element_type();
221 auto zero = computation->AddInstruction(
222 HloInstruction::CreateConstant(LiteralUtil::Zero(conv_output_type)));
223
224 int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
225 conv->convolution_dimension_numbers().output_feature_dimension());
226 bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
227 ShapeUtil::MakeShapeWithDescendingLayout(conv_output_type,
228 {num_output_feature}),
229 zero, {}));
230 }
231
232 CHECK(bias);
233 std::vector<HloInstruction*> args = {conv->mutable_operand(0),
234 conv->mutable_operand(1), bias};
235 if (match.side_input) {
236 args.push_back(match.side_input);
237 }
238 auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
239 conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
240 new_conv->set_feature_group_count(conv->feature_group_count());
241 new_conv->set_window(conv->window());
242 new_conv->set_convolution_dimension_numbers(
243 conv->convolution_dimension_numbers());
244 new_conv->set_metadata(conv->metadata());
245 TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
246 conv->backend_config<CudnnConvBackendConfig>());
247 config.set_activation_mode(
248 static_cast<int64>(se::dnn::ActivationMode::kRelu));
249 config.set_conv_result_scale(alpha_conv);
250 config.set_side_input_scale(alpha_side_input);
251 TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
252
253 VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
254 << new_conv->ToString();
255 return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
256 new_conv, 0);
257 }
258
259 // Fuse bias/scaling/ReLU with convolution custom call with floating point
260 // output
RunFuseBiasSideActivation(HloModule * module)261 StatusOr<bool> RunFuseBiasSideActivation(HloModule* module) {
262 bool changed = false;
263 for (HloComputation* computation : module->MakeNonfusionComputations()) {
264 std::vector<ConvWithRelu> matches;
265 int num_forward_convs = 0;
266 for (auto instr : computation->instructions()) {
267 auto match = FindConvWithRelu(instr);
268 if (match.has_value()) {
269 matches.push_back(*match);
270 }
271 if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
272 if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
273 num_forward_convs++;
274 }
275 }
276 }
277 VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
278 << " out of " << num_forward_convs << " forward convs.";
279 std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
280 replacements;
281 for (const ConvWithRelu& match : matches) {
282 TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
283 replacements.push_back({match.maximum, std::move(new_instr)});
284 changed = true;
285 }
286 for (auto& replacement : replacements) {
287 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
288 replacement.first, std::move(replacement.second)));
289 }
290 }
291 return changed;
292 }
293
294 // Describes a matched pattern:
295 // convert_or_clamp(get_tuple_element(custom_call(x,w, ...)));
296 // where the custom_call targets CuDNN convolution (either pure convolution or
297 // fused convolution).
298 struct ConvWithConvertOrClamp {
299 HloInstruction* convert_or_clamp;
300 HloInstruction* gte;
301 HloCustomCallInstruction* conv;
302 };
303
304 // The pattern we want to match:
305 // convert<int8>(clamp(broadcast(-128), (get_tuple_element(custom_call(int8_x,
306 // int8_w, ...)), broadcast(127));
FindConvWithClampAndConvertToInt8(HloInstruction * instr)307 absl::optional<ConvWithConvertOrClamp> FindConvWithClampAndConvertToInt8(
308 HloInstruction* instr) {
309 using match::Broadcast;
310 using match::Clamp;
311 using match::Convert;
312 using match::GetTupleElement;
313 using match::Op;
314
315 HloInstruction* gte = nullptr;
316 HloInstruction* conv_instr = nullptr;
317 auto lower_pattern = Broadcast(match::ConstantScalar(-128));
318 auto upper_pattern = Broadcast(match::ConstantScalar(127));
319 auto pattern = Convert(
320 Clamp(lower_pattern,
321 GetTupleElement(
322 >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0),
323 upper_pattern));
324
325 if (Match(instr, pattern)) {
326 if (conv_instr->operand(0)->shape().element_type() == xla::S8 &&
327 instr->shape().element_type() == xla::S8) {
328 HloCustomCallInstruction* conv =
329 CastOrNull<HloCustomCallInstruction>(conv_instr);
330 return ConvWithConvertOrClamp{instr, gte, conv};
331 }
332 }
333 return absl::nullopt;
334 }
335
336 // A help function to rewrite convert_or_clamp_or_other<new_type>(gte(conv()))
337 // to gte<new_type>(conv<new_type>()). It bypasses convert_or_clamp_or_other
338 // and set the output data type on gte and conv.
RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match)339 Status RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match) {
340 auto conv = match.conv;
341 auto gte = match.gte;
342 auto convert_or_clamp = match.convert_or_clamp;
343
344 // Change type on conv and gte
345 auto convert_out_type = convert_or_clamp->shape().element_type();
346 conv->mutable_shape()->mutable_tuple_shapes(0)->set_element_type(
347 convert_out_type);
348 gte->mutable_shape()->set_element_type(convert_out_type);
349
350 // Remove clamp/convert and so on and just keep
351 // get_tuple_element(custom_call(x,w, ...))
352 TF_RETURN_IF_ERROR(convert_or_clamp->ReplaceAllUsesWithDifferentShape(gte));
353 TF_RETURN_IF_ERROR(
354 conv->parent()->RemoveInstructionAndUnusedOperands(convert_or_clamp));
355 return Status::OK();
356 }
357
RewriteForFinalOutput(ConvWithConvertOrClamp match)358 Status RewriteForFinalOutput(ConvWithConvertOrClamp match) {
359 // When the matched clamp has a single user, which is convert<int8>, we
360 // will absorb it, if
361 // 1. the side_input matches a convert<float>(int8_side_input), or
362 // 2. there is no side input
363 const auto is_one_to_one_X_to_Y_cast = [](const HloInstruction* instr,
364 PrimitiveType X,
365 PrimitiveType Y) -> bool {
366 return (instr->opcode() == HloOpcode::kConvert &&
367 instr->shape().element_type() == Y && instr->operand_count() == 1 &&
368 instr->operand(0)->user_count() == 1 &&
369 instr->operand(0)->shape().element_type() == X);
370 };
371
372 if (match.conv->operand_count() < 4) {
373 // Conv input #3 (zero based) is side_input, after x, w, and bias.
374 // Side input doesn't exist in this case.
375 TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
376 } else if (is_one_to_one_X_to_Y_cast(match.conv->operand(3), S8, F32)) {
377 // If side_input has a convert_float_to_int8, absorb it as well.
378 auto side_converter = match.conv->mutable_operand(3);
379 TF_RETURN_IF_ERROR(side_converter->ReplaceAllUsesWithDifferentShape(
380 side_converter->mutable_operand(0)));
381 TF_RETURN_IF_ERROR(
382 side_converter->parent()->RemoveInstructionAndUnusedOperands(
383 side_converter));
384
385 TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
386 }
387 return Status::OK();
388 }
389
390 // Fuse the clamp/convert pattern with the int8 convolution custom call
391 // (either pure or fused) for int8 output
RunFuseClamp(HloModule * module)392 StatusOr<bool> RunFuseClamp(HloModule* module) {
393 bool changed = false;
394 for (HloComputation* computation : module->MakeNonfusionComputations()) {
395 std::vector<ConvWithConvertOrClamp> matches;
396 for (auto instr : computation->instructions()) {
397 auto match = FindConvWithClampAndConvertToInt8(instr);
398 if (match.has_value()) {
399 matches.push_back(*match);
400 }
401 }
402 for (const ConvWithConvertOrClamp& match : matches) {
403 TF_RETURN_IF_ERROR(RewriteForFinalOutput(match));
404 changed = true;
405 }
406
407 // Report error for any convolution still having int32 output.
408 // Although int32 output convolution will trigger other sanity check errors
409 // later, we want to give specific error message here.
410 for (auto instr : computation->instructions()) {
411 if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
412 if ((call->custom_call_target() == kCudnnConvForwardCallTarget ||
413 call->custom_call_target() ==
414 kCudnnConvBiasActivationForwardCallTarget) &&
415 call->shape().tuple_shapes(0).element_type() == xla::S32) {
416 return Unimplemented(
417 "Integer convolutions for CuDNN must have float or int8 output. "
418 "Use convert to cast output to float or the following pattern to "
419 "int8: "
420 "clamp(broadcast(-128), conv(int8_x, int8_w, ...), "
421 "broadcast(127)).");
422 }
423 }
424 }
425 }
426 return changed;
427 }
428
429 // The pattern we want to match:
430 // convert<float>(get_tuple_element<int32>(custom_call()));
FindConvWithConvertToFloat(HloInstruction * instr)431 absl::optional<ConvWithConvertOrClamp> FindConvWithConvertToFloat(
432 HloInstruction* instr) {
433 using match::Convert;
434 using match::GetTupleElement;
435 using match::Op;
436
437 HloInstruction* gte = nullptr;
438 HloInstruction* conv_instr = nullptr;
439 auto pattern =
440 Convert(GetTupleElement(
441 >e,
442 Op(&conv_instr)
443 .WithOpcode(HloOpcode::kCustomCall)
444 .WithCustomCallTarget(kCudnnConvForwardCallTarget),
445 0)
446 .WithShape(match::Shape().WithElementType(xla::S32)))
447 .WithShape(match::Shape().WithElementType(xla::F32));
448 if (Match(instr, pattern)) {
449 HloCustomCallInstruction* conv =
450 CastOrNull<HloCustomCallInstruction>(conv_instr);
451 return ConvWithConvertOrClamp{instr, gte, conv};
452 }
453 return absl::nullopt;
454 }
455
456 // Transform
457 // convert<float>(GetTupleElement<int32>(custom_call<int32>(int8_x, int8_w)))
458 // to
459 // GetTupleElement<float>(custom_call<int32>(int8_x, int8_w))
RunFuseConvertToFloat(HloModule * module)460 StatusOr<bool> RunFuseConvertToFloat(HloModule* module) {
461 bool changed = false;
462 for (HloComputation* computation : module->MakeNonfusionComputations()) {
463 std::vector<ConvWithConvertOrClamp> matches;
464 for (auto instr : computation->instructions()) {
465 auto match = FindConvWithConvertToFloat(instr);
466 if (match.has_value()) {
467 matches.push_back(*match);
468 }
469 }
470
471 for (const ConvWithConvertOrClamp& match : matches) {
472 TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
473 changed = true;
474 }
475 }
476 return changed;
477 }
478 } // namespace
479
Run(HloModule * module)480 StatusOr<bool> CudnnFusedConvRewriter::Run(HloModule* module) {
481 TF_ASSIGN_OR_RETURN(bool fused_for_convert_to_float,
482 RunFuseConvertToFloat(module));
483
484 TF_ASSIGN_OR_RETURN(bool fused_for_bias, RunFuseBiasSideActivation(module));
485
486 TF_ASSIGN_OR_RETURN(bool fused_for_clamp, RunFuseClamp(module));
487
488 return fused_for_convert_to_float || fused_for_bias || fused_for_clamp;
489 }
490 } // namespace gpu
491 } // namespace xla
492