1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
39 namespace {
40 
41 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
42 // operations with feature_group_count > 1 into convolutions with
43 // feature_group_count = 1.
44 class ConvolutionVisitor : public DfsHloVisitorWithDefault {
45  public:
46   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)47   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
48     return Status::OK();
49   }
50 
51   Status HandleConvolution(HloInstruction* convolution) override;
52 
53   Status HandleBatchGroupCount(HloInstruction* convolution);
54 
55   // Runs the visitor on a computation.
56   static bool Run(HloComputation* computation,
57                   std::function<bool(HloInstruction*)> is_cost_viable,
58                   bool convert_batch_groups_only,
59                   bool canonicalize_depthwise_filter);
60 
61   // Returns whether any convolution ops were rewritten.
changed() const62   const bool changed() const { return changed_; }
63 
64   ~ConvolutionVisitor() override = default;
65 
66  private:
ConvolutionVisitor(HloComputation * computation,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool canonicalize_depthwise_filter=false)67   explicit ConvolutionVisitor(
68       HloComputation* computation,
69       std::function<bool(HloInstruction*)> is_cost_viable,
70       bool convert_batch_groups_only,
71       bool canonicalize_depthwise_filter = false)
72       : computation_(computation),
73         filter_expansion_(!canonicalize_depthwise_filter),
74         convert_batch_groups_only_(convert_batch_groups_only),
75         is_cost_viable_(is_cost_viable) {}
76 
77   // Current HloComputation instance the ConvolutionVisitor is traversing.
78   HloComputation* computation_;
79 
80   // Whether rewrite has occurred.
81   bool changed_ = false;
82 
83   // Whether filter expansion is required.
84   bool filter_expansion_;
85 
86   // Decides whether to convert batch groups or feature groups.
87   bool convert_batch_groups_only_;
88 
89   // std::function<std::vector<LloValue*>(int64, int64)> chunk_fetcher
90   std::function<bool(HloInstruction*)> is_cost_viable_;
91 };
92 
Run(HloComputation * computation,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool canonicalize_depthwise_filter)93 bool ConvolutionVisitor::Run(
94     HloComputation* computation,
95     std::function<bool(HloInstruction*)> is_cost_viable,
96     bool convert_batch_groups_only, bool canonicalize_depthwise_filter) {
97   ConvolutionVisitor visitor(computation, is_cost_viable,
98                              convert_batch_groups_only,
99                              canonicalize_depthwise_filter);
100   TF_CHECK_OK(computation->Accept(&visitor));
101   return visitor.changed_;
102 }
103 
ExpandedFilterShape(const Shape & shape,int64 group_count,int64 input_feature_dim)104 Shape ExpandedFilterShape(const Shape& shape, int64 group_count,
105                           int64 input_feature_dim) {
106   int64 num_dims = shape.dimensions_size();
107   CHECK_GE(num_dims, 2);
108   Shape expanded_shape = shape;
109   expanded_shape.set_dimensions(
110       input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
111   return expanded_shape;
112 }
113 
114 // Returns a vector with 'group_count' many groups, where the i-th group
115 // consists of 'group_size' times the value i.
GetMaskIds(int64 group_size,int64 group_count)116 std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) {
117   std::vector<int32> values;
118   for (int i = 0; i < group_count; ++i) {
119     for (int j = 0; j < group_size; ++j) {
120       values.push_back(i);
121     }
122   }
123   return values;
124 }
125 
126 // Create a mask for grouped convolution that will make a normal convolution
127 // produce the same results as a grouped convolution. For a [2, 1, 6]
128 // filter this returns a [2, 3, 6] mask
129 //   1 1 0 0 0 0
130 //   0 0 1 1 0 0
131 //   0 0 0 0 1 1
132 //
133 //   1 1 0 0 0 0
134 //   0 0 1 1 0 0
135 //   0 0 0 0 1 1
136 //
137 // The first step is to create a rank 1 constant:
138 //   0 1 2
139 //
140 // This is broadcasted to
141 //   0 0 0 0 0 0
142 //   1 1 1 1 1 1
143 //   2 2 2 2 2 2
144 //
145 //   0 0 0 0 0 0
146 //   1 1 1 1 1 1
147 //   2 2 2 2 2 2
148 //
149 // Then we create another rank 1 constant
150 //   0 0 1 1 2 2
151 //
152 // This is broadcasted to
153 //   0 0 1 1 2 2
154 //   0 0 1 1 2 2
155 //   0 0 1 1 2 2
156 //
157 //   0 0 1 1 2 2
158 //   0 0 1 1 2 2
159 //   0 0 1 1 2 2
160 //
161 // Finally we use the Eq op of these two broadcasted constants and get the
162 // desired mask.
GetExpandedFilterMask(const Shape & filter_shape,int64 kernel_input_feature_dim,int64 kernel_output_feature_dim,int64 group_count,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)163 HloInstruction* GetExpandedFilterMask(
164     const Shape& filter_shape, int64 kernel_input_feature_dim,
165     int64 kernel_output_feature_dim, int64 group_count,
166     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
167         add_instruction) {
168   Shape expanded_filter_shape =
169       ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
170   Shape mask_shape = ShapeUtil::MakeShape(
171       S32, AsInt64Slice(expanded_filter_shape.dimensions()));
172   int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim);
173   int64 group_size = filter_shape.dimensions(kernel_input_feature_dim);
174 
175   // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
176   // that will be broadcasted into perpendicular dimensions and compared.
177   const std::vector<int32> input_feature_filter_mask =
178       GetMaskIds(group_size, group_count);
179   const std::vector<int32> output_feature_filter_mask =
180       GetMaskIds(output_feature / group_count, group_count);
181   auto mask1 = add_instruction(HloInstruction::CreateConstant(
182       LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
183   auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
184       mask_shape, mask1, {kernel_input_feature_dim}));
185   auto mask2 = add_instruction(HloInstruction::CreateConstant(
186       LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
187   auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
188       mask_shape, mask2, {kernel_output_feature_dim}));
189 
190   // Compare the broadcasted output feature linspace to the input feature
191   // linspace to create a diagonal predicate.
192   Shape predicate_shape = ShapeUtil::MakeShape(
193       PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
194   return add_instruction(HloInstruction::CreateCompare(
195       predicate_shape, broadcasted_mask1, broadcasted_mask2,
196       ComparisonDirection::kEq));
197 }
198 
199 // This function handles batch_group_counts which are relevant only for
200 // depthwise backprop filter convolutions.
HandleBatchGroupCount(HloInstruction * convolution)201 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
202   auto dim_numbers = convolution->convolution_dimension_numbers();
203   auto activation = convolution->mutable_operand(0);
204   auto filter = convolution->mutable_operand(1);
205   int64 batch_group_count = convolution->batch_group_count();
206 
207   if (batch_group_count == 1) {
208     return Status::OK();
209   }
210 
211   VLOG(2) << "Dealing with batch_group_count " << batch_group_count
212           << " for convolution " << convolution->ToString() << "\n";
213 
214   auto add = [&](std::unique_ptr<HloInstruction> inst) {
215     return computation_->AddInstruction(std::move(inst));
216   };
217 
218   int64 input_batch_dimension = dim_numbers.input_batch_dimension();
219   int64 output_batch_dimension = dim_numbers.output_batch_dimension();
220   int64 output_feature_dimension = dim_numbers.output_feature_dimension();
221 
222   int64 input_batch = activation->shape().dimensions(input_batch_dimension);
223 
224   // We are not yet supporting batch_group of sizes greater than 1.
225   TF_RET_CHECK(input_batch == batch_group_count);
226 
227   if (!is_cost_viable_(convolution) || filter_expansion_) {
228     // We first obtain the expanded the filter (which is the convolution
229     // output). The batch dimension is the expanded one (which originally
230     // represents kernel input feature dimension). We mask the filter to zero
231     // out the expanded regions. Next we reduce the filter in the batch
232     // dimension to obtain the original filter size.
233 
234     HloInstruction* filter_mask =
235         GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
236                               output_feature_dimension, batch_group_count, add);
237     auto expanded_filter_shape = ExpandedFilterShape(
238         convolution->shape(), batch_group_count, output_batch_dimension);
239 
240     auto new_convolution = add(HloInstruction::CreateConvolve(
241         expanded_filter_shape, activation, filter,
242         /*feature_group_count=*/1, /*batch_group_count=*/1,
243         convolution->window(), dim_numbers, convolution->precision_config()));
244 
245     auto zero = add(HloInstruction::CreateConstant(
246         LiteralUtil::Zero(expanded_filter_shape.element_type())));
247     auto zero_filter =
248         add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
249 
250     auto new_filter = add(HloInstruction::CreateTernary(
251         expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
252         zero_filter));
253 
254     PrimitiveType reduce_type = new_filter->shape().element_type();
255     auto reduce_window_shape = new_convolution->shape();
256     reduce_window_shape.set_dimensions(output_batch_dimension, 1);
257 
258     // Ensure that data input to reduce window uses at least 32 bits.
259     if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
260       reduce_type = F32;
261       reduce_window_shape.set_element_type(F32);
262       Shape convert_shape = new_filter->shape();
263       convert_shape.set_element_type(F32);
264       new_filter =
265           add(HloInstruction::CreateConvert(convert_shape, new_filter));
266     }
267 
268     auto zero_literal = LiteralUtil::Zero(reduce_type);
269     auto zero_scalar =
270         add(HloInstruction::CreateConstant(std::move(zero_literal)));
271 
272     auto reduce_function = [&]() -> HloComputation* {
273       HloComputation::Builder b("add_computation");
274       Shape shape = ShapeUtil::MakeShape(reduce_type, {});
275       auto lhs =
276           b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
277       auto rhs =
278           b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
279       auto scalar_op = b.AddInstruction(
280           HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
281       return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
282     };
283 
284     // Create the reduce window.
285     Window window;
286     for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
287       auto* dim = window.add_dimensions();
288       dim->set_padding_low(0);
289       dim->set_padding_high(0);
290       dim->set_window_dilation(1);
291       dim->set_base_dilation(1);
292       if (i == output_batch_dimension) {
293         dim->set_stride(batch_group_count);
294         dim->set_size(batch_group_count);
295       } else {
296         dim->set_stride(1);
297         dim->set_size(1);
298       }
299     }
300     auto reduce_window = add(HloInstruction::CreateReduceWindow(
301         reduce_window_shape, new_filter, zero_scalar, window,
302         reduce_function()));
303 
304     Shape convert_back_shape = reduce_window->shape();
305     convert_back_shape.set_element_type(activation->shape().element_type());
306 
307     // Convert reduced data back to the original data type.
308     auto reduce_window_converted =
309         HloInstruction::CreateConvert(convert_back_shape, reduce_window);
310 
311     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
312         convolution, std::move(reduce_window_converted)));
313     changed_ = true;
314   }
315 
316   return Status::OK();
317 }
318 
HandleConvolution(HloInstruction * convolution)319 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
320   if (convert_batch_groups_only_) {
321     return HandleBatchGroupCount(convolution);
322   }
323 
324   auto add = [&](std::unique_ptr<HloInstruction> inst) {
325     return computation_->AddInstruction(std::move(inst));
326   };
327 
328   int64 group_count = convolution->feature_group_count();
329   if (group_count == 1) {
330     return Status::OK();
331   }
332 
333   changed_ = true;
334   auto dim_numbers = convolution->convolution_dimension_numbers();
335   auto filter = convolution->mutable_operand(1);
336   int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
337   int64 group_size = filter->shape().dimensions(kernel_input_feature_dim);
338   int64 kernel_output_feature_dim =
339       dim_numbers.kernel_output_feature_dimension();
340   auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
341                                                    kernel_input_feature_dim);
342   HloInstruction* filter_mask =
343       GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
344                             kernel_output_feature_dim, group_count, add);
345   HloInstruction* expanded_filter;
346 
347   if (group_size == 1) {
348     bool depthwise_separable =
349         (group_count == filter->shape().dimensions(kernel_output_feature_dim));
350     // If the code generator handles depthwise separable convolutions
351     // inherently, then no filter expansion is needed.
352     if (!filter_expansion_ && depthwise_separable) {
353       changed_ = false;
354       return Status::OK();
355     }
356     // We want to repeat 'filter' in the 'input_feature_dim' dimension
357     // 'group_count' times.
358     Shape reshaped_filter_shape =
359         ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
360     auto reshaped_filter =
361         add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
362     std::vector<int64> broadcast_dims;
363     for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
364       if (i == kernel_input_feature_dim) {
365         continue;
366       }
367       broadcast_dims.push_back(i);
368     }
369     expanded_filter = add(HloInstruction::CreateBroadcast(
370         expanded_filter_shape, reshaped_filter, broadcast_dims));
371 
372     auto zero = add(HloInstruction::CreateConstant(
373         LiteralUtil::Zero(expanded_filter_shape.element_type())));
374     auto zero_filter =
375         add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
376     auto new_filter = add(HloInstruction::CreateTernary(
377         expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter,
378         zero_filter));
379 
380     auto new_convolution = HloInstruction::CreateConvolve(
381         convolution->shape(), convolution->mutable_operand(0), new_filter,
382         /*feature_group_count=*/1, /*batch_group_count=*/1,
383         convolution->window(), dim_numbers, convolution->precision_config());
384     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
385         convolution, std::move(new_convolution)));
386   } else {
387     int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
388 
389     int64 output_feature =
390         filter->shape().dimensions(kernel_output_feature_dim);
391 
392     // If group_count == output_feature, then we map those grouped convolutions
393     // onto depthwise convolution. This is done by adding an additional spatial
394     // dimension to the activations, kernel, and the output.
395     // E.g., we would turn
396     // [2, 12]{B, IF} conv [3, 4]{IF, OF} into
397     // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
398     // additional spatial dimension. The generated convolution output will be
399     // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
400 
401     if (group_count == output_feature && !filter_expansion_) {
402       auto filter = convolution->mutable_operand(1);
403       auto activation = convolution->mutable_operand(0);
404 
405       // Add spatial dimension to the activation, and reshape.
406       Shape reshaped_activation_shape = activation->shape();
407       ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
408 
409       int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1;
410 
411       reshaped_activation_shape.set_dimensions(activation_input_feature_dim,
412                                                group_count);
413       activation = add(
414           HloInstruction::CreateReshape(reshaped_activation_shape, activation));
415 
416       // Add spatial dimension to the filter, and reshape.
417       Shape reshaped_filter_shape = filter->shape();
418       ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape);
419 
420       filter =
421           add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
422 
423       Shape new_output_shape = convolution->shape();
424       ShapeUtil::AppendMajorDimension(1, &new_output_shape);
425 
426       // Edit convolution dimension numbers. Note that kernel_input_feature_dim
427       // now becomes a spatial dimension, and the newly added dimension of size
428       // 1 is the new kernel_input_feature_dim.
429       dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
430       dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim);
431       dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim);
432       dim_numbers.add_output_spatial_dimensions(new_spatial_dim);
433 
434       // Add window for the new spatial dimension.
435       Window new_window = convolution->window();
436       auto* dim = new_window.add_dimensions();
437       dim->set_window_dilation(1);
438       dim->set_base_dilation(1);
439       dim->set_stride(1);
440       dim->set_size(group_size);
441 
442       auto new_convolution = add(HloInstruction::CreateConvolve(
443           new_output_shape, activation, filter, group_count,
444           /*batch_group_count=*/1, new_window, dim_numbers,
445           convolution->precision_config()));
446 
447       // Delete the extra spatial dimension, and reshape.
448       Shape reshaped_convolution_shape =
449           ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
450       auto reshaped_convolution = HloInstruction::CreateReshape(
451           reshaped_convolution_shape, new_convolution);
452 
453       TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
454           convolution, std::move(reshaped_convolution)));
455 
456     } else {
457       // The filter expansion mechanism adds zeroes in the kernel.
458       // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask
459       // would look like (IF on the Y-axis, OF on the X-axis)
460       // 1 1 1 1 0 0 0 0 0 0 0 0
461       // 1 1 1 1 0 0 0 0 0 0 0 0
462       // 0 0 0 0 1 1 1 1 0 0 0 0
463       // 0 0 0 0 1 1 1 1 0 0 0 0
464       // 0 0 0 0 0 0 0 0 1 1 1 1
465       // 0 0 0 0 0 0 0 0 1 1 1 1
466       //
467       // Instead of convolving the above with the input, we instead slice the
468       // kernel into three kernels, each containing islands of 1s from the
469       // filter above. We also slice the activations in the IF dimension with
470       // each slice of size = group_size. For each slice, we perform
471       // convolutions, and concatenate the generated outputs in the output OF
472       // dimension.
473 
474       std::vector<HloInstruction*> sliced_convolutions;
475       auto activation = convolution->mutable_operand(0);
476       std::vector<int64> slice_strides(filter->shape().dimensions_size(), 1);
477       std::vector<int64> filter_slice_starts(filter->shape().dimensions_size(),
478                                              0);
479       std::vector<int64> filter_slice_limits(
480           filter->shape().dimensions().begin(),
481           filter->shape().dimensions().end());
482       std::vector<int64> activation_slice_starts(
483           activation->shape().dimensions_size(), 0);
484       std::vector<int64> activation_slice_limits(
485           activation->shape().dimensions().begin(),
486           activation->shape().dimensions().end());
487 
488       int64 output_feature =
489           filter->shape().dimensions(kernel_output_feature_dim);
490       auto output_feature_dim = dim_numbers.output_feature_dimension();
491       int64 filter_slice_width = output_feature / group_count;
492 
493       int64 activation_input_feature_dim =
494           dim_numbers.input_feature_dimension();
495 
496       for (int64 i = 0; i < group_count; i++) {
497         filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width;
498         filter_slice_limits[kernel_output_feature_dim] =
499             (i + 1) * filter_slice_width;
500         auto filter_sliced_shape = filter->shape();
501         filter_sliced_shape.set_dimensions(kernel_output_feature_dim,
502                                            filter_slice_width);
503         auto filter_slice = add(HloInstruction::CreateSlice(
504             filter_sliced_shape, filter, filter_slice_starts,
505             filter_slice_limits, slice_strides));
506 
507         activation_slice_starts[activation_input_feature_dim] = i * group_size;
508         activation_slice_limits[activation_input_feature_dim] =
509             (i + 1) * group_size;
510         auto activation_sliced_shape = activation->shape();
511         activation_sliced_shape.set_dimensions(activation_input_feature_dim,
512                                                group_size);
513         auto activation_slice = add(HloInstruction::CreateSlice(
514             activation_sliced_shape, activation, activation_slice_starts,
515             activation_slice_limits, slice_strides));
516 
517         auto conv_slice_shape = convolution->shape();
518         conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width);
519 
520         auto new_convolution = add(HloInstruction::CreateConvolve(
521             conv_slice_shape, activation_slice, filter_slice,
522             /*feature_group_count=*/1, /*batch_group_count=*/1,
523             convolution->window(), dim_numbers,
524             convolution->precision_config()));
525 
526         sliced_convolutions.push_back(new_convolution);
527       }
528 
529       auto new_conv = HloInstruction::CreateConcatenate(
530           convolution->shape(), sliced_convolutions, output_feature_dim);
531       TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
532           convolution, std::move(new_conv)));
533     }
534   }
535 
536   return Status::OK();
537 }
538 
539 }  // namespace
540 
Run(HloModule * module)541 StatusOr<bool> ConvolutionGroupConverter::Run(HloModule* module) {
542   XLA_VLOG_LINES(
543       2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
544   bool changed = false;
545   for (auto* comp : module->MakeNonfusionComputations()) {
546     if (ConvolutionVisitor::Run(comp, is_cost_viable_,
547                                 convert_batch_groups_only_,
548                                 filter_expansion_)) {
549       changed = true;
550     }
551   }
552   XLA_VLOG_LINES(
553       2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
554   return changed;
555 }
556 
557 }  // namespace xla
558