1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
17
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_parser.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace op = xla::testing::opcode_matchers;
40
41 namespace xla {
42 namespace {
43
44 class TransposeFoldingTest : public HloTestBase {
45 protected:
FoldTranspose(HloModule * module)46 void FoldTranspose(HloModule* module) {
47 TransposeFolding transpose_folding(
48 [](const HloInstruction& dot,
49 const TransposeFolding::OperandIndices& candidate_operands) {
50 return candidate_operands;
51 },
52 [](const HloInstruction& convolution,
53 const TransposeFolding::OperandIndices& candidate_operands) {
54 return candidate_operands;
55 });
56 EXPECT_IS_OK(transpose_folding.Run(module).status());
57 }
58 };
59
TEST_F(TransposeFoldingTest,FoldDotTranspose)60 TEST_F(TransposeFoldingTest, FoldDotTranspose) {
61 string hlo_string = R"(
62 HloModule FoldDotTranspose
63
64 ENTRY entry_computation {
65 x = f32[2,3]{1,0} parameter(0)
66 y = f32[2,3]{1,0} parameter(1)
67 transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0}
68 ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
69 }
70 )";
71 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
72 ParseHloString(hlo_string));
73
74 FoldTranspose(module.get());
75
76 EXPECT_THAT(module->entry_computation()->root_instruction(),
77 op::Dot(op::Parameter(0), op::Parameter(1),
78 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
79 }
80
TEST_F(TransposeFoldingTest,DontFoldTransposeOfBatchDim)81 TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) {
82 string hlo_string = R"(
83 HloModule FoldDotTranspose
84
85 ENTRY entry_computation {
86 x = f32[2,3] parameter(0)
87 y = f32[3,2] parameter(1)
88 transpose = f32[2,3] transpose(y), dimensions={1,0}
89 ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
90 }
91 )";
92
93 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
94 ParseHloString(hlo_string));
95
96 TransposeFolding transpose_folding(
97 [](const HloInstruction& dot,
98 const TransposeFolding::OperandIndices& candidate_operands) {
99 return candidate_operands;
100 },
101 [](const HloInstruction& convolution,
102 const TransposeFolding::OperandIndices& candidate_operands) {
103 return candidate_operands;
104 });
105 TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
106 EXPECT_FALSE(changed);
107 }
108
TEST_F(TransposeFoldingTest,DontFoldTransposeOfRank1Dot)109 TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) {
110 string hlo_string = R"(
111 HloModule FoldDotTranspose
112
113 ENTRY entry_computation {
114 x = f32[3] parameter(0)
115 y = f32[3,2] parameter(1)
116 transpose = f32[2,3] transpose(y), dimensions={1,0}
117 ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1}
118 }
119 )";
120
121 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
122 ParseHloString(hlo_string));
123
124 TransposeFolding transpose_folding(
125 [](const HloInstruction& dot,
126 const TransposeFolding::OperandIndices& candidate_operands) {
127 return candidate_operands;
128 },
129 [](const HloInstruction& convolution,
130 const TransposeFolding::OperandIndices& candidate_operands) {
131 return candidate_operands;
132 });
133 TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
134 EXPECT_FALSE(changed);
135 }
136
TEST_F(TransposeFoldingTest,FoldDotTransposeConstant)137 TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
138 string hlo_string = R"(
139 HloModule FoldDotTransposeConstant
140
141 ENTRY entry_computation {
142 constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } })
143 transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0}
144 constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } })
145 transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0}
146 ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
147 }
148 )";
149 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
150 ParseHloString(hlo_string));
151
152 FoldTranspose(module.get());
153
154 EXPECT_THAT(module->entry_computation()->root_instruction(),
155 op::Dot(op::Constant(), op::Constant(),
156 /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1));
157 }
158
TEST_F(TransposeFoldingTest,FuseDotWithConstantOperands)159 TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
160 auto builder = HloComputation::Builder("entry");
161 // (1.0 + 2.0) * (2.0 - 3.0)
162 HloInstruction* const1 = builder.AddInstruction(
163 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
164 HloInstruction* const2 = builder.AddInstruction(
165 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
166 HloInstruction* const3 = builder.AddInstruction(
167 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
168 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
169 const1->shape(), HloOpcode::kAdd, const1, const2));
170 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
171 const2->shape(), HloOpcode::kSubtract, const2, const3));
172 HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
173 add->shape(), HloOpcode::kMultiply, add, sub));
174
175 auto module = CreateNewVerifiedModule("fuse_with_constant_operands");
176 HloComputation* entry_computation =
177 module->AddEntryComputation(builder.Build(mul));
178 HloInstruction* call = module->OutlineExpressionFromComputation(
179 {add, sub, mul}, "entry", entry_computation);
180 EXPECT_EQ(call, entry_computation->root_instruction());
181 HloComputation* callee_computation = call->to_apply();
182 // The arguments to the call should be const1, const2, and const3.
183 EXPECT_THAT(call->operands(),
184 ::testing::UnorderedElementsAre(const1, const2, const3));
185
186 // The callee should contain 3 parameters and 3 binary operators.
187 EXPECT_EQ(6, callee_computation->instruction_count());
188 }
189
TEST_F(TransposeFoldingTest,FoldDotTransposeInCall)190 TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) {
191 string hlo_string = R"(
192 HloModule FoldDotTransposeInCall
193
194 callee {
195 name.0 = f32[2,3]{1,0} parameter(0)
196 name.1 = f32[2,3]{1,0} parameter(1)
197 transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0}
198 ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0}
199 }
200
201 ENTRY entry_computation {
202 y = f32[2,3]{1,0} parameter(1)
203 x = f32[2,3]{1,0} parameter(0)
204 ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
205 }
206 )";
207 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
208 ParseHloString(hlo_string));
209 FoldTranspose(module.get());
210
211 const HloComputation* callee = module->GetComputationWithName("callee");
212 ASSERT_NE(callee, nullptr);
213 EXPECT_THAT(callee->root_instruction(),
214 op::Dot(op::Parameter(1), op::Parameter(0),
215 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
216 }
217
218 // Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvDimSwapTransposeRhs)219 TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
220 auto builder = HloComputation::Builder("entry_computation");
221 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
222 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
223 /*name=*/"x"));
224 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
225 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
226 /*name=*/"y"));
227 HloInstruction* transpose_y =
228 builder.AddInstruction(HloInstruction::CreateTranspose(
229 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
230 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
231 Window window;
232 for (int i = 0; i < 2; ++i) {
233 WindowDimension* dim = window.add_dimensions();
234 dim->set_padding_low(0);
235 dim->set_padding_high(0);
236 dim->set_base_dilation(1);
237 dim->set_window_dilation(1);
238 dim->set_stride(1);
239 dim->set_size(
240 transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
241 }
242 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
243 x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
244 /*batch_group_count=*/1, window, dnums);
245 EXPECT_IS_OK(conv_shape);
246 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
247 conv_shape.ValueOrDie(), x, transpose_y,
248 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
249 DefaultPrecisionConfig(2)));
250
251 auto module = CreateNewVerifiedModule("test_module");
252 HloComputation* entry_computation =
253 module->AddEntryComputation(builder.Build(conv));
254 FoldTranspose(module.get());
255
256 // Instructions after folding: x, y, and the convolution.
257 std::unordered_set<HloInstruction*> instruction_set(
258 entry_computation->instructions().begin(),
259 entry_computation->instructions().end());
260 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
261 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
262 CHECK_EQ(1, instruction_set.size())
263 << "entry_computation should contain exactly 3 instructions.";
264 HloInstruction* new_conv = *instruction_set.begin();
265 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
266 EXPECT_EQ(dnums.kernel_input_feature_dimension(),
267 new_conv->convolution_dimension_numbers()
268 .kernel_output_feature_dimension());
269 EXPECT_EQ(dnums.kernel_output_feature_dimension(),
270 new_conv->convolution_dimension_numbers()
271 .kernel_input_feature_dimension());
272 }
273
274 // Test that a complex transpose of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvComplexTransposeRhs)275 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
276 auto builder = HloComputation::Builder("entry_computation");
277 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
278 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
279 /*name=*/"x"));
280 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
281 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}),
282 /*name=*/"y"));
283 HloInstruction* transpose_y =
284 builder.AddInstruction(HloInstruction::CreateTranspose(
285 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
286 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
287 Window window;
288 for (int i = 0; i < 2; ++i) {
289 WindowDimension* dim = window.add_dimensions();
290 dim->set_padding_low(0);
291 dim->set_padding_high(0);
292 dim->set_base_dilation(1);
293 dim->set_window_dilation(1);
294 dim->set_stride(1);
295 dim->set_size(
296 transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
297 }
298 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
299 x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
300 /*batch_group_count=*/1, window, dnums);
301 EXPECT_IS_OK(conv_shape);
302 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
303 conv_shape.ValueOrDie(), x, transpose_y,
304 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
305 DefaultPrecisionConfig(2)));
306
307 auto module = CreateNewVerifiedModule("test_module");
308 HloComputation* entry_computation =
309 module->AddEntryComputation(builder.Build(conv));
310 FoldTranspose(module.get());
311
312 // Instructions after folding: x, y, and the convolution.
313 std::unordered_set<HloInstruction*> instruction_set(
314 entry_computation->instructions().begin(),
315 entry_computation->instructions().end());
316 CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
317 CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
318 CHECK_EQ(1, instruction_set.size())
319 << "entry_computation should contain exactly 3 instructions.";
320 HloInstruction* new_conv = *instruction_set.begin();
321 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
322 EXPECT_EQ(dnums.kernel_input_feature_dimension(),
323 new_conv->convolution_dimension_numbers()
324 .kernel_output_feature_dimension());
325 EXPECT_EQ(dnums.kernel_spatial_dimensions(1),
326 new_conv->convolution_dimension_numbers()
327 .kernel_input_feature_dimension());
328 EXPECT_EQ(
329 dnums.kernel_output_feature_dimension(),
330 new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0));
331 EXPECT_EQ(
332 dnums.kernel_spatial_dimensions(0),
333 new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
334 }
335
336 // Test that a transpose of the activations gets folded into convolution.
TEST_F(TransposeFoldingTest,FoldConvTransposeLhs)337 TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
338 auto builder = HloComputation::Builder("entry_computation");
339 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
340 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
341 /*name=*/"x"));
342 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
343 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
344 /*name=*/"y"));
345 HloInstruction* transpose_x =
346 builder.AddInstruction(HloInstruction::CreateTranspose(
347 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
348 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
349 Window window;
350 for (int i = 0; i < 2; ++i) {
351 WindowDimension* dim = window.add_dimensions();
352 dim->set_padding_low(0);
353 dim->set_padding_high(0);
354 dim->set_base_dilation(1);
355 dim->set_window_dilation(1);
356 dim->set_stride(1);
357 dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
358 }
359 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
360 transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
361 /*batch_group_count=*/1, window, dnums);
362 EXPECT_IS_OK(conv_shape);
363 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
364 conv_shape.ValueOrDie(), transpose_x, y,
365 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
366 DefaultPrecisionConfig(2)));
367
368 auto module = CreateNewVerifiedModule("test_module");
369 HloComputation* entry_computation =
370 module->AddEntryComputation(builder.Build(conv));
371 FoldTranspose(module.get());
372
373 // Instructions after folding: x, y, and the convolution.
374 std::unordered_set<HloInstruction*> instruction_set(
375 entry_computation->instructions().begin(),
376 entry_computation->instructions().end());
377 EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
378 EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
379 EXPECT_EQ(1, instruction_set.size())
380 << "entry_computation should contain exactly 3 instructions.";
381 HloInstruction* new_conv = *instruction_set.begin();
382 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
383 EXPECT_EQ(dnums.input_feature_dimension(),
384 new_conv->convolution_dimension_numbers().input_batch_dimension());
385 EXPECT_EQ(
386 dnums.input_batch_dimension(),
387 new_conv->convolution_dimension_numbers().input_feature_dimension());
388 EXPECT_EQ(
389 dnums.input_spatial_dimensions(0),
390 new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
391 EXPECT_EQ(
392 dnums.input_spatial_dimensions(1),
393 new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
394 EXPECT_EQ(
395 dnums.output_spatial_dimensions(0),
396 new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
397 EXPECT_EQ(
398 dnums.output_spatial_dimensions(1),
399 new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
400 }
401
402 // Test that a transpose of every dimension in the activations gets folded into
403 // convolution.
TEST_F(TransposeFoldingTest,FoldConvComplexTransposeLhs)404 TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
405 auto builder = HloComputation::Builder("entry_computation");
406 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
407 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
408 /*name=*/"x"));
409 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
410 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
411 /*name=*/"y"));
412 HloInstruction* transpose_x =
413 builder.AddInstruction(HloInstruction::CreateTranspose(
414 ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
415 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
416 Window window;
417 for (int i = 0; i < 2; ++i) {
418 WindowDimension* dim = window.add_dimensions();
419 dim->set_padding_low(0);
420 dim->set_padding_high(0);
421 dim->set_base_dilation(1);
422 dim->set_window_dilation(1);
423 dim->set_stride(1);
424 dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
425 }
426 StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
427 transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
428 /*batch_group_count=*/1, window, dnums);
429 EXPECT_IS_OK(conv_shape);
430 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
431 conv_shape.ValueOrDie(), transpose_x, y,
432 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
433 DefaultPrecisionConfig(2)));
434
435 auto module = CreateNewVerifiedModule("test_module");
436 HloComputation* entry_computation =
437 module->AddEntryComputation(builder.Build(conv));
438 FoldTranspose(module.get());
439
440 // Instructions after folding: x, y, and the convolution.
441 std::unordered_set<HloInstruction*> instruction_set(
442 entry_computation->instructions().begin(),
443 entry_computation->instructions().end());
444 EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
445 EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
446 EXPECT_EQ(1, instruction_set.size())
447 << "entry_computation should contain exactly 3 instructions.";
448 HloInstruction* new_conv = *instruction_set.begin();
449 EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
450 EXPECT_EQ(dnums.input_feature_dimension(),
451 new_conv->convolution_dimension_numbers().input_batch_dimension());
452 EXPECT_EQ(
453 dnums.input_batch_dimension(),
454 new_conv->convolution_dimension_numbers().input_feature_dimension());
455 EXPECT_EQ(
456 dnums.input_spatial_dimensions(0),
457 new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
458 EXPECT_EQ(
459 dnums.input_spatial_dimensions(1),
460 new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
461 EXPECT_EQ(
462 dnums.output_spatial_dimensions(0),
463 new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
464 EXPECT_EQ(
465 dnums.output_spatial_dimensions(1),
466 new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
467 }
468
469 } // namespace
470 } // namespace xla
471