1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 using se::DeviceMemory;
32 using se::DeviceMemoryBase;
33 using se::Stream;
34 using se::dnn::AlgorithmConfig;
35 using se::dnn::BatchDescriptor;
36 using se::dnn::ConvolutionDescriptor;
37 using se::dnn::DataLayout;
38 using se::dnn::DimIndex;
39 using se::dnn::FilterDescriptor;
40 using se::dnn::FilterLayout;
41 using se::dnn::ProfileResult;
42 
43 // A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
44 // returning it (in its entirety) the first time Allocate() is called.
45 class ScratchBufAllocator : public se::ScratchAllocator {
46  public:
ScratchBufAllocator(se::DeviceMemoryBase scratch)47   explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
48       : scratch_(scratch) {}
49 
50   ~ScratchBufAllocator() override = default;
51 
GetMemoryLimitInBytes()52   int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
53 
AllocateBytes(int64 byte_size)54   se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
55       int64 byte_size) override {
56     if (allocated_) {
57       return se::port::InternalError(
58           "Can't allocate twice from a ScratchBufAllocator.");
59     }
60     if (byte_size > scratch_.size()) {
61       return se::port::InternalError(absl::StrCat(
62           "Can't allocate ", byte_size,
63           " bytes from a ScratchBufAllocator of size ", scratch_.size()));
64     }
65 
66     allocated_ = true;
67     return se::DeviceMemory<uint8>(scratch_);
68   }
69 
70  private:
71   se::DeviceMemoryBase scratch_;
72   bool allocated_ = false;
73 };
74 
75 template <typename ElementType, typename OutputType>
RunGpuConvForward(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)76 Status RunGpuConvForward(GpuConvParams params,
77                          se::ScratchAllocator* scratch_allocator,
78                          se::Stream* stream, RunConvOptions options,
79                          DeviceMemory<ElementType> input_buf,
80                          DeviceMemory<ElementType> filter_buf,
81                          DeviceMemory<OutputType> output_buf,
82                          AlgorithmConfig algorithm) {
83   if (params.config.conv_result_scale != 1) {
84     return InternalError(
85         "StreamExecutor doesn't support scaled convolution: %lf.",
86         params.config.conv_result_scale);
87   }
88   return stream->ConvolveWithAlgorithm(
89       params.config.input_descriptor, input_buf,
90       params.config.filter_descriptor, filter_buf, params.config.conv_desc,
91       params.config.output_descriptor, &output_buf, scratch_allocator,
92       algorithm, options.profile_result);
93 }
94 
95 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)96 Status RunGpuConvForwardActivation(GpuConvParams params,
97                                    se::ScratchAllocator* scratch_allocator,
98                                    se::Stream* stream, RunConvOptions options,
99                                    DeviceMemory<ElementType> input_buf,
100                                    DeviceMemory<ElementType> filter_buf,
101                                    DeviceMemory<OutputType> output_buf,
102                                    AlgorithmConfig algorithm) {
103   BatchDescriptor bias_desc;
104   bias_desc.set_count(1)
105       .set_height(1)
106       .set_width(1)
107       .set_feature_map_count(
108           params.config.output_descriptor.feature_map_count())
109       .set_layout(params.config.output_descriptor.layout());
110 
111   se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
112   // If there is no side input, use output as the side input.
113   if (side_input.is_null()) {
114     if (params.config.fusion->side_input_scale != 0) {
115       return InternalError(
116           "Side input scale is not 0, yet no side input buffer is "
117           "provided");
118     }
119     // Since side-input scale is 0, the values in the side input don't
120     // matter.  The simplest thing to do would be to pass in a null buffer
121     // for the side input, but cudnn doesn't allow this.  cudnn does promise
122     // that if side-input-scale is 0 the side input won't be read, so we
123     // just pass in the output buffer, since it's handy and has the correct
124     // size.
125     side_input = output_buf;
126   }
127 
128   return stream->FusedConvolveWithAlgorithm(
129       params.config.input_descriptor, input_buf,
130       params.config.conv_result_scale, params.config.filter_descriptor,
131       filter_buf, params.config.conv_desc, side_input,
132       params.config.fusion->side_input_scale, bias_desc,
133       DeviceMemory<BiasType>(params.fusion->bias_buf),
134       params.config.fusion->mode, params.config.output_descriptor, &output_buf,
135       scratch_allocator, algorithm, options.profile_result);
136 }
137 
138 // StreamExecutor supports various data types via overloading, and the support
139 // is maintained on-demand. To avoid calling into non-exist overloads, we have
140 // to carefully not call into them by using enable_if.
141 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
142 // StreamExecutor overloadings to template functions, and for unsupported data
143 // types return runtime errors.
144 // This is the specialization for double, float, and half types.  All kinds of
145 // convolutions are supported here.
146 template <typename ElementType, typename BiasType, typename OutputType,
147           typename std::enable_if<
148               !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)149 Status RunGpuConvInternalImpl(GpuConvParams params,
150                               se::ScratchAllocator* scratch_allocator,
151                               se::Stream* stream, RunConvOptions options,
152                               DeviceMemory<ElementType> input_buf,
153                               DeviceMemory<ElementType> filter_buf,
154                               DeviceMemory<OutputType> output_buf,
155                               AlgorithmConfig algorithm) {
156   switch (params.config.kind) {
157     case CudnnConvKind::kForward:
158       return RunGpuConvForward(params, scratch_allocator, stream, options,
159                                input_buf, filter_buf, output_buf, algorithm);
160     case CudnnConvKind::kBackwardInput:
161       if (params.config.conv_result_scale != 1) {
162         return InternalError(
163             "StreamExecutor doesn't support scaled convolution: %lf.",
164             params.config.conv_result_scale);
165       }
166       return stream->ConvolveBackwardDataWithAlgorithm(
167           params.config.filter_descriptor, filter_buf,
168           params.config.output_descriptor, output_buf, params.config.conv_desc,
169           params.config.input_descriptor, &input_buf, scratch_allocator,
170           algorithm, options.profile_result);
171       break;
172     case CudnnConvKind::kBackwardFilter:
173       if (params.config.conv_result_scale != 1) {
174         return InternalError(
175             "StreamExecutor doesn't support scaled convolution: %lf.",
176             params.config.conv_result_scale);
177       }
178       return stream->ConvolveBackwardFilterWithAlgorithm(
179           params.config.input_descriptor, input_buf,
180           params.config.output_descriptor, output_buf, params.config.conv_desc,
181           params.config.filter_descriptor, &filter_buf, scratch_allocator,
182           algorithm, options.profile_result);
183       break;
184     case CudnnConvKind::kForwardActivation: {
185       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
186           params, scratch_allocator, stream, options, input_buf, filter_buf,
187           output_buf, algorithm);
188     }
189   }
190   return Status::OK();
191 }
192 
193 // Specialization for integer types.  Only two forward convolutions are allowed.
194 template <typename ElementType, typename BiasType, typename OutputType,
195           typename std::enable_if<std::is_integral<ElementType>::value>::type* =
196               nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)197 Status RunGpuConvInternalImpl(GpuConvParams params,
198                               se::ScratchAllocator* scratch_allocator,
199                               se::Stream* stream, RunConvOptions options,
200                               DeviceMemory<ElementType> input_buf,
201                               DeviceMemory<ElementType> filter_buf,
202                               DeviceMemory<OutputType> output_buf,
203                               AlgorithmConfig algorithm) {
204   switch (params.config.kind) {
205     case CudnnConvKind::kForward:
206       return RunGpuConvForward(params, scratch_allocator, stream, options,
207                                input_buf, filter_buf, output_buf, algorithm);
208     case CudnnConvKind::kForwardActivation:
209       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
210           params, scratch_allocator, stream, options, input_buf, filter_buf,
211           output_buf, algorithm);
212     default:
213       return InternalError(
214           "Only convolution kinds kForward and kForwardActivation are "
215           "supported for integer types");
216   }
217   return Status::OK();
218 }
219 
220 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)221 Status RunGpuConvImpl(const GpuConvParams& params,
222                       se::ScratchAllocator* scratch_allocator,
223                       se::Stream* stream, RunConvOptions options) {
224   auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
225   auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
226   auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
227   AlgorithmConfig algorithm = params.config.algorithm;
228 
229   if (options.algo_override.has_value()) {
230     algorithm = AlgorithmConfig(*options.algo_override);
231     if (options.scratch_size_override.has_value()) {
232       algorithm.set_scratch_size(*options.scratch_size_override);
233     }
234   }
235 
236   Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
237       params, scratch_allocator, stream, options, input_buf, filter_buf,
238       output_buf, algorithm);
239 
240   if (run_status != Status::OK()) {
241     return run_status;
242   }
243 
244   if (!stream->ok()) {
245     return InternalError(
246         "Unable to launch convolution with type %s and algorithm (%d, %s)",
247         CudnnConvKindToString(params.config.kind),
248         algorithm.algorithm()->algo_id(),
249         algorithm.algorithm_no_scratch().has_value()
250             ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id())
251             : "none");
252   }
253   return Status::OK();
254 }
255 
256 }  // anonymous namespace
257 
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)258 StatusOr<GpuConvConfig> GetGpuConvConfig(
259     const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
260   GpuConvConfig config;
261 
262   const Shape& operand0_shape = desc.operand0_shape;
263   const Shape& operand1_shape = desc.operand1_shape;
264   const Shape& result_shape = desc.result_shape;
265   const CudnnConvBackendConfig& backend_config = desc.backend_config;
266 
267   config.input_type = operand0_shape.element_type();
268   config.output_type = result_shape.element_type();
269   config.kind = desc.kind;
270 
271   // The third field is scratch size stored from conv_algorithm_picker
272   // The operand is added to the shape field of the conv instruction
273   // in GpuConvAlgorithmPicker::RunOnInstruction() call.
274   config.algorithm = se::dnn::AlgorithmConfig(
275       se::dnn::AlgorithmDesc(backend_config.algorithm(),
276                              backend_config.tensor_ops_enabled()),
277       desc.scratch_size);
278   config.conv_result_scale = backend_config.conv_result_scale();
279 
280   switch (config.kind) {
281     case CudnnConvKind::kForward:
282     case CudnnConvKind::kForwardActivation:
283       config.input_shape = operand0_shape;
284       config.filter_shape = operand1_shape;
285       config.output_shape = result_shape;
286       break;
287     case CudnnConvKind::kBackwardInput:
288       config.input_shape = result_shape;
289       config.filter_shape = operand1_shape;
290       config.output_shape = operand0_shape;
291       break;
292     case CudnnConvKind::kBackwardFilter:
293       config.input_shape = operand0_shape;
294       config.filter_shape = result_shape;
295       config.output_shape = operand1_shape;
296       break;
297     default:
298       return InternalError("Unknown convolution kind");
299   }
300 
301   if (config.kind == CudnnConvKind::kForwardActivation) {
302     config.fusion.emplace();
303     GpuConvConfig::FusionConfig& fusion = *config.fusion;
304     if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
305       return InternalError("Bad activation mode: %s",
306                            backend_config.ShortDebugString());
307     }
308     fusion.mode =
309         static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
310     fusion.side_input_scale = backend_config.side_input_scale();
311   }
312 
313   const Window& window = desc.window;
314   const ConvolutionDimensionNumbers& dnums = desc.dnums;
315 
316   VLOG(3) << "Convolution Algorithm: "
317           << config.algorithm.algorithm()->algo_id();
318   VLOG(3) << "tensor_ops_enabled: "
319           << config.algorithm.algorithm()->tensor_ops_enabled();
320   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
321   VLOG(3) << "input shape: "
322           << ShapeUtil::HumanStringWithLayout(config.input_shape);
323   VLOG(3) << "filter shape: "
324           << ShapeUtil::HumanStringWithLayout(config.filter_shape);
325   VLOG(3) << "Output shape: "
326           << ShapeUtil::HumanStringWithLayout(config.output_shape);
327   VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
328   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
329 
330   const int num_dimensions = window.dimensions_size();
331   CHECK_LE(num_dimensions, 3) << inst_as_string;
332 
333   // cuDNN does not support 1D convolutions. We therefore express 1D
334   // convolutions as 2D convolutions where the first spatial dimension is 1.
335   // This matches the behavior of TF (see definition of conv1d in
336   // tensorflow/python/ops/nn_ops.py).
337   const int effective_num_dimensions = std::max(2, num_dimensions);
338 
339   // If one dimension is reversed, we need to have all dimensions reversed (so
340   // we're doing convolution not cross correlation).
341   const bool dims_reversed =
342       window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
343 
344   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
345       << inst_as_string;
346   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
347       << inst_as_string;
348   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
349       << inst_as_string;
350   for (const WindowDimension& dim : window.dimensions()) {
351     CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
352     CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
353     CHECK_EQ(dim.base_dilation(), 1)
354         << "cudnn does not support base dilation; it "
355            "must be made explicit with a kPad: "
356         << inst_as_string;
357   }
358 
359   // cuDNN's convolution APIs support the BDYX layout for activations/output and
360   // the OIYX layout for weights.
361   DataLayout input_dl;
362   FilterLayout filter_dl;
363   DataLayout output_dl;
364 
365   const Shape& input_shape = config.input_shape;
366   const Shape& filter_shape = config.filter_shape;
367   const Shape& output_shape = config.output_shape;
368 
369   TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
370                       XlaConvLayoutsToStreamExecutorLayouts(
371                           dnums, input_shape.layout(), filter_shape.layout(),
372                           output_shape.layout()));
373 
374   BatchDescriptor& input_descriptor = config.input_descriptor;
375   input_descriptor = BatchDescriptor(effective_num_dimensions);
376   input_descriptor.set_layout(input_dl)
377       .set_feature_map_count(
378           input_shape.dimensions(dnums.input_feature_dimension()))
379       .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
380   for (int dim = 0; dim < num_dimensions; ++dim) {
381     // Note that the dimensions are reversed. The same holds below.
382     input_descriptor.set_spatial_dim(
383         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
384         input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
385   }
386 
387   FilterDescriptor& filter_descriptor = config.filter_descriptor;
388   filter_descriptor = FilterDescriptor(effective_num_dimensions);
389   filter_descriptor.set_layout(filter_dl)
390       .set_input_feature_map_count(
391           filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
392       .set_output_feature_map_count(
393           filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
394   for (int dim = 0; dim < num_dimensions; ++dim) {
395     filter_descriptor.set_spatial_dim(
396         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
397         filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
398   }
399 
400   config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
401   config.conv_desc.set_group_count(desc.feature_group_count);
402   config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
403   for (int dim = 0; dim < num_dimensions; ++dim) {
404     config.conv_desc
405         .set_zero_padding(
406             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
407             window.dimensions(dim).padding_low())
408         .set_filter_stride(
409             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
410             window.dimensions(dim).stride())
411         .set_dilation_rate(
412             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
413             window.dimensions(dim).window_dilation());
414   }
415 
416   BatchDescriptor& output_descriptor = config.output_descriptor;
417   output_descriptor = BatchDescriptor(effective_num_dimensions);
418   output_descriptor.set_layout(output_dl)
419       .set_feature_map_count(
420           output_shape.dimensions(dnums.output_feature_dimension()))
421       .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
422   for (int dim = 0; dim < num_dimensions; ++dim) {
423     output_descriptor.set_spatial_dim(
424         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
425         output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
426   }
427 
428   // Add a singleton dimension in the 1D convolution case.
429   for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
430     input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
431     output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
432     filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
433     config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
434         .set_filter_stride(static_cast<DimIndex>(dim), 1);
435   }
436 
437   return config;
438 }
439 
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)440 StatusOr<GpuConvConfig> GetGpuConvConfig(
441     const HloCustomCallInstruction* cudnn_call) {
442   GpuConvDescriptor descriptor;
443 
444   TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
445   TF_ASSIGN_OR_RETURN(descriptor.backend_config,
446                       cudnn_call->backend_config<CudnnConvBackendConfig>());
447   descriptor.operand0_shape = cudnn_call->operand(0)->shape();
448   descriptor.operand1_shape = cudnn_call->operand(1)->shape();
449   descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
450   descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
451   descriptor.window = cudnn_call->window();
452   descriptor.dnums = cudnn_call->convolution_dimension_numbers();
453   descriptor.feature_group_count = cudnn_call->feature_group_count();
454   return GetGpuConvConfig(descriptor, cudnn_call->ToString());
455 }
456 
GetGpuConvParams(const GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)457 StatusOr<GpuConvParams> GetGpuConvParams(
458     const GpuConvConfig& config,
459     absl::Span<se::DeviceMemoryBase> operand_buffers,
460     se::DeviceMemoryBase result_buffer) {
461   GpuConvParams params;
462   params.config = config;
463 
464   switch (config.kind) {
465     case CudnnConvKind::kForward:
466     case CudnnConvKind::kForwardActivation:
467       params.input_buf = operand_buffers[0];
468       params.filter_buf = operand_buffers[1];
469       params.output_buf = result_buffer;
470       break;
471     case CudnnConvKind::kBackwardInput:
472       params.input_buf = result_buffer;
473       params.filter_buf = operand_buffers[1];
474       params.output_buf = operand_buffers[0];
475       break;
476     case CudnnConvKind::kBackwardFilter:
477       params.input_buf = operand_buffers[0];
478       params.filter_buf = result_buffer;
479       params.output_buf = operand_buffers[1];
480       break;
481   }
482 
483   if (config.kind == CudnnConvKind::kForwardActivation) {
484     params.fusion.emplace();
485     GpuConvParams::FusionParams& fusion = *params.fusion;
486     fusion.bias_buf = operand_buffers[2];
487     if (operand_buffers.size() >= 4) {
488       fusion.side_input_buf = operand_buffers[3];
489     }
490   }
491 
492   return params;
493 }
494 
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_buf,se::Stream * stream,RunConvOptions options)495 Status RunGpuConv(const gpu::GpuConvConfig& config,
496                   absl::Span<se::DeviceMemoryBase> operand_buffers,
497                   se::DeviceMemoryBase result_buffer,
498                   se::DeviceMemoryBase scratch_buf, se::Stream* stream,
499                   RunConvOptions options) {
500   ScratchBufAllocator scratch_allocator(scratch_buf);
501   return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator,
502                     stream, options);
503 }
504 
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)505 Status RunGpuConv(const gpu::GpuConvConfig& config,
506                   absl::Span<se::DeviceMemoryBase> operand_buffers,
507                   se::DeviceMemoryBase result_buffer,
508                   se::ScratchAllocator* scratch_allocator, se::Stream* stream,
509                   RunConvOptions options) {
510   TF_ASSIGN_OR_RETURN(GpuConvParams params,
511                       GetGpuConvParams(config, operand_buffers, result_buffer));
512 
513   PrimitiveType input_primitive_type = config.input_type;
514   switch (input_primitive_type) {
515     case F16:
516       return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
517           params, scratch_allocator, stream, options);
518     case F32:
519       return RunGpuConvImpl<float, float, float>(params, scratch_allocator,
520                                                  stream, options);
521     case F64:
522       return RunGpuConvImpl<double, double, double>(params, scratch_allocator,
523                                                     stream, options);
524     case S8: {
525       PrimitiveType output_primitive_type = config.output_type;
526       switch (output_primitive_type) {
527         case F32:
528           return RunGpuConvImpl<int8, float, float>(params, scratch_allocator,
529                                                     stream, options);
530         case S8:
531           return RunGpuConvImpl<int8, float, int8>(params, scratch_allocator,
532                                                    stream, options);
533         default:
534           return Unimplemented("Unimplemented convolution");
535       }
536     }
537     default:
538       return Unimplemented("Unimplemented convolution");
539   }
540 }
541 
542 }  // namespace gpu
543 }  // namespace xla
544