1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/memory/memory.h"
17 #include "tensorflow/compiler/xla/debug_options_flags.h"
18 #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
19 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
20 #include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26
27 namespace xla {
28 namespace {
29
30 class HloDomainTest : public HloTestBase {
31 protected:
FindUserViaDomainPath(HloInstruction * instruction,HloInstruction * operand) const32 bool FindUserViaDomainPath(HloInstruction* instruction,
33 HloInstruction* operand) const {
34 for (HloInstruction* user : operand->users()) {
35 if (user == instruction) {
36 return true;
37 }
38 if (user->opcode() == HloOpcode::kDomain &&
39 FindUserViaDomainPath(instruction, user)) {
40 return true;
41 }
42 }
43 return false;
44 }
45
46 // Checks whether there is a kDomain instruction in the edge between the
47 // instruction and the operand.
HasDomainEdge(HloModule * module,absl::string_view instruction_name,absl::string_view operand_name)48 bool HasDomainEdge(HloModule* module, absl::string_view instruction_name,
49 absl::string_view operand_name) {
50 HloInstruction* instruction = FindInstruction(module, instruction_name);
51 HloInstruction* operand = FindInstruction(module, operand_name);
52 CHECK_NE(instruction, nullptr);
53 CHECK_NE(operand, nullptr);
54 if (!instruction->IsUserOf(operand)) {
55 // If instruction is not an immediate user, we must find a path from
56 // operand to instruction anyway, otherwise there is a corruption.
57 if (FindUserViaDomainPath(instruction, operand)) {
58 return true;
59 }
60 LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name
61 << "' and '" << operand_name << "' instructions:\n"
62 << module->ToString();
63 }
64 return false;
65 }
66 };
67
68 // Dummy DomainMetadata implementation which create kDomain boundaries around
69 // HLO instructions with the same metadata().op_name() values.
70 class OpNameMetadata : public DomainMetadata {
71 public:
OpNameMetadata(string opname)72 explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
73
Clone() const74 std::unique_ptr<DomainMetadata> Clone() const override {
75 return absl::make_unique<OpNameMetadata>(opname_);
76 }
77
Kind() const78 absl::string_view Kind() const override { return KindName(); }
79
Matches(const DomainMetadata & other) const80 bool Matches(const DomainMetadata& other) const override {
81 const OpNameMetadata* other_ptr =
82 dynamic_cast<const OpNameMetadata*>(&other);
83 if (other_ptr == nullptr) {
84 // If other is not a OpNameMetadata, then it is clearly a no match.
85 return false;
86 }
87 return opname_ == other_ptr->opname_;
88 }
89
ToString() const90 string ToString() const override { return opname_; }
91
KindName()92 static absl::string_view KindName() { return "opname"; }
93
Hash() const94 size_t Hash() const override { return std::hash<string>()(opname_); }
95
96 private:
97 string opname_;
98 };
99
100 // Creator function for OpNameMetadata domains.
101 class OpNameDomainCreator {
102 public:
operator ()(HloInstruction * instruction,HloInstruction * root,HloInstruction * operand)103 HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
104 HloInstruction* operand) {
105 if (instruction->metadata().op_name() == root->metadata().op_name()) {
106 return nullptr;
107 }
108 std::unique_ptr<DomainMetadata> operand_side_metadata =
109 absl::make_unique<OpNameMetadata>(root->metadata().op_name());
110 std::unique_ptr<DomainMetadata> user_side_metadata =
111 absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
112 return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
113 operand->shape(), operand, std::move(operand_side_metadata),
114 std::move(user_side_metadata)));
115 }
116 };
117
OpNameDomainNormalizer(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)118 Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
119 const DomainMetadata* metadata) {
120 // Nothing to do for the particular use this test make of the OpName domains.
121 return Status::OK();
122 }
123
TEST_F(HloDomainTest,CheckDomainLinks)124 TEST_F(HloDomainTest, CheckDomainLinks) {
125 const char* const hlo_string = R"(
126 HloModule Module
127
128 ENTRY entry {
129 p0 = (f32[4], f32[4]) parameter(0)
130 a = f32[4] get-tuple-element(p0), index=0
131 b = f32[4] get-tuple-element(p0), index=1
132 c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1}
133 d = f32[4] subtract(a, b), sharding={maximal device=1}
134 e = f32[4] multiply(c, d), sharding={maximal device=1}
135 ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
136 }
137 )";
138
139 TF_ASSERT_OK_AND_ASSIGN(auto module,
140 ParseAndReturnVerifiedModule(hlo_string));
141 LOG(INFO) << "Original module:\n" << module->ToString();
142
143 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
144 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
145 EXPECT_TRUE(isolator_changed);
146
147 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
148 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
149 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
150 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b"));
151 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
152 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
153
154 HloDomainRemover remover(ShardingMetadata::KindName(),
155 ShardingMetadata::NormalizeShardingDomain);
156 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
157 EXPECT_TRUE(remover_changed);
158
159 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
160 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
161 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
162 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b"));
163 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
164 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
165 }
166
TEST_F(HloDomainTest,CheckNoDomainAddedIfNoSharding)167 TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) {
168 const char* const hlo_string = R"(
169 HloModule Module
170
171 ENTRY entry {
172 p0 = (f32[4], f32[4]) parameter(0)
173 a = f32[4] get-tuple-element(p0), index=0
174 b = f32[4] get-tuple-element(p0), index=1
175 c = f32[4] add(f32[4] a, f32[4] b)
176 d = f32[4] subtract(a, b)
177 e = f32[4] multiply(c, d)
178 ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e)
179 }
180 )";
181
182 TF_ASSERT_OK_AND_ASSIGN(auto module,
183 ParseAndReturnVerifiedModule(hlo_string));
184 LOG(INFO) << "Original module:\n" << module->ToString();
185
186 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
187 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
188 EXPECT_TRUE(!isolator_changed);
189 }
190
TEST_F(HloDomainTest,CheckDomainAroundIO)191 TEST_F(HloDomainTest, CheckDomainAroundIO) {
192 const char* const hlo_string = R"(
193 HloModule Module
194
195 ENTRY entry {
196 p0 = (f32[4]) parameter(0)
197 a = f32[4] get-tuple-element(p0), index=0
198 token0 = token[] after-all()
199 b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0}
200 c = token[] send-done(b), channel_id=1, sharding={maximal device=0}
201 d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0}
202 e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0}
203 e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0}
204 f = f32[4] add(a, e_element)
205 g = f32[4] subtract(a, e_element)
206 ROOT h = (f32[4], f32[4]) tuple(f, g)
207 }
208 )";
209
210 TF_ASSERT_OK_AND_ASSIGN(auto module,
211 ParseAndReturnVerifiedModule(hlo_string));
212 LOG(INFO) << "Original module:\n" << module->ToString();
213
214 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
215 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
216 EXPECT_TRUE(isolator_changed);
217
218 EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a"));
219 EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element"));
220 EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0"));
221 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
222 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
223
224 HloDomainRemover remover(ShardingMetadata::KindName(),
225 ShardingMetadata::NormalizeShardingDomain);
226 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
227 EXPECT_TRUE(remover_changed);
228
229 EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a"));
230 EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element"));
231 }
232
TEST_F(HloDomainTest,CheckNoDomainAddedOnPureIOComputation)233 TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
234 const char* const hlo_string = R"(
235 HloModule Module
236
237 ENTRY entry {
238 token0 = token[] after-all(), sharding={maximal device=-1}
239 a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1}
240 b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1}
241 b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1}
242 c = f32[4] add(b_element, b_element), sharding={maximal device=-1}
243 d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1}
244 ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1}
245 }
246 )";
247
248 TF_ASSERT_OK_AND_ASSIGN(auto module,
249 ParseAndReturnVerifiedModule(hlo_string));
250 LOG(INFO) << "Original module:\n" << module->ToString();
251
252 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
253 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
254 EXPECT_FALSE(isolator_changed);
255 }
256
TEST_F(HloDomainTest,CheckNormalizationOnPureIOComputation)257 TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
258 const char* const hlo_string = R"(
259 HloModule Module
260
261 ENTRY entry {
262 token0 = token[] after-all(), sharding={maximal device=0}
263 a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0}
264 b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0}
265 b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0}
266 c = f32[4] add(b_element, b_element)
267 d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0}
268 ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0}
269 }
270 )";
271
272 TF_ASSERT_OK_AND_ASSIGN(auto module,
273 ParseAndReturnVerifiedModule(hlo_string));
274 LOG(INFO) << "Original module:\n" << module->ToString();
275
276 HloDomainRemover remover(ShardingMetadata::KindName(),
277 ShardingMetadata::NormalizeShardingDomain);
278 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
279 EXPECT_FALSE(remover_changed);
280
281 HloInstruction* add = FindInstruction(module.get(), "c");
282 ASSERT_NE(add, nullptr);
283 auto device = add->sharding_unique_device();
284 EXPECT_TRUE(device.has_value());
285 EXPECT_EQ(*device, 0);
286 }
287
TEST_F(HloDomainTest,CheckMultiDomainLinks)288 TEST_F(HloDomainTest, CheckMultiDomainLinks) {
289 const char* const hlo_string = R"(
290 HloModule Module
291
292 ENTRY entry {
293 p0 = (f32[4], f32[4]) parameter(0)
294 a = f32[4] get-tuple-element(p0), index=0
295 b = f32[4] get-tuple-element(p0), index=1
296 c = f32[4] add(a, b), sharding={maximal device=1}
297 d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"}
298 e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"}
299 f = f32[4] add(e, c), sharding={maximal device=1}
300 ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f)
301 }
302 )";
303
304 TF_ASSERT_OK_AND_ASSIGN(auto module,
305 ParseAndReturnVerifiedModule(hlo_string));
306 LOG(INFO) << "Original module:\n" << module->ToString();
307
308 HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; });
309 TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
310 sharding_isolator.Run(module.get()));
311 EXPECT_TRUE(sharding_isolator_changed);
312
313 HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
314 TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
315 opname_isolator.Run(module.get()));
316 EXPECT_TRUE(opname_isolator_changed);
317
318 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
319 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
320 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
321 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
322 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
323
324 HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
325 ShardingMetadata::NormalizeShardingDomain);
326 TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
327 sharding_remover.Run(module.get()));
328 EXPECT_TRUE(sharding_remover_changed);
329
330 HloDomainRemover opname_remover(OpNameMetadata::KindName(),
331 OpNameDomainNormalizer);
332 TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
333 opname_remover.Run(module.get()));
334 EXPECT_TRUE(opname_remover_changed);
335
336 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
337 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
338 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
339 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
340 }
341
TEST_F(HloDomainTest,CheckNormalizationOnInfeedTuple)342 TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
343 const char* const hlo_string = R"(
344 HloModule Module
345
346 ENTRY entry {
347 token0 = token[] after-all()
348 infeed = ((f32[4], f32[4]), token[]) infeed(token0),
349 sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
350 infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0,
351 sharding={{maximal device=1}, {maximal device=0}}
352 gte0 = f32[4] get-tuple-element(infeed.data), index=0
353 gte1 = f32[4] get-tuple-element(infeed.data), index=1
354 copy0 = f32[4] copy(gte0)
355 copy1 = f32[4] copy(gte1)
356 ROOT add = f32[4] add(copy0, copy1)
357 }
358 )";
359
360 TF_ASSERT_OK_AND_ASSIGN(auto module,
361 ParseAndReturnVerifiedModule(hlo_string));
362 LOG(INFO) << "Original module:\n" << module->ToString();
363
364 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
365 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
366 EXPECT_TRUE(isolator_changed);
367
368 EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed"));
369 EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0"));
370 EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1"));
371
372 // Inject unassigned tuple/gte within the infeed domain, to simulate the
373 // HLO passes adding unexpected instructions.
374 //
375 // infeed
376 // |
377 // infeed.data (tuple element 0 of infeed)
378 // / \
379 // GTE0 GTE1
380 // / \
381 // COPY0 COPY1
382 // \ /
383 // \ /
384 // TUPLE
385 // |
386 HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data");
387 ASSERT_NE(infeed_data, nullptr);
388
389 auto infeed_data_users = infeed_data->users();
390 HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
391 HloInstruction::CreateGetTupleElement(
392 ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data,
393 0));
394 HloInstruction* new_copy0 =
395 infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
396 new_gte0->shape(), HloOpcode::kCopy, new_gte0));
397 HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction(
398 HloInstruction::CreateGetTupleElement(
399 ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data,
400 1));
401 HloInstruction* new_copy1 =
402 infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
403 new_gte1->shape(), HloOpcode::kCopy, new_gte1));
404 HloInstruction* new_tuple = infeed_data->parent()->AddInstruction(
405 HloInstruction::CreateTuple({new_copy0, new_copy1}));
406 for (HloInstruction* user : infeed_data_users) {
407 TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple));
408 }
409
410 HloDomainRemover remover(ShardingMetadata::KindName(),
411 ShardingMetadata::NormalizeShardingDomain);
412 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
413 EXPECT_TRUE(remover_changed);
414
415 struct Assignment {
416 HloInstruction* instruction;
417 int64 device;
418 } assignments[] = {
419 {new_gte0, 1},
420 {new_copy0, 1},
421 {new_gte1, 0},
422 {new_copy1, 0},
423 };
424 for (auto& assignment : assignments) {
425 auto device = assignment.instruction->sharding_unique_device();
426 ASSERT_TRUE(device.has_value());
427 EXPECT_EQ(*device, assignment.device);
428 }
429 EXPECT_TRUE(new_tuple->has_sharding());
430 EXPECT_EQ(
431 new_tuple->sharding(),
432 HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1),
433 HloSharding::AssignDevice(0)}));
434 }
435
TEST_F(HloDomainTest,EmptyRootDomain)436 TEST_F(HloDomainTest, EmptyRootDomain) {
437 const char* const hlo_string = R"(
438 HloModule Module
439
440 ENTRY entry {
441 %param = f32[1] parameter(0), sharding={maximal device=0}
442 %tuple = (f32[1]) tuple(%param),
443 sharding={maximal device=1}
444 ROOT %gte = f32[1] get-tuple-element(%tuple), index=0,
445 sharding={maximal device=1}
446 })";
447
448 TF_ASSERT_OK_AND_ASSIGN(auto module,
449 ParseAndReturnVerifiedModule(hlo_string));
450
451 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
452 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
453 EXPECT_TRUE(isolator_changed);
454
455 EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param"));
456 EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple"));
457
458 // Remove %tuple and %gte (tuple simplification)
459 HloInstruction* gte = FindInstruction(module.get(), "gte");
460 HloInstruction* tuple = FindInstruction(module.get(), "tuple");
461 module->entry_computation()->set_root_instruction(tuple->mutable_operand(0));
462 TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte));
463 TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
464
465 HloDomainRemover remover(ShardingMetadata::KindName(),
466 ShardingMetadata::NormalizeShardingDomain);
467 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
468 EXPECT_TRUE(remover_changed);
469
470 const HloInstruction* root = module->entry_computation()->root_instruction();
471 EXPECT_TRUE(root->has_sharding());
472 EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1));
473 }
474
475 // Tests that text dumps of domain instructions can be parsed back, in the
476 // specific case of null shardings.
TEST_F(HloDomainTest,DumpParseNullSharding)477 TEST_F(HloDomainTest, DumpParseNullSharding) {
478 auto builder = HloComputation::Builder(TestName());
479 Shape shape = ShapeUtil::MakeShape(F32, {});
480 auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr);
481 auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr);
482 HloInstruction* param =
483 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
484 HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
485 shape, param, std::move(sharding_md_0), std::move(sharding_md_1)));
486 builder.AddInstruction(
487 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain));
488
489 auto module = CreateNewVerifiedModule();
490 module->AddEntryComputation(builder.Build());
491
492 auto hlo_string = module->ToString();
493 ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok());
494 }
495
496 // Tuple inputs are domain instructions.
TEST_F(HloDomainTest,DomainTuple)497 TEST_F(HloDomainTest, DomainTuple) {
498 const char* const hlo_string = R"(
499 HloModule Module
500
501 ENTRY entry {
502 p0 = f32[4] parameter(0), sharding={maximal device=0}
503 cst = u32[] constant(0), sharding={maximal device=1}
504 tpl = (u32[], f32[4]) tuple(cst, p0),
505 sharding={{maximal device=1}, {maximal device=0}}
506 ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0}
507 }
508 )";
509
510 TF_ASSERT_OK_AND_ASSIGN(auto module,
511 ParseAndReturnVerifiedModule(hlo_string));
512
513 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
514 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
515 EXPECT_TRUE(isolator_changed);
516
517 // Clear sharding of tpl instruction, in order to test domain sharding
518 // application.
519 auto tpl = FindInstruction(module.get(), "tpl");
520 tpl->clear_sharding();
521
522 HloDomainRemover remover(ShardingMetadata::KindName(),
523 ShardingMetadata::NormalizeShardingDomain);
524 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
525 EXPECT_TRUE(remover_changed);
526
527 EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1),
528 HloSharding::AssignDevice(0)}),
529 tpl->sharding());
530 }
531
TEST_F(HloDomainTest,MultiDomainMultiUser)532 TEST_F(HloDomainTest, MultiDomainMultiUser) {
533 const char* const hlo_string = R"(
534 HloModule Module
535
536 ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
537 %p0 = (f32[4], f32[4]) parameter(0)
538 %a = f32[4]{0} get-tuple-element(%p0), index=0
539 %domain = f32[4] domain(%a),
540 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
541 %b = f32[4] get-tuple-element(%p0), index=1
542 %domain.1 = f32[4] domain(%b),
543 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
544 %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
545 %domain.2 = f32[4] domain(%c),
546 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
547 %d = f32[4] subtract(%domain, %c),
548 sharding={maximal device=1}, metadata={op_name="D"}
549 %domain.3 = f32[4] domain(%d),
550 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
551 %e = f32[4] multiply(%c, %d),
552 sharding={maximal device=1}, metadata={op_name="D"}
553 %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
554 %domain.4 = f32[4]{0} domain(%f),
555 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
556 ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
557 })";
558
559 TF_ASSERT_OK_AND_ASSIGN(auto module,
560 ParseAndReturnVerifiedModule(hlo_string));
561 LOG(INFO) << "Original module:\n" << module->ToString();
562
563 HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
564 TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
565 opname_isolator.Run(module.get()));
566 EXPECT_TRUE(opname_isolator_changed);
567
568 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
569 EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
570 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
571 EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
572 EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
573
574 HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
575 ShardingMetadata::NormalizeShardingDomain);
576 TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
577 sharding_remover.Run(module.get()));
578 EXPECT_TRUE(sharding_remover_changed);
579
580 HloDomainRemover opname_remover(OpNameMetadata::KindName(),
581 OpNameDomainNormalizer);
582 TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
583 opname_remover.Run(module.get()));
584 EXPECT_TRUE(opname_remover_changed);
585
586 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
587 EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
588 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
589 EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
590 }
591
592 // Emulate instructions inserted at top and bottom within nested tuple domain.
TEST_F(HloDomainTest,DomainTupleTopBottomInsert)593 TEST_F(HloDomainTest, DomainTupleTopBottomInsert) {
594 const char* const hlo_string = R"(
595 HloModule Module
596
597 ENTRY entry {
598 p0 = f32[4] parameter(0), sharding={maximal device=1}
599 p1 = (f32[5], f32[6]) parameter(1),
600 sharding={{maximal device=1}, {maximal device=0}}
601 tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1),
602 sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}}
603 ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1,
604 sharding={{maximal device=1}, {maximal device=0}}
605 }
606 )";
607
608 TF_ASSERT_OK_AND_ASSIGN(auto module,
609 ParseAndReturnVerifiedModule(hlo_string));
610
611 HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
612 TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
613 EXPECT_TRUE(isolator_changed);
614
615 // Clear sharding of tuple.0 instruction, in order to test domain sharding
616 // application.
617 auto tuple0 = FindInstruction(module.get(), "tuple.0");
618 tuple0->clear_sharding();
619
620 // Insert the following instructons above and below tuple.0, to emulate other
621 // passes effects:
622 // COPY.0
623 // \ /
624 // TUPLE.0
625 // / \
626 // COPY.1 \
627 // / \
628 // GTE.0 GTE.1
629 // | |
630 // | COPY.2
631 // \ /
632 // \ /
633 // TUPLE.1
634 // |
635 auto tuple0_users = tuple0->users();
636 auto computation = tuple0->parent();
637 HloInstruction* copy0 = computation->AddInstruction(
638 HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy,
639 tuple0->mutable_operand(1)));
640 TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0));
641
642 HloInstruction* copy1 = computation->AddInstruction(
643 HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0));
644 HloInstruction* gte0 =
645 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
646 ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0));
647 HloInstruction* gte1 =
648 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
649 ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1));
650 HloInstruction* copy2 = computation->AddInstruction(
651 HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1));
652 HloInstruction* tuple1 =
653 computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2}));
654
655 for (HloInstruction* user : tuple0_users) {
656 TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1));
657 }
658
659 HloDomainRemover remover(ShardingMetadata::KindName(),
660 ShardingMetadata::NormalizeShardingDomain);
661 TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
662 EXPECT_TRUE(remover_changed);
663
664 EXPECT_TRUE(tuple0->has_sharding());
665 EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1),
666 HloSharding::AssignDevice(1),
667 HloSharding::AssignDevice(0)}),
668 tuple0->sharding());
669
670 EXPECT_TRUE(copy0->has_sharding());
671 EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1),
672 HloSharding::AssignDevice(0)}),
673 copy0->sharding());
674
675 // copy1 has partial information only from gte.0, so in the end it gets no
676 // sharding at all. During propagation it does propagate the information from
677 // gte.0 though, enabling Tuple.0 to be fully sharded.
678 EXPECT_FALSE(copy1->has_sharding());
679
680 EXPECT_TRUE(gte0->has_sharding());
681 EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding());
682
683 EXPECT_TRUE(gte1->has_sharding());
684 EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1),
685 HloSharding::AssignDevice(0)}),
686 gte1->sharding());
687
688 EXPECT_TRUE(copy2->has_sharding());
689 EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1),
690 HloSharding::AssignDevice(0)}),
691 copy2->sharding());
692
693 EXPECT_TRUE(tuple1->has_sharding());
694 EXPECT_EQ(tuple0->sharding(), tuple1->sharding());
695 }
696
697 } // namespace
698 } // namespace xla
699