1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38
39 namespace {
40
41 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
42 // operations with feature_group_count > 1 into convolutions with
43 // feature_group_count = 1.
44 class ConvolutionVisitor : public DfsHloVisitorWithDefault {
45 public:
46 // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)47 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
48 return Status::OK();
49 }
50
51 Status HandleConvolution(HloInstruction* convolution) override;
52
53 Status HandleBatchGroupCount(HloInstruction* convolution);
54
55 // Runs the visitor on a computation.
56 static bool Run(HloComputation* computation,
57 std::function<bool(HloInstruction*)> is_cost_viable,
58 bool convert_batch_groups_only,
59 bool canonicalize_depthwise_filter);
60
61 // Returns whether any convolution ops were rewritten.
changed() const62 const bool changed() const { return changed_; }
63
64 ~ConvolutionVisitor() override = default;
65
66 private:
ConvolutionVisitor(HloComputation * computation,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool canonicalize_depthwise_filter=false)67 explicit ConvolutionVisitor(
68 HloComputation* computation,
69 std::function<bool(HloInstruction*)> is_cost_viable,
70 bool convert_batch_groups_only,
71 bool canonicalize_depthwise_filter = false)
72 : computation_(computation),
73 filter_expansion_(!canonicalize_depthwise_filter),
74 convert_batch_groups_only_(convert_batch_groups_only),
75 is_cost_viable_(is_cost_viable) {}
76
77 // Current HloComputation instance the ConvolutionVisitor is traversing.
78 HloComputation* computation_;
79
80 // Whether rewrite has occurred.
81 bool changed_ = false;
82
83 // Whether filter expansion is required.
84 bool filter_expansion_;
85
86 // Decides whether to convert batch groups or feature groups.
87 bool convert_batch_groups_only_;
88
89 // std::function<std::vector<LloValue*>(int64, int64)> chunk_fetcher
90 std::function<bool(HloInstruction*)> is_cost_viable_;
91 };
92
Run(HloComputation * computation,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool canonicalize_depthwise_filter)93 bool ConvolutionVisitor::Run(
94 HloComputation* computation,
95 std::function<bool(HloInstruction*)> is_cost_viable,
96 bool convert_batch_groups_only, bool canonicalize_depthwise_filter) {
97 ConvolutionVisitor visitor(computation, is_cost_viable,
98 convert_batch_groups_only,
99 canonicalize_depthwise_filter);
100 TF_CHECK_OK(computation->Accept(&visitor));
101 return visitor.changed_;
102 }
103
ExpandedFilterShape(const Shape & shape,int64 group_count,int64 input_feature_dim)104 Shape ExpandedFilterShape(const Shape& shape, int64 group_count,
105 int64 input_feature_dim) {
106 int64 num_dims = shape.dimensions_size();
107 CHECK_GE(num_dims, 2);
108 Shape expanded_shape = shape;
109 expanded_shape.set_dimensions(
110 input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
111 return expanded_shape;
112 }
113
114 // Returns a vector with 'group_count' many groups, where the i-th group
115 // consists of 'group_size' times the value i.
GetMaskIds(int64 group_size,int64 group_count)116 std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) {
117 std::vector<int32> values;
118 for (int i = 0; i < group_count; ++i) {
119 for (int j = 0; j < group_size; ++j) {
120 values.push_back(i);
121 }
122 }
123 return values;
124 }
125
126 // Create a mask for grouped convolution that will make a normal convolution
127 // produce the same results as a grouped convolution. For a [2, 1, 6]
128 // filter this returns a [2, 3, 6] mask
129 // 1 1 0 0 0 0
130 // 0 0 1 1 0 0
131 // 0 0 0 0 1 1
132 //
133 // 1 1 0 0 0 0
134 // 0 0 1 1 0 0
135 // 0 0 0 0 1 1
136 //
137 // The first step is to create a rank 1 constant:
138 // 0 1 2
139 //
140 // This is broadcasted to
141 // 0 0 0 0 0 0
142 // 1 1 1 1 1 1
143 // 2 2 2 2 2 2
144 //
145 // 0 0 0 0 0 0
146 // 1 1 1 1 1 1
147 // 2 2 2 2 2 2
148 //
149 // Then we create another rank 1 constant
150 // 0 0 1 1 2 2
151 //
152 // This is broadcasted to
153 // 0 0 1 1 2 2
154 // 0 0 1 1 2 2
155 // 0 0 1 1 2 2
156 //
157 // 0 0 1 1 2 2
158 // 0 0 1 1 2 2
159 // 0 0 1 1 2 2
160 //
161 // Finally we use the Eq op of these two broadcasted constants and get the
162 // desired mask.
GetExpandedFilterMask(const Shape & filter_shape,int64 kernel_input_feature_dim,int64 kernel_output_feature_dim,int64 group_count,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)163 HloInstruction* GetExpandedFilterMask(
164 const Shape& filter_shape, int64 kernel_input_feature_dim,
165 int64 kernel_output_feature_dim, int64 group_count,
166 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
167 add_instruction) {
168 Shape expanded_filter_shape =
169 ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
170 Shape mask_shape = ShapeUtil::MakeShape(
171 S32, AsInt64Slice(expanded_filter_shape.dimensions()));
172 int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim);
173 int64 group_size = filter_shape.dimensions(kernel_input_feature_dim);
174
175 // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
176 // that will be broadcasted into perpendicular dimensions and compared.
177 const std::vector<int32> input_feature_filter_mask =
178 GetMaskIds(group_size, group_count);
179 const std::vector<int32> output_feature_filter_mask =
180 GetMaskIds(output_feature / group_count, group_count);
181 auto mask1 = add_instruction(HloInstruction::CreateConstant(
182 LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
183 auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
184 mask_shape, mask1, {kernel_input_feature_dim}));
185 auto mask2 = add_instruction(HloInstruction::CreateConstant(
186 LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
187 auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
188 mask_shape, mask2, {kernel_output_feature_dim}));
189
190 // Compare the broadcasted output feature linspace to the input feature
191 // linspace to create a diagonal predicate.
192 Shape predicate_shape = ShapeUtil::MakeShape(
193 PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
194 return add_instruction(HloInstruction::CreateCompare(
195 predicate_shape, broadcasted_mask1, broadcasted_mask2,
196 ComparisonDirection::kEq));
197 }
198
199 // This function handles batch_group_counts which are relevant only for
200 // depthwise backprop filter convolutions.
HandleBatchGroupCount(HloInstruction * convolution)201 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
202 auto dim_numbers = convolution->convolution_dimension_numbers();
203 auto activation = convolution->mutable_operand(0);
204 auto filter = convolution->mutable_operand(1);
205 int64 batch_group_count = convolution->batch_group_count();
206
207 if (batch_group_count == 1) {
208 return Status::OK();
209 }
210
211 VLOG(2) << "Dealing with batch_group_count " << batch_group_count
212 << " for convolution " << convolution->ToString() << "\n";
213
214 auto add = [&](std::unique_ptr<HloInstruction> inst) {
215 return computation_->AddInstruction(std::move(inst));
216 };
217
218 int64 input_batch_dimension = dim_numbers.input_batch_dimension();
219 int64 output_batch_dimension = dim_numbers.output_batch_dimension();
220 int64 output_feature_dimension = dim_numbers.output_feature_dimension();
221
222 int64 input_batch = activation->shape().dimensions(input_batch_dimension);
223
224 // We are not yet supporting batch_group of sizes greater than 1.
225 TF_RET_CHECK(input_batch == batch_group_count);
226
227 if (!is_cost_viable_(convolution) || filter_expansion_) {
228 // We first obtain the expanded the filter (which is the convolution
229 // output). The batch dimension is the expanded one (which originally
230 // represents kernel input feature dimension). We mask the filter to zero
231 // out the expanded regions. Next we reduce the filter in the batch
232 // dimension to obtain the original filter size.
233
234 HloInstruction* filter_mask =
235 GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
236 output_feature_dimension, batch_group_count, add);
237 auto expanded_filter_shape = ExpandedFilterShape(
238 convolution->shape(), batch_group_count, output_batch_dimension);
239
240 auto new_convolution = add(HloInstruction::CreateConvolve(
241 expanded_filter_shape, activation, filter,
242 /*feature_group_count=*/1, /*batch_group_count=*/1,
243 convolution->window(), dim_numbers, convolution->precision_config()));
244
245 auto zero = add(HloInstruction::CreateConstant(
246 LiteralUtil::Zero(expanded_filter_shape.element_type())));
247 auto zero_filter =
248 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
249
250 auto new_filter = add(HloInstruction::CreateTernary(
251 expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
252 zero_filter));
253
254 PrimitiveType reduce_type = new_filter->shape().element_type();
255 auto reduce_window_shape = new_convolution->shape();
256 reduce_window_shape.set_dimensions(output_batch_dimension, 1);
257
258 // Ensure that data input to reduce window uses at least 32 bits.
259 if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
260 reduce_type = F32;
261 reduce_window_shape.set_element_type(F32);
262 Shape convert_shape = new_filter->shape();
263 convert_shape.set_element_type(F32);
264 new_filter =
265 add(HloInstruction::CreateConvert(convert_shape, new_filter));
266 }
267
268 auto zero_literal = LiteralUtil::Zero(reduce_type);
269 auto zero_scalar =
270 add(HloInstruction::CreateConstant(std::move(zero_literal)));
271
272 auto reduce_function = [&]() -> HloComputation* {
273 HloComputation::Builder b("add_computation");
274 Shape shape = ShapeUtil::MakeShape(reduce_type, {});
275 auto lhs =
276 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
277 auto rhs =
278 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
279 auto scalar_op = b.AddInstruction(
280 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
281 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
282 };
283
284 // Create the reduce window.
285 Window window;
286 for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
287 auto* dim = window.add_dimensions();
288 dim->set_padding_low(0);
289 dim->set_padding_high(0);
290 dim->set_window_dilation(1);
291 dim->set_base_dilation(1);
292 if (i == output_batch_dimension) {
293 dim->set_stride(batch_group_count);
294 dim->set_size(batch_group_count);
295 } else {
296 dim->set_stride(1);
297 dim->set_size(1);
298 }
299 }
300 auto reduce_window = add(HloInstruction::CreateReduceWindow(
301 reduce_window_shape, new_filter, zero_scalar, window,
302 reduce_function()));
303
304 Shape convert_back_shape = reduce_window->shape();
305 convert_back_shape.set_element_type(activation->shape().element_type());
306
307 // Convert reduced data back to the original data type.
308 auto reduce_window_converted =
309 HloInstruction::CreateConvert(convert_back_shape, reduce_window);
310
311 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
312 convolution, std::move(reduce_window_converted)));
313 changed_ = true;
314 }
315
316 return Status::OK();
317 }
318
HandleConvolution(HloInstruction * convolution)319 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
320 if (convert_batch_groups_only_) {
321 return HandleBatchGroupCount(convolution);
322 }
323
324 auto add = [&](std::unique_ptr<HloInstruction> inst) {
325 return computation_->AddInstruction(std::move(inst));
326 };
327
328 int64 group_count = convolution->feature_group_count();
329 if (group_count == 1) {
330 return Status::OK();
331 }
332
333 changed_ = true;
334 auto dim_numbers = convolution->convolution_dimension_numbers();
335 auto filter = convolution->mutable_operand(1);
336 int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
337 int64 group_size = filter->shape().dimensions(kernel_input_feature_dim);
338 int64 kernel_output_feature_dim =
339 dim_numbers.kernel_output_feature_dimension();
340 auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
341 kernel_input_feature_dim);
342 HloInstruction* filter_mask =
343 GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
344 kernel_output_feature_dim, group_count, add);
345 HloInstruction* expanded_filter;
346
347 if (group_size == 1) {
348 bool depthwise_separable =
349 (group_count == filter->shape().dimensions(kernel_output_feature_dim));
350 // If the code generator handles depthwise separable convolutions
351 // inherently, then no filter expansion is needed.
352 if (!filter_expansion_ && depthwise_separable) {
353 changed_ = false;
354 return Status::OK();
355 }
356 // We want to repeat 'filter' in the 'input_feature_dim' dimension
357 // 'group_count' times.
358 Shape reshaped_filter_shape =
359 ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
360 auto reshaped_filter =
361 add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
362 std::vector<int64> broadcast_dims;
363 for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
364 if (i == kernel_input_feature_dim) {
365 continue;
366 }
367 broadcast_dims.push_back(i);
368 }
369 expanded_filter = add(HloInstruction::CreateBroadcast(
370 expanded_filter_shape, reshaped_filter, broadcast_dims));
371
372 auto zero = add(HloInstruction::CreateConstant(
373 LiteralUtil::Zero(expanded_filter_shape.element_type())));
374 auto zero_filter =
375 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
376 auto new_filter = add(HloInstruction::CreateTernary(
377 expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter,
378 zero_filter));
379
380 auto new_convolution = HloInstruction::CreateConvolve(
381 convolution->shape(), convolution->mutable_operand(0), new_filter,
382 /*feature_group_count=*/1, /*batch_group_count=*/1,
383 convolution->window(), dim_numbers, convolution->precision_config());
384 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
385 convolution, std::move(new_convolution)));
386 } else {
387 int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
388
389 int64 output_feature =
390 filter->shape().dimensions(kernel_output_feature_dim);
391
392 // If group_count == output_feature, then we map those grouped convolutions
393 // onto depthwise convolution. This is done by adding an additional spatial
394 // dimension to the activations, kernel, and the output.
395 // E.g., we would turn
396 // [2, 12]{B, IF} conv [3, 4]{IF, OF} into
397 // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
398 // additional spatial dimension. The generated convolution output will be
399 // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
400
401 if (group_count == output_feature && !filter_expansion_) {
402 auto filter = convolution->mutable_operand(1);
403 auto activation = convolution->mutable_operand(0);
404
405 // Add spatial dimension to the activation, and reshape.
406 Shape reshaped_activation_shape = activation->shape();
407 ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
408
409 int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1;
410
411 reshaped_activation_shape.set_dimensions(activation_input_feature_dim,
412 group_count);
413 activation = add(
414 HloInstruction::CreateReshape(reshaped_activation_shape, activation));
415
416 // Add spatial dimension to the filter, and reshape.
417 Shape reshaped_filter_shape = filter->shape();
418 ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape);
419
420 filter =
421 add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
422
423 Shape new_output_shape = convolution->shape();
424 ShapeUtil::AppendMajorDimension(1, &new_output_shape);
425
426 // Edit convolution dimension numbers. Note that kernel_input_feature_dim
427 // now becomes a spatial dimension, and the newly added dimension of size
428 // 1 is the new kernel_input_feature_dim.
429 dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
430 dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim);
431 dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim);
432 dim_numbers.add_output_spatial_dimensions(new_spatial_dim);
433
434 // Add window for the new spatial dimension.
435 Window new_window = convolution->window();
436 auto* dim = new_window.add_dimensions();
437 dim->set_window_dilation(1);
438 dim->set_base_dilation(1);
439 dim->set_stride(1);
440 dim->set_size(group_size);
441
442 auto new_convolution = add(HloInstruction::CreateConvolve(
443 new_output_shape, activation, filter, group_count,
444 /*batch_group_count=*/1, new_window, dim_numbers,
445 convolution->precision_config()));
446
447 // Delete the extra spatial dimension, and reshape.
448 Shape reshaped_convolution_shape =
449 ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
450 auto reshaped_convolution = HloInstruction::CreateReshape(
451 reshaped_convolution_shape, new_convolution);
452
453 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
454 convolution, std::move(reshaped_convolution)));
455
456 } else {
457 // The filter expansion mechanism adds zeroes in the kernel.
458 // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask
459 // would look like (IF on the Y-axis, OF on the X-axis)
460 // 1 1 1 1 0 0 0 0 0 0 0 0
461 // 1 1 1 1 0 0 0 0 0 0 0 0
462 // 0 0 0 0 1 1 1 1 0 0 0 0
463 // 0 0 0 0 1 1 1 1 0 0 0 0
464 // 0 0 0 0 0 0 0 0 1 1 1 1
465 // 0 0 0 0 0 0 0 0 1 1 1 1
466 //
467 // Instead of convolving the above with the input, we instead slice the
468 // kernel into three kernels, each containing islands of 1s from the
469 // filter above. We also slice the activations in the IF dimension with
470 // each slice of size = group_size. For each slice, we perform
471 // convolutions, and concatenate the generated outputs in the output OF
472 // dimension.
473
474 std::vector<HloInstruction*> sliced_convolutions;
475 auto activation = convolution->mutable_operand(0);
476 std::vector<int64> slice_strides(filter->shape().dimensions_size(), 1);
477 std::vector<int64> filter_slice_starts(filter->shape().dimensions_size(),
478 0);
479 std::vector<int64> filter_slice_limits(
480 filter->shape().dimensions().begin(),
481 filter->shape().dimensions().end());
482 std::vector<int64> activation_slice_starts(
483 activation->shape().dimensions_size(), 0);
484 std::vector<int64> activation_slice_limits(
485 activation->shape().dimensions().begin(),
486 activation->shape().dimensions().end());
487
488 int64 output_feature =
489 filter->shape().dimensions(kernel_output_feature_dim);
490 auto output_feature_dim = dim_numbers.output_feature_dimension();
491 int64 filter_slice_width = output_feature / group_count;
492
493 int64 activation_input_feature_dim =
494 dim_numbers.input_feature_dimension();
495
496 for (int64 i = 0; i < group_count; i++) {
497 filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width;
498 filter_slice_limits[kernel_output_feature_dim] =
499 (i + 1) * filter_slice_width;
500 auto filter_sliced_shape = filter->shape();
501 filter_sliced_shape.set_dimensions(kernel_output_feature_dim,
502 filter_slice_width);
503 auto filter_slice = add(HloInstruction::CreateSlice(
504 filter_sliced_shape, filter, filter_slice_starts,
505 filter_slice_limits, slice_strides));
506
507 activation_slice_starts[activation_input_feature_dim] = i * group_size;
508 activation_slice_limits[activation_input_feature_dim] =
509 (i + 1) * group_size;
510 auto activation_sliced_shape = activation->shape();
511 activation_sliced_shape.set_dimensions(activation_input_feature_dim,
512 group_size);
513 auto activation_slice = add(HloInstruction::CreateSlice(
514 activation_sliced_shape, activation, activation_slice_starts,
515 activation_slice_limits, slice_strides));
516
517 auto conv_slice_shape = convolution->shape();
518 conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width);
519
520 auto new_convolution = add(HloInstruction::CreateConvolve(
521 conv_slice_shape, activation_slice, filter_slice,
522 /*feature_group_count=*/1, /*batch_group_count=*/1,
523 convolution->window(), dim_numbers,
524 convolution->precision_config()));
525
526 sliced_convolutions.push_back(new_convolution);
527 }
528
529 auto new_conv = HloInstruction::CreateConcatenate(
530 convolution->shape(), sliced_convolutions, output_feature_dim);
531 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
532 convolution, std::move(new_conv)));
533 }
534 }
535
536 return Status::OK();
537 }
538
539 } // namespace
540
Run(HloModule * module)541 StatusOr<bool> ConvolutionGroupConverter::Run(HloModule* module) {
542 XLA_VLOG_LINES(
543 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
544 bool changed = false;
545 for (auto* comp : module->MakeNonfusionComputations()) {
546 if (ConvolutionVisitor::Run(comp, is_cost_viable_,
547 convert_batch_groups_only_,
548 filter_expansion_)) {
549 changed = true;
550 }
551 }
552 XLA_VLOG_LINES(
553 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
554 return changed;
555 }
556
557 } // namespace xla
558