1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
24 
25 namespace xla {
26 namespace gpu {
27 namespace {
28 
29 // Describes matched patterns:
30 //   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
31 //   for floating point types or
32 //   max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
33 //   * side_input + broadcast(bias));
34 //   for int8.
35 // Where side_input has the shape of output buffer, and bias is a 1D array with
36 // the dimension of number of output features.
37 struct ConvWithRelu {
38   HloInstruction* maximum;
39   HloCustomCallInstruction* conv;
40   HloInstruction* bias;
41   HloInstruction* side_input;
42   HloConstantInstruction* alpha_conv;
43   HloConstantInstruction* alpha_side_input;
44 };
45 
46 // The pattern we want to match:
47 //   max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
48 //   or
49 //   max(0, alpha1 * conv<float>(int8_x, int8_w) + alpha2 *
50 //   * side_input + broadcast(bias));
51 // With its variants involving commute/reassociation of adds, multiplies, and
52 // max, and omission of alpha1, side_input, alpha2, or bias.
FindConvWithRelu(HloInstruction * instr)53 absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
54   using match::Add;
55   using match::AddAnyOrder;
56   using match::AnyOf;
57   using match::Broadcast;
58   using match::ConstantScalar;
59   using match::GetTupleElement;
60   using match::Maximum;
61   using match::MultiplyAnyOrder;
62   using match::Op;
63 
64   HloInstruction* relu_input;
65 
66   // Match max(0, relu_input).
67   auto zero_pattern = Broadcast(ConstantScalar(0));
68   if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
69       !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
70     return absl::nullopt;
71   }
72   HloInstruction* conv_instr = nullptr;
73   HloInstruction* alpha_conv_instr = nullptr;
74   HloInstruction* alpha_side_input_instr = nullptr;
75   HloInstruction* bias_broadcast_instr = nullptr;
76   HloInstruction* bias = nullptr;
77   HloInstruction* side_input = nullptr;
78 
79   // These nodes will not be in the returned value, but we need to check them
80   // for single use.
81   HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
82                  *mul1 = nullptr, *mul2 = nullptr;
83 
84   const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
85   const auto conv_pattern = [&] {
86     auto alpha_pattern = Broadcast(ConstantScalar(&alpha_conv_instr));
87     auto conv_pattern = GetTupleElement(
88         &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
89     return AnyOf<HloInstruction>(
90         MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
91   }();
92   const auto side_input_pattern = [&] {
93     auto alpha_pattern = Broadcast(ConstantScalar(&alpha_side_input_instr));
94     // If bias is already matched, match arbitrary additional input as side
95     // input. Note this may force a cheap operation (e.g. broadcast) to be
96     // materialized into a large buffer, as large as the output buffer.
97     //
98     // TODO(timshen): If in practice there are significant false positives, we
99     // should fix it.
100     auto side_input_pattern = Op(&side_input);
101     return AnyOf<HloInstruction>(
102         MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
103         side_input_pattern);
104   }();
105 
106   {
107     // Try to match any of the following form of add, in any association:
108     //   addends[0]
109     //   addends[0] + addends[1]
110     //   addends[0] + addends[1] + addends[2]
111     //
112     // Then try to match each addend with one of the three patterns: bias, conv,
113     // or side_input. Notice that side_input matching must go last, as it
114     // also matches a conv or a bias.
115     HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
116     auto add3_pattern = [&] {
117       auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
118       return AnyOf<HloInstruction>(
119           AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
120           Op(&addends[0]));
121     }();
122     CHECK(Match(relu_input, add3_pattern));
123     for (auto addend : addends) {
124       if (addend) {
125         if (bias == nullptr && Match(addend, bias_pattern)) {
126           CHECK(bias);
127         } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
128           CHECK(conv_instr);
129         } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
130           CHECK(side_input);
131         } else {
132           return absl::nullopt;
133         }
134       }
135     }
136   }
137 
138   if (conv_instr == nullptr) {
139     return absl::nullopt;
140   }
141 
142   for (HloInstruction* instr :
143        {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
144     if (instr && instr->user_count() > 1) {
145       return absl::nullopt;
146     }
147   }
148 
149   auto conv = Cast<HloCustomCallInstruction>(conv_instr);
150   auto bias_broadcast =
151       CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
152 
153   if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
154     return absl::nullopt;
155   }
156 
157   // In order to map to cudnnConvolutionBiasActivationForward for int8, the
158   // convolution output is float, i.e. conv<float>(int8_x, int8_w)
159   if (conv->operand(0)->shape().element_type() == xla::S8) {
160     if (conv->shape().tuple_shapes(0).element_type() != xla::F32) {
161       return absl::nullopt;
162     }
163   }
164 
165   if (bias_broadcast) {
166     // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
167     if (bias_broadcast_instr->dimensions().size() != 1) {
168       return absl::nullopt;
169     }
170     if (bias_broadcast_instr->dimensions(0) !=
171         conv->convolution_dimension_numbers().output_feature_dimension()) {
172       return absl::nullopt;
173     }
174   }
175 
176   return ConvWithRelu{
177       instr,
178       conv,
179       bias,
180       side_input,
181       CastOrNull<HloConstantInstruction>(alpha_conv_instr),
182       CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
183 }
184 
TryRewriteToCudnnForwardRelu(ConvWithRelu match)185 StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
186     ConvWithRelu match) {
187   auto conv = match.conv;
188 
189   HloComputation* computation = conv->parent();
190 
191   const auto get_alpha_value =
192       [](HloConstantInstruction* instr) -> StatusOr<double> {
193     TF_ASSIGN_OR_RETURN(
194         auto alpha,
195         Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
196     return alpha.GetFirstElement<double>();
197   };
198 
199   double alpha_conv = 1;
200   if (match.alpha_conv) {
201     TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
202   }
203 
204   double alpha_side_input;
205   if (match.side_input) {
206     if (match.alpha_side_input) {
207       TF_ASSIGN_OR_RETURN(alpha_side_input,
208                           get_alpha_value(match.alpha_side_input));
209     } else {
210       alpha_side_input = 1;
211     }
212   } else {
213     CHECK(match.alpha_side_input == nullptr);
214     alpha_side_input = 0;
215   }
216 
217   auto bias = match.bias;
218   if (!bias) {
219     PrimitiveType conv_output_type =
220         conv->shape().tuple_shapes(0).element_type();
221     auto zero = computation->AddInstruction(
222         HloInstruction::CreateConstant(LiteralUtil::Zero(conv_output_type)));
223 
224     int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
225         conv->convolution_dimension_numbers().output_feature_dimension());
226     bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
227         ShapeUtil::MakeShapeWithDescendingLayout(conv_output_type,
228                                                  {num_output_feature}),
229         zero, {}));
230   }
231 
232   CHECK(bias);
233   std::vector<HloInstruction*> args = {conv->mutable_operand(0),
234                                        conv->mutable_operand(1), bias};
235   if (match.side_input) {
236     args.push_back(match.side_input);
237   }
238   auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
239       conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
240   new_conv->set_feature_group_count(conv->feature_group_count());
241   new_conv->set_window(conv->window());
242   new_conv->set_convolution_dimension_numbers(
243       conv->convolution_dimension_numbers());
244   new_conv->set_metadata(conv->metadata());
245   TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
246                       conv->backend_config<CudnnConvBackendConfig>());
247   config.set_activation_mode(
248       static_cast<int64>(se::dnn::ActivationMode::kRelu));
249   config.set_conv_result_scale(alpha_conv);
250   config.set_side_input_scale(alpha_side_input);
251   TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
252 
253   VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
254           << new_conv->ToString();
255   return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
256                                                new_conv, 0);
257 }
258 
259 // Fuse bias/scaling/ReLU with convolution custom call with floating point
260 // output
RunFuseBiasSideActivation(HloModule * module)261 StatusOr<bool> RunFuseBiasSideActivation(HloModule* module) {
262   bool changed = false;
263   for (HloComputation* computation : module->MakeNonfusionComputations()) {
264     std::vector<ConvWithRelu> matches;
265     int num_forward_convs = 0;
266     for (auto instr : computation->instructions()) {
267       auto match = FindConvWithRelu(instr);
268       if (match.has_value()) {
269         matches.push_back(*match);
270       }
271       if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
272         if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
273           num_forward_convs++;
274         }
275       }
276     }
277     VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
278             << " out of " << num_forward_convs << " forward convs.";
279     std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
280         replacements;
281     for (const ConvWithRelu& match : matches) {
282       TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
283       replacements.push_back({match.maximum, std::move(new_instr)});
284       changed = true;
285     }
286     for (auto& replacement : replacements) {
287       TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
288           replacement.first, std::move(replacement.second)));
289     }
290   }
291   return changed;
292 }
293 
294 // Describes a matched pattern:
295 // convert_or_clamp(get_tuple_element(custom_call(x,w, ...)));
296 // where the custom_call targets CuDNN convolution (either pure convolution or
297 // fused convolution).
298 struct ConvWithConvertOrClamp {
299   HloInstruction* convert_or_clamp;
300   HloInstruction* gte;
301   HloCustomCallInstruction* conv;
302 };
303 
304 // The pattern we want to match:
305 //   convert<int8>(clamp(broadcast(-128), (get_tuple_element(custom_call(int8_x,
306 //   int8_w, ...)), broadcast(127));
FindConvWithClampAndConvertToInt8(HloInstruction * instr)307 absl::optional<ConvWithConvertOrClamp> FindConvWithClampAndConvertToInt8(
308     HloInstruction* instr) {
309   using match::Broadcast;
310   using match::Clamp;
311   using match::Convert;
312   using match::GetTupleElement;
313   using match::Op;
314 
315   HloInstruction* gte = nullptr;
316   HloInstruction* conv_instr = nullptr;
317   auto lower_pattern = Broadcast(match::ConstantScalar(-128));
318   auto upper_pattern = Broadcast(match::ConstantScalar(127));
319   auto pattern = Convert(
320       Clamp(lower_pattern,
321             GetTupleElement(
322                 &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0),
323             upper_pattern));
324 
325   if (Match(instr, pattern)) {
326     if (conv_instr->operand(0)->shape().element_type() == xla::S8 &&
327         instr->shape().element_type() == xla::S8) {
328       HloCustomCallInstruction* conv =
329           CastOrNull<HloCustomCallInstruction>(conv_instr);
330       return ConvWithConvertOrClamp{instr, gte, conv};
331     }
332   }
333   return absl::nullopt;
334 }
335 
336 // A help function to rewrite convert_or_clamp_or_other<new_type>(gte(conv()))
337 // to gte<new_type>(conv<new_type>()).  It bypasses convert_or_clamp_or_other
338 // and set the output data type on gte and conv.
RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match)339 Status RewriteForConvertOrClampImpl(ConvWithConvertOrClamp match) {
340   auto conv = match.conv;
341   auto gte = match.gte;
342   auto convert_or_clamp = match.convert_or_clamp;
343 
344   // Change type on conv and gte
345   auto convert_out_type = convert_or_clamp->shape().element_type();
346   conv->mutable_shape()->mutable_tuple_shapes(0)->set_element_type(
347       convert_out_type);
348   gte->mutable_shape()->set_element_type(convert_out_type);
349 
350   // Remove clamp/convert and so on and just keep
351   // get_tuple_element(custom_call(x,w, ...))
352   TF_RETURN_IF_ERROR(convert_or_clamp->ReplaceAllUsesWithDifferentShape(gte));
353   TF_RETURN_IF_ERROR(
354       conv->parent()->RemoveInstructionAndUnusedOperands(convert_or_clamp));
355   return Status::OK();
356 }
357 
RewriteForFinalOutput(ConvWithConvertOrClamp match)358 Status RewriteForFinalOutput(ConvWithConvertOrClamp match) {
359   // When the matched clamp has a single user, which is convert<int8>, we
360   // will absorb it, if
361   // 1. the side_input matches a convert<float>(int8_side_input), or
362   // 2. there is no side input
363   const auto is_one_to_one_X_to_Y_cast = [](const HloInstruction* instr,
364                                             PrimitiveType X,
365                                             PrimitiveType Y) -> bool {
366     return (instr->opcode() == HloOpcode::kConvert &&
367             instr->shape().element_type() == Y && instr->operand_count() == 1 &&
368             instr->operand(0)->user_count() == 1 &&
369             instr->operand(0)->shape().element_type() == X);
370   };
371 
372   if (match.conv->operand_count() < 4) {
373     // Conv input #3 (zero based) is side_input, after x, w, and bias.
374     // Side input doesn't exist in this case.
375     TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
376   } else if (is_one_to_one_X_to_Y_cast(match.conv->operand(3), S8, F32)) {
377     // If side_input has a convert_float_to_int8, absorb it as well.
378     auto side_converter = match.conv->mutable_operand(3);
379     TF_RETURN_IF_ERROR(side_converter->ReplaceAllUsesWithDifferentShape(
380         side_converter->mutable_operand(0)));
381     TF_RETURN_IF_ERROR(
382         side_converter->parent()->RemoveInstructionAndUnusedOperands(
383             side_converter));
384 
385     TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
386   }
387   return Status::OK();
388 }
389 
390 // Fuse the clamp/convert pattern with the int8 convolution custom call
391 // (either pure or fused) for int8 output
RunFuseClamp(HloModule * module)392 StatusOr<bool> RunFuseClamp(HloModule* module) {
393   bool changed = false;
394   for (HloComputation* computation : module->MakeNonfusionComputations()) {
395     std::vector<ConvWithConvertOrClamp> matches;
396     for (auto instr : computation->instructions()) {
397       auto match = FindConvWithClampAndConvertToInt8(instr);
398       if (match.has_value()) {
399         matches.push_back(*match);
400       }
401     }
402     for (const ConvWithConvertOrClamp& match : matches) {
403       TF_RETURN_IF_ERROR(RewriteForFinalOutput(match));
404       changed = true;
405     }
406 
407     // Report error for any convolution still having int32 output.
408     // Although int32 output convolution will trigger other sanity check errors
409     // later, we want to give specific error message here.
410     for (auto instr : computation->instructions()) {
411       if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
412         if ((call->custom_call_target() == kCudnnConvForwardCallTarget ||
413              call->custom_call_target() ==
414                  kCudnnConvBiasActivationForwardCallTarget) &&
415             call->shape().tuple_shapes(0).element_type() == xla::S32) {
416           return Unimplemented(
417               "Integer convolutions for CuDNN must have float or int8 output.  "
418               "Use convert to cast output to float or the following pattern to "
419               "int8: "
420               "clamp(broadcast(-128), conv(int8_x, int8_w, ...), "
421               "broadcast(127)).");
422         }
423       }
424     }
425   }
426   return changed;
427 }
428 
429 // The pattern we want to match:
430 //   convert<float>(get_tuple_element<int32>(custom_call()));
FindConvWithConvertToFloat(HloInstruction * instr)431 absl::optional<ConvWithConvertOrClamp> FindConvWithConvertToFloat(
432     HloInstruction* instr) {
433   using match::Convert;
434   using match::GetTupleElement;
435   using match::Op;
436 
437   HloInstruction* gte = nullptr;
438   HloInstruction* conv_instr = nullptr;
439   auto pattern =
440       Convert(GetTupleElement(
441                   &gte,
442                   Op(&conv_instr)
443                       .WithOpcode(HloOpcode::kCustomCall)
444                       .WithCustomCallTarget(kCudnnConvForwardCallTarget),
445                   0)
446                   .WithShape(match::Shape().WithElementType(xla::S32)))
447           .WithShape(match::Shape().WithElementType(xla::F32));
448   if (Match(instr, pattern)) {
449     HloCustomCallInstruction* conv =
450         CastOrNull<HloCustomCallInstruction>(conv_instr);
451     return ConvWithConvertOrClamp{instr, gte, conv};
452   }
453   return absl::nullopt;
454 }
455 
456 // Transform
457 // convert<float>(GetTupleElement<int32>(custom_call<int32>(int8_x, int8_w)))
458 // to
459 // GetTupleElement<float>(custom_call<int32>(int8_x, int8_w))
RunFuseConvertToFloat(HloModule * module)460 StatusOr<bool> RunFuseConvertToFloat(HloModule* module) {
461   bool changed = false;
462   for (HloComputation* computation : module->MakeNonfusionComputations()) {
463     std::vector<ConvWithConvertOrClamp> matches;
464     for (auto instr : computation->instructions()) {
465       auto match = FindConvWithConvertToFloat(instr);
466       if (match.has_value()) {
467         matches.push_back(*match);
468       }
469     }
470 
471     for (const ConvWithConvertOrClamp& match : matches) {
472       TF_RETURN_IF_ERROR(RewriteForConvertOrClampImpl(match));
473       changed = true;
474     }
475   }
476   return changed;
477 }
478 }  // namespace
479 
Run(HloModule * module)480 StatusOr<bool> CudnnFusedConvRewriter::Run(HloModule* module) {
481   TF_ASSIGN_OR_RETURN(bool fused_for_convert_to_float,
482                       RunFuseConvertToFloat(module));
483 
484   TF_ASSIGN_OR_RETURN(bool fused_for_bias, RunFuseBiasSideActivation(module));
485 
486   TF_ASSIGN_OR_RETURN(bool fused_for_clamp, RunFuseClamp(module));
487 
488   return fused_for_convert_to_float || fused_for_bias || fused_for_clamp;
489 }
490 }  // namespace gpu
491 }  // namespace xla
492