1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/memory/memory.h"
17 #include "tensorflow/compiler/xla/debug_options_flags.h"
18 #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
19 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
20 #include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 
27 namespace xla {
28 namespace {
29 
30 class HloDomainTest : public HloTestBase {
31  protected:
FindUserViaDomainPath(HloInstruction * instruction,HloInstruction * operand) const32   bool FindUserViaDomainPath(HloInstruction* instruction,
33                              HloInstruction* operand) const {
34     for (HloInstruction* user : operand->users()) {
35       if (user == instruction) {
36         return true;
37       }
38       if (user->opcode() == HloOpcode::kDomain &&
39           FindUserViaDomainPath(instruction, user)) {
40         return true;
41       }
42     }
43     return false;
44   }
45 
46   // Checks whether there is a kDomain instruction in the edge between the
47   // instruction and the operand.
HasDomainEdge(HloModule * module,absl::string_view instruction_name,absl::string_view operand_name)48   bool HasDomainEdge(HloModule* module, absl::string_view instruction_name,
49                      absl::string_view operand_name) {
50     HloInstruction* instruction = FindInstruction(module, instruction_name);
51     HloInstruction* operand = FindInstruction(module, operand_name);
52     CHECK_NE(instruction, nullptr);
53     CHECK_NE(operand, nullptr);
54     if (!instruction->IsUserOf(operand)) {
55       // If instruction is not an immediate user, we must find a path from
56       // operand to instruction anyway, otherwise there is a corruption.
57       if (FindUserViaDomainPath(instruction, operand)) {
58         return true;
59       }
60       LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name
61                  << "' and '" << operand_name << "' instructions:\n"
62                  << module->ToString();
63     }
64     return false;
65   }
66 };
67 
68 // Dummy DomainMetadata implementation which create kDomain boundaries around
69 // HLO instructions with the same metadata().op_name() values.
70 class OpNameMetadata : public DomainMetadata {
71  public:
OpNameMetadata(string opname)72   explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
73 
Clone() const74   std::unique_ptr<DomainMetadata> Clone() const override {
75     return absl::make_unique<OpNameMetadata>(opname_);
76   }
77 
Kind() const78   absl::string_view Kind() const override { return KindName(); }
79 
Matches(const DomainMetadata & other) const80   bool Matches(const DomainMetadata& other) const override {
81     const OpNameMetadata* other_ptr =
82         dynamic_cast<const OpNameMetadata*>(&other);
83     if (other_ptr == nullptr) {
84       // If other is not a OpNameMetadata, then it is clearly a no match.
85       return false;
86     }
87     return opname_ == other_ptr->opname_;
88   }
89 
ToString() const90   string ToString() const override { return opname_; }
91 
KindName()92   static absl::string_view KindName() { return "opname"; }
93 
Hash() const94   size_t Hash() const override { return std::hash<string>()(opname_); }
95 
96  private:
97   string opname_;
98 };
99 
100 // Creator function for OpNameMetadata domains.
101 class OpNameDomainCreator {
102  public:
operator ()(HloInstruction * instruction,HloInstruction * root,HloInstruction * operand)103   HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
104                              HloInstruction* operand) {
105     if (instruction->metadata().op_name() == root->metadata().op_name()) {
106       return nullptr;
107     }
108     std::unique_ptr<DomainMetadata> operand_side_metadata =
109         absl::make_unique<OpNameMetadata>(root->metadata().op_name());
110     std::unique_ptr<DomainMetadata> user_side_metadata =
111         absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
112     return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
113         operand->shape(), operand, std::move(operand_side_metadata),
114         std::move(user_side_metadata)));
115   }
116 };
117 
OpNameDomainNormalizer(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)118 Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
119                               const DomainMetadata* metadata) {
120   // Nothing to do for the particular use this test make of the OpName domains.
121   return Status::OK();
122 }
123 
TEST_F(HloDomainTest,CheckDomainLinks)124 TEST_F(HloDomainTest, CheckDomainLinks) {
125   const char* const hlo_string = R"(
126 HloModule Module
127 
128 ENTRY entry {
129   p0 = (f32[4], f32[4]) parameter(0)
130   a = f32[4] get-tuple-element(p0), index=0
131   b = f32[4] get-tuple-element(p0), index=1
132   c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1}
133   d = f32[4] subtract(a, b), sharding={maximal device=1}
134   e = f32[4] multiply(c, d), sharding={maximal device=1}
135   ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
136 }
137 )";
138 
139   TF_ASSERT_OK_AND_ASSIGN(auto module,
140                           ParseAndReturnVerifiedModule(hlo_string));
141   LOG(INFO) << "Original module:\n" << module->ToString();
142 
143   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
144   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
145   EXPECT_TRUE(isolator_changed);
146 
147   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
148   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
149   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
150   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b"));
151   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
152   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
153 
154   HloDomainRemover remover(ShardingMetadata::KindName(),
155                            ShardingMetadata::NormalizeShardingDomain);
156   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
157   EXPECT_TRUE(remover_changed);
158 
159   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
160   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
161   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
162   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b"));
163   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
164   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
165 }
166 
TEST_F(HloDomainTest,CheckNoDomainAddedIfNoSharding)167 TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) {
168   const char* const hlo_string = R"(
169 HloModule Module
170 
171 ENTRY entry {
172   p0 = (f32[4], f32[4]) parameter(0)
173   a = f32[4] get-tuple-element(p0), index=0
174   b = f32[4] get-tuple-element(p0), index=1
175   c = f32[4] add(f32[4] a, f32[4] b)
176   d = f32[4] subtract(a, b)
177   e = f32[4] multiply(c, d)
178   ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
179 }
180 )";
181 
182   TF_ASSERT_OK_AND_ASSIGN(auto module,
183                           ParseAndReturnVerifiedModule(hlo_string));
184   LOG(INFO) << "Original module:\n" << module->ToString();
185 
186   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
187   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
188   EXPECT_TRUE(!isolator_changed);
189 }
190 
TEST_F(HloDomainTest,CheckDomainAroundIO)191 TEST_F(HloDomainTest, CheckDomainAroundIO) {
192   const char* const hlo_string = R"(
193 HloModule Module
194 
195 ENTRY entry {
196   p0 = (f32[4]) parameter(0)
197   a = f32[4] get-tuple-element(p0), index=0
198   token0 = token[] after-all()
199   b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0}
200   c = token[] send-done(b), channel_id=1, sharding={maximal device=0}
201   d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0}
202   e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0}
203   e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0}
204   f = f32[4] add(a, e_element)
205   g = f32[4] subtract(a, e_element)
206   ROOT h = (f32[4], f32[4]) tuple(f, g)
207 }
208 )";
209 
210   TF_ASSERT_OK_AND_ASSIGN(auto module,
211                           ParseAndReturnVerifiedModule(hlo_string));
212   LOG(INFO) << "Original module:\n" << module->ToString();
213 
214   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
215   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
216   EXPECT_TRUE(isolator_changed);
217 
218   EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a"));
219   EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element"));
220   EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0"));
221   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
222   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
223 
224   HloDomainRemover remover(ShardingMetadata::KindName(),
225                            ShardingMetadata::NormalizeShardingDomain);
226   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
227   EXPECT_TRUE(remover_changed);
228 
229   EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a"));
230   EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element"));
231 }
232 
TEST_F(HloDomainTest,CheckNoDomainAddedOnPureIOComputation)233 TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
234   const char* const hlo_string = R"(
235 HloModule Module
236 
237 ENTRY entry {
238   token0 = token[] after-all(), sharding={maximal device=-1}
239   a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1}
240   b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1}
241   b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1}
242   c = f32[4] add(b_element, b_element), sharding={maximal device=-1}
243   d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1}
244   ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1}
245 }
246 )";
247 
248   TF_ASSERT_OK_AND_ASSIGN(auto module,
249                           ParseAndReturnVerifiedModule(hlo_string));
250   LOG(INFO) << "Original module:\n" << module->ToString();
251 
252   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
253   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
254   EXPECT_FALSE(isolator_changed);
255 }
256 
TEST_F(HloDomainTest,CheckNormalizationOnPureIOComputation)257 TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
258   const char* const hlo_string = R"(
259 HloModule Module
260 
261 ENTRY entry {
262   token0 = token[] after-all(), sharding={maximal device=0}
263   a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0}
264   b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0}
265   b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0}
266   c = f32[4] add(b_element, b_element)
267   d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0}
268   ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0}
269 }
270 )";
271 
272   TF_ASSERT_OK_AND_ASSIGN(auto module,
273                           ParseAndReturnVerifiedModule(hlo_string));
274   LOG(INFO) << "Original module:\n" << module->ToString();
275 
276   HloDomainRemover remover(ShardingMetadata::KindName(),
277                            ShardingMetadata::NormalizeShardingDomain);
278   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
279   EXPECT_FALSE(remover_changed);
280 
281   HloInstruction* add = FindInstruction(module.get(), "c");
282   ASSERT_NE(add, nullptr);
283   auto device = add->sharding_unique_device();
284   EXPECT_TRUE(device.has_value());
285   EXPECT_EQ(*device, 0);
286 }
287 
TEST_F(HloDomainTest,CheckMultiDomainLinks)288 TEST_F(HloDomainTest, CheckMultiDomainLinks) {
289   const char* const hlo_string = R"(
290 HloModule Module
291 
292 ENTRY entry {
293   p0 = (f32[4], f32[4]) parameter(0)
294   a = f32[4] get-tuple-element(p0), index=0
295   b = f32[4] get-tuple-element(p0), index=1
296   c = f32[4] add(a, b), sharding={maximal device=1}
297   d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"}
298   e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"}
299   f = f32[4] add(e, c), sharding={maximal device=1}
300   ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f)
301 }
302 )";
303 
304   TF_ASSERT_OK_AND_ASSIGN(auto module,
305                           ParseAndReturnVerifiedModule(hlo_string));
306   LOG(INFO) << "Original module:\n" << module->ToString();
307 
308   HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; });
309   TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
310                           sharding_isolator.Run(module.get()));
311   EXPECT_TRUE(sharding_isolator_changed);
312 
313   HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
314   TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
315                           opname_isolator.Run(module.get()));
316   EXPECT_TRUE(opname_isolator_changed);
317 
318   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
319   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
320   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
321   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
322   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
323 
324   HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
325                                     ShardingMetadata::NormalizeShardingDomain);
326   TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
327                           sharding_remover.Run(module.get()));
328   EXPECT_TRUE(sharding_remover_changed);
329 
330   HloDomainRemover opname_remover(OpNameMetadata::KindName(),
331                                   OpNameDomainNormalizer);
332   TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
333                           opname_remover.Run(module.get()));
334   EXPECT_TRUE(opname_remover_changed);
335 
336   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
337   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
338   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
339   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
340 }
341 
TEST_F(HloDomainTest,CheckNormalizationOnInfeedTuple)342 TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
343   const char* const hlo_string = R"(
344 HloModule Module
345 
346 ENTRY entry {
347   token0 = token[] after-all()
348   infeed = ((f32[4], f32[4]), token[]) infeed(token0),
349     sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
350   infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0,
351     sharding={{maximal device=1}, {maximal device=0}}
352   gte0 = f32[4] get-tuple-element(infeed.data), index=0
353   gte1 = f32[4] get-tuple-element(infeed.data), index=1
354   copy0 = f32[4] copy(gte0)
355   copy1 = f32[4] copy(gte1)
356   ROOT add = f32[4] add(copy0, copy1)
357 }
358 )";
359 
360   TF_ASSERT_OK_AND_ASSIGN(auto module,
361                           ParseAndReturnVerifiedModule(hlo_string));
362   LOG(INFO) << "Original module:\n" << module->ToString();
363 
364   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
365   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
366   EXPECT_TRUE(isolator_changed);
367 
368   EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed"));
369   EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0"));
370   EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1"));
371 
372   // Inject unassigned tuple/gte within the infeed domain, to simulate the
373   // HLO passes adding unexpected instructions.
374   //
375   //            infeed
376   //              |
377   //          infeed.data (tuple element 0 of infeed)
378   //           /      \
379   //         GTE0    GTE1
380   //         /          \
381   //       COPY0       COPY1
382   //          \         /
383   //           \       /
384   //             TUPLE
385   //               |
386   HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data");
387   ASSERT_NE(infeed_data, nullptr);
388 
389   auto infeed_data_users = infeed_data->users();
390   HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
391       HloInstruction::CreateGetTupleElement(
392           ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data,
393           0));
394   HloInstruction* new_copy0 =
395       infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
396           new_gte0->shape(), HloOpcode::kCopy, new_gte0));
397   HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction(
398       HloInstruction::CreateGetTupleElement(
399           ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data,
400           1));
401   HloInstruction* new_copy1 =
402       infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
403           new_gte1->shape(), HloOpcode::kCopy, new_gte1));
404   HloInstruction* new_tuple = infeed_data->parent()->AddInstruction(
405       HloInstruction::CreateTuple({new_copy0, new_copy1}));
406   for (HloInstruction* user : infeed_data_users) {
407     TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple));
408   }
409 
410   HloDomainRemover remover(ShardingMetadata::KindName(),
411                            ShardingMetadata::NormalizeShardingDomain);
412   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
413   EXPECT_TRUE(remover_changed);
414 
415   struct Assignment {
416     HloInstruction* instruction;
417     int64 device;
418   } assignments[] = {
419       {new_gte0, 1},
420       {new_copy0, 1},
421       {new_gte1, 0},
422       {new_copy1, 0},
423   };
424   for (auto& assignment : assignments) {
425     auto device = assignment.instruction->sharding_unique_device();
426     ASSERT_TRUE(device.has_value());
427     EXPECT_EQ(*device, assignment.device);
428   }
429   EXPECT_TRUE(new_tuple->has_sharding());
430   EXPECT_EQ(
431       new_tuple->sharding(),
432       HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1),
433                                               HloSharding::AssignDevice(0)}));
434 }
435 
TEST_F(HloDomainTest,EmptyRootDomain)436 TEST_F(HloDomainTest, EmptyRootDomain) {
437   const char* const hlo_string = R"(
438 HloModule Module
439 
440 ENTRY entry {
441   %param = f32[1] parameter(0), sharding={maximal device=0}
442   %tuple = (f32[1]) tuple(%param),
443     sharding={maximal device=1}
444   ROOT %gte = f32[1] get-tuple-element(%tuple), index=0,
445     sharding={maximal device=1}
446 })";
447 
448   TF_ASSERT_OK_AND_ASSIGN(auto module,
449                           ParseAndReturnVerifiedModule(hlo_string));
450 
451   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
452   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
453   EXPECT_TRUE(isolator_changed);
454 
455   EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param"));
456   EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple"));
457 
458   // Remove %tuple and %gte (tuple simplification)
459   HloInstruction* gte = FindInstruction(module.get(), "gte");
460   HloInstruction* tuple = FindInstruction(module.get(), "tuple");
461   module->entry_computation()->set_root_instruction(tuple->mutable_operand(0));
462   TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte));
463   TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
464 
465   HloDomainRemover remover(ShardingMetadata::KindName(),
466                            ShardingMetadata::NormalizeShardingDomain);
467   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
468   EXPECT_TRUE(remover_changed);
469 
470   const HloInstruction* root = module->entry_computation()->root_instruction();
471   EXPECT_TRUE(root->has_sharding());
472   EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1));
473 }
474 
475 // Tests that text dumps of domain instructions can be parsed back, in the
476 // specific case of null shardings.
TEST_F(HloDomainTest,DumpParseNullSharding)477 TEST_F(HloDomainTest, DumpParseNullSharding) {
478   auto builder = HloComputation::Builder(TestName());
479   Shape shape = ShapeUtil::MakeShape(F32, {});
480   auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr);
481   auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr);
482   HloInstruction* param =
483       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
484   HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
485       shape, param, std::move(sharding_md_0), std::move(sharding_md_1)));
486   builder.AddInstruction(
487       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain));
488 
489   auto module = CreateNewVerifiedModule();
490   module->AddEntryComputation(builder.Build());
491 
492   auto hlo_string = module->ToString();
493   ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok());
494 }
495 
496 // Tuple inputs are domain instructions.
TEST_F(HloDomainTest,DomainTuple)497 TEST_F(HloDomainTest, DomainTuple) {
498   const char* const hlo_string = R"(
499 HloModule Module
500 
501 ENTRY entry {
502   p0 = f32[4] parameter(0), sharding={maximal device=0}
503   cst = u32[] constant(0), sharding={maximal device=1}
504   tpl = (u32[], f32[4]) tuple(cst, p0),
505     sharding={{maximal device=1}, {maximal device=0}}
506   ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0}
507 }
508 )";
509 
510   TF_ASSERT_OK_AND_ASSIGN(auto module,
511                           ParseAndReturnVerifiedModule(hlo_string));
512 
513   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
514   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
515   EXPECT_TRUE(isolator_changed);
516 
517   // Clear sharding of tpl instruction, in order to test domain sharding
518   // application.
519   auto tpl = FindInstruction(module.get(), "tpl");
520   tpl->clear_sharding();
521 
522   HloDomainRemover remover(ShardingMetadata::KindName(),
523                            ShardingMetadata::NormalizeShardingDomain);
524   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
525   EXPECT_TRUE(remover_changed);
526 
527   EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1),
528                                               HloSharding::AssignDevice(0)}),
529             tpl->sharding());
530 }
531 
TEST_F(HloDomainTest,MultiDomainMultiUser)532 TEST_F(HloDomainTest, MultiDomainMultiUser) {
533   const char* const hlo_string = R"(
534   HloModule Module
535 
536 ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
537   %p0 = (f32[4], f32[4]) parameter(0)
538   %a = f32[4]{0} get-tuple-element(%p0), index=0
539   %domain = f32[4] domain(%a),
540     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
541   %b = f32[4] get-tuple-element(%p0), index=1
542   %domain.1 = f32[4] domain(%b),
543     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
544   %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
545   %domain.2 = f32[4] domain(%c),
546     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
547   %d = f32[4] subtract(%domain, %c),
548     sharding={maximal device=1}, metadata={op_name="D"}
549   %domain.3 = f32[4] domain(%d),
550     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
551   %e = f32[4] multiply(%c, %d),
552     sharding={maximal device=1}, metadata={op_name="D"}
553   %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
554   %domain.4 = f32[4]{0} domain(%f),
555     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
556   ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
557 })";
558 
559   TF_ASSERT_OK_AND_ASSIGN(auto module,
560                           ParseAndReturnVerifiedModule(hlo_string));
561   LOG(INFO) << "Original module:\n" << module->ToString();
562 
563   HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
564   TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
565                           opname_isolator.Run(module.get()));
566   EXPECT_TRUE(opname_isolator_changed);
567 
568   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
569   EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
570   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
571   EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
572   EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
573 
574   HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
575                                     ShardingMetadata::NormalizeShardingDomain);
576   TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
577                           sharding_remover.Run(module.get()));
578   EXPECT_TRUE(sharding_remover_changed);
579 
580   HloDomainRemover opname_remover(OpNameMetadata::KindName(),
581                                   OpNameDomainNormalizer);
582   TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
583                           opname_remover.Run(module.get()));
584   EXPECT_TRUE(opname_remover_changed);
585 
586   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
587   EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
588   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
589   EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
590 }
591 
592 // Emulate instructions inserted at top and bottom within nested tuple domain.
TEST_F(HloDomainTest,DomainTupleTopBottomInsert)593 TEST_F(HloDomainTest, DomainTupleTopBottomInsert) {
594   const char* const hlo_string = R"(
595 HloModule Module
596 
597 ENTRY entry {
598   p0 = f32[4] parameter(0), sharding={maximal device=1}
599   p1 = (f32[5], f32[6]) parameter(1),
600     sharding={{maximal device=1}, {maximal device=0}}
601   tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1),
602     sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}}
603   ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1,
604     sharding={{maximal device=1}, {maximal device=0}}
605 }
606 )";
607 
608   TF_ASSERT_OK_AND_ASSIGN(auto module,
609                           ParseAndReturnVerifiedModule(hlo_string));
610 
611   HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
612   TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
613   EXPECT_TRUE(isolator_changed);
614 
615   // Clear sharding of tuple.0 instruction, in order to test domain sharding
616   // application.
617   auto tuple0 = FindInstruction(module.get(), "tuple.0");
618   tuple0->clear_sharding();
619 
620   // Insert the following instructons above and below tuple.0, to emulate other
621   // passes effects:
622   //                 COPY.0
623   //             \    /
624   //            TUPLE.0
625   //              /    \
626   //           COPY.1   \
627   //            /        \
628   //         GTE.0      GTE.1
629   //           |          |
630   //           |        COPY.2
631   //            \       /
632   //             \     /
633   //             TUPLE.1
634   //                |
635   auto tuple0_users = tuple0->users();
636   auto computation = tuple0->parent();
637   HloInstruction* copy0 = computation->AddInstruction(
638       HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy,
639                                   tuple0->mutable_operand(1)));
640   TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0));
641 
642   HloInstruction* copy1 = computation->AddInstruction(
643       HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0));
644   HloInstruction* gte0 =
645       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
646           ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0));
647   HloInstruction* gte1 =
648       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
649           ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1));
650   HloInstruction* copy2 = computation->AddInstruction(
651       HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1));
652   HloInstruction* tuple1 =
653       computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2}));
654 
655   for (HloInstruction* user : tuple0_users) {
656     TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1));
657   }
658 
659   HloDomainRemover remover(ShardingMetadata::KindName(),
660                            ShardingMetadata::NormalizeShardingDomain);
661   TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
662   EXPECT_TRUE(remover_changed);
663 
664   EXPECT_TRUE(tuple0->has_sharding());
665   EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1),
666                                                  HloSharding::AssignDevice(1),
667                                                  HloSharding::AssignDevice(0)}),
668             tuple0->sharding());
669 
670   EXPECT_TRUE(copy0->has_sharding());
671   EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1),
672                                                 HloSharding::AssignDevice(0)}),
673             copy0->sharding());
674 
675   // copy1 has partial information only from gte.0, so in the end it gets no
676   // sharding at all. During propagation it does propagate the information from
677   // gte.0 though, enabling Tuple.0 to be fully sharded.
678   EXPECT_FALSE(copy1->has_sharding());
679 
680   EXPECT_TRUE(gte0->has_sharding());
681   EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding());
682 
683   EXPECT_TRUE(gte1->has_sharding());
684   EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1),
685                                                HloSharding::AssignDevice(0)}),
686             gte1->sharding());
687 
688   EXPECT_TRUE(copy2->has_sharding());
689   EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1),
690                                                 HloSharding::AssignDevice(0)}),
691             copy2->sharding());
692 
693   EXPECT_TRUE(tuple1->has_sharding());
694   EXPECT_EQ(tuple0->sharding(), tuple1->sharding());
695 }
696 
697 }  // namespace
698 }  // namespace xla
699