1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17 
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_parser.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace op = xla::testing::opcode_matchers;
40 
41 namespace xla {
42 namespace {
43 
44 class TransposeFoldingTest : public HloTestBase {
45  protected:
FoldTranspose(HloModule * module)46   void FoldTranspose(HloModule* module) {
47     TransposeFolding transpose_folding(
48         [](const HloInstruction& dot,
49            const TransposeFolding::OperandIndices& candidate_operands) {
50           return candidate_operands;
51         },
52         [](const HloInstruction& convolution,
53            const TransposeFolding::OperandIndices& candidate_operands) {
54           return candidate_operands;
55         });
56     EXPECT_IS_OK(transpose_folding.Run(module).status());
57   }
58 };
59 
TEST_F(TransposeFoldingTest,FoldDotTranspose)60 TEST_F(TransposeFoldingTest, FoldDotTranspose) {
61   string hlo_string = R"(
62 HloModule FoldDotTranspose
63 
64 ENTRY entry_computation {
65   x = f32[2,3]{1,0} parameter(0)
66   y = f32[2,3]{1,0} parameter(1)
67   transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0}
68   ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
69 }
70 )";
71   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
72                           ParseHloString(hlo_string));
73 
74   FoldTranspose(module.get());
75 
76   EXPECT_THAT(module->entry_computation()->root_instruction(),
77               op::Dot(op::Parameter(0), op::Parameter(1),
78                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
79 }
80 
TEST_F(TransposeFoldingTest,DontFoldTransposeOfBatchDim)81 TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) {
82   string hlo_string = R"(
83 HloModule FoldDotTranspose
84 
85 ENTRY entry_computation {
86   x = f32[2,3] parameter(0)
87   y = f32[3,2] parameter(1)
88   transpose = f32[2,3] transpose(y), dimensions={1,0}
89   ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
90 }
91 )";
92 
93   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
94                           ParseHloString(hlo_string));
95 
96   TransposeFolding transpose_folding(
97       [](const HloInstruction& dot,
98          const TransposeFolding::OperandIndices& candidate_operands) {
99         return candidate_operands;
100       },
101       [](const HloInstruction& convolution,
102          const TransposeFolding::OperandIndices& candidate_operands) {
103         return candidate_operands;
104       });
105   TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
106   EXPECT_FALSE(changed);
107 }
108 
TEST_F(TransposeFoldingTest,DontFoldTransposeOfRank1Dot)109 TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) {
110   string hlo_string = R"(
111 HloModule FoldDotTranspose
112 
113 ENTRY entry_computation {
114   x = f32[3] parameter(0)
115   y = f32[3,2] parameter(1)
116   transpose = f32[2,3] transpose(y), dimensions={1,0}
117   ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1}
118 }
119 )";
120 
121   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
122                           ParseHloString(hlo_string));
123 
124   TransposeFolding transpose_folding(
125       [](const HloInstruction& dot,
126          const TransposeFolding::OperandIndices& candidate_operands) {
127         return candidate_operands;
128       },
129       [](const HloInstruction& convolution,
130          const TransposeFolding::OperandIndices& candidate_operands) {
131         return candidate_operands;
132       });
133   TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
134   EXPECT_FALSE(changed);
135 }
136 
TEST_F(TransposeFoldingTest,FoldDotTransposeConstant)137 TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
138   string hlo_string = R"(
139 HloModule FoldDotTransposeConstant
140 
141 ENTRY entry_computation {
142   constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } })
143   transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0}
144   constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } })
145   transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0}
146   ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
147 }
148 )";
149   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
150                           ParseHloString(hlo_string));
151 
152   FoldTranspose(module.get());
153 
154   EXPECT_THAT(module->entry_computation()->root_instruction(),
155               op::Dot(op::Constant(), op::Constant(),
156                       /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1));
157 }
158 
TEST_F(TransposeFoldingTest,FuseDotWithConstantOperands)159 TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
160   auto builder = HloComputation::Builder("entry");
161   // (1.0 + 2.0) * (2.0 - 3.0)
162   HloInstruction* const1 = builder.AddInstruction(
163       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
164   HloInstruction* const2 = builder.AddInstruction(
165       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
166   HloInstruction* const3 = builder.AddInstruction(
167       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
168   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
169       const1->shape(), HloOpcode::kAdd, const1, const2));
170   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
171       const2->shape(), HloOpcode::kSubtract, const2, const3));
172   HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
173       add->shape(), HloOpcode::kMultiply, add, sub));
174 
175   auto module = CreateNewVerifiedModule("fuse_with_constant_operands");
176   HloComputation* entry_computation =
177       module->AddEntryComputation(builder.Build(mul));
178   HloInstruction* call = module->OutlineExpressionFromComputation(
179       {add, sub, mul}, "entry", entry_computation);
180   EXPECT_EQ(call, entry_computation->root_instruction());
181   HloComputation* callee_computation = call->to_apply();
182   // The arguments to the call should be const1, const2, and const3.
183   EXPECT_THAT(call->operands(),
184               ::testing::UnorderedElementsAre(const1, const2, const3));
185 
186   // The callee should contain 3 parameters and 3 binary operators.
187   EXPECT_EQ(6, callee_computation->instruction_count());
188 }
189 
TEST_F(TransposeFoldingTest,FoldDotTransposeInCall)190 TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) {
191   string hlo_string = R"(
192 HloModule FoldDotTransposeInCall
193 
194 callee {
195   name.0 = f32[2,3]{1,0} parameter(0)
196   name.1 = f32[2,3]{1,0} parameter(1)
197   transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0}
198   ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0}
199 }
200 
201 ENTRY entry_computation {
202   y = f32[2,3]{1,0} parameter(1)
203   x = f32[2,3]{1,0} parameter(0)
204   ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
205 }
206 )";
207   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
208                           ParseHloString(hlo_string));
209   FoldTranspose(module.get());
210 
211   const HloComputation* callee = module->GetComputationWithName("callee");
212   ASSERT_NE(callee, nullptr);
213   EXPECT_THAT(callee->root_instruction(),
214               op::Dot(op::Parameter(1), op::Parameter(0),
215                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
216 }
217 
218 // Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvDimSwapTransposeRhs)219 TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
220   auto builder = HloComputation::Builder("entry_computation");
221   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
222       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
223       /*name=*/"x"));
224   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
225       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
226       /*name=*/"y"));
227   HloInstruction* transpose_y =
228       builder.AddInstruction(HloInstruction::CreateTranspose(
229           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
230   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
231   Window window;
232   for (int i = 0; i < 2; ++i) {
233     WindowDimension* dim = window.add_dimensions();
234     dim->set_padding_low(0);
235     dim->set_padding_high(0);
236     dim->set_base_dilation(1);
237     dim->set_window_dilation(1);
238     dim->set_stride(1);
239     dim->set_size(
240         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
241   }
242   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
243       x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
244       /*batch_group_count=*/1, window, dnums);
245   EXPECT_IS_OK(conv_shape);
246   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
247       conv_shape.ValueOrDie(), x, transpose_y,
248       /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
249       DefaultPrecisionConfig(2)));
250 
251   auto module = CreateNewVerifiedModule("test_module");
252   HloComputation* entry_computation =
253       module->AddEntryComputation(builder.Build(conv));
254   FoldTranspose(module.get());
255 
256   // Instructions after folding: x, y, and the convolution.
257   std::unordered_set<HloInstruction*> instruction_set(
258       entry_computation->instructions().begin(),
259       entry_computation->instructions().end());
260   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
261   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
262   CHECK_EQ(1, instruction_set.size())
263       << "entry_computation should contain exactly 3 instructions.";
264   HloInstruction* new_conv = *instruction_set.begin();
265   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
266   EXPECT_EQ(dnums.kernel_input_feature_dimension(),
267             new_conv->convolution_dimension_numbers()
268                 .kernel_output_feature_dimension());
269   EXPECT_EQ(dnums.kernel_output_feature_dimension(),
270             new_conv->convolution_dimension_numbers()
271                 .kernel_input_feature_dimension());
272 }
273 
274 // Test that a complex transpose of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvComplexTransposeRhs)275 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
276   auto builder = HloComputation::Builder("entry_computation");
277   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
278       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
279       /*name=*/"x"));
280   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
281       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}),
282       /*name=*/"y"));
283   HloInstruction* transpose_y =
284       builder.AddInstruction(HloInstruction::CreateTranspose(
285           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
286   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
287   Window window;
288   for (int i = 0; i < 2; ++i) {
289     WindowDimension* dim = window.add_dimensions();
290     dim->set_padding_low(0);
291     dim->set_padding_high(0);
292     dim->set_base_dilation(1);
293     dim->set_window_dilation(1);
294     dim->set_stride(1);
295     dim->set_size(
296         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
297   }
298   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
299       x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
300       /*batch_group_count=*/1, window, dnums);
301   EXPECT_IS_OK(conv_shape);
302   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
303       conv_shape.ValueOrDie(), x, transpose_y,
304       /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
305       DefaultPrecisionConfig(2)));
306 
307   auto module = CreateNewVerifiedModule("test_module");
308   HloComputation* entry_computation =
309       module->AddEntryComputation(builder.Build(conv));
310   FoldTranspose(module.get());
311 
312   // Instructions after folding: x, y, and the convolution.
313   std::unordered_set<HloInstruction*> instruction_set(
314       entry_computation->instructions().begin(),
315       entry_computation->instructions().end());
316   CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
317   CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
318   CHECK_EQ(1, instruction_set.size())
319       << "entry_computation should contain exactly 3 instructions.";
320   HloInstruction* new_conv = *instruction_set.begin();
321   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
322   EXPECT_EQ(dnums.kernel_input_feature_dimension(),
323             new_conv->convolution_dimension_numbers()
324                 .kernel_output_feature_dimension());
325   EXPECT_EQ(dnums.kernel_spatial_dimensions(1),
326             new_conv->convolution_dimension_numbers()
327                 .kernel_input_feature_dimension());
328   EXPECT_EQ(
329       dnums.kernel_output_feature_dimension(),
330       new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0));
331   EXPECT_EQ(
332       dnums.kernel_spatial_dimensions(0),
333       new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
334 }
335 
336 // Test that a transpose of the activations gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvTransposeLhs)337 TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
338   auto builder = HloComputation::Builder("entry_computation");
339   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
340       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
341       /*name=*/"x"));
342   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
343       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
344       /*name=*/"y"));
345   HloInstruction* transpose_x =
346       builder.AddInstruction(HloInstruction::CreateTranspose(
347           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
348   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
349   Window window;
350   for (int i = 0; i < 2; ++i) {
351     WindowDimension* dim = window.add_dimensions();
352     dim->set_padding_low(0);
353     dim->set_padding_high(0);
354     dim->set_base_dilation(1);
355     dim->set_window_dilation(1);
356     dim->set_stride(1);
357     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
358   }
359   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
360       transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
361       /*batch_group_count=*/1, window, dnums);
362   EXPECT_IS_OK(conv_shape);
363   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
364       conv_shape.ValueOrDie(), transpose_x, y,
365       /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
366       DefaultPrecisionConfig(2)));
367 
368   auto module = CreateNewVerifiedModule("test_module");
369   HloComputation* entry_computation =
370       module->AddEntryComputation(builder.Build(conv));
371   FoldTranspose(module.get());
372 
373   // Instructions after folding: x, y, and the convolution.
374   std::unordered_set<HloInstruction*> instruction_set(
375       entry_computation->instructions().begin(),
376       entry_computation->instructions().end());
377   EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
378   EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
379   EXPECT_EQ(1, instruction_set.size())
380       << "entry_computation should contain exactly 3 instructions.";
381   HloInstruction* new_conv = *instruction_set.begin();
382   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
383   EXPECT_EQ(dnums.input_feature_dimension(),
384             new_conv->convolution_dimension_numbers().input_batch_dimension());
385   EXPECT_EQ(
386       dnums.input_batch_dimension(),
387       new_conv->convolution_dimension_numbers().input_feature_dimension());
388   EXPECT_EQ(
389       dnums.input_spatial_dimensions(0),
390       new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
391   EXPECT_EQ(
392       dnums.input_spatial_dimensions(1),
393       new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
394   EXPECT_EQ(
395       dnums.output_spatial_dimensions(0),
396       new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
397   EXPECT_EQ(
398       dnums.output_spatial_dimensions(1),
399       new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
400 }
401 
402 // Test that a transpose of every dimension in the activations gets folded into
403 // convolution.
TEST_F(TransposeFoldingTest,FoldConvComplexTransposeLhs)404 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
405   auto builder = HloComputation::Builder("entry_computation");
406   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
407       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
408       /*name=*/"x"));
409   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
410       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
411       /*name=*/"y"));
412   HloInstruction* transpose_x =
413       builder.AddInstruction(HloInstruction::CreateTranspose(
414           ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
415   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
416   Window window;
417   for (int i = 0; i < 2; ++i) {
418     WindowDimension* dim = window.add_dimensions();
419     dim->set_padding_low(0);
420     dim->set_padding_high(0);
421     dim->set_base_dilation(1);
422     dim->set_window_dilation(1);
423     dim->set_stride(1);
424     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
425   }
426   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
427       transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
428       /*batch_group_count=*/1, window, dnums);
429   EXPECT_IS_OK(conv_shape);
430   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
431       conv_shape.ValueOrDie(), transpose_x, y,
432       /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
433       DefaultPrecisionConfig(2)));
434 
435   auto module = CreateNewVerifiedModule("test_module");
436   HloComputation* entry_computation =
437       module->AddEntryComputation(builder.Build(conv));
438   FoldTranspose(module.get());
439 
440   // Instructions after folding: x, y, and the convolution.
441   std::unordered_set<HloInstruction*> instruction_set(
442       entry_computation->instructions().begin(),
443       entry_computation->instructions().end());
444   EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
445   EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
446   EXPECT_EQ(1, instruction_set.size())
447       << "entry_computation should contain exactly 3 instructions.";
448   HloInstruction* new_conv = *instruction_set.begin();
449   EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
450   EXPECT_EQ(dnums.input_feature_dimension(),
451             new_conv->convolution_dimension_numbers().input_batch_dimension());
452   EXPECT_EQ(
453       dnums.input_batch_dimension(),
454       new_conv->convolution_dimension_numbers().input_feature_dimension());
455   EXPECT_EQ(
456       dnums.input_spatial_dimensions(0),
457       new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
458   EXPECT_EQ(
459       dnums.input_spatial_dimensions(1),
460       new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
461   EXPECT_EQ(
462       dnums.output_spatial_dimensions(0),
463       new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
464   EXPECT_EQ(
465       dnums.output_spatial_dimensions(1),
466       new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
467 }
468 
469 }  // namespace
470 }  // namespace xla
471