1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
18 #include "tensorflow/compiler/xla/statusor.h"
19 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 
22 namespace xla {
23 namespace {
24 
25 namespace op = xla::testing::opcode_matchers;
26 
27 class ArCrsCombinerTest : public HloTestBase {};
28 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase)29 TEST_F(ArCrsCombinerTest, SameValueTestBasecase) {
30   const char* module_str = R"(
31 HloModule foobar
32 
33 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
34   %p = f32[2,2] parameter(0)
35   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
36   %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}})
37   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
38 }
39 )";
40 
41   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
42                           ParseAndReturnVerifiedModule(module_str));
43   auto root_tuple = module->entry_computation()->root_instruction();
44   auto i1 = root_tuple->operands()[0];
45   auto i2 = root_tuple->operands()[1];
46   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(
47       i1, module->entry_computation()->parameter_instruction(0)));
48   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
49 }
50 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase2)51 TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) {
52   const char* module_str = R"(
53 HloModule foobar
54 
55 ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) {
56   %x = f32[] parameter(0)
57   ROOT %tuple = (f32[], f32[]) tuple(%x, %x)
58 }
59 )";
60 
61   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
62                           ParseAndReturnVerifiedModule(module_str));
63   auto root_tuple = module->entry_computation()->root_instruction();
64   auto i1 = root_tuple->operands()[0];
65   auto i2 = root_tuple->operands()[1];
66   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
67 }
68 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase3)69 TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) {
70   const char* module_str = R"(
71 HloModule foobar
72 
73 ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) {
74   %x = f32[] parameter(0)
75   %y = f32[] parameter(1)
76   ROOT %tuple = (f32[], f32[]) tuple(%x, %y)
77 }
78 )";
79 
80   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
81                           ParseAndReturnVerifiedModule(module_str));
82   auto root_tuple = module->entry_computation()->root_instruction();
83   auto i1 = root_tuple->operands()[0];
84   auto i2 = root_tuple->operands()[1];
85   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
86 }
87 
TEST_F(ArCrsCombinerTest,SameValueTestNumOperands)88 TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) {
89   const char* module_str = R"(
90 HloModule foobar
91 
92 ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) {
93   %p = f32[2,2] parameter(0)
94   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
95   %tuple1 = (f32[2,2]) tuple(%constant.f32)
96   %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
97   ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2)
98 }
99 )";
100 
101   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
102                           ParseAndReturnVerifiedModule(module_str));
103   auto root_tuple = module->entry_computation()->root_instruction();
104   auto i1 = root_tuple->operands()[0];
105   auto i2 = root_tuple->operands()[1];
106   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
107 }
108 
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesMatch)109 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) {
110   const char* module_str = R"(
111 HloModule foobar
112 
113 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
114   %p = f32[2] parameter(0)
115   %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
116   %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]}
117   ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
118 }
119 )";
120 
121   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
122                           ParseAndReturnVerifiedModule(module_str));
123   auto root_tuple = module->entry_computation()->root_instruction();
124   auto i1 = root_tuple->operands()[0];
125   auto i2 = root_tuple->operands()[1];
126   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
127 }
128 
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesDontMatch)129 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) {
130   const char* module_str = R"(
131 HloModule foobar
132 
133 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
134   %p = f32[2] parameter(0)
135   %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
136   %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]}
137   ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
138 }
139 )";
140 
141   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
142                           ParseAndReturnVerifiedModule(module_str));
143   auto root_tuple = module->entry_computation()->root_instruction();
144   auto i1 = root_tuple->operands()[0];
145   auto i2 = root_tuple->operands()[1];
146   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
147 }
148 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementSameIndex)149 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) {
150   const char* module_str = R"(
151 HloModule foobar
152 
153 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
154   %p = f32[2,2] parameter(0)
155   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
156   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
157   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
158   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0
159   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
160 }
161 )";
162 
163   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
164                           ParseAndReturnVerifiedModule(module_str));
165   auto root_tuple = module->entry_computation()->root_instruction();
166   auto i1 = root_tuple->operands()[0];
167   auto i2 = root_tuple->operands()[1];
168   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
169 }
170 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex1)171 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) {
172   const char* module_str = R"(
173 HloModule foobar
174 
175 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
176   %p = f32[2,2] parameter(0)
177   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
178   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
179   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
180   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
181   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
182 }
183 )";
184 
185   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
186                           ParseAndReturnVerifiedModule(module_str));
187   auto root_tuple = module->entry_computation()->root_instruction();
188   auto i1 = root_tuple->operands()[0];
189   auto i2 = root_tuple->operands()[1];
190   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
191 }
192 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex2)193 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) {
194   const char* module_str = R"(
195 HloModule foobar
196 
197 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
198   %p = f32[2,2] parameter(0)
199   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
200   %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}})
201   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
202   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
203   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
204   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
205 }
206 )";
207 
208   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
209                           ParseAndReturnVerifiedModule(module_str));
210   auto root_tuple = module->entry_computation()->root_instruction();
211   auto i1 = root_tuple->operands()[0];
212   auto i2 = root_tuple->operands()[1];
213   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
214 }
215 
TEST_F(ArCrsCombinerTest,SameValueTestWhile1)216 TEST_F(ArCrsCombinerTest, SameValueTestWhile1) {
217   const char* module_str = R"(
218 HloModule foobar
219 
220 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
221   %x = (f32[2,2], f32[2,2]) parameter(0)
222   %constant.0 = s32[] constant(0)
223   %constant.1 = s32[] constant(1)
224   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
225 }
226 
227 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
228   %x = (f32[2,2], f32[2,2]) parameter(0)
229   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
230   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
231   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
232   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
233   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
234   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
235 }
236 
237 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
238   %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
239   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
240   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
241 }
242 )";
243 
244   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
245                           ParseAndReturnVerifiedModule(module_str));
246   auto root_while = module->entry_computation()->root_instruction();
247   auto body_tuple = root_while->while_body()->root_instruction();
248   auto i1 = body_tuple->operands()[0];
249   auto i2 = body_tuple->operands()[1];
250   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
251 }
252 
TEST_F(ArCrsCombinerTest,SameValueTestWhile2)253 TEST_F(ArCrsCombinerTest, SameValueTestWhile2) {
254   const char* module_str = R"(
255 HloModule foobar
256 
257 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
258   %x = (f32[2,2], f32[2,2]) parameter(0)
259   %constant.0 = s32[] constant(0)
260   %constant.1 = s32[] constant(1)
261   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
262 }
263 
264 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
265   %x = (f32[2,2], f32[2,2]) parameter(0)
266   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
267   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
268   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
269   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
270   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
271   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
272 }
273 
274 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
275   %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}})
276   %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}})
277   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
278   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
279 }
280 )";
281 
282   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
283                           ParseAndReturnVerifiedModule(module_str));
284   auto root_while = module->entry_computation()->root_instruction();
285   auto body_tuple = root_while->while_body()->root_instruction();
286   auto i1 = body_tuple->operands()[0];
287   auto i2 = body_tuple->operands()[1];
288   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
289 }
290 
TEST_F(ArCrsCombinerTest,SameValueTestWhile3)291 TEST_F(ArCrsCombinerTest, SameValueTestWhile3) {
292   const char* module_str = R"(
293 HloModule foobar
294 
295 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
296   %x = (f32[2,2], f32[2,2]) parameter(0)
297   %constant.0 = s32[] constant(0)
298   %constant.1 = s32[] constant(1)
299   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
300 }
301 
302 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
303   %x = (f32[2,2], f32[2,2]) parameter(0)
304   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
305   %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}})
306   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
307   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
308   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1)
309   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2)
310   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
311 }
312 
313 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
314   %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
315   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
316   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
317 }
318 )";
319 
320   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
321                           ParseAndReturnVerifiedModule(module_str));
322   auto root_while = module->entry_computation()->root_instruction();
323   auto body_tuple = root_while->while_body()->root_instruction();
324   auto i1 = body_tuple->operands()[0]->operands()[0];  // %get-tuple-element.1
325   auto i2 = body_tuple->operands()[1]->operands()[0];  // %get-tuple-element.2
326   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
327 }
328 
CompareReplicaGroups(const std::vector<ReplicaGroup> & groups_before,const std::vector<ReplicaGroup> & groups_after)329 void CompareReplicaGroups(const std::vector<ReplicaGroup>& groups_before,
330                           const std::vector<ReplicaGroup>& groups_after) {
331   ASSERT_EQ(groups_before.size(), groups_after.size());
332   for (int i = 0; i < groups_before.size(); ++i) {
333     // Somewhat verbose way to compare the replica_ids, because EqualsProto
334     // is not available in the open-source build.
335     auto group_before = groups_before[i];
336     std::vector<int64> ids_before(group_before.replica_ids().begin(),
337                                   group_before.replica_ids().end());
338     auto group_after = groups_after[i];
339     std::vector<int64> ids_after(group_after.replica_ids().begin(),
340                                  group_after.replica_ids().end());
341     EXPECT_EQ(ids_before, ids_after);
342   }
343 }
344 
TEST_F(ArCrsCombinerTest,RewriteArConvertCrs)345 TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
346   const char* module_str = R"(
347 HloModule foobar
348 
349 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
350   %a = bf16[] parameter(0)
351   %b = bf16[] parameter(1)
352   ROOT %add = bf16[] add(%a, %b)
353 }
354 
355 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
356   %x = f32[] parameter(0)
357   %y = f32[] parameter(1)
358   ROOT %add = f32[] add(%x, %y)
359 }
360 
361 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
362   %p = bf16[] parameter(0)
363   %constant.bf16 = bf16[] constant(1)
364 
365   %all-reduce.ar.1 = bf16[]
366       all-reduce(%p),
367       replica_groups={{0},{1}},
368       all_reduce_id=1,
369       to_apply=%sum.bf16,
370       sharding={maximal device=0}
371   %convert.1 = f32[]
372       convert(%all-reduce.ar.1),
373       sharding={maximal device=0}
374   %all-reduce.1 = f32[]
375       all-reduce(%convert.1),
376       replica_groups={{0,1}},
377       to_apply=%sum.f32,
378       sharding={maximal device=0}
379 
380   %all-reduce.ar.2 = bf16[]
381       all-reduce(%constant.bf16),
382       replica_groups={{0},{1}},
383       all_reduce_id=1,
384       to_apply=%sum.bf16,
385       sharding={maximal device=1}
386   %convert.2 = f32[]
387       convert(%all-reduce.ar.2),
388       sharding={maximal device=1}
389   %all-reduce.2 = f32[]
390       all-reduce(%convert.2),
391       replica_groups={{0,1}},
392       to_apply=%sum.f32,
393       sharding={maximal device=1}
394 
395   ROOT %tuple = (f32[], f32[])
396       tuple(%all-reduce.1, %all-reduce.2),
397       sharding={{maximal device=0}, {maximal device=1}}
398 }
399 )";
400 
401   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
402                           ParseAndReturnVerifiedModule(module_str));
403   auto crs_before =
404       module->entry_computation()->root_instruction()->operands()[0];
405   auto replica_groups_before = crs_before->replica_groups();
406   ArCrsCombiner combiner(2);
407   auto changed = combiner.Run(module.get()).ValueOrDie();
408   EXPECT_TRUE(changed);
409   EXPECT_THAT(module->entry_computation()->root_instruction(),
410               op::Tuple(op::AllReduce(op::Convert(op::Parameter())),
411                         op::AllReduce(op::Convert(op::Constant()))));
412   auto crs_after =
413       module->entry_computation()->root_instruction()->operands()[0];
414   auto replica_groups_after = crs_after->replica_groups();
415   CompareReplicaGroups(replica_groups_before, replica_groups_after);
416 }
417 
TEST_F(ArCrsCombinerTest,RewriteArBitcastCrs)418 TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
419   const char* module_str = R"(
420 HloModule foobar
421 
422 %sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
423   %a = f32[2,1] parameter(0)
424   %b = f32[2,1] parameter(1)
425   ROOT %add = f32[2,1] add(%a, %b)
426 }
427 
428 %sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
429   %x = f32[2] parameter(0)
430   %y = f32[2] parameter(1)
431   ROOT %add = f32[2] add(%x, %y)
432 }
433 
434 ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
435   %p = f32[2,1] parameter(0)
436 
437   %all-reduce.ar.1 = f32[2,1]
438       all-reduce(%p),
439       replica_groups={{0},{1}},
440       all_reduce_id=1,
441       to_apply=%sum.1,
442       sharding={maximal device=0}
443   %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
444   %all-reduce.1 = f32[2]
445       all-reduce(%bitcast.1),
446       replica_groups={{0,1}},
447       to_apply=%sum.2,
448       sharding={maximal device=0}
449 
450   %all-reduce.ar.2 = f32[2,1]
451       all-reduce(%p),
452       replica_groups={{0},{1}},
453       all_reduce_id=1,
454       to_apply=%sum.1,
455       sharding={maximal device=1}
456   %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
457   %all-reduce.2 = f32[2]
458       all-reduce(%bitcast.2),
459       replica_groups={{0,1}},
460       to_apply=%sum.2,
461       sharding={maximal device=1}
462 
463   ROOT %tuple = (f32[], f32[])
464       tuple(%all-reduce.1, %all-reduce.2),
465       sharding={{maximal device=0}, {maximal device=1}}
466 }
467 )";
468 
469   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
470                           ParseAndReturnVerifiedModule(module_str));
471   auto crs_before =
472       module->entry_computation()->root_instruction()->operands()[0];
473   auto replica_groups_before = crs_before->replica_groups();
474   ArCrsCombiner combiner(2);
475   auto changed = combiner.Run(module.get()).ValueOrDie();
476   EXPECT_TRUE(changed);
477   EXPECT_THAT(module->entry_computation()->root_instruction(),
478               op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())),
479                         op::AllReduce(op::Bitcast(op::Parameter()))));
480   auto crs_after =
481       module->entry_computation()->root_instruction()->operands()[0];
482   auto replica_groups_after = crs_after->replica_groups();
483   CompareReplicaGroups(replica_groups_before, replica_groups_after);
484 }
485 
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrs)486 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
487   const char* module_str = R"(
488 HloModule foobar
489 
490 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
491   %x = f32[] parameter(0)
492   %y = f32[] parameter(1)
493   ROOT %add = f32[] add(%x, %y)
494 }
495 
496 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
497   %p = f32[] parameter(0)
498   %constant.f32 = f32[] constant(123)
499 
500   %all-reduce.ar.1 = f32[]
501       all-reduce(%p),
502       replica_groups={{0},{1}},
503       all_reduce_id=1,
504       to_apply=%sum.f32,
505       sharding={maximal device=0}
506   %multiply.1 = f32[]
507       multiply(%all-reduce.ar.1, %constant.f32),
508       sharding={maximal device=0}
509   %all-reduce.1 = f32[]
510       all-reduce(%multiply.1),
511       replica_groups={{0,1}},
512       to_apply=%sum.f32,
513       sharding={maximal device=0}
514 
515   %all-reduce.ar.2 = f32[]
516       all-reduce(%p),
517       replica_groups={{0},{1}},
518       all_reduce_id=1,
519       to_apply=%sum.f32,
520       sharding={maximal device=1}
521   %multiply.2 = f32[]
522       multiply(%all-reduce.ar.2, %constant.f32),
523       sharding={maximal device=1}
524   %all-reduce.2 = f32[]
525       all-reduce(%multiply.2),
526       replica_groups={{0,1}},
527       to_apply=%sum.f32,
528       sharding={maximal device=1}
529 
530   ROOT %tuple = (f32[], f32[])
531       tuple(%all-reduce.1, %all-reduce.2),
532       sharding={{maximal device=0}, {maximal device=1}}
533 }
534 )";
535 
536   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
537                           ParseAndReturnVerifiedModule(module_str));
538   auto crs_before =
539       module->entry_computation()->root_instruction()->operands()[0];
540   auto replica_groups_before = crs_before->replica_groups();
541   ArCrsCombiner combiner(2);
542   auto changed = combiner.Run(module.get()).ValueOrDie();
543   EXPECT_TRUE(changed);
544   EXPECT_THAT(
545       module->entry_computation()->root_instruction(),
546       op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())),
547                 op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
548   auto crs_after =
549       module->entry_computation()->root_instruction()->operands()[0];
550   auto replica_groups_after = crs_after->replica_groups();
551   CompareReplicaGroups(replica_groups_before, replica_groups_after);
552 }
553 
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrs)554 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
555   const char* module_str = R"(
556 HloModule foobar
557 
558 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
559   %a = bf16[] parameter(0)
560   %b = bf16[] parameter(1)
561   ROOT %add = bf16[] add(%a, %b)
562 }
563 
564 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
565   %x = f32[] parameter(0)
566   %y = f32[] parameter(1)
567   ROOT %add = f32[] add(%x, %y)
568 }
569 
570 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
571   %p = f32[] parameter(0)
572   %constant.bf16 = bf16[] constant(1)
573   %constant.f32 = f32[] constant(2)
574 
575   %all-reduce.ar.1 = bf16[]
576       all-reduce(%constant.bf16),
577       replica_groups={{0},{1}},
578       all_reduce_id=1,
579       to_apply=%sum.bf16,
580       sharding={maximal device=0}
581   %convert.1 = f32[]
582       convert(%all-reduce.ar.1),
583       sharding={maximal device=0}
584   %add.1 = f32[]
585       add(%constant.f32, %convert.1),
586       sharding={maximal device=0}
587   %all-reduce.1 = f32[]
588       all-reduce(%add.1),
589       replica_groups={{0,1}},
590       to_apply=%sum.f32,
591       sharding={maximal device=0}
592 
593   %all-reduce.ar.2 = bf16[]
594       all-reduce(%constant.bf16),
595       replica_groups={{0},{1}},
596       all_reduce_id=1,
597       to_apply=%sum.bf16,
598       sharding={maximal device=1}
599   %convert.2 = f32[]
600       convert(%all-reduce.ar.2),
601       sharding={maximal device=1}
602   %add.2 = f32[]
603       add(%constant.f32, %convert.2),
604       sharding={maximal device=1}
605   %all-reduce.2 = f32[]
606       all-reduce(%add.2),
607       replica_groups={{0,1}},
608       to_apply=%sum.f32,
609       sharding={maximal device=1}
610 
611   ROOT %tuple = (f32[], f32[])
612       tuple(%all-reduce.1, %all-reduce.2),
613       sharding={{maximal device=0}, {maximal device=1}}
614 }
615 )";
616 
617   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
618                           ParseAndReturnVerifiedModule(module_str));
619   auto crs_before =
620       module->entry_computation()->root_instruction()->operands()[0];
621   auto replica_groups_before = crs_before->replica_groups();
622   ArCrsCombiner combiner(2);
623   auto changed = combiner.Run(module.get()).ValueOrDie();
624   EXPECT_TRUE(changed);
625   EXPECT_THAT(
626       module->entry_computation()->root_instruction(),
627       op::Tuple(
628           op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
629                                 op::Convert())),
630           op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
631                                 op::Convert()))));
632   auto crs_after =
633       module->entry_computation()->root_instruction()->operands()[0];
634   auto replica_groups_after = crs_after->replica_groups();
635   CompareReplicaGroups(replica_groups_before, replica_groups_after);
636 }
637 
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewrite)638 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
639   const char* module_str = R"(
640 HloModule foobar
641 
642 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
643   %a = bf16[] parameter(0)
644   %b = bf16[] parameter(1)
645   ROOT %add = bf16[] add(%a, %b)
646 }
647 
648 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
649   %x = f32[] parameter(0)
650   %y = f32[] parameter(1)
651   ROOT %add = f32[] add(%x, %y)
652 }
653 
654 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
655   %p = f32[] parameter(0)
656   %constant.bf16 = bf16[] constant(1)
657   %constant.f32.1 = f32[] constant(2)
658   %constant.f32.2 = f32[] constant(3)
659 
660   %all-reduce.ar.1 = bf16[]
661       all-reduce(%constant.bf16),
662       replica_groups={{0},{1}},
663       all_reduce_id=1,
664       to_apply=%sum.bf16,
665       sharding={maximal device=0}
666   %convert.1 = f32[]
667       convert(%all-reduce.ar.1),
668       sharding={maximal device=0}
669   %add.1 = f32[]
670       add(%constant.f32.1, %convert.1),
671       sharding={maximal device=0}
672   %all-reduce.1 = f32[]
673       all-reduce(%add.1),
674       replica_groups={{0,1}},
675       to_apply=%sum.f32,
676       sharding={maximal device=0}
677 
678   %all-reduce.ar.2 = bf16[]
679       all-reduce(%constant.bf16),
680       replica_groups={{0},{1}},
681       all_reduce_id=1,
682       to_apply=%sum.bf16,
683       sharding={maximal device=1}
684   %convert.2 = f32[]
685       convert(%all-reduce.ar.2),
686       sharding={maximal device=1}
687   %add.2 = f32[]
688       add(%constant.f32.2, %convert.2),
689       sharding={maximal device=1}
690   %all-reduce.2 = f32[]
691       all-reduce(%add.2),
692       replica_groups={{0,1}},
693       to_apply=%sum.f32,
694       sharding={maximal device=1}
695 
696   ROOT %tuple = (f32[], f32[])
697       tuple(%all-reduce.1, %all-reduce.2),
698       sharding={{maximal device=0}, {maximal device=1}}
699 }
700 )";
701 
702   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
703                           ParseAndReturnVerifiedModule(module_str));
704   ArCrsCombiner combiner(2);
705   auto changed = combiner.Run(module.get()).ValueOrDie();
706   EXPECT_FALSE(changed);
707 }
708 
TEST_F(ArCrsCombinerTest,ArThenCrsDontCrash)709 TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) {
710   const char* module_str = R"(
711 HloModule foobar
712 
713 %sum.1 (a: f32[], b: f32[]) -> f32[] {
714   %a = f32[] parameter(0)
715   %b = f32[] parameter(1)
716   ROOT %add = f32[] add(%a, %b)
717 }
718 
719 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
720   %p = f32[] parameter(0)
721   %constant.f32 = f32[] constant(123)
722 
723   %all-reduce.ar.1 = f32[]
724       all-reduce(%p),
725       replica_groups={{0},{1}},
726       all_reduce_id=1,
727       to_apply=%sum.1,
728       sharding={maximal device=0}
729   %all-reduce.1 = f32[]
730       all-reduce(%all-reduce.ar.1),
731       replica_groups={{0,1}},
732       to_apply=%sum.1,
733       sharding={maximal device=0}
734   %multiply.1 = f32[]
735       multiply(%all-reduce.1, %constant.f32),
736       sharding={maximal device=0}
737 
738   %all-reduce.ar.2 = f32[]
739       all-reduce(%p),
740       replica_groups={{0},{1}},
741       all_reduce_id=1,
742       to_apply=%sum.1,
743       sharding={maximal device=1}
744   %all-reduce.2 = f32[]
745       all-reduce(%all-reduce.ar.2),
746       replica_groups={{0,1}},
747       to_apply=%sum.1,
748       sharding={maximal device=1}
749   %multiply.2 = f32[]
750       multiply(%all-reduce.2, %constant.f32),
751       sharding={maximal device=1}
752 
753   ROOT %tuple = (f32[], f32[])
754       tuple(%all-reduce.1, %all-reduce.2),
755       sharding={{maximal device=0}, {maximal device=1}}
756 }
757 )";
758 
759   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
760                           ParseAndReturnVerifiedModule(module_str));
761   auto crs_before =
762       module->entry_computation()->root_instruction()->operands()[0];
763   auto replica_groups_before = crs_before->replica_groups();
764   ArCrsCombiner combiner(2);
765   auto changed = combiner.Run(module.get()).ValueOrDie();
766   EXPECT_TRUE(changed);
767   EXPECT_THAT(module->entry_computation()->root_instruction(),
768               op::Tuple(op::AllReduce(op::Parameter()),
769                         op::AllReduce(op::Parameter())));
770   auto crs_after =
771       module->entry_computation()->root_instruction()->operands()[0];
772   auto replica_groups_after = crs_after->replica_groups();
773   CompareReplicaGroups(replica_groups_before, replica_groups_after);
774 }
775 
TEST_F(ArCrsCombinerTest,RewriteMultipleAdds)776 TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) {
777   const char* module_str = R"(
778 HloModule foobar
779 
780 %sum (x: f32[], y: f32[]) -> f32[] {
781   %x = f32[] parameter(0)
782   %y = f32[] parameter(1)
783   ROOT %add = f32[] add(%x, %y)
784 }
785 
786 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
787   %p = f32[] parameter(0)
788   %constant.1 = f32[] constant(1)
789   %constant.2 = f32[] constant(2)
790 
791   %all-reduce.ar.1 = f32[]
792       all-reduce(%p),
793       replica_groups={{0},{1}},
794       all_reduce_id=1,
795       to_apply=%sum,
796       sharding={maximal device=0}
797   %add.11 = f32[]
798       add(%constant.1, %all-reduce.ar.1),
799       sharding={maximal device=0}
800   %add.12 = f32[]
801       add(%constant.2, %add.11),
802       sharding={maximal device=0}
803   %all-reduce.1 = f32[]
804       all-reduce(%add.12),
805       replica_groups={{0,1}},
806       to_apply=%sum,
807       sharding={maximal device=0}
808 
809   %all-reduce.ar.2 = f32[]
810       all-reduce(%p),
811       replica_groups={{0},{1}},
812       all_reduce_id=1,
813       to_apply=%sum,
814       sharding={maximal device=0}
815   %add.21 = f32[]
816       add(%constant.1, %all-reduce.ar.2),
817       sharding={maximal device=0}
818   %add.22 = f32[]
819       add(%constant.2, %add.21),
820       sharding={maximal device=0}
821   %all-reduce.2 = f32[]
822       all-reduce(%add.22),
823       replica_groups={{0,1}},
824       to_apply=%sum,
825       sharding={maximal device=0}
826 
827   ROOT %tuple = (f32[], f32[])
828       tuple(%all-reduce.1, %all-reduce.2),
829       sharding={{maximal device=0}, {maximal device=1}}
830 }
831 )";
832 
833   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
834                           ParseAndReturnVerifiedModule(module_str));
835   auto crs_before =
836       module->entry_computation()->root_instruction()->operands()[0];
837   auto replica_groups_before = crs_before->replica_groups();
838   ArCrsCombiner combiner(2);
839   auto changed = combiner.Run(module.get()).ValueOrDie();
840   EXPECT_TRUE(changed);
841   EXPECT_THAT(module->entry_computation()->root_instruction(),
842               op::Tuple(op::AllReduce(op::Add(
843                             op::Divide(op::Constant(), op::Constant()),
844                             op::Add(op::Divide(op::Constant(), op::Constant()),
845                                     op::Parameter()))),
846                         op::AllReduce(op::Add(
847                             op::Divide(op::Constant(), op::Constant()),
848                             op::Add(op::Divide(op::Constant(), op::Constant()),
849                                     op::Parameter())))));
850   auto crs_after =
851       module->entry_computation()->root_instruction()->operands()[0];
852   auto replica_groups_after = crs_after->replica_groups();
853   CompareReplicaGroups(replica_groups_before, replica_groups_after);
854 }
855 
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrs)856 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
857   const char* module_str = R"(
858 HloModule foobar
859 
860 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
861   %x = f32[] parameter(0)
862   %y = f32[] parameter(1)
863   ROOT %add = f32[] add(%x, %y)
864 }
865 
866 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
867   %p = f32[] parameter(0)
868   %constant.f32 = f32[] constant(123)
869 
870   %all-reduce.ar.1 = f32[]
871       all-reduce(%p),
872       replica_groups={{0},{1}},
873       all_reduce_id=1,
874       to_apply=%sum.f32,
875       sharding={maximal device=0}
876   %sub.1 = f32[]
877       subtract(%constant.f32, %all-reduce.ar.1),
878       sharding={maximal device=0}
879   %all-reduce.1 = f32[]
880       all-reduce(%sub.1),
881       replica_groups={{0,1}},
882       to_apply=%sum.f32,
883       sharding={maximal device=0}
884 
885   %all-reduce.ar.2 = f32[]
886       all-reduce(%p),
887       replica_groups={{0},{1}},
888       all_reduce_id=1,
889       to_apply=%sum.f32,
890       sharding={maximal device=1}
891   %sub.2 = f32[]
892       subtract(%constant.f32, %all-reduce.ar.2),
893       sharding={maximal device=1}
894   %all-reduce.2 = f32[]
895       all-reduce(%sub.2),
896       replica_groups={{0,1}},
897       to_apply=%sum.f32,
898       sharding={maximal device=1}
899 
900   ROOT %tuple = (f32[], f32[])
901       tuple(%all-reduce.1, %all-reduce.2),
902       sharding={{maximal device=0}, {maximal device=1}}
903 }
904 )";
905 
906   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
907                           ParseAndReturnVerifiedModule(module_str));
908   auto crs_before =
909       module->entry_computation()->root_instruction()->operands()[0];
910   auto replica_groups_before = crs_before->replica_groups();
911   ArCrsCombiner combiner(2);
912   auto changed = combiner.Run(module.get()).ValueOrDie();
913   EXPECT_TRUE(changed);
914   EXPECT_THAT(
915       module->entry_computation()->root_instruction(),
916       op::Tuple(
917           op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
918                                      op::Parameter())),
919           op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
920                                      op::Parameter()))));
921   auto crs_after =
922       module->entry_computation()->root_instruction()->operands()[0];
923   auto replica_groups_after = crs_after->replica_groups();
924   CompareReplicaGroups(replica_groups_before, replica_groups_after);
925 }
926 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeft)927 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
928   const char* module_str = R"(
929 HloModule foobar
930 
931 %sum (x: f32[], y: f32[]) -> f32[] {
932   %x = f32[] parameter(0)
933   %y = f32[] parameter(1)
934   ROOT %add = f32[] add(%x, %y)
935 }
936 
937 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
938   %p = f32[] parameter(0)
939   %const1 = f32[] constant(1)
940   %const2 = f32[] constant(2)
941 
942   %ar11 = f32[]
943       all-reduce(%p),
944       replica_groups={{0},{1}},
945       all_reduce_id=1,
946       to_apply=%sum,
947       sharding={maximal device=0}
948   %add11 = f32[]
949       add(%ar11, %const1),
950       sharding={maximal device=0}
951   %ar12 = f32[]
952       all-reduce(%p),
953       replica_groups={{0},{1}},
954       all_reduce_id=2,
955       to_apply=%sum,
956       sharding={maximal device=0}
957   %add12 = f32[]
958       add(%add11, %ar12),
959       sharding={maximal device=0}
960   %crs1 = f32[]
961       all-reduce(%add12),
962       replica_groups={{0,1}},
963       to_apply=%sum,
964       sharding={maximal device=0}
965 
966   %ar21 = f32[]
967       all-reduce(%p),
968       replica_groups={{0},{1}},
969       all_reduce_id=1,
970       to_apply=%sum,
971       sharding={maximal device=1}
972   %add21 = f32[]
973       add(%ar21, %const1),
974       sharding={maximal device=1}
975   %ar22 = f32[]
976       all-reduce(%p),
977       replica_groups={{0},{1}},
978       all_reduce_id=2,
979       to_apply=%sum,
980       sharding={maximal device=1}
981   %add22 = f32[]
982       add(%add21, %ar22),
983       sharding={maximal device=1}
984   %crs2 = f32[]
985       all-reduce(%add22),
986       replica_groups={{0,1}},
987       to_apply=%sum,
988       sharding={maximal device=1}
989 
990   ROOT %tuple = (f32[], f32[])
991       tuple(%crs1, %crs2),
992       sharding={{maximal device=0}, {maximal device=1}}
993 }
994 )";
995 
996   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
997                           ParseAndReturnVerifiedModule(module_str));
998   auto crs_before =
999       module->entry_computation()->root_instruction()->operands()[0];
1000   auto replica_groups_before = crs_before->replica_groups();
1001   ArCrsCombiner combiner(2);
1002   auto changed = combiner.Run(module.get()).ValueOrDie();
1003   EXPECT_TRUE(changed);
1004   EXPECT_THAT(module->entry_computation()->root_instruction(),
1005               op::Tuple(op::AllReduce(op::Add(
1006                             op::Add(op::Parameter(),
1007                                     op::Divide(op::Constant(), op::Constant())),
1008                             op::Parameter())),
1009                         op::AllReduce(op::Add(
1010                             op::Add(op::Parameter(),
1011                                     op::Divide(op::Constant(), op::Constant())),
1012                             op::Parameter()))));
1013   auto crs_after =
1014       module->entry_computation()->root_instruction()->operands()[0];
1015   auto replica_groups_after = crs_after->replica_groups();
1016   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1017 }
1018 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRight)1019 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
1020   const char* module_str = R"(
1021 HloModule foobar
1022 
1023 %sum (x: f32[], y: f32[]) -> f32[] {
1024   %x = f32[] parameter(0)
1025   %y = f32[] parameter(1)
1026   ROOT %add = f32[] add(%x, %y)
1027 }
1028 
1029 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1030   %p = f32[] parameter(0)
1031   %const1 = f32[] constant(1)
1032   %const2 = f32[] constant(2)
1033 
1034   %ar11 = f32[]
1035       all-reduce(%p),
1036       replica_groups={{0},{1}},
1037       all_reduce_id=1,
1038       to_apply=%sum,
1039       sharding={maximal device=0}
1040   %ar12 = f32[]
1041       all-reduce(%p),
1042       replica_groups={{0},{1}},
1043       all_reduce_id=2,
1044       to_apply=%sum,
1045       sharding={maximal device=0}
1046   %add11 = f32[]
1047       add(%ar12, %const1),
1048       sharding={maximal device=0}
1049   %add12 = f32[]
1050       add(%ar11, %add11),
1051       sharding={maximal device=0}
1052   %crs1 = f32[]
1053       all-reduce(%add12),
1054       replica_groups={{0,1}},
1055       to_apply=%sum,
1056       sharding={maximal device=0}
1057 
1058   %ar21 = f32[]
1059       all-reduce(%p),
1060       replica_groups={{0},{1}},
1061       all_reduce_id=1,
1062       to_apply=%sum,
1063       sharding={maximal device=1}
1064   %ar22 = f32[]
1065       all-reduce(%p),
1066       replica_groups={{0},{1}},
1067       all_reduce_id=2,
1068       to_apply=%sum,
1069       sharding={maximal device=1}
1070   %add21 = f32[]
1071       add(%ar22, %const1),
1072       sharding={maximal device=1}
1073   %add22 = f32[]
1074       add(%ar21, %add21),
1075       sharding={maximal device=1}
1076   %crs2 = f32[]
1077       all-reduce(%add22),
1078       replica_groups={{0,1}},
1079       to_apply=%sum,
1080       sharding={maximal device=1}
1081 
1082   ROOT %tuple = (f32[], f32[])
1083       tuple(%crs1, %crs2),
1084       sharding={{maximal device=0}, {maximal device=1}}
1085 }
1086 )";
1087 
1088   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1089                           ParseAndReturnVerifiedModule(module_str));
1090   auto crs_before =
1091       module->entry_computation()->root_instruction()->operands()[0];
1092   auto replica_groups_before = crs_before->replica_groups();
1093   ArCrsCombiner combiner(2);
1094   auto changed = combiner.Run(module.get()).ValueOrDie();
1095   EXPECT_TRUE(changed);
1096   EXPECT_THAT(
1097       module->entry_computation()->root_instruction(),
1098       op::Tuple(op::AllReduce(op::Add(
1099                     op::Parameter(),
1100                     op::Add(op::Parameter(),
1101                             op::Divide(op::Constant(), op::Constant())))),
1102                 op::AllReduce(op::Add(
1103                     op::Parameter(),
1104                     op::Add(op::Parameter(),
1105                             op::Divide(op::Constant(), op::Constant()))))));
1106 
1107   auto crs_after =
1108       module->entry_computation()->root_instruction()->operands()[0];
1109   auto replica_groups_after = crs_after->replica_groups();
1110   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1111 }
1112 
TEST_F(ArCrsCombinerTest,OneReplicaDontRewrite)1113 TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
1114   const char* module_str = R"(
1115 HloModule foobar
1116 
1117 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1118   %a = bf16[] parameter(0)
1119   %b = bf16[] parameter(1)
1120   ROOT %add = bf16[] add(%a, %b)
1121 }
1122 
1123 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1124   %x = f32[] parameter(0)
1125   %y = f32[] parameter(1)
1126   ROOT %add = f32[] add(%x, %y)
1127 }
1128 
1129 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1130   %p = bf16[] parameter(0)
1131   %constant.bf16 = bf16[] constant(1)
1132 
1133   %all-reduce.ar.1 = bf16[]
1134       all-reduce(%p),
1135       replica_groups={{0}},
1136       all_reduce_id=1,
1137       to_apply=%sum.bf16,
1138       sharding={maximal device=0}
1139   %convert.1 = f32[]
1140       convert(%all-reduce.ar.1),
1141       sharding={maximal device=0}
1142   %all-reduce.1 = f32[]
1143       all-reduce(%convert.1),
1144       replica_groups={{0}},
1145       to_apply=%sum.f32,
1146       sharding={maximal device=0}
1147 
1148   %all-reduce.ar.2 = bf16[]
1149       all-reduce(%constant.bf16),
1150       replica_groups={{0}},
1151       all_reduce_id=1,
1152       to_apply=%sum.bf16,
1153       sharding={maximal device=1}
1154   %convert.2 = f32[]
1155       convert(%all-reduce.ar.2),
1156       sharding={maximal device=1}
1157   %all-reduce.2 = f32[]
1158       all-reduce(%convert.2),
1159       replica_groups={{0}},
1160       to_apply=%sum.f32,
1161       sharding={maximal device=1}
1162 
1163   ROOT %tuple = (f32[], f32[])
1164       tuple(%all-reduce.1, %all-reduce.2),
1165       sharding={{maximal device=0}, {maximal device=1}}
1166 }
1167 )";
1168 
1169   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1170                           ParseAndReturnVerifiedModule(module_str));
1171   ArCrsCombiner combiner(2);
1172   auto changed = combiner.Run(module.get()).ValueOrDie();
1173   EXPECT_FALSE(changed);
1174 }
1175 
1176 }  // namespace
1177 }  // namespace xla
1178